Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix __iter__ and .shape #330

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 3 additions & 26 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
flip_comp_indices,
interpolate_xyz,
loc_of_index,
get_local_indices
)
from jaxley.utils.plot_utils import plot_morph

Expand Down Expand Up @@ -1249,29 +1250,6 @@ def adjust_view(self, key: str, index: Union[int, str, list, range, slice]):
self.view["controlled_by_param"] -= self.view["controlled_by_param"].iloc[0]
return self

def _get_local_indices(self):
"""Computes local from global indices.

#cell_index, branch_index, comp_index
0, 0, 0 --> 0, 0, 0 # 1st compartment of 1st branch of 1st cell
0, 0, 1 --> 0, 0, 1 # 2nd compartment of 1st branch of 1st cell
0, 1, 2 --> 0, 1, 0 # 1st compartment of 2nd branch of 1st cell
0, 1, 3 --> 0, 1, 1 # 2nd compartment of 2nd branch of 1st cell
1, 2, 4 --> 1, 0, 0 # 1st compartment of 1st branch of 2nd cell
1, 2, 5 --> 1, 0, 1 # 2nd compartment of 1st branch of 2nd cell
1, 3, 6 --> 1, 1, 0 # 1st compartment of 2nd branch of 2nd cell
1, 3, 7 --> 1, 1, 1 # 2nd compartment of 2nd branch of 2nd cell
"""

def reindex_a_by_b(df, a, b):
df.loc[:, a] = df.groupby(b)[a].rank(method="dense").astype(int) - 1
return df

idcs = self.view[["cell_index", "branch_index", "comp_index"]]
idcs = reindex_a_by_b(idcs, "branch_index", "cell_index")
idcs = reindex_a_by_b(idcs, "comp_index", ["cell_index", "branch_index"])
return idcs

def _childview(self, index: Union[int, str, list, range, slice]):
"""Return the child view of the current view.

Expand All @@ -1293,7 +1271,7 @@ def __getitem__(self, index):
return self._childview(index)

def __iter__(self):
for i in range(self.shape[0]):
for i in range(self.shape[1]):
yield self[i]

def rotate(self, degrees: float, rotation_axis: str = "xy"):
Expand All @@ -1309,8 +1287,7 @@ def rotate(self, degrees: float, rotation_axis: str = "xy"):

@property
def shape(self):
local_idcs = self._get_local_indices()
return tuple(local_idcs.nunique())
raise NotImplementedError

def _append_multiple_synapses(
self, pre_rows: pd.DataFrame, post_rows: pd.DataFrame, synapse_type: Synapse
Expand Down
14 changes: 12 additions & 2 deletions jaxley/modules/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from jaxley.modules.base import GroupView, Module, View
from jaxley.modules.compartment import Compartment, CompartmentView
from jaxley.utils.cell_utils import compute_coupling_cond
from jaxley.utils.cell_utils import compute_coupling_cond, get_local_indices


class Branch(Module):
Expand Down Expand Up @@ -82,6 +82,11 @@ def __getattr__(self, key):
else:
raise KeyError(f"Key {key} not recognized.")

@property
def shape(self):
local_idcs = get_local_indices(self.nodes)
return tuple(local_idcs.nunique())[2:]

def init_conds(self, params):
conds = self.init_branch_conds(
params["axial_resistivity"], params["radius"], params["length"], self.nseg
Expand Down Expand Up @@ -134,7 +139,7 @@ def __init__(self, pointer, view):
super().__init__(pointer, view)

def __call__(self, index: float):
local_idcs = self._get_local_indices()
local_idcs = get_local_indices(self.view)
self.view[local_idcs.columns] = (
local_idcs # set indexes locally. enables net[0:2,0:2]
)
Expand All @@ -146,3 +151,8 @@ def __getattr__(self, key):
assert key in ["comp", "loc"]
compview = CompartmentView(self.pointer, self.view)
return compview if key == "comp" else compview.loc

@property
def shape(self):
local_idcs = get_local_indices(self.view)
return tuple(local_idcs.nunique())[1:]
13 changes: 12 additions & 1 deletion jaxley/modules/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
compute_coupling_cond,
compute_levels,
loc_of_index,
get_local_indices
)
from jaxley.utils.swc import swc_to_jaxley

Expand Down Expand Up @@ -115,6 +116,11 @@ def __getattr__(self, key):
else:
raise KeyError(f"Key {key} not recognized.")

@property
def shape(self):
local_idcs = get_local_indices(self.nodes)
return tuple(local_idcs.nunique())[1:]

def init_morph(self):
"""Initialize morphology."""
parents = self.comb_parents
Expand Down Expand Up @@ -236,7 +242,7 @@ def __init__(self, pointer, view):
super().__init__(pointer, view)

def __call__(self, index: float):
local_idcs = self._get_local_indices()
local_idcs = get_local_indices(self.view)
self.view[local_idcs.columns] = (
local_idcs # set indexes locally. enables net[0:2,0:2]
)
Expand All @@ -249,6 +255,11 @@ def __getattr__(self, key):
assert key == "branch"
return BranchView(self.pointer, self.view)

@property
def shape(self):
local_idcs = get_local_indices(self.view)
return tuple(local_idcs.nunique())

def rotate(self, degrees: float, rotation_axis: str = "xy"):
"""Rotate jaxley modules clockwise. Used only for visualization.

Expand Down
11 changes: 10 additions & 1 deletion jaxley/modules/compartment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from matplotlib.axes import Axes

from jaxley.modules.base import Module, View
from jaxley.utils.cell_utils import index_of_loc, interpolate_xyz, loc_of_index
from jaxley.utils.cell_utils import index_of_loc, interpolate_xyz, loc_of_index, get_local_indices


class Compartment(Module):
Expand Down Expand Up @@ -45,6 +45,10 @@ def __init__(self):
# Coordinates.
self.xyzr = [float("NaN") * np.zeros((2, 4))]

@property
def shape(self):
return ()

def init_conds(self, params):
cond_params = {
"branch_conds_fwd": jnp.asarray([]),
Expand Down Expand Up @@ -72,6 +76,11 @@ def __call__(self, index: int):
"'CompartmentView' object has no attribute 'comp' or 'loc'."
)

@property
def shape(self):
local_idcs = get_local_indices(self.view)
return tuple(local_idcs.nunique())[2:]

def loc(self, loc: float):
if loc != "all":
assert (
Expand Down
6 changes: 6 additions & 0 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
convert_point_process_to_distributed,
flip_comp_indices,
merge_cells,
get_local_indices,
)
from jaxley.utils.syn_utils import gather_synapes, prepare_syn

Expand Down Expand Up @@ -111,6 +112,11 @@ def __getattr__(self, key):
else:
raise KeyError(f"Key {key} not recognized.")

@property
def shape(self):
local_idcs = get_local_indices(self.nodes)
return tuple(local_idcs.nunique())

def init_morph(self):
self.nbranches_per_cell = [cell.total_nbranches for cell in self.cells]
self.total_nbranches = sum(self.nbranches_per_cell)
Expand Down
24 changes: 24 additions & 0 deletions jaxley/utils/cell_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,27 @@ def convert_point_process_to_distributed(
area = 2 * pi * radius * length
current /= area # nA / um^2
return current * 100_000 # Convert (nA / um^2) to (uA / cm^2)


def get_local_indices(view):
"""Computes local from global indices.

#cell_index, branch_index, comp_index
0, 0, 0 --> 0, 0, 0 # 1st compartment of 1st branch of 1st cell
0, 0, 1 --> 0, 0, 1 # 2nd compartment of 1st branch of 1st cell
0, 1, 2 --> 0, 1, 0 # 1st compartment of 2nd branch of 1st cell
0, 1, 3 --> 0, 1, 1 # 2nd compartment of 2nd branch of 1st cell
1, 2, 4 --> 1, 0, 0 # 1st compartment of 1st branch of 2nd cell
1, 2, 5 --> 1, 0, 1 # 2nd compartment of 1st branch of 2nd cell
1, 3, 6 --> 1, 1, 0 # 1st compartment of 2nd branch of 2nd cell
1, 3, 7 --> 1, 1, 1 # 2nd compartment of 2nd branch of 2nd cell
"""

def reindex_a_by_b(df, a, b):
df.loc[:, a] = df.groupby(b)[a].rank(method="dense").astype(int) - 1
return df

idcs = view[["cell_index", "branch_index", "comp_index"]]
idcs = reindex_a_by_b(idcs, "branch_index", "cell_index")
idcs = reindex_a_by_b(idcs, "comp_index", ["cell_index", "branch_index"])
return idcs
Loading