Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] fix __iter__ #329

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,14 +1055,9 @@ def shape(self):

```
network.shape = (num_cells, num_branches, num_compartments)
cell.shape = (num_branches, num_compartments)
branch.shape = (num_compartments,)
cell.shape = (1, num_branches, num_compartments)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not so sure if we should assign dim=1 to cell if you technically have no way to index into it (it is the object being indexed into).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency, net would also have to have shape=(1,...), which would be odd imo.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be added to #43 when settled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry, I had not meant to request a review yet...

Anyways, __iter__ is broken right now, see the PR description. But let's just talk in person this is not an urgent PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Went over the recent PRs and thought I'd just do a pass of this one as well, even though you did not request it.

Alright sounds good.

branch.shape = (1, 1, num_compartments,)
```"""
mod_name = self.__class__.__name__.lower()
if "comp" in mod_name:
return (1,)
elif "branch" in mod_name:
return self[:].shape[1:]
return self[:].shape

def _childview(self, index: Union[int, str, list, range, slice]):
Expand All @@ -1087,8 +1082,7 @@ def __getitem__(self, index):
return self._childview(index)

def __iter__(self):
for i in range(self.shape[0]):
yield self[i]
raise NotImplementedError

def _local_inds_to_global(
self, cell_inds: np.ndarray, branch_inds: np.ndarray, comp_inds: np.ndarray
Expand Down Expand Up @@ -1293,8 +1287,7 @@ def __getitem__(self, index):
return self._childview(index)

def __iter__(self):
for i in range(self.shape[0]):
yield self[i]
raise NotImplementedError

def rotate(self, degrees: float, rotation_axis: str = "xy"):
"""Rotate jaxley modules clockwise. Used only for visualization.
Expand Down
8 changes: 8 additions & 0 deletions jaxley/modules/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def __getattr__(self, key):
else:
raise KeyError(f"Key {key} not recognized.")

def __iter__(self):
for i in range(self.shape[2]):
yield self[i]

def init_conds(self, params):
conds = self.init_branch_conds(
params["axial_resistivity"], params["radius"], params["length"], self.nseg
Expand Down Expand Up @@ -142,6 +146,10 @@ def __call__(self, index: float):
new_view = super().adjust_view("branch_index", index)
return new_view

def __iter__(self):
for i in range(self.shape[2]):
yield self[i]

def __getattr__(self, key):
assert key in ["comp", "loc"]
compview = CompartmentView(self.pointer, self.view)
Expand Down
8 changes: 8 additions & 0 deletions jaxley/modules/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ def __getattr__(self, key):
else:
raise KeyError(f"Key {key} not recognized.")

def __iter__(self):
for i in range(self.shape[1]):
yield self[i]

def init_morph(self):
"""Initialize morphology."""
parents = self.comb_parents
Expand Down Expand Up @@ -248,6 +252,10 @@ def __call__(self, index: float):
def __getattr__(self, key):
assert key == "branch"
return BranchView(self.pointer, self.view)

def __iter__(self):
for i in range(self.shape[1]):
yield self[i]

def rotate(self, degrees: float, rotation_axis: str = "xy"):
"""Rotate jaxley modules clockwise. Used only for visualization.
Expand Down
4 changes: 4 additions & 0 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ def __getattr__(self, key):
else:
raise KeyError(f"Key {key} not recognized.")

def __iter__(self):
for i in range(self.shape[0]):
yield self[i]

def init_morph(self):
self.nbranches_per_cell = [cell.total_nbranches for cell in self.cells]
self.total_nbranches = sum(self.nbranches_per_cell)
Expand Down
Loading