Skip to content

Commit

Permalink
fix sorting
Browse files Browse the repository at this point in the history
  • Loading branch information
ntolley committed Dec 20, 2024
1 parent 1aa2523 commit 9595a21
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions jaxley/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,12 @@ def integrate(
if module.recordings.empty:
raise ValueError("No recordings are set. Please set them.")
recording_df = module.recordings.reset_index(drop=True)
rec_states, rec_inds, sort_inds = list(), list(), list()
rec_states, rec_inds, group_inds = list(), list(), list()
for state, df_group in recording_df.groupby("state"):
rec_states.append(state)
rec_inds.append(df_group.rec_index.to_numpy())
sort_inds.extend(df_group.index.to_list())
group_inds.extend(df_group.index.to_list())
sort_inds = jnp.argsort(jnp.asarray(group_inds))

# Shorten or pad stimulus depending on `t_max`.
if t_max is not None:
Expand Down Expand Up @@ -315,5 +316,5 @@ def _body_fun(state, externals):
nested_lengths=checkpoint_lengths,
)
recs = jnp.concatenate([init_recording, recordings[:nsteps_to_return]], axis=0).T
# recs = recs[sort_inds, :] # Sort recordings back to order that was set by user.
recs = recs[sort_inds, :] # Sort recordings back to order that was set by user.
return (recs, all_states) if return_states else recs

0 comments on commit 9595a21

Please sign in to comment.