Skip to content

Commit

Permalink
Plotting of comp morphology (#432)
Browse files Browse the repository at this point in the history
* add: add new plot type to show the compartment

* enh/doc: working version. add documentation. rewrite comp centers

* fix: incorporate feedback. Fix plotting of comps/views

* fix: allow to plot comps as volume

* add: add tests

* fix: fix typo
  • Loading branch information
jnsbck authored Oct 2, 2024
1 parent 4db0a8f commit 4fc75a1
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 16 deletions.
50 changes: 36 additions & 14 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from jaxley.utils.debug_solver import compute_morphology_indices
from jaxley.utils.misc_utils import childview, concat_and_ignore_empty
from jaxley.utils.plot_utils import plot_morph
from jaxley.utils.plot_utils import plot_comps, plot_morph
from jaxley.utils.solver_utils import convert_to_csc


Expand Down Expand Up @@ -115,23 +115,37 @@ def __init__(self):
self.debug_states = {}

def _update_nodes_with_xyz(self):
"""Add xyz coordinates to nodes."""
"""Add xyz coordinates of compartment centers to nodes.
Note: For sake of performance, interpolation is not done for each branch,
but once along a concatenated (and padded) array of all branches.
"""
num_branches = len(self.xyzr)
x = np.linspace(
0.5 / self.nseg,
(num_branches * 1 - 0.5 / self.nseg),
num_branches * self.nseg,
comp_ends = (
np.linspace(0, 1, self.nseg + 1).reshape(1, -1).repeat(num_branches, 0)
)
x += np.arange(num_branches).repeat(
self.nseg
) # add offset to prevent branch loc overlap
xp = np.hstack(
[np.linspace(0, 1, x.shape[0]) + 2 * i for i, x in enumerate(self.xyzr)]
comp_ends = comp_ends + 2 * np.arange(num_branches).reshape(
-1, 1
) # inter-branch padding
comp_ends = comp_ends.reshape(-1)
branch_lens = []
for i, xyzr in enumerate(self.xyzr):
branch_len = np.sqrt(
np.sum(np.diff(xyzr[:, :3], axis=0) ** 2, axis=1)
).cumsum()
branch_len = np.hstack([np.array([0]), branch_len])
branch_len = branch_len / branch_len.max() + 2 * i # add padding like above
branch_len[np.isnan(branch_len)] = 0
branch_lens.append(branch_len)
branch_lens = np.hstack(branch_lens)
xyz = np.vstack(self.xyzr)[:, :3]
xyz = v_interp(comp_ends, branch_lens, xyz).reshape(
3, num_branches, self.nseg + 1
)
xyz = v_interp(x, xp, np.vstack(self.xyzr)[:, :3])
centers = ((xyz[:, :, 1:] + xyz[:, :, :-1]) / 2).reshape(3, -1).T
idcs = self.nodes["comp_index"]
self.nodes.loc[idcs, ["x", "y", "z"]] = xyz.T
return xyz.T
self.nodes.loc[idcs, ["x", "y", "z"]] = centers
return centers, xyz

def __repr__(self):
return f"{type(self).__name__} with {len(self.channels)} different channels. Use `.show()` for details."
Expand Down Expand Up @@ -1226,6 +1240,12 @@ def _vis(
morph_plot_kwargs: Dict,
) -> Axes:
branches_inds = view["branch_index"].to_numpy()

if type == "volume":
return plot_comps(
self, view, dims=dims, ax=ax, col=col, **morph_plot_kwargs
)

coords = []
for branch_ind in branches_inds:
assert not np.any(
Expand Down Expand Up @@ -1692,6 +1712,8 @@ def vis(
Args:
ax: An axis into which to plot.
col: The color for all branches.
type: Whether to plot as points ("scatter"), a line ("line") or the
actual volume of the compartment("volume").
dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of
two of them.
morph_plot_kwargs: Keyword arguments passed to the plotting function.
Expand Down
21 changes: 21 additions & 0 deletions jaxley/modules/compartment.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,31 @@ def vis(
self,
ax: Optional[Axes] = None,
col: str = "k",
type: str = "scatter",
dims: Tuple[int] = (0, 1),
morph_plot_kwargs: Dict = {},
) -> Axes:
"""Visualize the compartment.
Args:
ax: An axis into which to plot.
col: The color for all branches.
type: Whether to plot as point ("scatter") or the projected volume ("volume").
dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of
two of them.
morph_plot_kwargs: Keyword arguments passed to the plotting function.
"""
nodes = self.set_global_index_and_index(self.view)
if type == "volume":
return self.pointer._vis(
ax=ax,
col=col,
dims=dims,
view=nodes,
type="volume",
morph_plot_kwargs=morph_plot_kwargs,
)

return self.pointer._scatter(
ax=ax,
col=col,
Expand Down
203 changes: 201 additions & 2 deletions jaxley/utils/plot_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from numpy import ndarray
from scipy.spatial import ConvexHull

from jaxley.utils.cell_utils import v_interp


def plot_morph(
Expand All @@ -14,7 +19,7 @@ def plot_morph(
ax: Optional[Axes] = None,
type: str = "line",
morph_plot_kwargs: Dict = {},
):
) -> Axes:
"""Plot morphology.
Args:
Expand All @@ -38,3 +43,197 @@ def plot_morph(
raise NotImplementedError

return ax


def extract_outline(points: ndarray) -> ndarray:
"""Get the outline of a 2D/3D shape.
Extracts the subset of points which form the convex hull, i.e. the outline of
the input points.
Args:
points: An array of points / corrdinates.
Returns:
An array of points which form the convex hull.
"""
hull = ConvexHull(points)
hull_points = points[hull.vertices]
return hull_points


def compute_rotation_matrix(axis: ndarray, angle: float) -> ndarray:
"""
Return the rotation matrix associated with counterclockwise rotation about
the given axis by the given angle.
Can be used to rotate a coordinate vector by multiplying it with the rotation
matrix.
Args:
axis: The axis of rotation.
angle: The angle of rotation in radians.
Returns:
A 3x3 rotation matrix.
"""
axis = axis / np.sqrt(np.dot(axis, axis))
a = np.cos(angle / 2.0)
b, c, d = -axis * np.sin(angle / 2.0)
aa, bb, cc, dd = a * a, b * b, c * c, d * d
bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
return np.array(
[
[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
[2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
[2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc],
]
)


def plot_cylinder_projection(
orientation: ndarray,
length: float,
radius: float,
center: ndarray,
dims: Tuple[int],
ax: Axes = None,
**kwargs,
) -> Axes:
"""Plot the 2D projection of a cylinder on a cardinal plane.
Project the projection of a cylinder that is oriented in 3D space.
- Create cylinder mesh
- rotate cylinder mesh to orient it lengthwise along a given orientation vector.
- move its center
- project onto plane
- compute outline of projected mesh.
- fill area inside the outline
Args:
orientation: orientation vector. The cylinder will be oriented along this vector.
length: The length of the cylinder.
radius: The radius of the cylinder.
center: The x,y,z coordinates of the center of the cylinder.
dims: The dimensions to project the cylinder onto, i.e. [0,1] xy-plane.
ax: The matplotlib axis to plot on.
Returns:
Plot of the cylinder projection.
"""
if ax is None:
fig = plt.figure(figsize=(3, 3))
ax = fig.add_subplot(111)

# Normalize axis vector
orientation = np.array(orientation)
orientation = orientation / np.linalg.norm(orientation)

# Create a rotation matrix to align the cylinder with the given axis
z_axis = np.array([0, 0, 1])
rotation_axis = np.cross(z_axis, orientation)
rotation_angle = np.arccos(np.dot(z_axis, orientation))

if np.allclose(rotation_axis, 0):
rotation_matrix = np.eye(3)
else:
rotation_matrix = compute_rotation_matrix(rotation_axis, rotation_angle)

# Define cylinder
resolution = 100
t = np.linspace(0, 2 * np.pi, resolution)
z = np.linspace(-length / 2, length / 2, resolution)
T, Z = np.meshgrid(t, z)

X = radius * np.cos(T)
Y = radius * np.sin(T)

# Rotate cylinder
points = np.dot(rotation_matrix, np.array([X.flatten(), Y.flatten(), Z.flatten()]))
X = points.reshape(3, -1)

# project onto plane and move
X = X[dims]
X += np.array(center)[dims, np.newaxis]

# get outline of cylinder mesh
X = extract_outline(X.T).T

ax.fill(X[0].flatten(), X[1].flatten(), **kwargs)
return ax


def plot_comps(
module_or_view: Union["jx.Module", "jx.View"],
view: "jx.View",
dims: Tuple[int] = (0, 1),
col: str = "k",
ax: Optional[Axes] = None,
comp_plot_kwargs: Dict = {},
true_comp_length: bool = True,
) -> Axes:
"""Plot compartmentalized neural mrophology.
Plots the projection of the cylindrical compartments.
Args:
module_or_view: The module or view to plot.
view: The view of the module.
dims: The dimensions to project the cylinder onto, i.e. [0,1] xy-plane.
ax: The matplotlib axis to plot on.
comp_plot_kwargs: The plot kwargs for plt.fill.
true_comp_length: If True, the length of the compartment is used, i.e. the
length of the traced neurite. This means for zig-zagging neurites the
cylinders will be longer than the straight-line distance between the
start and end point of the neurite. This can lead to overlapping and
miss-aligned cylinders. Setting this False will use the straight-line
distance instead for nicer plots.
Returns:
Plot of the compartmentalized morphology.
"""
if ax is None:
fig = plt.figure(figsize=(3, 3))
ax = fig.add_subplot(111)

module = (
module_or_view.pointer
if "pointer" in module_or_view.__dict__
else module_or_view
)
assert not np.any(np.isnan(module.xyzr[0][:, :3])), "missing xyz coordinates."
if "x" not in module.nodes.columns:
module._update_nodes_with_xyz()
view[["x", "y", "z"]] = module.nodes.loc[view.index, ["x", "y", "z"]]

branches_inds = np.unique(view["branch_index"].to_numpy())
for idx in branches_inds:
locs = module.xyzr[idx][:, :3]
if locs.shape[0] == 1: # assume spherical comp
radius = module.xyzr[idx][:, -1]
ax.add_artist(plt.Circle(locs[0, dims], radius, color=col))
else:
lens = np.sqrt(np.nansum(np.diff(locs, axis=0) ** 2, axis=1))
lens = np.cumsum([0] + lens.tolist())
comp_ends = v_interp(
np.linspace(0, lens[-1], module.nseg + 1), lens, locs
).T
axes = np.diff(comp_ends, axis=0)
cylinder_lens = np.sqrt(np.sum(axes**2, axis=1))

branch_df = view[view["branch_index"] == idx]
for l, axis, (i, comp) in zip(cylinder_lens, axes, branch_df.iterrows()):
center = comp[["x", "y", "z"]]
radius = comp["radius"]
length = comp["length"] if true_comp_length else l
ax = plot_cylinder_projection(
axis,
length,
radius,
center,
np.array(dims),
ax,
color=col,
**comp_plot_kwargs,
)
return ax
18 changes: 18 additions & 0 deletions tests/test_plotting_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,21 @@ def test_mixed_network():
assert np.allclose(b[:, 1], a[:, 0], atol=1e-6)

_ = net.vis(detail="full")


def test_volume_plotting():
comp = jx.Compartment()
comp.compute_xyz()
branch = jx.Branch(comp, 4)
branch.compute_xyz()
cell = jx.Cell([branch] * 3, [-1, 0, 0])
cell.compute_xyz()
net = jx.Network([cell] * 4)
net.compute_xyz()

fig, ax = plt.subplots()
for module in [comp, branch, cell, net]:
module.vis(type="volume", ax=ax)
if not isinstance(module, jx.Compartment):
module[0].vis(type="volume", ax=ax)
plt.close(fig)

0 comments on commit 4fc75a1

Please sign in to comment.