From 2d6a6f02a1a8dbc390cfbc632d5882f355310dea Mon Sep 17 00:00:00 2001 From: Kyra Date: Thu, 7 Nov 2024 15:26:59 +0100 Subject: [PATCH 01/10] Fixed bug in connectivity_matrix_connect, added test, and removed rand selection of compartments --- jaxley/connect.py | 28 ++++++++++++---------------- tests/test_connection.py | 14 ++++++++++++++ 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/jaxley/connect.py b/jaxley/connect.py index caff2267..391c280e 100644 --- a/jaxley/connect.py +++ b/jaxley/connect.py @@ -134,7 +134,8 @@ def connectivity_matrix_connect( Connects pre- and postsynaptic cells according to a custom connectivity matrix. Entries > 0 in the matrix indicate a connection between the corresponding cells. - Connections are from branch 0 location 0 to a randomly chosen branch and loc. + Connections are from branch 0 location 0 on the presynaptic cell to branch 0 + location 0 on the postsynaptic cell. Args: pre_cell_view: View of the presynaptic cell. @@ -142,7 +143,7 @@ def connectivity_matrix_connect( synapse_type: The synapse to append. connectivity_matrix: A boolean matrix indicating the connections between cells. """ - # Get pre- and postsynaptic cell indices. + # Get pre- and postsynaptic cell indices pre_cell_inds = pre_cell_view._cells_in_view post_cell_inds = post_cell_view._cells_in_view @@ -152,24 +153,19 @@ def connectivity_matrix_connect( ), "Connectivity matrix must have shape (num_pre, num_post)." assert connectivity_matrix.dtype == bool, "Connectivity matrix must be boolean." - # get connection pairs from connectivity matrix + # Get connection pairs from connectivity matrix from_idx, to_idx = np.where(connectivity_matrix) - pre_cell_inds = pre_cell_inds[from_idx] - post_cell_inds = post_cell_inds[to_idx] - - # Sample random postsynaptic compartments (global comp indices). - global_post_indices = np.hstack( - [ - sample_comp(post_cell_view.scope("global").cell(cell_idx)) - for cell_idx in post_cell_inds - ] - ) - post_rows = post_cell_view.nodes.loc[global_post_indices] - # Pre-synapse is at the zero-eth branch and zero-eth compartment. + # Pre-synapse at the zero-eth branch and zero-eth compartment global_pre_indices = ( pre_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() ) # setting scope ensure that this works indep of current scope - pre_rows = pre_cell_view.select(nodes=global_pre_indices[pre_cell_inds]).nodes + pre_rows = pre_cell_view.select(nodes=global_pre_indices[from_idx]).nodes + + # Post-synapse also at the zero-eth branch and zero-eth compartment + global_post_indices = ( + post_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() + ) + post_rows = post_cell_view.select(nodes=global_post_indices[to_idx]).nodes pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type) diff --git a/tests/test_connection.py b/tests/test_connection.py index 5178d24b..853c1ac1 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -203,3 +203,17 @@ def test_connectivity_matrix_connect(): comp_inds = nodes.loc[net.edges[cols].to_numpy().flatten()] cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) assert np.all(cell_inds == incides_of_connected_cells) + + net = jx.Network([cell for _ in range(4 * 4)]) + connectivity_matrix_connect( + net[1:4], net[2:6], TestSynapse(), m_by_n_adjacency_matrix + ) + assert len(net.edges.index) == 5 + nodes = net.nodes.set_index("global_comp_index") + cols = ["global_pre_comp_index", "global_post_comp_index"] + comp_inds = nodes.loc[net.edges[cols].to_numpy().flatten()] + cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) + # adjust the cell indices based on the view range passed + incides_of_connected_cells[:, 0] += 1 + incides_of_connected_cells[:, 1] += 2 + assert np.all(cell_inds == incides_of_connected_cells) From dbb8ef5cacb2aa1fcbe955e139fcf5c734462753 Mon Sep 17 00:00:00 2001 From: Kyra Kadhim Date: Fri, 8 Nov 2024 10:55:07 +0100 Subject: [PATCH 02/10] More explanatory variable names and test for single comp cells --- jaxley/connect.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/jaxley/connect.py b/jaxley/connect.py index 74e6e124..691d10c4 100644 --- a/jaxley/connect.py +++ b/jaxley/connect.py @@ -144,32 +144,28 @@ def connectivity_matrix_connect( connectivity_matrix: A boolean matrix indicating the connections between cells. """ # Get pre- and postsynaptic cell indices - pre_cell_inds = pre_cell_view._cells_in_view - post_cell_inds = post_cell_view._cells_in_view - # setting scope ensure that this works indep of current scope - pre_nodes = pre_cell_view.scope("local").branch(0).comp(0).nodes - pre_nodes["index"] = pre_nodes.index - pre_cell_nodes = pre_nodes.set_index("global_cell_index") + global_pre_cell_inds = pre_cell_view._cells_in_view + global_post_cell_inds = post_cell_view._cells_in_view assert connectivity_matrix.shape == ( - len(pre_cell_inds), - len(post_cell_inds), + len(global_pre_cell_inds), + len(global_post_cell_inds), ), "Connectivity matrix must have shape (num_pre, num_post)." assert connectivity_matrix.dtype == bool, "Connectivity matrix must be boolean." - # Get connection pairs from connectivity matrix + # Get pre to post connection pairs from connectivity matrix from_idx, to_idx = np.where(connectivity_matrix) # Pre-synapse at the zero-eth branch and zero-eth compartment - global_pre_indices = ( + global_pre_comp_indices = ( pre_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() ) # setting scope ensure that this works indep of current scope - pre_rows = pre_cell_view.select(nodes=global_pre_indices[from_idx]).nodes + pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes # Post-synapse also at the zero-eth branch and zero-eth compartment - global_post_indices = ( + global_post_comp_indices = ( post_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() ) - post_rows = post_cell_view.select(nodes=global_post_indices[to_idx]).nodes + post_rows = post_cell_view.select(nodes=global_post_comp_indices[to_idx]).nodes pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type) From cf498f65a15ed4eb3d01b752f9d58bf34fe6adbf Mon Sep 17 00:00:00 2001 From: Kyra Kadhim Date: Fri, 8 Nov 2024 11:40:14 +0100 Subject: [PATCH 03/10] test updates, accidentally not in last commit --- tests/test_connection.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_connection.py b/tests/test_connection.py index 853c1ac1..bc04fb77 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -204,6 +204,7 @@ def test_connectivity_matrix_connect(): cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) assert np.all(cell_inds == incides_of_connected_cells) + # Test with different view ranges net = jx.Network([cell for _ in range(4 * 4)]) connectivity_matrix_connect( net[1:4], net[2:6], TestSynapse(), m_by_n_adjacency_matrix @@ -217,3 +218,18 @@ def test_connectivity_matrix_connect(): incides_of_connected_cells[:, 0] += 1 incides_of_connected_cells[:, 1] += 2 assert np.all(cell_inds == incides_of_connected_cells) + + # Test with single compartment cells + comp = jx.Compartment() + branch = jx.Branch([comp], nseg=1) + cell = jx.Cell([branch], parents=[-1]) + net = jx.Network([cell for _ in range(4 * 4)]) + connectivity_matrix_connect( + net[1:4], net[2:6], TestSynapse(), m_by_n_adjacency_matrix + ) + assert len(net.edges.index) == 5 + nodes = net.nodes.set_index("global_comp_index") + cols = ["global_pre_comp_index", "global_post_comp_index"] + comp_inds = nodes.loc[net.edges[cols].to_numpy().flatten()] + cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) + assert np.all(cell_inds == incides_of_connected_cells) From 588198ae80a9fd88786ceb6736cf7672db172801 Mon Sep 17 00:00:00 2001 From: Kyra Kadhim Date: Fri, 8 Nov 2024 14:10:30 +0100 Subject: [PATCH 04/10] Removed random selection of post-synaptic compartments, fixed test bug in sparse connect, standardized connection functions --- jaxley/connect.py | 68 +++++++++++++------------- tests/test_connection.py | 103 ++++++++++++++++++++------------------- 2 files changed, 89 insertions(+), 82 deletions(-) diff --git a/jaxley/connect.py b/jaxley/connect.py index 691d10c4..713510e7 100644 --- a/jaxley/connect.py +++ b/jaxley/connect.py @@ -58,19 +58,21 @@ def fully_connect( num_pre = len(pre_cell_view._cells_in_view) num_post = len(post_cell_view._cells_in_view) - # Infer indices of (random) postsynaptic compartments. - global_post_indices = ( - post_cell_view.nodes.groupby("global_cell_index") - .sample(num_pre, replace=True) - .index.to_numpy() - ) - global_post_indices = global_post_indices.reshape((-1, num_pre), order="F").ravel() - post_rows = post_cell_view.nodes.loc[global_post_indices] + # Pre-synapse at the zero-eth branch and zero-eth compartment + global_pre_comp_indices = ( + pre_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() + ) # setting scope ensure that this works indep of current scope + # Repeat comp indices `num_post` times. See SO 50788508 as before + global_pre_comp_indices = np.repeat(global_pre_comp_indices, num_post) + pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices).nodes - # Pre-synapse is at the zero-eth branch and zero-eth compartment. - pre_rows = pre_cell_view.scope("local").branch(0).comp(0).nodes.copy() - # Repeat rows `num_post` times. See SO 50788508. - pre_rows = pre_rows.loc[pre_rows.index.repeat(num_post)].reset_index(drop=True) + # Post-synapse also at the zero-eth branch and zero-eth compartment + global_post_comp_indices = ( + post_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() + ) + # Tile comp indices `num_pre` times + global_post_comp_indices = np.tile(global_post_comp_indices, num_pre) + post_rows = post_cell_view.select(nodes=global_post_comp_indices).nodes pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type) @@ -83,7 +85,7 @@ def sparse_connect( ): """Appends multiple connections which build a sparse, randomly connected layer. - Connections are from branch 0 location 0 to a randomly chosen branch and loc. + Connections are from branch 0 location 0 to branch 0 location 0. Args: pre_cell_view: View of the presynaptic cell. @@ -97,28 +99,28 @@ def sparse_connect( num_pre = len(pre_cell_inds) num_post = len(post_cell_inds) + # Get the indices of connections, like it's from a random connectivity matrix num_connections = np.random.binomial(num_pre * num_post, p) - pre_syn_neurons = np.random.choice(pre_cell_inds, size=num_connections) - post_syn_neurons = np.random.choice(post_cell_inds, size=num_connections) - - # Sort the synapses only for convenience of inspecting `.edges`. - sorting = np.argsort(pre_syn_neurons) - pre_syn_neurons = pre_syn_neurons[sorting] - post_syn_neurons = post_syn_neurons[sorting] - - # Post-synapse is a randomly chosen branch and compartment. - global_post_indices = [ - sample_comp(post_cell_view.scope("global").cell(cell_idx)) - for cell_idx in post_syn_neurons - ] - global_post_indices = ( - np.hstack(global_post_indices) if len(global_post_indices) > 1 else [] - ) - post_rows = post_cell_view.base.nodes.loc[global_post_indices] + from_idx = np.random.choice(range(0, num_pre), size=num_connections) + to_idx = np.random.choice(range(0, num_post), size=num_connections) + + # Remove duplicate connections + row_inds = np.stack((from_idx, to_idx), axis=1) + row_inds = np.unique(row_inds, axis=0) + from_idx = row_inds[:, 0] + to_idx = row_inds[:, 1] + + # Pre-synapse at the zero-eth branch and zero-eth compartment + global_pre_comp_indices = ( + pre_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() + ) # setting scope ensure that this works indep of current scope + pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes - # Pre-synapse is at the zero-eth branch and zero-eth compartment. - global_pre_indices = pre_cell_view.base._cumsum_nseg_per_cell[pre_syn_neurons] - pre_rows = pre_cell_view.base.nodes.loc[global_pre_indices] + # Post-synapse also at the zero-eth branch and zero-eth compartment + global_post_comp_indices = ( + post_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() + ) + post_rows = post_cell_view.select(nodes=global_post_comp_indices[to_idx]).nodes if len(pre_rows) > 0: pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type) diff --git a/tests/test_connection.py b/tests/test_connection.py index bc04fb77..f52f4fc7 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -84,41 +84,42 @@ def test_fully_connect(): fully_connect(net[8:12], net[12:16], TestSynapse()) + # This was previously visually inspected assert all( net.edges.global_post_comp_index == [ - 108, - 135, - 165, + 96, + 120, + 144, 168, - 99, - 123, - 151, - 177, - 115, - 141, - 162, - 172, - 119, - 126, - 156, - 169, - 294, - 329, - 345, - 379, - 295, - 317, - 356, - 365, - 311, - 325, - 355, - 375, - 302, - 320, - 352, - 375, + 96, + 120, + 144, + 168, + 96, + 120, + 144, + 168, + 96, + 120, + 144, + 168, + 288, + 312, + 336, + 360, + 288, + 312, + 336, + 360, + 288, + 312, + 336, + 360, + 288, + 312, + 336, + 360, ] ) @@ -135,27 +136,26 @@ def test_sparse_connect(): sparse_connect(net[8:12], net[12:], TestSynapse(), p=0.5) + # This was previously visually inspected assert all( - [ - 63, - 59, - 65, - 86, - 80, - 58, - 92, - 85, + net.edges.global_post_comp_index + == [ + 48, + 60, + 84, + 60, + 72, + 48, + 72, + 168, + 180, + 144, 168, - 145, - 189, - 153, 180, - 190, - 184, - 163, - 159, - 179, - 182, + 156, + 180, + 144, + 180, ] ) @@ -233,3 +233,8 @@ def test_connectivity_matrix_connect(): comp_inds = nodes.loc[net.edges[cols].to_numpy().flatten()] cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) assert np.all(cell_inds == incides_of_connected_cells) + + +if __name__ == "__main__": + + test_sparse_connect() From 322e60c8b7a0694aced9532e57dc14688ba242b0 Mon Sep 17 00:00:00 2001 From: Kyra Kadhim Date: Fri, 8 Nov 2024 14:11:16 +0100 Subject: [PATCH 05/10] Black on connect.py --- jaxley/connect.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxley/connect.py b/jaxley/connect.py index 713510e7..7264085d 100644 --- a/jaxley/connect.py +++ b/jaxley/connect.py @@ -109,7 +109,7 @@ def sparse_connect( row_inds = np.unique(row_inds, axis=0) from_idx = row_inds[:, 0] to_idx = row_inds[:, 1] - + # Pre-synapse at the zero-eth branch and zero-eth compartment global_pre_comp_indices = ( pre_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() From b80f462ba3432763a154005bb7ab148a6e32c508 Mon Sep 17 00:00:00 2001 From: Kyra Kadhim Date: Fri, 8 Nov 2024 14:22:32 +0100 Subject: [PATCH 06/10] Standardized fully_connect even a little bit more --- jaxley/connect.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/jaxley/connect.py b/jaxley/connect.py index 7264085d..f04bec83 100644 --- a/jaxley/connect.py +++ b/jaxley/connect.py @@ -58,21 +58,21 @@ def fully_connect( num_pre = len(pre_cell_view._cells_in_view) num_post = len(post_cell_view._cells_in_view) + # Get the indices of the connections, like it's a fully connected connectivity matrix + from_idx = np.repeat(range(0, num_pre), num_post) + to_idx = np.tile(range(0, num_post), num_pre) + # Pre-synapse at the zero-eth branch and zero-eth compartment global_pre_comp_indices = ( pre_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() ) # setting scope ensure that this works indep of current scope - # Repeat comp indices `num_post` times. See SO 50788508 as before - global_pre_comp_indices = np.repeat(global_pre_comp_indices, num_post) - pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices).nodes + pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes # Post-synapse also at the zero-eth branch and zero-eth compartment global_post_comp_indices = ( post_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() ) - # Tile comp indices `num_pre` times - global_post_comp_indices = np.tile(global_post_comp_indices, num_pre) - post_rows = post_cell_view.select(nodes=global_post_comp_indices).nodes + post_rows = post_cell_view.select(nodes=global_post_comp_indices[to_idx]).nodes pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type) From 4971c32a1b470c9398f7000e83135ae160418f13 Mon Sep 17 00:00:00 2001 From: Kyra Kadhim Date: Mon, 18 Nov 2024 12:25:23 +0100 Subject: [PATCH 07/10] Added back option to randomly select post-synaptic compartment in standard way without looping --- jaxley/connect.py | 95 +++++++++++++------- tests/jaxley_identical/test_basic_modules.py | 8 +- tests/jaxley_identical/test_grad.py | 8 +- tests/test_connection.py | 9 +- 4 files changed, 77 insertions(+), 43 deletions(-) diff --git a/jaxley/connect.py b/jaxley/connect.py index f04bec83..1faa8670 100644 --- a/jaxley/connect.py +++ b/jaxley/connect.py @@ -44,6 +44,7 @@ def fully_connect( pre_cell_view: "View", post_cell_view: "View", synapse_type: "Synapse", + random_post_comp: bool = False, ): """Appends multiple connections which build a fully connected layer. @@ -60,7 +61,9 @@ def fully_connect( # Get the indices of the connections, like it's a fully connected connectivity matrix from_idx = np.repeat(range(0, num_pre), num_post) - to_idx = np.tile(range(0, num_post), num_pre) + to_idx = np.tile( + range(0, num_post), num_pre + ) # used only if random_post_comp is False # Pre-synapse at the zero-eth branch and zero-eth compartment global_pre_comp_indices = ( @@ -68,11 +71,23 @@ def fully_connect( ) # setting scope ensure that this works indep of current scope pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes - # Post-synapse also at the zero-eth branch and zero-eth compartment - global_post_comp_indices = ( - post_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() - ) - post_rows = post_cell_view.select(nodes=global_post_comp_indices[to_idx]).nodes + if random_post_comp: + # Randomly sample the post-synaptic compartments + global_post_comp_indices = ( + post_cell_view.nodes.groupby("global_cell_index") + .sample(num_pre, replace=True) + .index.to_numpy() + ) + global_post_comp_indices = global_post_comp_indices.reshape( + (-1, num_pre), order="F" + ).ravel() + else: + # Post-synapse also at the zero-eth branch and zero-eth compartment + global_post_comp_indices = ( + post_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() + ) + global_post_comp_indices = global_post_comp_indices[to_idx] + post_rows = post_cell_view.nodes.loc[global_post_comp_indices] pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type) @@ -82,6 +97,7 @@ def sparse_connect( post_cell_view: "View", synapse_type: "Synapse", p: float, + random_post_comp: bool = False, ): """Appends multiple connections which build a sparse, randomly connected layer. @@ -94,21 +110,12 @@ def sparse_connect( p: Probability of connection. """ # Get pre- and postsynaptic cell indices. - pre_cell_inds = pre_cell_view._cells_in_view - post_cell_inds = post_cell_view._cells_in_view - num_pre = len(pre_cell_inds) - num_post = len(post_cell_inds) - - # Get the indices of connections, like it's from a random connectivity matrix - num_connections = np.random.binomial(num_pre * num_post, p) - from_idx = np.random.choice(range(0, num_pre), size=num_connections) - to_idx = np.random.choice(range(0, num_post), size=num_connections) - - # Remove duplicate connections - row_inds = np.stack((from_idx, to_idx), axis=1) - row_inds = np.unique(row_inds, axis=0) - from_idx = row_inds[:, 0] - to_idx = row_inds[:, 1] + num_pre = len(pre_cell_view._cells_in_view) + num_post = len(post_cell_view._cells_in_view) + + # Generate random cxns without duplicates --> respects p but memory intesive if extremely large n cells + connectivity_matrix = np.random.binomial(1, p, (num_pre, num_post)) + from_idx, to_idx = np.where(connectivity_matrix) # Pre-synapse at the zero-eth branch and zero-eth compartment global_pre_comp_indices = ( @@ -116,10 +123,21 @@ def sparse_connect( ) # setting scope ensure that this works indep of current scope pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes - # Post-synapse also at the zero-eth branch and zero-eth compartment - global_post_comp_indices = ( - post_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() - ) + if random_post_comp: + # Randomly sample the post-synaptic compartments + global_post_comp_indices = ( + post_cell_view.nodes.groupby("global_cell_index") + .sample(num_pre, replace=True) + .index.to_numpy() + ) + global_post_comp_indices = global_post_comp_indices.reshape( + (-1, num_pre), order="F" + ).ravel() + else: + # Post-synapse also at the zero-eth branch and zero-eth compartment + global_post_comp_indices = ( + post_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() + ) post_rows = post_cell_view.select(nodes=global_post_comp_indices[to_idx]).nodes if len(pre_rows) > 0: @@ -131,6 +149,7 @@ def connectivity_matrix_connect( post_cell_view: "View", synapse_type: "Synapse", connectivity_matrix: np.ndarray[bool], + random_post_comp: bool = False, ): """Appends multiple connections which build a custom connected network. @@ -146,12 +165,12 @@ def connectivity_matrix_connect( connectivity_matrix: A boolean matrix indicating the connections between cells. """ # Get pre- and postsynaptic cell indices - global_pre_cell_inds = pre_cell_view._cells_in_view - global_post_cell_inds = post_cell_view._cells_in_view + num_pre = len(pre_cell_view._cells_in_view) + num_post = len(post_cell_view._cells_in_view) assert connectivity_matrix.shape == ( - len(global_pre_cell_inds), - len(global_post_cell_inds), + num_pre, + num_post, ), "Connectivity matrix must have shape (num_pre, num_post)." assert connectivity_matrix.dtype == bool, "Connectivity matrix must be boolean." @@ -164,10 +183,20 @@ def connectivity_matrix_connect( ) # setting scope ensure that this works indep of current scope pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes - # Post-synapse also at the zero-eth branch and zero-eth compartment - global_post_comp_indices = ( - post_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() - ) + if random_post_comp: + global_post_comp_indices = ( + post_cell_view.nodes.groupby("global_cell_index") + .sample(len(from_idx), replace=True) + .index.to_numpy() + ) + global_post_comp_indices = global_post_comp_indices.reshape( + (-1, len(from_idx)), order="F" + ).ravel() + else: + # Post-synapse also at the zero-eth branch and zero-eth compartment + global_post_comp_indices = ( + post_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() + ) post_rows = post_cell_view.select(nodes=global_post_comp_indices[to_idx]).nodes pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type) diff --git a/tests/jaxley_identical/test_basic_modules.py b/tests/jaxley_identical/test_basic_modules.py index 8faba4b3..df220697 100644 --- a/tests/jaxley_identical/test_basic_modules.py +++ b/tests/jaxley_identical/test_basic_modules.py @@ -319,13 +319,13 @@ def test_complex_net(voltage_solver: str): _ = np.random.seed(0) pre = net.cell([0, 1, 2]) post = net.cell([3, 4, 5]) - fully_connect(pre, post, IonotropicSynapse()) - fully_connect(pre, post, TestSynapse()) + fully_connect(pre, post, IonotropicSynapse(), random_post_comp=True) + fully_connect(pre, post, TestSynapse(), random_post_comp=True) pre = net.cell([3, 4, 5]) post = net.cell(6) - fully_connect(pre, post, IonotropicSynapse()) - fully_connect(pre, post, TestSynapse()) + fully_connect(pre, post, IonotropicSynapse(), random_post_comp=True) + fully_connect(pre, post, TestSynapse(), random_post_comp=True) area = 2 * pi * 10.0 * 1.0 point_process_to_dist_factor = 100_000.0 / area diff --git a/tests/jaxley_identical/test_grad.py b/tests/jaxley_identical/test_grad.py index 198201bc..582e6249 100644 --- a/tests/jaxley_identical/test_grad.py +++ b/tests/jaxley_identical/test_grad.py @@ -33,13 +33,13 @@ def test_network_grad(): _ = np.random.seed(0) pre = net.cell([0, 1, 2]) post = net.cell([3, 4, 5]) - fully_connect(pre, post, IonotropicSynapse()) - fully_connect(pre, post, TestSynapse()) + fully_connect(pre, post, IonotropicSynapse(), random_post_comp=True) + fully_connect(pre, post, TestSynapse(), random_post_comp=True) pre = net.cell([3, 4, 5]) post = net.cell(6) - fully_connect(pre, post, IonotropicSynapse()) - fully_connect(pre, post, TestSynapse()) + fully_connect(pre, post, IonotropicSynapse(), random_post_comp=True) + fully_connect(pre, post, TestSynapse(), random_post_comp=True) area = 2 * pi * 10.0 * 1.0 point_process_to_dist_factor = 100_000.0 / area diff --git a/tests/test_connection.py b/tests/test_connection.py index f52f4fc7..4ca7d37c 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -142,15 +142,20 @@ def test_sparse_connect(): == [ 48, 60, + 72, 84, 60, - 72, + 84, 48, 72, + 84, + 48, + 60, + 156, 168, 180, 144, - 168, + 156, 180, 156, 180, From 5471c85138b7960abade724914c840bbe5eeaf87 Mon Sep 17 00:00:00 2001 From: Kyra Kadhim Date: Wed, 20 Nov 2024 17:09:37 +0100 Subject: [PATCH 08/10] Fixed docstrings and comments --- jaxley/connect.py | 16 ++++++++++------ tests/test_connection.py | 11 ++--------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/jaxley/connect.py b/jaxley/connect.py index 1faa8670..10296648 100644 --- a/jaxley/connect.py +++ b/jaxley/connect.py @@ -48,12 +48,14 @@ def fully_connect( ): """Appends multiple connections which build a fully connected layer. - Connections are from branch 0 location 0 to a randomly chosen branch and loc. + Connections are from branch 0 location 0 of the pre-synaptic cell to branch 0 + location 0 of the post-synaptic cell unless random_post_comp=True. Args: pre_cell_view: View of the presynaptic cell. post_cell_view: View of the postsynaptic cell. synapse_type: The synapse to append. + random_post_comp: If True, randomly samples the postsynaptic compartments. """ # Get pre- and postsynaptic cell indices. num_pre = len(pre_cell_view._cells_in_view) @@ -101,13 +103,15 @@ def sparse_connect( ): """Appends multiple connections which build a sparse, randomly connected layer. - Connections are from branch 0 location 0 to branch 0 location 0. + Connections are from branch 0 location 0 of the pre-synaptic cell to branch 0 + location 0 of the post-synaptic cell unless random_post_comp=True. Args: pre_cell_view: View of the presynaptic cell. post_cell_view: View of the postsynaptic cell. synapse_type: The synapse to append. p: Probability of connection. + random_post_comp: If True, randomly samples the postsynaptic compartments. """ # Get pre- and postsynaptic cell indices. num_pre = len(pre_cell_view._cells_in_view) @@ -151,18 +155,18 @@ def connectivity_matrix_connect( connectivity_matrix: np.ndarray[bool], random_post_comp: bool = False, ): - """Appends multiple connections which build a custom connected network. + """Appends multiple connections according to a custom connectivity matrix. - Connects pre- and postsynaptic cells according to a custom connectivity matrix. Entries > 0 in the matrix indicate a connection between the corresponding cells. - Connections are from branch 0 location 0 on the presynaptic cell to branch 0 - location 0 on the postsynaptic cell. + Connections are from branch 0 location 0 of the pre-synaptic cell to branch 0 + location 0 of the post-synaptic cell unless random_post_comp=True. Args: pre_cell_view: View of the presynaptic cell. post_cell_view: View of the postsynaptic cell. synapse_type: The synapse to append. connectivity_matrix: A boolean matrix indicating the connections between cells. + random_post_comp: If True, randomly samples the postsynaptic compartments. """ # Get pre- and postsynaptic cell indices num_pre = len(pre_cell_view._cells_in_view) diff --git a/tests/test_connection.py b/tests/test_connection.py index 4ca7d37c..a1ce351f 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -84,7 +84,6 @@ def test_fully_connect(): fully_connect(net[8:12], net[12:16], TestSynapse()) - # This was previously visually inspected assert all( net.edges.global_post_comp_index == [ @@ -136,7 +135,6 @@ def test_sparse_connect(): sparse_connect(net[8:12], net[12:], TestSynapse(), p=0.5) - # This was previously visually inspected assert all( net.edges.global_post_comp_index == [ @@ -209,7 +207,7 @@ def test_connectivity_matrix_connect(): cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) assert np.all(cell_inds == incides_of_connected_cells) - # Test with different view ranges + # Test with different cell views net = jx.Network([cell for _ in range(4 * 4)]) connectivity_matrix_connect( net[1:4], net[2:6], TestSynapse(), m_by_n_adjacency_matrix @@ -219,7 +217,7 @@ def test_connectivity_matrix_connect(): cols = ["global_pre_comp_index", "global_post_comp_index"] comp_inds = nodes.loc[net.edges[cols].to_numpy().flatten()] cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) - # adjust the cell indices based on the view range passed + # adjust the cell indices based on the views passed incides_of_connected_cells[:, 0] += 1 incides_of_connected_cells[:, 1] += 2 assert np.all(cell_inds == incides_of_connected_cells) @@ -238,8 +236,3 @@ def test_connectivity_matrix_connect(): comp_inds = nodes.loc[net.edges[cols].to_numpy().flatten()] cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) assert np.all(cell_inds == incides_of_connected_cells) - - -if __name__ == "__main__": - - test_sparse_connect() From 23b31fe1b719f7879628892c6da7653479e736fe Mon Sep 17 00:00:00 2001 From: Kyra Kadhim Date: Mon, 25 Nov 2024 16:50:26 +0100 Subject: [PATCH 09/10] Updated sparse connectivity function to not use full adj mat, removed scoping, and added function for random post-comp selection --- jaxley/connect.py | 104 ++++++++++++++++++++++----------------- tests/test_connection.py | 32 ++---------- 2 files changed, 64 insertions(+), 72 deletions(-) diff --git a/jaxley/connect.py b/jaxley/connect.py index 10296648..bb65667c 100644 --- a/jaxley/connect.py +++ b/jaxley/connect.py @@ -18,6 +18,19 @@ def sample_comp(cell_view: "View", num: int = 1, replace=True) -> "CompartmentVi return np.random.choice(cell_view._comps_in_view, num, replace=replace) +def get_random_post_comps(post_cell_view: "View", num_post: int) -> "CompartmentView": + """Sample global compartment indices from all postsynaptic cells.""" + global_post_comp_indices = ( + post_cell_view.nodes.groupby("global_cell_index") + .sample(num_post, replace=True) + .index.to_numpy() + ) + global_post_comp_indices = global_post_comp_indices.reshape( + (-1, num_post), order="F" + ).ravel() + return global_post_comp_indices + + def connect( pre: "View", post: "View", @@ -61,33 +74,25 @@ def fully_connect( num_pre = len(pre_cell_view._cells_in_view) num_post = len(post_cell_view._cells_in_view) - # Get the indices of the connections, like it's a fully connected connectivity matrix + # Get the indices of the connections (from is the pre cell index) from_idx = np.repeat(range(0, num_pre), num_post) - to_idx = np.tile( - range(0, num_post), num_pre - ) # used only if random_post_comp is False # Pre-synapse at the zero-eth branch and zero-eth compartment global_pre_comp_indices = ( - pre_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() - ) # setting scope ensure that this works indep of current scope + pre_cell_view.nodes.groupby("global_cell_index").first()["global_comp_index"] + ).to_numpy() pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes if random_post_comp: - # Randomly sample the post-synaptic compartments - global_post_comp_indices = ( - post_cell_view.nodes.groupby("global_cell_index") - .sample(num_pre, replace=True) - .index.to_numpy() - ) - global_post_comp_indices = global_post_comp_indices.reshape( - (-1, num_pre), order="F" - ).ravel() + global_post_comp_indices = get_random_post_comps(post_cell_view, num_pre) else: # Post-synapse also at the zero-eth branch and zero-eth compartment + to_idx = np.tile(range(0, num_post), num_pre) global_post_comp_indices = ( - post_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() - ) + post_cell_view.nodes.groupby("global_cell_index").first()[ + "global_comp_index" + ] + ).to_numpy() global_post_comp_indices = global_post_comp_indices[to_idx] post_rows = post_cell_view.nodes.loc[global_post_comp_indices] @@ -117,31 +122,45 @@ def sparse_connect( num_pre = len(pre_cell_view._cells_in_view) num_post = len(post_cell_view._cells_in_view) - # Generate random cxns without duplicates --> respects p but memory intesive if extremely large n cells - connectivity_matrix = np.random.binomial(1, p, (num_pre, num_post)) - from_idx, to_idx = np.where(connectivity_matrix) + # Generate random cxns via Bernoulli trials (no duplicates), done in blocks of the + # connectivity matrix to save memory and time (smaller cut size saves memory, + # larger saves time) + cut_size = 100 # --> (100, 100) dim blocks + pre_cuts, pre_mod = divmod(num_pre, cut_size) + post_cuts, post_mod = divmod(num_post, cut_size) + + x_inds = [] + y_inds = [] + for i in range(pre_cuts + min(1, pre_mod)): + for j in range(post_cuts + min(1, post_mod)): + block = np.random.binomial(1, p, size=(cut_size, cut_size)) + xb, yb = np.where(block) + xb += i * cut_size + yb += j * cut_size + x_inds.append(xb) + y_inds.append(yb) + all_inds = np.stack((np.concatenate(x_inds), np.concatenate(y_inds)), axis=1) + # Filter out connections where either pre or post index is out of range + all_inds = all_inds[(all_inds[:, 0] < num_pre) & (all_inds[:, 1] < num_post)] + from_idx = all_inds[:, 0] + to_idx = all_inds[:, 1] + del all_inds # Pre-synapse at the zero-eth branch and zero-eth compartment global_pre_comp_indices = ( - pre_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() - ) # setting scope ensure that this works indep of current scope + pre_cell_view.nodes.groupby("global_cell_index").first()["global_comp_index"] + ).to_numpy() pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes if random_post_comp: - # Randomly sample the post-synaptic compartments - global_post_comp_indices = ( - post_cell_view.nodes.groupby("global_cell_index") - .sample(num_pre, replace=True) - .index.to_numpy() - ) - global_post_comp_indices = global_post_comp_indices.reshape( - (-1, num_pre), order="F" - ).ravel() + global_post_comp_indices = get_random_post_comps(post_cell_view, num_pre) else: # Post-synapse also at the zero-eth branch and zero-eth compartment global_post_comp_indices = ( - post_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() - ) + post_cell_view.nodes.groupby("global_cell_index").first()[ + "global_comp_index" + ] + ).to_numpy() post_rows = post_cell_view.select(nodes=global_post_comp_indices[to_idx]).nodes if len(pre_rows) > 0: @@ -183,24 +202,19 @@ def connectivity_matrix_connect( # Pre-synapse at the zero-eth branch and zero-eth compartment global_pre_comp_indices = ( - pre_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() - ) # setting scope ensure that this works indep of current scope + pre_cell_view.nodes.groupby("global_cell_index").first()["global_comp_index"] + ).to_numpy() pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes if random_post_comp: - global_post_comp_indices = ( - post_cell_view.nodes.groupby("global_cell_index") - .sample(len(from_idx), replace=True) - .index.to_numpy() - ) - global_post_comp_indices = global_post_comp_indices.reshape( - (-1, len(from_idx)), order="F" - ).ravel() + global_post_comp_indices = get_random_post_comps(post_cell_view, num_pre) else: # Post-synapse also at the zero-eth branch and zero-eth compartment global_post_comp_indices = ( - post_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() - ) + post_cell_view.nodes.groupby("global_cell_index").first()[ + "global_comp_index" + ] + ).to_numpy() post_rows = post_cell_view.select(nodes=global_post_comp_indices[to_idx]).nodes pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type) diff --git a/tests/test_connection.py b/tests/test_connection.py index 5b0f2847..4bb3711b 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -132,30 +132,8 @@ def test_sparse_connect(SimpleNet): sparse_connect(net[8:12], net[12:], TestSynapse(), p=0.5) assert all( - net.edges.global_post_comp_index - == [ - 48, - 60, - 72, - 84, - 60, - 84, - 48, - 72, - 84, - 48, - 60, - 156, - 168, - 180, - 144, - 156, - 180, - 156, - 180, - 144, - 180, - ] + net.edges.post_global_comp_index + == [64, 80, 96, 112, 64, 96, 64, 80, 112, 208, 224, 240, 192, 208, 240, 208] ) @@ -201,13 +179,13 @@ def test_connectivity_matrix_connect(SimpleNet): assert np.all(cell_inds == incides_of_connected_cells) # Test with different cell views - net = jx.Network([cell for _ in range(4 * 4)]) + net = SimpleNet(4 * 4, 3, 8) connectivity_matrix_connect( net[1:4], net[2:6], TestSynapse(), m_by_n_adjacency_matrix ) assert len(net.edges.index) == 5 nodes = net.nodes.set_index("global_comp_index") - cols = ["global_pre_comp_index", "global_post_comp_index"] + cols = ["pre_global_comp_index", "post_global_comp_index"] comp_inds = nodes.loc[net.edges[cols].to_numpy().flatten()] cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) # adjust the cell indices based on the views passed @@ -225,7 +203,7 @@ def test_connectivity_matrix_connect(SimpleNet): ) assert len(net.edges.index) == 5 nodes = net.nodes.set_index("global_comp_index") - cols = ["global_pre_comp_index", "global_post_comp_index"] + cols = ["pre_global_comp_index", "post_global_comp_index"] comp_inds = nodes.loc[net.edges[cols].to_numpy().flatten()] cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) assert np.all(cell_inds == incides_of_connected_cells) From c50468988c3f5ddbb0c64223ddd23119dcb4e460 Mon Sep 17 00:00:00 2001 From: Kyra Date: Mon, 25 Nov 2024 18:42:34 +0100 Subject: [PATCH 10/10] Cleaning up sparse connect a bit --- jaxley/connect.py | 44 +++++++++++++++++++------------------------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/jaxley/connect.py b/jaxley/connect.py index bb65667c..bf3a05c2 100644 --- a/jaxley/connect.py +++ b/jaxley/connect.py @@ -74,14 +74,10 @@ def fully_connect( num_pre = len(pre_cell_view._cells_in_view) num_post = len(post_cell_view._cells_in_view) - # Get the indices of the connections (from is the pre cell index) - from_idx = np.repeat(range(0, num_pre), num_post) - - # Pre-synapse at the zero-eth branch and zero-eth compartment - global_pre_comp_indices = ( - pre_cell_view.nodes.groupby("global_cell_index").first()["global_comp_index"] - ).to_numpy() - pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes + # Get a view of the zeroeth compartment of each cell as the pre compartments + pre_comps = pre_cell_view.scope("local").branch(0).comp(0).nodes.copy() + # Repeat rows `num_post` times + pre_rows = pre_comps.loc[pre_comps.index.repeat(num_post)].reset_index(drop=True) if random_post_comp: global_post_comp_indices = get_random_post_comps(post_cell_view, num_pre) @@ -126,25 +122,23 @@ def sparse_connect( # connectivity matrix to save memory and time (smaller cut size saves memory, # larger saves time) cut_size = 100 # --> (100, 100) dim blocks - pre_cuts, pre_mod = divmod(num_pre, cut_size) - post_cuts, post_mod = divmod(num_post, cut_size) - - x_inds = [] - y_inds = [] - for i in range(pre_cuts + min(1, pre_mod)): - for j in range(post_cuts + min(1, post_mod)): + pre_inds, post_inds = [], [] + for i in range((num_pre + cut_size - 1) // cut_size): + for j in range((num_post + cut_size - 1) // cut_size): block = np.random.binomial(1, p, size=(cut_size, cut_size)) - xb, yb = np.where(block) - xb += i * cut_size - yb += j * cut_size - x_inds.append(xb) - y_inds.append(yb) - all_inds = np.stack((np.concatenate(x_inds), np.concatenate(y_inds)), axis=1) + block_pre, block_post = np.where(block) + block_pre += i * cut_size # block inds --> full adj mat inds + block_post += j * cut_size # block inds --> full adj mat inds + pre_inds.append(block_pre) + post_inds.append(block_post) + pre_post_inds = np.stack( + (np.concatenate(pre_inds), np.concatenate(post_inds)), axis=1 + ) # Filter out connections where either pre or post index is out of range - all_inds = all_inds[(all_inds[:, 0] < num_pre) & (all_inds[:, 1] < num_post)] - from_idx = all_inds[:, 0] - to_idx = all_inds[:, 1] - del all_inds + pre_post_inds = pre_post_inds[ + (pre_post_inds[:, 0] < num_pre) & (pre_post_inds[:, 1] < num_post) + ] + from_idx, to_idx = pre_post_inds[:, 0], pre_post_inds[:, 1] # Pre-synapse at the zero-eth branch and zero-eth compartment global_pre_comp_indices = (