diff --git a/jaxley/connect.py b/jaxley/connect.py index 0d05893d..bf3a05c2 100644 --- a/jaxley/connect.py +++ b/jaxley/connect.py @@ -18,6 +18,19 @@ def sample_comp(cell_view: "View", num: int = 1, replace=True) -> "CompartmentVi return np.random.choice(cell_view._comps_in_view, num, replace=replace) +def get_random_post_comps(post_cell_view: "View", num_post: int) -> "CompartmentView": + """Sample global compartment indices from all postsynaptic cells.""" + global_post_comp_indices = ( + post_cell_view.nodes.groupby("global_cell_index") + .sample(num_post, replace=True) + .index.to_numpy() + ) + global_post_comp_indices = global_post_comp_indices.reshape( + (-1, num_post), order="F" + ).ravel() + return global_post_comp_indices + + def connect( pre: "View", post: "View", @@ -44,33 +57,40 @@ def fully_connect( pre_cell_view: "View", post_cell_view: "View", synapse_type: "Synapse", + random_post_comp: bool = False, ): """Appends multiple connections which build a fully connected layer. - Connections are from branch 0 location 0 to a randomly chosen branch and loc. + Connections are from branch 0 location 0 of the pre-synaptic cell to branch 0 + location 0 of the post-synaptic cell unless random_post_comp=True. Args: pre_cell_view: View of the presynaptic cell. post_cell_view: View of the postsynaptic cell. synapse_type: The synapse to append. + random_post_comp: If True, randomly samples the postsynaptic compartments. """ # Get pre- and postsynaptic cell indices. num_pre = len(pre_cell_view._cells_in_view) num_post = len(post_cell_view._cells_in_view) - # Infer indices of (random) postsynaptic compartments. - global_post_indices = ( - post_cell_view.nodes.groupby("global_cell_index") - .sample(num_pre, replace=True) - .index.to_numpy() - ) - global_post_indices = global_post_indices.reshape((-1, num_pre), order="F").ravel() - post_rows = post_cell_view.nodes.loc[global_post_indices] - - # Pre-synapse is at the zero-eth branch and zero-eth compartment. - pre_rows = pre_cell_view.scope("local").branch(0).comp(0).nodes.copy() - # Repeat rows `num_post` times. See SO 50788508. - pre_rows = pre_rows.loc[pre_rows.index.repeat(num_post)].reset_index(drop=True) + # Get a view of the zeroeth compartment of each cell as the pre compartments + pre_comps = pre_cell_view.scope("local").branch(0).comp(0).nodes.copy() + # Repeat rows `num_post` times + pre_rows = pre_comps.loc[pre_comps.index.repeat(num_post)].reset_index(drop=True) + + if random_post_comp: + global_post_comp_indices = get_random_post_comps(post_cell_view, num_pre) + else: + # Post-synapse also at the zero-eth branch and zero-eth compartment + to_idx = np.tile(range(0, num_post), num_pre) + global_post_comp_indices = ( + post_cell_view.nodes.groupby("global_cell_index").first()[ + "global_comp_index" + ] + ).to_numpy() + global_post_comp_indices = global_post_comp_indices[to_idx] + post_rows = post_cell_view.nodes.loc[global_post_comp_indices] pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type) @@ -80,45 +100,62 @@ def sparse_connect( post_cell_view: "View", synapse_type: "Synapse", p: float, + random_post_comp: bool = False, ): """Appends multiple connections which build a sparse, randomly connected layer. - Connections are from branch 0 location 0 to a randomly chosen branch and loc. + Connections are from branch 0 location 0 of the pre-synaptic cell to branch 0 + location 0 of the post-synaptic cell unless random_post_comp=True. Args: pre_cell_view: View of the presynaptic cell. post_cell_view: View of the postsynaptic cell. synapse_type: The synapse to append. p: Probability of connection. + random_post_comp: If True, randomly samples the postsynaptic compartments. """ # Get pre- and postsynaptic cell indices. - pre_cell_inds = pre_cell_view._cells_in_view - post_cell_inds = post_cell_view._cells_in_view - num_pre = len(pre_cell_inds) - num_post = len(post_cell_inds) - - num_connections = np.random.binomial(num_pre * num_post, p) - pre_syn_neurons = np.random.choice(pre_cell_inds, size=num_connections) - post_syn_neurons = np.random.choice(post_cell_inds, size=num_connections) - - # Sort the synapses only for convenience of inspecting `.edges`. - sorting = np.argsort(pre_syn_neurons) - pre_syn_neurons = pre_syn_neurons[sorting] - post_syn_neurons = post_syn_neurons[sorting] - - # Post-synapse is a randomly chosen branch and compartment. - global_post_indices = [ - sample_comp(post_cell_view.scope("global").cell(cell_idx)) - for cell_idx in post_syn_neurons - ] - global_post_indices = ( - np.hstack(global_post_indices) if len(global_post_indices) > 1 else [] - ) - post_rows = post_cell_view.base.nodes.loc[global_post_indices] + num_pre = len(pre_cell_view._cells_in_view) + num_post = len(post_cell_view._cells_in_view) - # Pre-synapse is at the zero-eth branch and zero-eth compartment. - global_pre_indices = pre_cell_view.base._cumsum_ncomp_per_cell[pre_syn_neurons] - pre_rows = pre_cell_view.base.nodes.loc[global_pre_indices] + # Generate random cxns via Bernoulli trials (no duplicates), done in blocks of the + # connectivity matrix to save memory and time (smaller cut size saves memory, + # larger saves time) + cut_size = 100 # --> (100, 100) dim blocks + pre_inds, post_inds = [], [] + for i in range((num_pre + cut_size - 1) // cut_size): + for j in range((num_post + cut_size - 1) // cut_size): + block = np.random.binomial(1, p, size=(cut_size, cut_size)) + block_pre, block_post = np.where(block) + block_pre += i * cut_size # block inds --> full adj mat inds + block_post += j * cut_size # block inds --> full adj mat inds + pre_inds.append(block_pre) + post_inds.append(block_post) + pre_post_inds = np.stack( + (np.concatenate(pre_inds), np.concatenate(post_inds)), axis=1 + ) + # Filter out connections where either pre or post index is out of range + pre_post_inds = pre_post_inds[ + (pre_post_inds[:, 0] < num_pre) & (pre_post_inds[:, 1] < num_post) + ] + from_idx, to_idx = pre_post_inds[:, 0], pre_post_inds[:, 1] + + # Pre-synapse at the zero-eth branch and zero-eth compartment + global_pre_comp_indices = ( + pre_cell_view.nodes.groupby("global_cell_index").first()["global_comp_index"] + ).to_numpy() + pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes + + if random_post_comp: + global_post_comp_indices = get_random_post_comps(post_cell_view, num_pre) + else: + # Post-synapse also at the zero-eth branch and zero-eth compartment + global_post_comp_indices = ( + post_cell_view.nodes.groupby("global_cell_index").first()[ + "global_comp_index" + ] + ).to_numpy() + post_rows = post_cell_view.select(nodes=global_post_comp_indices[to_idx]).nodes if len(pre_rows) > 0: pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type) @@ -129,49 +166,49 @@ def connectivity_matrix_connect( post_cell_view: "View", synapse_type: "Synapse", connectivity_matrix: np.ndarray[bool], + random_post_comp: bool = False, ): - """Appends multiple connections which build a custom connected network. + """Appends multiple connections according to a custom connectivity matrix. - Connects pre- and postsynaptic cells according to a custom connectivity matrix. Entries > 0 in the matrix indicate a connection between the corresponding cells. - Connections are from branch 0 location 0 to a randomly chosen branch and loc. + Connections are from branch 0 location 0 of the pre-synaptic cell to branch 0 + location 0 of the post-synaptic cell unless random_post_comp=True. Args: pre_cell_view: View of the presynaptic cell. post_cell_view: View of the postsynaptic cell. synapse_type: The synapse to append. connectivity_matrix: A boolean matrix indicating the connections between cells. + random_post_comp: If True, randomly samples the postsynaptic compartments. """ - # Get pre- and postsynaptic cell indices. - pre_cell_inds = pre_cell_view._cells_in_view - post_cell_inds = post_cell_view._cells_in_view - # setting scope ensure that this works indep of current scope - pre_nodes = pre_cell_view.scope("local").branch(0).comp(0).nodes - pre_nodes["index"] = pre_nodes.index - pre_cell_nodes = pre_nodes.set_index("global_cell_index") + # Get pre- and postsynaptic cell indices + num_pre = len(pre_cell_view._cells_in_view) + num_post = len(post_cell_view._cells_in_view) assert connectivity_matrix.shape == ( - len(pre_cell_inds), - len(post_cell_inds), + num_pre, + num_post, ), "Connectivity matrix must have shape (num_pre, num_post)." assert connectivity_matrix.dtype == bool, "Connectivity matrix must be boolean." - # get connection pairs from connectivity matrix + # Get pre to post connection pairs from connectivity matrix from_idx, to_idx = np.where(connectivity_matrix) - pre_cell_inds = pre_cell_inds[from_idx] - post_cell_inds = post_cell_inds[to_idx] - - # Sample random postsynaptic compartments (global comp indices). - global_post_indices = np.hstack( - [ - sample_comp(post_cell_view.scope("global").cell(cell_idx)) - for cell_idx in post_cell_inds - ] - ) - post_rows = post_cell_view.nodes.loc[global_post_indices] - # Pre-synapse is at the zero-eth branch and zero-eth compartment. - global_pre_indices = pre_cell_nodes.loc[pre_cell_inds, "index"].to_numpy() - pre_rows = pre_cell_view.select(nodes=global_pre_indices).nodes + # Pre-synapse at the zero-eth branch and zero-eth compartment + global_pre_comp_indices = ( + pre_cell_view.nodes.groupby("global_cell_index").first()["global_comp_index"] + ).to_numpy() + pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes + + if random_post_comp: + global_post_comp_indices = get_random_post_comps(post_cell_view, num_pre) + else: + # Post-synapse also at the zero-eth branch and zero-eth compartment + global_post_comp_indices = ( + post_cell_view.nodes.groupby("global_cell_index").first()[ + "global_comp_index" + ] + ).to_numpy() + post_rows = post_cell_view.select(nodes=global_post_comp_indices[to_idx]).nodes pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type) diff --git a/tests/jaxley_identical/test_basic_modules.py b/tests/jaxley_identical/test_basic_modules.py index 61d201f2..585efda2 100644 --- a/tests/jaxley_identical/test_basic_modules.py +++ b/tests/jaxley_identical/test_basic_modules.py @@ -302,13 +302,13 @@ def test_complex_net(voltage_solver, SimpleNet): _ = np.random.seed(0) pre = net.cell([0, 1, 2]) post = net.cell([3, 4, 5]) - fully_connect(pre, post, IonotropicSynapse()) - fully_connect(pre, post, TestSynapse()) + fully_connect(pre, post, IonotropicSynapse(), random_post_comp=True) + fully_connect(pre, post, TestSynapse(), random_post_comp=True) pre = net.cell([3, 4, 5]) post = net.cell(6) - fully_connect(pre, post, IonotropicSynapse()) - fully_connect(pre, post, TestSynapse()) + fully_connect(pre, post, IonotropicSynapse(), random_post_comp=True) + fully_connect(pre, post, TestSynapse(), random_post_comp=True) area = 2 * pi * 10.0 * 1.0 point_process_to_dist_factor = 100_000.0 / area diff --git a/tests/jaxley_identical/test_grad.py b/tests/jaxley_identical/test_grad.py index bfd1de84..3c541077 100644 --- a/tests/jaxley_identical/test_grad.py +++ b/tests/jaxley_identical/test_grad.py @@ -31,13 +31,13 @@ def test_network_grad(SimpleNet): _ = np.random.seed(0) pre = net.cell([0, 1, 2]) post = net.cell([3, 4, 5]) - fully_connect(pre, post, IonotropicSynapse()) - fully_connect(pre, post, TestSynapse()) + fully_connect(pre, post, IonotropicSynapse(), random_post_comp=True) + fully_connect(pre, post, TestSynapse(), random_post_comp=True) pre = net.cell([3, 4, 5]) post = net.cell(6) - fully_connect(pre, post, IonotropicSynapse()) - fully_connect(pre, post, TestSynapse()) + fully_connect(pre, post, IonotropicSynapse(), random_post_comp=True) + fully_connect(pre, post, TestSynapse(), random_post_comp=True) area = 2 * pi * 10.0 * 1.0 point_process_to_dist_factor = 100_000.0 / area diff --git a/tests/test_connection.py b/tests/test_connection.py index d8277e5a..4bb3711b 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -86,38 +86,38 @@ def test_fully_connect(): assert all( net.edges.post_global_comp_index == [ - 108, - 135, - 165, + 96, + 120, + 144, 168, - 99, - 123, - 151, - 177, - 115, - 141, - 162, - 172, - 119, - 126, - 156, - 169, - 294, - 329, - 345, - 379, - 295, - 317, - 356, - 365, - 311, - 325, - 355, - 375, - 302, - 320, - 352, - 375, + 96, + 120, + 144, + 168, + 96, + 120, + 144, + 168, + 96, + 120, + 144, + 168, + 288, + 312, + 336, + 360, + 288, + 312, + 336, + 360, + 288, + 312, + 336, + 360, + 288, + 312, + 336, + 360, ] ) @@ -132,27 +132,8 @@ def test_sparse_connect(SimpleNet): sparse_connect(net[8:12], net[12:], TestSynapse(), p=0.5) assert all( - [ - 63, - 59, - 65, - 86, - 80, - 58, - 92, - 85, - 168, - 145, - 189, - 153, - 180, - 190, - 184, - 163, - 159, - 179, - 182, - ] + net.edges.post_global_comp_index + == [64, 80, 96, 112, 64, 96, 64, 80, 112, 208, 224, 240, 192, 208, 240, 208] ) @@ -196,3 +177,33 @@ def test_connectivity_matrix_connect(SimpleNet): comp_inds = nodes.loc[net.edges[cols].to_numpy().flatten()] cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) assert np.all(cell_inds == incides_of_connected_cells) + + # Test with different cell views + net = SimpleNet(4 * 4, 3, 8) + connectivity_matrix_connect( + net[1:4], net[2:6], TestSynapse(), m_by_n_adjacency_matrix + ) + assert len(net.edges.index) == 5 + nodes = net.nodes.set_index("global_comp_index") + cols = ["pre_global_comp_index", "post_global_comp_index"] + comp_inds = nodes.loc[net.edges[cols].to_numpy().flatten()] + cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) + # adjust the cell indices based on the views passed + incides_of_connected_cells[:, 0] += 1 + incides_of_connected_cells[:, 1] += 2 + assert np.all(cell_inds == incides_of_connected_cells) + + # Test with single compartment cells + comp = jx.Compartment() + branch = jx.Branch([comp], nseg=1) + cell = jx.Cell([branch], parents=[-1]) + net = jx.Network([cell for _ in range(4 * 4)]) + connectivity_matrix_connect( + net[1:4], net[2:6], TestSynapse(), m_by_n_adjacency_matrix + ) + assert len(net.edges.index) == 5 + nodes = net.nodes.set_index("global_comp_index") + cols = ["pre_global_comp_index", "post_global_comp_index"] + comp_inds = nodes.loc[net.edges[cols].to_numpy().flatten()] + cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) + assert np.all(cell_inds == incides_of_connected_cells)