Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat): refactor base AnnData class to use a AnnDataBase abstract class #949

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 63 additions & 83 deletions anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .index import _normalize_indices, _subset, Index, Index1D, get_vector
from .file_backing import AnnDataFileManager, to_memory
from .access import ElementRef
from .anndata_base import AnnDataBase
from .aligned_mapping import (
AxisArrays,
AxisArraysView,
Expand Down Expand Up @@ -125,7 +126,7 @@ def _(anno, length, index_names):
raise ValueError(f"Cannot convert {type(anno)} to DataFrame")


class AnnData(metaclass=utils.DeprecationMixinMeta):
class AnnData(AnnDataBase):
"""\
An annotated data matrix.

Expand Down Expand Up @@ -278,7 +279,7 @@ def __init__(
vidx: Index1D = None,
):
if asview:
if not isinstance(X, AnnData):
if not issubclass(type(X), AnnData):
raise ValueError("`X` has to be an AnnData object.")
self._init_as_view(X, oidx, vidx)
else:
Expand Down Expand Up @@ -357,74 +358,44 @@ def _init_as_view(self, adata_ref: "AnnData", oidx: Index, vidx: Index):
else:
self._raw = None

def _init_as_actual(
self,
X=None,
obs=None,
var=None,
uns=None,
obsm=None,
varm=None,
varp=None,
obsp=None,
raw=None,
layers=None,
dtype=None,
shape=None,
filename=None,
filemode=None,
def _reformat_axes_args_from_X(
self, X, obs, var, uns, obsm, varm, obsp, varp, layers, raw
):
# view attributes
self._is_view = False
self._adata_ref = None
self._oidx = None
self._vidx = None

# ----------------------------------------------------------------------
# various ways of initializing the data
# ----------------------------------------------------------------------

# If X is a data frame, we store its indices for verification
x_indices = []

# init from file
if filename is not None:
self.file = AnnDataFileManager(self, filename, filemode)
else:
self.file = AnnDataFileManager(self, None)

# init from AnnData
if isinstance(X, AnnData):
if any((obs, var, uns, obsm, varm, obsp, varp)):
raise ValueError(
"If `X` is a dict no further arguments must be provided."
)
X, obs, var, uns, obsm, varm, obsp, varp, layers, raw = (
X._X,
X.obs,
X.var,
X.uns,
X.obsm,
X.varm,
X.obsp,
X.varp,
X.layers,
X.raw,
# init from AnnData
if isinstance(X, AnnData):
if any((obs, var, uns, obsm, varm, obsp, varp)):
raise ValueError(
"If `X` is a dict no further arguments must be provided."
)
X, obs, var, uns, obsm, varm, obsp, varp, layers, raw = (
X._X,
X.obs,
X.var,
X.uns,
X.obsm,
X.varm,
X.obsp,
X.varp,
X.layers,
X.raw,
)

# init from DataFrame
elif isinstance(X, pd.DataFrame):
# to verify index matching, we wait until obs and var are DataFrames
if obs is None:
obs = pd.DataFrame(index=X.index)
elif not isinstance(X.index, pd.RangeIndex):
x_indices.append(("obs", "index", X.index))
if var is None:
var = pd.DataFrame(index=X.columns)
elif not isinstance(X.columns, pd.RangeIndex):
x_indices.append(("var", "columns", X.columns))
X = ensure_df_homogeneous(X, "X")

# init from DataFrame
elif isinstance(X, pd.DataFrame):
# to verify index matching, we wait until obs and var are DataFrames
if obs is None:
obs = pd.DataFrame(index=X.index)
elif not isinstance(X.index, pd.RangeIndex):
x_indices.append(("obs", "index", X.index))
if var is None:
var = pd.DataFrame(index=X.columns)
elif not isinstance(X.columns, pd.RangeIndex):
x_indices.append(("var", "columns", X.columns))
X = ensure_df_homogeneous(X, "X")
return (X, obs, var, uns, obsm, varm, obsp, varp, layers, raw, x_indices)

def _assign_X(self, X, shape, dtype):
# ----------------------------------------------------------------------
# actually process the data
# ----------------------------------------------------------------------
Expand Down Expand Up @@ -459,9 +430,18 @@ def _init_as_actual(
X = np.array(X, dtype, copy=False)
# data matrix and shape
self._X = X
self._n_obs, self._n_vars = self._X.shape
else:
self._X = None

def _initialize_indices(self, shape, obs, var):
# ----------------------------------------------------------------------
# actually process the data
# ----------------------------------------------------------------------

# check data type of X
if self._X is not None:
self._n_obs, self._n_vars = self._X.shape
else:
self._n_obs = len([] if obs is None else obs)
self._n_vars = len([] if var is None else var)
# check consistency with shape
Expand All @@ -477,34 +457,38 @@ def _init_as_actual(
if self._n_vars != shape[1]:
raise ValueError("`shape` is inconsistent with `var`")

# annotations
# annotations
def _assign_obs(self, obs):
self._obs = _gen_dataframe(obs, self._n_obs, ["obs_names", "row_names"])
self._var = _gen_dataframe(var, self._n_vars, ["var_names", "col_names"])

# now we can verify if indices match!
for attr_name, x_name, idx in x_indices:
attr = getattr(self, attr_name)
if isinstance(attr.index, pd.RangeIndex):
attr.index = idx
elif not idx.equals(attr.index):
raise ValueError(f"Index of {attr_name} must match {x_name} of X.")
def _assign_var(self, var):
self._var = _gen_dataframe(var, self._n_vars, ["var_names", "col_names"])

# unstructured annotations
# unstructured annotations
def _assign_uns(self, uns):
self.uns = uns or OrderedDict()

# TODO: Think about consequences of making obsm a group in hdf
# TODO: Think about consequences of making obsm a group in hdf
def _assign_obsm(self, obsm):
self._obsm = AxisArrays(self, 0, vals=convert_to_dict(obsm))

def _assign_varm(self, varm):
self._varm = AxisArrays(self, 1, vals=convert_to_dict(varm))

def _assign_obsp(self, obsp):
self._obsp = PairwiseArrays(self, 0, vals=convert_to_dict(obsp))

def _assign_varp(self, varp):
self._varp = PairwiseArrays(self, 1, vals=convert_to_dict(varp))

def _run_checks(self):
# Backwards compat for connectivities matrices in uns["neighbors"]
_move_adj_mtx({"uns": self._uns, "obsp": self._obsp})

self._check_dimensions()
self._check_uniqueness()

def _cleanup_raw_and_uns(self, raw, uns):
if self.filename:
assert not isinstance(
raw, Raw
Expand All @@ -521,6 +505,7 @@ def _init_as_actual(
# clean up old formats
self._clean_up_old_format(uns)

def _assign_layers(self, layers):
# layers
self._layers = Layers(self, layers)

Expand Down Expand Up @@ -581,11 +566,6 @@ def __eq__(self, other):
"instead compare the desired attributes."
)

@property
def shape(self) -> Tuple[int, int]:
"""Shape of data matrix (:attr:`n_obs`, :attr:`n_vars`)."""
return self.n_obs, self.n_vars

@property
def X(self) -> Optional[Union[np.ndarray, sparse.spmatrix, ArrayView]]:
"""Data matrix of shape :attr:`n_obs` × :attr:`n_vars`."""
Expand Down
Loading