Skip to content

Commit

Permalink
Bugfix: composing compartments had wrong order
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Nov 22, 2023
1 parent 203b80f commit 6909167
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 0 deletions.
3 changes: 3 additions & 0 deletions jaxley/modules/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def __init__(
compartment_list = [compartments for _ in range(nseg)]
else:
compartment_list = compartments
# Compartments are currently defined in reverse. See also #30. This `.reverse`
# is needed to make `tests/test_composability_of_modules.py` pass.
compartment_list.reverse()

self._append_to_params_and_state(compartment_list)
for comp in compartment_list:
Expand Down
91 changes: 91 additions & 0 deletions tests/test_composability_of_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import jax

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

import jax.numpy as jnp

import jaxley as jx
from jaxley.channels import HHChannel


def test_compose_branch():
"""Test inserting to comp and composing to branch equals inserting to branch."""
dt = 0.025
t_max = 3.0
current = jx.step_current(1.0, 1.0, 0.1, dt, t_max)

comp1 = jx.Compartment()
comp1.insert(HHChannel())
comp2 = jx.Compartment()
branch1 = jx.Branch([comp1, comp2])
branch1.comp(0.0).record()
branch1.comp(0.0).stimulate(current)

comp = jx.Compartment()
branch2 = jx.Branch(comp, nseg=2)
branch2.comp(0.0).insert(HHChannel())
branch2.comp(0.0).record()
branch2.comp(0.0).stimulate(current)

voltages1 = jx.integrate(branch1, delta_t=dt)
voltages2 = jx.integrate(branch2, delta_t=dt)

assert jnp.max(jnp.abs(voltages1 - voltages2)) < 1e-8


def test_compose_cell():
"""Test inserting to branch and composing to cell equals inserting to cell."""
nseg_per_branch = 4
dt = 0.025
t_max = 3.0
current = jx.step_current(1.0, 1.0, 0.1, dt, t_max)

comp = jx.Compartment()

branch1 = jx.Branch(comp, nseg_per_branch)
branch1.insert(HHChannel())
branch2 = jx.Branch(comp, nseg_per_branch)
cell1 = jx.Cell([branch1, branch2], parents=[-1, 0])
cell1.branch(0).comp(0.0).record()
cell1.branch(0).comp(0.0).stimulate(current)

branch = jx.Branch(comp, nseg_per_branch)
cell2 = jx.Cell(branch, parents=[-1, 0])
cell2.branch(0).insert(HHChannel())
cell2.branch(0).comp(0.0).record()
cell2.branch(0).comp(0.0).stimulate(current)

voltages1 = jx.integrate(cell1, delta_t=dt)
voltages2 = jx.integrate(cell2, delta_t=dt)

assert jnp.max(jnp.abs(voltages1 - voltages2)) < 1e-8


def test_compose_net():
"""Test inserting to cell and composing to net equals inserting to net."""
nseg_per_branch = 4
dt = 0.025
t_max = 3.0
current = jx.step_current(1.0, 1.0, 0.1, dt, t_max)

comp = jx.Compartment()
branch = jx.Branch(comp, nseg_per_branch)

cell1 = jx.Cell(branch, parents=[-1, 0, 0])
cell1.insert(HHChannel())
cell2 = jx.Cell(branch, parents=[-1, 0, 0])
net1 = jx.Network([cell1, cell2], [])
net1.cell(0).branch(0).comp(0.0).record()
net1.cell(0).branch(0).comp(0.0).stimulate(current)

cell = jx.Cell(branch, parents=[-1, 0, 0])
net2 = jx.Network([cell, cell], [])
net2.cell(0).insert(HHChannel())
net2.cell(0).branch(0).comp(0.0).record()
net2.cell(0).branch(0).comp(0.0).stimulate(current)

voltages1 = jx.integrate(net1, delta_t=dt)
voltages2 = jx.integrate(net2, delta_t=dt)

assert jnp.max(jnp.abs(voltages1 - voltages2)) < 1e-8

0 comments on commit 6909167

Please sign in to comment.