Skip to content

Commit

Permalink
Better viewing (#447)
Browse files Browse the repository at this point in the history
* wip: context manager view working and set also. at also works recursively

* wip: save wip adding required methods

* wip: context manager view working and set also. at also works recursively

* wip: adding in more attrs

* wip: adding more attrs

* wip: started on vis and added remaining args

* wip: got sidetracked, added new plotting

* wip: save wip

* wip: save wip

* wip: save wip

* wip: save wip

* wip: connect added

* wip: add trainables and fix connecting

* wip: add channel and synapse views

* mv: move view notebook

* chore: restructure notebook. in prep for PR

* wip: save wip

* wip: fixing bugs after rebase

* wip: PR ready. fixed make_trainables

* fix: ammend prev commit

* wip: in process of transferring proof of concept to jaxley/modules

* rm: rm old base from tracking

* fix: small fixes

* fix: save wip

* wip: get methods from jaxley

* fix: add latest changes from main

* wip: transferring proof of concept after rebase

* fix: add branch_inds_in_view back in

* wip: all modules can be initialized

* rm: remove old views

* wip: new view now working for a lot of things

* fix: all basic new view functions now work in module

* fix: small fixes. xyzr working for view, trainables and reorder update_local_inds

* fix: plot tests pass

* rm: rm new_view notebook

* fix: fix connect

* add: add edge functionality. More tests passing now

* enh/fix: better local indexing. better edge. small fixes to tests

* fix: fixes for edge view

* fix: make more tests pass

* enh: small improvements to lazy indexing, iteration

* fix: refactor edge viewing

* fix: make set_ncomp work

* fix: make set_ncomps work

* fix: move_to tests pass

* fix: more tests pass (groups)

* fix: connect tests passing

* fix: even more tests passing now

* add: add distance method

* fix: roll back fix for making synapses pass make_trainable since it causes issues elsewhere.

* fix: kinda fix indexing tests by commenting out relevant sections and adding todos for new tests

* fix: fixed issues when recording synapse states

* fix: add fix mimicing old synapse viewing in get_all_parameters. all but 1 tests pass

* fix: ALL TESTS PASSING!!! YAY modified group test to make pass, since new behaviour differs to prev

* doc: add documentation to view and new methods of Module

* doc: add more documentation

* doc: add more documentation

* fix: add not_implemented to copy

* enh: add jaxnodes to View and ideas for new tests

* fix: fix failing tests from prev commit

* fix: ran black - ammend prev commit

* fix: fix some of the things listed in https://github.com/jaxleyverse/jaxley/pull/447\#issuecomment-2427128474

* fix: fix issues with jaxedges in View and trainables for synapses in view

* fix: fix failing tests

* fix: address comments

* fix: allow autapses

* fix: fix wrong kwarg in View

* fix: fix type diffs between module and view. Add test for this

* fix: fix printing trainables issue

* fix: rename view args to fix failing tests

* fix: fix type of cumsum_nseg and rename filter to select

* fix: rm autapse failure case from test and add hotfix for cumsum_nbranches
  • Loading branch information
jnsbck authored Oct 23, 2024
1 parent d73202b commit 26acd32
Show file tree
Hide file tree
Showing 24 changed files with 1,508 additions and 1,729 deletions.
110 changes: 49 additions & 61 deletions jaxley/connect.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,26 @@
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Tuple

import numpy as np


def get_pre_post_inds(
pre_cell_view: "CellView", post_cell_view: "CellView"
) -> Tuple[np.ndarray, np.ndarray]:
"""Get the unique cell indices of the pre- and postsynaptic cells."""
pre_cell_inds = np.unique(pre_cell_view.view["cell_index"].to_numpy())
post_cell_inds = np.unique(post_cell_view.view["cell_index"].to_numpy())
return pre_cell_inds, post_cell_inds


def pre_comp_not_equal_post_comp(
pre: "CompartmentView", post: "CompartmentView"
) -> np.ndarray[bool]:
"""Check if pre and post compartments are different."""
cols = ["cell_index", "branch_index", "comp_index"]
return np.any(pre.view[cols].values != post.view[cols].values, axis=1)


def is_same_network(pre: "View", post: "View") -> bool:
"""Check if views are from the same network."""
is_in_net = "network" in pre.pointer.__class__.__name__.lower()
is_in_same_net = pre.pointer is post.pointer
is_in_net = "network" in pre.base.__class__.__name__.lower()
is_in_same_net = pre.base is post.base
return is_in_net and is_in_same_net


def sample_comp(
cell_view: "CellView", cell_idx: int, num: int = 1, replace=True
) -> "CompartmentView":
def sample_comp(cell_view: "View", num: int = 1, replace=True) -> "CompartmentView":
"""Sample a compartment from a cell.
Returns View with shape (num, num_cols)."""
cell_idx_view = lambda view, cell_idx: view[view["cell_index"] == cell_idx]
return cell_idx_view(cell_view.view, cell_idx).sample(num, replace=replace)
return np.random.choice(cell_view._comps_in_view, num, replace=replace)


def connect(
pre: "CompartmentView",
post: "CompartmentView",
pre: "View",
post: "View",
synapse_type: "Synapse",
):
"""Connect two compartments with a chemical synapse.
Expand All @@ -58,16 +36,13 @@ def connect(
assert is_same_network(
pre, post
), "Pre and post compartments must be part of the same network."
assert np.all(
pre_comp_not_equal_post_comp(pre, post)
), "Pre and post compartments must be different."

pre._append_multiple_synapses(pre.view, post.view, synapse_type)
pre.base._append_multiple_synapses(pre.nodes, post.nodes, synapse_type)


def fully_connect(
pre_cell_view: "CellView",
post_cell_view: "CellView",
pre_cell_view: "View",
post_cell_view: "View",
synapse_type: "Synapse",
):
"""Appends multiple connections which build a fully connected layer.
Expand All @@ -80,29 +55,29 @@ def fully_connect(
synapse_type: The synapse to append.
"""
# Get pre- and postsynaptic cell indices.
pre_cell_inds, post_cell_inds = get_pre_post_inds(pre_cell_view, post_cell_view)
num_pre, num_post = len(pre_cell_inds), len(post_cell_inds)
num_pre = len(pre_cell_view._cells_in_view)
num_post = len(post_cell_view._cells_in_view)

# Infer indices of (random) postsynaptic compartments.
global_post_indices = (
post_cell_view.view.groupby("cell_index")
post_cell_view.nodes.groupby("global_cell_index")
.sample(num_pre, replace=True)
.index.to_numpy()
)
global_post_indices = global_post_indices.reshape((-1, num_pre), order="F").ravel()
post_rows = post_cell_view.view.loc[global_post_indices]
post_rows = post_cell_view.nodes.loc[global_post_indices]

# Pre-synapse is at the zero-eth branch and zero-eth compartment.
pre_rows = pre_cell_view[0, 0].view
pre_rows = pre_cell_view.scope("local").branch(0).comp(0).nodes.copy()
# Repeat rows `num_post` times. See SO 50788508.
pre_rows = pre_rows.loc[pre_rows.index.repeat(num_post)].reset_index(drop=True)

pre_cell_view._append_multiple_synapses(pre_rows, post_rows, synapse_type)
pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)


def sparse_connect(
pre_cell_view: "CellView",
post_cell_view: "CellView",
pre_cell_view: "View",
post_cell_view: "View",
synapse_type: "Synapse",
p: float,
):
Expand All @@ -117,8 +92,10 @@ def sparse_connect(
p: Probability of connection.
"""
# Get pre- and postsynaptic cell indices.
pre_cell_inds, post_cell_inds = get_pre_post_inds(pre_cell_view, post_cell_view)
num_pre, num_post = len(pre_cell_inds), len(post_cell_inds)
pre_cell_inds = pre_cell_view._cells_in_view
post_cell_inds = post_cell_view._cells_in_view
num_pre = len(pre_cell_inds)
num_post = len(post_cell_inds)

num_connections = np.random.binomial(num_pre * num_post, p)
pre_syn_neurons = np.random.choice(pre_cell_inds, size=num_connections)
Expand All @@ -131,20 +108,25 @@ def sparse_connect(

# Post-synapse is a randomly chosen branch and compartment.
global_post_indices = [
sample_comp(post_cell_view, cell_idx).index[0] for cell_idx in post_syn_neurons
sample_comp(post_cell_view.scope("global").cell(cell_idx))
for cell_idx in post_syn_neurons
]
post_rows = post_cell_view.view.loc[global_post_indices]
global_post_indices = (
np.hstack(global_post_indices) if len(global_post_indices) > 1 else []
)
post_rows = post_cell_view.base.nodes.loc[global_post_indices]

# Pre-synapse is at the zero-eth branch and zero-eth compartment.
global_pre_indices = pre_cell_view.pointer._cumsum_nseg_per_cell[pre_syn_neurons]
pre_rows = pre_cell_view.view.loc[global_pre_indices]
global_pre_indices = pre_cell_view.base._cumsum_nseg_per_cell[pre_syn_neurons]
pre_rows = pre_cell_view.base.nodes.loc[global_pre_indices]

pre_cell_view._append_multiple_synapses(pre_rows, post_rows, synapse_type)
if len(pre_rows) > 0:
pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)


def connectivity_matrix_connect(
pre_cell_view: "CellView",
post_cell_view: "CellView",
pre_cell_view: "View",
post_cell_view: "View",
synapse_type: "Synapse",
connectivity_matrix: np.ndarray[bool],
):
Expand All @@ -161,11 +143,12 @@ def connectivity_matrix_connect(
connectivity_matrix: A boolean matrix indicating the connections between cells.
"""
# Get pre- and postsynaptic cell indices.
pre_cell_inds, post_cell_inds = get_pre_post_inds(pre_cell_view, post_cell_view)
pre_cell_inds = pre_cell_view._cells_in_view
post_cell_inds = post_cell_view._cells_in_view

assert connectivity_matrix.shape == (
pre_cell_view.shape[0],
post_cell_view.shape[0],
len(pre_cell_inds),
len(post_cell_inds),
), "Connectivity matrix must have shape (num_pre, num_post)."
assert connectivity_matrix.dtype == bool, "Connectivity matrix must be boolean."

Expand All @@ -175,13 +158,18 @@ def connectivity_matrix_connect(
post_cell_inds = post_cell_inds[to_idx]

# Sample random postsynaptic compartments (global comp indices).
global_post_indices = [
sample_comp(post_cell_view, cell_idx).index[0] for cell_idx in post_cell_inds
]
post_rows = post_cell_view.view.loc[global_post_indices]
global_post_indices = np.hstack(
[
sample_comp(post_cell_view.scope("global").cell(cell_idx))
for cell_idx in post_cell_inds
]
)
post_rows = post_cell_view.nodes.loc[global_post_indices]

# Pre-synapse is at the zero-eth branch and zero-eth compartment.
global_pre_indices = pre_cell_view.pointer._cumsum_nseg_per_cell[pre_cell_inds]
pre_rows = pre_cell_view.view.loc[global_pre_indices]
global_pre_indices = (
pre_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy()
) # setting scope ensure that this works indep of current scope
pre_rows = pre_cell_view.select(nodes=global_pre_indices[pre_cell_inds]).nodes

pre_cell_view._append_multiple_synapses(pre_rows, post_rows, synapse_type)
pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)
8 changes: 4 additions & 4 deletions jaxley/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,23 +75,23 @@ def integrate(
if data_stimuli is not None:
externals["i"] = jnp.concatenate([externals["i"], data_stimuli[1]])
external_inds["i"] = jnp.concatenate(
[external_inds["i"], data_stimuli[2].comp_index.to_numpy()]
[external_inds["i"], data_stimuli[2].global_comp_index.to_numpy()]
)
else:
externals["i"] = data_stimuli[1]
external_inds["i"] = data_stimuli[2].comp_index.to_numpy()
external_inds["i"] = data_stimuli[2].global_comp_index.to_numpy()

# If a clamp is inserted, add it to the external inputs.
if data_clamps is not None:
state_name, clamps, inds = data_clamps
if state_name in module.externals.keys():
externals[state_name] = jnp.concatenate([externals[state_name], clamps])
external_inds[state_name] = jnp.concatenate(
[external_inds[state_name], inds.comp_index.to_numpy()]
[external_inds[state_name], inds.global_comp_index.to_numpy()]
)
else:
externals[state_name] = clamps
external_inds[state_name] = inds.comp_index.to_numpy()
external_inds[state_name] = inds.global_comp_index.to_numpy()

if not externals.keys():
# No stimulus was inserted and no clamp was set.
Expand Down
Loading

0 comments on commit 26acd32

Please sign in to comment.