Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simpler edge indexing #487

Open
wants to merge 26 commits into
base: v1_0_release
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
dbb1b34
fix: v1 of new get_all_parameters and to_jax
jnsbck Oct 29, 2024
d5a7d03
enh: simplified and refactored steping currents in synapses and chann…
jnsbck Nov 7, 2024
4107a46
fix: cleanup
jnsbck Nov 7, 2024
050debc
fix: ran isort
jnsbck Nov 7, 2024
3ffda84
wip: rename channel and synapse params and enh to_jax
jnsbck Dec 4, 2024
f63b6ee
wip: make get_all_params work with new indexing
jnsbck Dec 5, 2024
b9dd411
wip: more tests passing, small refactor
jnsbck Dec 5, 2024
2bf99a8
wip: more tests passing some fixes
jnsbck Dec 5, 2024
0eaea44
wip: new lookup table added
jnsbck Dec 6, 2024
4b7395e
wip: more fixes
jnsbck Dec 6, 2024
150cf64
wip: save wip, bug hunting in _synapse_current voltages
jnsbck Dec 6, 2024
f1b0e1c
fix: fixed indexing
jnsbck Dec 7, 2024
0524151
fix: fix remaining indexing issues, tests passing (I think)
jnsbck Dec 7, 2024
8820561
wip: wip fixing multiple mechs with same param / state
jnsbck Dec 9, 2024
f139f15
wip: more refactoring in light of recent discussion about new channel…
jnsbck Dec 13, 2024
44c4a80
fix: fix jitting issues of to_jax!
jnsbck Dec 13, 2024
cd4f664
fix: all tests finally passing
jnsbck Dec 16, 2024
8bbc6ba
fix: ammend last commit
jnsbck Dec 16, 2024
5e31be8
fix: small fixes and comments added
jnsbck Dec 16, 2024
cc23b2b
fix: move some things around
jnsbck Dec 16, 2024
efee504
doc: add documentation
jnsbck Dec 16, 2024
b760e90
fix: fix param sharing
jnsbck Dec 17, 2024
62351b8
doc: update changelog
jnsbck Dec 23, 2024
983817c
fix: refactor of shared states and got rid of prepare_for_jax and oth…
jnsbck Jan 11, 2025
6f7f389
fix: major refactor of jaxnodes and fix regression tests.
jnsbck Jan 13, 2025
a204f63
fix: ammend prev commit
jnsbck Jan 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ net.vis()

- changelog added to CI (#537, #558, @jnsbck)

- Refactor of channel and synapse stepping internals and how the model is transferred to jax for more efficient and readable code (#487, @jnsbck).

# 0.5.0

### API changes
Expand Down
8 changes: 4 additions & 4 deletions jaxley/channels/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,22 @@ def change_name(self, new_name: str):
new_prefix = new_name + "_"

self._name = new_name
self.channel_params = {
self.params = {
(
new_prefix + key[len(old_prefix) :]
if key.startswith(old_prefix)
else key
): value
for key, value in self.channel_params.items()
for key, value in self.params.items()
}

self.channel_states = {
self.states = {
(
new_prefix + key[len(old_prefix) :]
if key.startswith(old_prefix)
else key
): value
for key, value in self.channel_states.items()
for key, value in self.states.items()
}
return self

Expand Down
4 changes: 2 additions & 2 deletions jaxley/channels/hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gNa": 0.12,
f"{prefix}_gK": 0.036,
f"{prefix}_gLeak": 0.0003,
f"{prefix}_eNa": 50.0,
f"{prefix}_eK": -77.0,
f"{prefix}_eLeak": -54.3,
}
self.channel_states = {
self.states = {
f"{prefix}_m": 0.2,
f"{prefix}_h": 0.2,
f"{prefix}_n": 0.2,
Expand Down
24 changes: 12 additions & 12 deletions jaxley/channels/pospischil.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gLeak": 1e-4,
f"{prefix}_eLeak": -70.0,
}
self.channel_states = {}
self.states = {}
self.current_name = f"i_{prefix}"

def update_states(
Expand Down Expand Up @@ -77,12 +77,12 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gNa": 50e-3,
"eNa": 50.0,
"vt": -60.0, # Global parameter, not prefixed with `Na`.
}
self.channel_states = {f"{prefix}_m": 0.2, f"{prefix}_h": 0.2}
self.states = {f"{prefix}_m": 0.2, f"{prefix}_h": 0.2}
self.current_name = f"i_Na"

def update_states(
Expand Down Expand Up @@ -148,12 +148,12 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gK": 5e-3,
"eK": -90.0,
"vt": -60.0, # Global parameter, not prefixed with `Na`.
}
self.channel_states = {f"{prefix}_n": 0.2}
self.states = {f"{prefix}_n": 0.2}
self.current_name = f"i_K"

def update_states(
Expand Down Expand Up @@ -204,12 +204,12 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gKm": 0.004e-3,
f"{prefix}_taumax": 4000.0,
f"eK": -90.0,
}
self.channel_states = {f"{prefix}_p": 0.2}
self.states = {f"{prefix}_p": 0.2}
self.current_name = f"i_K"

def update_states(
Expand Down Expand Up @@ -261,11 +261,11 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gCaL": 0.1e-3,
"eCa": 120.0,
}
self.channel_states = {f"{prefix}_q": 0.2, f"{prefix}_r": 0.2}
self.states = {f"{prefix}_q": 0.2, f"{prefix}_r": 0.2}
self.current_name = f"i_Ca"

def update_states(
Expand Down Expand Up @@ -329,12 +329,12 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gCaT": 0.4e-4,
f"{prefix}_vx": 2.0,
"eCa": 120.0, # Global parameter, not prefixed with `CaT`.
}
self.channel_states = {f"{prefix}_u": 0.2}
self.states = {f"{prefix}_u": 0.2}
self.current_name = f"i_Ca"

def update_states(
Expand Down
Loading