Skip to content

Commit

Permalink
fix: fix synapse_terminals (#545)
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck authored Dec 5, 2024
1 parent 36bab7b commit 79f311c
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,29 +428,28 @@ def vis(
color: str = "k",
synapse_color: str = "b",
dims: Tuple[int] = (0, 1),
type: str = "line",
cell_plot_kwargs: Dict = {},
synapse_plot_kwargs: Dict = {},
synapse_scatter_kwargs: Dict = {},
**kwargs, # absorb add. kwargs, i.e. to enable net.cell(0).vis(type="line")
) -> Axes:
"""Visualize the module.
Args:
detail: Either of [point, full]. `point` visualizes every neuron in the
network as a dot.
`full` plots the full morphology of every neuron. It requires that
`compute_xyz()` has been run and allows for indivual neurons to be
moved with `.move()`.
color: The color in which cells are plotted. Only takes effect if
`detail='full'`.
type: Either `line` or `scatter`. Only takes effect if `detail='full'`.
synapse_color: The color in which synapses are plotted. Only takes effect if
`detail='full'`.
`compute_xyz()` has been run.
color: The color in which cells are plotted.
synapse_color: The color in which synapses are plotted.
dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of
two of them.
cell_plot_kwargs: Keyword arguments passed to the plotting function for
cell morphologies. Only takes effect for `detail='full'`.
synapse_plot_kwargs: Keyword arguments passed to the plotting function for
syanpses. Only takes effect for `detail='full'`.
syanpses.
synapse_scatter_kwargs: Keyword arguments passed to the scatter function for
syanpse terminals.
"""
xyz0 = self.cell(0).xyzr[0][:, :3]
same_xyz = np.all([np.all(xyz0 == cell.xyzr[0][:, :3]) for cell in self.cells])
Expand All @@ -472,9 +471,7 @@ def vis(
pos = cell_to_point_xyz(cell)[dims_np]
ax.scatter(*pos, color=color, **cell_plot_kwargs)
elif detail == "full":
ax = super().vis(
dims=dims, color=color, ax=ax, type=type, **cell_plot_kwargs
)
ax = super().vis(dims=dims, color=color, ax=ax, **cell_plot_kwargs)
else:
raise ValueError("detail must be in {full, point}.")

Expand All @@ -485,7 +482,7 @@ def vis(
loc, comp = edge[[prepost + "_locs", prepost + "_global_comp_index"]]
branch = nodes.loc[comp, "global_branch_index"]
cell = nodes.loc[comp, "global_cell_index"]
branch_xyz = self.xyzr[branch]
branch_xyz = self.xyzr[branch][:, :3]

xyz_loc = branch_xyz
if detail == "point":
Expand All @@ -501,8 +498,10 @@ def vis(

prepost_locs.append(xyz_loc)
prepost_locs = np.stack(prepost_locs).T

ax.plot(*prepost_locs[dims_np], color=synapse_color, **synapse_plot_kwargs)
ax.scatter(
*prepost_locs[dims_np, 1], color=synapse_color, **synapse_scatter_kwargs
)

return ax

Expand Down

0 comments on commit 79f311c

Please sign in to comment.