From 51352f1adda032d80b063fa7287619ecee5856d2 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Mon, 6 May 2024 19:25:53 +0200 Subject: [PATCH] Add metric evaluation --- scripts/get_tracking_results.py | 53 ++++++++++++---- scripts/test_ctc_metric.py | 108 ++++++++++++++++++++++++-------- 2 files changed, 123 insertions(+), 38 deletions(-) diff --git a/scripts/get_tracking_results.py b/scripts/get_tracking_results.py index 4d4b817..17d11d2 100644 --- a/scripts/get_tracking_results.py +++ b/scripts/get_tracking_results.py @@ -13,18 +13,42 @@ def load_tracking_segmentation(experiment): result_dir = os.path.join(ROOT, "results") - if experiment == "vit_l": - seg_path = os.path.join(result_dir, "vit_l.tif") - elif experiment == "vit_l_lm": - seg_path = os.path.join(result_dir, "vit_l_lm.tif") - elif experiment == "vit_l_specialist": - seg_path = os.path.join(result_dir, "vit_l_lm_specialist.tif") - elif experiment == "trackmate_stardist": - seg_path = os.path.join(result_dir, "trackmate_stardist", "every_3rd_fr_result.tif") - else: - raise ValueError(experiment) - return imageio.imread(seg_path) + if experiment.startswith("vit"): + if experiment == "vit_l": + seg_path = os.path.join(result_dir, "vit_l.tif") + seg = imageio.imread(seg_path) + # HACK + ignore_labels = [8, 44, 57, 102, 50] + + elif experiment == "vit_l_lm": + seg_path = os.path.join(result_dir, "vit_l_lm.tif") + seg = imageio.imread(seg_path) + # HACK + ignore_labels = [] + + elif experiment == "vit_l_specialist": + seg_path = os.path.join(result_dir, "vit_l_lm_specialist.tif") + seg = imageio.imread(seg_path) + # HACK + ignore_labels = [88, 45, 30, 46] + + # elif experiment == "trackmate_stardist": + # seg_path = os.path.join(result_dir, "trackmate_stardist", "every_3rd_fr_result.tif") + # seg = imageio.imread(seg_path) + + else: + raise ValueError(experiment) + + # HACK: + # we remove some labels as they have a weird lineage, is creating issues for creating the graph + # (e.g. frames where the object exists: 1, 2, 4, 5, 6) + seg[np.isin(seg, ignore_labels)] = 0 + + return seg + + else: # return the result directory for stardist + return os.path.join(result_dir, "trackmate_stardist", "01_RES") def check_tracking_results(raw, labels, curr_lineages, chosen_frames): @@ -93,4 +117,11 @@ def get_tracking_data(): curr_frames = v["frames"] v["frames"] = [frmaps[frval] for frval in curr_frames if frval in chosen_frames] + # HACK: + # we remove label with id 62 as it has a weird lineage, is creating issues for creating the graph + ignore_labels = [62, 87, 92, 99, 58] + labels[np.isin(labels, ignore_labels)] = 0 + for _label in ignore_labels: + curr_lineages.pop(_label) + return raw, labels, curr_lineages, chosen_frames diff --git a/scripts/test_ctc_metric.py b/scripts/test_ctc_metric.py index 68c6227..dc06637 100644 --- a/scripts/test_ctc_metric.py +++ b/scripts/test_ctc_metric.py @@ -1,41 +1,56 @@ +import os import numpy as np +import pandas as pd from deepcell_tracking.isbi_utils import trk_to_isbi -from traccuracy.loaders._ctc import _get_node_attributes +from traccuracy import run_metrics +from traccuracy._tracking_graph import TrackingGraph +from traccuracy.matchers import CTCMatcher, IOUMatcher +from traccuracy.metrics import CTCMetrics, DivisionMetrics +from traccuracy.loaders._ctc import _get_node_attributes, ctc_to_graph, _check_ctc, load_ctc_data from get_tracking_results import get_tracking_data, load_tracking_segmentation +def mark_potential_split(frames, last_frame, idx): + if frames.max() == last_frame: # object is tracked until the last frame + split_frame = None # they can't split in this case + prev_parent_id = None + else: # object either goes out of frame or splits + split_frame = frames.max() # let's assume that it splits, we will know if it does or not + prev_parent_id = idx + return split_frame, prev_parent_id + + def extract_df_from_segmentation(segmentation): track_ids = np.unique(segmentation)[1:] last_frame = segmentation.shape[0] - 1 all_tracks = [] - splits = 0 - for idx in track_ids: + prev_parent_id = None + for idx in track_ids: frames = np.unique(np.where(segmentation == idx)[0]) if frames.min() == 0: # object starts at first frame - if frames.max() == last_frame: # object is tracked until the last frame - pid = 0 - have_fam = None # they can't split in this case - else: # object either goes out of frame or splits - pid = 0 - have_fam = frames.max() # let's assume that it splits, we will know if it does or not + pid = 0 + split_frame, prev_parent_id = mark_potential_split(frames, last_frame, idx) else: - if have_fam is not None: # takes the parent information from above - pid = have_fam - splits += 1 - - if splits > 2: # assumes every mother cell splits into two daughter cells - print("The mother cell has made enough daughter splits, hence this is a new object.") - splits = 0 - # pid = 0 # this is the case where an objects appears at nth frame and has no parent id + if split_frame is not None: # takes the parent information from above + # have fam is the end frame of the potential parent, so our frame has to be the next frame + if split_frame + 1 == frames.min(): + pid = prev_parent_id + + # otherwise we just have some track that starts so it's not the child + else: + pid = 0 + split_frame, prev_parent_id = mark_potential_split(frames, last_frame, idx) + else: pid = 0 # assumes that it was an object that started at a random frame + split_frame, prev_parent_id = mark_potential_split(frames, last_frame, idx) track_dict = { "Cell_ID": idx, @@ -44,30 +59,69 @@ def extract_df_from_segmentation(segmentation): "Parent_ID": pid, } - print(track_dict) - all_tracks.append(track_dict) + all_tracks.append(pd.DataFrame.from_dict([track_dict])) - breakpoint() + pred_tracks_df = pd.concat(all_tracks) + return pred_tracks_df -def evaluate_tracking(raw, labels, curr_lineages, chosen_frames, segmentation_method): +def evaluate_tracking(labels, curr_lineages, segmentation_method): seg = load_tracking_segmentation(segmentation_method) + if os.path.isdir(seg): # for trackmate stardist + seg_T = load_ctc_data( + data_dir=seg, + track_path=os.path.join(seg, 'res_track.txt'), + name=f'DynamicNuclearNet-{segmentation_method}' + ) + + else: # for micro-sam + seg_nodes = _get_node_attributes(seg) + seg_df = extract_df_from_segmentation(seg) + seg_G = ctc_to_graph(seg_df, seg_nodes) + _check_ctc(seg_df, seg_nodes, seg) + seg_T = TrackingGraph(seg_G, segmentation=seg, name=f"DynamicNuclearNet-{segmentation_method}") + + breakpoint() + # calcuates node attributes for each detection - gt_df = _get_node_attributes(labels) - seg_df = _get_node_attributes(seg) + gt_nodes = _get_node_attributes(labels) + + # converts inputs to isbi-tracking format - the version expected as inputs in traccuracy + gt_df = trk_to_isbi(curr_lineages, path=None) + + # creates graphs from ctc-type info (isbi-type? probably means the same thing) + gt_G = ctc_to_graph(gt_df, gt_nodes) + + # OPTIONAL: This tests if inputs (images, dfs and node attributes) to create tracking graphs are as expected + _check_ctc(gt_df, gt_nodes, labels) + + gt_T = TrackingGraph(gt_G, segmentation=labels, name="DynamicNuclearNet-GT") + + ctc_results = run_metrics( + gt_data=gt_T, + pred_data=seg_T, + matcher=CTCMatcher(), + metrics=[CTCMetrics(), DivisionMetrics(max_frame_buffer=0)], + ) + print(ctc_results) - # converts inputs to isbi-track format - the version expected as inputs in traccuracy - output = trk_to_isbi(curr_lineages, path=None) + breakpoint() - df = extract_df_from_segmentation(seg) + iou_results = run_metrics( + gt_data=gt_T, + pred_data=seg_T, + matcher=IOUMatcher(iou_threshold=0.1), + metrics=[DivisionMetrics(max_frame_buffer=0)], + ) + print(iou_results) def main(): raw, labels, curr_lineages, chosen_frames = get_tracking_data() segmentation_method = "vit_l_specialist" - evaluate_tracking(raw, labels, curr_lineages, chosen_frames, segmentation_method) + evaluate_tracking(labels, curr_lineages, segmentation_method) if __name__ == "__main__":