diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 4d41eb70..150e7869 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1055,14 +1055,9 @@ def shape(self): ``` network.shape = (num_cells, num_branches, num_compartments) - cell.shape = (num_branches, num_compartments) - branch.shape = (num_compartments,) + cell.shape = (1, num_branches, num_compartments) + branch.shape = (1, 1, num_compartments,) ```""" - mod_name = self.__class__.__name__.lower() - if "comp" in mod_name: - return (1,) - elif "branch" in mod_name: - return self[:].shape[1:] return self[:].shape def _childview(self, index: Union[int, str, list, range, slice]): @@ -1087,8 +1082,7 @@ def __getitem__(self, index): return self._childview(index) def __iter__(self): - for i in range(self.shape[0]): - yield self[i] + raise NotImplementedError def _local_inds_to_global( self, cell_inds: np.ndarray, branch_inds: np.ndarray, comp_inds: np.ndarray @@ -1293,8 +1287,7 @@ def __getitem__(self, index): return self._childview(index) def __iter__(self): - for i in range(self.shape[0]): - yield self[i] + raise NotImplementedError def rotate(self, degrees: float, rotation_axis: str = "xy"): """Rotate jaxley modules clockwise. Used only for visualization. diff --git a/jaxley/modules/branch.py b/jaxley/modules/branch.py index c006765d..464a24ae 100644 --- a/jaxley/modules/branch.py +++ b/jaxley/modules/branch.py @@ -82,6 +82,10 @@ def __getattr__(self, key): else: raise KeyError(f"Key {key} not recognized.") + def __iter__(self): + for i in range(self.shape[2]): + yield self[i] + def init_conds(self, params): conds = self.init_branch_conds( params["axial_resistivity"], params["radius"], params["length"], self.nseg @@ -142,6 +146,10 @@ def __call__(self, index: float): new_view = super().adjust_view("branch_index", index) return new_view + def __iter__(self): + for i in range(self.shape[2]): + yield self[i] + def __getattr__(self, key): assert key in ["comp", "loc"] compview = CompartmentView(self.pointer, self.view) diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index 7f901bd4..85058f81 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -115,6 +115,10 @@ def __getattr__(self, key): else: raise KeyError(f"Key {key} not recognized.") + def __iter__(self): + for i in range(self.shape[1]): + yield self[i] + def init_morph(self): """Initialize morphology.""" parents = self.comb_parents @@ -248,6 +252,10 @@ def __call__(self, index: float): def __getattr__(self, key): assert key == "branch" return BranchView(self.pointer, self.view) + + def __iter__(self): + for i in range(self.shape[1]): + yield self[i] def rotate(self, degrees: float, rotation_axis: str = "xy"): """Rotate jaxley modules clockwise. Used only for visualization. diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index cc1471e0..175bccbf 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -111,6 +111,10 @@ def __getattr__(self, key): else: raise KeyError(f"Key {key} not recognized.") + def __iter__(self): + for i in range(self.shape[0]): + yield self[i] + def init_morph(self): self.nbranches_per_cell = [cell.total_nbranches for cell in self.cells] self.total_nbranches = sum(self.nbranches_per_cell)