From 55d914c6e8dfd1cbae38535dfe155f65ebc3f4c3 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Tue, 15 Oct 2024 16:47:01 +0200 Subject: [PATCH] wip: get methods from jaxley --- new_view.ipynb | 430 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 349 insertions(+), 81 deletions(-) diff --git a/new_view.ipynb b/new_view.ipynb index 644f13b7..39ec24a8 100644 --- a/new_view.ipynb +++ b/new_view.ipynb @@ -67,6 +67,17 @@ "net += cell" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: change base nodes\n", + "# TODO: change inds to global inds in existing\n", + "# TODO: replace comp_index with self._in_view" + ] + }, { "cell_type": "code", "execution_count": 137, @@ -110,8 +121,8 @@ " # different or modified from original module implementation\n", " self.groups = {}\n", " self.edges = pd.DataFrame(columns=[f\"{scope}_{lvl}_index\" for lvl in [\"pre_comp\", \"pre_branch\", \"pre_cell\", \"post_comp\", \"post_branch\", \"post_cell\"] for scope in [\"global\", \"local\"]]+[\"pre_locs\", \"post_locs\", \"type\", \"type_ind\"])\n", - " self.nodes[\"controlled_by_param\"] = 0\n", " self._in_view = self.nodes.index.to_numpy()\n", + " self.nodes[\"controlled_by_param\"] = 0\n", " self._scope = \"local\" # defaults to local scope\n", " self.__class__.__name__ = module.__class__.__name__ # HOTFIX\n", "\n", @@ -165,13 +176,25 @@ " new_indices = np.sort(new_indices) if sorted else new_indices\n", " return View(self, at=new_indices)\n", "\n", - " def set(self, key, value):\n", + " def set(self, key: str, val: Union[float, jnp.ndarray]):\n", + " \"\"\"Set parameter of module (or its view) to a new value.\n", + "\n", + " Note that this function can not be called within `jax.jit` or `jax.grad`.\n", + " Instead, it should be used set the parameters of the module **before** the\n", + " simulation. Use `.data_set()` to set parameters during `jax.jit` or\n", + " `jax.grad`.\n", + "\n", + " Args:\n", + " key: The name of the parameter to set.\n", + " val: The value to set the parameter to. If it is `jnp.ndarray` then it\n", + " must be of shape `(len(num_compartments))`.\n", + " \"\"\"\n", " if key in self.nodes.columns:\n", " not_nan = ~self.nodes[key].isna()\n", - " self.base.nodes.loc[self._in_view[not_nan], key] = value\n", + " self.base.nodes.loc[self._in_view[not_nan], key] = val\n", " elif key in self.edges.columns:\n", " not_nan = ~self.edges[key].isna()\n", - " self.base.edges.loc[self._edges_in_view[not_nan], key] = value\n", + " self.base.edges.loc[self._edges_in_view[not_nan], key] = val\n", " else:\n", " raise KeyError(f\"Key '{key}' not found in nodes or edges\")\n", "\n", @@ -206,8 +229,30 @@ " view = self.comp(idx)\n", " return view\n", " \n", - " def add_group(self, name):\n", - " self.base.groups[name] = self._in_view\n", + " def add_to_group(self, group_name: str):\n", + " \"\"\"Add a view of the module to a group.\n", + "\n", + " Groups can then be indexed. For example:\n", + " ```python\n", + " net.cell(0).add_to_group(\"excitatory\")\n", + " net.excitatory.set(\"radius\", 0.1)\n", + " ```\n", + "\n", + " Args:\n", + " group_name: The name of the group.\n", + " \"\"\"\n", + " self.base.groups[group_name] = self._in_view\n", + "\n", + " def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:\n", + " \"\"\"Get all trainable parameters.\n", + "\n", + " The returned parameters should be passed to `jx.integrate(..., params=params).\n", + "\n", + " Returns:\n", + " A list of all trainable parameters in the form of\n", + " [{\"gNa\": jnp.array([0.1, 0.2, 0.3])}, ...].\n", + " \"\"\"\n", + " return self.trainable_params\n", "\n", " def __getattr__(self, key):\n", " if key.startswith(\"__\"):\n", @@ -239,24 +284,72 @@ " view._set_controlled_by_param(key)\n", " return view\n", " \n", - " def show(self):\n", - " nodes = self.nodes.copy() # prevents this from being edited\n", - " # drop columns with global indices if scope is local\n", - " drop = \"global\" if self._scope == \"local\" else \"local\"\n", - " nodes = nodes.drop(columns=[col for col in nodes.columns if drop in col])\n", - " nodes.columns = [col.replace(f\"{self._scope}_\", \"\") for col in nodes.columns]\n", - " return nodes\n", + " def delete_trainables(self):\n", + " \"\"\"Removes all trainable parameters from the module.\"\"\"\n", + " self.base.indices_set_by_trainables = []\n", + " self.base.trainable_params = []\n", + " self.base.num_trainable_params = 0\n", + "\n", + " def delete_recordings(self):\n", + " \"\"\"Removes all recordings from the module.\"\"\"\n", + " self.base.recordings = pd.DataFrame().from_dict({})\n", + "\n", + " def show(\n", + " self,\n", + " param_names: Optional[Union[str, List[str]]] = None, # TODO.\n", + " *,\n", + " indices: bool = True,\n", + " params: bool = True,\n", + " states: bool = True,\n", + " channel_names: Optional[List[str]] = None,\n", + " ) -> pd.DataFrame:\n", + " \"\"\"Print detailed information about the Module or a view of it.\n", + "\n", + " Args:\n", + " param_names: The names of the parameters to show. If `None`, all parameters\n", + " are shown. NOT YET IMPLEMENTED.\n", + " indices: Whether to show the indices of the compartments.\n", + " params: Whether to show the parameters of the compartments.\n", + " states: Whether to show the states of the compartments.\n", + " channel_names: The names of the channels to show. If `None`, all channels are\n", + " shown.\n", + "\n", + " Returns:\n", + " A `pd.DataFrame` with the requested information.\n", + " \"\"\"\n", + " nodes = self.nodes.copy() # prevents this from being edited\n", + "\n", + " cols = []\n", + " inds = [\"comp_index\", \"branch_index\", \"cell_index\"]\n", + " scopes = [\"local\", \"global\"]\n", + " cols += (\n", + " [f\"{scope}_{idx}\" for idx in inds for scope in scopes] if indices else []\n", + " )\n", + " cols += [ch._name for ch in self.channels] if channel_names else []\n", + " cols += (\n", + " sum([list(ch.channel_params) for ch in self.channels], []) if params else []\n", + " )\n", + " cols += (\n", + " sum([list(ch.channel_states) for ch in self.channels], []) if states else []\n", + " )\n", + "\n", + " if not param_names is None:\n", + " cols = (\n", + " [c for c in cols if c in param_names] if params else list(param_names)\n", + " )\n", + "\n", + " return nodes[cols]\n", " \n", - " def __getitem__(self, idx):\n", + " def __getitem__(self, index):\n", " levels = [\"network\", \"cell\", \"branch\", \"comp\"]\n", - " module = self.base.__class__.__name__.lower() # \n", + " module = self.base.__class__.__name__.lower() #\n", " module = \"comp\" if module == \"compartment\" else module\n", - " \n", - " children = levels[levels.index(module)+1:]\n", - " idx = idx if isinstance(idx, tuple) else (idx,)\n", + "\n", + " children = levels[levels.index(module) + 1 :]\n", + " index = index if isinstance(index, tuple) else (index,)\n", " view = self\n", " for i, child in enumerate(children):\n", - " view = view._at_level(child, idx[i])\n", + " view = view._at_level(child, index[i])\n", " return view\n", " \n", " def _iter_level(self, level):\n", @@ -278,7 +371,14 @@ " yield from self._iter_level(\"comp\") \n", "\n", " @property\n", - " def shape(self):\n", + " def shape(self) -> Tuple[int]:\n", + " \"\"\"Returns the number of submodules contained in a module.\n", + "\n", + " ```\n", + " network.shape = (num_cells, num_branches, num_compartments)\n", + " cell.shape = (num_branches, num_compartments)\n", + " branch.shape = (num_compartments,)\n", + " ```\"\"\"\n", " cols = [\"global_cell_index\", \"global_branch_index\", \"global_comp_index\"]\n", " raw_shape = self.nodes[cols].nunique().to_list()\n", "\n", @@ -286,7 +386,7 @@ " levels = [\"network\", \"cell\", \"branch\", \"comp\"]\n", " module = self.base.__class__.__name__.lower()\n", " module = \"comp\" if module == \"compartment\" else module\n", - " shape = tuple(raw_shape[levels.index(module):])\n", + " shape = tuple(raw_shape[levels.index(module) :])\n", " return shape\n", " \n", " def copy(self, reset_index=False, as_module=False):\n", @@ -302,22 +402,40 @@ " def view(self):\n", " return View(self, self._in_view)\n", "\n", - " def vis(self, dims=[0,1], level=\"branch\", ax=None, type=\"line\", **kwargs):\n", - " if ax is None:\n", - " _, ax = plt.subplots(1, 1, figsize=(3, 3))\n", - " if level == \"branch\":\n", - " for coords_of_branch in self.xyzr:\n", - " x1, x2 = coords_of_branch[:, dims].T\n", - "\n", - " if \"line\" in type.lower():\n", - " _ = ax.plot(x1, x2, **kwargs)\n", - " elif \"scatter\" in type.lower():\n", - " _ = ax.scatter(x1, x2, **kwargs)\n", - " else:\n", - " raise NotImplementedError\n", - " if level == \"comp\":\n", - " x1, x2 = self.nodes[[\"x\", \"y\", \"z\"]].values[:, dims].T\n", - " ax.scatter(x1, x2, **kwargs)\n", + " def vis(\n", + " self,\n", + " ax: Optional[Axes] = None,\n", + " col: str = \"k\",\n", + " dims: Tuple[int] = (0, 1),\n", + " type: str = \"line\",\n", + " morph_plot_kwargs: Dict = {},\n", + " ) -> Axes:\n", + " \"\"\"Visualize the module.\n", + "\n", + " Args:\n", + " ax: An axis into which to plot.\n", + " col: The color for all branches.\n", + " dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n", + " two of them.\n", + " morph_plot_kwargs: Keyword arguments passed to the plotting function.\n", + " \"\"\"\n", + " branches_inds = self.nodes[\"branch_index\"].to_numpy()\n", + " coords = []\n", + " for branch_ind in branches_inds:\n", + " assert not np.any(\n", + " np.isnan(self.xyzr[branch_ind][:, dims])\n", + " ), \"No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`.\"\n", + " coords.append(self.xyzr[branch_ind])\n", + "\n", + " ax = plot_morph(\n", + " coords,\n", + " dims=dims,\n", + " col=col,\n", + " ax=ax,\n", + " type=type,\n", + " morph_plot_kwargs=morph_plot_kwargs,\n", + " )\n", + "\n", " return ax\n", "\n", " def record(self, state, verbose=True):\n", @@ -335,7 +453,9 @@ " # Channel does not yet exist in the `jx.Module` at all.\n", " if name not in [c._name for c in self.base.channels]:\n", " self.base.channels.append(channel)\n", - " self.base.nodes[name] = False # Previous columns do not have the new channel.\n", + " self.base.nodes[name] = (\n", + " False # Previous columns do not have the new channel.\n", + " )\n", "\n", " if channel.current_name not in self.base.membrane_current_names:\n", " self.base.membrane_current_names.append(channel.current_name)\n", @@ -351,19 +471,58 @@ " for key in channel.channel_states:\n", " self.base.nodes.loc[self._in_view, key] = channel.channel_states[key]\n", " \n", - " def stimulate(self, current, verbose=False):\n", - " self._external_input(\"i\", current, verbose)\n", - "\n", - " def _external_input(self, key, values, verbose=False):\n", + " def stimulate(self, current: Optional[jnp.ndarray] = None, verbose: bool = True):\n", + " \"\"\"Insert a stimulus into the compartment.\n", + "\n", + " current must be a 1d array or have batch dimension of size `(num_compartments, )`\n", + " or `(1, )`. If 1d, the same stimulus is added to all compartments.\n", + "\n", + " This function cannot be run during `jax.jit` and `jax.grad`. Because of this,\n", + " it should only be used for static stimuli (i.e., stimuli that do not depend\n", + " on the data and that should not be learned). For stimuli that depend on data\n", + " (or that should be learned), please use `data_stimulate()`.\n", + "\n", + " Args:\n", + " current: Current in `nA`.\n", + " \"\"\"\n", + " self._external_input(\"i\", current, verbose=verbose)\n", + "\n", + " def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True):\n", + " \"\"\"Clamp a state to a given value across specified compartments.\n", + "\n", + " Args:\n", + " state_name: The name of the state to clamp.\n", + " state_array (jnp.nd: Array of values to clamp the state to.\n", + " verbose : If True, prints details about the clamping.\n", + "\n", + " This function sets external states for the compartments.\n", + " \"\"\"\n", + " if state_name not in self.nodes.columns:\n", + " raise KeyError(f\"{state_name} is not a recognized state in this module.\")\n", + " self._external_input(state_name, state_array, self.nodes, verbose=verbose)\n", + "\n", + " def _external_input(\n", + " self,\n", + " key: str,\n", + " values: Optional[jnp.ndarray],\n", + " verbose: bool = True,\n", + " ):\n", " values = values if values.ndim == 2 else jnp.expand_dims(values, axis=0)\n", " batch_size = values.shape[0]\n", " num_inserted = len(self._in_view)\n", " is_multiple = num_inserted == batch_size\n", - " values = values if is_multiple else jnp.repeat(values, len(self._in_view), axis=0)\n", - " assert batch_size in [1, num_inserted], \"Number of comps and stimuli do not match.\"\n", + " values = (\n", + " values if is_multiple else jnp.repeat(values, len(self._in_view), axis=0)\n", + " )\n", + " assert batch_size in [\n", + " 1,\n", + " num_inserted,\n", + " ], \"Number of comps and stimuli do not match.\"\n", "\n", " if key in self.base.externals.keys():\n", - " self.base.externals[key] = jnp.concatenate([self.base.externals[key], values])\n", + " self.base.externals[key] = jnp.concatenate(\n", + " [self.base.externals[key], values]\n", + " )\n", " self.base.external_inds[key] = jnp.concatenate(\n", " [self.base.external_inds[key], self._in_view]\n", " )\n", @@ -372,10 +531,9 @@ " self.base.external_inds[key] = self._in_view\n", "\n", " if verbose:\n", - " print(f\"Added {num_inserted} external_states. See `.externals` for details.\")\n", - "\n", - " def clamp(self, state_name, state_array, verbose=False):\n", - " self._external_input(state_name, state_array, verbose=verbose)\n", + " print(\n", + " f\"Added {num_inserted} external_states. See `.externals` for details.\"\n", + " )\n", "\n", " def data_stimulate(self, current, data_stimuli, verbose=False):\n", " current = current if current.ndim == 2 else jnp.expand_dims(current, axis=0)\n", @@ -404,7 +562,21 @@ "\n", " return (currents, inds)\n", "\n", - " def data_set(self, key, val, param_state=None):\n", + " def data_set(\n", + " self,\n", + " key: str,\n", + " val: Union[float, jnp.ndarray],\n", + " param_state: Optional[List[Dict]],\n", + " ):\n", + " \"\"\"Set parameter of module (or its view) to a new value within `jit`.\n", + "\n", + " Args:\n", + " key: The name of the parameter to set.\n", + " val: The value to set the parameter to. If it is `jnp.ndarray` then it\n", + " must be of shape `(len(num_compartments))`.\n", + " param_state: State of the setted parameters, internally used such that this\n", + " function does not modify global state.\n", + " \"\"\"\n", " # Note: `data_set` does not support arrays for `val`.\n", " if key in self.nodes.columns:\n", " not_nan = ~self.nodes[key].isna()\n", @@ -423,29 +595,77 @@ " raise KeyError(\"Key not recognized.\")\n", " return param_state\n", "\n", - " def move(self, x,y,z, update_nodes=True):\n", + " def move(\n", + " self, x: float = 0.0, y: float = 0.0, z: float = 0.0, update_nodes: bool = True\n", + " ):\n", + " \"\"\"Move cells or networks by adding to their (x, y, z) coordinates.\n", + "\n", + " This function is used only for visualization. It does not affect the simulation.\n", + "\n", + " Args:\n", + " x: The amount to move in the x direction in um.\n", + " y: The amount to move in the y direction in um.\n", + " z: The amount to move in the z direction in um.\n", + " update_nodes: Whether `.nodes` should be updated or not. Setting this to\n", + " `False` largely speeds up moving, especially for big networks, but\n", + " `.nodes` or `.show` will not show the new xyz coordinates.\n", + " \"\"\"\n", " indizes = self.nodes[\"global_branch_index\"].unique()\n", " for i in indizes:\n", " self.base.xyzr[i][:, :3] += np.array([x, x, y])\n", " if update_nodes:\n", " self._update_nodes_with_xyz()\n", "\n", - " def move_to(self, x,y,z, update_nodes=True):\n", + " def move_to(\n", + " self,\n", + " x: Union[float, np.ndarray] = 0.0,\n", + " y: Union[float, np.ndarray] = 0.0,\n", + " z: Union[float, np.ndarray] = 0.0,\n", + " update_nodes: bool = True,\n", + " ):\n", + " \"\"\"Move cells or networks to a location (x, y, z).\n", + "\n", + " If x, y, and z are floats, then the first compartment of the first branch\n", + " of the first cell is moved to that float coordinate, and everything else is\n", + " shifted by the difference between that compartment's previous coordinate and\n", + " the new float location.\n", + "\n", + " If x, y, and z are arrays, then they must each have a length equal to the number\n", + " of cells being moved. Then the first compartment of the first branch of each\n", + " cell is moved to the specified location.\n", + "\n", + " Args:\n", + " update_nodes: Whether `.nodes` should be updated or not. Setting this to\n", + " `False` largely speeds up moving, especially for big networks, but\n", + " `.nodes` or `.show` will not show the new xyz coordinates.\n", + " \"\"\"\n", " # Test if any coordinate values are NaN which would greatly affect moving\n", " if np.any(np.concatenate(self.xyzr, axis=0)[:, :3] == np.nan):\n", " raise ValueError(\n", " \"NaN coordinate values detected. Shift amounts cannot be computed. Please run compute_xyzr() or assign initial coordinate values.\"\n", " )\n", - " \n", + "\n", " indizes = self.nodes[\"global_branch_index\"].unique()\n", - " move_by = np.array([x, y, z]).T - self.xyzr[0][0,:3] # move with respect to root idx\n", - " \n", + " move_by = (\n", + " np.array([x, y, z]).T - self.xyzr[0][0, :3]\n", + " ) # move with respect to root idx\n", + "\n", " for idx in indizes:\n", " self.base.xyzr[idx][:, :3] += move_by\n", " if update_nodes:\n", " self._update_nodes_with_xyz()\n", "\n", - " def rotate(self, degrees, rotation_axis=\"xy\", update_nodes=True):\n", + " def rotate(\n", + " self, degrees: float, rotation_axis: str = \"xy\", update_nodes: bool = True\n", + " ):\n", + " \"\"\"Rotate jaxley modules clockwise. Used only for visualization.\n", + "\n", + " This function is used only for visualization. It does not affect the simulation.\n", + "\n", + " Args:\n", + " degrees: How many degrees to rotate the module by.\n", + " rotation_axis: Either of {`xy` | `xz` | `yz`}.\n", + " \"\"\"\n", " degrees = degrees / 180 * np.pi\n", " if rotation_axis == \"xy\":\n", " dims = [0, 1]\n", @@ -467,44 +687,91 @@ " self._update_nodes_with_xyz()\n", "\n", " def _update_nodes_with_xyz(self):\n", - " num_branches = len(self.base.xyzr)\n", - " comp_ends = (\n", - " np.linspace(0, 1, self.nseg + 1).reshape(1, -1).repeat(num_branches, 0)\n", + " \"\"\"Add xyz coordinates of compartment centers to nodes.\n", + "\n", + " Centers are the midpoint between the comparment endpoints on the morphology\n", + " as defined by xyzr.\n", + "\n", + " Note: For sake of performance, interpolation is not done for each branch\n", + " individually, but only once along a concatenated (and padded) array of all branches.\n", + " This means for nsegs = [2,4] and normalized cum_branch_lens of [[0,1],[0,1]] we would\n", + " interpolate xyz at the locations comp_ends = [[0,0.5,1], [0,0.25,0.5,0.75,1]],\n", + " where 0 is the start of the branch and 1 is the end point at the full branch_len.\n", + " To avoid do this in one go we set comp_ends = [0,0.5,1,2,2.25,2.5,2.75,3], and\n", + " norm_cum_branch_len = [0,1,2,3] incrememting and also padding them by 1 to\n", + " avoid overlapping branch_lens i.e. norm_cum_branch_len = [0,1,1,2] for only\n", + " incrementing.\n", + " \"\"\"\n", + " nsegs = (\n", + " self.nodes.groupby(\"global_branch_index\")[\"global_comp_index\"]\n", + " .nunique()\n", + " .to_numpy()\n", " )\n", - " comp_ends = comp_ends + 2 * np.arange(num_branches).reshape(\n", - " -1, 1\n", - " ) # inter-branch padding\n", - " comp_ends = comp_ends.reshape(-1)\n", - " branch_lens = []\n", - " for i, xyzr in enumerate(self.base.xyzr):\n", - " branch_len = np.sqrt(\n", - " np.sum(np.diff(xyzr[:, :3], axis=0) ** 2, axis=1)\n", - " ).cumsum()\n", - " branch_len = np.hstack([np.array([0]), branch_len])\n", - " branch_len = branch_len / branch_len.max() + 2 * i # add padding like above\n", - " branch_len[np.isnan(branch_len)] = 0\n", - " branch_lens.append(branch_len)\n", - " branch_lens = np.hstack(branch_lens)\n", - " xyz = np.vstack(self.base.xyzr)[:, :3]\n", - " xyz = v_interp(comp_ends, branch_lens, xyz).reshape(\n", - " 3, num_branches, self.nseg + 1\n", + "\n", + " comp_ends = np.hstack(\n", + " [np.linspace(0, 1, nseg + 1) + 2 * i for i, nseg in enumerate(nsegs)]\n", " )\n", - " centers = ((xyz[:, :, 1:] + xyz[:, :, :-1]) / 2).reshape(3, -1).T\n", - " self.base.nodes.loc[self._in_view, [\"x\", \"y\", \"z\"]] = centers[self._in_view]\n", + " comp_ends = comp_ends.reshape(-1)\n", + " cum_branch_lens = []\n", + " for i, xyzr in enumerate(self.xyzr):\n", + " branch_len = np.sqrt(np.sum(np.diff(xyzr[:, :3], axis=0) ** 2, axis=1))\n", + " cum_branch_len = np.cumsum(np.concatenate([np.array([0]), branch_len]))\n", + " max_len = cum_branch_len.max()\n", + " # add padding like above\n", + " cum_branch_len = cum_branch_len / (max_len if max_len > 0 else 1) + 2 * i\n", + " cum_branch_len[np.isnan(cum_branch_len)] = 0\n", + " cum_branch_lens.append(cum_branch_len)\n", + " cum_branch_lens = np.hstack(cum_branch_lens)\n", + " xyz = np.vstack(self.xyzr)[:, :3]\n", + " xyz = v_interp(comp_ends, cum_branch_lens, xyz).T\n", + " centers = (xyz[:-1] + xyz[1:]) / 2 # unaware of inter vs intra comp centers\n", + " cum_nsegs = np.cumsum(nsegs)\n", + " # this means centers between comps have to be removed here\n", + " between_comp_inds = (cum_nsegs + np.arange(len(cum_nsegs)))[:-1]\n", + " centers = np.delete(centers, between_comp_inds, axis=0)\n", + " self.base.nodes.loc[self._in_view, [\"x\", \"y\", \"z\"]] = centers\n", " return centers, xyz\n", " \n", - " def make_trainable(self, key, init_val=None, verbose=False):\n", + " def make_trainable(\n", + " self,\n", + " key: str,\n", + " init_val: Optional[Union[float, list]] = None,\n", + " verbose: bool = True,\n", + " ):\n", + " \"\"\"Make a parameter trainable.\n", + "\n", + " If a parameter is made trainable, it will be returned by `get_parameters()`\n", + " and should then be passed to `jx.integrate(..., params=params)`.\n", + "\n", + " Args:\n", + " key: Name of the parameter to make trainable.\n", + " init_val: Initial value of the parameter. If `float`, the same value is\n", + " used for every created parameter. If `list`, the length of the list has\n", + " to match the number of created parameters. If `None`, the current\n", + " parameter value is used and if parameter sharing is performed that the\n", + " current parameter value is averaged over all shared parameters.\n", + " verbose: Whether to print the number of parameters that are added and the\n", + " total number of parameters.\n", + " \"\"\"\n", + " assert (\n", + " self.allow_make_trainable\n", + " ), \"network.cell('all').make_trainable() is not supported. Use a for-loop over cells.\"\n", + "\n", " data = self.nodes if key in self.nodes.columns else None\n", " data = self.edges if key in self.edges.columns else data\n", " assert data is not None, f\"Key '{key}' not found in nodes or edges\"\n", " not_nan = ~data[key].isna()\n", " data = data.loc[not_nan]\n", - " assert len(data) > 0, \"No settable parameters found in the selected compartments.\"\n", + " assert (\n", + " len(data) > 0\n", + " ), \"No settable parameters found in the selected compartments.\"\n", "\n", " grouped_view = data.groupby(\"controlled_by_param\")\n", " # Because of this `x.index.values` we cannot support `make_trainable()` on\n", " # the module level for synapse parameters (but only for `SynapseView`).\n", - " inds_of_comps = list(grouped_view.apply(lambda x: x.index.values, include_groups=False))\n", + " inds_of_comps = list(\n", + " grouped_view.apply(lambda x: x.index.values, include_groups=False)\n", + " )\n", " indices_per_param = jnp.stack(inds_of_comps)\n", " # Sorted inds are only used to infer the correct starting values.\n", " param_vals = jnp.asarray(\n", @@ -615,6 +882,7 @@ " self.nbranches_per_cell = self._nbranches_per_cell_in_view()\n", " self.cumsum_nbranches = np.cumsum(self.nbranches_per_cell)\n", " self.comb_branches_in_each_level = pointer.comb_branches_in_each_level\n", + " self.branch_edges = pointer.branch_edges.loc[self._branch_edges_in_view]\n", "\n", " self.synapse_names = np.unique(self.edges[\"type\"]).tolist()\n", " self.synapses, self.synapse_param_names, self.synapse_state_names = self._synapses_in_view(pointer)\n",