Skip to content

Commit

Permalink
wip: more tests passing, small refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Dec 5, 2024
1 parent e3e2000 commit aa4ae5f
Showing 1 changed file with 58 additions and 54 deletions.
112 changes: 58 additions & 54 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
compute_axial_conductances,
compute_current_density,
compute_levels,
interpolate_xyz,
loc_of_index,
interpolate_xyzr,
params_to_pstate,
query_states_and_params,
v_interp,
Expand Down Expand Up @@ -228,6 +227,7 @@ def __getattr__(self, key):
view._set_controlled_by_param(key) # overwrites param set by edge
# Ensure synapse param sharing works with `edge`
# `edge` will be removed as part of #463
view.edges["local_edge_index"] = np.arange(len(view.edges))
return view

def _childviews(self) -> List[str]:
Expand Down Expand Up @@ -1198,9 +1198,9 @@ def _get_state_names(self) -> Tuple[List, List]:
"""Collect all recordable / clampable states in the membrane and synapses.
Returns states seperated by comps and edges."""
channel_states = [name for c in self.channels for name in c.channel_states]
channel_states = [name for c in self.channels for name in c.states]
synapse_states = [
name for s in self.synapses if s is not None for name in s.synapse_states
name for s in self.synapses if s is not None for name in s.states
]
membrane_states = ["v", "i"] + self.membrane_current_names
return (
Expand All @@ -1219,6 +1219,26 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:
"""
return self.trainable_params

@only_allow_module
def _iter_states_or_params(self, type="states") -> Dict[str, jnp.ndarray]:
# TODO FROM #447: MAKE THIS WORK FOR VIEW?
"""Return states as they are set in the `.nodes` and `.edges` tables."""
morph_params = ["radius", "length", "axial_resistivity", "capacitance"]
global_states = ["v"]
global_states_or_params = morph_params if type == "params" else global_states
for key in global_states_or_params:
yield key, self.base.jaxnodes["index"], self.base.jaxnodes[key]

# Join node and edge states into a single state dictionary.
for jax_arrays, mechs in zip(
[self.base.jaxnodes, self.base.jaxedges],
[self.base.channels, self.base.synapses],
):
for mech in mechs:
mech_inds = jax_arrays[mech._name]
for key in mech.__dict__[type]:
yield key, mech_inds, jax_arrays[key]

@only_allow_module
def get_all_parameters(
self, pstate: List[Dict], voltage_solver: str
Expand Down Expand Up @@ -1255,34 +1275,24 @@ def get_all_parameters(
Returns:
A dictionary of all module parameters.
"""
params = {}
morph_params = ["radius", "length", "axial_resistivity", "capacitance"]
for key in ["v"] + morph_params:
params[key] = self.base.jaxnodes[key]
pstate_inds = {d["key"]: i for i, d in enumerate(pstate)}

for jax_arrays, data, mechs in zip(
[self.base.jaxnodes, self.base.jaxedges],
[self.base.nodes, self.base.edges],
[self.base.channels, self.base.synapses],
):
for mech in mechs:
inds = jax_arrays[mech._name]
for mech_param in mech.params:
params[mech_param] = data[mech_param].to_numpy()
params[mech_param][inds] = jax_arrays[mech_param]
params[mech_param] = jnp.asarray(params[mech_param])
params = {}
for key, mech_inds, jax_array in self._iter_states_or_params("params"):
params[key] = jax_array

# Override with those parameters set by `.make_trainable()`.
for parameter in pstate:
key = parameter["key"]
inds = parameter["indices"]
set_param = parameter["val"]
# Override with those parameters set by `.make_trainable()`.
if key in pstate_inds:
idx = pstate_inds[key]
key = pstate[idx]["key"]
inds = pstate[idx]["indices"]
set_param = pstate[idx]["val"]

if key in params: # Only parameters, not initial states.
# `inds` is of shape `(num_params, num_comps_per_param)`.
# `set_param` is of shape `(num_params,)`
# We need to unsqueeze `set_param` to make it `(num_params, 1)` for the
# `.set()` to work. This is done with `[:, None]`.
# We need to unsqueeze `set_param` to make it `(num_params, 1)`
# for the `.set()` to work. This is done with `[:, None]`.
inds = np.searchsorted(mech_inds, inds)
params[key] = params[key].at[inds].set(set_param[:, None])

# Compute conductance params and add them to the params dictionary.
Expand All @@ -1291,20 +1301,6 @@ def get_all_parameters(
)
return params

@only_allow_module
def _get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]:
# TODO FROM #447: MAKE THIS WORK FOR VIEW?
"""Return states as they are set in the `.nodes` and `.edges` tables."""
self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.
states = {"v": self.base.jaxnodes["v"]}
# Join node and edge states into a single state dictionary.
for channel in self.base.channels:
for channel_states in channel.states:
states[channel_states] = self.base.jaxnodes[channel_states]
for synapse_states in self.base.synapse_state_names:
states[synapse_states] = self.base.jaxedges[synapse_states]
return states

@only_allow_module
def get_all_states(
self, pstate: List[Dict], all_params, delta_t: float
Expand All @@ -1320,18 +1316,23 @@ def get_all_states(
Returns:
A dictionary of all states of the module.
"""
states = self.base._get_states_from_nodes_and_edges()

# Override with the initial states set by `.make_trainable()`.
for parameter in pstate:
key = parameter["key"]
inds = parameter["indices"]
set_param = parameter["val"]
if key in states: # Only initial states, not parameters.
# `inds` is of shape `(num_params, num_comps_per_param)`.
# `set_param` is of shape `(num_params,)`
# We need to unsqueeze `set_param` to make it `(num_params, 1)` for the
# `.set()` to work. This is done with `[:, None]`.
pstate_inds = {d["key"]: i for i, d in enumerate(pstate)}
states = {}
for key, mech_inds, jax_array in self._iter_states_or_params("states"):
states[key] = jax_array

# Override with those parameters set by `.make_trainable()`.
if key in pstate_inds:
idx = pstate_inds[key]
key = pstate[idx]["key"]
inds = pstate[idx]["indices"]
set_param = pstate[idx]["val"]

# `inds` is of shape `(num_states, num_comps_per_param)`.
# `set_param` is of shape `(num_states,)`
# We need to unsqueeze `set_param` to make it `(num_states, 1)`
# for the `.set()` to work. This is done with `[:, None]`.
inds = np.searchsorted(mech_inds, inds)
states[key] = states[key].at[inds].set(set_param[:, None])

# Add to the states the initial current through every channel.
Expand Down Expand Up @@ -1366,8 +1367,11 @@ def init_states(self, delta_t: float = 0.025):
delta_t: Passed on to `channel.init_state()`.
"""
# Update states of the channels.
self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.
channel_nodes = self.base.nodes
states = self.base._get_states_from_nodes_and_edges()
states = {}
for key, _, jax_array in self._iter_states_or_params("states"):
states[key] = jax_array

# We do not use any `pstate` for initializing. In principle, we could change
# that by allowing an input `params` and `pstate` to this function.
Expand Down

0 comments on commit aa4ae5f

Please sign in to comment.