Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simpler edge indexing #487

Open
wants to merge 26 commits into
base: v1_0_release
Choose a base branch
from
Open

Simpler edge indexing #487

wants to merge 26 commits into from

Conversation

jnsbck
Copy link
Contributor

@jnsbck jnsbck commented Nov 7, 2024

Currently synapse and channel parameters / states are handled differently. While channel params are referred to with global indices, synapse params are referenced on a per synapse basis. This leads to very different implementations of synapse and channel updates. This PR makes the synapse indexing global, which simplifies several aspects of the code and allows for more function reuse.

This could potentially be simplified even further, such that channels and synapses can be handled through mostly the same functions.

@jnsbck jnsbck mentioned this pull request Nov 8, 2024
@jnsbck
Copy link
Contributor Author

jnsbck commented Nov 8, 2024

Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

I am a bit worried about memory consumption though. If every jaxedge has every state, then there will a lot of NaNs each of which contains memory. I think that, in the worst case, memory increases by N**2, where is the number of types of synapses.

Can we have a zoom about this? E.g. next week?

jaxley/modules/base.py Show resolved Hide resolved
jaxley/modules/network.py Show resolved Hide resolved
@jnsbck
Copy link
Contributor Author

jnsbck commented Nov 13, 2024

True! I did not think about this. I have two thoughts though. Why not do this for nodes on a per mechanism basis as well. And is there not a more straight forward way to avoid nans while still using global indexes?
But yes, maybe a quick zoom would he great :)

@michaeldeistler
Copy link
Contributor

I agree that we should also do it like this for nodes. Let's zoom next week!

@jnsbck jnsbck force-pushed the simpler_edge_indexing branch from f3ec50a to aa4ae5f Compare December 5, 2024 13:20
Comment on lines +764 to +774
def dtype_aware_concat(dfs):
concat_df = pd.concat(dfs, ignore_index=True)
# replace nans with Nones
# this correctly casts float(None) -> NaN, bool(None) -> NaN, etc.
concat_df[concat_df.isna()] = None
for col in concat_df.columns[concat_df.dtypes == "object"]:
for df in dfs:
if col in df.columns:
concat_df[col] = concat_df[col].astype(df[col].dtype)
break # first match is sufficient
return concat_df
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixes a bug, where dtype of bool columns was changed upon concatenation in:

comp1 = jx.Compartment()
comp2 = jx.Compartment()
comp2.insert(HH())
branch = jx.Branch([comp1, comp2)] #-> branch.nodes["HH"].dtype was `object` in this case

Comment on lines -148 to -150
self.synapse_param_names = []
self.synapse_state_names = []
self.synapse_names = []
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

obsolete

@@ -189,7 +185,9 @@ def __str__(self):

def __dir__(self):
base_dir = object.__dir__(self)
return sorted(base_dir + self.synapse_names + list(self.group_nodes.keys()))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

group_nodes was removed in #447.

@jnsbck jnsbck force-pushed the simpler_edge_indexing branch from 7b4e771 to b58c3dd Compare December 17, 2024 00:18
@jnsbck
Copy link
Contributor Author

jnsbck commented Dec 17, 2024

Main changes:

  • rename channel_params... synapse_states to params...states (allows to reuse stuff for both synapses and channels)
  • to_jax stores params/states on a per mechanism basis now -> no more NaNs. To do this some prep is needed in order to keep track of what goes where to find it again later. Hence I had to write a _prepare_for_jax method, that pre-computes indices and adds a lookup method for all states / params in the module. This should also make this easily adaptable if we add the global_param API change.
  • refactor of get_all_parameters and get_all_states
  • refactor of channel and synapse updates.

In general I think the code is a bit cleaner and more consolidated now (-150 line diff). Will also run regression tests. Maybe its even faster as well. I am not super happy about the _prepare_for_jax solution tbh, but it seems to work well fwiw. If you have feedback, I'd be very happy about any suggestions!

Lemme know what you think or if you have any questions.

The only thing I touched in terms of channels / synapse rewrite is the params/states rename. Does this warrant a seperate branch?

Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @jnsbck!

I am already off for the holidays so will not be able to do a detailed review, but:

  1. I think all your high-level explanations sound good!
  2. Yes, let's make a new v1.0 branch and merge it there because it breaks all channel models.

And yes, please run regression tests and ensure that all is good! Good to go then! Thanks!

@jnsbck jnsbck force-pushed the simpler_edge_indexing branch from a0183be to e17f925 Compare December 22, 2024 23:37
@jnsbck
Copy link
Contributor Author

jnsbck commented Dec 22, 2024

/test_regression

Copy link
Contributor

github-actions bot commented Dec 22, 2024

Regression Test Results

✅ Process completed

test_runtime(num_cells=1, artificial=False, connect=False, connection_prob=0.0, voltage_solver=jaxley.stone)
🟢 build_time: (-42.62% vs 1.651s).
🟢 compile_time: (-3.17% vs 19.652s).
🟢 run_time: (-16.57% vs 2.891s).

test_runtime(num_cells=1, artificial=False, connect=False, connection_prob=0.0, voltage_solver=jax.sparse)
🟢 build_time: (-3.26% vs 0.313s).
🟢 compile_time: (-30.37% vs 3.523s).
🟢 run_time: (-16.81% vs 2.454s).

test_runtime(num_cells=10, artificial=False, connect=True, connection_prob=0.1, voltage_solver=jaxley.stone)
🟢 build_time: (-36.81% vs 3.856s).
🟢 compile_time: (-15.62% vs 29.145s).
🟢 run_time: (-19.73% vs 18.551s).

test_runtime(num_cells=10, artificial=False, connect=True, connection_prob=0.1, voltage_solver=jax.sparse)
🟢 build_time: (-31.42% vs 1.825s).
🟢 compile_time: (-22.40% vs 21.890s).
🟢 run_time: (-18.60% vs 20.018s).

test_runtime(num_cells=1000, artificial=True, connect=True, connection_prob=0.001, voltage_solver=jaxley.stone)
🟢 build_time: (-4.67% vs 113.937s).
🟢 compile_time: (-22.13% vs 45.211s).
🟢 run_time: (-21.45% vs 41.206s).

test_runtime(num_cells=1000, artificial=True, connect=True, connection_prob=0.001, voltage_solver=jax.sparse)
🟢 build_time: (-2.18% vs 108.073s).
🟢 compile_time: (-19.37% vs 49.108s).
🟢 run_time: (-17.58% vs 46.364s).

@jnsbck jnsbck force-pushed the simpler_edge_indexing branch from e17f925 to 62351b8 Compare December 23, 2024 01:04
@jnsbck jnsbck changed the base branch from main to v1_0_release December 23, 2024 01:06
@jnsbck
Copy link
Contributor Author

jnsbck commented Jan 11, 2025

/test_regression

Copy link
Contributor

github-actions bot commented Jan 11, 2025

Regression Test Results

❌ Process completed

test_runtime(num_cells=1, artificial=False, connect=False, connection_prob=0.0, voltage_solver=jaxley.stone)
🟢 build_time: (-48.74% vs 1.665s).
🟢 compile_time: (-0.98% vs 19.612s).
🟢 run_time: (-2.90% vs 2.847s).

test_runtime(num_cells=1, artificial=False, connect=False, connection_prob=0.0, voltage_solver=jax.sparse)
🟢 build_time: (-8.37% vs 0.310s).
🟢 compile_time: (-17.82% vs 3.442s).
🟢 run_time: (-3.14% vs 2.487s).

@jnsbck
Copy link
Contributor Author

jnsbck commented Jan 11, 2025

I have done some refactoring and am now much happier with how the code looks. Locally tests pass, but since this is not a pull request to main (but v1.0), tests are not being run here. I also started another regression tests (see above), which I hope will be as quick as prev, i.e. faster than whats in main. EDIT: This did not work. Looking into it.

Would still be great if you could have a more thorough look before I merge it, when your back in the office.
Also, this PR should work with shared states for synapses as well now, which we did not work prior and which we do not check. Lemme know if I should add a test

@jnsbck
Copy link
Contributor Author

jnsbck commented Jan 13, 2025

...more refactoring and consolidation imo. Somewhat forward thinking, now to_jax moves the jx.Module to a single pytree, and makes handling global params and states much simpler now. Regression test should also pass again.

/test_regression

Copy link
Contributor

github-actions bot commented Jan 13, 2025

Regression Test Results

❌ Process completed

test_runtime(num_cells=1, artificial=False, connect=False, connection_prob=0.0, voltage_solver=jaxley.stone)
🟢 build_time: (-44.27% vs 1.643s).
🔴 compile_time: (+3.26% vs 20.012s).
🟢 run_time: (-5.27% vs 2.930s).

test_runtime(num_cells=1, artificial=False, connect=False, connection_prob=0.0, voltage_solver=jax.sparse)
🟢 build_time: (-7.21% vs 0.312s).
🟢 compile_time: (-10.35% vs 3.490s).
🟢 run_time: (-1.19% vs 2.461s).

test_runtime(num_cells=10, artificial=False, connect=True, connection_prob=0.1, voltage_solver=jaxley.stone)
🟢 build_time: (-37.05% vs 3.802s).
🟢 compile_time: (-3.07% vs 29.811s).
🟢 run_time: (-1.75% vs 18.879s).

test_runtime(num_cells=10, artificial=False, connect=True, connection_prob=0.1, voltage_solver=jax.sparse)
🟢 build_time: (-48.92% vs 1.833s).
🟢 compile_time: (-1.50% vs 21.809s).
🟢 run_time: (-3.93% vs 20.784s).

test_runtime(num_cells=1000, artificial=True, connect=True, connection_prob=0.001, voltage_solver=jaxley.stone)
🟢 build_time: (-1.26% vs 107.721s).
🟢 compile_time: (-2.78% vs 45.691s).
🔴 run_time: (+0.24% vs 40.729s).

test_runtime(num_cells=1000, artificial=True, connect=True, connection_prob=0.001, voltage_solver=jax.sparse)
🟢 build_time: (-0.03% vs 104.756s).
🟢 compile_time: (-3.64% vs 50.023s).
🟢 run_time: (-1.50% vs 46.854s).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants