Skip to content

Commit

Permalink
wip: tests are passing. except for voltages, which only passes in not…
Browse files Browse the repository at this point in the history
…ebook but not in pytest
  • Loading branch information
jnsbck committed Aug 27, 2024
1 parent 41685e1 commit 0bb338c
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 846 deletions.
37 changes: 27 additions & 10 deletions jaxley/io/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ def from_graph(
nseg: int = 4,
max_branch_len: float = 2000.0,
assign_groups: bool = True,
ignore_swc_trace_errors: bool = True,
) -> Union[jx.Network, jx.Cell, jx.Branch, jx.Compartment]:
"""Build a module from a networkx graph.
Expand Down Expand Up @@ -633,6 +634,8 @@ def from_graph(
assigned yet.
assign_groups: Wether to assign groups to nodes based on the the id or groups
attribute.
ignore_swc_trace_errors: Whether to ignore discontinuities in the swc tracing
order. If False, this will result in split branches at these points.
Returns:
A module instance that is populated with the node and egde attributes of
Expand All @@ -645,12 +648,20 @@ def from_graph(
if "type" not in graph.graph:
try:
graph = make_jaxley_compatible(
graph, nseg=nseg, max_branch_len=max_branch_len
graph,
nseg=nseg,
max_branch_len=max_branch_len,
ignore_swc_trace_errors=ignore_swc_trace_errors,
)
except:
raise Exception("Graph appears to be incompatible with jaxley.")
elif graph.graph["type"] == "swc":
graph = make_jaxley_compatible(graph, nseg=nseg, max_branch_len=max_branch_len)
graph = make_jaxley_compatible(
graph,
nseg=nseg,
max_branch_len=max_branch_len,
ignore_swc_trace_errors=ignore_swc_trace_errors,
)

#################################
### Import graph as jx.Module ###
Expand Down Expand Up @@ -692,19 +703,21 @@ def from_graph(
acc_parents.append([-1] + parents.tolist())

# drop special attrs from nodes and ignore error if col does not exist
optional_attrs = ["recordings", "externals", "groups", "trainable"]
# x,y,z can be re-computed from xyzr if needed
optional_attrs = ["recordings", "externals", "groups", "trainable", "x", "y", "z"]
nodes.drop(columns=optional_attrs, inplace=True, errors="ignore")

# build module
idxs = nodes[["cell_index", "branch_index", "comp_index"]]
module = build_module_scaffold(idxs, graph.graph["type"], acc_parents)

# set global attributes of module
graph.graph.pop("type")
for k, v in graph.graph.items():
setattr(module, k, v)

module.nodes[nodes.columns] = nodes # set column-wise. preserves cols not in nodes.
module.edges = synapse_edges.T
module.edges = synapse_edges.T if not synapse_edges.empty else module.edges

module.membrane_current_names = [c.current_name for c in module.channels]
module.synapse_names = [s._name for s in module.synapses]
Expand All @@ -727,8 +740,12 @@ def from_graph(
cached_external_inds = {}
cached_externals = {}
for key, data in externals.items():
cached_externals[key] = np.stack(data[~data.isna()].explode().values)
cached_external_inds[key] = data[~data.isna()].explode().index.to_numpy()
cached_externals[key] = jnp.array(
np.stack(data[~data.isna()].explode().values)
)
cached_external_inds[key] = jnp.array(
data[~data.isna()].explode().index.to_numpy()
)
module.externals = cached_externals
module.external_inds = cached_external_inds

Expand Down Expand Up @@ -773,10 +790,10 @@ def from_graph(
if not groups.empty and assign_groups:
groups = groups.explode(1).rename(columns={0: "index", 1: "group"})
groups = groups[groups["group"] != "undefined"] # skip undefined comps
group_nodes = {k: nodes.loc[v["index"]] for k, v in groups.groupby("group")}
# module[:] ensure group nodes in module reflect what is shown in view
group_nodes = {
k: module[:].view.loc[v["index"]] for k, v in groups.groupby("group")
}
module.group_nodes = group_nodes
# update group nodes in module to reflect what is shown in view
for group, nodes in module.group_nodes.items():
module.group_nodes[group] = module.__getattr__(group).view

return module
3 changes: 1 addition & 2 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,8 @@ def __init__(self):
def _update_nodes_with_xyz(self):
"""Add xyz coordinates to nodes."""
loc = np.linspace(0.5 / self.nseg, 1 - 0.5 / self.nseg, self.nseg)
jit_interp = jit(interpolate_xyz)
xyz = (
[jit_interp(loc, xyzr).T for xyzr in self.xyzr]
[interpolate_xyz(loc, xyzr).T for xyzr in self.xyzr]
if len(loc) > 0
else [self.xyzr]
)
Expand Down
12 changes: 6 additions & 6 deletions jaxley/utils/cell_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ def remap_to_consecutive(arr):
return inverse_indices


v_interp = vmap(jnp.interp, in_axes=(None, None, 1))


def interpolate_xyz(loc: float, coords: np.ndarray):
"""Perform a linear interpolation between xyz-coordinates.
Expand All @@ -288,9 +291,9 @@ def interpolate_xyz(loc: float, coords: np.ndarray):
Return:
Interpolated xyz coordinate at `loc`, shape `(3,).
"""
return vmap(lambda x: jnp.interp(loc, jnp.linspace(0, 1, len(x)), x), in_axes=(1,))(
coords[:, :3]
)
lens = np.cumsum(np.sqrt(np.sum(np.diff(coords[:, :3], axis=0) ** 2, axis=1)))
lens = np.insert(lens, 0, 0)
return v_interp(loc * lens[-1], lens, coords[:, :3])


def params_to_pstate(
Expand Down Expand Up @@ -388,6 +391,3 @@ def group_and_sum(
group_sums = group_sums.at[inds_to_group_by].add(values_to_sum)

return group_sums


v_interp = vmap(jnp.interp, in_axes=(None, None, 1))
58 changes: 39 additions & 19 deletions jaxley/utils/misc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,41 +33,61 @@ def childview(
raise AttributeError("Compartment does not support indexing")


def recursive_compare(a, b):
if isinstance(a, (int, float)):
def recursive_compare(a, b, verbose=False):
def verbose_comp(a, b, type):
if verbose:
print(f"{type} {a} and {b} are not equal.")
return False

if type(a) != type(b):
return verbose_comp(a, b, "Types")

if isinstance(a, (float, int)):
if abs(a - b) > 1e-5 and not (np.isnan(a) and np.isnan(b)):
return False
return verbose_comp(a, b, "Floats/Ints")

elif isinstance(a, str):
if a != b:
return False
return verbose_comp(a, b, "Strings")

elif isinstance(a, (np.ndarray, jnp.ndarray)):
if a.size > 1:
for i in range(len(a)):
if not recursive_compare(a[i], b[i]):
return False
else:
if not recursive_compare(a.item(), b.item()):
return False
if a.dtype.kind in "biufc": # is numeric
if not np.allclose(a, b, equal_nan=True):
return verbose_comp(a, b, "Arrays")
elif not np.all(a == b):
return verbose_comp(a, b, "Arrays")

elif isinstance(a, (list, tuple)):
if len(a) != len(b):
return False
return verbose_comp(a, b, "Lists/Tuples")

for i in range(len(a)):
if not recursive_compare(a[i], b[i]):
return False
return verbose_comp(a[i], b[i], "Lists/Tuples")

elif isinstance(a, dict):
if len(a) != len(b) and len(a) != 0:
return False
return verbose_comp(a, b, "Dicts")

if set(a.keys()) != set(b.keys()):
return False
return verbose_comp(a, b, "Dicts")

for k in a.keys():
if not recursive_compare(a[k], b[k]):
return False
return verbose_comp(a[k], b[k], "Dicts")

elif isinstance(a, pd.DataFrame):
if not recursive_compare(a.to_dict(), b.to_dict()):
return False
return verbose_comp(a, b, "DataFrames")

elif a is None or b is None:
if not (a is None and b is None):
return False
return verbose_comp(a, b, "None")
else:
raise ValueError(f"Type {type(a)} not supported")
try:
if not a == b:
return verbose_comp(a, b, "Other")

except AttributeError:
raise ValueError(f"Type {type(a)} not supported")
return True
22 changes: 11 additions & 11 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from neuron import h
import numpy as np
import pandas as pd

Expand Down Expand Up @@ -38,14 +37,14 @@ def get_segment_xyzrL(section, comp_idx=None, loc=None, nseg=8):
return x, y, z, r, L[-1]/nseg


def jaxley2neuron_by_coords(jx_cell, neuron_secs, branch_loc=0.05, nseg=8):
neuron_coords = {i: np.vstack(get_segment_xyzrL(sec, branch_loc, nseg=nseg))[:,:3].T for i, sec in enumerate(neuron_secs)}
def jaxley2neuron_by_coords(jx_cell, neuron_secs, comp_idx=None, loc=None, nseg=8):
neuron_coords = {i: np.vstack(get_segment_xyzrL(sec, comp_idx=comp_idx, loc=loc, nseg=nseg))[:3].T for i, sec in enumerate(neuron_secs)}
neuron_coords = np.vstack([np.hstack([k*np.ones((v.shape[0], 1)), v]) for k,v in neuron_coords.items()])
neuron_coords = pd.DataFrame(neuron_coords, columns=["branch_index", "x", "y", "z"])
neuron_coords["branch_index"] = neuron_coords["branch_index"].astype(int)

neuron_loc_xyz = neuron_coords.groupby("branch_index").mean()
jaxley_loc_xyz = jx_cell.branch("all").loc(branch_loc).show().set_index("branch_index")[["x", "y", "z"]]
jaxley_loc_xyz = jx_cell.branch("all").loc(loc).show().set_index("branch_index")[["x", "y", "z"]]

jaxley2neuron_inds = {}
for i,xyz in enumerate(jaxley_loc_xyz.to_numpy()):
Expand All @@ -54,13 +53,13 @@ def jaxley2neuron_by_coords(jx_cell, neuron_secs, branch_loc=0.05, nseg=8):
return jaxley2neuron_inds


def jaxley2neuron_by_group(cell, neuron_secs, branch_loc=0.05, nseg=8, num_apical=20, num_tuft=20, num_basal=10):
y_apical = cell.apical.show().groupby("branch_index").mean()["y"].abs().sort_values()
def jaxley2neuron_by_group(jx_cell, neuron_secs, comp_idx=None, loc=None, nseg=8, num_apical=20, num_tuft=20, num_basal=10):
y_apical = jx_cell.apical.show().groupby("branch_index").mean()["y"].abs().sort_values()
trunk_inds = y_apical.index[:num_apical].tolist()
tuft_inds = y_apical.index[-num_tuft:].tolist()
basal_inds = cell.basal.show()["branch_index"].unique()[:num_basal].tolist()
basal_inds = jx_cell.basal.show()["branch_index"].unique()[:num_basal].tolist()

jaxley2neuron = jaxley2neuron_by_coords(cell, neuron_secs, loc=branch_loc, nseg=nseg)
jaxley2neuron = jaxley2neuron_by_coords(jx_cell, neuron_secs, comp_idx=comp_idx, loc=loc, nseg=nseg)

neuron_trunk_inds = [jaxley2neuron[i] for i in trunk_inds]
neuron_tuft_inds = [jaxley2neuron[i] for i in tuft_inds]
Expand All @@ -70,12 +69,13 @@ def jaxley2neuron_by_group(cell, neuron_secs, branch_loc=0.05, nseg=8, num_apica
jaxley_inds = {"trunk": trunk_inds, "tuft": tuft_inds, "basal": basal_inds}
return neuron_inds, jaxley_inds

def match_stim_loc(jx_cell, neuron_sec, loc=0.05, nseg=8):
stim_coords = get_segment_xyzrL(neuron_sec, loc=loc, nseg=nseg)[:,:3]
def match_stim_loc(jx_cell, neuron_sec, comp_idx=None, loc=None, nseg=8):
stim_coords = get_segment_xyzrL(neuron_sec, comp_idx=comp_idx, loc=loc, nseg=nseg)[:3]
stim_idx = ((jx_cell.nodes[["x", "y", "z"]]-stim_coords)**2).sum(axis=1).argmin()
return stim_idx

def import_neuron_morph(fname, nseg=8):
from neuron import h
_ = h.load_file("stdlib.hoc")
_ = h.load_file("import3d.hoc")
nseg = 8
Expand All @@ -91,4 +91,4 @@ def import_neuron_morph(fname, nseg=8):

for sec in h.allsec():
sec.nseg = nseg
return cell
return h, cell
Loading

0 comments on commit 0bb338c

Please sign in to comment.