Skip to content

Commit

Permalink
Improve aligned mapping errors (#1252)
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep authored Dec 8, 2023
1 parent 4745b1d commit 49ca3bd
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 33 deletions.
42 changes: 23 additions & 19 deletions anndata/_core/aligned_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import (
TYPE_CHECKING,
ClassVar,
Literal,
TypeVar,
Union,
)
Expand Down Expand Up @@ -68,21 +69,24 @@ def _validate_value(self, val: V, key: str) -> V:
# stacklevel=3,
)
for i, axis in enumerate(self.axes):
if self.parent.shape[axis] != dim_len(val, i):
right_shape = tuple(self.parent.shape[a] for a in self.axes)
actual_shape = tuple(dim_len(val, a) for a, _ in enumerate(self.axes))
if actual_shape[i] is None and isinstance(val, AwkArray):
raise ValueError(
f"The AwkwardArray is of variable length in dimension {i}.",
f"Try ak.to_regular(array, {i}) before including the array in AnnData",
)
else:
raise ValueError(
f"Value passed for key {key!r} is of incorrect shape. "
f"Values of {self.attrname} must match dimensions "
f"{self.axes} of parent. Value had shape {actual_shape} while "
f"it should have had {right_shape}."
)
if self.parent.shape[axis] == dim_len(val, i):
continue
right_shape = tuple(self.parent.shape[a] for a in self.axes)
actual_shape = tuple(dim_len(val, a) for a, _ in enumerate(self.axes))
if actual_shape[i] is None and isinstance(val, AwkArray):
dim = ("obs", "var")[i]
msg = (
f"The AwkwardArray is of variable length in dimension {dim}.",
f"Try ak.to_regular(array, {i}) before including the array in AnnData",
)
else:
dims = tuple(("obs", "var")[ax] for ax in self.axes)
msg = (
f"Value passed for key {key!r} is of incorrect shape. "
f"Values of {self.attrname} must match dimensions {dims} of parent. "
f"Value had shape {actual_shape} while it should have had {right_shape}."
)
raise ValueError(msg)

if not self._allow_df and isinstance(val, pd.DataFrame):
name = self.attrname.title().rstrip("s")
Expand All @@ -97,7 +101,7 @@ def attrname(self) -> str:

@property
@abstractmethod
def axes(self) -> tuple[int, ...]:
def axes(self) -> tuple[Literal[0, 1], ...]:
"""Which axes of the parent is this aligned to?"""
pass

Expand Down Expand Up @@ -222,7 +226,7 @@ def attrname(self) -> str:
return f"{self.dim}m"

@property
def axes(self) -> tuple[int]:
def axes(self) -> tuple[Literal[0, 1]]:
"""Axes of the parent this is aligned to"""
return (self._axis,)

Expand Down Expand Up @@ -256,7 +260,7 @@ def _validate_value(self, val: V, key: str) -> V:
try:
pd.testing.assert_index_equal(val.index, self.dim_names)
except AssertionError as e:
msg = f"value.index does not match parent’s axis {self.axes[0]} names:\n{e}"
msg = f"value.index does not match parent’s {self.dim} names:\n{e}"
raise ValueError(msg) from None
else:
msg = "Index.equals and pd.testing.assert_index_equal disagree"
Expand Down Expand Up @@ -357,7 +361,7 @@ def attrname(self) -> str:
return f"{self.dim}p"

@property
def axes(self) -> tuple[int, int]:
def axes(self) -> tuple[Literal[0], Literal[0]] | tuple[Literal[1], Literal[1]]:
"""Axes of the parent this is aligned to"""
return self._axis, self._axis

Expand Down
13 changes: 13 additions & 0 deletions anndata/tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,16 @@ def test_copy():
bdata = adata.copy()
adata.layers["L"] += 10
assert np.all(adata.layers["L"] != bdata.layers["L"]) # 201


def test_shape_error():
adata = AnnData(X=X)
with pytest.raises(
ValueError,
match=(
r"Value passed for key 'L' is of incorrect shape\. "
r"Values of layers must match dimensions \('obs', 'var'\) of parent\. "
r"Value had shape \(4, 3\) while it should have had \(3, 3\)\."
),
):
adata.layers["L"] = np.zeros((X.shape[0] + 1, X.shape[1]))
26 changes: 19 additions & 7 deletions anndata/tests/test_obsmvarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
from scipy import sparse

import anndata
from anndata import AnnData

M, N = (100, 100)

Expand All @@ -19,10 +19,10 @@ def adata():
index=[f"cell{i:03d}" for i in range(N)],
)
var = pd.DataFrame(index=[f"gene{i:03d}" for i in range(N)])
return anndata.AnnData(X, obs=obs, var=var)
return AnnData(X, obs=obs, var=var)


def test_assignment_dict(adata):
def test_assignment_dict(adata: AnnData):
d_obsm = dict(
a=pd.DataFrame(
dict(a1=np.ones(M), a2=[f"a{i}" for i in range(M)]),
Expand All @@ -45,7 +45,7 @@ def test_assignment_dict(adata):
assert np.all(adata.varm[k] == v)


def test_setting_ndarray(adata):
def test_setting_ndarray(adata: AnnData):
adata.obsm["a"] = np.ones((M, 10))
adata.varm["a"] = np.ones((N, 10))
assert np.all(adata.obsm["a"] == np.ones((M, 10)))
Expand All @@ -63,7 +63,7 @@ def test_setting_ndarray(adata):
assert h == joblib.hash(adata)


def test_setting_dataframe(adata):
def test_setting_dataframe(adata: AnnData):
obsm_df = pd.DataFrame(dict(b_1=np.ones(M), b_2=["a"] * M), index=adata.obs_names)
varm_df = pd.DataFrame(dict(b_1=np.ones(N), b_2=["a"] * N), index=adata.var_names)

Expand All @@ -83,7 +83,7 @@ def test_setting_dataframe(adata):
adata.varm["c"] = bad_varm_df


def test_setting_sparse(adata):
def test_setting_sparse(adata: AnnData):
obsm_sparse = sparse.random(M, 100)
adata.obsm["a"] = obsm_sparse
assert not np.any((adata.obsm["a"] != obsm_sparse).data)
Expand All @@ -105,7 +105,7 @@ def test_setting_sparse(adata):
assert h == joblib.hash(adata)


def test_setting_daskarray(adata):
def test_setting_daskarray(adata: AnnData):
import dask.array as da

adata.obsm["a"] = da.ones((M, 10))
Expand All @@ -125,3 +125,15 @@ def test_setting_daskarray(adata):
with pytest.raises(ValueError):
adata.varm["b"] = da.ones((int(N * 2), 10))
assert h == joblib.hash(adata)


def test_shape_error(adata: AnnData):
with pytest.raises(
ValueError,
match=(
r"Value passed for key 'b' is of incorrect shape\. "
r"Values of obsm must match dimensions \('obs',\) of parent\. "
r"Value had shape \(101,\) while it should have had \(100,\)\."
),
):
adata.obsm["b"] = np.zeros((adata.shape[0] + 1, adata.shape[0]))
26 changes: 19 additions & 7 deletions anndata/tests/test_obspvarp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
from scipy import sparse

import anndata
from anndata import AnnData
from anndata.tests.helpers import gen_typed_df_t2_size
from anndata.utils import asarray

Expand All @@ -24,10 +24,10 @@ def adata():
index=[f"cell{i:03d}" for i in range(M)],
)
var = pd.DataFrame(index=[f"gene{i:03d}" for i in range(N)])
return anndata.AnnData(X, obs=obs, var=var)
return AnnData(X, obs=obs, var=var)


def test_assigmnent_dict(adata):
def test_assigmnent_dict(adata: AnnData):
d_obsp = dict(
a=pd.DataFrame(np.ones((M, M)), columns=adata.obs_names, index=adata.obs_names),
b=np.zeros((M, M)),
Expand All @@ -46,7 +46,7 @@ def test_assigmnent_dict(adata):
assert np.all(asarray(adata.varp[k]) == asarray(v))


def test_setting_ndarray(adata):
def test_setting_ndarray(adata: AnnData):
adata.obsp["a"] = np.ones((M, M))
adata.varp["a"] = np.ones((N, N))
assert np.all(adata.obsp["a"] == np.ones((M, M)))
Expand All @@ -64,7 +64,7 @@ def test_setting_ndarray(adata):
assert h == joblib.hash(adata)


def test_setting_sparse(adata):
def test_setting_sparse(adata: AnnData):
obsp_sparse = sparse.random(M, M)
adata.obsp["a"] = obsp_sparse
assert not np.any((adata.obsp["a"] != obsp_sparse).data)
Expand Down Expand Up @@ -95,7 +95,7 @@ def test_setting_sparse(adata):
],
ids=["heterogeneous", "homogeneous"],
)
def test_setting_dataframe(adata, field, dim, homogenous, df, dtype):
def test_setting_dataframe(adata: AnnData, field, dim, homogenous, df, dtype):
if homogenous:
with pytest.warns(UserWarning, match=rf"{field.title()} 'df'.*dtype object"):
getattr(adata, field)["df"] = df(dim)
Expand All @@ -107,7 +107,7 @@ def test_setting_dataframe(adata, field, dim, homogenous, df, dtype):
assert np.issubdtype(getattr(adata, field)["df"].dtype, dtype)


def test_setting_daskarray(adata):
def test_setting_daskarray(adata: AnnData):
import dask.array as da

adata.obsp["a"] = da.ones((M, M))
Expand All @@ -127,3 +127,15 @@ def test_setting_daskarray(adata):
with pytest.raises(ValueError):
adata.varp["b"] = da.ones((N, int(N * 2)))
assert h == joblib.hash(adata)


def test_shape_error(adata: AnnData):
with pytest.raises(
ValueError,
match=(
r"Value passed for key 'a' is of incorrect shape\. "
r"Values of obsp must match dimensions \('obs', 'obs'\) of parent\. "
r"Value had shape \(201, 200\) while it should have had \(200, 200\)\."
),
):
adata.obsp["a"] = np.zeros((adata.shape[0] + 1, adata.shape[0]))
1 change: 1 addition & 0 deletions docs/release-notes/0.10.4.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

```{rubric} Documentation
```
* Improve aligned mapping error messages {pr}`1252` {user}`flying-sheep`

```{rubric} Performance
```

0 comments on commit 49ca3bd

Please sign in to comment.