diff --git a/elf/tracking/tracking_utils.py b/elf/tracking/tracking_utils.py index 4a4c74d..c1ed15b 100644 --- a/elf/tracking/tracking_utils.py +++ b/elf/tracking/tracking_utils.py @@ -8,7 +8,8 @@ import numpy as np from scipy.spatial.distance import cdist -from skimage.measure import regionprops +from skimage.measure import regionprops, label +from scipy.ndimage import binary_closing from skimage.segmentation import relabel_sequential from tqdm import trange @@ -23,7 +24,7 @@ def compute_overlap_between_frames(frame_a, frame_b): source_ids = [src for node_id, ovlp in zip(node_ids, overlaps) for src in [node_id] * len(ovlp[0])] target_ids = [ov for ovlp in overlaps for ov in ovlp[0]] overlap_values = [ov for ovlp in overlaps for ov in ovlp[1]] - assert len(source_ids) == len(target_ids) == len(overlap_values),\ + assert len(source_ids) == len(target_ids) == len(overlap_values), \ f"{len(source_ids)}, {len(target_ids)}, {len(overlap_values)}" edges = [ @@ -50,9 +51,9 @@ def compute_edges_from_centroid_distance(segmentation, max_distance, normalize_d centroids_and_labels = [[prop.centroid[0], prop.centroid[1:], prop.label] for prop in props] centroids, labels = {}, {} - for t, centroid, label in centroids_and_labels: + for t, centroid, label_id in centroids_and_labels: centroids[t] = centroids.get(t, []) + [centroid] - labels[t] = labels.get(t, []) + [label] + labels[t] = labels.get(t, []) + [label_id] centroids = {t: np.stack(np.array(val)) for t, val in centroids.items()} labels = {t: np.array(val) for t, val in labels.items()} @@ -106,3 +107,63 @@ def relabel_segmentation_across_time(segmentation): offset = frame.max() relabeled.append(frame) return np.stack(relabeled) + + +def preprocess_closing(slice_segmentation, gap_closing, verbose=True): + binarized = slice_segmentation > 0 + structuring_element = np.zeros((3, 1, 1)) + structuring_element[:, 0, 0] = 1 + closed_segmentation = binary_closing(binarized, iterations=gap_closing, structure=structuring_element) + + new_segmentation = np.zeros_like(slice_segmentation) + n_slices = new_segmentation.shape[0] + + def process_slice(z, offset): + seg_z = slice_segmentation[z] + + # Closing does not work for the first and last gap slices + if z < gap_closing or z >= (n_slices - gap_closing): + seg_z, _, _ = relabel_sequential(seg_z, offset=offset) + offset = int(seg_z.max()) + 1 + return seg_z, offset + + # Apply connected components to the closed segmentation. + closed_z = label(closed_segmentation[z]) + + # Map objects in the closed and initial segmentation. + # We take objects from the closed segmentation unless they + # have overlap with more than one object from the initial segmentation. + # This indicates wrong merging of closeby objects that we want to prevent. + matches = ngt.overlap(closed_z, seg_z) + matches = {seg_id: matches.overlapArrays(seg_id, sorted=False)[0] + for seg_id in range(1, int(closed_z.max() + 1))} + matches = {k: v[v != 0] for k, v in matches.items()} + + ids_initial, ids_closed = [], [] + for seg_id, matched in matches.items(): + if len(matched) > 1: + ids_initial.extend(matched.tolist()) + else: + ids_closed.append(seg_id) + + seg_new = np.zeros_like(seg_z) + closed_mask = np.isin(closed_z, ids_closed) + seg_new[closed_mask] = closed_z[closed_mask] + + if ids_initial: + initial_mask = np.isin(seg_z, ids_initial) + seg_new[initial_mask] = relabel_sequential(seg_z[initial_mask], offset=seg_new.max() + 1)[0] + + seg_new, _, _ = relabel_sequential(seg_new, offset=offset) + max_z = seg_new.max() + if max_z > 0: + offset = int(max_z) + 1 + + return seg_new, offset + + # Further optimization: parallelize + offset = 1 + for z in trange(n_slices, disable=not verbose): + new_segmentation[z], offset = process_slice(z, offset) + + return new_segmentation