Skip to content

Commit

Permalink
wip: progress on tests and tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Aug 27, 2024
1 parent aa87cc0 commit 41685e1
Show file tree
Hide file tree
Showing 7 changed files with 285 additions and 199 deletions.
106 changes: 77 additions & 29 deletions jaxley/io/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ def build_module_scaffold(
return_type = infer_module_type_from_inds(idxs)

comp = jx.Compartment()
build_cache["compartment"] = comp
build_cache["compartment"] = [comp]

if return_type in return_types[1:]:
nsegs = idxs["branch_index"].value_counts().iloc[0]
branch = jx.Branch([comp for _ in range(nsegs)])
build_cache["branch"] = branch
build_cache["branch"] = [branch]

if return_type in return_types[2:]:
for cell_id, cell_groups in idxs.groupby("cell_index"):
Expand Down Expand Up @@ -95,7 +95,7 @@ def to_graph(module: jx.Module) -> nx.DiGraph:
module._update_nodes_with_xyz() # make xyz coords attr of nodes

# add global attrs
module_graph.graph["type"] = module.__class__.__name__
module_graph.graph["type"] = module.__class__.__name__.lower()
for attr in [
"nseg",
"initialized_morph",
Expand Down Expand Up @@ -130,10 +130,17 @@ def to_graph(module: jx.Module) -> nx.DiGraph:
rec_states = group["state"].values
module_graph.add_node(rec_index, **{"recordings": rec_states})

# add currents to nodes
if module.currents is not None:
for index, currents in zip(module.current_inds.index, module.currents):
module_graph.add_node(index, **{"currents": currents})
# add externals to nodes
if module.externals is not None:
for key, inds in module.external_inds.items():
unique_inds = np.unique(inds.flatten())
for i in unique_inds:
which = np.where(inds == i)[0]
if "externals" not in module_graph.nodes[i]:
module_graph.nodes[i]["externals"] = {}
module_graph.nodes[i]["externals"].update(
{key: module.externals[key][which]}
)

# add trainable params to nodes
if module.trainable_params:
Expand Down Expand Up @@ -376,7 +383,7 @@ def simulate_swc_trace_errors(

def make_jaxley_compatible(
graph: nx.DiGraph,
nseg: int = 8,
nseg: int = 4,
max_branch_len: float = 2000.0,
source_node: Union[str, int] = 0,
ignore_swc_trace_errors: bool = True,
Expand Down Expand Up @@ -429,15 +436,33 @@ def make_jaxley_compatible(
is returned by `to_graph` when exporting a module.
"""

graph = add_edge_lens(graph) # add edge lengths to graph just in case
available_keys = graph.nodes[0].keys()
defaults = {
"id": 0,
"x": float("nan"),
"y": float("nan"),
"z": float("nan"),
"r": 1,
}
# add defaults if not present
for key in set(defaults.keys()).difference(available_keys):
nx.set_node_attributes(graph, defaults[key], key)

# add edge lengths to graph just in case
graph = add_edge_lens(graph)
if np.isnan(next(iter(graph.edges(data=True)))[2]["l"]):
nx.set_edge_attributes(graph, 1, "l")
branches = trace_branches(graph, source_node=source_node)

if not ignore_swc_trace_errors:
breaks = find_swc_trace_errors(graph)
branches = simulate_swc_trace_errors(branches, breaks)

# ensures singular root branch
if source_node != "leaf":
if (
source_node != "leaf"
and graph.out_degree(source_node) + graph.in_degree(source_node) > 1
):
branches = [np.array([[source_node, source_node]])] + branches
graph.add_edge(source_node, source_node)
graph.edges[source_node, source_node]["l"] = 0.01
Expand Down Expand Up @@ -500,17 +525,21 @@ def make_jaxley_compatible(

branch_roots_and_leafs = np.stack([np.array(b)[[0, -1]] for b in branch_nodes])
is_branch_parent_of_child = np.equal(*np.meshgrid(*(branch_roots_and_leafs.T)))
edges_between_branches = np.stack(list(zip(*np.where(is_branch_parent_of_child))))
branch_parents_and_children = list(zip(*np.where(is_branch_parent_of_child)))

comps_in_branches = jaxley_comps.groupby("branch_index")["comp_index"]
intra_branch_edges = sum([branch_n2e(c) for i, c in comps_in_branches], [])

branch_roots = comps_in_branches.first().values
branch_leafs = comps_in_branches.last().values

inter_branch_children = branch_roots[edges_between_branches[:, 1]]
inter_branch_parents = branch_leafs[edges_between_branches[:, 0]]
inter_branch_edges = np.stack([inter_branch_parents, inter_branch_children]).T
if len(branch_parents_and_children) > 0:
edges_between_branches = np.stack(branch_parents_and_children)
inter_branch_children = branch_roots[edges_between_branches[:, 1]]
inter_branch_parents = branch_leafs[edges_between_branches[:, 0]]
inter_branch_edges = np.stack([inter_branch_parents, inter_branch_children]).T
else:
inter_branch_edges = []

comp_graph = nx.DiGraph()
comp_graph.add_edges_from(inter_branch_edges, type="inter_branch")
Expand All @@ -533,7 +562,7 @@ def make_jaxley_compatible(
def from_graph(
graph: nx.DiGraph,
nseg: int = 4,
max_branch_len: float = 300.0,
max_branch_len: float = 2000.0,
assign_groups: bool = True,
) -> Union[jx.Network, jx.Cell, jx.Branch, jx.Compartment]:
"""Build a module from a networkx graph.
Expand All @@ -544,7 +573,7 @@ def from_graph(
edges are considered synapse edges. These are added to the Module.branch_edges and
Module.edges attributes, respectively. Additionally, the graph can contain
global attributes, which are added as attrs, i.e. to the module instance and
optionally can store recordings, currents, groups, and trainables. These are
optionally can store recordings, externals, groups, and trainables. These are
imported from the node attributes of the graph. See `to_graph` for how they
are formatted.
Expand Down Expand Up @@ -588,7 +617,7 @@ def from_graph(
- y: float
- z: float
- recordings: list[str]
- currents: list[float]
- externals: list[float]
- trainable: dict[str, float]
- edges:
- type: str ("synapse" or "inter_branch" / "intra_branch" or None)
Expand All @@ -613,7 +642,14 @@ def from_graph(
### Make the graph jaxley compatible ###
########################################

if graph.graph["type"] == "swc":
if "type" not in graph.graph:
try:
graph = make_jaxley_compatible(
graph, nseg=nseg, max_branch_len=max_branch_len
)
except:
raise Exception("Graph appears to be incompatible with jaxley.")
elif graph.graph["type"] == "swc":
graph = make_jaxley_compatible(graph, nseg=nseg, max_branch_len=max_branch_len)

#################################
Expand All @@ -627,7 +663,12 @@ def from_graph(
) # ensure index == comp_index
edge_type = nx.get_edge_attributes(graph, "type")
edges = pd.DataFrame(edge_type.values(), index=edge_type.keys(), columns=["type"])
edges = edges.reset_index(names=["pre", "post"])

if edges.empty: # handles comp without edges
edges = pd.DataFrame(graph.edges, columns=["pre", "post", "type"], dtype=int)
else:
edges = edges.reset_index(names=["pre", "post"])

is_synapse = edges["type"] == "synapse"
is_inter_branch = edges["type"] == "inter_branch"
inter_branch_edges = edges.loc[is_inter_branch][["pre", "post"]].values
Expand All @@ -636,8 +677,6 @@ def from_graph(
nodes["branch_index"].values[inter_branch_edges],
columns=["parent_branch_index", "child_branch_index"],
)
# branch_graph = nx.Graph((r.values for i,r in branch_edges.iterrows()))
# branch_edges = pd.DataFrame([(k,v) for k,v in nx.dfs_successors(branch_graph, source=0).items()], columns=["parent_branch_index", "child_branch_index"]).explode("child_branch_index")

edge_params = nx.get_edge_attributes(graph, "parameters")
edge_params = {k: v for k, v in edge_params.items() if k in synapse_edges}
Expand All @@ -653,7 +692,7 @@ def from_graph(
acc_parents.append([-1] + parents.tolist())

# drop special attrs from nodes and ignore error if col does not exist
optional_attrs = ["recordings", "currents", "groups", "trainable"]
optional_attrs = ["recordings", "externals", "groups", "trainable"]
nodes.drop(columns=optional_attrs, inplace=True, errors="ignore")

# build module
Expand All @@ -665,11 +704,17 @@ def from_graph(
setattr(module, k, v)

module.nodes[nodes.columns] = nodes # set column-wise. preserves cols not in nodes.
module.edges = synapse_edges
module.edges = synapse_edges.T

module.membrane_current_names = [c.current_name for c in module.channels]
module.synapse_names = [s._name for s in module.synapses]
get_names = lambda x: [list(s.__dict__[x].keys()) for s in module.synapses]
module.synapse_param_names = sum(get_names("synapse_params"), [])
module.synapse_state_names = sum(get_names("synapse_states"), [])

# Add optional attributes if they can be found in nodes
recordings = pd.DataFrame(nx.get_node_attributes(graph, "recordings"))
currents = pd.DataFrame(nx.get_node_attributes(graph, "currents"))
externals = pd.DataFrame(nx.get_node_attributes(graph, "externals")).T
groups = pd.DataFrame(nx.get_node_attributes(graph, "groups").items())
trainables = pd.DataFrame(nx.get_node_attributes(graph, "trainable"), dtype=float)

Expand All @@ -678,11 +723,14 @@ def from_graph(
recordings = recordings.rename(columns={"level_1": "rec_index", 0: "state"})
module.recordings = recordings

if not currents.empty:
current_inds = nodes.loc[currents.T.index]
currents = jnp.vstack(currents.values).T
module.currents = currents
module.current_inds = current_inds
if not externals.empty:
cached_external_inds = {}
cached_externals = {}
for key, data in externals.items():
cached_externals[key] = np.stack(data[~data.isna()].explode().values)
cached_external_inds[key] = data[~data.isna()].explode().index.to_numpy()
module.externals = cached_externals
module.external_inds = cached_external_inds

if not trainables.empty:
# trainables require special handling, since some of them are shared
Expand Down
9 changes: 7 additions & 2 deletions jaxley/io/swc.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,13 @@ def swc_to_graph(fname, num_lines=None, sort=True) -> nx.DiGraph:
i_id_xyzr_p = np.loadtxt(fname)[:num_lines]

graph = nx.DiGraph()
graph.add_nodes_from(((int(i), {"id": int(id), "x": x, "y": y, "z": z, "r": r}) for i, id, x, y, z, r, p in i_id_xyzr_p))
graph.add_edges_from([(p, i) for p, i in i_id_xyzr_p[:,[-1,0]] if p != -1])
graph.add_nodes_from(
(
(int(i), {"id": int(id), "x": x, "y": y, "z": z, "r": r})
for i, id, x, y, z, r, p in i_id_xyzr_p
)
)
graph.add_edges_from([(p, i) for p, i in i_id_xyzr_p[:, [-1, 0]] if p != -1])
graph = nx.relabel_nodes(graph, {i: i - 1 for i in graph.nodes})
graph.graph["type"] = "swc"
return graph
Expand Down
7 changes: 5 additions & 2 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
convert_point_process_to_distributed,
interpolate_xyz,
loc_of_index,
recursive_compare,
)
from jaxley.utils.debug_solver import compute_morphology_indices, convert_to_csc
from jaxley.utils.misc_utils import childview, concat_and_ignore_empty
from jaxley.utils.misc_utils import (
childview,
concat_and_ignore_empty,
recursive_compare,
)
from jaxley.utils.plot_utils import plot_morph


Expand Down
42 changes: 0 additions & 42 deletions jaxley/utils/cell_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,46 +390,4 @@ def group_and_sum(
return group_sums


def recursive_compare(a, b):
if type(a) != type(b):
return False
if isinstance(a, (int, float)):
if a != b and not (np.isnan(a) and np.isnan(b)):
return False
elif isinstance(a, str):
if a != b:
return False
elif isinstance(a, (np.ndarray, jnp.ndarray)):
if a.size > 1:
for i in range(len(a)):
if not recursive_compare(a[i], b[i]):
return False
else:
if not recursive_compare(a.item(), b.item()):
return False
elif isinstance(a, (list, tuple)):
if len(a) != len(b):
return False
for i in range(len(a)):
if not recursive_compare(a[i], b[i]):
return False
elif isinstance(a, dict):
if len(a) != len(b) and len(a) != 0:
return False
if set(a.keys()) != set(b.keys()):
return False
for k in a.keys():
if not recursive_compare(a[k], b[k]):
return False
elif isinstance(a, pd.DataFrame):
if not recursive_compare(a.to_dict(), b.to_dict()):
return False
elif a is None or b is None:
if not (a is None and b is None):
return False
else:
raise ValueError(f"Type {type(a)} not supported")
return True


v_interp = vmap(jnp.interp, in_axes=(None, None, 1))
41 changes: 41 additions & 0 deletions jaxley/utils/misc_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Optional, Union

import jax.numpy as jnp
import numpy as np
import pandas as pd

Expand Down Expand Up @@ -30,3 +31,43 @@ def childview(
if child_name != "/":
return module.__getattr__(child_name)(index)
raise AttributeError("Compartment does not support indexing")


def recursive_compare(a, b):
if isinstance(a, (int, float)):
if abs(a - b) > 1e-5 and not (np.isnan(a) and np.isnan(b)):
return False
elif isinstance(a, str):
if a != b:
return False
elif isinstance(a, (np.ndarray, jnp.ndarray)):
if a.size > 1:
for i in range(len(a)):
if not recursive_compare(a[i], b[i]):
return False
else:
if not recursive_compare(a.item(), b.item()):
return False
elif isinstance(a, (list, tuple)):
if len(a) != len(b):
return False
for i in range(len(a)):
if not recursive_compare(a[i], b[i]):
return False
elif isinstance(a, dict):
if len(a) != len(b) and len(a) != 0:
return False
if set(a.keys()) != set(b.keys()):
return False
for k in a.keys():
if not recursive_compare(a[k], b[k]):
return False
elif isinstance(a, pd.DataFrame):
if not recursive_compare(a.to_dict(), b.to_dict()):
return False
elif a is None or b is None:
if not (a is None and b is None):
return False
else:
raise ValueError(f"Type {type(a)} not supported")
return True
Loading

0 comments on commit 41685e1

Please sign in to comment.