Skip to content

Commit

Permalink
fix: move some things around
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Dec 17, 2024
1 parent d9b1125 commit 4f4d61c
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 79 deletions.
147 changes: 74 additions & 73 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def __str__(self):

def __dir__(self):
base_dir = object.__dir__(self)
synapses = [s._name for s in self.synapses]
synapses = [s.name for s in self.synapses]
groups = [] if len(self.groups) == 0 else list(self.groups.keys())
return sorted(base_dir + synapses + groups)

Expand All @@ -205,16 +205,16 @@ def __getattr__(self, key):
return view

# intercepts calls to channels
if key in [c._name for c in self.base.channels]:
channel_names = [c._name for c in self.channels]
if key in [c.name for c in self.base.channels]:
channel_names = [c.name for c in self.channels]
inds = self.nodes.index[self.nodes[key]].to_numpy()
view = self.select(inds) if key in channel_names else self.select(None)
view._set_controlled_by_param(key)
return view

# intercepts calls to synapse types
base_syn_names = [s._name for s in self.base.synapses]
syn_names = [s._name for s in self.synapses]
base_syn_names = [s.name for s in self.base.synapses]
syn_names = [s.name for s in self.synapses]
if key in base_syn_names:
syn_inds = self.edges[self.edges["type"] == key][
"global_edge_index"
Expand Down Expand Up @@ -714,15 +714,60 @@ def _gather_channels_from_constituents(self, constituents: List):
"""
for module in constituents:
for channel in module.channels:
if channel._name not in [c._name for c in self.channels]:
if channel.name not in [c.name for c in self.channels]:
self.base.channels.append(channel)
if channel.current_name not in self.membrane_current_names:
self.base.membrane_current_names.append(channel.current_name)
# Setting columns of channel names to `False` instead of `NaN`.
for channel in self.base.channels:
name = channel._name
name = channel.name
self.base.nodes.loc[self.nodes[name].isna(), name] = False

def _prepare_for_jax(self):
# prepare lookup of indices of states, parameters and mechanisms
global_params = ["radius", "length", "axial_resistivity", "capacitance"]
global_states = ["v"]

current_names = self.membrane_current_names + self.synapse_current_names
global_states_params = global_states + global_params + current_names

channel_names = [c.name for c in self.channels]
syn_names = [s.name for s in self.synapses]

node_attrs = self.nodes.columns.to_list() + current_names + channel_names

def inds_of_key(key: str) -> np.ndarray:
"""Return the indices for params, states, mechanisms and currents."""
data = self.nodes if key in node_attrs else pd.DataFrame()
data = self.edges if key in self.edges.columns or key in syn_names else data

if key in channel_names + syn_names:
where = data["type"] == key if key in syn_names else data[key]
elif key in data.columns:
where = ~data[key].isna()
elif key in global_states_params:
where = pd.Index([True] * len(data))
else:
raise ValueError(f"Key '{key}' not found in nodes or edges")
return data.index[where].to_numpy()

# expose the lookup function to the class with precomputed attrs in scope
self._inds_of_state_param = inds_of_key

# add index attrs to mechansisms (i.e. where was it inserted) and also keep track
# of states / parameters that are also shared by other mechanisms.
for mech in self.channels + self.synapses:
mech.indices = self._inds_of_state_param(mech.name)
mech._jax_inds = {}
current = {mech.current_name: None} if isinstance(mech, Channel) else {}

for param_state in {**mech.params, **mech.states, **current}:
is_global = not param_state.startswith(f"{mech.name}_")
if is_global:
global_inds = self._inds_of_state_param(param_state)
local_inds = np.where(np.isin(global_inds, mech.indices))[0]
mech._jax_inds[param_state] = local_inds

def to_jax(self):
"""Move `.nodes` to `.jaxnodes`.
Expand Down Expand Up @@ -784,7 +829,7 @@ def show(
scopes = ["local", "global"]
inds = [f"{s}_{i}" for i in inds for s in scopes] if indices else []
cols += inds
cols += [ch._name for ch in self.channels] if channel_names else []
cols += [ch.name for ch in self.channels] if channel_names else []
cols += sum([list(ch.params) for ch in self.channels], []) if params else []
cols += sum([list(ch.states) for ch in self.channels], []) if states else []

Expand Down Expand Up @@ -914,7 +959,7 @@ def set_ncomp(
all_nodes = self.base.nodes
start_idx = self.nodes["global_comp_index"].to_numpy()[0]
ncomp_per_branch = self.base.ncomp_per_branch
channel_names = [c._name for c in self.base.channels]
channel_names = [c.name for c in self.base.channels]
channel_param_names = list(chain(*[c.params for c in self.base.channels]))
channel_state_names = list(chain(*[c.states for c in self.base.channels]))
radius_generating_fns = self.base._radius_generating_fns
Expand Down Expand Up @@ -1216,33 +1261,22 @@ def _get_state_names(self) -> Tuple[List, List]:
synapse_states + self.synapse_current_names,
)

def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:
"""Get all trainable parameters.
The returned parameters should be passed to `jx.integrate(..., params=params).
Returns:
A list of all trainable parameters in the form of
[{"gNa": jnp.array([0.1, 0.2, 0.3])}, ...].
"""
return self.trainable_params

def _iter_states_params(
self, params=False, states=False, currents=False
) -> Tuple[str, np.ndarray]: # type: ignore
) -> Tuple[str, np.ndarray]: # type: ignore
# assert that either params or states is True
assert params or states or currents, "Select either params / states / currents."
all_mechs = self.channels + self.synapses

if params:
global_params = ["radius", "length", "axial_resistivity", "capacitance"]
all_params = sum([list(m.params) for m in all_mechs], []) + global_params
all_params = [p for m in all_mechs for p in m.params] + global_params
for key in all_params:
yield key, self._inds_of_state_param(key)

if states:
global_states = ["v"]
all_states = sum([list(m.states) for m in all_mechs], []) + global_states
all_states = [s for m in all_mechs for s in m.states] + global_states
for key in all_states:
yield key, self._inds_of_state_param(key)

Expand All @@ -1251,49 +1285,16 @@ def _iter_states_params(
for key in current_names:
yield key, self._inds_of_state_param(key)

def _prepare_for_jax(self):
# prepare lookup of indices of states, parameters and mechanisms
global_params = ["radius", "length", "axial_resistivity", "capacitance"]
global_states = ["v"]

current_names = self.membrane_current_names + self.synapse_current_names
global_states_params = global_states + global_params + current_names

channel_names = [c._name for c in self.channels]
syn_names = [s._name for s in self.synapses]

node_attrs = self.nodes.columns.to_list() + current_names + channel_names
def inds_of_key(key: str) -> np.ndarray:
"""Return the indices for params, states, mechanisms and currents."""
data = self.nodes if key in node_attrs else pd.DataFrame()
data = self.edges if key in self.edges.columns or key in syn_names else data

if key in channel_names + syn_names:
where = data["type"] == key if key in syn_names else data[key]
elif key in data.columns:
where = ~data[key].isna()
elif key in global_states_params:
where = pd.Index([True] * len(data))
else:
raise ValueError(f"Key '{key}' not found in nodes or edges")
return data.index[where].to_numpy()

# expose the lookup function to the class with precomputed attrs in scope
self._inds_of_state_param = inds_of_key
def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:
"""Get all trainable parameters.
# add index attrs to mechansisms (i.e. where was it inserted) and also keep track
# of states / parameters that are also shared by other mechanisms.
for mech in self.channels + self.synapses:
mech.indices = self._inds_of_state_param(mech._name)
mech._jax_inds = {}
current = {mech.current_name: None} if isinstance(mech, Channel) else {}
The returned parameters should be passed to `jx.integrate(..., params=params).
for param_state in {**mech.params, **mech.states, **current}:
is_global = not param_state.startswith(f"{mech._name}_")
if is_global:
global_inds = self._inds_of_state_param(param_state)
local_inds = np.where(np.isin(global_inds, mech.indices))[0]
mech._jax_inds[param_state] = local_inds
Returns:
A list of all trainable parameters in the form of
[{"gNa": jnp.array([0.1, 0.2, 0.3])}, ...].
"""
return self.trainable_params

def _get_all_states_params(
self,
Expand Down Expand Up @@ -1751,10 +1752,10 @@ def insert(self, channel: Channel):
Args:
channel: The channel to insert."""
name = channel._name
name = channel.name

# Channel does not yet exist in the `jx.Module` at all.
if name not in [c._name for c in self.base.channels]:
if name not in [c.name for c in self.base.channels]:
self.base.channels.append(channel)
self.base.nodes[name] = (
False # Previous columns do not have the new channel.
Expand All @@ -1779,9 +1780,9 @@ def delete_channel(self, channel: Channel):
Args:
channel: The channel to remove."""
name = channel._name
channel_names = [c._name for c in self.channels]
all_channel_names = [c._name for c in self.base.channels]
name = channel.name
channel_names = [c.name for c in self.channels]
all_channel_names = [c.name for c in self.base.channels]
if name in channel_names:
channel_cols = list({**channel.params, **channel.states}.keys())
self.base.nodes.loc[self._nodes_in_view, channel_cols] = float("nan")
Expand Down Expand Up @@ -2615,15 +2616,15 @@ def _set_trainables_in_view(self):

def _channels_in_view(self, pointer: Union[Module, View]) -> List[Channel]:
"""Set channels to show only those in view."""
names = [name._name for name in pointer.channels]
names = [c.name for c in pointer.channels]
channel_in_view = self.nodes[names].any(axis=0)
channel_in_view = channel_in_view[channel_in_view].index
return [deepcopy(c) for c in pointer.channels if c._name in channel_in_view]
return [deepcopy(c) for c in pointer.channels if c.name in channel_in_view]

def _synapses_in_view(self, pointer: Union[Module, View]):
"""Set synapses to show only those in view."""
names = self.edges["type"].unique()
return [deepcopy(syn) for syn in pointer.synapses if syn._name in names]
return [deepcopy(syn) for syn in pointer.synapses if syn.name in names]

def _nbranches_per_cell_in_view(self) -> np.ndarray:
cell_nodes = self.nodes.groupby("global_cell_index")
Expand Down
12 changes: 6 additions & 6 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def _synapse_currents(
diff = 1e-3

num_comp = len(voltages)
synapse_current_states = {f"i_{s._name}": zeros for s in syn_channels}
synapse_current_states = {f"i_{s.name}": zeros for s in syn_channels}
for i, group in edges.groupby("type_ind"):
synapse = syn_channels[i]
pre_inds = group["pre_global_comp_index"].to_numpy()
Expand Down Expand Up @@ -340,15 +340,15 @@ def _synapse_currents(
syn_const_terms = syn_const_terms.at[:].add(-gathered_syn_currents[1])
# Save the current (for the unperturbed voltage) as a state that will
# also be passed to the state update.
synapse_current_states[f"i_{synapse._name}"] = (
synapse_current_states[f"i_{synapse._name}"]
synapse_current_states[f"i_{synapse.name}"] = (
synapse_current_states[f"i_{synapse.name}"]
.at[post_inds]
.add(synapse_currents_dist[0])
)

# Copy the currents into the `state` dictionary such that they can be
# recorded and used by `Channel.update_states()`.
for name in [s._name for s in self.synapses]:
for name in [s.name for s in self.synapses]:
states[f"i_{name}"] = synapse_current_states[f"i_{name}"]
return states, (syn_voltage_terms, syn_const_terms)

Expand Down Expand Up @@ -474,14 +474,14 @@ def vis(
return ax

def _infer_synapse_type_ind(self, synapse_name):
syn_names = [s._name for s in self.base.synapses]
syn_names = [s.name for s in self.base.synapses]
is_new_type = False if synapse_name in syn_names else True
type_ind = len(syn_names) if is_new_type else syn_names.index(synapse_name)
return type_ind, is_new_type

def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type):
# Add synapse types to the module and infer their unique identifier.
synapse_name = synapse_type._name
synapse_name = synapse_type.name
synapse_current_name = f"i_{synapse_name}"
type_ind, is_new = self._infer_synapse_type_ind(synapse_name)
if is_new: # synapse is not known
Expand Down

0 comments on commit 4f4d61c

Please sign in to comment.