diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 71435643..cf9392dd 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -724,6 +724,16 @@ def _gather_channels_from_constituents(self, constituents: List): self.base.nodes.loc[self.nodes[name].isna(), name] = False def _prepare_for_jax(self): + """Prepare the module for simulation with JAX. + + This function has to be run inside or before `to_jax`. It's main purpose is to; + 1. Prepare the lookup of indices of states, parameters and mechanisms. + 2. Add index attributes to mechanisms (i.e. where was it inserted) and also keep + track of states / parameters that are also shared by other mechanisms. + + Adds `_inds_of_state_param(key: str)` to the module and also adds `indices` and + `_jax_inds` to the mechanisms. + """ # prepare lookup of indices of states, parameters and mechanisms global_params = ["radius", "length", "axial_resistivity", "capacitance"] global_states = ["v"] @@ -1262,9 +1272,18 @@ def _get_state_names(self) -> Tuple[List, List]: ) def _iter_states_params( - self, params=False, states=False, currents=False + self, params: bool = False, states: bool = False, currents: bool = False ) -> Tuple[str, np.ndarray]: # type: ignore - # assert that either params or states is True + """Iterate over all states and parameters. + + Args: + params: Whether to iterate over parameters. + states: Whether to iterate over states. + currents: Whether to iterate over currents. + + Yields: + The key and the indices of the states / parameters. + """ assert params or states or currents, "Select either params / states / currents." all_mechs = self.channels + self.synapses @@ -1299,12 +1318,27 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: def _get_all_states_params( self, pstate: List[Dict], - voltage_solver=None, - delta_t=None, - all_params=None, - params=False, - states=False, + voltage_solver: str = None, + delta_t: float = None, + all_params: Dict[str, jnp.ndarray] = None, + params: bool = False, + states: bool = False, ) -> Dict[str, jnp.ndarray]: + """Get all parameters and/or states of the module. + + Common backbone of both `get_all_parameters()` and `get_all_states()`. + + Args: + pstate: The state of the trainable parameters. + voltage_solver: The voltage solver that is used. + delta_t: The stepsize. + all_params: All parameters of the module. + params: Whether to get the parameters. + states: Whether to get the states. + + Returns: + A dictionary of all parameters and/or states of the module. + """ states_params = {} pkeys = {} for i, p in enumerate(pstate):