From ca40d466618d233ef4d089b02127f392ec11f2b8 Mon Sep 17 00:00:00 2001 From: Michael Deistler Date: Fri, 3 Nov 2023 17:43:18 +0100 Subject: [PATCH] Allow task training by setting the current from outside --- neurax/__init__.py | 2 +- neurax/integrate.py | 6 +++--- neurax/stimulus.py | 40 ---------------------------------------- 3 files changed, 4 insertions(+), 44 deletions(-) diff --git a/neurax/__init__.py b/neurax/__init__.py index 8bb1fdcc8..3ba8a986b 100644 --- a/neurax/__init__.py +++ b/neurax/__init__.py @@ -2,4 +2,4 @@ from neurax.integrate import integrate from neurax.modules import * from neurax.optimize import ParamTransform -from neurax.stimulus import Stimuli, Stimulus, step_current +from neurax.stimulus import step_current diff --git a/neurax/integrate.py b/neurax/integrate.py index 001f8e7f8..dd04dd6aa 100644 --- a/neurax/integrate.py +++ b/neurax/integrate.py @@ -4,14 +4,14 @@ import jax.numpy as jnp from neurax.modules import Module -from neurax.stimulus import Stimuli, Stimulus -from neurax.utils.cell_utils import index_of_loc from neurax.utils.jax_utils import nested_checkpoint_scan def integrate( module: Module, params: List[Dict[str, jnp.ndarray]] = [], + currents: Optional[jnp.ndarray] = None, + *, t_max: Optional[float] = None, delta_t: float = 0.025, solver: str = "bwd_euler", @@ -41,7 +41,7 @@ def integrate( assert module.initialized, "Module is not initialized, run `.initialize()`." - i_current = module.currents.T + i_current = module.currents.T if currents is None else currents.T i_inds = module.current_inds.comp_index.to_numpy() rec_inds = module.recordings.comp_index.to_numpy() diff --git a/neurax/stimulus.py b/neurax/stimulus.py index 1a6ceb7c8..ea0191ad2 100644 --- a/neurax/stimulus.py +++ b/neurax/stimulus.py @@ -7,46 +7,6 @@ from neurax.utils.cell_utils import index_of_loc -class Stimulus: - """A single stimulus to the network.""" - - def __init__( - self, cell_ind, branch_ind, loc, current: Optional[jnp.ndarray] = None - ): - """ - Args: - current: Time series of the current. - """ - self.cell_ind = cell_ind - self.branch_ind = branch_ind - self.loc = loc - self.current = current - - -class Stimuli: - """Several stimuli to the network. - - Here, the properties of all individual stimuli already get vectorized and put - into arrays. This increases speed for big datasets consisting of dozens or hundreds - of stimuli. - """ - - def __init__( - self, stims: List[Stimulus], nseg_per_branch: int, cumsum_nbranches: jnp.ndarray - ): - self.comp_inds = jnp.asarray( - [index_of_loc(s.branch_ind, s.loc, nseg_per_branch) for s in stims] - ) - cell_inds = jnp.asarray([s.cell_ind for s in stims]) - self.branch_inds = cumsum_nbranches[cell_inds] * nseg_per_branch - self.currents = jnp.asarray([s.current for s in stims]).T # nA - - def set_currents(self, currents: float): - """Rescale the current of the stimulus with a constant value over time.""" - self.currents = currents - return self - - def step_current( i_delay: float, i_dur: float, i_amp: float, time_vec: jnp.asarray, i_offset=0.0 ):