diff --git a/elf/tracking/mamut.py b/elf/tracking/mamut.py index 5ae1c1f..2a91d80 100644 --- a/elf/tracking/mamut.py +++ b/elf/tracking/mamut.py @@ -1,8 +1,10 @@ +from typing import Tuple + import xml.etree.ElementTree as ET import numpy as np # -# adapted from code shared by @wolny +# Adapted from code shared by @wolny. # @@ -29,31 +31,38 @@ def _extract_tracks(root, flatten_spots=True): return tracks -def extract_tracks_as_volume(path, timepoint, shape, voxel_size, binary=False): +def extract_tracks_as_volume( + path: str, + timepoint: int, + shape: Tuple[int, int, int], + voxel_size: Tuple[float, float, float], + binary: bool = False +) -> np.ndarray: """Extract tracks as volume from MaMuT xml. Arguments: - path [str] - path to the xml file with tracks - timepoint [int] - timepoint for which to extract the tracks - shape [tuple[int]] - shape of the output volume - voxel_size [tuple[float]] - voxel size + path: Path to the xml file with tracks stored in MaMuT format. + timepoint: Timepoint for which to extract the tracks. + shape: Shape of the output volume. + voxel_size: Voxel size of the volume. + binary: Whether to return the volume as binary labels and not instance ids. Returns: - np.ndarray - + The volume with instance ids or binary ids. """ # get root XML element root = ET.parse(path).getroot() # retrieve all of the spots - all_spots = root.find('Model').find('AllSpots') + all_spots = root.find("Model").find("AllSpots") # get all spots for a given time frame - spots = next((s for s in all_spots if int(s.attrib['frame']) == timepoint), None) + spots = next((s for s in all_spots if int(s.attrib["frame"]) == timepoint), None) if spots is None: - raise RuntimeError('Could not find spots for time frame:', timepoint) + raise RuntimeError("Could not find spots for time frame:", timepoint) # get pixel coordinates pixel_coordinates = np.array([_to_zyx_coordinates( - [spot.attrib['POSITION_Z'], spot.attrib['POSITION_Y'], spot.attrib['POSITION_X']], + [spot.attrib["POSITION_Z"], spot.attrib["POSITION_Y"], spot.attrib["POSITION_X"]], np.array([vsize for vsize in voxel_size]) ) for spot in spots]) z = pixel_coordinates[:, 0] @@ -62,13 +71,13 @@ def extract_tracks_as_volume(path, timepoint, shape, voxel_size, binary=False): # extract the volume as binary if binary: - spot_mask = np.zeros(shape, dtype='bool') + spot_mask = np.zeros(shape, dtype="bool") spot_mask[z, y, x] = 1 return spot_mask # extract volume with track ids - spot_ids = [int(spot.attrib['ID']) for spot in spots] - track_volume = np.zeros(shape, dtype='uint32') + spot_ids = [int(spot.attrib["ID"]) for spot in spots] + track_volume = np.zeros(shape, dtype="uint32") tracks = _extract_tracks(root) spots_to_tracks = {spot: track for track, spots in tracks.items() for spot in spots} diff --git a/elf/tracking/motile_tracking.py b/elf/tracking/motile_tracking.py index 945625e..4ca5b93 100644 --- a/elf/tracking/motile_tracking.py +++ b/elf/tracking/motile_tracking.py @@ -1,6 +1,7 @@ """Functionality for tracking microscopy data with [motile](https://github.com/funkelab/motile). """ from copy import deepcopy +from typing import Dict, List, Optional, Tuple, Union import motile import networkx as nx @@ -14,17 +15,18 @@ # -# Utility functionality for motile: -# Solution parsing and visualization. +# Utility functionality for motile: Solution parsing and visualization. # def parse_result(solver, graph): + """@private + """ lineage_graph = nx.DiGraph() node_indicators = solver.get_variables(motile.variables.NodeSelected) edge_indicators = solver.get_variables(motile.variables.EdgeSelected) - # build new graphs that contain the selected nodes and tracking / lineage results + # Build new graphs that contain the selected nodes and tracking / lineage results. for node, index in node_indicators.items(): if solver.solution[index] > 0.5: lineage_graph.add_node(node, **graph.nodes[node]) @@ -33,26 +35,28 @@ def parse_result(solver, graph): if solver.solution[index] > 0.5: lineage_graph.add_edge(*edge, **graph.edges[edge]) - # use connected components to find the lineages + # Use connected components to find the lineages. lineages = nx.weakly_connected_components(lineage_graph) lineages = {lineage_id: list(lineage) for lineage_id, lineage in enumerate(lineages, 1)} return lineage_graph, lineages def lineage_graph_to_track_graph(lineage_graph, lineages): - # create a new graph that only contains the tracks by not connecting nodes with a degree of 2 + """@private + """ + # Create a new graph that only contains the tracks by not connecting nodes with a degree of 2. track_graph = nx.DiGraph() track_graph.add_nodes_from(lineage_graph.nodes) - # iterate over the edges to find splits and end tracks there + # Iterate over the edges to find splits and end tracks there. for (u, v), features in lineage_graph.edges.items(): out_edges = lineage_graph.out_edges(u) - # normal track continuation + # Normal track continuation if len(out_edges) == 1: track_graph.add_edge(u, v) - # otherwise track ends at division and we don't continue + # Otherwise track ends at division and we don't continue. - # use connected components to find the tracks + # Use connected components to find the tracks. tracks = nx.weakly_connected_components(track_graph) tracks = {track_id: list(track) for track_id, track in enumerate(tracks, 1)} @@ -60,12 +64,14 @@ def lineage_graph_to_track_graph(lineage_graph, lineages): def get_node_assignment(node_ids, assignments): - # generate a dictionary that maps each node id (= segment id) to its assignment + """@private + """ + # Generate a dictionary that maps each node id (= segment id) to its assignment. node_assignment = { node_id: assignment_id for assignment_id, nodes in assignments.items() for node_id in nodes } - # everything that was not selected gets mapped to 0 + # Everything that was not selected gets mapped to 0. not_selected = list(set(node_ids) - set(node_assignment.keys())) node_assignment.update({not_select: 0 for not_select in not_selected}) @@ -73,59 +79,87 @@ def get_node_assignment(node_ids, assignments): def recolor_segmentation(segmentation, node_to_assignment): - # we need to add a value for mapping 0, otherwise the function fails + """@private + """ + # We need to add a value for mapping 0, otherwise the function fails. node_to_assignment_ = deepcopy(node_to_assignment) node_to_assignment_[0] = 0 recolored_segmentation = takeDict(node_to_assignment_, segmentation) return recolored_segmentation -def create_data_for_track_layer(segmentation, lineage_graph, node_to_track): - # compute regionpros and extract centroids +def create_data_for_track_layer(segmentation, lineage_graph, node_to_track, skip_zero=True): + """@private + """ + # Compute regionpros and extract centroids. props = regionprops(segmentation) centroids = {prop.label: prop.centroid for prop in props} - # create the track data representation for napari - track_data = [ + # Create the track data representation for napari, which expects: + # track_id, timepoint, (z), y, x + track_data = np.array([ [node_to_track[node_id]] + list(centroid) for node_id, centroid in centroids.items() if node_id in node_to_track - ] + ]) + if skip_zero: + track_data = track_data[track_data[:, 0] != 0] - # create the parent graph for tracks + # Order the track data by track_id and timepoint. + sorted_indices = np.lexsort((track_data[:, 1], track_data[:, 0])) + track_data = track_data[sorted_indices] + + # Create the parent graph for the tracks. parent_graph = {} for (u, v), features in lineage_graph.edges.items(): out_edges = lineage_graph.out_edges(u) if len(out_edges) == 2: track_u, track_v = node_to_track[u], node_to_track[v] + if skip_zero and (track_u == 0 or track_v == 0): + continue parent_graph[track_v] = parent_graph.get(track_v, []) + [track_u] return track_data, parent_graph # -# Utility functions for constructing motile tracking problems +# Utility functions for constructing motile tracking problems. # -# TODO exppose the relevant weights and constants! +# TODO expose the relevant weights and constants! def construct_problem( - segmentation, - node_costs, - edges_and_costs, - max_parents=1, - max_children=2, -): + segmentation: np.ndarray, + node_costs: np.ndarray, + edges_and_costs: List[Dict[str, Union[int, float]]], + max_parents: int = 1, + max_children: int = 2, +) -> Tuple[motile.solver.Solver, motile.track_graph.TrackGraph]: + """Construct a motile tracking problem from a segmentation timeseries. + + Args: + segmentation: The segmentation timeseries. + node_costs: The node selection costs. + edges_and_costs: The edge selection costs. + max_parents: The maximal number of parents. + Corresponding to the maximal number of edges to the previous time point. + max_children: The maximal number of children. + Corresponding to the maximal number of edges to the next time point. + + Returns: + The motile solver. + The motile tracking graph. + """ node_ids, indexes = np.unique(segmentation, return_index=True) indexes = np.unravel_index(indexes, shape=segmentation.shape) timeframes = indexes[0] - # get rid of 0 + # Get rid of 0. if node_ids[0] == 0: node_ids, timeframes = node_ids[1:], timeframes[1:] assert len(node_ids) == len(timeframes) graph = nx.DiGraph() - # if the node function is not passed then we assume that all nodes should be selected + # If the node function is not passed then we assume that all nodes should be selected. assert len(node_costs) == len(node_ids) nodes = [ {"id": node_id, "score": score, "t": t} for node_id, score, t in zip(node_ids, node_costs, timeframes) @@ -134,28 +168,27 @@ def construct_problem( graph.add_nodes_from([(node["id"], node) for node in nodes]) graph.add_edges_from([(edge["source"], edge["target"], edge) for edge in edges_and_costs]) - # construct da graph + # Get the tracking graph and the motile solver. graph = motile.TrackGraph(graph) solver = motile.Solver(graph) - # we can do linear reweighting of the costs: a * x + b - # where: a=weight, b=constant - solver.add_costs(costs.NodeSelection(weight=-1.0, attribute="score", constant=0)) - solver.add_costs(costs.EdgeSelection(weight=-1.0, attribute="score", constant=0)) + # We can do linear reweighting of the costs: a * x + b, where: a=weight, b=constant. + solver.add_cost(costs.NodeSelection(weight=-1.0, attribute="score", constant=0)) + solver.add_cost(costs.EdgeSelection(weight=-1.0, attribute="score", constant=0)) - # add the constraints: we allow for divisions (max childeren = 2) - solver.add_constraints(constraints.MaxParents(max_parents)) - solver.add_constraints(constraints.MaxChildren(max_children)) + # Add the constraints: we allow for divisions (max childeren = 2). + solver.add_constraint(constraints.MaxParents(max_parents)) + solver.add_constraint(constraints.MaxChildren(max_children)) - # add costs for appearance and divisions - solver.add_costs(costs.Appear(constant=1.0)) - solver.add_costs(costs.Split(constant=1.0)) + # Add costs for appearance and divisions. + solver.add_cost(costs.Appear(constant=1.0)) + solver.add_cost(costs.Split(constant=1.0)) return solver, graph # -# Motile based tracking +# Motile based tracking. # @@ -167,75 +200,114 @@ def _track_with_motile_impl( node_selection_cost=0.95, **problem_kwargs, ): - # relabel sthe segmentation so that the ids are unique across time. - # if `relabel_segmentation is False` the segmentation has to be in the correct format already + # Relabel the segmentation so that the ids are unique across time. + # If `relabel_segmentation is False` the segmentation has to be in the correct format already. if relabel_segmentation: segmentation = utils.relabel_segmentation_across_time(segmentation) - # compute the node selection costs. - # if `node_cost_function` is passed it is used to compute the costs. - # otherwise we set a fixed node selection cost. + # Compute the node selection costs. + # If `node_cost_function` is passed it is used to compute the costs. + # Otherwise we set a fixed node selection cost. if node_cost_function is None: n_nodes = int(segmentation.max()) node_costs = np.full(n_nodes, node_selection_cost) else: node_costs = node_cost_function(segmentation) - # compute the edges and edge selection cost. - # if `edge_cost_function` is not given we use the default approach. - # (currently based on overlap of adjacent slices) + # Compute the edges and edge selection cost. + # If `edge_cost_function` is not given we use the default approach, based on overlap of adjacent slices. if edge_cost_function is None: edge_cost_function = utils.compute_edges_from_overlap edges_and_costs = edge_cost_function(segmentation) - # construct the problem + # Construct and solve the tracking problem. solver, graph = construct_problem(segmentation, node_costs, edges_and_costs, **problem_kwargs) - - # solver the problem solver.solve() - return solver, graph + return solver, graph, segmentation def track_with_motile( - segmentation, - relabel_segmentation=True, - node_cost_function=None, - edge_cost_function=None, - node_selection_cost=0.95, + segmentation: np.ndarray, + relabel_segmentation: bool = True, + node_cost_function: Optional[callable] = None, + edge_cost_function: Optional[callable] = None, + node_selection_cost: float = 0.95, **problem_kwargs, -): +) -> Tuple[np.ndarray, nx.DiGraph, Dict[int, List[int]], nx.DiGraph, Dict[int, List[int]]]: """Track segmented objects across time with motile. - Note: this will relabel the segmentation unless `relabel_segmentation=False` + Args: + segmentation: The input segmentation. + relabel_segmentation: Whether to relabel the segmentation so that ids are unique across time. + If set to False, then unique ids across time have to be ensured in the input. + node_cost_function: Function for computing costs for node selection. + If not given, then the constant factor `node_selection_cost` is used. + edge_cost_function: Function for computing costs for edge selection. + If not given, then the function `utils.compute_edges_from_overlap` is used. + node_selection_cost: Node selection cost. + problem_kwargs: Additional keyword arguments for constructing the tracking problem. + + Returns: + The input segmentation after relabeling. + The lineage graph, a directed graph that connects track ids across divisions or fusions. + Map of lineage ids to track ids. + The track graph, a directed graph that connects segmentation ids across time points. + Map of track ids to segmentation ids. """ - - solver, graph = _track_with_motile_impl( + solver, graph, segmentation = _track_with_motile_impl( segmentation, relabel_segmentation, node_cost_function, edge_cost_function, node_selection_cost, **problem_kwargs, ) - # parse solution lineage_graph, lineages = parse_result(solver, graph) track_graph, tracks = lineage_graph_to_track_graph(lineage_graph, lineages) return segmentation, lineage_graph, lineages, track_graph, tracks -def get_representation_for_napari(segmentation, lineage_graph, lineages, tracks, color_by_lineage=True): - +def get_representation_for_napari( + segmentation: np.ndarray, + lineage_graph: nx.DiGraph, + lineages: Dict[int, List[int]], + tracks: Dict[int, List[int]], + color_by_lineage: bool = True, +) -> Tuple[np.ndarray, np.ndarray, Dict[int, List[int]]]: + """Convert tracking result from motile into representation for napari. + + The output of this function can be passed to `napari.add_tracks` like this: + ``` + tracking_result, track_data, parent_graph = get_representation_for_napari(...) + viewer = napari.Viewer() + viewer.add_labels(tracking_result) + viewer.add_tracks(track_data, graph=parent_graph) + napari.run() + ``` + + Args: + segmentation: The input segmentation after relabeling. + lineage_graph: The lineage graph result from tracking. + lineages: The lineage assignment result from tracking. + tracks: The track assignment result from tracking. + color_by_lineage: Whether to color the tracking result by lineage id or by track id. + + Returns: + The relabeled segmentation, where each segment id is either colored by the lineage id or track id. + The track data for the napari tracks layer, which is a table containing track_id, timepoint, (z), y, x. + The parent graph, which maps each track id to its parent id, if it exists. + """ node_ids = np.unique(segmentation)[1:] node_to_track = get_node_assignment(node_ids, tracks) node_to_lineage = get_node_assignment(node_ids, lineages) - # create label layer and track data for visualization in napari + # Create label layer and track data for visualization in napari. tracking_result = recolor_segmentation( segmentation, node_to_lineage if color_by_lineage else node_to_track ) - # create the track data and corresponding parent graph + # Create the track data and corresponding parent graph. track_data, parent_graph = create_data_for_track_layer( - segmentation, lineage_graph, node_to_track + segmentation, lineage_graph, node_to_track, skip_zero=True ) return tracking_result, track_data, parent_graph diff --git a/elf/tracking/tracking_utils.py b/elf/tracking/tracking_utils.py index c1ed15b..d6a090c 100644 --- a/elf/tracking/tracking_utils.py +++ b/elf/tracking/tracking_utils.py @@ -4,6 +4,8 @@ motile or with other python tracking libraries. """ +from typing import Dict, List, Union + import nifty.ground_truth as ngt import numpy as np @@ -14,7 +16,17 @@ from tqdm import trange -def compute_edges_from_overlap(segmentation, verbose=True): +def compute_edges_from_overlap(segmentation: np.ndarray, verbose: bool = True) -> List[Dict[str, Union[int, float]]]: + """Compute the edges between segmented objects in adjacent frames, based on their overlap. + + Args: + segmentation: The input segmentation. + verbose: Whether to be verbose in the computation. + + Returns: + The edges, represented as a dictionary contaning source ids, target ids, and corresponding overlap. + """ + def compute_overlap_between_frames(frame_a, frame_b): overlap_function = ngt.overlap(frame_a, frame_b) @@ -42,10 +54,27 @@ def compute_overlap_between_frames(frame_a, frame_b): next_frame = segmentation[t + 1] frame_edges = compute_overlap_between_frames(this_frame, next_frame) edges.extend(frame_edges) + return edges -def compute_edges_from_centroid_distance(segmentation, max_distance, normalize_distances=True, verbose=True): +def compute_edges_from_centroid_distance( + segmentation: np.ndarray, + max_distance: float, + normalize_distances: bool = True, + verbose: bool = True, +) -> List[Dict[str, Union[int, float]]]: + """Compute the edges between segmented objects in adjacent frames, based on their centroid distances. + + Args: + segmentation: The input segmentation. + max_distance: The maximal distance for taking an edge into account. + normalize_distances: Whether to normalize the distances. + verbose: Whether to be verbose in the computation. + + Returns: + The edges, represented as a dictionary contaning source ids, target ids, and corresponding distance. + """ nt = segmentation.shape[0] props = regionprops(segmentation) centroids_and_labels = [[prop.centroid[0], prop.centroid[1:], prop.label] for prop in props] @@ -72,7 +101,6 @@ def compute_dist_between_frames(t): assert len(distance_values) == len(source_ids) == len(target_ids) return source_ids, target_ids, distance_values - # return edges source_ids, target_ids, distances = [], [], [] for t in trange(nt - 1, disable=not verbose, desc="Compute edges via centroid distance"): @@ -91,14 +119,35 @@ def compute_dist_between_frames(t): return edges -# TODO does this work for 4d data (time + 3d)? if no we need to iterate over the time axis -def compute_node_costs_from_foreground_probabilities(segmentation, probabilities, cost_attribute="mean_intensity"): +def compute_node_costs_from_foreground_probabilities( + segmentation: np.ndarray, + probabilities: np.ndarray, + cost_attribute: str = "mean_intensity", +) -> List[float]: + """Derive the node selection cost from a foreground probability map. + + Args: + segmentation: The segmentation. + probabilities: The foreground probability map. + cost_attribute: The attribute of regionprops to use for the selection cost. + + Returns: + The selection cost for each node in the segmentation. + """ props = regionprops(segmentation, probabilities) costs = [getattr(prop, cost_attribute) for prop in props] return costs -def relabel_segmentation_across_time(segmentation): +def relabel_segmentation_across_time(segmentation: np.ndarray) -> np.ndarray: + """Relabel the segmentation across time, so that segmentation ids are unique in each timepoint. + + Args: + The input segmentation. + + Returns: + The relabeled segmentation. + """ offset = 0 relabeled = [] for frame in segmentation: @@ -109,7 +158,17 @@ def relabel_segmentation_across_time(segmentation): return np.stack(relabeled) -def preprocess_closing(slice_segmentation, gap_closing, verbose=True): +def preprocess_closing(slice_segmentation: np.ndarray, gap_closing: int, verbose: bool = True) -> np.ndarray: + """Preprocess a segmentation by applying a closing operation to fill in missing segments in timepoints. + + Args: + slice_segmentation: The input segmentation. + gap_closing: The maximal number of slices to close. + verbose: Whether to be verbose in the computation. + + Returns: + The segmentation with missing segments filled in. + """ binarized = slice_segmentation > 0 structuring_element = np.zeros((3, 1, 1)) structuring_element[:, 0, 0] = 1