Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better viewing #447

Merged
merged 72 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
043d18c
wip: context manager view working and set also. at also works recursi…
jnsbck Sep 18, 2024
c65de28
wip: save wip adding required methods
jnsbck Sep 24, 2024
1bf3cae
wip: context manager view working and set also. at also works recursi…
jnsbck Sep 18, 2024
36186f5
wip: adding in more attrs
jnsbck Sep 24, 2024
921512f
wip: adding more attrs
jnsbck Sep 24, 2024
9192c3a
wip: started on vis and added remaining args
jnsbck Sep 25, 2024
be69b9d
wip: got sidetracked, added new plotting
jnsbck Sep 27, 2024
8e173fc
wip: save wip
jnsbck Sep 29, 2024
a67453b
wip: save wip
jnsbck Sep 30, 2024
77857d8
wip: save wip
jnsbck Oct 1, 2024
3b7f3c3
wip: save wip
jnsbck Oct 1, 2024
86d22a3
wip: connect added
jnsbck Oct 3, 2024
24901ab
wip: add trainables and fix connecting
jnsbck Oct 7, 2024
1cf2375
wip: add channel and synapse views
jnsbck Oct 9, 2024
6f83e62
mv: move view notebook
jnsbck Oct 9, 2024
9ddd43e
chore: restructure notebook. in prep for PR
jnsbck Oct 9, 2024
3bfe610
wip: save wip
jnsbck Oct 9, 2024
c5f2dca
wip: fixing bugs after rebase
jnsbck Oct 10, 2024
8836ee1
wip: PR ready. fixed make_trainables
jnsbck Oct 10, 2024
9c40a56
fix: ammend prev commit
jnsbck Oct 10, 2024
fbc6046
wip: in process of transferring proof of concept to jaxley/modules
jnsbck Oct 15, 2024
accf287
rm: rm old base from tracking
jnsbck Oct 15, 2024
ca9d7b6
fix: small fixes
jnsbck Oct 15, 2024
baba158
fix: save wip
jnsbck Oct 15, 2024
55d914c
wip: get methods from jaxley
jnsbck Oct 15, 2024
a559caf
fix: add latest changes from main
jnsbck Oct 15, 2024
c8420ea
wip: transferring proof of concept after rebase
jnsbck Oct 15, 2024
fa268fe
fix: add branch_inds_in_view back in
jnsbck Oct 15, 2024
19b2ea2
wip: all modules can be initialized
jnsbck Oct 15, 2024
eae31eb
rm: remove old views
jnsbck Oct 15, 2024
5e9d426
wip: new view now working for a lot of things
jnsbck Oct 15, 2024
9121acc
fix: all basic new view functions now work in module
jnsbck Oct 16, 2024
4e8f842
fix: small fixes. xyzr working for view, trainables and reorder updat…
jnsbck Oct 16, 2024
96cf668
fix: plot tests pass
jnsbck Oct 16, 2024
626b5c3
rm: rm new_view notebook
jnsbck Oct 16, 2024
3ab3e6b
fix: fix connect
jnsbck Oct 16, 2024
c1daf00
add: add edge functionality. More tests passing now
jnsbck Oct 16, 2024
c361193
enh/fix: better local indexing. better edge. small fixes to tests
jnsbck Oct 17, 2024
06cb57a
fix: fixes for edge view
jnsbck Oct 17, 2024
8dd3b33
fix: make more tests pass
jnsbck Oct 17, 2024
a7cb956
enh: small improvements to lazy indexing, iteration
jnsbck Oct 17, 2024
9efcef7
fix: refactor edge viewing
jnsbck Oct 17, 2024
7a49969
fix: make set_ncomp work
jnsbck Oct 17, 2024
e81c48a
fix: make set_ncomps work
jnsbck Oct 17, 2024
a688d64
fix: move_to tests pass
jnsbck Oct 17, 2024
1aa9c66
fix: more tests pass (groups)
jnsbck Oct 17, 2024
765286c
fix: connect tests passing
jnsbck Oct 17, 2024
5a72564
fix: even more tests passing now
jnsbck Oct 18, 2024
0d3754e
add: add distance method
jnsbck Oct 18, 2024
c00ff63
fix: roll back fix for making synapses pass make_trainable since it c…
jnsbck Oct 18, 2024
b829f3b
fix: kinda fix indexing tests by commenting out relevant sections and…
jnsbck Oct 18, 2024
80a0ef1
fix: fixed issues when recording synapse states
jnsbck Oct 18, 2024
9350ba3
fix: add fix mimicing old synapse viewing in get_all_parameters. all …
jnsbck Oct 18, 2024
3365b10
fix: ALL TESTS PASSING!!! YAY modified group test to make pass, since…
jnsbck Oct 18, 2024
522f771
doc: add documentation to view and new methods of Module
jnsbck Oct 20, 2024
efac669
doc: add more documentation
jnsbck Oct 21, 2024
c525605
doc: add more documentation
jnsbck Oct 21, 2024
83f5467
fix: add not_implemented to copy
jnsbck Oct 21, 2024
b1677b2
enh: add jaxnodes to View and ideas for new tests
jnsbck Oct 21, 2024
aafa0f2
fix: fix failing tests from prev commit
jnsbck Oct 21, 2024
2130c91
fix: ran black - ammend prev commit
jnsbck Oct 21, 2024
ad3cf0e
fix: fix some of the things listed in https://github.com/jaxleyverse/…
jnsbck Oct 21, 2024
4f2463a
fix: fix issues with jaxedges in View and trainables for synapses in …
jnsbck Oct 22, 2024
3246f6f
fix: fix failing tests
jnsbck Oct 22, 2024
9d87924
fix: address comments
jnsbck Oct 22, 2024
7f9bd2e
fix: allow autapses
jnsbck Oct 22, 2024
e91991a
fix: fix wrong kwarg in View
jnsbck Oct 22, 2024
2b58922
fix: fix type diffs between module and view. Add test for this
jnsbck Oct 22, 2024
6f9f09c
fix: fix printing trainables issue
jnsbck Oct 22, 2024
447c2e6
fix: rename view args to fix failing tests
jnsbck Oct 22, 2024
3f84730
fix: fix type of cumsum_nseg and rename filter to select
jnsbck Oct 22, 2024
9dd73ba
fix: rm autapse failure case from test and add hotfix for cumsum_nbra…
jnsbck Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 49 additions & 61 deletions jaxley/connect.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,26 @@
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Tuple

import numpy as np


def get_pre_post_inds(
pre_cell_view: "CellView", post_cell_view: "CellView"
) -> Tuple[np.ndarray, np.ndarray]:
"""Get the unique cell indices of the pre- and postsynaptic cells."""
pre_cell_inds = np.unique(pre_cell_view.view["cell_index"].to_numpy())
post_cell_inds = np.unique(post_cell_view.view["cell_index"].to_numpy())
return pre_cell_inds, post_cell_inds


def pre_comp_not_equal_post_comp(
pre: "CompartmentView", post: "CompartmentView"
) -> np.ndarray[bool]:
"""Check if pre and post compartments are different."""
cols = ["cell_index", "branch_index", "comp_index"]
return np.any(pre.view[cols].values != post.view[cols].values, axis=1)


def is_same_network(pre: "View", post: "View") -> bool:
"""Check if views are from the same network."""
is_in_net = "network" in pre.pointer.__class__.__name__.lower()
is_in_same_net = pre.pointer is post.pointer
is_in_net = "network" in pre.base.__class__.__name__.lower()
is_in_same_net = pre.base is post.base
return is_in_net and is_in_same_net


def sample_comp(
cell_view: "CellView", cell_idx: int, num: int = 1, replace=True
) -> "CompartmentView":
def sample_comp(cell_view: "View", num: int = 1, replace=True) -> "CompartmentView":
"""Sample a compartment from a cell.

Returns View with shape (num, num_cols)."""
cell_idx_view = lambda view, cell_idx: view[view["cell_index"] == cell_idx]
return cell_idx_view(cell_view.view, cell_idx).sample(num, replace=replace)
return np.random.choice(cell_view._comps_in_view, num, replace=replace)


def connect(
pre: "CompartmentView",
post: "CompartmentView",
pre: "View",
post: "View",
synapse_type: "Synapse",
):
"""Connect two compartments with a chemical synapse.
Expand All @@ -58,16 +36,13 @@ def connect(
assert is_same_network(
pre, post
), "Pre and post compartments must be part of the same network."
assert np.all(
pre_comp_not_equal_post_comp(pre, post)
), "Pre and post compartments must be different."

pre._append_multiple_synapses(pre.view, post.view, synapse_type)
pre.base._append_multiple_synapses(pre.nodes, post.nodes, synapse_type)


def fully_connect(
pre_cell_view: "CellView",
post_cell_view: "CellView",
pre_cell_view: "View",
post_cell_view: "View",
synapse_type: "Synapse",
):
"""Appends multiple connections which build a fully connected layer.
Expand All @@ -80,29 +55,29 @@ def fully_connect(
synapse_type: The synapse to append.
"""
# Get pre- and postsynaptic cell indices.
pre_cell_inds, post_cell_inds = get_pre_post_inds(pre_cell_view, post_cell_view)
num_pre, num_post = len(pre_cell_inds), len(post_cell_inds)
num_pre = len(pre_cell_view._cells_in_view)
num_post = len(post_cell_view._cells_in_view)

# Infer indices of (random) postsynaptic compartments.
global_post_indices = (
post_cell_view.view.groupby("cell_index")
post_cell_view.nodes.groupby("global_cell_index")
.sample(num_pre, replace=True)
.index.to_numpy()
)
global_post_indices = global_post_indices.reshape((-1, num_pre), order="F").ravel()
post_rows = post_cell_view.view.loc[global_post_indices]
post_rows = post_cell_view.nodes.loc[global_post_indices]

# Pre-synapse is at the zero-eth branch and zero-eth compartment.
pre_rows = pre_cell_view[0, 0].view
pre_rows = pre_cell_view.scope("local").branch(0).comp(0).nodes.copy()
jnsbck marked this conversation as resolved.
Show resolved Hide resolved
# Repeat rows `num_post` times. See SO 50788508.
pre_rows = pre_rows.loc[pre_rows.index.repeat(num_post)].reset_index(drop=True)

pre_cell_view._append_multiple_synapses(pre_rows, post_rows, synapse_type)
pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)


def sparse_connect(
pre_cell_view: "CellView",
post_cell_view: "CellView",
pre_cell_view: "View",
post_cell_view: "View",
synapse_type: "Synapse",
p: float,
):
Expand All @@ -117,8 +92,10 @@ def sparse_connect(
p: Probability of connection.
"""
# Get pre- and postsynaptic cell indices.
pre_cell_inds, post_cell_inds = get_pre_post_inds(pre_cell_view, post_cell_view)
num_pre, num_post = len(pre_cell_inds), len(post_cell_inds)
pre_cell_inds = pre_cell_view._cells_in_view
post_cell_inds = post_cell_view._cells_in_view
num_pre = len(pre_cell_inds)
num_post = len(post_cell_inds)

num_connections = np.random.binomial(num_pre * num_post, p)
pre_syn_neurons = np.random.choice(pre_cell_inds, size=num_connections)
Expand All @@ -131,20 +108,25 @@ def sparse_connect(

# Post-synapse is a randomly chosen branch and compartment.
global_post_indices = [
sample_comp(post_cell_view, cell_idx).index[0] for cell_idx in post_syn_neurons
sample_comp(post_cell_view.scope("global").cell(cell_idx))
for cell_idx in post_syn_neurons
]
post_rows = post_cell_view.view.loc[global_post_indices]
global_post_indices = (
np.hstack(global_post_indices) if len(global_post_indices) > 1 else []
)
post_rows = post_cell_view.base.nodes.loc[global_post_indices]

# Pre-synapse is at the zero-eth branch and zero-eth compartment.
global_pre_indices = pre_cell_view.pointer._cumsum_nseg_per_cell[pre_syn_neurons]
pre_rows = pre_cell_view.view.loc[global_pre_indices]
global_pre_indices = pre_cell_view.base._cumsum_nseg_per_cell[pre_syn_neurons]
pre_rows = pre_cell_view.base.nodes.loc[global_pre_indices]

pre_cell_view._append_multiple_synapses(pre_rows, post_rows, synapse_type)
if len(pre_rows) > 0:
pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)


def connectivity_matrix_connect(
pre_cell_view: "CellView",
post_cell_view: "CellView",
pre_cell_view: "View",
post_cell_view: "View",
synapse_type: "Synapse",
connectivity_matrix: np.ndarray[bool],
):
Expand All @@ -161,11 +143,12 @@ def connectivity_matrix_connect(
connectivity_matrix: A boolean matrix indicating the connections between cells.
"""
# Get pre- and postsynaptic cell indices.
pre_cell_inds, post_cell_inds = get_pre_post_inds(pre_cell_view, post_cell_view)
pre_cell_inds = pre_cell_view._cells_in_view
post_cell_inds = post_cell_view._cells_in_view

assert connectivity_matrix.shape == (
pre_cell_view.shape[0],
post_cell_view.shape[0],
len(pre_cell_inds),
len(post_cell_inds),
), "Connectivity matrix must have shape (num_pre, num_post)."
assert connectivity_matrix.dtype == bool, "Connectivity matrix must be boolean."

Expand All @@ -175,13 +158,18 @@ def connectivity_matrix_connect(
post_cell_inds = post_cell_inds[to_idx]

# Sample random postsynaptic compartments (global comp indices).
global_post_indices = [
sample_comp(post_cell_view, cell_idx).index[0] for cell_idx in post_cell_inds
]
post_rows = post_cell_view.view.loc[global_post_indices]
global_post_indices = np.hstack(
[
sample_comp(post_cell_view.scope("global").cell(cell_idx))
for cell_idx in post_cell_inds
]
)
post_rows = post_cell_view.nodes.loc[global_post_indices]

# Pre-synapse is at the zero-eth branch and zero-eth compartment.
global_pre_indices = pre_cell_view.pointer._cumsum_nseg_per_cell[pre_cell_inds]
pre_rows = pre_cell_view.view.loc[global_pre_indices]
global_pre_indices = (
pre_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy()
) # setting scope ensure that this works indep of current scope
pre_rows = pre_cell_view.select(nodes=global_pre_indices[pre_cell_inds]).nodes

pre_cell_view._append_multiple_synapses(pre_rows, post_rows, synapse_type)
pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)
8 changes: 4 additions & 4 deletions jaxley/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,23 +75,23 @@ def integrate(
if data_stimuli is not None:
externals["i"] = jnp.concatenate([externals["i"], data_stimuli[1]])
external_inds["i"] = jnp.concatenate(
[external_inds["i"], data_stimuli[2].comp_index.to_numpy()]
[external_inds["i"], data_stimuli[2].global_comp_index.to_numpy()]
)
else:
externals["i"] = data_stimuli[1]
external_inds["i"] = data_stimuli[2].comp_index.to_numpy()
external_inds["i"] = data_stimuli[2].global_comp_index.to_numpy()

# If a clamp is inserted, add it to the external inputs.
if data_clamps is not None:
state_name, clamps, inds = data_clamps
if state_name in module.externals.keys():
externals[state_name] = jnp.concatenate([externals[state_name], clamps])
external_inds[state_name] = jnp.concatenate(
[external_inds[state_name], inds.comp_index.to_numpy()]
[external_inds[state_name], inds.global_comp_index.to_numpy()]
)
else:
externals[state_name] = clamps
external_inds[state_name] = inds.comp_index.to_numpy()
external_inds[state_name] = inds.global_comp_index.to_numpy()

if not externals.keys():
# No stimulus was inserted and no clamp was set.
Expand Down
Loading
Loading