From dcded4b81626f3e2c46b4813489b9ffd501d6091 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Fri, 15 Nov 2024 17:13:49 +0100 Subject: [PATCH] add init_params method --- jaxley/channels/hh.py | 12 ++++++++++++ jaxley/modules/base.py | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/jaxley/channels/hh.py b/jaxley/channels/hh.py index c19bf002..9387541f 100644 --- a/jaxley/channels/hh.py +++ b/jaxley/channels/hh.py @@ -24,6 +24,7 @@ def __init__(self, name: Optional[str] = None): f"{prefix}_eNa": 50.0, f"{prefix}_eK": -77.0, f"{prefix}_eLeak": -54.3, + f"celsius": 37.0, } self.channel_states = { f"{prefix}_m": 0.2, @@ -75,6 +76,17 @@ def init_state(self, states, v, params, delta_t): f"{prefix}_h": alpha_h / (alpha_h + beta_h), f"{prefix}_n": alpha_n / (alpha_n + beta_n), } + + def init_params(self, states, v, params): + """Initialize the parameters given the temperature.""" + prefix = self._name + q10 = 2.3 + t = params["celsius"] + gna = q10 ** ((t - 37.0) / 10.0) * params[f"{prefix}_gNa"] + gk = q10 ** ((t - 37.0) / 10.0) * params[f"{prefix}_gK"] + gleak = q10 ** ((t - 37.0) / 10.0) * params[f"{prefix}_gLeak"] + return {f"{prefix}_gNa": gna, f"{prefix}_gK": gk, f"{prefix}_gLeak": gleak} + @staticmethod def m_gate(v): diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 66f8fd0f..1c02709e 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1404,6 +1404,47 @@ def init_states(self, delta_t: float = 0.025): # no issues with overriding states). self.nodes.loc[channel_indices, key] = val + @only_allow_module + def init_params(self): + """Run `channel.init_params()` to initialize parameters.""" + # Update states of the channels. + channel_nodes = self.base.nodes + states = self.base._get_states_from_nodes_and_edges() + + # We do not use any `pstate` for initializing. In principle, we could change + # that by allowing an input `params` and `pstate` to this function. + # `voltage_solver` could also be `jax.sparse` here, because both of them + # build the channel parameters in the same way. + params = self.base.get_all_parameters([], voltage_solver="jaxley.thomas") + + for channel in self.base.channels: + name = channel._name + channel_indices = channel_nodes.loc[channel_nodes[name]][ + "global_comp_index" + ].to_numpy() + voltages = channel_nodes.loc[channel_indices, "v"].to_numpy() + + channel_param_names = list(channel.channel_params.keys()) + channel_state_names = list(channel.channel_states.keys()) + channel_states = query_channel_states_and_params( + states, channel_state_names, channel_indices + ) + channel_params = query_channel_states_and_params( + params, channel_param_names, channel_indices + ) + + init_params = channel.init_params( + channel_states, voltages, channel_params + ) + + # `init_params` might not return all channel states. Only the ones that are + # returned are updated here. + for key, val in init_params.items(): + # Note that we are overriding `self.nodes` here, but `self.nodes` is + # not used above to actually compute the current states (so there are + # no issues with overriding states). + self.nodes.loc[channel_indices, key] = val + def _init_morph_for_debugging(self): """Instandiates row and column inds which can be used to solve the voltage eqs.