From 6909167bb3d637166278224072361f9313e89f06 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Wed, 22 Nov 2023 12:28:19 +0100 Subject: [PATCH] Bugfix: composing compartments had wrong order --- jaxley/modules/branch.py | 3 + tests/test_composability_of_modules.py | 91 ++++++++++++++++++++++++++ 2 files changed, 94 insertions(+) create mode 100644 tests/test_composability_of_modules.py diff --git a/jaxley/modules/branch.py b/jaxley/modules/branch.py index f5b85fa3..92bc25a6 100644 --- a/jaxley/modules/branch.py +++ b/jaxley/modules/branch.py @@ -31,6 +31,9 @@ def __init__( compartment_list = [compartments for _ in range(nseg)] else: compartment_list = compartments + # Compartments are currently defined in reverse. See also #30. This `.reverse` + # is needed to make `tests/test_composability_of_modules.py` pass. + compartment_list.reverse() self._append_to_params_and_state(compartment_list) for comp in compartment_list: diff --git a/tests/test_composability_of_modules.py b/tests/test_composability_of_modules.py new file mode 100644 index 00000000..36c33b89 --- /dev/null +++ b/tests/test_composability_of_modules.py @@ -0,0 +1,91 @@ +import jax + +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platform_name", "cpu") + +import jax.numpy as jnp + +import jaxley as jx +from jaxley.channels import HHChannel + + +def test_compose_branch(): + """Test inserting to comp and composing to branch equals inserting to branch.""" + dt = 0.025 + t_max = 3.0 + current = jx.step_current(1.0, 1.0, 0.1, dt, t_max) + + comp1 = jx.Compartment() + comp1.insert(HHChannel()) + comp2 = jx.Compartment() + branch1 = jx.Branch([comp1, comp2]) + branch1.comp(0.0).record() + branch1.comp(0.0).stimulate(current) + + comp = jx.Compartment() + branch2 = jx.Branch(comp, nseg=2) + branch2.comp(0.0).insert(HHChannel()) + branch2.comp(0.0).record() + branch2.comp(0.0).stimulate(current) + + voltages1 = jx.integrate(branch1, delta_t=dt) + voltages2 = jx.integrate(branch2, delta_t=dt) + + assert jnp.max(jnp.abs(voltages1 - voltages2)) < 1e-8 + + +def test_compose_cell(): + """Test inserting to branch and composing to cell equals inserting to cell.""" + nseg_per_branch = 4 + dt = 0.025 + t_max = 3.0 + current = jx.step_current(1.0, 1.0, 0.1, dt, t_max) + + comp = jx.Compartment() + + branch1 = jx.Branch(comp, nseg_per_branch) + branch1.insert(HHChannel()) + branch2 = jx.Branch(comp, nseg_per_branch) + cell1 = jx.Cell([branch1, branch2], parents=[-1, 0]) + cell1.branch(0).comp(0.0).record() + cell1.branch(0).comp(0.0).stimulate(current) + + branch = jx.Branch(comp, nseg_per_branch) + cell2 = jx.Cell(branch, parents=[-1, 0]) + cell2.branch(0).insert(HHChannel()) + cell2.branch(0).comp(0.0).record() + cell2.branch(0).comp(0.0).stimulate(current) + + voltages1 = jx.integrate(cell1, delta_t=dt) + voltages2 = jx.integrate(cell2, delta_t=dt) + + assert jnp.max(jnp.abs(voltages1 - voltages2)) < 1e-8 + + +def test_compose_net(): + """Test inserting to cell and composing to net equals inserting to net.""" + nseg_per_branch = 4 + dt = 0.025 + t_max = 3.0 + current = jx.step_current(1.0, 1.0, 0.1, dt, t_max) + + comp = jx.Compartment() + branch = jx.Branch(comp, nseg_per_branch) + + cell1 = jx.Cell(branch, parents=[-1, 0, 0]) + cell1.insert(HHChannel()) + cell2 = jx.Cell(branch, parents=[-1, 0, 0]) + net1 = jx.Network([cell1, cell2], []) + net1.cell(0).branch(0).comp(0.0).record() + net1.cell(0).branch(0).comp(0.0).stimulate(current) + + cell = jx.Cell(branch, parents=[-1, 0, 0]) + net2 = jx.Network([cell, cell], []) + net2.cell(0).insert(HHChannel()) + net2.cell(0).branch(0).comp(0.0).record() + net2.cell(0).branch(0).comp(0.0).stimulate(current) + + voltages1 = jx.integrate(net1, delta_t=dt) + voltages2 = jx.integrate(net2, delta_t=dt) + + assert jnp.max(jnp.abs(voltages1 - voltages2)) < 1e-8