From f9ba0d3f8f2ff2248e83d724e572c2a06831bbf8 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Wed, 23 Oct 2024 09:36:57 +0200 Subject: [PATCH 01/17] add: add first version of new tests --- tests/test_viewing.py | 141 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 140 insertions(+), 1 deletion(-) diff --git a/tests/test_viewing.py b/tests/test_viewing.py index d869a29a..0c824fce 100644 --- a/tests/test_viewing.py +++ b/tests/test_viewing.py @@ -4,6 +4,7 @@ from copy import deepcopy import jax +import pandas as pd import pytest jax.config.update("jax_enable_x64", True) @@ -268,7 +269,7 @@ def test_solve_indexer(): # make sure all attrs in module also have a corresponding attr in view @pytest.mark.parametrize("module", [comp, branch, cell, net]) -def test_view_attrs(module): +def test_view_attrs(module: jx.Compartment | jx.Branch | jx.Cell | jx.Network): # attributes of Module that do not have to exist in View exceptions = ["view"] # TODO: should be added to View in the future @@ -308,3 +309,141 @@ def test_view_attrs(module): # add test local_indexing and global_indexing # add cell.comp (branch is skipped also for param sharing) # add tests for new features i.e. iter, context, scope + +comp = jx.Compartment() +branch = jx.Branch(comp, nseg=4) +cell = jx.Cell([branch] * 4, parents=[-1, 0, 0]) +net = jx.Network([cell] * 4) + +@pytest.mark.parametrize("module", [comp, branch, cell, net]) +def test_different_index_types(module): + # test int, range, slice, list, np.array, pd.Index + index_types = [0, range(3), slice(0, 3), [0, 1, 2], np.array([0, 1, 2]), pd.Index([0, 1, 2])] + for index in index_types: + assert module.comp(index), f"Failed for {type(index)}" + + # for loc test float and list of floats + assert module.loc(0.0), "Failed for float" + assert module.loc([0.0, 0.5, 1.0]), "Failed for List[float]" + +def test_select(): + comp = jx.Compartment() + branch = jx.Branch(comp, nseg=3) + cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) + net = jx.Network([cell] * 3) + connect(net[0, 0, :], net[1, 0, :], TestSynapse()) + + np.random.seed(0) + inds = np.random.choice(net.nodes.index, replace=False, size=5) + view = net.select(nodes=inds) + assert np.all(view.nodes.index == inds), "Selecting nodes by index failed" + + inds = np.random.choice(net.edges.index, replace=False, size=2) + view = net.select(edges=inds) + assert np.all(view.edges.index == inds), "Selecting edges by index failed" + + +def test_arbitrary_selection(): + comp = jx.Compartment() + branch = jx.Branch(comp, nseg=3) + cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) + net = jx.Network([cell] * 3) + + for view, local_targets, global_targets in zip([net.branch(0),net.cell(0).comp(0), net.comp(0), cell.comp(0)], + [], []): + view.nodes["local_comp_index"] = local_targets + view.nodes["global_comp_index"] = global_targets + +def test_scope(): + comp = jx.Compartment() + branch = jx.Branch(comp, nseg=3) + cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) + + view = cell.scope("global").branch(0) + assert view._scope == "global" + view = view.scope("local").comp(0) + assert view.nodes[["global_branch_index", "global_comp_index"]] + + cell.set_scope("global") + assert cell._scope == "global" + view = cell.branch(0).comp(5) + assert np.all(view.nodes[["global_branch_index", "global_comp_index"]] == [0, 5]) + + cell.set_scope("local") + assert cell._scope == "local" + view = cell.branch(0).comp(5) + assert np.all(view.nodes[["global_branch_index", "global_comp_index"]] == [None, None]) + +def test_context_manager(): + comp = jx.Compartment() + branch = jx.Branch(comp, nseg=3) + cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) + + with cell.branch(0).comp(0) as comp: + comp.set("v", -71) + comp.set("radius", 0.1) + + with cell.branch(1).comp([0, 1]) as comps: + comps.set("v", -71) + comps.set("radius", 0.1) + + assert np.all(cell.branch(0).comp(0).nodes[["v", "radius"]] == [-71, 0.1]) + assert np.all(cell.branch(1).comp([0, 1]).nodes[["v", "radius"]] == [-71, 0.1]) + +def test_iter(): + comp = jx.Compartment() + branch1 = jx.Branch(comp, nseg=2) + branch2 = jx.Branch(comp, nseg=3) + cell = jx.Cell([branch1, branch1, branch2], parents=[-1, 0, 0]) + net = jx.Network([cell] * 2) + + [len(branch.nodes) for branch in cell.branches] + + for cell in net.cells: + for branch in cell.branches: + for comp in branch.comps: + pass + + for cell in net: + for branch in cell: + for comp in branch: + pass + + [len(comp.nodes) for comp in net[0, 0].comps] + [len(comp.nodes) for comp in net.comps] + + + + + + +# # iterables +# for cell in net.cells: +# for branch in cell.branches: +# for comp in branch.comps: +# comp.set("v", -71) + +# for comp in net.cell(0).branch(0).comps: +# comp.set("v", -72) +# net.show()[["v"]] + + +# # groups +# net.cell(1).branch(0).add_group("group") +# net.group.show() + +# # Channel and Synapse views +# net.HH.show() +# net.cell(0).HH.nodes +# net.HH.cell(0).nodes + +# net.TestSynapse.nodes +# net.cell([0,1]).TestSynapse.nodes +# net.TestSynapse.cell(0).nodes + +# # edges +# net.edge([0,1,2]).edges + +# # copying +# cell0 = net.cell(0).copy() +# cell0.show() \ No newline at end of file From 228102c76d460ff4a621a0517113157219433126 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Wed, 23 Oct 2024 14:32:10 +0200 Subject: [PATCH 02/17] add: add more tests for new view --- tests/test_viewing.py | 263 +++++++++++++++++++++++++++++++----------- 1 file changed, 195 insertions(+), 68 deletions(-) diff --git a/tests/test_viewing.py b/tests/test_viewing.py index 0c824fce..6cf3460e 100644 --- a/tests/test_viewing.py +++ b/tests/test_viewing.py @@ -305,145 +305,272 @@ def test_view_attrs(module: jx.Compartment | jx.Branch | jx.Cell | jx.Network): ), f"Type mismatch: {name}, Module type: {type(getattr(module, name))}, View type: {type(getattr(view, name))}" -# TODO: test filter for modules and check for param sharing -# add test local_indexing and global_indexing -# add cell.comp (branch is skipped also for param sharing) -# add tests for new features i.e. iter, context, scope - comp = jx.Compartment() -branch = jx.Branch(comp, nseg=4) -cell = jx.Cell([branch] * 4, parents=[-1, 0, 0]) +branch = jx.Branch([comp] * 4) +cell = jx.Cell([branch] * 4, parents=[-1, 0, 0, 0]) net = jx.Network([cell] * 4) + @pytest.mark.parametrize("module", [comp, branch, cell, net]) def test_different_index_types(module): # test int, range, slice, list, np.array, pd.Index - index_types = [0, range(3), slice(0, 3), [0, 1, 2], np.array([0, 1, 2]), pd.Index([0, 1, 2])] + index_types = [ + 0, + range(3), + slice(0, 3), + [0, 1, 2], + np.array([0, 1, 2]), + pd.Index([0, 1, 2]), + ] for index in index_types: + assert isinstance( + module._reformat_index(index), np.ndarray + ), f"Failed for {type(index)}" assert module.comp(index), f"Failed for {type(index)}" - + assert View(module).comp(index), f"Failed for {type(index)}" + # for loc test float and list of floats assert module.loc(0.0), "Failed for float" - assert module.loc([0.0, 0.5, 1.0]), "Failed for List[float]" + assert module.loc([0.0, 0.5, 1.0]), "Failed for List[float]" + def test_select(): comp = jx.Compartment() - branch = jx.Branch(comp, nseg=3) + branch = jx.Branch([comp] * 3) cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) net = jx.Network([cell] * 3) connect(net[0, 0, :], net[1, 0, :], TestSynapse()) np.random.seed(0) + + # select only nodes inds = np.random.choice(net.nodes.index, replace=False, size=5) view = net.select(nodes=inds) assert np.all(view.nodes.index == inds), "Selecting nodes by index failed" + # select only edges inds = np.random.choice(net.edges.index, replace=False, size=2) view = net.select(edges=inds) assert np.all(view.edges.index == inds), "Selecting edges by index failed" - -def test_arbitrary_selection(): + # check if pre and post comps of edges are in nodes + edge_node_inds = np.unique( + view.edges[["global_pre_comp_index", "global_post_comp_index"]] + .to_numpy() + .flatten() + ) + assert np.all( + view.nodes["global_comp_index"] == edge_node_inds + ), "Selecting edges did not yield the correct nodes." + + # select nodes and edges + node_inds = np.random.choice(net.nodes.index, replace=False, size=5) + edge_inds = np.random.choice(net.edges.index, replace=False, size=2) + view = net.select(nodes=node_inds, edges=edge_inds) + assert np.all( + view.nodes.index == node_inds + ), "Selecting nodes and edges by index failed for nodes." + assert np.all( + view.edges.index == edge_inds + ), "Selecting nodes and edges by index failed for edges." + + +def test_viewing(): comp = jx.Compartment() - branch = jx.Branch(comp, nseg=3) + branch = jx.Branch([comp] * 3) cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) net = jx.Network([cell] * 3) - for view, local_targets, global_targets in zip([net.branch(0),net.cell(0).comp(0), net.comp(0), cell.comp(0)], - [], []): - view.nodes["local_comp_index"] = local_targets - view.nodes["global_comp_index"] = global_targets + # test parameter sharing works correctly + nodes1 = net.branch(0).comp("all").nodes + nodes2 = net.branch(0).nodes + nodes3 = net.cell(0).nodes + control_params1 = nodes1.pop("controlled_by_param") + control_params2 = nodes2.pop("controlled_by_param") + control_params3 = nodes3.pop("controlled_by_param") + assert np.all(nodes1 == nodes2), "Nodes are not the same" + assert np.all( + control_params1 == nodes1["global_comp_index"] + ), "Parameter sharing is not correct" + assert np.all( + control_params2 == nodes2["global_branch_index"] + ), "Parameter sharing is not correct" + assert np.all( + control_params3 == nodes3["global_cell_index"] + ), "Parameter sharing is not correct" + + # test local and global indexes match the expected targets + for view, local_targets, global_targets in zip( + [ + net.branch(0), # shows every comp on 0th branch of all cells + cell.branch("all"), # shows all branches and comps of cell + net.cell(0).comp(0), # shows every 0th comp for every branch on 0th cell + net.comp(0), # shows 0th comp of every branch of every cell + cell.comp(0), # shows 0th comp of every branch of cell + ], + [[0, 1, 2] * 3, [0, 1, 2] * 3, [0] * 3, [0] * 9, [0] * 3], + [ + [0, 1, 2, 9, 10, 11, 18, 19, 20], + list(range(9)), + [0, 3, 6], + list(range(0, 27, 3)), + list(range(0, 9, 3)), + ], + ): + assert np.all( + view.nodes["local_comp_index"] == local_targets + ), "Indices do not match that of the target" + assert np.all( + view.nodes["global_comp_index"] == global_targets + ), "Indices do not match that of the target" + + with pytest.raises(ValueError): + net.scope("global").comp(999) # Nothing should be in View + def test_scope(): comp = jx.Compartment() - branch = jx.Branch(comp, nseg=3) + branch = jx.Branch([comp] * 3) cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) - view = cell.scope("global").branch(0) + view = cell.scope("global").branch(1) assert view._scope == "global" view = view.scope("local").comp(0) - assert view.nodes[["global_branch_index", "global_comp_index"]] + assert np.all( + view.nodes[["global_branch_index", "global_comp_index"]] == [1, 3] + ), "Expected [1,3] but got {}".format( + view.nodes[["global_branch_index", "global_comp_index"]] + ) cell.set_scope("global") assert cell._scope == "global" - view = cell.branch(0).comp(5) - assert np.all(view.nodes[["global_branch_index", "global_comp_index"]] == [0, 5]) + view = cell.branch(1).comp(3) + assert np.all( + view.nodes[["global_branch_index", "global_comp_index"]] == [1, 3] + ), "Expected [1,3] but got {}".format( + view.nodes[["global_branch_index", "global_comp_index"]] + ) cell.set_scope("local") assert cell._scope == "local" - view = cell.branch(0).comp(5) - assert np.all(view.nodes[["global_branch_index", "global_comp_index"]] == [None, None]) + view = cell.branch(1).comp(0) + assert np.all( + view.nodes[["global_branch_index", "global_comp_index"]] == [1, 3] + ), "Expected [1,3] but got {}".format( + view.nodes[["global_branch_index", "global_comp_index"]] + ) + def test_context_manager(): comp = jx.Compartment() - branch = jx.Branch(comp, nseg=3) + branch = jx.Branch([comp] * 3) cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) with cell.branch(0).comp(0) as comp: comp.set("v", -71) - comp.set("radius", 0.1) - + comp.set("radius", 0.123) + with cell.branch(1).comp([0, 1]) as comps: comps.set("v", -71) - comps.set("radius", 0.1) + comps.set("radius", 0.123) + + assert np.all( + cell.branch(0).comp(1).nodes[["v", "radius"]] == [-70, 1.0] + ), "Set affected nodes not in context manager View." + assert np.all( + cell.branch(0).comp(0).nodes[["v", "radius"]] == [-71, 0.123] + ), "Context management of View not working." + assert np.all( + cell.branch(1).comp([0, 1]).nodes[["v", "radius"]] == [-71, 0.123] + ), "Context management of View not working." - assert np.all(cell.branch(0).comp(0).nodes[["v", "radius"]] == [-71, 0.1]) - assert np.all(cell.branch(1).comp([0, 1]).nodes[["v", "radius"]] == [-71, 0.1]) def test_iter(): comp = jx.Compartment() - branch1 = jx.Branch(comp, nseg=2) - branch2 = jx.Branch(comp, nseg=3) + branch1 = jx.Branch([comp] * 2) + branch2 = jx.Branch([comp] * 3) cell = jx.Cell([branch1, branch1, branch2], parents=[-1, 0, 0]) net = jx.Network([cell] * 2) - [len(branch.nodes) for branch in cell.branches] + # test iterating over bracnhes with different numbers of compartments + assert np.all( + [ + len(branch.nodes) == expected_len + for branch, expected_len in zip(cell.branches, [2, 2, 3]) + ] + ), "__iter__ failed for branches with different numbers of compartments." + # test iterating using cells, branches, and comps properties + nodes1 = [] for cell in net.cells: for branch in cell.branches: for comp in branch.comps: - pass + nodes1.append(comp.nodes) + assert len(nodes1) == len(net.nodes), "Some compartments were skipped in iteration." + nodes2 = [] for cell in net: for branch in cell: for comp in branch: - pass - - [len(comp.nodes) for comp in net[0, 0].comps] - [len(comp.nodes) for comp in net.comps] - - - - - - -# # iterables -# for cell in net.cells: -# for branch in cell.branches: -# for comp in branch.comps: -# comp.set("v", -71) + nodes2.append(comp.nodes) + assert len(nodes2) == len(net.nodes), "Some compartments were skipped in iteration." + assert np.all( + [np.all(n1 == n2) for n1, n2 in zip(nodes1, nodes2)] + ), "__iter__ is not consistent with [comp.nodes for cell in net.cells for branches in cell.branches for comp in branches.comps]" + + assert np.all( + [len(comp.nodes) for comp in net[0, 0].comps] == [1, 1] + ), "Iterator yielded unexpected number of compartments" + + # 0th comp in every branch (3), 1st comp in every branch (3), 2nd comp in (every) branch (only 1 branch with > 2 comps) + assert np.all( + [len(comp.nodes) for comp in net[0].comps] == [3, 3, 1] + ), "Iterator yielded unexpected number of compartments" + + # 0th comp in every branch for every cell (6), 1st comp in every branch for every cell , 2nd comp in (every) branch for every cell + assert np.all( + [len(comp.nodes) for comp in net.comps] == [6, 6, 2] + ), "Iterator yielded unexpected number of compartments" + + for comp in branch1: + comp.set("v", -72) + assert np.all(branch1.nodes["v"] == -72), "Setting parameters with __iter__ failed." + + # needs to be redefined because cell was overwritten with View object + cell = jx.Cell([branch1, branch1, branch2], parents=[-1, 0, 0]) + for branch in cell: + for comp in branch: + comp.set("v", -73) + assert np.all(cell.nodes["v"] == -73), "Setting parameters with __iter__ failed." -# for comp in net.cell(0).branch(0).comps: -# comp.set("v", -72) -# net.show()[["v"]] +def test_synapse_and_channel_filtering(): + comp = jx.Compartment() + branch = jx.Branch([comp] * 3) + cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) + net = jx.Network([cell] * 3) + net.insert(HH()) + connect(net[0, 0, :], net[1, 0, :], TestSynapse()) -# # groups -# net.cell(1).branch(0).add_group("group") -# net.group.show() + assert np.all(net.cell(0).HH.nodes == net.HH.cell(0).nodes) + view1 = net.cell([0, 1]).TestSynapse + nodes1 = view1.nodes + edges1 = view1.edges + view2 = net.TestSynapse.cell([0, 1]) + nodes2 = view2.nodes + edges2 = view2.edges + nodes_control_param1 = nodes1.pop("controlled_by_param") + nodes_control_param2 = nodes2.pop("controlled_by_param") + edges_control_param1 = edges1.pop("controlled_by_param") + edges_control_param2 = edges2.pop("controlled_by_param") -# # Channel and Synapse views -# net.HH.show() -# net.cell(0).HH.nodes -# net.HH.cell(0).nodes + assert np.all(nodes1 == nodes2) + assert np.all(nodes_control_param1 == 0) + assert np.all(nodes_control_param2 == nodes2["global_cell_index"]) -# net.TestSynapse.nodes -# net.cell([0,1]).TestSynapse.nodes -# net.TestSynapse.cell(0).nodes + assert np.all(edges1 == edges2) + assert np.all(net.edge(0).TestSynapse.nodes == net.TestSynapse.edge(0).nodes) + assert np.all(net.edge(0).TestSynapse.edges == net.TestSynapse.edge(0).edges) -# # edges -# net.edge([0,1,2]).edges -# # copying -# cell0 = net.cell(0).copy() -# cell0.show() \ No newline at end of file +# TODO: test copying/extracting views/module parts From 2f660529864ae0c94951d0da937a986a1d14d6cf Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Wed, 23 Oct 2024 15:09:37 +0200 Subject: [PATCH 03/17] fix: add cumsum_nseg to View and remove cell_list in Network after used to init SolveIndexer --- jaxley/modules/base.py | 2 ++ jaxley/modules/network.py | 21 ++++++++++----------- tests/test_viewing.py | 3 --- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index ee4bc998..7154e10d 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -2208,6 +2208,8 @@ def __init__( self.cumsum_nbranches = jnp.cumsum(np.asarray(self.nbranches_per_cell)) self.comb_branches_in_each_level = pointer.comb_branches_in_each_level self.branch_edges = pointer.branch_edges.loc[self._branch_edges_in_view] + self.nseg_per_branch = self.base.nseg_per_branch[self._branches_in_view] + self.cumsum_nseg = cumsum_leading_zero(self.nseg_per_branch) self.synapse_names = np.unique(self.edges["type"]).tolist() self._set_synapses_in_view(pointer) diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index b4de9771..10a74f17 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -52,7 +52,7 @@ def __init__( for cell in cells: self.xyzr += deepcopy(cell.xyzr) - self.cells_list = cells # TODO: TEMPORARY FIX, REMOVE BY ADDING ATTRS TO VIEW (solve_indexer.children_in_level) + self._cells_list = cells self.nseg_per_branch = np.concatenate([cell.nseg_per_branch for cell in cells]) self.nseg = int(np.max(self.nseg_per_branch)) self.cumsum_nseg = cumsum_leading_zero(self.nseg_per_branch) @@ -119,18 +119,18 @@ def _init_morph_jaxley_spsolve(self): children_in_level = merge_cells( self.cumsum_nbranches, self.cumsum_nbranchpoints_per_cell, - [cell.solve_indexer.children_in_level for cell in self.cells_list], + [cell.solve_indexer.children_in_level for cell in self._cells_list], exclude_first=False, ) parents_in_level = merge_cells( self.cumsum_nbranches, self.cumsum_nbranchpoints_per_cell, - [cell.solve_indexer.parents_in_level for cell in self.cells_list], + [cell.solve_indexer.parents_in_level for cell in self._cells_list], exclude_first=False, ) padded_cumsum_nseg = cumsum_leading_zero( np.concatenate( - [np.diff(cell.solve_indexer.cumsum_nseg) for cell in self.cells_list] + [np.diff(cell.solve_indexer.cumsum_nseg) for cell in self._cells_list] ) ) @@ -171,12 +171,12 @@ def _init_morph_jax_spsolve(self): `type == 4`: child-compartment --> branchpoint """ self._cumsum_nseg_per_cell = cumsum_leading_zero( - jnp.asarray([cell.cumsum_nseg[-1] for cell in self.cells_list]) + jnp.asarray([cell.cumsum_nseg[-1] for cell in self.cells]) ) self._comp_edges = pd.DataFrame() # Add all the internal nodes. - for offset, cell in zip(self._cumsum_nseg_per_cell, self.cells_list): + for offset, cell in zip(self._cumsum_nseg_per_cell, self._cells_list): condition = cell._comp_edges["type"].to_numpy() == 0 rows = cell._comp_edges[condition] self._comp_edges = pd.concat( @@ -188,7 +188,7 @@ def _init_morph_jax_spsolve(self): for offset, offset_branchpoints, cell in zip( self._cumsum_nseg_per_cell, self.cumsum_nbranchpoints_per_cell, - self.cells_list, + self._cells_list, ): offset_within_cell = cell.cumsum_nseg[-1] condition = cell._comp_edges["type"].isin([1, 2]) @@ -210,7 +210,7 @@ def _init_morph_jax_spsolve(self): for offset, offset_branchpoints, cell in zip( self._cumsum_nseg_per_cell, self.cumsum_nbranchpoints_per_cell, - self.cells_list, + self._cells_list, ): offset_within_cell = cell.cumsum_nseg[-1] condition = cell._comp_edges["type"].isin([3, 4]) @@ -228,8 +228,7 @@ def _init_morph_jax_spsolve(self): ignore_index=True, ) - # Note that, unlike in `cell.py`, we cannot delete `self.cells_list` here because - # it is used in plotting. + del self._cells_list # Convert comp_edges to the index format required for `jax.sparse` solvers. n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges) @@ -533,7 +532,7 @@ def build_extents(*subset_sizes): for i, layer in enumerate(layers): graph.add_nodes_from(layer, layer=i) else: - graph.add_nodes_from(range(len(self.cells_list))) + graph.add_nodes_from(range(len(self._cells_list))) pre_cell = self.edges["global_pre_cell_index"].to_numpy() post_cell = self.edges["global_post_cell_index"].to_numpy() diff --git a/tests/test_viewing.py b/tests/test_viewing.py index 6cf3460e..9358c269 100644 --- a/tests/test_viewing.py +++ b/tests/test_viewing.py @@ -258,8 +258,6 @@ def test_solve_indexer(): assert np.all(idx.upper(branch_inds) == np.asarray([[0, 1, 2], [7, 8, 9]])) -# TODO: tests - comp = jx.Compartment() branch = jx.Branch(comp, nseg=3) cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) @@ -274,7 +272,6 @@ def test_view_attrs(module: jx.Compartment | jx.Branch | jx.Cell | jx.Network): exceptions = ["view"] # TODO: should be added to View in the future exceptions += [ - "cumsum_nseg", "_internal_node_inds", "par_inds", "child_inds", From ef7dc88b470085eee9d5020a6d531e009ebee249 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Wed, 23 Oct 2024 15:13:05 +0200 Subject: [PATCH 04/17] fix: ammend prev commit --- jaxley/modules/network.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 10a74f17..29909f8f 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -105,7 +105,8 @@ def __init__( # Channels. self._gather_channels_from_constituents(cells) - self._initialize() + self.initialize() + del self._cells_list def __repr__(self): return f"{type(self).__name__} with {len(self.channels)} different channels and {len(self.synapses)} synapses. Use `.nodes` or `.edges` for details." @@ -228,8 +229,6 @@ def _init_morph_jax_spsolve(self): ignore_index=True, ) - del self._cells_list - # Convert comp_edges to the index format required for `jax.sparse` solvers. n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges) self._n_nodes = n_nodes From 7e621eefe8e03f6de537948511a874b6f295c5b5 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Wed, 23 Oct 2024 15:16:51 +0200 Subject: [PATCH 05/17] fix: fold todos into funcs --- jaxley/modules/base.py | 52 +++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 7154e10d..cd3093a3 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -476,27 +476,27 @@ def edge(self, idx: Any) -> View: View of the module at the specified edge index.""" return self._at_edges("edge", idx) - # TODO: pre and post could just modify scope - # -> self.scope=self.scope+"_pre" and then call edge? - # def pre(self, idx: Any) -> View: - # """Return a View of the module at the selected pre-synaptic compartments(s). + def pre(self, idx: Any) -> View: + """Return a View of the module at the selected pre-synaptic compartments(s). - # Args: - # idx: index of the edge to view. + Args: + idx: index of the edge to view. - # Returns: - # View of the module filtered by the selected pre-comp index.""" - # return self._at_edges("edge", idx) + Returns: + View of the module filtered by the selected pre-comp index.""" + # TODO: pre and post could just modify scope + # -> self.scope=self.scope+"_pre" and then call edge? + return None # self._at_edges("edge", idx) - # def post(self, idx: Any) -> View: - # """Return a View of the module at the selected post-synaptic compartments(s). + def post(self, idx: Any) -> View: + """Return a View of the module at the selected post-synaptic compartments(s). - # Args: - # idx: index of the edge to view. + Args: + idx: index of the edge to view. - # Returns: - # View of the module filtered by the selected post-comp index.""" - # return self._at_edges("edge", idx) + Returns: + View of the module filtered by the selected post-comp index.""" + return None # self._at_edges("edge", idx) def loc(self, at: Any) -> View: """Return a View of the module at the selected branch location(s). @@ -661,8 +661,8 @@ def _gather_channels_from_constituents(self, constituents: List): name = channel._name self.base.nodes.loc[self.nodes[name].isna(), name] = False - # TODO: Make this work for View? def to_jax(self): + # TODO: Make this work for View? """Move `.nodes` to `.jaxnodes`. Before the actual simulation is run (via `jx.integrate`), all parameters of @@ -690,7 +690,7 @@ def to_jax(self): def show( self, - param_names: Optional[Union[str, List[str]]] = None, # TODO. + param_names: Optional[Union[str, List[str]]] = None, *, indices: bool = True, params: bool = True, @@ -1107,8 +1107,8 @@ def distance(self, endpoint: "View") -> float: end_xyz = endpoint.xyzr[0][0, :3] return np.sqrt(np.sum((start_xyz - end_xyz) ** 2)) - # TODO: MAKE THIS WORK FOR VIEW? def delete_trainables(self): + # TODO: MAKE THIS WORK FOR VIEW? """Removes all trainable parameters from the module.""" assert isinstance(self, Module), "Only supports modules." self.base.indices_set_by_trainables = [] @@ -1134,8 +1134,8 @@ def add_to_group(self, group_name: str): np.concatenate([self.base.groups[group_name], self._nodes_in_view]) ) - # TODO: MAKE THIS WORK FOR VIEW? def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: + # TODO: MAKE THIS WORK FOR VIEW? """Get all trainable parameters. The returned parameters should be passed to `jx.integrate(..., params=params). @@ -1146,10 +1146,10 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: """ return self.base.trainable_params - # TODO: MAKE THIS WORK FOR VIEW? def get_all_parameters( self, pstate: List[Dict], voltage_solver: str ) -> Dict[str, jnp.ndarray]: + # TODO: MAKE THIS WORK FOR VIEW? """Return all parameters (and coupling conductances) needed to simulate. Runs `_compute_axial_conductances()` and return every parameter that is needed @@ -1235,10 +1235,10 @@ def _get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]: states[synapse_states] = self.base.jaxedges[synapse_states] return states - # TODO: MAKE THIS WORK FOR VIEW? def get_all_states( self, pstate: List[Dict], all_params, delta_t: float ) -> Dict[str, jnp.ndarray]: + # TODO: MAKE THIS WORK FOR VIEW? """Get the full initial state of the module from jaxnodes and trainables. Args: @@ -1284,8 +1284,8 @@ def _initialize(self): self._init_morph() return self - # TODO: MAKE THIS WORK FOR VIEW? def init_states(self, delta_t: float = 0.025): + # TODO: MAKE THIS WORK FOR VIEW? """Initialize all mechanisms in their steady state. This considers the voltages and parameters of each compartment. @@ -1413,8 +1413,8 @@ def record(self, state: str = "v", verbose=True): f"Added {len(in_view)-sum(has_duplicates)} recordings. See `.recordings` for details." ) - # TODO: MAKE THIS WORK FOR VIEW? def delete_recordings(self): + # TODO: MAKE THIS WORK FOR VIEW? """Removes all recordings from the module.""" assert isinstance(self, Module), "Only supports modules." self.base.recordings = pd.DataFrame().from_dict({}) @@ -1565,15 +1565,15 @@ def _data_external_input( return (state_name, external_input, inds) - # TODO: MAKE THIS WORK FOR VIEW? def delete_stimuli(self): + # TODO: MAKE THIS WORK FOR VIEW? """Removes all stimuli from the module.""" assert isinstance(self, Module), "Only supports modules." self.base.externals.pop("i", None) self.base.external_inds.pop("i", None) - # TODO: MAKE THIS WORK FOR VIEW? def delete_clamps(self, state_name: str): + # TODO: MAKE THIS WORK FOR VIEW? """Removes all clamps of the given state from the module.""" assert isinstance(self, Module), "Only supports modules." self.base.externals.pop(state_name, None) From 6b74f68f86f6de6a2f4172ab151f3b84daa274f5 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Wed, 23 Oct 2024 19:08:03 +0200 Subject: [PATCH 06/17] wip: save wip, rm pre/post,make delete methods local, fixes --- jaxley/modules/base.py | 130 +++++++++++++++++++++++--------------- jaxley/modules/network.py | 2 +- 2 files changed, 81 insertions(+), 51 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index cd3093a3..5a646081 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -476,28 +476,6 @@ def edge(self, idx: Any) -> View: View of the module at the specified edge index.""" return self._at_edges("edge", idx) - def pre(self, idx: Any) -> View: - """Return a View of the module at the selected pre-synaptic compartments(s). - - Args: - idx: index of the edge to view. - - Returns: - View of the module filtered by the selected pre-comp index.""" - # TODO: pre and post could just modify scope - # -> self.scope=self.scope+"_pre" and then call edge? - return None # self._at_edges("edge", idx) - - def post(self, idx: Any) -> View: - """Return a View of the module at the selected post-synaptic compartments(s). - - Args: - idx: index of the edge to view. - - Returns: - View of the module filtered by the selected post-comp index.""" - return None # self._at_edges("edge", idx) - def loc(self, at: Any) -> View: """Return a View of the module at the selected branch location(s). @@ -623,7 +601,7 @@ def copy( @property def view(self): """Return view of the module.""" - return View(self, self._nodes_in_view) + return View(self, self._nodes_in_view, self._edges_in_view) @property def _module_type(self): @@ -1108,12 +1086,19 @@ def distance(self, endpoint: "View") -> float: return np.sqrt(np.sum((start_xyz - end_xyz) ** 2)) def delete_trainables(self): - # TODO: MAKE THIS WORK FOR VIEW? + # TODO: Test that this correctly works for View! """Removes all trainable parameters from the module.""" - assert isinstance(self, Module), "Only supports modules." - self.base.indices_set_by_trainables = [] - self.base.trainable_params = [] - self.base.num_trainable_params = 0 + + if isinstance(self, View): + trainables_and_inds = self._filter_trainables(is_viewed=False) + self.base.indices_set_by_trainables = trainables_and_inds[0] + self.base.trainable_params = trainables_and_inds[1] + self.base.num_trainable_params -= self.num_trainable_params + else: + self.base.indices_set_by_trainables = [] + self.base.trainable_params = [] + self.base.num_trainable_params = 0 + self._update_view() def add_to_group(self, group_name: str): """Add a view of the module to a group. @@ -1135,7 +1120,6 @@ def add_to_group(self, group_name: str): ) def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: - # TODO: MAKE THIS WORK FOR VIEW? """Get all trainable parameters. The returned parameters should be passed to `jx.integrate(..., params=params). @@ -1144,7 +1128,7 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: A list of all trainable parameters in the form of [{"gNa": jnp.array([0.1, 0.2, 0.3])}, ...]. """ - return self.base.trainable_params + return self.trainable_params def get_all_parameters( self, pstate: List[Dict], voltage_solver: str @@ -1400,9 +1384,11 @@ def _init_morph_for_debugging(self): self.base.debug_states["par_inds"] = self.base.par_inds def record(self, state: str = "v", verbose=True): - in_view = ( - self._edges_in_view if state in self.edges.columns else self._nodes_in_view - ) + in_view = None + in_view = self._edges_in_view if state in self.edges.columns else in_view + in_view = self._nodes_in_view if state in self.nodes.columns else in_view + assert in_view is not None, "State not found in nodes or edges." + new_recs = pd.DataFrame(in_view, columns=["rec_index"]) new_recs["state"] = state self.base.recordings = pd.concat([self.base.recordings, new_recs]) @@ -1413,11 +1399,27 @@ def record(self, state: str = "v", verbose=True): f"Added {len(in_view)-sum(has_duplicates)} recordings. See `.recordings` for details." ) + def _update_view(self): + """Update the attrs of the view after changes in the base module.""" + if isinstance(self, View): + scope = self._scope + current_view = self._current_view + self.__dict__ = View( + self.base, self._nodes_in_view, self._edges_in_view + ).__dict__ + self._scope = scope + self._current_view = current_view + def delete_recordings(self): - # TODO: MAKE THIS WORK FOR VIEW? """Removes all recordings from the module.""" - assert isinstance(self, Module), "Only supports modules." - self.base.recordings = pd.DataFrame().from_dict({}) + if isinstance(self, View): + base_recs = self.base.recordings + self.base.recordings = base_recs[ + ~base_recs.isin(self.recordings).all(axis=1) + ] + self._update_view() + else: + self.base.recordings = pd.DataFrame().from_dict({}) def stimulate(self, current: Optional[jnp.ndarray] = None, verbose: bool = True): """Insert a stimulus into the compartment. @@ -1566,18 +1568,26 @@ def _data_external_input( return (state_name, external_input, inds) def delete_stimuli(self): - # TODO: MAKE THIS WORK FOR VIEW? """Removes all stimuli from the module.""" - assert isinstance(self, Module), "Only supports modules." - self.base.externals.pop("i", None) - self.base.external_inds.pop("i", None) + self.delete_clamps("i") def delete_clamps(self, state_name: str): - # TODO: MAKE THIS WORK FOR VIEW? """Removes all clamps of the given state from the module.""" - assert isinstance(self, Module), "Only supports modules." - self.base.externals.pop(state_name, None) - self.base.external_inds.pop(state_name, None) + if state_name in self.externals: + keep_inds = ~np.isin( + self.base.external_inds[state_name], self._nodes_in_view + ) + base_exts = self.base.externals + base_exts_inds = self.base.external_inds + if np.all(~keep_inds): + base_exts.pop(state_name, None) + base_exts_inds.pop(state_name, None) + else: + base_exts[state_name] = base_exts[state_name][keep_inds] + base_exts_inds[state_name] = base_exts_inds[state_name][keep_inds] + self._update_view() + else: + pass # does not have to be deleted if not in externals def insert(self, channel: Channel): """Insert a channel into the module. @@ -2324,7 +2334,14 @@ def _set_externals_in_view(self): self.externals[name] = data[in_view] self.external_inds[name] = inds_in_view - def _set_trainables_in_view(self): + def _filter_trainables( + self, is_viewed: bool = True + ) -> Tuple[List[np.ndarray], List[Dict]]: + """filters the trainables inside and outside of the view + + Args: + is_viewed: Toggles between returning the trainables and inds + currently inside or outside of the scope of View.""" trainable_inds = self.base.indices_set_by_trainables trainable_inds = ( np.unique(np.hstack([inds.reshape(-1) for inds in trainable_inds])) @@ -2340,7 +2357,7 @@ def _set_trainables_in_view(self): for inds, params in zip( self.base.indices_set_by_trainables, self.base.trainable_params ): - in_view = np.isin(inds, trainable_node_inds_in_view) + in_view = is_viewed == np.isin(inds, trainable_node_inds_in_view) completely_in_view = in_view.all(axis=1) índices_set_by_trainables_in_view.append(inds[completely_in_view]) @@ -2358,7 +2375,9 @@ def _set_trainables_in_view(self): # TODO: working but ugly. maybe integrate into above loop trainable_names = np.array([next(iter(d)) for d in self.base.trainable_params]) - is_syn_trainable_in_view = np.isin(trainable_names, self.synapse_param_names) + is_syn_trainable_in_view = is_viewed * np.isin( + trainable_names, self.synapse_param_names + ) syn_trainable_names_in_view = trainable_names[is_syn_trainable_in_view] syn_trainable_inds_in_view = np.intersect1d( syn_trainable_names_in_view, trainable_names, return_indices=True @@ -2367,7 +2386,9 @@ def _set_trainables_in_view(self): syn_name = trainable_names[idx].split("_")[0] syn_edges = self.base.edges[self.base.edges["type"] == syn_name] syn_inds = np.arange(len(syn_edges)) - syn_inds_in_view = syn_inds[np.isin(syn_edges.index, self._edges_in_view)] + syn_inds_in_view = syn_inds[ + is_viewed == np.isin(syn_edges.index, self._edges_in_view) + ] syn_trainable_params_in_view = { k: v[syn_inds_in_view] @@ -2379,12 +2400,21 @@ def _set_trainables_in_view(self): ][syn_inds_in_view] índices_set_by_trainables_in_view.append(syn_inds_set_by_trainables_in_view) - self.indices_set_by_trainables = [ + indices_set_by_trainables = [ inds for inds in índices_set_by_trainables_in_view if len(inds) > 0 ] - self.trainable_params = [ + trainable_params = [ p for p in trainable_params_in_view if len(next(iter(p.values()))) > 0 ] + return indices_set_by_trainables, trainable_params + + def _set_trainables_in_view(self): + # TODO: Test this! The two examples below are also buggy! + # net.cell([0,1]).branch(0).comp(0).trainable_params + # net.cell(0).branch(0).comp(0).trainable_params + self.indices_set_by_trainables, self.trainable_params = ( + self._filter_trainables() + ) def _channels_in_view(self, pointer: Union[Module, View]) -> List[Channel]: names = [name._name for name in pointer.channels] diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 29909f8f..1a1a7f86 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -531,7 +531,7 @@ def build_extents(*subset_sizes): for i, layer in enumerate(layers): graph.add_nodes_from(layer, layer=i) else: - graph.add_nodes_from(range(len(self._cells_list))) + graph.add_nodes_from(range(len(self._cells_in_view))) pre_cell = self.edges["global_pre_cell_index"].to_numpy() post_cell = self.edges["global_post_cell_index"].to_numpy() From c5076e90602fbe5f68c820eedf4ba1d940e59010 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Thu, 24 Oct 2024 15:22:10 +0200 Subject: [PATCH 07/17] fix: fix trainables and local inds in edges --- jaxley/modules/base.py | 116 +++++++++++++++++++---------------------- 1 file changed, 54 insertions(+), 62 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 5a646081..3326e2a9 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -227,18 +227,29 @@ def reindex_a_by_b( return df index_names = ["cell_index", "branch_index", "comp_index"] # order is important - for obj, prefix in zip( - [self.nodes, self.edges, self.edges], ["", "pre_", "post_"] - ): - global_idx_cols = [f"global_{prefix}{name}" for name in index_names] - local_idx_cols = [f"local_{prefix}{name}" for name in index_names] - idcs = obj[global_idx_cols] - - idcs = reindex_a_by_b(idcs, global_idx_cols[0]) - idcs = reindex_a_by_b(idcs, global_idx_cols[1], global_idx_cols[0]) - idcs = reindex_a_by_b(idcs, global_idx_cols[2], global_idx_cols[:2]) - idcs.columns = [col.replace("global", "local") for col in global_idx_cols] - obj[local_idx_cols] = idcs[local_idx_cols].astype(int) + global_idx_cols = [f"global_{name}" for name in index_names] + local_idx_cols = [f"local_{name}" for name in index_names] + local_pre_cols = [f"local_pre_{name}" for name in index_names] + local_post_cols = [f"local_post_{name}" for name in index_names] + idcs = self.nodes[global_idx_cols] + + # update local indices of nodes + idcs = reindex_a_by_b(idcs, global_idx_cols[0]) + idcs = reindex_a_by_b(idcs, global_idx_cols[1], global_idx_cols[0]) + idcs = reindex_a_by_b(idcs, global_idx_cols[2], global_idx_cols[:2]) + idcs.columns = [col.replace("global", "local") for col in global_idx_cols] + self.nodes[local_idx_cols] = idcs[local_idx_cols].astype(int) + + # add local indices of nodes to edges + global_pre_inds = self.edges["global_pre_comp_index"] + global_post_inds = self.edges["global_post_comp_index"] + global_node_inds = self.nodes["global_comp_index"] + flat = lambda x: np.array(x).flatten() + is_pre = flat([np.where(global_node_inds == i)[0] for i in global_pre_inds]) + is_post = flat([np.where(global_node_inds == i)[0] for i in global_post_inds]) + local_node_inds = self.nodes[local_idx_cols] + self.edges.loc[:, local_pre_cols] = local_node_inds.loc[is_pre].to_numpy() + self.edges.loc[:, local_post_cols] = local_node_inds.loc[is_post].to_numpy() # move indices to the front of the dataframe; move controlled_by_param to the end self.nodes = reorder_cols( @@ -1086,7 +1097,6 @@ def distance(self, endpoint: "View") -> float: return np.sqrt(np.sum((start_xyz - end_xyz) ** 2)) def delete_trainables(self): - # TODO: Test that this correctly works for View! """Removes all trainable parameters from the module.""" if isinstance(self, View): @@ -1187,6 +1197,7 @@ def get_all_parameters( # TODO: Longterm this should be gotten rid of. # Instead edges should work similar to nodes (would also allow for # param sharing). + # TODO: URGENT: FIX THIS SHIT if key in self.base.synapse_param_names: syn_name_from_param = key.split("_")[0] syn_edges = self.__getattr__(syn_name_from_param).edges @@ -2342,63 +2353,44 @@ def _filter_trainables( Args: is_viewed: Toggles between returning the trainables and inds currently inside or outside of the scope of View.""" - trainable_inds = self.base.indices_set_by_trainables - trainable_inds = ( - np.unique(np.hstack([inds.reshape(-1) for inds in trainable_inds])) - if len(trainable_inds) > 0 - else [] - ) - trainable_node_inds_in_view = np.intersect1d( - trainable_inds, self._nodes_in_view - ) - índices_set_by_trainables_in_view = [] trainable_params_in_view = [] for inds, params in zip( self.base.indices_set_by_trainables, self.base.trainable_params ): - in_view = is_viewed == np.isin(inds, trainable_node_inds_in_view) - + pkey, pval = next(iter(params.items())) + trainable_inds_in_view = None + if pkey in sum( + [list(c.channel_params.keys()) for c in self.base.channels], [] + ): + trainable_inds_in_view = np.intersect1d(inds, self._nodes_in_view) + elif pkey in sum( + [list(s.synapse_params.keys()) for s in self.base.synapses], [] + ): + trainable_inds_in_view = np.intersect1d(inds, self._edges_in_view) + + in_view = is_viewed == np.isin(inds, trainable_inds_in_view) completely_in_view = in_view.all(axis=1) - índices_set_by_trainables_in_view.append(inds[completely_in_view]) + partially_in_view = in_view.any(axis=1) & ~completely_in_view + trainable_params_in_view.append( {k: v[completely_in_view] for k, v in params.items()} ) - - partially_in_view = in_view.any(axis=1) & ~completely_in_view - índices_set_by_trainables_in_view.append( - inds[partially_in_view][in_view[partially_in_view]] - ) trainable_params_in_view.append( {k: v[partially_in_view] for k, v in params.items()} ) - # TODO: working but ugly. maybe integrate into above loop - trainable_names = np.array([next(iter(d)) for d in self.base.trainable_params]) - is_syn_trainable_in_view = is_viewed * np.isin( - trainable_names, self.synapse_param_names - ) - syn_trainable_names_in_view = trainable_names[is_syn_trainable_in_view] - syn_trainable_inds_in_view = np.intersect1d( - syn_trainable_names_in_view, trainable_names, return_indices=True - )[2] - for idx in syn_trainable_inds_in_view: - syn_name = trainable_names[idx].split("_")[0] - syn_edges = self.base.edges[self.base.edges["type"] == syn_name] - syn_inds = np.arange(len(syn_edges)) - syn_inds_in_view = syn_inds[ - is_viewed == np.isin(syn_edges.index, self._edges_in_view) - ] + índices_set_by_trainables_in_view.append(inds[completely_in_view]) + partial_inds = inds[partially_in_view][in_view[partially_in_view]] - syn_trainable_params_in_view = { - k: v[syn_inds_in_view] - for k, v in self.base.trainable_params[idx].items() - } - trainable_params_in_view.append(syn_trainable_params_in_view) - syn_inds_set_by_trainables_in_view = self.base.indices_set_by_trainables[ - idx - ][syn_inds_in_view] - índices_set_by_trainables_in_view.append(syn_inds_set_by_trainables_in_view) + # the indexing above can lead to inconsistent shapes. + # this is fixed here to return them to the prev shape + if inds.shape[0] > 1 and partial_inds.shape != (0,): + partial_inds = partial_inds.reshape(-1, 1) + if inds.shape[1] > 1 and partial_inds.shape != (0,): + partial_inds = partial_inds.reshape(1, -1) + + índices_set_by_trainables_in_view.append(partial_inds) indices_set_by_trainables = [ inds for inds in índices_set_by_trainables_in_view if len(inds) > 0 @@ -2409,12 +2401,12 @@ def _filter_trainables( return indices_set_by_trainables, trainable_params def _set_trainables_in_view(self): - # TODO: Test this! The two examples below are also buggy! - # net.cell([0,1]).branch(0).comp(0).trainable_params - # net.cell(0).branch(0).comp(0).trainable_params - self.indices_set_by_trainables, self.trainable_params = ( - self._filter_trainables() - ) + trainables = self._filter_trainables() + + # note for `branch.comp(0).make_trainable("X"); branch.make_trainable("X")` + # `view = branch.comp(0)` will have duplicate training params. + self.indices_set_by_trainables = trainables[0] + self.trainable_params = trainables[1] def _channels_in_view(self, pointer: Union[Module, View]) -> List[Channel]: names = [name._name for name in pointer.channels] From 90a0382341560022c93f412b63cebcc9f60bf535 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Thu, 24 Oct 2024 17:10:28 +0200 Subject: [PATCH 08/17] fix: fix jit simulate with data_stimulate issues and streamline edges --- jaxley/modules/base.py | 57 +++++++++++++++------------------------ jaxley/modules/network.py | 24 ++++++++++------- 2 files changed, 36 insertions(+), 45 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 3326e2a9..c8a4b362 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -68,19 +68,15 @@ def __init__(self): self._edges_in_view: np.ndarray = None self.edges = pd.DataFrame( - columns=["global_edge_index"] - + [ - f"global_{lvl}_index" - for lvl in [ - "pre_comp", - "pre_branch", - "pre_cell", - "post_comp", - "post_branch", - "post_cell", - ] + columns=[ + "global_edge_index", + "global_pre_comp_index", + "global_post_comp_index", + "pre_locs", + "post_locs", + "type", + "type_ind", ] - + ["pre_locs", "post_locs", "type", "type_ind"] ) self.cumsum_nbranches: Optional[np.ndarray] = None @@ -162,11 +158,18 @@ def __getattr__(self, key): # intercepts calls to synapse types if key in self.base.synapse_names: - syn_inds = self.edges.index[self.edges["type"] == key].to_numpy() + syn_inds = self.edges[self.edges["type"] == key][ + "global_edge_index" + ].to_numpy() + orig_scope = self._scope view = ( - self.edge(syn_inds) if key in self.synapse_names else self.select(None) + self.scope("global").edge(syn_inds).scope(orig_scope) + if key in self.synapse_names + else self.select(None) ) view._set_controlled_by_param(key) # overwrites param set by edge + # Temporary fix for synapse param sharing + view.edges["local_edge_index"] = np.arange(len(view.edges)) return view def _childviews(self) -> List[str]: @@ -229,8 +232,6 @@ def reindex_a_by_b( index_names = ["cell_index", "branch_index", "comp_index"] # order is important global_idx_cols = [f"global_{name}" for name in index_names] local_idx_cols = [f"local_{name}" for name in index_names] - local_pre_cols = [f"local_pre_{name}" for name in index_names] - local_post_cols = [f"local_post_{name}" for name in index_names] idcs = self.nodes[global_idx_cols] # update local indices of nodes @@ -240,17 +241,6 @@ def reindex_a_by_b( idcs.columns = [col.replace("global", "local") for col in global_idx_cols] self.nodes[local_idx_cols] = idcs[local_idx_cols].astype(int) - # add local indices of nodes to edges - global_pre_inds = self.edges["global_pre_comp_index"] - global_post_inds = self.edges["global_post_comp_index"] - global_node_inds = self.nodes["global_comp_index"] - flat = lambda x: np.array(x).flatten() - is_pre = flat([np.where(global_node_inds == i)[0] for i in global_pre_inds]) - is_post = flat([np.where(global_node_inds == i)[0] for i in global_post_inds]) - local_node_inds = self.nodes[local_idx_cols] - self.edges.loc[:, local_pre_cols] = local_node_inds.loc[is_pre].to_numpy() - self.edges.loc[:, local_post_cols] = local_node_inds.loc[is_post].to_numpy() - # move indices to the front of the dataframe; move controlled_by_param to the end self.nodes = reorder_cols( self.nodes, @@ -261,8 +251,7 @@ def reindex_a_by_b( ], ) self.nodes = reorder_cols(self.nodes, ["controlled_by_param"], first=False) - self.edges["local_edge_index"] = rerank(self.edges["global_edge_index"]) - self.edges = reorder_cols(self.edges, ["global_edge_index", "local_edge_index"]) + self.edges = reorder_cols(self.edges, ["global_edge_index"]) self.edges = reorder_cols(self.edges, ["controlled_by_param"], first=False) def _init_view(self): @@ -358,7 +347,7 @@ def _set_controlled_by_param(self, key: str): key: key specifying group / view that is in control of the params.""" if key in ["comp", "branch", "cell"]: self.nodes["controlled_by_param"] = self.nodes[f"global_{key}_index"] - self.edges["controlled_by_param"] = self.edges[f"global_pre_{key}_index"] + self.edges["controlled_by_param"] = 0 elif key == "edge": self.edges["controlled_by_param"] = np.arange(len(self.edges)) elif key == "filter": @@ -1197,12 +1186,10 @@ def get_all_parameters( # TODO: Longterm this should be gotten rid of. # Instead edges should work similar to nodes (would also allow for # param sharing). - # TODO: URGENT: FIX THIS SHIT + synapse_inds = self.base.edges.groupby("type").rank()["global_edge_index"] + synapse_inds = (synapse_inds.astype(int) - 1).to_numpy() if key in self.base.synapse_param_names: - syn_name_from_param = key.split("_")[0] - syn_edges = self.__getattr__(syn_name_from_param).edges - inds = syn_edges.loc[inds.reshape(-1)]["local_edge_index"].values - inds = inds.reshape(-1, 1) + inds = synapse_inds[inds] if key in params: # Only parameters, not initial states. # `inds` is of shape `(num_params, num_comps_per_param)`. diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 1a1a7f86..7efa2c29 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -471,8 +471,11 @@ def vis( pre_locs = self.edges["pre_locs"].to_numpy() post_locs = self.edges["post_locs"].to_numpy() - pre_branch = self.edges["global_pre_branch_index"].to_numpy() - post_branch = self.edges["global_post_branch_index"].to_numpy() + pre_comp = self.edges["global_pre_comp_index"].to_numpy() + nodes = self.nodes.set_index("global_comp_index") + pre_branch = nodes.loc[pre_comp, "global_branch_index"].to_numpy() + post_comp = self.edges["global_post_comp_index"].to_numpy() + post_branch = nodes.loc[post_comp, "global_branch_index"].to_numpy() dims_np = np.asarray(dims) @@ -533,8 +536,11 @@ def build_extents(*subset_sizes): else: graph.add_nodes_from(range(len(self._cells_in_view))) - pre_cell = self.edges["global_pre_cell_index"].to_numpy() - post_cell = self.edges["global_post_cell_index"].to_numpy() + pre_comp = self.edges["global_pre_comp_index"].to_numpy() + nodes = self.nodes.set_index("global_comp_index") + pre_cell = nodes.loc[pre_comp, "global_cell_index"].to_numpy() + post_comp = self.edges["global_post_comp_index"].to_numpy() + post_cell = nodes.loc[post_comp, "global_cell_index"].to_numpy() inds = np.stack([pre_cell, post_cell]).T graph.add_edges_from(inds) @@ -576,11 +582,10 @@ def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type): ) # Define new synapses. Each row is one synapse. - cols = ["comp_index", "branch_index", "cell_index"] - pre_nodes = pre_nodes[[f"global_{col}" for col in cols]] - pre_nodes.columns = [f"global_pre_{col}" for col in cols] - post_nodes = post_nodes[[f"global_{col}" for col in cols]] - post_nodes.columns = [f"global_post_{col}" for col in cols] + pre_nodes = pre_nodes[["global_comp_index"]] + pre_nodes.columns = ["global_pre_comp_index"] + post_nodes = post_nodes[["global_comp_index"]] + post_nodes.columns = ["global_post_comp_index"] new_rows = pd.concat( [ global_edge_index, @@ -589,7 +594,6 @@ def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type): ], axis=1, ) - new_rows["local_edge_index"] = new_rows["global_edge_index"] new_rows["type"] = synapse_name new_rows["type_ind"] = type_ind new_rows["pre_locs"] = pre_loc From 5660e94d2c70096804866b8a4f0737fad230a1eb Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Thu, 24 Oct 2024 17:27:27 +0200 Subject: [PATCH 09/17] fix: make remaining tests pass --- tests/test_connection.py | 43 ++++++++++++++++------------------------ tests/test_viewing.py | 2 -- 2 files changed, 17 insertions(+), 28 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 4d1bd37d..9d790589 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -58,16 +58,13 @@ def test_connect(): # check if all connections are made correctly first_set_edges = net2.edges.iloc[:8] # TODO: VERIFY THAT THIS IS INTENDED BEHAVIOUR! @Michael - assert ( - ( - first_set_edges[["global_pre_branch_index", "global_post_branch_index"]] - == (4, 8) - ) - .all() - .all() - ) - assert (first_set_edges["global_pre_cell_index"] == 1).all() - assert (first_set_edges["global_post_cell_index"] == 2).all() + nodes = net2.nodes.set_index("global_comp_index") + cols = ["global_pre_comp_index", "global_post_comp_index"] + comp_inds = nodes.loc[first_set_edges[cols].to_numpy().flatten()] + branch_inds = comp_inds["global_branch_index"].to_numpy().reshape(-1, 2) + cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) + assert np.all(branch_inds == (4, 8)) + assert (cell_inds == (1, 2)).all() assert ( get_comps(first_set_edges["pre_locs"]) == get_comps(first_set_edges["post_locs"]) @@ -181,14 +178,11 @@ def test_connectivity_matrix_connect(): net[:4], net[4:8], TestSynapse(), n_by_n_adjacency_matrix ) assert len(net.edges.index) == 4 - assert ( - ( - net.edges[["global_pre_cell_index", "global_post_cell_index"]] - == incides_of_connected_cells - ) - .all() - .all() - ) + nodes = net.nodes.set_index("global_comp_index") + cols = ["global_pre_comp_index", "global_post_comp_index"] + comp_inds = nodes.loc[net.edges[cols].to_numpy().flatten()] + cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) + assert np.all(cell_inds == incides_of_connected_cells) m_by_n_adjacency_matrix = np.array( [[0, 1, 1, 0], [0, 0, 1, 1], [0, 0, 0, 1]], dtype=bool @@ -205,11 +199,8 @@ def test_connectivity_matrix_connect(): net[:3], net[:4], TestSynapse(), m_by_n_adjacency_matrix ) assert len(net.edges.index) == 5 - assert ( - ( - net.edges[["global_pre_cell_index", "global_post_cell_index"]] - == incides_of_connected_cells - ) - .all() - .all() - ) + nodes = net.nodes.set_index("global_comp_index") + cols = ["global_pre_comp_index", "global_post_comp_index"] + comp_inds = nodes.loc[net.edges[cols].to_numpy().flatten()] + cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) + assert np.all(cell_inds == incides_of_connected_cells) diff --git a/tests/test_viewing.py b/tests/test_viewing.py index 9358c269..105bb8e2 100644 --- a/tests/test_viewing.py +++ b/tests/test_viewing.py @@ -566,8 +566,6 @@ def test_synapse_and_channel_filtering(): assert np.all(nodes_control_param2 == nodes2["global_cell_index"]) assert np.all(edges1 == edges2) - assert np.all(net.edge(0).TestSynapse.nodes == net.TestSynapse.edge(0).nodes) - assert np.all(net.edge(0).TestSynapse.edges == net.TestSynapse.edge(0).edges) # TODO: test copying/extracting views/module parts From 1b5d42478b99447a79f2a31a7a9f56691621da78 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Thu, 24 Oct 2024 17:28:51 +0200 Subject: [PATCH 10/17] fix: ammend last --- tests/test_connection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 9d790589..5178d24b 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -57,7 +57,6 @@ def test_connect(): # check if all connections are made correctly first_set_edges = net2.edges.iloc[:8] - # TODO: VERIFY THAT THIS IS INTENDED BEHAVIOUR! @Michael nodes = net2.nodes.set_index("global_comp_index") cols = ["global_pre_comp_index", "global_post_comp_index"] comp_inds = nodes.loc[first_set_edges[cols].to_numpy().flatten()] From b33d9b0ce655b56b4d42dd6c41b12a2713631997 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Thu, 24 Oct 2024 17:37:05 +0200 Subject: [PATCH 11/17] doc: prepare for review --- jaxley/modules/base.py | 6 ++++-- tests/test_viewing.py | 20 +++----------------- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index c8a4b362..f1d0a8bb 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -2,6 +2,7 @@ # licensed under the Apache License Version 2.0, see from __future__ import annotations +import warnings from abc import ABC, abstractmethod from copy import deepcopy from itertools import chain @@ -183,7 +184,7 @@ def _childviews(self) -> List[str]: def __getitem__(self, index): supported_lvls = ["network", "cell", "branch"] # cannot index into comp - # TODO: SHOULD WE ALLOW GROUPVIEW TO BE INDEXED? + # TODO FROM #447: SHOULD WE ALLOW GROUPVIEW TO BE INDEXED? # IF YES, UNDER WHICH CONDITIONS? is_group_view = self._current_view in self.groups assert ( @@ -591,7 +592,8 @@ def copy( Returns: A part of the module or a copied view of it.""" view = deepcopy(self) - # TODO: add reset_index, i.e. for parents, nodes, edges etc. such that they + warnings.warn("This method is experimental, use at your own risk.") + # TODO FROM #447: add reset_index, i.e. for parents, nodes, edges etc. such that they # start from 0/-1 and are contiguous if as_module: raise NotImplementedError("Not yet implemented.") diff --git a/tests/test_viewing.py b/tests/test_viewing.py index 105bb8e2..ae37caa1 100644 --- a/tests/test_viewing.py +++ b/tests/test_viewing.py @@ -196,27 +196,16 @@ def test_local_indexing(): ["local_cell_index", "local_branch_index", "local_comp_index"] ] idx_cols = ["global_cell_index", "global_branch_index", "global_comp_index"] - # TODO: Write new and more comprehensive test for local indexing! global_index = 0 for cell_idx in range(2): for branch_idx in range(5): for comp_idx in range(4): - - # compview = net[cell_idx, branch_idx, comp_idx].show() - # assert np.all( - # compview[idx_cols].values == [cell_idx, branch_idx, comp_idx] - # ) assert np.all( local_idxs.iloc[global_index] == [cell_idx, branch_idx, comp_idx] ) global_index += 1 -def test_comp_indexing_exception_handling(): - # TODO: Add tests for indexing exceptions - pass - - def test_indexing_a_compartment_of_many_branches(): comp = jx.Compartment() branch1 = jx.Branch(comp, nseg=3) @@ -227,7 +216,7 @@ def test_indexing_a_compartment_of_many_branches(): net = jx.Network([cell1, cell2]) # Indexing a single compartment of multiple branches is not supported with `loc`. - # TODO: Reevaluate what kind of indexing is allowed and which is not! + # TODO FROM #447: Reevaluate what kind of indexing is allowed and which is not! # with pytest.raises(NotImplementedError): # net.cell("all").branch("all").loc(0.0) # with pytest.raises(NotImplementedError): @@ -270,7 +259,7 @@ def test_solve_indexer(): def test_view_attrs(module: jx.Compartment | jx.Branch | jx.Cell | jx.Network): # attributes of Module that do not have to exist in View exceptions = ["view"] - # TODO: should be added to View in the future + # TODO FROM #447: should be added to View in the future exceptions += [ "_internal_node_inds", "par_inds", @@ -289,7 +278,7 @@ def test_view_attrs(module: jx.Compartment | jx.Branch | jx.Cell | jx.Network): "cumsum_nbranchpoints_per_cell", "_cumsum_nseg_per_cell", ] # for network - exceptions += ["cumsum_nbranches"] # HOTFIX #TODO: take care of this + exceptions += ["cumsum_nbranches"] # TODO: take care of this for name, attr in module.__dict__.items(): if name not in exceptions: @@ -566,6 +555,3 @@ def test_synapse_and_channel_filtering(): assert np.all(nodes_control_param2 == nodes2["global_cell_index"]) assert np.all(edges1 == edges2) - - -# TODO: test copying/extracting views/module parts From 049cdd8d3d33fe0959066d45e2fc65a1d8900361 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Mon, 28 Oct 2024 17:52:26 +0100 Subject: [PATCH 12/17] wip: save wip --- jaxley/modules/base.py | 89 +++++++++++++++++++++++++++++++++--------- tests/test_groups.py | 14 +------ tests/test_viewing.py | 49 +++++++++++++++++++++++ 3 files changed, 121 insertions(+), 31 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index f1d0a8bb..e3f2d019 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -42,6 +42,18 @@ from jaxley.utils.swc import build_radiuses_from_xyzr +def only_allow_module(func): + """Decorator to only allow the function to be called on Module instances.""" + + def wrapper(self, *args, **kwargs): + assert not isinstance( + self, View + ), "This function can only be called on Module instances" + return func(self, *args, **kwargs) + + return wrapper + + class Module(ABC): """Module base class. @@ -50,7 +62,42 @@ class Module(ABC): Modules can be traversed and modified using the `at`, `cell`, `branch`, `comp`, `edge`, and `loc` methods. The `scope` method can be used to toggle between - global and local indices. + global and local indices. Traversal of Modules will return a `View` of itself, + that has a modified set of attributes, which only consider the part of the Module + that is in view. + + This has consequences for how to operate on Module and which changes take affect + where. The following guidelines should be followed (copied from `View`): + 1. We consider a Module to have everything in view. + 2. Views can display and keep track of how a module is traversed. But(!), + do not support making changes or setting variables. This still has to be + done in the base Module, i.e. `self.base`. In order to enssure that these + changes only affects whatever is currently in view `self._nodes_in_view`, + or `self._edges_in_view` among others have to be used. Operating on nodes + currently in view can for example be done with + `self.base.node.loc[self._nodes_in_view]` + 3. Every attribute of Module that changes based on what's in view, i.e. `xyzr`, + needs to modified when View is instantiated. I.e. `xyzr` of `cell.branch(0)`, + should be `[self.base.xyzr[0]]` This could be achieved via: + `[self.base.xyzr[b] for b in self._branches_in_view]`. + + + Example to make methods of Module compatible with View: + ``` + # use data in view to return something + def count_small_branches(self): + # no need to use self.base.attr + viewed indices, + # since no change is made to the attr in question (nodes) + comp_lens = self.nodes["length"] + branch_lens = comp_lens.groupby("global_branch_index").sum() + return np.sum(branch_lens < 10) + + # change data in view + def change_attr_in_view(self): + # changes to attrs have to be made via self.base.attr + viewed indices + a = func1(self.base.attr1[self._cells_in_view]) + b = func2(self.base.attr2[self._edges_in_view]) + self.base.attr3[self._branches_in_view] = a + b This base class defines the scaffold for all jaxley modules (compartments, branches, cells, networks). @@ -243,16 +290,17 @@ def reindex_a_by_b( self.nodes[local_idx_cols] = idcs[local_idx_cols].astype(int) # move indices to the front of the dataframe; move controlled_by_param to the end + # move indices of current scope to the front and the others to the back + not_scope = "global" if self._scope == "local" else "local" self.nodes = reorder_cols( - self.nodes, - [ - f"{scope}_{name}" - for scope in ["global", "local"] - for name in index_names - ], + self.nodes, [f"{self._scope}_{name}" for name in index_names], first=True ) - self.nodes = reorder_cols(self.nodes, ["controlled_by_param"], first=False) + self.nodes = reorder_cols( + self.nodes, [f"{not_scope}_{name}" for name in index_names], first=False + ) + self.edges = reorder_cols(self.edges, ["global_edge_index"]) + self.nodes = reorder_cols(self.nodes, ["controlled_by_param"], first=False) self.edges = reorder_cols(self.edges, ["controlled_by_param"], first=False) def _init_view(self): @@ -597,7 +645,7 @@ def copy( # start from 0/-1 and are contiguous if as_module: raise NotImplementedError("Not yet implemented.") - # TODO: initialize a new module with the same attributes + # initialize a new module with the same attributes return view @property @@ -641,8 +689,9 @@ def _gather_channels_from_constituents(self, constituents: List): name = channel._name self.base.nodes.loc[self.nodes[name].isna(), name] = False + @only_allow_module def to_jax(self): - # TODO: Make this work for View? + # TODO FROM #447: Make this work for View? """Move `.nodes` to `.jaxnodes`. Before the actual simulation is run (via `jx.integrate`), all parameters of @@ -715,6 +764,7 @@ def show( return nodes[cols] + @only_allow_module def _init_morph(self): """Initialize the morphology such that it can be processed by the solvers.""" self._init_morph_jaxley_spsolve() @@ -826,7 +876,6 @@ def set_ncomp( and len(self._branches_in_view) == len(self.base._branches_in_view) ), "This is not allowed for cells." - # TODO: MAKE THIS NICER # Update all attributes that are affected by compartment structure. view = self.nodes.copy() all_nodes = self.base.nodes @@ -1131,10 +1180,11 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: """ return self.trainable_params + @only_allow_module def get_all_parameters( self, pstate: List[Dict], voltage_solver: str ) -> Dict[str, jnp.ndarray]: - # TODO: MAKE THIS WORK FOR VIEW? + # TODO FROM #447: MAKE THIS WORK FOR VIEW? """Return all parameters (and coupling conductances) needed to simulate. Runs `_compute_axial_conductances()` and return every parameter that is needed @@ -1185,7 +1235,7 @@ def get_all_parameters( # This is needed since SynapseViews worked differently before. # This mimics the old behaviour and tranformes the new indices # to the old indices. - # TODO: Longterm this should be gotten rid of. + # TODO FROM #447: Longterm this should be gotten rid of. # Instead edges should work similar to nodes (would also allow for # param sharing). synapse_inds = self.base.edges.groupby("type").rank()["global_edge_index"] @@ -1206,8 +1256,9 @@ def get_all_parameters( ) return params - # TODO: MAKE THIS WORK FOR VIEW? + @only_allow_module def _get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]: + # TODO FROM #447: MAKE THIS WORK FOR VIEW? """Return states as they are set in the `.nodes` and `.edges` tables.""" self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. states = {"v": self.base.jaxnodes["v"]} @@ -1219,10 +1270,11 @@ def _get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]: states[synapse_states] = self.base.jaxedges[synapse_states] return states + @only_allow_module def get_all_states( self, pstate: List[Dict], all_params, delta_t: float ) -> Dict[str, jnp.ndarray]: - # TODO: MAKE THIS WORK FOR VIEW? + # TODO FROM #447: MAKE THIS WORK FOR VIEW? """Get the full initial state of the module from jaxnodes and trainables. Args: @@ -1268,8 +1320,9 @@ def _initialize(self): self._init_morph() return self + @only_allow_module def init_states(self, delta_t: float = 0.025): - # TODO: MAKE THIS WORK FOR VIEW? + # TODO FROM #447: MAKE THIS WORK FOR VIEW? """Initialize all mechanisms in their steady state. This considers the voltages and parameters of each compartment. @@ -1617,6 +1670,7 @@ def insert(self, channel: Channel): for key in channel.channel_states: self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_states[key] + @only_allow_module def step( self, u: Dict[str, jnp.ndarray], @@ -2183,7 +2237,6 @@ def change_attr_in_view(self): a = func1(self.base.attr1[self._cells_in_view]) b = func2(self.base.attr2[self._edges_in_view]) self.base.attr3[self._branches_in_view] = a + b - ``` """ def __init__( @@ -2254,7 +2307,7 @@ def __init__( self._current_view = "view" # if not instantiated via `comp`, `cell` etc. self._update_local_indices() - # TODO: + # TODO FROM #447: self.debug_states = pointer.debug_states if len(self.nodes) == 0: diff --git a/tests/test_groups.py b/tests/test_groups.py index a469a214..2987fb78 100644 --- a/tests/test_groups.py +++ b/tests/test_groups.py @@ -29,12 +29,6 @@ def test_subclassing_groups_cell_api(): cell.subtree.branch(0).set("radius", 0.1) cell.subtree.branch(0).comp("all").make_trainable("length") - # TODO: REMOVE THIS IS NOW ALLOWED - # with pytest.raises(KeyError): - # cell.subtree.cell(0).branch("all").make_trainable("length") - # with pytest.raises(KeyError): - # cell.subtree.comp(0).make_trainable("length") - def test_subclassing_groups_net_api(): comp = jx.Compartment() @@ -48,12 +42,6 @@ def test_subclassing_groups_net_api(): net.excitatory.cell(0).set("radius", 0.1) net.excitatory.cell(0).branch("all").make_trainable("length") - # TODO: REMOVE THIS IS NOW ALLOWED - # with pytest.raises(KeyError): - # cell.excitatory.branch(0).comp("all").make_trainable("length") - # with pytest.raises(KeyError): - # cell.excitatory.comp("all").make_trainable("length") - def test_subclassing_groups_net_set_equivalence(): """Test whether calling `.set` on subclasses group is same as on view.""" @@ -89,7 +77,7 @@ def test_subclassing_groups_net_make_trainable_equivalence(): # The following lines are made possible by PR #324. # The new behaviour needs changing of the scope to still conform here - # TODO: Rewrite this test / reconsider what behaviour is desired + # TODO FROM #447: Rewrite this test / reconsider what behaviour is desired net1.excitatory.scope("global").cell([0, 3]).scope("local").branch( 0 ).make_trainable("radius") diff --git a/tests/test_viewing.py b/tests/test_viewing.py index ae37caa1..17635b18 100644 --- a/tests/test_viewing.py +++ b/tests/test_viewing.py @@ -555,3 +555,52 @@ def test_synapse_and_channel_filtering(): assert np.all(nodes_control_param2 == nodes2["global_cell_index"]) assert np.all(edges1 == edges2) + + +def test_view_equals_module(): + # test that module behaves the same as view for important attributes + comp = jx.Compartment() + branch = jx.Branch([comp] * 3) + + comp.insert(HH()) + branch.comp([0, 1]).insert(HH()) + + comp.set("v", -71.2) + branch.comp(0).set("v", -71.2) + + comp.record("v") + branch.comp([0, 1]).record("v") + + comp.stimulate(np.zeros(100)) + branch.comp([0, 1]).stimulate(np.zeros(100)) + + comp.make_trainable("HH_gNa") + comp.make_trainable("HH_gK") + branch.comp([0, 1]).make_trainable("HH_gNa") + branch.make_trainable("HH_gK") + + # test deleting subset of attributes + branch.comp(1).delete_trainables() + branch.comp(1).delete_recordings() + branch.comp(1).delete_stimuli() + + assert ( + branch.comp(1).trainable_params == [] and branch.comp(0).trainable_params != [] + ) + assert branch.comp(1).recordings.empty and not branch.comp(0).recordings.empty + assert branch.comp(1).externals == {} and branch.comp(0).externals != {} + + # convert to dict so order of cols and index dont matter for __eq__ + assert comp.nodes.to_dict() == branch.comp(0).nodes.to_dict() + + assert comp.trainable_params == branch.comp(0).trainable_params + assert comp.indices_set_by_trainables == branch.comp(0).indices_set_by_trainables + assert np.all(comp.recordings == branch.comp(0).recordings) + assert np.all( + [ + np.all([np.all(v1 == v2), k1 == k2]) + for (k1, v1), (k2, v2) in zip( + comp.externals.items(), branch.comp(0).externals.items() + ) + ] + ) From 192a5d11af0af418c4c93a3a294e237bb23a2cf7 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Tue, 29 Oct 2024 10:45:24 +0100 Subject: [PATCH 13/17] fix: all tests passing, address review comments --- jaxley/modules/base.py | 56 +++++++++++++++++++++++++++++++----------- tests/test_viewing.py | 26 +++++++++++++++++--- 2 files changed, 63 insertions(+), 19 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index e3f2d019..d034923e 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -216,7 +216,8 @@ def __getattr__(self, key): else self.select(None) ) view._set_controlled_by_param(key) # overwrites param set by edge - # Temporary fix for synapse param sharing + # Ensure synapse param sharing works with `edge` + # `edge` will be removed as part of #463 view.edges["local_edge_index"] = np.arange(len(view.edges)) return view @@ -229,18 +230,16 @@ def _childviews(self) -> List[str]: return children def __getitem__(self, index): - supported_lvls = ["network", "cell", "branch"] # cannot index into comp + """Lazy indexing of the module.""" + supported_parents = ["network", "cell", "branch"] # cannot index into comp - # TODO FROM #447: SHOULD WE ALLOW GROUPVIEW TO BE INDEXED? - # IF YES, UNDER WHICH CONDITIONS? - is_group_view = self._current_view in self.groups + not_group_view = self._current_view not in self.groups assert ( - self._current_view in supported_lvls or is_group_view - ), "Lazy indexing is not supported for this View/Module." + self._current_view in supported_parents or not_group_view + ), "Lazy indexing is only supported for `Network`, `Cell`, `Branch` and Views thereof." index = index if isinstance(index, tuple) else (index,) - module_or_view = self.base if is_group_view else self - child_views = module_or_view._childviews() + child_views = self._childviews() assert len(index) <= len(child_views), "Too many indices." view = self for i, child in zip(index, child_views): @@ -307,8 +306,8 @@ def _init_view(self): """Init attributes critical for View. Needs to be called at init of a Module.""" - lvl = self.__class__.__name__.lower() - self._current_view = "comp" if lvl == "compartment" else lvl + parent = self.__class__.__name__.lower() + self._current_view = "comp" if parent == "compartment" else parent self._nodes_in_view = self.nodes.index.to_numpy() self._edges_in_view = self.edges.index.to_numpy() self.nodes["controlled_by_param"] = 0 @@ -2205,6 +2204,13 @@ class View(Module): allow to target specific parts of a Module, i.e. setting parameters for parts of a cell. + Almost all methods in View are concerned with updating the attributes of the + base Module, i.e. `self.base`, based on the indices in view. For example, + `_channels_in_view` lists all channels, finds the subset set to `True` in + `self.nodes` (currently in view) and returns the updated list such that we can set + `self.channels = self._channels_in_view()`. + + To allow seamless operation on Views and Modules as if they were the same, the following needs to be ensured: 1. We consider a Module to have everything in view. @@ -2316,7 +2322,7 @@ def __init__( def _set_inds_in_view( self, pointer: Union[Module, View], nodes: np.ndarray, edges: np.ndarray ): - """Set nodes and edge indices that are in view.""" + """Update node and edge indices to list only those currently in view.""" # set nodes and edge indices in view has_node_inds = nodes is not None has_edge_inds = edges is not None @@ -2352,6 +2358,7 @@ def _set_inds_in_view( self._edges_in_view = edges def _jax_arrays_in_view(self, pointer: Union[Module, View]): + """Update jaxnodes/jaxedges to show only those currently in view.""" a_intersects_b_at = lambda a, b: jnp.intersect1d(a, b, return_indices=True)[1] jaxnodes = {} if pointer.jaxnodes is not None else None if self.jaxnodes is not None: @@ -2376,6 +2383,7 @@ def _jax_arrays_in_view(self, pointer: Union[Module, View]): return jaxnodes, jaxedges def _set_externals_in_view(self): + """Update external inputs to show only those currently in view.""" self.externals = {} self.external_inds = {} for (name, inds), data in zip( @@ -2390,7 +2398,17 @@ def _set_externals_in_view(self): def _filter_trainables( self, is_viewed: bool = True ) -> Tuple[List[np.ndarray], List[Dict]]: - """filters the trainables inside and outside of the view + """Filters the trainables inside and outside of the view. + + Trainables are split between `indices_set_by_trainables` and `trainable_params` + and can be shared between mutliple compartments / branches etc, which makes it + difficult to filter them based on the current view w.o. destroying the + original structure. + + This method filters `indices_set_by_trainables` for the indices that are + currently in view (or not in view) and returns the corresponding trainable + parameters and indices such that the sharing behavior is preserved as much as + possible. Args: is_viewed: Toggles between returning the trainables and inds @@ -2425,8 +2443,9 @@ def _filter_trainables( índices_set_by_trainables_in_view.append(inds[completely_in_view]) partial_inds = inds[partially_in_view][in_view[partially_in_view]] - # the indexing above can lead to inconsistent shapes. - # this is fixed here to return them to the prev shape + # the indexing i.e. `inds[partially_in_view]` reshapes `inds`. Since the shape + # determines how parameters are shared, `inds` has to be returned to its + # original shape. if inds.shape[0] > 1 and partial_inds.shape != (0,): partial_inds = partial_inds.reshape(-1, 1) if inds.shape[1] > 1 and partial_inds.shape != (0,): @@ -2443,6 +2462,7 @@ def _filter_trainables( return indices_set_by_trainables, trainable_params def _set_trainables_in_view(self): + """Set `trainable_params` and `indices_set_by_trainables` to show only those in view.""" trainables = self._filter_trainables() # note for `branch.comp(0).make_trainable("X"); branch.make_trainable("X")` @@ -2451,12 +2471,14 @@ def _set_trainables_in_view(self): self.trainable_params = trainables[1] def _channels_in_view(self, pointer: Union[Module, View]) -> List[Channel]: + """Set channels to show only those in view.""" names = [name._name for name in pointer.channels] channel_in_view = self.nodes[names].any(axis=0) channel_in_view = channel_in_view[channel_in_view].index return [c for c in pointer.channels if c._name in channel_in_view] def _set_synapses_in_view(self, pointer: Union[Module, View]): + """Set synapses to show only those in view.""" viewed_synapses = [] viewed_params = [] viewed_states = [] @@ -2478,6 +2500,9 @@ def _nbranches_per_cell_in_view(self) -> np.ndarray: return cell_nodes["global_branch_index"].nunique().to_list() def _xyzr_in_view(self) -> List[np.ndarray]: + """Return xyzr coordinates of every branch that is in `_branches_in_view`. + + If a branch is not completely in view, the coordinates are interpolated.""" xyzr = [self.base.xyzr[i] for i in self._branches_in_view].copy() # Currently viewing with `.loc` will show the closest compartment @@ -2527,6 +2552,7 @@ def _comps_in_view(self) -> np.ndarray: @property def _branch_edges_in_view(self) -> np.ndarray: + """Lists the global branch edge indices which are currently part of the view.""" incl_branches = self.nodes["global_branch_index"].unique() pre = self.base.branch_edges["parent_branch_index"].isin(incl_branches) post = self.base.branch_edges["child_branch_index"].isin(incl_branches) diff --git a/tests/test_viewing.py b/tests/test_viewing.py index 17635b18..e27bdc3e 100644 --- a/tests/test_viewing.py +++ b/tests/test_viewing.py @@ -257,8 +257,18 @@ def test_solve_indexer(): # make sure all attrs in module also have a corresponding attr in view @pytest.mark.parametrize("module", [comp, branch, cell, net]) def test_view_attrs(module: jx.Compartment | jx.Branch | jx.Cell | jx.Network): + """Check if all attributes of Module have a corresponding attribute in View. + + To ensure that View behaves like a Module as much as possible, View should support + all attributes of Module. This test checks if all attributes of Module have a + corresponding attribute in View. Also checks if the types of the attributes match. + """ # attributes of Module that do not have to exist in View exceptions = ["view"] + + # TODO: Types are inconsistent between different Modules + exceptions += ["cumsum_nbranches"] + # TODO FROM #447: should be added to View in the future exceptions += [ "_internal_node_inds", @@ -278,7 +288,6 @@ def test_view_attrs(module: jx.Compartment | jx.Branch | jx.Cell | jx.Network): "cumsum_nbranchpoints_per_cell", "_cumsum_nseg_per_cell", ] # for network - exceptions += ["cumsum_nbranches"] # TODO: take care of this for name, attr in module.__dict__.items(): if name not in exceptions: @@ -298,7 +307,8 @@ def test_view_attrs(module: jx.Compartment | jx.Branch | jx.Cell | jx.Network): @pytest.mark.parametrize("module", [comp, branch, cell, net]) -def test_different_index_types(module): +def test_view_supported_index_types(module): + """Check if different ways to index into Modules/Views work correctly.""" # test int, range, slice, list, np.array, pd.Index index_types = [ 0, @@ -308,6 +318,7 @@ def test_different_index_types(module): np.array([0, 1, 2]), pd.Index([0, 1, 2]), ] + # `_reformat_index` should always return a np.ndarray for index in index_types: assert isinstance( module._reformat_index(index), np.ndarray @@ -321,6 +332,7 @@ def test_different_index_types(module): def test_select(): + """Ensure `select` works correctly and returns expected View of Modules.""" comp = jx.Compartment() branch = jx.Branch([comp] * 3) cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) @@ -362,6 +374,7 @@ def test_select(): def test_viewing(): + """Test that the View object is working correctly.""" comp = jx.Compartment() branch = jx.Branch([comp] * 3) cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) @@ -415,6 +428,7 @@ def test_viewing(): def test_scope(): + """Ensure scope has the intended effect for Modules and Views.""" comp = jx.Compartment() branch = jx.Branch([comp] * 3) cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) @@ -448,6 +462,7 @@ def test_scope(): def test_context_manager(): + """Test that context manager works correctly for Module.""" comp = jx.Compartment() branch = jx.Branch([comp] * 3) cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) @@ -472,6 +487,7 @@ def test_context_manager(): def test_iter(): + """Test that __iter__ works correctly for all modules.""" comp = jx.Compartment() branch1 = jx.Branch([comp] * 2) branch2 = jx.Branch([comp] * 3) @@ -531,6 +547,7 @@ def test_iter(): def test_synapse_and_channel_filtering(): + """Test that synapses and channels are filtered correctly by View.""" comp = jx.Compartment() branch = jx.Branch([comp] * 3) cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) @@ -550,7 +567,8 @@ def test_synapse_and_channel_filtering(): edges_control_param1 = edges1.pop("controlled_by_param") edges_control_param2 = edges2.pop("controlled_by_param") - assert np.all(nodes1 == nodes2) + # convert to dict so order of cols and index dont matter for __eq__ + assert nodes1.to_dict() == nodes2.to_dict() assert np.all(nodes_control_param1 == 0) assert np.all(nodes_control_param2 == nodes2["global_cell_index"]) @@ -558,7 +576,7 @@ def test_synapse_and_channel_filtering(): def test_view_equals_module(): - # test that module behaves the same as view for important attributes + """Test that View behaves the same as Module for important attrs and methods.""" comp = jx.Compartment() branch = jx.Branch([comp] * 3) From 2f2ccd1a80c9e0545d3325078f2327ffc78f5b1d Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Tue, 29 Oct 2024 11:18:05 +0100 Subject: [PATCH 14/17] fix/rm: rm test for laxy indexing into groups, and rebase onto main. --- jaxley/modules/base.py | 9 +++++++-- jaxley/modules/network.py | 2 +- tests/test_groups.py | 31 ------------------------------- 3 files changed, 8 insertions(+), 34 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index d034923e..4ad8e7a4 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -43,12 +43,17 @@ def only_allow_module(func): - """Decorator to only allow the function to be called on Module instances.""" + """Decorator to only allow the function to be called on Module instances. + + Decorates methods of Module that cannot be called on Views of Modules instances. + and have to be called on the Module itself.""" def wrapper(self, *args, **kwargs): + module_name = self.base.__class__.__name__ + method_name = func.__name__ assert not isinstance( self, View - ), "This function can only be called on Module instances" + ), f"{method_name} is currently not supported for Views. Call on the {module_name} base Module." return func(self, *args, **kwargs) return wrapper diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 7efa2c29..c225545e 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -105,7 +105,7 @@ def __init__( # Channels. self._gather_channels_from_constituents(cells) - self.initialize() + self._initialize() del self._cells_list def __repr__(self): diff --git a/tests/test_groups.py b/tests/test_groups.py index 2987fb78..00e22ee5 100644 --- a/tests/test_groups.py +++ b/tests/test_groups.py @@ -101,37 +101,6 @@ def test_subclassing_groups_net_make_trainable_equivalence(): assert jnp.array_equal(inds1, inds2) -def test_subclassing_groups_net_lazy_indexing_make_trainable_equivalence(): - """Test whether groups can be indexing in a lazy way.""" - comp = jx.Compartment() - branch = jx.Branch(comp, 4) - cell = jx.Cell(branch, [-1, 0]) - net1 = jx.Network([cell for _ in range(10)]) - net2 = jx.Network([cell for _ in range(10)]) - - net1.cell([0, 3, 5]).add_to_group("excitatory") - net2.cell([0, 3, 5]).add_to_group("excitatory") - - # The following lines are made possible by PR #324. - net1.excitatory.cell([0, 3]).branch(0).make_trainable("radius") - net1.excitatory.cell([0, 5]).branch(1).comp("all").make_trainable("length") - net1.excitatory.cell("all").branch(1).comp(2).make_trainable("axial_resistivity") - params1 = jnp.concatenate(jax.tree_util.tree_flatten(net1.get_parameters())[0]) - - # The following lines are made possible by PR #324. - net2.excitatory[[0, 3], 0].make_trainable("radius") - net2.excitatory[[0, 5], 1, :].make_trainable("length") - net2.excitatory[:, 1, 2].make_trainable("axial_resistivity") - params2 = jnp.concatenate(jax.tree_util.tree_flatten(net2.get_parameters())[0]) - - assert jnp.array_equal(params1, params2) - - for inds1, inds2 in zip( - net1.indices_set_by_trainables, net2.indices_set_by_trainables - ): - assert jnp.array_equal(inds1, inds2) - - def test_fully_connect_groups_equivalence(): """Test whether groups can be used with `fully_connect`.""" comp = jx.Compartment() From 9f1cd5cbf24bfd0c9e0ce23529bc74c92f2ed367 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Tue, 29 Oct 2024 11:40:32 +0100 Subject: [PATCH 15/17] fix: small refactor --- jaxley/modules/base.py | 6 ++++++ tests/test_viewing.py | 28 +++++++++++++++++----------- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 4ad8e7a4..af79444d 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -234,6 +234,10 @@ def _childviews(self) -> List[str]: children = levels[levels.index(self._current_view) + 1 :] return children + def _has_childview(self, key: str) -> bool: + child_views = self._childviews() + return key in child_views + def __getitem__(self, index): """Lazy indexing of the module.""" supported_parents = ["network", "cell", "branch"] # cannot index into comp @@ -466,6 +470,8 @@ def _at_nodes(self, key: str, idx: Any) -> View: Keys can be `cell`, `branch`, `comp` and determine which index is used to filter. """ + base_name = self.base.__class__.__name__ + assert self.base._has_childview(key), f"{base_name} does not support {key}." idx = self._reformat_index(idx) idx = self.nodes[self._scope + f"_{key}_index"] if is_str_all(idx) else idx where = self.nodes[self._scope + f"_{key}_index"].isin(idx) diff --git a/tests/test_viewing.py b/tests/test_viewing.py index e27bdc3e..ba09757f 100644 --- a/tests/test_viewing.py +++ b/tests/test_viewing.py @@ -318,17 +318,23 @@ def test_view_supported_index_types(module): np.array([0, 1, 2]), pd.Index([0, 1, 2]), ] - # `_reformat_index` should always return a np.ndarray - for index in index_types: - assert isinstance( - module._reformat_index(index), np.ndarray - ), f"Failed for {type(index)}" - assert module.comp(index), f"Failed for {type(index)}" - assert View(module).comp(index), f"Failed for {type(index)}" - - # for loc test float and list of floats - assert module.loc(0.0), "Failed for float" - assert module.loc([0.0, 0.5, 1.0]), "Failed for List[float]" + + # comp.comp is not allowed + if not isinstance(module, jx.Compartment): + # `_reformat_index` should always return a np.ndarray + for index in index_types: + assert isinstance( + module._reformat_index(index), np.ndarray + ), f"Failed for {type(index)}" + assert module.comp(index), f"Failed for {type(index)}" + assert View(module).comp(index), f"Failed for {type(index)}" + + # for loc test float and list of floats + assert module.loc(0.0), "Failed for float" + assert module.loc([0.0, 0.5, 1.0]), "Failed for List[float]" + else: + with pytest.raises(AssertionError): + module.comp(0) def test_select(): From 63e1f59e8a13b1da7c19e10ae77cac94ce9a2e96 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Tue, 29 Oct 2024 16:08:25 +0100 Subject: [PATCH 16/17] fix: fix move_to issues. caused by not allowing cell.cells --- jaxley/modules/base.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index af79444d..1f90dae0 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -2164,14 +2164,19 @@ def move_to( "NaN coordinate values detected. Shift amounts cannot be computed. Please run compute_xyzr() or assign initial coordinate values." ) - root_xyz_cells = np.array([c.xyzr[0][0, :3] for c in self.cells]) + # can only iterate over cells for networks + # lambda makes sure that generator can be created multiple times + base_is_net = self.base._current_view == "network" + cells = lambda: (self.cells if base_is_net else [self]) + + root_xyz_cells = np.array([c.xyzr[0][0, :3] for c in cells()]) root_xyz = root_xyz_cells[0] if isinstance(x, float) else root_xyz_cells move_by = np.array([x, y, z]).T - root_xyz if len(move_by.shape) == 1: move_by = np.tile(move_by, (len(self._cells_in_view), 1)) - for cell, offset in zip(self.cells, move_by): + for cell, offset in zip(cells(), move_by): for idx in cell._branches_in_view: self.base.xyzr[idx][:, :3] += offset if update_nodes: From 42962bd68a376381969c2c4b42698827e43681bb Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Tue, 29 Oct 2024 16:24:34 +0100 Subject: [PATCH 17/17] doc: add comments --- jaxley/modules/base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 1f90dae0..12bb8f0d 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1467,9 +1467,13 @@ def _update_view(self): if isinstance(self, View): scope = self._scope current_view = self._current_view + # copy dict of new View. For some reason doing self = View(self) + # did not work. self.__dict__ = View( self.base, self._nodes_in_view, self._edges_in_view ).__dict__ + + # retain the scope and current_view of the previous view self._scope = scope self._current_view = current_view