Skip to content

Commit

Permalink
Add support for boolean indexing of cells/branches/comps (#494)
Browse files Browse the repository at this point in the history
* enh: add support for boolean indexing of cells/branches/comps

* fix: fix order

* fix: make tests pass

* fix: add more tests and support for edge select
  • Loading branch information
jnsbck authored Nov 14, 2024
1 parent 175758d commit ab50924
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 13 deletions.
23 changes: 16 additions & 7 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,10 @@ def _childviews(self) -> List[str]:
I.e. for net -> [cell, branch, comp]. For branch -> [comp]"""
levels = ["network", "cell", "branch", "comp"]
children = levels[levels.index(self._current_view) + 1 :]
return children
if self._current_view in levels:
children = levels[levels.index(self._current_view) + 1 :]
return children
return []

def _has_childview(self, key: str) -> bool:
child_views = self._childviews()
Expand Down Expand Up @@ -383,16 +385,23 @@ def _reformat_index(self, idx: Any, dtype: type = int) -> np.ndarray:
Returns:
array of indices of shape (N,)"""
if is_str_all(idx): # also asserts that the only allowed str == "all"
return idx

np_dtype = np.int64 if dtype is int else np.float64
idx = np.array([], dtype=dtype) if idx is None else idx
idx = np.array([idx]) if isinstance(idx, (dtype, np_dtype)) else idx
idx = np.array(idx) if isinstance(idx, (list, range, pd.Index)) else idx
num_nodes = len(self._nodes_in_view)
idx = np.arange(num_nodes + 1)[idx] if isinstance(idx, slice) else idx
if is_str_all(idx): # also asserts that the only allowed str == "all"
return idx

idx = np.arange(len(self.base.nodes))[idx] if isinstance(idx, slice) else idx
if idx.dtype == bool:
shape = (*self.shape, len(self.edges))
which_idx = len(idx) == np.array(shape)
assert np.any(which_idx), "Index not matching num of cells/branches/comps."
dim = shape[np.where(which_idx)[0][0]]
idx = np.arange(dim)[idx]
assert isinstance(idx, np.ndarray), "Invalid type"
assert idx.dtype == np_dtype, "Invalid dtype"
assert idx.dtype in [np_dtype, bool], "Invalid dtype"
return idx.reshape(-1)

def _set_controlled_by_param(self, key: str):
Expand Down
26 changes: 20 additions & 6 deletions tests/test_viewing.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,10 @@ def test_view_attrs(module: jx.Compartment | jx.Branch | jx.Cell | jx.Network):


comp = jx.Compartment()
branch = jx.Branch([comp] * 4)
cell = jx.Cell([branch] * 4, parents=[-1, 0, 0, 0])
net = jx.Network([cell] * 4)
branch = jx.Branch(nseg=4)
cell = jx.Cell([branch] * 5, parents=[-1, 0, 0, 1, 1])
net = jx.Network([cell] * 2)
connect(net[0, 0, :], net[1, 0, :], TestSynapse())


@pytest.mark.parametrize("module", [comp, branch, cell, net])
Expand All @@ -317,25 +318,38 @@ def test_view_supported_index_types(module):
[0, 1, 2],
np.array([0, 1, 2]),
pd.Index([0, 1, 2]),
np.array([True, False, True, False] * 100)[: len(module.nodes)],
]

# comp.comp is not allowed
all_inds = module.nodes.index.to_numpy()
if not isinstance(module, jx.Compartment):
# `_reformat_index` should always return a np.ndarray
for index in index_types:
assert isinstance(
module._reformat_index(index), np.ndarray
), f"Failed for {type(index)}"

# test indexing into module and view
assert module.comp(index), f"Failed for {type(index)}"
assert View(module).comp(index), f"Failed for {type(index)}"

# for loc test float and list of floats
assert module.loc(0.0), "Failed for float"
assert module.loc([0.0, 0.5, 1.0]), "Failed for List[float]"
expected_inds = all_inds[index]
assert np.all(module.select(nodes=index).nodes.index == expected_inds)

# for loc test float and list of floats
assert module.loc(0.0), "Failed for float"
assert module.loc([0.0, 0.5, 1.0]), "Failed for List[float]"
else:
with pytest.raises(AssertionError):
module.comp(0)

if isinstance(module, jx.Network):
all_inds = module.edges.index.to_numpy()
for index in index_types[:-1] + [np.array([True, False, True, False])]:
expected_inds = all_inds[index]
assert np.all(net.select(edges=index).edges.index == expected_inds)


def test_select():
"""Ensure `select` works correctly and returns expected View of Modules."""
Expand Down

0 comments on commit ab50924

Please sign in to comment.