Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multiple parameter values with data_set() #430

Closed
kyralianaka opened this issue Sep 22, 2024 · 2 comments
Closed

multiple parameter values with data_set() #430

kyralianaka opened this issue Sep 22, 2024 · 2 comments
Assignees

Comments

@kyralianaka
Copy link
Contributor

kyralianaka commented Sep 22, 2024

I've been having trouble using data_set to set the parameters of multiple synapse or cells at once as shown in this little example:

comp = jx.Compartment()
branch = jx.Branch(comp, nseg=1)
cell = jx.Cell([branch], parents=[-1])
network = jx.Network([cell, cell, cell])

jx.fully_connect(network.cell("all"), network.cell("all"), IonotropicSynapse())

n_edges = len(network.edges)
params = [{"IonotropicSynapse_gS": jnp.ones(n_edges) * 0.001}]

# n_cells = len(network.nodes)
# params = [{"radius": jnp.ones(n_cells)}]

network.cell(0).record()

def simulate(params):

    # Set the parameters with data_set
    pstates = None
    pstates = network.IonotropicSynapse.data_set("IonotropicSynapse_gS", list(params[0].values())[0], pstates)
    #pstates = network.data_set("radius", list(params[0].values())[0], pstates)

    v = jx.integrate(network, t_max=10, param_state=pstates)
    return v

s = simulate(params)

The issue ends up being that the shape of the indices array is incompatible with the shape of the values array (they are transposes). In comments in Jaxley (base.py lines 573-577), I see that inds is supposed to be shape (num_params, num_comps_per_param), but then this becomes (1, num_synapses) or (1, num_cells) in the example above, which is incompatible with the values array which gets projected to (num_params, 1), where num_params here is the number of edges or cells.

I have tried to adjust this on the user end, and maybe there's a way to do that that I haven't tried yet, but I think it would be nice to make the code above work. I don't have the time in the next two weeks, maybe later, but of course if someone else wants to have a look please do:)

@kyralianaka
Copy link
Contributor Author

I will do some more experimenting, but I think just changing the following two lines would allow both specifying single or multiple parameter values (one for each compartment, which conveniently works out to each cell when each cell is one compartment). In _data_set() line 396:
"val": jnp.atleast_1d(jnp.asarray(val)), --> "val": jnp.atleast_2d(jnp.asarray(val)),,
same as is done with the indices, and then in get_all_parameters(), called by integrate(), line 577:
params[key] = params[key].at[inds].set(set_param[:, None]) --> params[key] = params[key].at[inds].set(set_param).
I think it is then intuitive if the values and the indices can have the same shape. I'll look into it a bit more though.

@kyralianaka kyralianaka self-assigned this Oct 23, 2024
@kyralianaka
Copy link
Contributor Author

Can now be done using the select method to define indices from #447

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant