Skip to content

Commit

Permalink
Centralize version comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Jan 14, 2025
1 parent a12862a commit f85b027
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 46 deletions.
6 changes: 2 additions & 4 deletions src/anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from .. import abc
from .._settings import settings
from ..compat import H5Group, SpArray, ZarrArray, ZarrGroup, _read_attr
from ..compat import H5Group, SpArray, ZarrArray, ZarrGroup, _read_attr, is_zarr_v2
from .index import _fix_slice_bounds, _subset, unpack_index

if TYPE_CHECKING:
Expand Down Expand Up @@ -72,10 +72,8 @@ def copy(self) -> ss.csr_matrix | ss.csc_matrix:
return sparse_dataset(self.data.parent).to_memory()
if isinstance(self.data, ZarrArray):
import zarr
from packaging.version import Version

is_zarr_v2 = Version(zarr.__version__) < Version("3.0.0b0")
if is_zarr_v2:
if is_zarr_v2():
sparse_group = zarr.open(
store=self.data.store,
mode="r",
Expand Down
16 changes: 6 additions & 10 deletions src/anndata/_io/specs/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)

from ..._settings import settings
from ...compat import is_zarr_v2
from .registry import _REGISTRY, IOSpec, read_elem, read_elem_partial

if TYPE_CHECKING:
Expand Down Expand Up @@ -413,12 +414,11 @@ def write_basic_dask_zarr(
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
import dask.array as da
import zarr

if Version(zarr.__version__) >= Version("3.0.0b0"):
g = f.require_array(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
else:
if is_zarr_v2():
g = f.require_dataset(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
else:
g = f.require_array(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
da.store(elem, g, lock=GLOBAL_LOCK)


Expand Down Expand Up @@ -513,9 +513,7 @@ def write_vlen_string_array_zarr(
_writer: Writer,
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
import zarr

if Version(zarr.__version__) < Version("3.0.0b0"):
if is_zarr_v2():
import numcodecs

if Version(numcodecs.__version__) < Version("0.13"):
Expand Down Expand Up @@ -1181,12 +1179,10 @@ def write_scalar_zarr(
_writer: Writer,
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
import zarr

# these args are ignored in v2: https://zarr.readthedocs.io/en/v2.18.4/api/hierarchy.html#zarr.hierarchy.Group.create_dataset
# and error out in v3
dataset_kwargs = _remove_scalar_compression_args(dataset_kwargs)
if Version(zarr.__version__) < Version("3.0.0b0"):
if is_zarr_v2():
return f.create_dataset(key, data=np.array(value), shape=(), **dataset_kwargs)
else:
from numcodecs import VLenUTF8
Expand Down
24 changes: 8 additions & 16 deletions src/anndata/_io/specs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
from types import MappingProxyType
from typing import TYPE_CHECKING, Generic, TypeVar

from packaging.version import Version

from anndata._io.utils import report_read_key_on_error, report_write_key_on_error
from anndata._types import Read, ReadDask, _ReadDaskInternal, _ReadInternal
from anndata.compat import DaskArray, ZarrGroup, _read_attr
from anndata.compat import DaskArray, ZarrGroup, _read_attr, is_zarr_v2

if TYPE_CHECKING:
from collections.abc import Callable, Generator, Iterable
Expand Down Expand Up @@ -342,21 +340,15 @@ def write_elem(
return lambda *_, **__: None

# Normalize k to absolute path
is_zarr_group_and_is_zarr_package_v2 = False
if isinstance(store, ZarrGroup):
import zarr

if Version(zarr.__version__) < Version("3.0.0b0"):
is_zarr_group_and_is_zarr_package_v2 = True

if is_zarr_group_and_is_zarr_package_v2 or isinstance(store, h5py.Group):
if not PurePosixPath(k).is_absolute():
k = str(PurePosixPath(store.name) / k)
if (
(isinstance(store, ZarrGroup) and is_zarr_v2())
or isinstance(store, h5py.Group)
and not PurePosixPath(k).is_absolute()
):
k = str(PurePosixPath(store.name) / k)

if k == "/":
if isinstance(store, ZarrGroup) and Version(zarr.__version__) >= Version(
"3.0.0b0"
):
if isinstance(store, ZarrGroup) and not is_zarr_v2():
import asyncio

asyncio.run(store.store.clear())
Expand Down
16 changes: 7 additions & 9 deletions src/anndata/_io/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
import numpy as np
import pandas as pd
import zarr
from packaging.version import Version
from scipy import sparse

from .._core.anndata import AnnData
from .._settings import settings
from .._warnings import OldFormatWarning
from ..compat import _clean_uns, _from_fixed_length_strings
from ..compat import _clean_uns, _from_fixed_length_strings, is_zarr_v2
from ..experimental import read_dispatched, write_dispatched
from .specs import read_elem
from .utils import _read_legacy_raw, report_read_key_on_error
Expand Down Expand Up @@ -150,14 +149,13 @@ def read_dataframe(group: zarr.Group | zarr.Array) -> pd.DataFrame:
return read_elem(group)


_FMT_PARAM = (
"zarr_version" if Version(zarr.__version__) < Version("3.0.0b0") else "zarr_format"
)


def open_write_group(
store: StoreLike, *, mode: AccessModeLiteral = "w", **kwargs
) -> zarr.Group:
return zarr.open_group(
store, mode=mode, **{_FMT_PARAM: settings.zarr_write_format}, **kwargs
if {"zarr_version", "zarr_format"} & kwargs.keys():
msg = "Don’t specify `zarr_version` or `zarr_format` explicitly."
raise ValueError(msg)
kwargs["zarr_version" if is_zarr_v2() else "zarr_format"] = (
settings.zarr_write_format
)
return zarr.open_group(store, mode=mode, **kwargs)
10 changes: 9 additions & 1 deletion src/anndata/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Mapping
from contextlib import AbstractContextManager
from dataclasses import dataclass, field
from functools import partial, singledispatch, wraps
from functools import cache, partial, singledispatch, wraps
from importlib.util import find_spec
from inspect import Parameter, signature
from pathlib import Path
Expand Down Expand Up @@ -106,6 +106,14 @@ def __repr__():
return "mock zarr.core.Group"


@cache
def is_zarr_v2() -> bool:
import zarr
from packaging.version import Version

return Version(zarr.__version__) < Version("3.0.0b0")


if find_spec("awkward") or TYPE_CHECKING:
import awkward # noqa: F401
from awkward import Array as AwkArray
Expand Down
10 changes: 4 additions & 6 deletions src/anndata/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import numpy as np
import pandas as pd
import pytest
from packaging.version import Version
from pandas.api.types import is_numeric_dtype
from scipy import sparse

Expand All @@ -34,6 +33,7 @@
DaskArray,
SpArray,
ZarrArray,
is_zarr_v2,
)
from anndata.utils import asarray

Expand Down Expand Up @@ -1093,12 +1093,10 @@ def shares_memory_sparse(x, y):
]

if find_spec("zarr") or TYPE_CHECKING:
import zarr

if Version(zarr.__version__) > Version("3.0.0b0"):
from zarr.storage import LocalStore
else:
if is_zarr_v2():
from zarr.storage import DirectoryStore as LocalStore
else:
from zarr.storage import LocalStore
else:

class LocalStore:
Expand Down

0 comments on commit f85b027

Please sign in to comment.