Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make recording a mechanism #151

Merged
merged 1 commit into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 2 additions & 26 deletions neurax/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
def integrate(
module: Module,
stimuli: Union[List[Stimulus], Stimuli],
recordings: List[Recording],
params: List[Dict[str, jnp.ndarray]] = [],
t_max: Optional[float] = None,
delta_t: float = 0.025,
Expand All @@ -25,9 +24,7 @@ def integrate(
Solves ODE and simulates neuron model.

Args:
t_max: Duration of the simulation in milliseconds. If `None`, the duration is
inferred from the duration of the stimulus. If it is larger than the
duration of the stimulus, the stimulus is padded with zeros at the end.
t_max: Duration of the simulation in milliseconds.
delta_t: Time step of the solver in milliseconds.
solver: Which ODE solver to use. Either of ["fwd_euler", "bwd_euler", "cranck"].
tridiag_solver: Algorithm to solve tridiagonal systems. The different options
Expand All @@ -47,7 +44,7 @@ def integrate(
assert module.initialized, "Module is not initialized, run `.initialize()`."

i_current, i_inds = prepare_stim(module, stimuli)
rec_inds = prepare_recs(module, recordings)
rec_inds = module.recordings.comp_index.to_numpy()

# Shorten or pad stimulus depending on `t_max`.
if t_max is not None:
Expand Down Expand Up @@ -107,27 +104,6 @@ def _body_fun(state, i_stim):
return jnp.concatenate([init_recording, recordings[:nsteps_to_return]], axis=0).T


def prepare_recs(module, recordings: List[Recording]):
"""Prepare recordings."""
nseg = module.nseg
cumsum_nbranches = module.cumsum_nbranches

for rec in recordings:
assert rec.cell_ind < len(
module.nbranches_per_cell
), "recording.cell_ind is larger than the number of cells."
assert (
rec.branch_ind < module.nbranches_per_cell[rec.cell_ind]
), "recording.branch_ind is larger than the number of branches in the cell."
assert rec.loc <= 1.0 and rec.loc >= 0.0, "recording.loc must be in [0, 1]."

rec_comp_inds = [index_of_loc(r.branch_ind, r.loc, nseg) for r in recordings]
rec_comp_inds = jnp.asarray(rec_comp_inds)
rec_branch_inds = jnp.asarray([r.cell_ind for r in recordings])
rec_branch_inds = nseg * cumsum_nbranches[rec_branch_inds]
return rec_branch_inds + rec_comp_inds


def prepare_stim(module, stimuli: Union[List[Stimulus], Stimuli]):
"""Prepare stimuli."""
nseg = module.nseg
Expand Down
57 changes: 40 additions & 17 deletions neurax/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def __init__(self):
self.conns: List[Synapse] = None
self.group_views = {}

self.nodes: pd.DataFrame = None
self.syn_edges: pd.DataFrame = None
self.branch_edges: pd.DataFrame = None
self.nodes: Optional[pd.DataFrame] = None
self.syn_edges: Optional[pd.DataFrame] = None
self.branch_edges: Optional[pd.DataFrame] = None

self.cumsum_nbranches: jnp.ndarray = None

Expand All @@ -53,6 +53,9 @@ def __init__(self):
self.trainable_params: List[Dict[str, jnp.ndarray]] = []
self.allow_make_trainable: bool = True

# For recordings.
self.recordings: pd.DataFrame = pd.DataFrame().from_dict({})

def __repr__(self):
return f"{type(self).__name__} with {len(self.channel_nodes)} different channels. Use `.show()` for details."

Expand Down Expand Up @@ -395,6 +398,16 @@ def initialize(self):
self.init_syns()
return self

def record(self):
"""Insert a recording into the given section."""
self._record(self.nodes)

def _record(self, view):
assert (
len(view) == 1
), "Can only record from compartments, not branches, cells, or networks."
self.recordings = pd.concat([self.recordings, view])

def insert(self, channel):
"""Insert a channel."""
self._insert(channel, self.nodes)
Expand Down Expand Up @@ -551,15 +564,30 @@ def show(
states: bool = True,
):
if channel_name is None:
myview = self.view.drop("original_comp_index", axis=1)
myview = myview.drop("original_branch_index", axis=1)
myview = myview.drop("original_cell_index", axis=1)
myview = self.view.drop("global_comp_index", axis=1)
myview = myview.drop("global_branch_index", axis=1)
myview = myview.drop("global_cell_index", axis=1)
return self.pointer._show_base(myview, indices, params, states)
else:
return self.pointer._show_channel(
self.view, channel_name, indices, params, states
)

def set_global_index_and_index(nodes):
"""Use the global compartment, branch, and cell index as the index."""
nodes = nodes.drop("controlled_by_param", axis=1)
nodes = nodes.drop("comp_index", axis=1)
nodes = nodes.drop("branch_index", axis=1)
nodes = nodes.drop("cell_index", axis=1)
nodes = nodes.rename(
columns={
"global_comp_index": "comp_index",
"global_branch_index": "branch_index",
"global_cell_index": "cell_index",
}
)
return nodes

def insert(self, channel):
"""Insert a channel."""
assert not inspect.isclass(
Expand All @@ -568,19 +596,14 @@ def insert(self, channel):
Channel is a class, but it was not initialized. Use `.insert(Channel())`
instead of `.insert(Channel)`.
"""
nodes = self.view.drop("controlled_by_param", axis=1)
nodes = nodes.drop("comp_index", axis=1)
nodes = nodes.drop("branch_index", axis=1)
nodes = nodes.drop("cell_index", axis=1)
nodes = nodes.rename(
columns={
"original_comp_index": "comp_index",
"original_branch_index": "branch_index",
"original_cell_index": "cell_index",
}
)
nodes = self.set_global_index_and_index(self.view)
self.pointer._insert(channel, nodes)

def record(self):
"""Insert a channel."""
nodes = self.set_global_index_and_index(self.view)
self.pointer._record(nodes)

def set_params(self, key: str, val: float):
"""Set parameters of the pointer."""
self.pointer._set_params(key, val, self.view)
Expand Down
8 changes: 0 additions & 8 deletions neurax/recording.py

This file was deleted.

Loading