From 4fc75a12f78e9403ea5890a857aad30b56c31bba Mon Sep 17 00:00:00 2001 From: jnsbck <65561470+jnsbck@users.noreply.github.com> Date: Wed, 2 Oct 2024 12:26:50 +0200 Subject: [PATCH] Plotting of comp morphology (#432) * add: add new plot type to show the compartment * enh/doc: working version. add documentation. rewrite comp centers * fix: incorporate feedback. Fix plotting of comps/views * fix: allow to plot comps as volume * add: add tests * fix: fix typo --- jaxley/modules/base.py | 50 ++++++--- jaxley/modules/compartment.py | 21 ++++ jaxley/utils/plot_utils.py | 203 +++++++++++++++++++++++++++++++++- tests/test_plotting_api.py | 18 +++ 4 files changed, 276 insertions(+), 16 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 54e4fade..c3e6a34c 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -33,7 +33,7 @@ ) from jaxley.utils.debug_solver import compute_morphology_indices from jaxley.utils.misc_utils import childview, concat_and_ignore_empty -from jaxley.utils.plot_utils import plot_morph +from jaxley.utils.plot_utils import plot_comps, plot_morph from jaxley.utils.solver_utils import convert_to_csc @@ -115,23 +115,37 @@ def __init__(self): self.debug_states = {} def _update_nodes_with_xyz(self): - """Add xyz coordinates to nodes.""" + """Add xyz coordinates of compartment centers to nodes. + + Note: For sake of performance, interpolation is not done for each branch, + but once along a concatenated (and padded) array of all branches. + """ num_branches = len(self.xyzr) - x = np.linspace( - 0.5 / self.nseg, - (num_branches * 1 - 0.5 / self.nseg), - num_branches * self.nseg, + comp_ends = ( + np.linspace(0, 1, self.nseg + 1).reshape(1, -1).repeat(num_branches, 0) ) - x += np.arange(num_branches).repeat( - self.nseg - ) # add offset to prevent branch loc overlap - xp = np.hstack( - [np.linspace(0, 1, x.shape[0]) + 2 * i for i, x in enumerate(self.xyzr)] + comp_ends = comp_ends + 2 * np.arange(num_branches).reshape( + -1, 1 + ) # inter-branch padding + comp_ends = comp_ends.reshape(-1) + branch_lens = [] + for i, xyzr in enumerate(self.xyzr): + branch_len = np.sqrt( + np.sum(np.diff(xyzr[:, :3], axis=0) ** 2, axis=1) + ).cumsum() + branch_len = np.hstack([np.array([0]), branch_len]) + branch_len = branch_len / branch_len.max() + 2 * i # add padding like above + branch_len[np.isnan(branch_len)] = 0 + branch_lens.append(branch_len) + branch_lens = np.hstack(branch_lens) + xyz = np.vstack(self.xyzr)[:, :3] + xyz = v_interp(comp_ends, branch_lens, xyz).reshape( + 3, num_branches, self.nseg + 1 ) - xyz = v_interp(x, xp, np.vstack(self.xyzr)[:, :3]) + centers = ((xyz[:, :, 1:] + xyz[:, :, :-1]) / 2).reshape(3, -1).T idcs = self.nodes["comp_index"] - self.nodes.loc[idcs, ["x", "y", "z"]] = xyz.T - return xyz.T + self.nodes.loc[idcs, ["x", "y", "z"]] = centers + return centers, xyz def __repr__(self): return f"{type(self).__name__} with {len(self.channels)} different channels. Use `.show()` for details." @@ -1226,6 +1240,12 @@ def _vis( morph_plot_kwargs: Dict, ) -> Axes: branches_inds = view["branch_index"].to_numpy() + + if type == "volume": + return plot_comps( + self, view, dims=dims, ax=ax, col=col, **morph_plot_kwargs + ) + coords = [] for branch_ind in branches_inds: assert not np.any( @@ -1692,6 +1712,8 @@ def vis( Args: ax: An axis into which to plot. col: The color for all branches. + type: Whether to plot as points ("scatter"), a line ("line") or the + actual volume of the compartment("volume"). dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them. morph_plot_kwargs: Keyword arguments passed to the plotting function. diff --git a/jaxley/modules/compartment.py b/jaxley/modules/compartment.py index 9a3571c3..6cd24f9f 100644 --- a/jaxley/modules/compartment.py +++ b/jaxley/modules/compartment.py @@ -152,10 +152,31 @@ def vis( self, ax: Optional[Axes] = None, col: str = "k", + type: str = "scatter", dims: Tuple[int] = (0, 1), morph_plot_kwargs: Dict = {}, ) -> Axes: + """Visualize the compartment. + + Args: + ax: An axis into which to plot. + col: The color for all branches. + type: Whether to plot as point ("scatter") or the projected volume ("volume"). + dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of + two of them. + morph_plot_kwargs: Keyword arguments passed to the plotting function. + """ nodes = self.set_global_index_and_index(self.view) + if type == "volume": + return self.pointer._vis( + ax=ax, + col=col, + dims=dims, + view=nodes, + type="volume", + morph_plot_kwargs=morph_plot_kwargs, + ) + return self.pointer._scatter( ax=ax, col=col, diff --git a/jaxley/utils/plot_utils.py b/jaxley/utils/plot_utils.py index fca91c53..efa863c8 100644 --- a/jaxley/utils/plot_utils.py +++ b/jaxley/utils/plot_utils.py @@ -1,10 +1,15 @@ # This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is # licensed under the Apache License Version 2.0, see -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, Union import matplotlib.pyplot as plt +import numpy as np from matplotlib.axes import Axes +from numpy import ndarray +from scipy.spatial import ConvexHull + +from jaxley.utils.cell_utils import v_interp def plot_morph( @@ -14,7 +19,7 @@ def plot_morph( ax: Optional[Axes] = None, type: str = "line", morph_plot_kwargs: Dict = {}, -): +) -> Axes: """Plot morphology. Args: @@ -38,3 +43,197 @@ def plot_morph( raise NotImplementedError return ax + + +def extract_outline(points: ndarray) -> ndarray: + """Get the outline of a 2D/3D shape. + + Extracts the subset of points which form the convex hull, i.e. the outline of + the input points. + + Args: + points: An array of points / corrdinates. + + Returns: + An array of points which form the convex hull. + """ + hull = ConvexHull(points) + hull_points = points[hull.vertices] + return hull_points + + +def compute_rotation_matrix(axis: ndarray, angle: float) -> ndarray: + """ + Return the rotation matrix associated with counterclockwise rotation about + the given axis by the given angle. + + Can be used to rotate a coordinate vector by multiplying it with the rotation + matrix. + + Args: + axis: The axis of rotation. + angle: The angle of rotation in radians. + + Returns: + A 3x3 rotation matrix. + """ + axis = axis / np.sqrt(np.dot(axis, axis)) + a = np.cos(angle / 2.0) + b, c, d = -axis * np.sin(angle / 2.0) + aa, bb, cc, dd = a * a, b * b, c * c, d * d + bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d + return np.array( + [ + [aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], + [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], + [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc], + ] + ) + + +def plot_cylinder_projection( + orientation: ndarray, + length: float, + radius: float, + center: ndarray, + dims: Tuple[int], + ax: Axes = None, + **kwargs, +) -> Axes: + """Plot the 2D projection of a cylinder on a cardinal plane. + + Project the projection of a cylinder that is oriented in 3D space. + - Create cylinder mesh + - rotate cylinder mesh to orient it lengthwise along a given orientation vector. + - move its center + - project onto plane + - compute outline of projected mesh. + - fill area inside the outline + + Args: + orientation: orientation vector. The cylinder will be oriented along this vector. + length: The length of the cylinder. + radius: The radius of the cylinder. + center: The x,y,z coordinates of the center of the cylinder. + dims: The dimensions to project the cylinder onto, i.e. [0,1] xy-plane. + ax: The matplotlib axis to plot on. + + Returns: + Plot of the cylinder projection. + """ + if ax is None: + fig = plt.figure(figsize=(3, 3)) + ax = fig.add_subplot(111) + + # Normalize axis vector + orientation = np.array(orientation) + orientation = orientation / np.linalg.norm(orientation) + + # Create a rotation matrix to align the cylinder with the given axis + z_axis = np.array([0, 0, 1]) + rotation_axis = np.cross(z_axis, orientation) + rotation_angle = np.arccos(np.dot(z_axis, orientation)) + + if np.allclose(rotation_axis, 0): + rotation_matrix = np.eye(3) + else: + rotation_matrix = compute_rotation_matrix(rotation_axis, rotation_angle) + + # Define cylinder + resolution = 100 + t = np.linspace(0, 2 * np.pi, resolution) + z = np.linspace(-length / 2, length / 2, resolution) + T, Z = np.meshgrid(t, z) + + X = radius * np.cos(T) + Y = radius * np.sin(T) + + # Rotate cylinder + points = np.dot(rotation_matrix, np.array([X.flatten(), Y.flatten(), Z.flatten()])) + X = points.reshape(3, -1) + + # project onto plane and move + X = X[dims] + X += np.array(center)[dims, np.newaxis] + + # get outline of cylinder mesh + X = extract_outline(X.T).T + + ax.fill(X[0].flatten(), X[1].flatten(), **kwargs) + return ax + + +def plot_comps( + module_or_view: Union["jx.Module", "jx.View"], + view: "jx.View", + dims: Tuple[int] = (0, 1), + col: str = "k", + ax: Optional[Axes] = None, + comp_plot_kwargs: Dict = {}, + true_comp_length: bool = True, +) -> Axes: + """Plot compartmentalized neural mrophology. + + Plots the projection of the cylindrical compartments. + + Args: + module_or_view: The module or view to plot. + view: The view of the module. + dims: The dimensions to project the cylinder onto, i.e. [0,1] xy-plane. + ax: The matplotlib axis to plot on. + comp_plot_kwargs: The plot kwargs for plt.fill. + true_comp_length: If True, the length of the compartment is used, i.e. the + length of the traced neurite. This means for zig-zagging neurites the + cylinders will be longer than the straight-line distance between the + start and end point of the neurite. This can lead to overlapping and + miss-aligned cylinders. Setting this False will use the straight-line + distance instead for nicer plots. + + Returns: + Plot of the compartmentalized morphology. + """ + if ax is None: + fig = plt.figure(figsize=(3, 3)) + ax = fig.add_subplot(111) + + module = ( + module_or_view.pointer + if "pointer" in module_or_view.__dict__ + else module_or_view + ) + assert not np.any(np.isnan(module.xyzr[0][:, :3])), "missing xyz coordinates." + if "x" not in module.nodes.columns: + module._update_nodes_with_xyz() + view[["x", "y", "z"]] = module.nodes.loc[view.index, ["x", "y", "z"]] + + branches_inds = np.unique(view["branch_index"].to_numpy()) + for idx in branches_inds: + locs = module.xyzr[idx][:, :3] + if locs.shape[0] == 1: # assume spherical comp + radius = module.xyzr[idx][:, -1] + ax.add_artist(plt.Circle(locs[0, dims], radius, color=col)) + else: + lens = np.sqrt(np.nansum(np.diff(locs, axis=0) ** 2, axis=1)) + lens = np.cumsum([0] + lens.tolist()) + comp_ends = v_interp( + np.linspace(0, lens[-1], module.nseg + 1), lens, locs + ).T + axes = np.diff(comp_ends, axis=0) + cylinder_lens = np.sqrt(np.sum(axes**2, axis=1)) + + branch_df = view[view["branch_index"] == idx] + for l, axis, (i, comp) in zip(cylinder_lens, axes, branch_df.iterrows()): + center = comp[["x", "y", "z"]] + radius = comp["radius"] + length = comp["length"] if true_comp_length else l + ax = plot_cylinder_projection( + axis, + length, + radius, + center, + np.array(dims), + ax, + color=col, + **comp_plot_kwargs, + ) + return ax diff --git a/tests/test_plotting_api.py b/tests/test_plotting_api.py index 1dc5c490..e0046c31 100644 --- a/tests/test_plotting_api.py +++ b/tests/test_plotting_api.py @@ -153,3 +153,21 @@ def test_mixed_network(): assert np.allclose(b[:, 1], a[:, 0], atol=1e-6) _ = net.vis(detail="full") + + +def test_volume_plotting(): + comp = jx.Compartment() + comp.compute_xyz() + branch = jx.Branch(comp, 4) + branch.compute_xyz() + cell = jx.Cell([branch] * 3, [-1, 0, 0]) + cell.compute_xyz() + net = jx.Network([cell] * 4) + net.compute_xyz() + + fig, ax = plt.subplots() + for module in [comp, branch, cell, net]: + module.vis(type="volume", ax=ax) + if not isinstance(module, jx.Compartment): + module[0].vis(type="volume", ax=ax) + plt.close(fig)