diff --git a/jaxley/integrate.py b/jaxley/integrate.py index 06f0454a..f8c5b1a2 100644 --- a/jaxley/integrate.py +++ b/jaxley/integrate.py @@ -232,11 +232,12 @@ def integrate( if module.recordings.empty: raise ValueError("No recordings are set. Please set them.") recording_df = module.recordings.reset_index(drop=True) - rec_states, rec_inds, sort_inds = list(), list(), list() + rec_states, rec_inds, group_inds = list(), list(), list() for state, df_group in recording_df.groupby("state"): rec_states.append(state) rec_inds.append(df_group.rec_index.to_numpy()) - sort_inds.extend(df_group.index.to_list()) + group_inds.extend(df_group.index.to_list()) + sort_inds = jnp.argsort(jnp.asarray(group_inds)) # Shorten or pad stimulus depending on `t_max`. if t_max is not None: @@ -315,5 +316,5 @@ def _body_fun(state, externals): nested_lengths=checkpoint_lengths, ) recs = jnp.concatenate([init_recording, recordings[:nsteps_to_return]], axis=0).T - # recs = recs[sort_inds, :] # Sort recordings back to order that was set by user. + recs = recs[sort_inds, :] # Sort recordings back to order that was set by user. return (recs, all_states) if return_states else recs