diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index d303beb6..d16f5ccd 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -19,7 +19,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest] + os: [ubuntu-latest, windows-latest, macos-latest] python-version: ["3.10"] steps: diff --git a/development/benchmark.py b/development/benchmark.py index b3d3ab15..ff5a50e7 100644 --- a/development/benchmark.py +++ b/development/benchmark.py @@ -180,7 +180,7 @@ def main(): args = parser.parse_args() model_type = args.model_type - device = util._get_device(args.device) + device = util.get_device(args.device) print("Running benchmarks for", model_type) print("with device:", device) diff --git a/development/seg_with_decoder.py b/development/seg_with_decoder.py new file mode 100644 index 00000000..4186bd7c --- /dev/null +++ b/development/seg_with_decoder.py @@ -0,0 +1,31 @@ +import imageio.v3 as imageio +import napari + +from micro_sam.instance_segmentation import ( + load_instance_segmentation_with_decoder_from_checkpoint, mask_data_to_segmentation +) +from micro_sam.util import precompute_image_embeddings + +checkpoint = "./for_decoder/best.pt" +segmenter = load_instance_segmentation_with_decoder_from_checkpoint(checkpoint, model_type="vit_b") + +image_path = "/home/pape/Work/data/incu_cyte/livecell/images/livecell_train_val_images/A172_Phase_A7_1_02d00h00m_1.tif" +image = imageio.imread(image_path) + +embedding_path = "./for_decoder/A172_Phase_A7_1_02d00h00m_1.zarr" +image_embeddings = precompute_image_embeddings( + segmenter._predictor, image, embedding_path, +) +# image_embeddings = None + +print("Start segmentation ...") +segmenter.initialize(image, image_embeddings) +masks = segmenter.generate(output_mode="binary_mask") +segmentation = mask_data_to_segmentation(masks, with_background=True) +print("Segmentation done") + +v = napari.Viewer() +v.add_image(image) +# v.add_image(segmenter._foreground) +v.add_labels(segmentation) +napari.run() diff --git a/doc/finetuned_models.md b/doc/finetuned_models.md index ceedfc71..7ed52384 100644 --- a/doc/finetuned_models.md +++ b/doc/finetuned_models.md @@ -1,13 +1,14 @@ # Finetuned models -We provide models that were finetuned on microscopy data using `micro_sam.training`. They are hosted on zenodo. We currently offer the following models: +In addition to the original Segment anything models, we provide models that finetuned on microscopy data using the functionality from `micro_sam.training`. +The models are hosted on zenodo. We currently offer the following models: - `vit_h`: Default Segment Anything model with vit-h backbone. - `vit_l`: Default Segment Anything model with vit-l backbone. - `vit_b`: Default Segment Anything model with vit-b backbone. -- `vit_h_lm`: Finetuned Segment Anything model for cells and nuclei in light microscopy data with vit-h backbone. +- `vit_t`: Segment Anything model with vit-tiny backbone. From the [mobile sam publication](https://arxiv.org/abs/2306.14289). - `vit_b_lm`: Finetuned Segment Anything model for cells and nuclei in light microscopy data with vit-b backbone. -- `vit_h_em`: Finetuned Segment Anything model for neurites and cells in electron microscopy data with vit-h backbone. -- `vit_b_em`: Finetuned Segment Anything model for neurites and cells in electron microscopy data with vit-b backbone. +- `vit_b_em_organelles`: Finetuned Segment Anything model for mitochodria and nuclei in electron microscopy data with vit-b backbone. +- `vit_b_em_boundaries`: Finetuned Segment Anything model for neurites and cells in electron microscopy data with vit-b backbone. See the two figures below of the improvements through the finetuned model for LM and EM data. @@ -20,17 +21,32 @@ You can select which of the models is used in the annotation tools by selecting To use a specific model in the python library you need to pass the corresponding name as value to the `model_type` parameter exposed by all relevant functions. -See for example the [2d annotator example](https://github.com/computational-cell-analytics/micro-sam/blob/master/examples/annotator_2d.py#L62) where `use_finetuned_model` can be set to `True` to use the `vit_h_lm` model. +See for example the [2d annotator example](https://github.com/computational-cell-analytics/micro-sam/blob/master/examples/annotator_2d.py#L62) where `use_finetuned_model` can be set to `True` to use the `vit_b_lm` model. + +Note that we are still working on improving these models and may update them from time to time. All older models will stay available for download on zenodo, see [model sources](#model-sources) below + ## Which model should I choose? As a rule of thumb: -- Use the `_lm` models for segmenting cells or nuclei in light microscopy. -- Use the `_em` models for segmenting cells or neurites in electron microscopy. - - Note that this model does not work well for segmenting mitochondria or other organelles because it is biased towards segmenting the full cell / cellular compartment. -- For other cases use the default models. +- Use the `vit_b_lm` model for segmenting cells or nuclei in light microscopy. +- Use the `vit_b_em_organelles` models for segmenting mitochondria, nuclei or other organelles in electron microscopy. +- Use the `vit_b_em_boundaries` models for segmenting cells or neurites in electron microscopy. +- For other use-cases use one of the default models. See also the figures above for examples where the finetuned models work better than the vanilla models. Currently the model `vit_h` is used by default. -We are working on releasing more fine-tuned models, in particular for mitochondria and other organelles in EM. +We are working on further improving these models and adding new models for other biomedical imaging domains. + + +## Model Sources + +Here is an overview of all finetuned models we have released to zenodo so far: +- [vit_b_em_boundaries](https://zenodo.org/records/10524894): for segmenting compartments delineated by boundaries such as cells or neurites in EM. +- [vit_b_em_organelles](https://zenodo.org/records/10524828): for segmenting mitochondria, nuclei or other organelles in EM. +- [vit_b_lm](https://zenodo.org/records/10524791): for segmenting cells and nuclei in LM. +- [vit_h_em](https://zenodo.org/records/8250291): this model is outdated. +- [vit_h_lm](https://zenodo.org/records/8250299): this model is outdated. + +Some of these models contain multiple versions. diff --git a/doc/images/model-type-selector.png b/doc/images/model-type-selector.png index cab3b077..9bc5a434 100644 Binary files a/doc/images/model-type-selector.png and b/doc/images/model-type-selector.png differ diff --git a/environment_cpu.yaml b/environment_cpu.yaml index ef1fb2c3..29a56b06 100644 --- a/environment_cpu.yaml +++ b/environment_cpu.yaml @@ -8,13 +8,13 @@ dependencies: - napari - pip - pooch + - python-xxhash - python-elf >=0.4.8 - pytorch - segment-anything - torchvision - - torch_em >=0.5.1 + - torch_em >=0.6.0 - tqdm - timm - pip: - git+https://github.com/ChaoningZhang/MobileSAM.git - # - git+https://github.com/facebookresearch/segment-anything.git diff --git a/environment_gpu.yaml b/environment_gpu.yaml index 900b6bf5..3fc4990c 100644 --- a/environment_gpu.yaml +++ b/environment_gpu.yaml @@ -8,14 +8,14 @@ dependencies: - napari - pip - pooch + - python-xxhash - python-elf >=0.4.8 - pytorch - pytorch-cuda>=11.7 # you may need to update the cuda version to match your system - segment-anything - torchvision - - torch_em >=0.5.1 + - torch_em >=0.6.0 - tqdm - timm - pip: - git+https://github.com/ChaoningZhang/MobileSAM.git - # - git+https://github.com/facebookresearch/segment-anything.git diff --git a/examples/annotator_2d.py b/examples/annotator_2d.py index 8f4930f2..6ab15870 100644 --- a/examples/annotator_2d.py +++ b/examples/annotator_2d.py @@ -1,6 +1,13 @@ +import os + import imageio.v3 as imageio from micro_sam.sam_annotator import annotator_2d from micro_sam.sample_data import fetch_hela_2d_example_data, fetch_livecell_example_data, fetch_wholeslide_example_data +from micro_sam.util import get_cache_directory + +DATA_CACHE = os.path.join(get_cache_directory(), "sample_data") +EMBEDDING_CACHE = os.path.join(get_cache_directory(), "embeddings") +os.makedirs(EMBEDDING_CACHE, exist_ok=True) def livecell_annotator(use_finetuned_model): @@ -8,14 +15,14 @@ def livecell_annotator(use_finetuned_model): See https://doi.org/10.1038/s41592-021-01249-6 for details on the data. """ - example_data = fetch_livecell_example_data("./data") + example_data = fetch_livecell_example_data(DATA_CACHE) image = imageio.imread(example_data) if use_finetuned_model: - embedding_path = "./embeddings/embeddings-livecell-vit_h_lm.zarr" - model_type = "vit_h_lm" + embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-livecell-vit_b_lm.zarr") + model_type = "vit_b_lm" else: - embedding_path = "./embeddings/embeddings-livecell.zarr" + embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-livecell.zarr") model_type = "vit_h" annotator_2d(image, embedding_path, show_embeddings=False, model_type=model_type) @@ -24,14 +31,14 @@ def livecell_annotator(use_finetuned_model): def hela_2d_annotator(use_finetuned_model): """Run the 2d annotator for an example image form the cell tracking challenge HeLa 2d dataset. """ - example_data = fetch_hela_2d_example_data("./data") + example_data = fetch_hela_2d_example_data(DATA_CACHE) image = imageio.imread(example_data) if use_finetuned_model: - embedding_path = "./embeddings/embeddings-hela2d-vit_h_lm.zarr" - model_type = "vit_h_lm" + embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-hela2d-vit_b_lm.zarr") + model_type = "vit_b_lm" else: - embedding_path = "./embeddings/embeddings-hela2d.zarr" + embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-hela2d.zarr") model_type = "vit_h" annotator_2d(image, embedding_path, show_embeddings=False, model_type=model_type, precompute_amg_state=True) @@ -43,29 +50,28 @@ def wholeslide_annotator(use_finetuned_model): See https://neurips22-cellseg.grand-challenge.org/ for details on the data. """ - example_data = fetch_wholeslide_example_data("./data") + example_data = fetch_wholeslide_example_data(DATA_CACHE) image = imageio.imread(example_data) if use_finetuned_model: - embedding_path = "./embeddings/whole-slide-embeddings-vit_h_lm.zarr" - model_type = "vit_h_lm" + embedding_path = os.path.join(EMBEDDING_CACHE, "whole-slide-embeddings-vit_b_lm.zarr") + model_type = "vit_b_lm" else: - embedding_path = "./embeddings/whole-slide-embeddings.zarr" + embedding_path = os.path.join(EMBEDDING_CACHE, "whole-slide-embeddings.zarr") model_type = "vit_h" annotator_2d(image, embedding_path, tile_shape=(1024, 1024), halo=(256, 256), model_type=model_type) def main(): - # whether to use the fine-tuned SAM model - # this feature is still experimental! - use_finetuned_model = False + # Whether to use the fine-tuned SAM model for light microscopy data. + use_finetuned_model = True # 2d annotator for livecell data - # livecell_annotator(use_finetuned_model) + livecell_annotator(use_finetuned_model) # 2d annotator for cell tracking challenge hela data - hela_2d_annotator(use_finetuned_model) + # hela_2d_annotator(use_finetuned_model) # 2d annotator for a whole slide image # wholeslide_annotator(use_finetuned_model) diff --git a/examples/annotator_3d.py b/examples/annotator_3d.py index a279b8de..25e58cfd 100644 --- a/examples/annotator_3d.py +++ b/examples/annotator_3d.py @@ -1,33 +1,45 @@ +import os + from elf.io import open_file from micro_sam.sam_annotator import annotator_3d from micro_sam.sample_data import fetch_3d_example_data +from micro_sam.util import get_cache_directory + +DATA_CACHE = os.path.join(get_cache_directory(), "sample_data") +EMBEDDING_CACHE = os.path.join(get_cache_directory(), "embeddings") +os.makedirs(EMBEDDING_CACHE, exist_ok=True) -def em_3d_annotator(use_finetuned_model): +def em_3d_annotator(finetuned_model): """Run the 3d annotator for an example EM volume.""" # download the example data - example_data = fetch_3d_example_data("./data") + example_data = fetch_3d_example_data(DATA_CACHE) # load the example data (load the sequence of tif files as 3d volume) with open_file(example_data) as f: raw = f["*.png"][:] - if use_finetuned_model: - embedding_path = "./embeddings/embeddings-lucchi-vit_h_em.zarr" - model_type = "vit_h_em" - else: - embedding_path = "./embeddings/embeddings-lucchi.zarr" + if not finetuned_model: + embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-lucchi.zarr") model_type = "vit_h" + else: + assert finetuned_model in ("organelles", "boundaries") + embedding_path = os.path.join(EMBEDDING_CACHE, f"embeddings-lucchi-vit_b_em_{finetuned_model}.zarr") + model_type = f"vit_b_em_{finetuned_model}" + print(embedding_path) # start the annotator, cache the embeddings - annotator_3d(raw, embedding_path, model_type=model_type, show_embeddings=False) + annotator_3d(raw, embedding_path, model_type=model_type) def main(): - # whether to use the fine-tuned SAM model - # this feature is still experimental! - use_finetuned_model = False - - em_3d_annotator(use_finetuned_model) + # Whether to use the fine-tuned SAM model for mitochondria (organelles) or boundaries. + # valid choices are: + # - None / False (will use the vanilla model) + # - "organelles": will use the model for mitochondria and other organelles + # - "boundaries": will use the model for boundary based structures + finetuned_model = "boundaries" + + em_3d_annotator(finetuned_model) if __name__ == "__main__": diff --git a/examples/annotator_tracking.py b/examples/annotator_tracking.py index 89f5cdb8..7c810e23 100644 --- a/examples/annotator_tracking.py +++ b/examples/annotator_tracking.py @@ -1,32 +1,38 @@ +import os + from elf.io import open_file from micro_sam.sam_annotator import annotator_tracking from micro_sam.sample_data import fetch_tracking_example_data +from micro_sam.util import get_cache_directory + +DATA_CACHE = os.path.join(get_cache_directory(), "sample_data") +EMBEDDING_CACHE = os.path.join(get_cache_directory(), "embeddings") +os.makedirs(EMBEDDING_CACHE, exist_ok=True) def track_ctc_data(use_finetuned_model): """Run interactive tracking for data from the cell tracking challenge. """ # download the example data - example_data = fetch_tracking_example_data("./data") + example_data = fetch_tracking_example_data(DATA_CACHE) # load the example data (load the sequence of tif files as timeseries) with open_file(example_data, mode="r") as f: timeseries = f["*.tif"] if use_finetuned_model: - embedding_path = "./embeddings/embeddings-ctc-vit_h_lm.zarr" - model_type = "vit_h_lm" + embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-ctc-vit_b_lm.zarr") + model_type = "vit_b_lm" else: - embedding_path = "./embeddings/embeddings-ctc.zarr" + embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-ctc.zarr") model_type = "vit_h" # start the annotator with cached embeddings - annotator_tracking(timeseries, embedding_path=embedding_path, show_embeddings=False, model_type=model_type) + annotator_tracking(timeseries, embedding_path=embedding_path, model_type=model_type) def main(): - # whether to use the fine-tuned SAM model - # this feature is still experimental! - use_finetuned_model = False + # Whether to use the fine-tuned SAM model. + use_finetuned_model = True track_ctc_data(use_finetuned_model) diff --git a/examples/annotator_with_custom_model.py b/examples/annotator_with_custom_model.py index ceb8b2cb..e39a8c11 100644 --- a/examples/annotator_with_custom_model.py +++ b/examples/annotator_with_custom_model.py @@ -1,8 +1,24 @@ +import os + +import imageio import h5py import micro_sam.sam_annotator as annotator + from micro_sam.util import get_sam_model +from micro_sam.util import get_cache_directory +from micro_sam.sample_data import fetch_hela_2d_example_data + + +DATA_CACHE = os.path.join(get_cache_directory(), "sample_data") + + +def annotator_2d_with_custom_model(): + example_data = fetch_hela_2d_example_data(DATA_CACHE) + image = imageio.imread(example_data) -# TODO add an example for the 2d annotator with a custom model + custom_model = "/home/pape/Downloads/exported_models/vit_b_lm.pth" + predictor = get_sam_model(checkpoint_path=custom_model, model_type="vit_b") + annotator.annotator_2d(image, predictor=predictor) def annotator_3d_with_custom_model(): @@ -16,7 +32,8 @@ def annotator_3d_with_custom_model(): def main(): - annotator_3d_with_custom_model() + annotator_2d_with_custom_model() + # annotator_3d_with_custom_model() if __name__ == "__main__": diff --git a/examples/finetuning/finetune_hela.py b/examples/finetuning/finetune_hela.py index 58a34b51..e9726d2c 100644 --- a/examples/finetuning/finetune_hela.py +++ b/examples/finetuning/finetune_hela.py @@ -58,7 +58,7 @@ def get_dataloader(split, patch_shape, batch_size): patch_shape=patch_shape, batch_size=batch_size, ndim=2, is_seg_dataset=True, rois=roi, label_transform=torch_em.transform.label.connected_components, - num_workers=8, shuffle=True, + num_workers=8, shuffle=True, raw_transform=sam_training.identity, ) return loader diff --git a/examples/image_series_annotator.py b/examples/image_series_annotator.py index 1632d793..456ea727 100644 --- a/examples/image_series_annotator.py +++ b/examples/image_series_annotator.py @@ -1,5 +1,12 @@ +import os + from micro_sam.sam_annotator import image_folder_annotator from micro_sam.sample_data import fetch_image_series_example_data +from micro_sam.util import get_cache_directory + +DATA_CACHE = os.path.join(get_cache_directory(), "sample_data") +EMBEDDING_CACHE = os.path.join(get_cache_directory(), "embeddings") +os.makedirs(EMBEDDING_CACHE, exist_ok=True) def series_annotation(use_finetuned_model): @@ -7,23 +14,22 @@ def series_annotation(use_finetuned_model): """ if use_finetuned_model: - embedding_path = "./embeddings/series-embeddings-vit_h_lm" - model_type = "vit_h_lm" + embedding_path = os.path.join(EMBEDDING_CACHE, "series-embeddings-vit_b_lm") + model_type = "vit_b_lm" else: - embedding_path = "./embeddings/series-embeddings" + embedding_path = os.path.join(EMBEDDING_CACHE, "series-embeddings") model_type = "vit_h" - example_data = fetch_image_series_example_data("./data") + example_data = fetch_image_series_example_data(DATA_CACHE) image_folder_annotator( - example_data, "./data/series-segmentation-result", embedding_path=embedding_path, + example_data, "./series-segmentation-result", embedding_path=embedding_path, pattern="*.tif", model_type=model_type, precompute_amg_state=True, ) def main(): - # whether to use the fine-tuned SAM model - # this feature is still experimental! + # Whether to use the fine-tuned SAM model. use_finetuned_model = False series_annotation(use_finetuned_model) diff --git a/finetuning/.gitignore b/finetuning/.gitignore index d6643f86..b078b91d 100644 --- a/finetuning/.gitignore +++ b/finetuning/.gitignore @@ -2,4 +2,5 @@ checkpoints/ logs/ sam_embeddings/ results/ -*.sh \ No newline at end of file +iterative_prompting_results/ +*.sh diff --git a/finetuning/generalists/training/electron_microscopy/boundaries/obtain_boundaries_em_datasets.py b/finetuning/generalists/training/electron_microscopy/boundaries/obtain_boundaries_em_datasets.py new file mode 100644 index 00000000..b46744bc --- /dev/null +++ b/finetuning/generalists/training/electron_microscopy/boundaries/obtain_boundaries_em_datasets.py @@ -0,0 +1,136 @@ +import os +import numpy as np + +from elf.io import open_file + +from skimage.measure import label +from skimage.segmentation import watershed +from scipy.ndimage import distance_transform_edt + +from torch_em import get_data_loader +from torch_em.transform.label import PerObjectDistanceTransform +from torch_em.data import ConcatDataset, MinInstanceSampler, datasets + +from micro_sam.training import identity +from micro_sam.training.util import ResizeRawTrafo, ResizeLabelTrafo + + +def compute_platy_rois(root, sample_ids, ignore_label, file_template, label_key): + rois = {} + for sample_id in sample_ids: + path = os.path.join(root, (file_template % sample_id)) + with open_file(path, "r") as f: + labels = f[label_key][:] + valid_coordinates = np.where(labels != ignore_label) + roi = tuple(slice( + int(coord.min()), int(coord.max()) + 1 + ) for coord in valid_coordinates) + rois[sample_id] = roi + return rois + + +def axondeepseg_label_trafo(labels): + # after checking, labels look like this : 0 is bg, 1 is myelins and 2 is axons + foreground_seeds = label((labels == 2)) + boundary_prediction = (labels == 1) + + # use the distance to the myelinated axons as height map to assign pixels to nearest myelinated axon + hmap = distance_transform_edt(labels != 2) + seg = watershed(image=hmap, markers=foreground_seeds, mask=(foreground_seeds + boundary_prediction) > 0) + + dist_trafo = PerObjectDistanceTransform( + distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=0 + ) + seg = dist_trafo(seg) + return seg + + +def _check_dataset_available_for_rois(path, patch_shape): + """This function checks whether or not all the expected datasets are available, else downloads them + We do this only for "platynereis - cells", "cremi" datasets - as we expect specific RoIs only from them + """ + datasets.get_cremi_dataset(path=os.path.join(path, "cremi"), patch_shape=patch_shape, download=True) + datasets.get_platynereis_cell_dataset( + path=os.path.join(path, "platynereis"), patch_shape=patch_shape, download=True + ) + print("All the datasets are available for RoI splitting") + + +def get_concat_boundaries_datasets(input_path, patch_shape): + _check_dataset_available_for_rois(path=input_path, patch_shape=patch_shape) + + sampler = MinInstanceSampler() + standard_label_trafo = PerObjectDistanceTransform( + distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=0 + ) + + # cremi dataset parameters + cremi_train_rois = {"A": np.s_[0:75, :, :], "B": np.s_[0:75, :, :], "C": np.s_[0:75, :, :]} + cremi_val_rois = {"A": np.s_[75:100, :, :], "B": np.s_[75:100, :, :], "C": np.s_[75:100, :, :]} + + # platynereis cell dataset parameters + platy_root = os.path.join(input_path, "platynereis") + platy_cell_template = "membrane/train_data_membrane_%02i.n5" + platy_cell_label_key = "volumes/labels/segmentation/s1" + + platy_cell_train_samples = [1, 2, 3, 4, 5, 6] + platy_cell_train_rois = compute_platy_rois(platy_root, platy_cell_train_samples, ignore_label=0, + file_template=platy_cell_template, label_key=platy_cell_label_key) + platy_cell_val_samples = [7, 8] + platy_cell_val_rois = compute_platy_rois(platy_root, platy_cell_val_samples, ignore_label=0, + file_template=platy_cell_template, label_key=platy_cell_label_key) + + def cremi_dataset(rois, n_samples): + return datasets.get_cremi_dataset( + path=os.path.join(input_path, "cremi"), patch_shape=patch_shape, label_transform=standard_label_trafo, + n_samples=n_samples, rois=rois, sampler=sampler, ndim=2, defect_augmentation_kwargs=None, + download=True, raw_transform=identity + ) + + cremi_train_dataset = cremi_dataset(cremi_train_rois, n_samples=750) # taking ~50% of all training samples + cremi_val_dataset = cremi_dataset(cremi_val_rois, n_samples=250) # # taking ~50% of all val samples + + def platy_cell_dataset(rois, sample_ids): + return datasets.get_platynereis_cell_dataset( + path=platy_root, sample_ids=sample_ids, patch_shape=patch_shape, download=True, sampler=sampler, + ndim=2, rois=rois, raw_transform=ResizeRawTrafo(patch_shape[1:], do_rescaling=False), + label_transform=ResizeLabelTrafo(patch_shape[1:]) + ) + + platy_cell_train_dataset = platy_cell_dataset(platy_cell_train_rois, platy_cell_train_samples) + platy_cell_val_dataset = platy_cell_dataset(platy_cell_val_rois, platy_cell_val_samples) + + def axondeepseg_dataset(split): + # train is oversampled by ~10 times and val by ~15 times + n_samples = 500 if split == "train" else 100 + return datasets.get_axondeepseg_dataset( + path=os.path.join(input_path, "axondeepseg"), name=["sem"], patch_shape=patch_shape[1:], + label_transform=axondeepseg_label_trafo, sampler=sampler, split=split, + raw_transform=identity, download=True, val_fraction=0.1, n_samples=n_samples + ) + + axondeepseg_train_dataset = axondeepseg_dataset("train") + axondeepseg_val_dataset = axondeepseg_dataset("val") + + train_datasets = [cremi_train_dataset, platy_cell_train_dataset, axondeepseg_train_dataset] + val_datasets = [cremi_val_dataset, platy_cell_val_dataset, axondeepseg_val_dataset] + + generalist_em_train_dataset = ConcatDataset(*train_datasets) + generalist_em_val_dataset = ConcatDataset(*val_datasets) + + return generalist_em_train_dataset, generalist_em_val_dataset + + +def get_generalist_boundaries_loaders(input_path, patch_shape): + """This returns the concatenated electron microscopy datasets implemented in torch_em: + https://github.com/constantinpape/torch-em/tree/main/torch_em/data/datasets + It will automatically download all the datasets + NOTE: to remove / replace the datasets with another dataset, you need to add the datasets (for train and val splits) + in `get_concat_lm_dataset`. The labels have to be in a label mask instance segmentation format. + i.e. the tensors (inputs & masks) should be of same spatial shape, with each object in the mask having it's own ID. + IMPORTANT: the ID 0 is reserved for background, and the IDs must be consecutive. + """ + generalist_train_dataset, generalist_val_dataset = get_concat_boundaries_datasets(input_path, patch_shape) + train_loader = get_data_loader(generalist_train_dataset, batch_size=2, num_workers=16, shuffle=True) + val_loader = get_data_loader(generalist_val_dataset, batch_size=1, num_workers=16, shuffle=True) + return train_loader, val_loader diff --git a/finetuning/generalists/training/electron_microscopy/boundaries/train_boundaries_em_generalist.py b/finetuning/generalists/training/electron_microscopy/boundaries/train_boundaries_em_generalist.py new file mode 100644 index 00000000..64400ca3 --- /dev/null +++ b/finetuning/generalists/training/electron_microscopy/boundaries/train_boundaries_em_generalist.py @@ -0,0 +1,129 @@ +import os +import argparse + +import torch + +from torch_em.model import UNETR +from torch_em.loss import DiceBasedDistanceLoss + +import micro_sam.training as sam_training +from micro_sam.util import export_custom_sam_model + +from obtain_boundaries_em_datasets import get_generalist_boundaries_loaders + + +def finetune_boundaries_em_generalist(args): + """Code for finetuning SAM on boundary structures in electron microscopy datasets""" + # override this (below) if you have some more complex set-up and need to specify the exact gpu + device = "cuda" if torch.cuda.is_available() else "cpu" + + # training settings: + model_type = args.model_type + checkpoint_path = None # override this to start training from a custom checkpoint + patch_shape = (1, 512, 512) # the patch shape for training + n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled + freeze_parts = None # override this to freeze one or more of these backbones + + # get the trainable segment anything model + model = sam_training.get_trainable_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + freeze=freeze_parts + ) + model.to(device) + + # let's get the UNETR model for automatic instance segmentation pipeline + unetr = UNETR( + backbone="sam", + encoder=model.sam.image_encoder, + out_channels=3, + use_sam_stats=True, + final_activation="Sigmoid", + use_skip_connection=False, + resize_input=True + ) + unetr.to(device) + + # let's get the parameters for SAM and the decoder from UNETR + joint_model_params = [params for params in model.parameters()] # sam parameters + for name, params in unetr.named_parameters(): # unetr's decoder parameters + if not name.startswith("encoder"): + joint_model_params.append(params) + + # all the stuff we need for training + optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=20, verbose=True) + train_loader, val_loader = get_generalist_boundaries_loaders(input_path=args.input_path, patch_shape=patch_shape) + + # this class creates all the training data for a batch (inputs, prompts and labels) + convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) + + checkpoint_name = f"{args.model_type}/boundaries_em_generalist_sam" + + # the trainer which performs the joint training and validation (implemented using "torch_em") + trainer = sam_training.JointSamTrainer( + name=checkpoint_name, + save_root=args.save_root, + train_loader=train_loader, + val_loader=val_loader, + model=model, + optimizer=optimizer, + device=device, + lr_scheduler=scheduler, + logger=sam_training.JointSamLogger, + log_image_interval=100, + mixed_precision=True, + convert_inputs=convert_inputs, + n_objects_per_batch=n_objects_per_batch, + n_sub_iteration=8, + compile_model=False, + mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training + unetr=unetr, + instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True), + instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True) + ) + trainer.fit(args.iterations, save_every_kth_epoch=args.save_every_kth_epoch) + if args.export_path is not None: + checkpoint_path = os.path.join( + "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" + ) + export_custom_sam_model( + checkpoint_path=checkpoint_path, + model_type=model_type, + save_path=args.export_path, + ) + + +def main(): + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the Boundary Structures EM datasets.") + parser.add_argument( + "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/", + help="The filepath to all the respective EM datasets. If the data does not exist yet it will be downloaded" + ) + parser.add_argument( + "--model_type", "-m", default="vit_b", + help="The model type to use for fine-tuning. Either vit_b, vit_l or vit_h." + ) + parser.add_argument( + "--save_root", "-s", + help="Where to save the checkpoint and logs. By default they will be saved where this script is run from." + ) + parser.add_argument( + "--iterations", type=int, default=int(25e4), + help="For how many iterations should the model be trained? By default 250k." + ) + parser.add_argument( + "--export_path", "-e", + help="Where to export the finetuned model to. The exported model can be used in the annotation tools." + ) + parser.add_argument( + "--save_every_kth_epoch", type=int, default=None, + help="To save every kth epoch while fine-tuning. Expects an integer value." + ) + args = parser.parse_args() + finetune_boundaries_em_generalist(args) + + +if __name__ == "__main__": + main() diff --git a/finetuning/generalists/training/electron_microscopy/mito_nuc/obtain_mito_nuc_em_datasets.py b/finetuning/generalists/training/electron_microscopy/mito_nuc/obtain_mito_nuc_em_datasets.py new file mode 100644 index 00000000..95bc6fc4 --- /dev/null +++ b/finetuning/generalists/training/electron_microscopy/mito_nuc/obtain_mito_nuc_em_datasets.py @@ -0,0 +1,127 @@ +import os +import numpy as np + +from elf.io import open_file + +from torch_em import get_data_loader +from torch_em.transform.label import PerObjectDistanceTransform +from torch_em.data import ConcatDataset, MinInstanceSampler, datasets + +from micro_sam.training import identity +from micro_sam.training.util import ResizeRawTrafo, ResizeLabelTrafo + + +def compute_platy_rois(root, sample_ids, ignore_label, file_template, label_key): + rois = {} + for sample_id in sample_ids: + path = os.path.join(root, (file_template % sample_id)) + with open_file(path, "r") as f: + labels = f[label_key][:] + valid_coordinates = np.where(labels != ignore_label) + roi = tuple(slice( + int(coord.min()), int(coord.max()) + 1 + ) for coord in valid_coordinates) + rois[sample_id] = roi + return rois + + +def _check_dataset_available_for_rois(path, patch_shape): + """This function checks whether or not all the expected datasets are available, else downloads them + We do this only for "platynereis - nuclei", "mitoem" datasets - as we expect specific RoIs only from them + """ + datasets.get_mitoem_dataset( + path=os.path.join(path, "mitoem"), patch_shape=patch_shape, download=True, splits="train" + ) + datasets.get_platynereis_nuclei_dataset( + path=os.path.join(path, "platynereis"), patch_shape=patch_shape, download=True + ) + print("All the datasets are available for RoI splitting") + + +def get_concat_mito_nuc_datasets(input_path, patch_shape, with_cem=False): + _check_dataset_available_for_rois(path=input_path, patch_shape=patch_shape) + + sampler = MinInstanceSampler() + standard_label_trafo = PerObjectDistanceTransform( + distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=0 + ) + + # mitoem parameters + mitoem_train_rois = [np.s_[100:110, :, :], np.s_[100:110, :, :]] + mitoem_val_rois = [np.s_[0:5, :, :], np.s_[0:5, :, :]] + + # platynereis nuclei dataset parameters + platy_root = os.path.join(input_path, "platynereis") + platy_nuclei_template = "nuclei/train_data_nuclei_%02i.h5" + platy_nuclei_label_key = "volumes/labels/nucleus_instance_labels" + + platy_nuclei_train_samples = [1, 2, 3, 4, 5, 6, 7, 8] + platy_nuclei_train_rois = compute_platy_rois(platy_root, platy_nuclei_train_samples, ignore_label=-1, + file_template=platy_nuclei_template, label_key=platy_nuclei_label_key) + platy_nuclei_val_samples = [9, 10] + platy_nuclei_val_rois = compute_platy_rois(platy_root, platy_nuclei_val_samples, ignore_label=-1, + file_template=platy_nuclei_template, label_key=platy_nuclei_label_key) + + def mitoem_dataset(split, roi_choice): + return datasets.get_mitoem_dataset( + path=os.path.join(input_path, "mitoem"), splits=split, download=True, patch_shape=patch_shape, + rois=roi_choice, label_transform=standard_label_trafo, ndim=2, raw_transform=identity, + sampler=MinInstanceSampler(min_num_instances=5) + ) + + mitoem_train_dataset = mitoem_dataset("train", mitoem_train_rois) + mitoem_val_dataset = mitoem_dataset("val", mitoem_val_rois) + + def platy_nuclei_dataset(roi_choice, sample_ids): + return datasets.get_platynereis_nuclei_dataset( + path=platy_root, patch_shape=patch_shape, download=True, sampler=sampler, ndim=2, + label_transform=ResizeLabelTrafo(patch_shape[1:]), rois=roi_choice, + raw_transform=ResizeRawTrafo(patch_shape[1:], do_rescaling=False), sample_ids=sample_ids + ) + + platy_nuclei_train_dataset = platy_nuclei_dataset(platy_nuclei_train_rois, sample_ids=platy_nuclei_train_samples) + platy_nuclei_val_dataset = platy_nuclei_dataset(platy_nuclei_val_rois, sample_ids=platy_nuclei_val_samples) + + train_datasets = [mitoem_train_dataset, platy_nuclei_train_dataset] + val_datasets = [mitoem_val_dataset, platy_nuclei_val_dataset] + + if with_cem: + def cem_dataset(split): + # 10% of the total training set, 1/3 of the total val set + n_samples = 1620 if split == "train" else 600 + return datasets.cem.get_mitolab_dataset( + path=os.path.join(input_path, "mitolab"), split=split, val_fraction=0.1, sampler=sampler, + raw_transform=ResizeRawTrafo(patch_shape[1:], do_rescaling=False), patch_shape=patch_shape[1:], + label_transform=ResizeLabelTrafo(patch_shape[1:]), n_samples=n_samples + ) + + train_datasets.append(cem_dataset("train")) + val_datasets.append(cem_dataset("val")) + + for train_dataset in train_datasets: + train_dataset.max_sampling_attempts = 5000 + + for val_dataset in val_datasets: + val_dataset.max_sampling_attempts = 5000 + + generalist_em_train_dataset = ConcatDataset(*train_datasets) + generalist_em_val_dataset = ConcatDataset(*val_datasets) + + return generalist_em_train_dataset, generalist_em_val_dataset + + +def get_generalist_mito_nuc_loaders(input_path, patch_shape, with_cem=False): + """This returns the concatenated electron microscopy datasets implemented in torch_em: + https://github.com/constantinpape/torch-em/tree/main/torch_em/data/datasets + It will automatically download all the datasets + NOTE: to remove / replace the datasets with another dataset, you need to add the datasets (for train and val splits) + in `get_concat_lm_dataset`. The labels have to be in a label mask instance segmentation format. + i.e. the tensors (inputs & masks) should be of same spatial shape, with each object in the mask having it's own ID. + IMPORTANT: the ID 0 is reserved for background, and the IDs must be consecutive. + """ + generalist_train_dataset, generalist_val_dataset = get_concat_mito_nuc_datasets( + input_path, patch_shape, with_cem=with_cem + ) + train_loader = get_data_loader(generalist_train_dataset, batch_size=2, shuffle=True, num_workers=16) + val_loader = get_data_loader(generalist_val_dataset, batch_size=1, shuffle=True, num_workers=16) + return train_loader, val_loader diff --git a/finetuning/generalists/training/electron_microscopy/mito_nuc/train_mito_nuc_em_generalist.py b/finetuning/generalists/training/electron_microscopy/mito_nuc/train_mito_nuc_em_generalist.py new file mode 100644 index 00000000..d73555f0 --- /dev/null +++ b/finetuning/generalists/training/electron_microscopy/mito_nuc/train_mito_nuc_em_generalist.py @@ -0,0 +1,137 @@ +import os +import argparse + +import torch + +from torch_em.model import UNETR +from torch_em.loss import DiceBasedDistanceLoss + +import micro_sam.training as sam_training +from micro_sam.util import export_custom_sam_model + +from obtain_mito_nuc_em_datasets import get_generalist_mito_nuc_loaders + + +def finetune_mito_nuc_em_generalist(args): + """Code for finetuning SAM on mitochondria and nuclei electron microscopy datasets""" + # override this (below) if you have some more complex set-up and need to specify the exact gpu + device = "cuda" if torch.cuda.is_available() else "cpu" + + # training settings: + model_type = args.model_type + checkpoint_path = None # override this to start training from a custom checkpoint + patch_shape = (1, 512, 512) # the patch shape for training + n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled + freeze_parts = None # override this to freeze one or more of these backbones + + # get the trainable segment anything model + model = sam_training.get_trainable_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + freeze=freeze_parts + ) + model.to(device) + + # let's get the UNETR model for automatic instance segmentation pipeline + unetr = UNETR( + backbone="sam", + encoder=model.sam.image_encoder, + out_channels=3, + use_sam_stats=True, + final_activation="Sigmoid", + use_skip_connection=False, + resize_input=True + ) + unetr.to(device) + + # let's get the parameters for SAM and the decoder from UNETR + joint_model_params = [params for params in model.parameters()] # sam parameters + for name, params in unetr.named_parameters(): # unetr's decoder parameters + if not name.startswith("encoder"): + joint_model_params.append(params) + + # all the stuff we need for training + optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=15, verbose=True) + train_loader, val_loader = get_generalist_mito_nuc_loaders( + input_path=args.input_path, patch_shape=patch_shape, with_cem=args.with_cem + ) + + # this class creates all the training data for a batch (inputs, prompts and labels) + convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) + + checkpoint_name = f"{args.model_type}/" + checkpoint_name += "with_cem/" if args.with_cem else "without_cem/" + checkpoint_name += "mito_nuc_em_generalist_sam" + + # the trainer which performs the joint training and validation (implemented using "torch_em") + trainer = sam_training.JointSamTrainer( + name=checkpoint_name, + save_root=args.save_root, + train_loader=train_loader, + val_loader=val_loader, + model=model, + optimizer=optimizer, + device=device, + lr_scheduler=scheduler, + logger=sam_training.JointSamLogger, + log_image_interval=100, + mixed_precision=True, + convert_inputs=convert_inputs, + n_objects_per_batch=n_objects_per_batch, + n_sub_iteration=8, + compile_model=False, + mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training + unetr=unetr, + instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True), + instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True) + ) + trainer.fit(args.iterations, save_every_kth_epoch=args.save_every_kth_epoch) + if args.export_path is not None: + checkpoint_path = os.path.join( + "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" + ) + export_custom_sam_model( + checkpoint_path=checkpoint_path, + model_type=model_type, + save_path=args.export_path, + ) + + +def main(): + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the Mito. & Nuclei EM datasets.") + parser.add_argument( + "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/", + help="The filepath to all the respective EM datasets. If the data does not exist yet it will be downloaded" + ) + parser.add_argument( + "--model_type", "-m", default="vit_b", + help="The model type to use for fine-tuning. Either vit_b, vit_l or vit_h." + ) + parser.add_argument( + "--save_root", "-s", + help="Where to save the checkpoint and logs. By default they will be saved where this script is run from." + ) + parser.add_argument( + "--iterations", type=int, default=int(25e4), + help="For how many iterations should the model be trained? By default 250k." + ) + parser.add_argument( + "--export_path", "-e", + help="Where to export the finetuned model to. The exported model can be used in the annotation tools." + ) + parser.add_argument( + "--save_every_kth_epoch", type=int, default=None, + help="To save every kth epoch while fine-tuning. Expects an integer value." + ) + parser.add_argument( + "--with_cem", action="store_true", + help="To train the Mito-Nuc EM generalist using the MitoLab CEM dataset." + ) + args = parser.parse_args() + finetune_mito_nuc_em_generalist(args) + + +if __name__ == "__main__": + main() diff --git a/finetuning/generalists/training/electron_microscopy/obtain_em_datasets.py b/finetuning/generalists/training/electron_microscopy/obtain_em_datasets.py deleted file mode 100644 index 9b78d011..00000000 --- a/finetuning/generalists/training/electron_microscopy/obtain_em_datasets.py +++ /dev/null @@ -1,211 +0,0 @@ -import os -import numpy as np -from math import ceil, floor - -from elf.io import open_file -from skimage.measure import label -from skimage.segmentation import watershed - -from torch_em import get_data_loader -from torch_em.transform.raw import standardize -from torch_em.transform.label import label_consecutive -from torch_em.data import ConcatDataset, MinInstanceSampler, datasets - - -def axondeepseg_label_trafo(labels): - # after checking, labels look like this : 0 is bg, 1 is myelins and 2 is axons - foreground_seeds = label((labels == 2)) - boundary_prediction = (labels == 1) - seg = watershed(boundary_prediction, markers=foreground_seeds, mask=(foreground_seeds + boundary_prediction) > 0) - seg = label_consecutive(seg) - return seg - - -def raw_trafo_for_padding(raw, desired_shape=(512, 512)): - raw = standardize(raw) - tmp_ddim = (desired_shape[0] - raw.shape[0], desired_shape[1] - raw.shape[1]) - ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) - raw = np.pad(raw, - pad_width=((ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), - mode="reflect") - assert raw.shape == desired_shape - return raw - - -def label_trafo_for_padding(labels, desired_shape=(512, 512)): - labels = label(labels) - labels = label_consecutive(labels) - tmp_ddim = (desired_shape[0] - labels.shape[0], desired_shape[1] - labels.shape[1]) - ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) - labels = np.pad( - labels, - pad_width=((ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), - mode="constant" - ) - assert labels.shape == desired_shape - return labels - - -def standard_label_trafo(labels): - labels = label(labels) - labels = label_consecutive(labels) - return labels - - -def compute_platy_rois(root, sample_ids, ignore_label, file_template, label_key): - rois = {} - for sample_id in sample_ids: - path = os.path.join(root, (file_template % sample_id)) - with open_file(path, "r") as f: - labels = f[label_key][:] - valid_coordinates = np.where(labels != ignore_label) - roi = tuple(slice( - int(coord.min()), int(coord.max()) + 1 - ) for coord in valid_coordinates) - rois[sample_id] = roi - return rois - - -def _check_dataset_available_for_rois(path, patch_shape): - """This function checks whether or not all the expected datasets are available, else downloads them - We do this only for "platynereis", "cremi" and "mitoem" datasets - as we expect specific RoIs only from them - """ - datasets.get_cremi_dataset(path=os.path.join(path, "cremi"), patch_shape=patch_shape, download=True) - datasets.get_mitoem_dataset(path=os.path.join(path, "mitoem"), patch_shape=patch_shape, download=True, splits="train") - datasets.get_platynereis_cell_dataset(path=os.path.join(path, "platynereis"), patch_shape=patch_shape, download=True) - datasets.get_platynereis_nuclei_dataset(path=os.path.join(path, "platynereis"), patch_shape=patch_shape, download=True) - datasets.get_platynereis_cilia_dataset(path=os.path.join(path, "platynereis"), patch_shape=patch_shape, download=True) - print("All the datasets are available for RoI splitting") - - -def get_concat_em_datasets(input_path, patch_shape): - _check_dataset_available_for_rois(path=input_path, patch_shape=patch_shape) - - sampler = MinInstanceSampler() - - # cremi dataset parameters - cremi_train_rois = {"A": np.s_[0:75, :, :], "B": np.s_[0:75, :, :], "C": np.s_[0:75, :, :]} - cremi_val_rois = {"A": np.s_[75:100, :, :], "B": np.s_[75:100, :, :], "C": np.s_[75:100, :, :]} - - # mitoem parameters - mitoem_train_rois = [np.s_[100:120, :, :], np.s_[100:120, :, :]] - mitoem_val_rois = [np.s_[0:20, :, :], np.s_[0:20, :, :]] - - # platynereis cell dataset parameters - platy_root = os.path.join(input_path, "platynereis") - platy_cell_template = "membrane/train_data_membrane_%02i.n5" - platy_cell_label_key = "volumes/labels/segmentation/s1" - - platy_cell_train_samples = [1, 2, 3, 4, 5, 6] - platy_cell_train_rois = compute_platy_rois(platy_root, platy_cell_train_samples, ignore_label=0, - file_template=platy_cell_template, label_key=platy_cell_label_key) - platy_cell_val_samples = [7, 8] - platy_cell_val_rois = compute_platy_rois(platy_root, platy_cell_val_samples, ignore_label=0, - file_template=platy_cell_template, label_key=platy_cell_label_key) - - # platynereis cilia dataset parameters - platy_cilia_template = "cilia/train_data_cilia_%02i.h5" - platy_cilia_label_key = "volumes/labels/segmentation" - - platy_cilia_train_samples = [1, 2] - platy_cilia_train_rois = compute_platy_rois(platy_root, platy_cilia_train_samples, ignore_label=-1, - file_template=platy_cilia_template, label_key=platy_cilia_label_key) - platy_cilia_val_samples = [3] - platy_cilia_val_rois = compute_platy_rois(platy_root, platy_cilia_val_samples, ignore_label=-1, - file_template=platy_cilia_template, label_key=platy_cilia_label_key) - - # platynereis nuclei dataset parameters - platy_nuclei_template = "nuclei/train_data_nuclei_%02i.h5" - platy_nuclei_label_key = "volumes/labels/nucleus_instance_labels" - - platy_nuclei_train_samples = [1, 2, 3, 4, 5, 6, 7, 8] - platy_nuclei_train_rois = compute_platy_rois(platy_root, platy_nuclei_train_samples, ignore_label=-1, - file_template=platy_nuclei_template, label_key=platy_nuclei_label_key) - platy_nuclei_val_samples = [9, 10] - platy_nuclei_val_rois = compute_platy_rois(platy_root, platy_nuclei_val_samples, ignore_label=-1, - file_template=platy_nuclei_template, label_key=platy_nuclei_label_key) - - generalist_em_train_dataset = ConcatDataset( - datasets.get_cremi_dataset( - path=os.path.join(input_path, "cremi"), patch_shape=patch_shape, download=True, - label_transform=standard_label_trafo, rois=cremi_train_rois, sampler=sampler, ndim=2, - defect_augmentation_kwargs=None - ), - datasets.get_platynereis_cell_dataset( - path=platy_root, sample_ids=platy_cell_train_samples, patch_shape=patch_shape, - download=True, sampler=sampler, ndim=2, label_transform=label_trafo_for_padding, - rois=platy_cell_train_rois, raw_transform=raw_trafo_for_padding - ), - datasets.get_platynereis_cilia_dataset( - path=platy_root, download=True, patch_shape=patch_shape, ndim=2, rois=platy_cilia_train_rois, - raw_transform=raw_trafo_for_padding, label_transform=label_trafo_for_padding, sampler=sampler, - sample_ids=platy_cilia_train_samples - ), - datasets.get_mitoem_dataset( - path=os.path.join(input_path, "mitoem"), splits="train", download=True, patch_shape=patch_shape, - rois=mitoem_train_rois, label_transform=standard_label_trafo, ndim=2, - sampler=MinInstanceSampler(min_num_instances=5) - ), - datasets.get_axondeepseg_dataset( - path=os.path.join(input_path, "axondeepseg"), name=["sem"], patch_shape=patch_shape[1:], download=True, - label_transform=axondeepseg_label_trafo, sampler=sampler, data_fraction=0.9, split="train" - ), - datasets.get_uro_cell_dataset( - path=os.path.join(input_path, "uro_cell"), target="mito", patch_shape=patch_shape, download=True, - sampler=sampler, label_transform=label_trafo_for_padding, ndim=2, raw_transform=raw_trafo_for_padding - ), - datasets.get_platynereis_nuclei_dataset( - path=platy_root, patch_shape=patch_shape, download=True, sampler=sampler, ndim=2, - label_transform=label_trafo_for_padding, rois=platy_nuclei_train_rois, raw_transform=raw_trafo_for_padding, - sample_ids=platy_nuclei_train_samples - ) - ) - - generalist_em_val_dataset = ConcatDataset( - datasets.get_cremi_dataset( - path=os.path.join(input_path, "cremi"), patch_shape=patch_shape, - download=True, label_transform=standard_label_trafo, rois=cremi_val_rois, - sampler=sampler, ndim=2, defect_augmentation_kwargs=None - ), - datasets.get_platynereis_cell_dataset( - path=platy_root, sample_ids=platy_cell_val_samples, patch_shape=patch_shape, - download=True, sampler=sampler, ndim=2, label_transform=label_trafo_for_padding, - rois=platy_cell_val_rois, raw_transform=raw_trafo_for_padding - ), - datasets.get_platynereis_cilia_dataset( - path=platy_root, download=True, patch_shape=patch_shape, ndim=2, rois=platy_cilia_val_rois, - raw_transform=raw_trafo_for_padding, label_transform=label_trafo_for_padding, sampler=sampler, - sample_ids=platy_cilia_val_samples - ), - datasets.get_mitoem_dataset( - path=os.path.join(input_path, "mitoem"), splits="val", download=True, patch_shape=patch_shape, - rois=mitoem_val_rois, label_transform=standard_label_trafo, ndim=2, - sampler=MinInstanceSampler(min_num_instances=5) - ), - datasets.get_axondeepseg_dataset( - path=os.path.join(input_path, "axondeepseg"), name=["sem"], patch_shape=patch_shape[1:], download=True, - label_transform=axondeepseg_label_trafo, sampler=sampler, data_fraction=0.1, split="val" - ), - datasets.get_platynereis_nuclei_dataset( - path=platy_root, patch_shape=patch_shape, download=True, sampler=sampler, ndim=2, - label_transform=label_trafo_for_padding, rois=platy_nuclei_val_rois, raw_transform=raw_trafo_for_padding, - sample_ids=platy_nuclei_val_samples - ) - ) - - return generalist_em_train_dataset, generalist_em_val_dataset - - -def get_generalist_em_loaders(input_path, patch_shape): - """This returns the concatenated electron microscopy datasets implemented in torch_em: - https://github.com/constantinpape/torch-em/tree/main/torch_em/data/datasets - It will automatically download all the datasets - NOTE: to remove / replace the datasets with another dataset, you need to add the datasets (for train and val splits) - in `get_concat_lm_dataset`. The labels have to be in a label mask instance segmentation format. - i.e. the tensors (inputs & masks) should be of same spatial shape, with each object in the mask having it's own ID. - IMPORTANT: the ID 0 is reserved for background, and the IDs must be consecutive. - """ - generalist_train_dataset, generalist_val_dataset = get_concat_em_datasets(input_path, patch_shape) - train_loader = get_data_loader(generalist_train_dataset, batch_size=2, shuffle=True, num_workers=16) - val_loader = get_data_loader(generalist_val_dataset, batch_size=1, shuffle=True, num_workers=16) - return train_loader, val_loader diff --git a/finetuning/generalists/training/electron_microscopy/train_em_generalist.py b/finetuning/generalists/training/electron_microscopy/train_em_generalist.py deleted file mode 100644 index adb87568..00000000 --- a/finetuning/generalists/training/electron_microscopy/train_em_generalist.py +++ /dev/null @@ -1,97 +0,0 @@ -import os -import argparse - -import torch -from torch_em.loss import DiceLoss - -import micro_sam.training as sam_training -from micro_sam.util import export_custom_sam_model - -from obtain_em_datasets import get_generalist_em_loaders - - -def finetune_em_generalist(args): - """Example code for finetuning SAM on multiple electron microscopy datasets""" - # override this (below) if you have some more complex set-up and need to specify the exact gpu - device = "cuda" if torch.cuda.is_available() else "cpu" - - # training settings - model_type = args.model_type - checkpoint_path = None # override this to start training from a custom checkpoint - patch_shape = (1, 512, 512) # the patch shape for training - n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled - freeze_parts = None # override this to freeze one or more of these backbones - - # get the trainable segment anything model - model = sam_training.get_trainable_sam_model(model_type, checkpoint_path, freeze_parts) - - # all stuff needed for training - optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=10, verbose=True) - train_loader, val_loader = get_generalist_em_loaders(input_path=args.input_path, patch_shape=patch_shape) - - # this class creates all the training data for a batch (inputs, prompts and labels) - convert_inputs = sam_training.ConvertToSamInputs() - - checkpoint_name = "generalist_em_sam" - # the trainer which performs training and validation (implemented using "torch-em") - trainer = sam_training.SamTrainer( - name=checkpoint_name, - save_root=args.save_root, - train_loader=train_loader, - val_loader=val_loader, - model=model, - optimizer=optimizer, - # currently we compute loss batch-wise, else we pass channelwise True - loss=DiceLoss(channelwise=False), - metric=DiceLoss(), - device=device, - lr_scheduler=scheduler, - logger=sam_training.SamLogger, - log_image_interval=100, - mixed_precision=True, - convert_inputs=convert_inputs, - n_objects_per_batch=n_objects_per_batch, - n_sub_iteration=8, - compile_model=False - ) - trainer.fit(iterations=args.iterations) - if args.export_path is not None: - checkpoint_path = os.path.join( - "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" - ) - export_custom_sam_model( - checkpoint_path=checkpoint_path, - model_type=model_type, - save_path=args.export_path, - ) - - -def main(): - parser = argparse.ArgumentParser(description="Finetune Segment Anything for the EM datasets.") - parser.add_argument( - "--input_path", "-i", default="/scratch/usr/nimanwai/data/", - help="The filepath to all the respective EM datasets. If the data does not exist yet it will be downloaded" - ) - parser.add_argument( - "--model_type", "-m", default="vit_b", - help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l." - ) - parser.add_argument( - "--save_root", "-s", - help="Where to save the checkpoint and logs. By default they will be saved where this script is run from." - ) - parser.add_argument( - "--iterations", type=int, default=int(1e5), - help="For how many iterations should the model be trained? By default 100k." - ) - parser.add_argument( - "--export_path", "-e", - help="Where to export the finetuned model to. The exported model can be used in the annotation tools." - ) - args = parser.parse_args() - finetune_em_generalist(args) - - -if __name__ == "__main__": - main() diff --git a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py index 3e08205e..474a913f 100644 --- a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py +++ b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py @@ -1,98 +1,77 @@ import os import numpy as np -from math import ceil, floor import torch import torch_em import torch_em.data.datasets as datasets -from torch_em.transform.label import label_consecutive from torch_em.data import MinInstanceSampler, ConcatDataset -from torch_em.transform.raw import standardize, normalize_percentile +from torch_em.transform.label import PerObjectDistanceTransform +from torch_em.transform.raw import normalize_percentile, normalize + +from micro_sam.training import identity +from micro_sam.training.util import ResizeRawTrafo, ResizeLabelTrafo def neurips_raw_trafo(raw): raw = datasets.neurips_cell_seg.to_rgb(raw) # ensures 3 channels for the neurips data raw = normalize_percentile(raw) raw = np.mean(raw, axis=0) - raw = standardize(raw) + raw = normalize(raw) + raw = raw * 255 return raw -def tissuenet_raw_trafo(raw, desired_shape=(512, 512)): - raw = normalize_percentile(raw, axis=(1, 2)) - raw = np.mean(raw, axis=0) - raw = standardize(raw) - - tmp_ddim = (desired_shape[0] - raw.shape[0], desired_shape[1] - raw.shape[1]) - ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) - raw = np.pad(raw, pad_width=((ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), mode="reflect") - assert raw.shape == desired_shape - return raw - - -def tissuenet_label_trafo(labels, desired_shape=(512, 512)): - labels = label_consecutive(labels) - - tmp_ddim = (desired_shape[0] - labels.shape[0], desired_shape[1] - labels.shape[1]) - ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) - labels = np.pad( - labels, - pad_width=((ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), - mode="constant" - ) - assert labels.shape == desired_shape - return labels - - -def raw_padding_trafo(raw, desired_shape=(512, 512)): - tmp_ddim = (desired_shape[0] - raw.shape[0], desired_shape[1] - raw.shape[1]) - ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) - raw = np.pad(raw, pad_width=((ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), mode="reflect") - assert raw.shape == desired_shape +def deepbacs_raw_trafo(raw): + raw = normalize(raw) + raw = raw * 255 return raw def get_concat_lm_datasets(input_path, patch_shape, split_choice): assert split_choice in ["train", "val"] - label_dtype = torch.int64 - label_transform = label_consecutive + label_dtype = torch.float32 + label_transform = PerObjectDistanceTransform( + distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=0 + ) sampler = MinInstanceSampler() generalist_dataset = ConcatDataset( datasets.get_tissuenet_dataset( - path=os.path.join(input_path, "tissuenet"), split=split_choice, download=True, - patch_shape=patch_shape if split_choice == "train" else (256, 256), raw_channel="rgb", - label_channel="cell", raw_transform=tissuenet_raw_trafo, label_transform=tissuenet_label_trafo, - sampler=sampler, label_dtype=label_dtype, n_samples=1000 if split_choice == "train" else 100 + path=os.path.join(input_path, "tissuenet"), split=split_choice, download=True, patch_shape=patch_shape, + raw_channel="rgb", label_channel="cell", sampler=sampler, label_dtype=label_dtype, + raw_transform=ResizeRawTrafo(patch_shape), label_transform=ResizeLabelTrafo(patch_shape, min_size=0), + n_samples=1000 if split_choice == "train" else 100 ), datasets.get_livecell_dataset( path=os.path.join(input_path, "livecell"), split=split_choice, patch_shape=patch_shape, - label_transform=label_transform, sampler=sampler, label_dtype=label_dtype, + label_transform=label_transform, sampler=sampler, label_dtype=label_dtype, raw_transform=identity, n_samples=1000 if split_choice == "train" else 100, download=True ), datasets.get_deepbacs_dataset( - path=os.path.join(input_path, "deepbacs"), split=split_choice if split_choice == "train" else "test", - patch_shape=patch_shape, label_transform=label_transform, sampler=sampler, label_dtype=label_dtype, - download=True + path=os.path.join(input_path, "deepbacs"), split=split_choice, patch_shape=patch_shape, + raw_transform=deepbacs_raw_trafo, label_transform=label_transform, label_dtype=label_dtype, + download=True, sampler=MinInstanceSampler(min_num_instances=4) ), datasets.get_neurips_cellseg_supervised_dataset( - root=os.path.join(input_path, "neurips-cell-seg"), split=split_choice, patch_shape=patch_shape, - raw_transform=neurips_raw_trafo, label_transform=label_transform, label_dtype=label_dtype, sampler=sampler + root=os.path.join(input_path, "neurips-cell-seg"), split=split_choice, + patch_shape=patch_shape, raw_transform=neurips_raw_trafo, label_transform=label_transform, + label_dtype=label_dtype, sampler=MinInstanceSampler(min_num_instances=3) ), datasets.get_dsb_dataset( path=os.path.join(input_path, "dsb"), split=split_choice if split_choice == "train" else "test", - patch_shape=(1, patch_shape[0], patch_shape[1]), label_transform=label_transform, sampler=sampler, - label_dtype=label_dtype, download=True + patch_shape=patch_shape, label_transform=label_transform, sampler=sampler, + label_dtype=label_dtype, download=True, raw_transform=identity ), datasets.get_plantseg_dataset( path=os.path.join(input_path, "plantseg"), name="root", sampler=MinInstanceSampler(min_num_instances=10), - label_transform=tissuenet_label_trafo, ndim=2, split=split_choice, label_dtype=label_dtype, - raw_transform=raw_padding_trafo, patch_shape=(1, patch_shape[0], patch_shape[1]), download=True, + patch_shape=(1, *patch_shape), download=True, split=split_choice, ndim=2, label_dtype=label_dtype, + raw_transform=ResizeRawTrafo(patch_shape, do_rescaling=False), + label_transform=ResizeLabelTrafo(patch_shape, min_size=0), n_samples=1000 if split_choice == "train" else 100 ) - ) + # increasing the sampling attempts for the neurips cellseg dataset generalist_dataset.datasets[3].max_sampling_attempts = 5000 return generalist_dataset diff --git a/finetuning/generalists/training/light_microscopy/train_lm_generalist.py b/finetuning/generalists/training/light_microscopy/train_lm_generalist.py index 799fde6b..72f358be 100644 --- a/finetuning/generalists/training/light_microscopy/train_lm_generalist.py +++ b/finetuning/generalists/training/light_microscopy/train_lm_generalist.py @@ -2,7 +2,9 @@ import argparse import torch -from torch_em.loss import DiceLoss + +from torch_em.model import UNETR +from torch_em.loss import DiceBasedDistanceLoss import micro_sam.training as sam_training from micro_sam.util import export_custom_sam_model @@ -11,7 +13,7 @@ def finetune_lm_generalist(args): - """Example code for finetuning SAM on multiple light microscopy datasets""" + """Code for finetuning SAM on multiple light microscopy datasets""" # override this (below) if you have some more complex set-up and need to specify the exact gpu device = "cuda" if torch.cuda.is_available() else "cpu" @@ -23,39 +25,65 @@ def finetune_lm_generalist(args): freeze_parts = None # override this to freeze one or more of these backbones # get the trainable segment anything model - model = sam_training.get_trainable_sam_model(model_type, checkpoint_path, freeze_parts) + model = sam_training.get_trainable_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + freeze=freeze_parts + ) + model.to(device) + + # let's get the UNETR model for automatic instance segmentation pipeline + unetr = UNETR( + backbone="sam", + encoder=model.sam.image_encoder, + out_channels=3, + use_sam_stats=True, + final_activation="Sigmoid", + use_skip_connection=False, + resize_input=True + ) + unetr.to(device) - # all stuff needed for training - optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=10, verbose=True) + # let's get the parameters for SAM and the decoder from UNETR + joint_model_params = [params for params in model.parameters()] # sam parameters + for name, params in unetr.named_parameters(): # unetr's decoder parameters + if not name.startswith("encoder"): + joint_model_params.append(params) + + # all the stuff we need for training + optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True) train_loader, val_loader = get_generalist_lm_loaders(input_path=args.input_path, patch_shape=patch_shape) # this class creates all the training data for a batch (inputs, prompts and labels) - convert_inputs = sam_training.ConvertToSamInputs() + convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) + + checkpoint_name = f"{args.model_type}/lm_generalist_sam" - checkpoint_name = "generalist_lm_sam" - # the trainer which performs training and validation (implemented using "torch_em") - trainer = sam_training.SamTrainer( + # the trainer which performs the joint training and validation (implemented using "torch_em") + trainer = sam_training.JointSamTrainer( name=checkpoint_name, save_root=args.save_root, train_loader=train_loader, val_loader=val_loader, model=model, optimizer=optimizer, - # currently we compute loss batch-wise, else we pass channelwise True - loss=DiceLoss(channelwise=False), - metric=DiceLoss(), device=device, lr_scheduler=scheduler, - logger=sam_training.SamLogger, + logger=sam_training.JointSamLogger, log_image_interval=100, mixed_precision=True, convert_inputs=convert_inputs, n_objects_per_batch=n_objects_per_batch, n_sub_iteration=8, - compile_model=False + compile_model=False, + mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training + unetr=unetr, + instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True), + instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True) ) - trainer.fit(iterations=args.iterations) + trainer.fit(args.iterations, save_every_kth_epoch=args.save_every_kth_epoch) if args.export_path is not None: checkpoint_path = os.path.join( "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" @@ -70,25 +98,29 @@ def finetune_lm_generalist(args): def main(): parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LM datasets.") parser.add_argument( - "--input_path", "-i", default="/scratch/usr/nimanwai/data/", + "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/", help="The filepath to all the respective LM datasets. If the data does not exist yet it will be downloaded" ) parser.add_argument( "--model_type", "-m", default="vit_b", - help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l." + help="The model type to use for fine-tuning. Either vit_b, vit_l or vit_h." ) parser.add_argument( "--save_root", "-s", help="Where to save the checkpoint and logs. By default they will be saved where this script is run from." ) parser.add_argument( - "--iterations", type=int, default=int(1e5), - help="For how many iterations should the model be trained? By default 100k." + "--iterations", type=int, default=int(25e4), + help="For how many iterations should the model be trained? By default 250k." ) parser.add_argument( "--export_path", "-e", help="Where to export the finetuned model to. The exported model can be used in the annotation tools." ) + parser.add_argument( + "--save_every_kth_epoch", type=int, default=None, + help="To save every kth epoch while fine-tuning. Expects an integer value." + ) args = parser.parse_args() finetune_lm_generalist(args) diff --git a/finetuning/livecell/amg/grid_search_and_inference.py b/finetuning/livecell/amg/grid_search_and_inference.py deleted file mode 100644 index 1a04d763..00000000 --- a/finetuning/livecell/amg/grid_search_and_inference.py +++ /dev/null @@ -1,40 +0,0 @@ -import argparse -from micro_sam.evaluation.livecell import run_livecell_amg -from util import DATA_ROOT, get_checkpoint, get_experiment_folder, check_model - - -def run_job(model_name, use_mws): - checkpoint, model_type = get_checkpoint(model_name) - experiment_folder = get_experiment_folder(model_name) - input_folder = DATA_ROOT - - run_livecell_amg( - checkpoint, model_type, input_folder, experiment_folder, - n_val_per_cell_type=25, use_mws=use_mws, - ) - - -# TODO -def check_amg(model_name, use_mws): - pass - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("-n", "--name", required=True) - parser.add_argument("--mws", action="store_true") - parser.add_argument("-c", "--check", action="store_true") - args = parser.parse_args() - - model_name = args.name - use_mws = args.mws - check_model(model_name) - - if args.check: - check_amg(model_name, use_mws) - else: - run_job(model_name, use_mws) - - -if __name__ == "__main__": - main() diff --git a/finetuning/livecell/amg/grid_search_and_inference.sbatch b/finetuning/livecell/amg/grid_search_and_inference.sbatch deleted file mode 100755 index a3692182..00000000 --- a/finetuning/livecell/amg/grid_search_and_inference.sbatch +++ /dev/null @@ -1,9 +0,0 @@ -#! /bin/bash -#SBATCH -c 8 -#SBATCH --mem 96G -#SBATCH -t 2880 -#SBATCH -p grete:shared -#SBATCH -G A100:1 - -source activate sam -python grid_search_and_inference.py $@ diff --git a/finetuning/livecell/amg/util.py b/finetuning/livecell/amg/util.py deleted file mode 100644 index eff50eac..00000000 --- a/finetuning/livecell/amg/util.py +++ /dev/null @@ -1,30 +0,0 @@ -import os - -DATA_ROOT = "/scratch/projects/nim00007/data/LiveCELL" -EXPERIMENT_ROOT = "/scratch/projects/nim00007/sam/experiments/livecell" -MODELS = { - "vit_b": "/scratch-grete/projects/nim00007/sam/vanilla/sam_vit_b_01ec64.pth", - "vit_h": "/scratch-grete/projects/nim00007/sam/vanilla/sam_vit_h_4b8939.pth", - "vit_b_specialist": "/scratch-grete/projects/nim00007/sam/LM/LiveCELL/vit_b/best.pt", - "vit_h_specialist": "/scratch-grete/projects/nim00007/sam/LM/LiveCELL/vit_h/best.pt", - "vit_b_generalist": "/scratch-grete/projects/nim00007/sam/LM/generalist/vit_b/best.pt", - "vit_h_generalist": "/scratch-grete/projects/nim00007/sam/LM/generalist/vit_h/best.pt", -} - - -def get_checkpoint(name): - assert name in MODELS, name - ckpt = MODELS[name] - assert os.path.exists(ckpt), ckpt - model_type = name[:5] - assert model_type in ("vit_b", "vit_h"), model_type - return ckpt, model_type - - -def get_experiment_folder(name): - return os.path.join(EXPERIMENT_ROOT, name) - - -def check_model(name): - if name not in MODELS: - raise ValueError(f"Invalid model {name}, expect one of {MODELS.keys()}") diff --git a/finetuning/livecell/evaluation/evaluate_amg.py b/finetuning/livecell/evaluation/evaluate_amg.py new file mode 100644 index 00000000..5fa49bbe --- /dev/null +++ b/finetuning/livecell/evaluation/evaluate_amg.py @@ -0,0 +1,44 @@ +import argparse +import os + +from micro_sam.evaluation.evaluation import run_evaluation +from micro_sam.evaluation.livecell import run_livecell_amg +from util import DATA_ROOT, get_pred_and_gt_paths + + +def run_amg(model_type, checkpoint, experiment_folder): + input_folder = DATA_ROOT + prediction_folder = run_livecell_amg( + checkpoint, + input_folder, + model_type, + experiment_folder, + n_val_per_cell_type=25, + ) + return prediction_folder + + +def eval_amg(prediction_folder, experiment_folder): + print("Evaluating", prediction_folder) + pred_paths, gt_paths = get_pred_and_gt_paths(prediction_folder) + save_path = os.path.join(experiment_folder, "results", "amg.csv") + res = run_evaluation(gt_paths, pred_paths, save_path=save_path) + print(res) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", "--model", type=str, required=True, + help="Provide the model type to initialize the predictor" + ) + parser.add_argument("-c", "--checkpoint", type=str, required=True) + parser.add_argument("-e", "--experiment_folder", type=str, required=True) + args = parser.parse_args() + + prediction_folder = run_amg(args.model, args.checkpoint, args.experiment_folder) + eval_amg(prediction_folder, args.experiment_folder) + + +if __name__ == "__main__": + main() diff --git a/finetuning/livecell/evaluation/evaluate_amg.sbatch b/finetuning/livecell/evaluation/evaluate_amg.sbatch new file mode 100755 index 00000000..c6ce08e2 --- /dev/null +++ b/finetuning/livecell/evaluation/evaluate_amg.sbatch @@ -0,0 +1,12 @@ +#! /bin/bash +#SBATCH -c 8 +#SBATCH --mem 96G +#SBATCH -t 6:00:00 +#SBATCH -p grete:shared +#SBATCH -G A100:1 +#SBATCH -A nim00007 + +source activate sam +python evaluate_amg.py -c /scratch/usr/nimanwai/micro-sam/checkpoints/vit_b/livecell_sam/best.pt \ + -m vit_b \ + -e /scratch/projects/nim00007/sam/experiments/new_models/specialists/lm/livecell/vit_b/ diff --git a/finetuning/livecell/evaluation/evaluate_instance_segmentation.py b/finetuning/livecell/evaluation/evaluate_instance_segmentation.py new file mode 100644 index 00000000..c0c94e2a --- /dev/null +++ b/finetuning/livecell/evaluation/evaluate_instance_segmentation.py @@ -0,0 +1,46 @@ +import argparse +import os + +from micro_sam.evaluation.evaluation import run_evaluation +from micro_sam.evaluation.livecell import run_livecell_instance_segmentation_with_decoder +from util import DATA_ROOT, get_pred_and_gt_paths + + +def run_instance_segmentation_with_decoder(model_type, checkpoint, experiment_folder): + input_folder = DATA_ROOT + prediction_folder = run_livecell_instance_segmentation_with_decoder( + checkpoint, + input_folder, + model_type, + experiment_folder, + n_val_per_cell_type=25, + ) + return prediction_folder + + +def eval_instance_segmentation_with_decoder(prediction_folder, experiment_folder): + print("Evaluating", prediction_folder) + pred_paths, gt_paths = get_pred_and_gt_paths(prediction_folder) + save_path = os.path.join(experiment_folder, "results", "instance_segmentation_with_decoder.csv") + res = run_evaluation(gt_paths, pred_paths, save_path=save_path) + print(res) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", "--model", type=str, required=True, + help="Provide the model type to initialize the predictor" + ) + parser.add_argument("-c", "--checkpoint", type=str, required=True,) + parser.add_argument("-e", "--experiment_folder", type=str, required=True) + args = parser.parse_args() + + prediction_folder = run_instance_segmentation_with_decoder( + args.model, args.checkpoint, args.experiment_folder + ) + eval_instance_segmentation_with_decoder(prediction_folder, args.experiment_folder) + + +if __name__ == "__main__": + main() diff --git a/finetuning/livecell/evaluation/evaluate_instance_segmentation.sbatch b/finetuning/livecell/evaluation/evaluate_instance_segmentation.sbatch new file mode 100755 index 00000000..be80d046 --- /dev/null +++ b/finetuning/livecell/evaluation/evaluate_instance_segmentation.sbatch @@ -0,0 +1,12 @@ +#! /bin/bash +#SBATCH -c 8 +#SBATCH --mem 64G +#SBATCH -t 6:00:00 +#SBATCH -p grete:shared +#SBATCH -G A100:1 +#SBATCH -A nim00007 + +source activate sam +python evaluate_instance_segmentation.py -c /scratch/usr/nimanwai/micro-sam/checkpoints/vit_h/livecell_sam/best.pt \ + -m vit_h \ + -e /scratch/projects/nim00007/sam/experiments/new_models/specialists/lm/livecell/vit_h/ diff --git a/finetuning/livecell/evaluation/iterative.sbatch b/finetuning/livecell/evaluation/iterative.sbatch deleted file mode 100644 index c690df38..00000000 --- a/finetuning/livecell/evaluation/iterative.sbatch +++ /dev/null @@ -1,11 +0,0 @@ -#! /bin/bash -#SBATCH -c 16 -#SBATCH --mem 48G -#SBATCH -t 6:00:00 -#SBATCH -p grete:shared -#SBATCH -G A100:1 -#SBATCH -A nim00007 -#SBATCH --job-name=sam-iterative-prompting - -source activate sam -python iterative_prompting.py $@ \ No newline at end of file diff --git a/finetuning/livecell/evaluation/iterative_prompting.py b/finetuning/livecell/evaluation/iterative_prompting.py index d156afd1..ebd4ab49 100644 --- a/finetuning/livecell/evaluation/iterative_prompting.py +++ b/finetuning/livecell/evaluation/iterative_prompting.py @@ -1,32 +1,20 @@ +import argparse import os -import pandas as pd from glob import glob +import pandas as pd + from micro_sam.evaluation import inference from micro_sam.evaluation.evaluation import run_evaluation +from util import get_paths, get_model, get_pred_and_gt_paths -from util import get_paths, get_checkpoint, MODELS - -LIVECELL_GT_ROOT = "/scratch/projects/nim00007/data/LiveCELL/annotations_corrected/livecell_test_images" -PREDICTION_ROOT = "/scratch/projects/nim00007/sam/iterative_evaluation" - - -def get_prediction_root(start_with_box_prompt, model_description, root_dir=PREDICTION_ROOT): - if start_with_box_prompt: - prediction_root = os.path.join(root_dir, model_description, "start_with_box") - else: - prediction_root = os.path.join(root_dir, model_description, "start_with_point") - - return prediction_root - - -def run_interactive_prompting(predictor, start_with_box_prompt, model_description, prediction_root): - # we organize all the folders with data from this experiment below - embedding_folder = os.path.join(PREDICTION_ROOT, model_description, "embeddings") - os.makedirs(embedding_folder, exist_ok=True) +def run_interactive_prompting(exp_folder, predictor, start_with_box_prompt): + prediction_root = os.path.join( + exp_folder, "start_with_box" if start_with_box_prompt else "start_with_point" + ) + embedding_folder = os.path.join(exp_folder, "embeddings") image_paths, gt_paths = get_paths() - inference.run_inference_with_iterative_prompting( predictor=predictor, image_paths=image_paths, @@ -35,69 +23,51 @@ def run_interactive_prompting(predictor, start_with_box_prompt, model_descriptio prediction_dir=prediction_root, start_with_box_prompt=start_with_box_prompt ) + return prediction_root -def get_pg_paths(pred_folder): - pred_paths = sorted(glob(os.path.join(pred_folder, "*.tif"))) - names = [os.path.split(path)[1] for path in pred_paths] - gt_paths = [ - os.path.join(LIVECELL_GT_ROOT, name.split("_")[0], name) for name in names - ] - assert all(os.path.exists(pp) for pp in gt_paths) - return pred_paths, gt_paths - - -def evaluate_interactive_prompting(prediction_root, start_with_box_prompt, model_description): +def evaluate_interactive_prompting(prediction_root, start_with_box_prompt, exp_folder): assert os.path.exists(prediction_root), prediction_root - csv_save_dir = f"./iterative_prompting_results/{model_description}" - os.makedirs(csv_save_dir, exist_ok=True) - csv_path = os.path.join(csv_save_dir, "start_with_box.csv" if start_with_box_prompt else "start_with_point.csv") - if os.path.exists(csv_path): - print("The evaluated results for the expected setting already exist here:", csv_path) - return - prediction_folders = sorted(glob(os.path.join(prediction_root, "iteration*"))) list_of_results = [] for pred_folder in prediction_folders: print("Evaluating", pred_folder) - pred_paths, gt_paths = get_pg_paths(pred_folder) + pred_paths, gt_paths = get_pred_and_gt_paths(pred_folder) res = run_evaluation(gt_paths, pred_paths, save_path=None) list_of_results.append(res) print(res) df = pd.concat(list_of_results, ignore_index=True) + + # Save the results in the experiment folder. + result_folder = os.path.join(exp_folder, "results") + os.makedirs(result_folder, exist_ok=True) + csv_path = os.path.join( + result_folder, + "iterative_prompts_start_box.csv" if start_with_box_prompt else "iterative_prompts_start_point.csv" + ) df.to_csv(csv_path) -def main(args): - start_with_box_prompt = args.box # overwrite to start first iters' prompt with box instead of single point - model_description = args.model # overwrite to specify the choice of vanilla / finetuned models +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", "--model", type=str, required=True, help="Provide the model type to initialize the predictor" + ) + parser.add_argument("-c", "--checkpoint", type=str, required=True) + parser.add_argument("-e", "--experiment_folder", type=str, required=True) + parser.add_argument("--box", action="store_true", help="If passed, starts with first prompt as box") + args = parser.parse_args() - # add the root prediction path where you would like to save the iterative prompting results - prediction_root = get_prediction_root(start_with_box_prompt, model_description) + start_with_box_prompt = args.box # overwrite to start first iters' prompt with box instead of single point - # get the model checkpoints and desired model name to initialize the predictor - if args.checkpoint is None and model_description in MODELS.keys(): - checkpoint, model_type = get_checkpoint(model_description) - else: - checkpoint = args.checkpoint - model_type = model_description[:5] # get the predictor to perform inference - predictor = inference.get_predictor(checkpoint, model_type) + predictor = get_model(model_type=args.model, ckpt=args.checkpoint) - run_interactive_prompting(predictor, start_with_box_prompt, model_description, prediction_root) - evaluate_interactive_prompting(prediction_root, start_with_box_prompt, model_description) + prediction_root = run_interactive_prompting(args.experiment_folder, predictor, start_with_box_prompt) + evaluate_interactive_prompting(prediction_root, start_with_box_prompt, args.experiment_folder) if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("--box", action="store_true", help="If passed, starts with first prompt as box") - parser.add_argument( - "-m", "--model", type=str, # options: "vit_h", "vit_h_generalist", "vit_h_specialist" - help="Provide the model type to initialize the predictor" - ) - parser.add_argument("-c", "--checkpoint", type=str, default=None) - args = parser.parse_args() - main(args) + main() diff --git a/finetuning/livecell/evaluation/iterative_prompting.sbatch b/finetuning/livecell/evaluation/iterative_prompting.sbatch new file mode 100755 index 00000000..72dc7e60 --- /dev/null +++ b/finetuning/livecell/evaluation/iterative_prompting.sbatch @@ -0,0 +1,12 @@ +#! /bin/bash +#SBATCH -c 8 +#SBATCH --mem 64G +#SBATCH -t 6:00:00 +#SBATCH -p grete:shared +#SBATCH -G A100:1 +#SBATCH -A nim00007 + +source activate sam +python iterative_prompting.py -c /scratch/usr/nimanwai/micro-sam/checkpoints/vit_h/livecell_sam/best.pt \ + -m vit_h \ + -e /scratch/projects/nim00007/sam/experiments/new_models/specialists/lm/livecell/vit_h/ diff --git a/finetuning/livecell/evaluation/evaluation.py b/finetuning/livecell/evaluation/old/evaluation.py similarity index 100% rename from finetuning/livecell/evaluation/evaluation.py rename to finetuning/livecell/evaluation/old/evaluation.py diff --git a/finetuning/livecell/evaluation/evaluation.sbatch b/finetuning/livecell/evaluation/old/evaluation.sbatch similarity index 100% rename from finetuning/livecell/evaluation/evaluation.sbatch rename to finetuning/livecell/evaluation/old/evaluation.sbatch diff --git a/finetuning/livecell/evaluation/inference.py b/finetuning/livecell/evaluation/old/inference.py similarity index 100% rename from finetuning/livecell/evaluation/inference.py rename to finetuning/livecell/evaluation/old/inference.py diff --git a/finetuning/livecell/evaluation/inference.sbatch b/finetuning/livecell/evaluation/old/inference.sbatch similarity index 100% rename from finetuning/livecell/evaluation/inference.sbatch rename to finetuning/livecell/evaluation/old/inference.sbatch diff --git a/finetuning/livecell/evaluation/precompute_prompts.py b/finetuning/livecell/evaluation/old/precompute_prompts.py similarity index 100% rename from finetuning/livecell/evaluation/precompute_prompts.py rename to finetuning/livecell/evaluation/old/precompute_prompts.py diff --git a/finetuning/livecell/evaluation/precompute_prompts.sbatch b/finetuning/livecell/evaluation/old/precompute_prompts.sbatch similarity index 100% rename from finetuning/livecell/evaluation/precompute_prompts.sbatch rename to finetuning/livecell/evaluation/old/precompute_prompts.sbatch diff --git a/finetuning/livecell/evaluation/precompute_embeddings.py b/finetuning/livecell/evaluation/precompute_embeddings.py index 3d53ed83..431b80a2 100644 --- a/finetuning/livecell/evaluation/precompute_embeddings.py +++ b/finetuning/livecell/evaluation/precompute_embeddings.py @@ -2,21 +2,26 @@ import os from micro_sam.evaluation import precompute_all_embeddings -from util import get_paths, get_model, get_experiment_folder +from util import get_paths, get_model def main(): parser = argparse.ArgumentParser() - parser.add_argument("-n", "--name", required=True) + parser.add_argument("-m", "--model", type=str, required=True) + parser.add_argument("-c", "--checkpoint", type=str, required=True) + parser.add_argument("-e", "--experiment_folder", type=str, required=True) args = parser.parse_args() - name = args.name - - image_paths, _ = get_paths() - predictor = get_model(name) - exp_folder = get_experiment_folder(name) - embedding_dir = os.path.join(exp_folder, "embeddings") + predictor = get_model(model_type=args.model, ckpt=args.checkpoint) + embedding_dir = os.path.join(args.experiment_folder, "embeddings") os.makedirs(embedding_dir, exist_ok=True) + + # getting the embeddings for the test set + image_paths, _ = get_paths("test") + precompute_all_embeddings(predictor, image_paths, embedding_dir) + + # getting the embeddings for the val set + image_paths, _ = get_paths("val") precompute_all_embeddings(predictor, image_paths, embedding_dir) diff --git a/finetuning/livecell/evaluation/precompute_embeddings.sbatch b/finetuning/livecell/evaluation/precompute_embeddings.sbatch index 70faec19..44456a1a 100755 --- a/finetuning/livecell/evaluation/precompute_embeddings.sbatch +++ b/finetuning/livecell/evaluation/precompute_embeddings.sbatch @@ -4,7 +4,11 @@ #SBATCH --mem 64G #SBATCH -t 120 #SBATCH -p grete:shared +#SBATCH -A nim00007 #SBATCH -G A100:1 -source activate sam -python precompute_embeddings.py -n $1 +source ~/.bashrc +micromamba activate main +python precompute_embeddings.py -c /scratch/usr/nimanwai/micro-sam/checkpoints/vit_b/lm_generalist_sam/best.pt \ + -m vit_b \ + -e /scratch/projects/nim00007/sam/experiments/new_models/generalists/lm/livecell/vit_b/ diff --git a/finetuning/livecell/evaluation/submit_evaluation.py b/finetuning/livecell/evaluation/submit_evaluation.py new file mode 100644 index 00000000..e7ddacf5 --- /dev/null +++ b/finetuning/livecell/evaluation/submit_evaluation.py @@ -0,0 +1,102 @@ +import os +import shutil +import subprocess +from glob import glob +from datetime import datetime + + +def write_batch_script(env_name, out_path, inference_setup, checkpoint, model_type, experiment_folder, delay=True): + """Writing scripts with different fold-trainings for micro-sam evaluation + """ + batch_script = f"""#!/bin/bash +#SBATCH -c 8 +#SBATCH --mem 128G +#SBATCH -t 6:00:00 +#SBATCH -p grete:shared +#SBATCH -G A100:1 +#SBATCH -A nim00007 +#SBATCH --job-name={inference_setup} + +source ~/.bashrc +mamba activate {env_name} \n""" + + if delay: + batch_script += "sleep 10m \n" + + # python script + python_script = f"python {inference_setup}.py " + + _op = out_path[:-3] + f"_{inference_setup}.sh" + + # add the finetuned checkpoint + python_script += f"-c {checkpoint} " + + # name of the model configuration + python_script += f"-m {model_type} " + + # experiment folder + python_script += f"-e {experiment_folder} " + + # let's add the python script to the bash script + batch_script += python_script + + with open(_op, "w") as f: + f.write(batch_script) + + # we run the first prompt for iterative once starting with point, and then starting with box (below) + if inference_setup == "iterative_prompting": + batch_script += "--box " + + new_path = out_path[:-3] + f"_{inference_setup}_box.sh" + with open(new_path, "w") as f: + f.write(batch_script) + + +def get_batch_script_names(tmp_folder): + tmp_folder = os.path.expanduser(tmp_folder) + os.makedirs(tmp_folder, exist_ok=True) + + script_name = "livecell-inference" + + dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") + tmp_name = script_name + dt + batch_script = os.path.join(tmp_folder, f"{tmp_name}.sh") + + return batch_script + + +def submit_slurm(): + """Submit python script that needs gpus with given inputs on a slurm node. + """ + tmp_folder = "./gpu_jobs" + + # parameters to run the inference scripts + env_name = "sam" + model_type = "vit_b" + checkpoint = f"/scratch/usr/nimanwai/micro-sam/checkpoints/{model_type}/lm_generalist_sam/best.pt" + experiment_folder = f"/scratch/projects/nim00007/sam/experiments/new_models/generalists/lm/livecell/{model_type}/" + + all_setups = ["precompute_embeddings", "evaluate_amg", "evaluate_instance_segmentation", "iterative_prompting"] + for current_setup in all_setups: + write_batch_script( + env_name=env_name, + out_path=get_batch_script_names(tmp_folder), + inference_setup=current_setup, + checkpoint=checkpoint, + model_type=model_type, + experiment_folder=experiment_folder, + delay=False if current_setup == "precompute_embeddings" else True + ) + + for my_script in glob(tmp_folder + "/*"): + cmd = ["sbatch", my_script] + subprocess.run(cmd) + + +if __name__ == "__main__": + try: + shutil.rmtree("./gpu_jobs") + except FileNotFoundError: + pass + + submit_slurm() diff --git a/finetuning/livecell/evaluation/train_grid_search_eval.py b/finetuning/livecell/evaluation/train_grid_search_eval.py new file mode 100644 index 00000000..ad57025b --- /dev/null +++ b/finetuning/livecell/evaluation/train_grid_search_eval.py @@ -0,0 +1,94 @@ +import os +from subprocess import run + +import pandas as pd + +# TODO we need to make sure that this has the corrected training data for the proper training +DATA_ROOT = "/scratch/projects/nim00007/data/LiveCELL" +SAVE_ROOT = "/scratch/projects/nim00007/sam/livecell_grid_search" + +LRS = [1e-4, 5e-5, 1e-5, 5e-6] + + +def _get_name_and_checkpoint(lr, use_adamw): + name = f"vit_b-lr{lr}" + if use_adamw: + name += "-adamw" + checkpoint = os.path.join(SAVE_ROOT, "checkpoints", name, "best.pt") + return name, checkpoint + + +def precompute_embeddings(): + for lr in LRS: + for use_adamw in [True, False]: + name, ckpt = _get_name_and_checkpoint(lr, use_adamw) + if not os.path.exists(ckpt): + print("Skipping:", ckpt) + continue + cmd = ["sbatch", "precompute_embeddings.sbatch", "--name", f"livecell_grid_search/{name}", + "-m", "vit_b", "-c", ckpt] + run(cmd) + + +def run_evaluations(): + for lr in LRS: + for use_adamw in [True, False]: + name, ckpt = _get_name_and_checkpoint(lr, use_adamw) + if not os.path.exists(ckpt): + print("Skipping:", ckpt) + continue + # iterative prompting (start with prompt) + cmd = ["sbatch", "iterative_prompting.sbatch", "--name", f"livecell_grid_search/{name}", + "-m", "vit_b", "-c", ckpt] + run(cmd) + # iterative prompting (start with box) + cmd = ["sbatch", "iterative_prompting.sbatch", "--name", f"livecell_grid_search/{name}", + "-m", "vit_b", "-c", ckpt, "--box"] + run(cmd) + # instance segmentation + cmd = ["sbatch", "evaluate_instance_segmentation.sbatch", "--name", f"livecell_grid_search/{name}", + "-m", "vit_b", "-c", ckpt] + run(cmd) + + +def accumulate_results(): + result_root = "/scratch/projects/nim00007/sam/experiments/livecell/livecell_grid_search" + + # TODO add the instance segmentation result + exp_names = [ + "iterative_prompts_start_box.csv", "iterative_prompts_start_point.csv", "instance_segmentation_with_decoder.csv" + ] + + for exp_name in exp_names: + results = [] + for lr in LRS: + for use_adamw in [True, False]: + name = f"vit_b-lr{lr}" + if use_adamw: + name += "-adamw" + + result_path = os.path.join(result_root, name, "results", exp_name) + if not os.path.exists(result_path): + continue + + this_result = pd.read_csv(result_path) + this_result = this_result.rename(columns={"Unnamed: 0": "iteration"}) + + this_result["lr"] = [lr] * len(this_result) + this_result["optimizer"] = ["adamw" if use_adamw else "adam"] * len(this_result) + + results.append(this_result) + + results = pd.concat(results) + out_path = os.path.join(result_root, exp_name) + results.to_csv(out_path, index=False) + + +def main(): + # precompute_embeddings() + # run_evaluations() + accumulate_results() + + +if __name__ == "__main__": + main() diff --git a/finetuning/livecell/evaluation/util.py b/finetuning/livecell/evaluation/util.py index 76d60d0e..46977ee9 100644 --- a/finetuning/livecell/evaluation/util.py +++ b/finetuning/livecell/evaluation/util.py @@ -1,11 +1,14 @@ import os +from glob import glob from micro_sam.evaluation import get_predictor from micro_sam.evaluation.livecell import _get_livecell_paths +# FIXME make sure this uses the corrected ground-truth!!! DATA_ROOT = "/scratch/projects/nim00007/data/LiveCELL" EXPERIMENT_ROOT = "/scratch/projects/nim00007/sam/experiments/livecell" PROMPT_FOLDER = "/scratch/projects/nim00007/sam/experiments/prompts/livecell" +# TODO update the finetuned models MODELS = { "vit_b": "/scratch/projects/nim00007/sam/vanilla/sam_vit_b_01ec64.pth", "vit_h": "/scratch/projects/nim00007/sam/vanilla/sam_vit_h_4b8939.pth", @@ -16,8 +19,8 @@ } -def get_paths(): - return _get_livecell_paths(DATA_ROOT) +def get_paths(split="test"): + return _get_livecell_paths(DATA_ROOT, split=split) def get_checkpoint(name): @@ -29,8 +32,10 @@ def get_checkpoint(name): return ckpt, model_type -def get_model(name): - ckpt, model_type = get_checkpoint(name) +def get_model(name=None, model_type=None, ckpt=None): + if ckpt is None: + ckpt, model_type = get_checkpoint(name) + assert (ckpt is not None) and (model_type is not None) predictor = get_predictor(ckpt, model_type) return predictor @@ -44,6 +49,17 @@ def check_model(name): raise ValueError(f"Invalid model {name}, expect one of {MODELS.keys()}") +def get_pred_and_gt_paths(prediction_folder): + pred_paths = sorted(glob(os.path.join(prediction_folder, "*.tif"))) + names = [os.path.split(path)[1] for path in pred_paths] + gt_root = os.path.join(DATA_ROOT, "annotations_corrected/livecell_test_images") + gt_paths = [ + os.path.join(gt_root, name.split("_")[0], name) for name in names + ] + assert all(os.path.exists(pp) for pp in gt_paths) + return pred_paths, gt_paths + + def download_livecell(): from torch_em.data.datasets import get_livecell_loader get_livecell_loader(DATA_ROOT, "train", (512, 512), 1, download=True) diff --git a/finetuning/livecell/joint_training/grid_search_train.py b/finetuning/livecell/joint_training/grid_search_train.py new file mode 100644 index 00000000..bee170d5 --- /dev/null +++ b/finetuning/livecell/joint_training/grid_search_train.py @@ -0,0 +1,25 @@ +from subprocess import run + +# TODO we need to make sure that this has the corrected training data for the proper training +DATA_ROOT = "/scratch/projects/nim00007/data/LiveCELL" +SAVE_ROOT = "/scratch-grete/projects/nim00007/sam/livecell_grid_search" + + +def run_grid_search(dry_run): + lrs = [1e-4, 5e-5, 1e-5, 5e-6] + for lr in lrs: + for use_adamw in [True, False]: + name = f"vit_b-lr{lr}" + if use_adamw: + name += "-adamw" + cmd = ["sbatch", "submit_training.sh", "-i", DATA_ROOT, "-s", SAVE_ROOT, + "--iterations", "25000", "--name", name, "--lr", str(lr)] + if use_adamw: + cmd.append("--use_adamw") + if dry_run: + print(cmd) + else: + run(cmd) + + +run_grid_search(False) diff --git a/finetuning/livecell/joint_training/joint_finetuning.py b/finetuning/livecell/joint_training/joint_finetuning.py new file mode 100644 index 00000000..a405e6dd --- /dev/null +++ b/finetuning/livecell/joint_training/joint_finetuning.py @@ -0,0 +1,173 @@ +import os +import argparse + +import torch + +from torch_em.model import UNETR +from torch_em.loss import DiceBasedDistanceLoss +from torch_em.data.datasets import get_livecell_loader +from torch_em.transform.label import PerObjectDistanceTransform + +import micro_sam.training as sam_training +from micro_sam.util import export_custom_sam_model + + +def get_dataloaders(patch_shape, data_path, cell_type=None): + """This returns the livecell data loaders implemented in torch_em: + https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/livecell.py + It will automatically download the livecell data. + + Note: to replace this with another data loader you need to return a torch data loader + that retuns `x, y` tensors, where `x` is the image data and `y` are the labels. + The labels have to be in a label mask instance segmentation format. + I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. + Important: the ID 0 is reseved for background, and the IDs must be consecutive + """ + label_transform = PerObjectDistanceTransform( + distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25 + ) + raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1] + train_loader = get_livecell_loader( + path=data_path, patch_shape=patch_shape, split="train", batch_size=2, num_workers=16, + cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, + raw_transform=raw_transform, label_dtype=torch.float32, + ) + val_loader = get_livecell_loader( + path=data_path, patch_shape=patch_shape, split="val", batch_size=4, num_workers=16, + cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, + raw_transform=raw_transform, label_dtype=torch.float32, + ) + + return train_loader, val_loader + + +def finetune_livecell(args): + """Example code for finetuning SAM on LiveCELL""" + # override this (below) if you have some more complex set-up and need to specify the exact gpu + device = "cuda" if torch.cuda.is_available() else "cpu" + + # training settings: + model_type = args.model_type + checkpoint_path = None # override this to start training from a custom checkpoint + patch_shape = (520, 704) # the patch shape for training + n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled + freeze_parts = args.freeze # override this to freeze different parts of the model + + # get the trainable segment anything model + model = sam_training.get_trainable_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + freeze=freeze_parts + ) + model.to(device) + + # let's get the UNETR model for automatic instance segmentation pipeline + unetr = UNETR( + backbone="sam", + encoder=model.sam.image_encoder, + out_channels=3, + use_sam_stats=True, + final_activation="Sigmoid", + use_skip_connection=False, + resize_input=True, + ) + unetr.to(device) + + # let's get the parameters for SAM and the decoder from UNETR + joint_model_params = [params for params in model.parameters()] # sam parameters + for name, params in unetr.named_parameters(): # unetr's decoder parameters + if not name.startswith("encoder"): + joint_model_params.append(params) + + # Optimizer and learning rate. + if args.use_adamw: + print("Use AdamW with lr", args.lr) + optimizer = torch.optim.AdamW(joint_model_params, lr=args.lr) + else: + print("Use Adam with lr", args.lr) + optimizer = torch.optim.Adam(joint_model_params, lr=args.lr) + + # Two different schedulers depending on how long we train. + if args.iterations > 25000: + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.9, patience=10, verbose=True + ) + else: + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=2, verbose=True + ) + train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) + + # this class creates all the training data for a batch (inputs, prompts and labels) + convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) + + trainer = sam_training.JointSamTrainer( + name=args.name, + save_root=args.save_root, + train_loader=train_loader, + val_loader=val_loader, + model=model, + optimizer=optimizer, + device=device, + lr_scheduler=scheduler, + logger=sam_training.JointSamLogger, + log_image_interval=100, + mixed_precision=True, + convert_inputs=convert_inputs, + n_objects_per_batch=n_objects_per_batch, + n_sub_iteration=8, + compile_model=False, + mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training + unetr=unetr, + instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True), + instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True) + ) + trainer.fit(args.iterations) + if args.export_path is not None: + checkpoint_path = os.path.join( + "" if args.save_root is None else args.save_root, "checkpoints", args.name, "best.pt" + ) + export_custom_sam_model( + checkpoint_path=checkpoint_path, + model_type=model_type, + save_path=args.export_path, + ) + + +def main(): + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") + parser.add_argument( + "--input_path", "-i", default="/scratch/usr/nimanwai/data/livecell/", + help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." + ) + parser.add_argument( + "--model_type", "-m", default="vit_b", + help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l." + ) + parser.add_argument( + "--save_root", "-s", + help="Where to save the checkpoint and logs. By default they will be saved where this script is run." + ) + parser.add_argument( + "--iterations", type=int, default=int(1e5), + help="For how many iterations should the model be trained? By default 100k." + ) + parser.add_argument( + "--export_path", "-e", + help="Where to export the finetuned model to. The exported model can be used in the annotation tools." + ) + parser.add_argument( + "--freeze", type=str, nargs="+", default=None, + help="Which parts of the model to freeze for finetuning." + ) + # Parameter for training grid search + parser.add_argument("--name", default="livecell_sam") + parser.add_argument("--lr", type=float, default=1e-5) + parser.add_argument("--use_adamw", action="store_true") + args = parser.parse_args() + finetune_livecell(args) + + +if __name__ == "__main__": + main() diff --git a/finetuning/livecell/joint_training/unetr_inference.py b/finetuning/livecell/joint_training/unetr_inference.py new file mode 100644 index 00000000..3dfc2e59 --- /dev/null +++ b/finetuning/livecell/joint_training/unetr_inference.py @@ -0,0 +1,133 @@ +import os +import h5py +import argparse +import numpy as np +import pandas as pd +from glob import glob +from tqdm import tqdm +from pathlib import Path +import imageio.v3 as imageio +from collections import OrderedDict + +import torch + +from torch_em.model import UNETR +from torch_em.util import segmentation +from torch_em.util.prediction import predict_with_padding + +from elf.evaluation import mean_segmentation_accuracy + +from micro_sam.util import get_sam_model + + +def get_unetr_model(model_type, checkpoint, device): + # let's get the sam finetuned model + predictor = get_sam_model( + model_type=model_type + ) + + # load the model with the respective unetr model state + model = UNETR( + encoder=predictor.model.image_encoder, + out_channels=3, + use_sam_stats=True, + final_activation="Sigmoid", + use_skip_connection=False + ) + + sam_state = torch.load(checkpoint, map_location="cpu")["model_state"] + # let's get the vit parameters from sam + encoder_state = [] + prune_prefix = "sam.image_" + for k, v in sam_state.items(): + if k.startswith(prune_prefix): + encoder_state.append((k[len(prune_prefix):], v)) + encoder_state = OrderedDict(encoder_state) + + decoder_state = torch.load(checkpoint, map_location="cpu")["decoder_state"] + + unetr_state = OrderedDict(list(encoder_state.items()) + list(decoder_state.items())) + model.load_state_dict(unetr_state) + model.to(device) + model.eval() + + return model + + +def predict_for_unetr(inputs, save_dir, model, device): + save_dir = os.path.join(save_dir, "results") + os.makedirs(save_dir, exist_ok=True) + + with torch.no_grad(): + for img_path in tqdm(glob(os.path.join(inputs, "images", "livecell_test_images", "*")), + desc="Run unetr inference"): + fname = Path(img_path).stem + save_path = os.path.join(save_dir, f"{fname}.h5") + if os.path.exists(save_path): + continue + + input_ = imageio.imread(img_path) + + outputs = predict_with_padding(model, input_, device=device, min_divisible=(16, 16)) + fg, cdist, bdist = outputs.squeeze() + dm_seg = segmentation.watershed_from_center_and_boundary_distances( + cdist, bdist, fg, min_size=50, + center_distance_threshold=0.5, + boundary_distance_threshold=0.6, + distance_smoothing=1.0 + ) + + with h5py.File(save_path, "a") as f: + ds = f.require_dataset("segmentation", shape=dm_seg.shape, compression="gzip", dtype=dm_seg.dtype) + ds[:] = dm_seg + + +def evaluation_for_unetr(inputs, save_dir, csv_path): + if os.path.exists(csv_path): + return + + msa_list, sa50_list = [], [] + for gt_path in tqdm(glob(os.path.join(inputs, "annotations", "livecell_test_images", "*", "*")), + desc="Run unetr evaluation"): + gt = imageio.imread(gt_path) + fname = Path(gt_path).stem + + output_file = os.path.join(save_dir, "results", f"{fname}.h5") + with h5py.File(output_file, "r") as f: + instances = f["segmentation"][:] + + msa, sa_acc = mean_segmentation_accuracy(instances, gt, return_accuracies=True) + msa_list.append(msa) + sa50_list.append(sa_acc[0]) + + res_dict = { + "LiveCELL": "Metrics", + "mSA": np.mean(msa_list), + "SA50": np.mean(sa50_list) + } + df = pd.DataFrame.from_dict([res_dict]) + df.to_csv(csv_path) + + +def main(args): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # let's get the unetr model (initialized with the joint training setup) + model = get_unetr_model(model_type=args.model_type, checkpoint=args.checkpoint, device=device) + + # let's get the predictions + predict_for_unetr(inputs=args.inputs, save_dir=args.save_dir, model=model, device=device) + + # let's evaluate the predictions + evaluation_for_unetr(inputs=args.inputs, save_dir=args.save_dir, csv_path=args.csv_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--inputs", default="/scratch/usr/nimanwai/data/livecell/") + parser.add_argument("-c", "--checkpoint", type=str, required=True) + parser.add_argument("-m", "--model_type", type=str, default="vit_b") + parser.add_argument("--save_dir", type=str, required=True) + parser.add_argument("--csv_path", type=str, default="livecell_joint_training.csv") + args = parser.parse_args() + main(args) diff --git a/finetuning/livecell_finetuning.py b/finetuning/livecell_finetuning.py index 686a6fae..ab54c39a 100644 --- a/finetuning/livecell_finetuning.py +++ b/finetuning/livecell_finetuning.py @@ -1,11 +1,14 @@ -import argparse import os +import argparse -import micro_sam.training as sam_training import torch -import torch_em +from torch_em.model import UNETR +from torch_em.loss import DiceBasedDistanceLoss from torch_em.data.datasets import get_livecell_loader +from torch_em.transform.label import PerObjectDistanceTransform + +import micro_sam.training as sam_training from micro_sam.util import export_custom_sam_model @@ -20,13 +23,21 @@ def get_dataloaders(patch_shape, data_path, cell_type=None): I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. Important: the ID 0 is reseved for background, and the IDs must be consecutive """ - label_transform = torch_em.transform.label.label_consecutive # to ensure consecutive IDs - train_loader = get_livecell_loader(path=data_path, patch_shape=patch_shape, split="train", batch_size=2, - num_workers=16, cell_types=cell_type, download=True, - label_transform=label_transform, shuffle=True) - val_loader = get_livecell_loader(path=data_path, patch_shape=patch_shape, split="val", batch_size=1, - num_workers=16, cell_types=cell_type, download=True, - label_transform=label_transform, shuffle=True) + label_transform = PerObjectDistanceTransform( + distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25 + ) + raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1] + train_loader = get_livecell_loader( + path=data_path, patch_shape=patch_shape, split="train", batch_size=2, num_workers=16, + cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, + raw_transform=raw_transform, label_dtype=torch.float32 + ) + val_loader = get_livecell_loader( + path=data_path, patch_shape=patch_shape, split="val", batch_size=1, num_workers=16, + cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, + raw_transform=raw_transform, label_dtype=torch.float32 + ) + return train_loader, val_loader @@ -40,42 +51,70 @@ def finetune_livecell(args): checkpoint_path = None # override this to start training from a custom checkpoint patch_shape = (520, 704) # the patch shape for training n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled + freeze_parts = args.freeze # override this to freeze different parts of the model # get the trainable segment anything model - model = sam_training.get_trainable_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path) + model = sam_training.get_trainable_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + freeze=freeze_parts + ) + model.to(device) + + # let's get the UNETR model for automatic instance segmentation pipeline + unetr = UNETR( + backbone="sam", + encoder=model.sam.image_encoder, + out_channels=3, + use_sam_stats=True, + final_activation="Sigmoid", + use_skip_connection=False, + resize_input=True, + ) + unetr.to(device) + + # let's get the parameters for SAM and the decoder from UNETR + joint_model_params = [params for params in model.parameters()] # sam parameters + for name, params in unetr.named_parameters(): # unetr's decoder parameters + if not name.startswith("encoder"): + joint_model_params.append(params) # all the stuff we need for training - optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True) train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) # this class creates all the training data for a batch (inputs, prompts and labels) - convert_inputs = sam_training.ConvertToSamInputs() + convert_inputs = sam_training.ConvertToSamInputs( + transform=model.transform, box_distortion_factor=0.025 + ) + + checkpoint_name = f"{args.model_type}/livecell_sam" - checkpoint_name = "livecell_sam" - # the trainer which performs training and validation (implemented using "torch_em") - trainer = sam_training.SamTrainer( + # the trainer which performs the joint training and validation (implemented using "torch_em") + trainer = sam_training.JointSamTrainer( name=checkpoint_name, save_root=args.save_root, train_loader=train_loader, val_loader=val_loader, model=model, optimizer=optimizer, - # currently we compute loss batch-wise, else we pass channelwise True - loss=torch_em.loss.DiceLoss(channelwise=False), - metric=torch_em.loss.DiceLoss(), device=device, lr_scheduler=scheduler, - logger=sam_training.SamLogger, + logger=sam_training.JointSamLogger, log_image_interval=100, mixed_precision=True, convert_inputs=convert_inputs, n_objects_per_batch=n_objects_per_batch, n_sub_iteration=8, compile_model=False, - mask_prob=0.5 # (optional) overwrite to provide the probability of using mask inputs while training + mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training + unetr=unetr, + instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True), + instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True) ) - trainer.fit(args.iterations) + trainer.fit(args.iterations, save_every_kth_epoch=args.save_every_kth_epoch) if args.export_path is not None: checkpoint_path = os.path.join( "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" @@ -90,25 +129,33 @@ def finetune_livecell(args): def main(): parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") parser.add_argument( - "--input_path", "-i", default="/scratch/projects/nim00007/data/LiveCELL/", + "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/livecell/", help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." ) parser.add_argument( "--model_type", "-m", default="vit_b", - help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l." + help="The model type to use for fine-tuning. Either vit_b, vit_l or vit_h." ) parser.add_argument( "--save_root", "-s", help="Where to save the checkpoint and logs. By default they will be saved where this script is run." ) parser.add_argument( - "--iterations", type=int, default=int(1e5), - help="For how many iterations should the model be trained? By default 100k." + "--iterations", type=int, default=int(25e4), + help="For how many iterations should the model be trained? By default 250k." ) parser.add_argument( "--export_path", "-e", help="Where to export the finetuned model to. The exported model can be used in the annotation tools." ) + parser.add_argument( + "--freeze", type=str, nargs="+", default=None, + help="Which parts of the model to freeze for finetuning." + ) + parser.add_argument( + "--save_every_kth_epoch", type=int, default=None, + help="To save every kth epoch while fine-tuning. Expects an integer value." + ) args = parser.parse_args() finetune_livecell(args) diff --git a/finetuning/specialists/training/light_microscopy/deepbacs_finetuning.py b/finetuning/specialists/training/light_microscopy/deepbacs_finetuning.py new file mode 100644 index 00000000..1cdc938b --- /dev/null +++ b/finetuning/specialists/training/light_microscopy/deepbacs_finetuning.py @@ -0,0 +1,177 @@ +import os +import argparse + +import torch + +from torch_em.model import UNETR +from torch_em.data import MinInstanceSampler +from torch_em.transform.raw import normalize +from torch_em.loss import DiceBasedDistanceLoss +from torch_em.data.datasets import get_deepbacs_loader +from torch_em.transform.label import PerObjectDistanceTransform + +import micro_sam.training as sam_training +from micro_sam.util import export_custom_sam_model + + +def deepbacs_raw_trafo(raw): + raw = normalize(raw) + raw = raw * 255 + return raw + + +def get_dataloaders(patch_shape, data_path): + """This returns the deepbacs data loaders implemented in torch_em: + https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/deepbacs.py + It will automatically download the deepbacs data. + + Note: to replace this with another data loader you need to return a torch data loader + that retuns `x, y` tensors, where `x` is the image data and `y` are the labels. + The labels have to be in a label mask instance segmentation format. + I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. + Important: the ID 0 is reseved for background, and the IDs must be consecutive + """ + raw_transform = deepbacs_raw_trafo + label_transform = PerObjectDistanceTransform( + distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25 + ) + sampler = MinInstanceSampler(min_num_instances=4) + label_dtype = torch.float32 + + train_loader = get_deepbacs_loader( + path=data_path, split="train", patch_shape=patch_shape, batch_size=2, + raw_transform=raw_transform, label_transform=label_transform, label_dtype=label_dtype, + sampler=sampler, download=True, num_workers=16, shuffle=True + ) + val_loader = get_deepbacs_loader( + path=data_path, split="val", patch_shape=patch_shape, batch_size=1, + raw_transform=raw_transform, label_transform=label_transform, label_dtype=label_dtype, + sampler=sampler, download=True, num_workers=16, shuffle=True + ) + + return train_loader, val_loader + + +def finetune_deepbacs(args): + """Code for finetuning SAM on DeepBacs""" + # override this (below) if you have some more complex set-up and need to specify the exact gpu + device = "cuda" if torch.cuda.is_available() else "cpu" + + # training settings: + model_type = args.model_type + checkpoint_path = None # override this to start training from a custom checkpoint + patch_shape = (512, 512) # the patch shape for training + n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled + freeze_parts = args.freeze # override this to freeze different parts of the model + + # get the trainable segment anything model + model = sam_training.get_trainable_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + freeze=freeze_parts + ) + model.to(device) + + # let's get the UNETR model for automatic instance segmentation pipeline + unetr = UNETR( + backbone="sam", + encoder=model.sam.image_encoder, + out_channels=3, + use_sam_stats=True, + final_activation="Sigmoid", + use_skip_connection=False, + resize_input=True + ) + unetr.to(device) + + # let's get the parameters for SAM and the decoder from UNETR + joint_model_params = [params for params in model.parameters()] # sam parameters + for name, params in unetr.named_parameters(): # unetr's decoder parameters + if not name.startswith("encoder"): + joint_model_params.append(params) + + # all the stuff we need for training + optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.9, patience=50, verbose=True + ) + train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) + + # this class creates all the training data for a batch (inputs, prompts and labels) + convert_inputs = sam_training.ConvertToSamInputs( + transform=model.transform, box_distortion_factor=0.025 + ) + + checkpoint_name = f"{args.model_type}/deepbacs_sam" + + # the trainer which performs the joint training and validation (implemented using "torch_em") + trainer = sam_training.JointSamTrainer( + name=checkpoint_name, + save_root=args.save_root, + train_loader=train_loader, + val_loader=val_loader, + model=model, + optimizer=optimizer, + device=device, + lr_scheduler=scheduler, + logger=sam_training.JointSamLogger, + log_image_interval=100, + mixed_precision=True, + convert_inputs=convert_inputs, + n_objects_per_batch=n_objects_per_batch, + n_sub_iteration=8, + compile_model=False, + mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training + unetr=unetr, + instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True), + instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True) + ) + trainer.fit(args.iterations, save_every_kth_epoch=args.save_every_kth_epoch) + if args.export_path is not None: + checkpoint_path = os.path.join( + "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" + ) + export_custom_sam_model( + checkpoint_path=checkpoint_path, + model_type=model_type, + save_path=args.export_path, + ) + + +def main(): + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the DeepBacs dataset.") + parser.add_argument( + "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/deepbacs/", + help="The filepath to the DeepBacs data. If the data does not exist yet it will be downloaded." + ) + parser.add_argument( + "--model_type", "-m", default="vit_b", + help="The model type to use for fine-tuning. Either vit_b, vit_l or vit_h." + ) + parser.add_argument( + "--save_root", "-s", + help="Where to save the checkpoint and logs. By default they will be saved where this script is run." + ) + parser.add_argument( + "--iterations", type=int, default=int(25e4), + help="For how many iterations should the model be trained? By default 250k." + ) + parser.add_argument( + "--export_path", "-e", + help="Where to export the finetuned model to. The exported model can be used in the annotation tools." + ) + parser.add_argument( + "--freeze", type=str, nargs="+", default=None, + help="Which parts of the model to freeze for finetuning." + ) + parser.add_argument( + "--save_every_kth_epoch", type=int, default=None, + help="To save every kth epoch while fine-tuning. Expects an integer value." + ) + args = parser.parse_args() + finetune_deepbacs(args) + + +if __name__ == "__main__": + main() diff --git a/finetuning/specialists/training/light_microscopy/plantseg_root_finetuning.py b/finetuning/specialists/training/light_microscopy/plantseg_root_finetuning.py new file mode 100644 index 00000000..17c7898c --- /dev/null +++ b/finetuning/specialists/training/light_microscopy/plantseg_root_finetuning.py @@ -0,0 +1,166 @@ +import os +import argparse + +import torch + +from torch_em.model import UNETR +from torch_em.data import MinInstanceSampler +from torch_em.loss import DiceBasedDistanceLoss +from torch_em.data.datasets import get_plantseg_loader + +import micro_sam.training as sam_training +from micro_sam.util import export_custom_sam_model +from micro_sam.training.util import ResizeLabelTrafo, ResizeRawTrafo + + +def get_dataloaders(patch_shape, data_path): + """This returns the plantseg data loaders implemented in torch_em: + https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/plantseg.py + It will automatically download the plantseg (root) data. + + Note: to replace this with another data loader you need to return a torch data loader + that retuns `x, y` tensors, where `x` is the image data and `y` are the labels. + The labels have to be in a label mask instance segmentation format. + I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. + Important: the ID 0 is reseved for background, and the IDs must be consecutive + """ + raw_transform = ResizeRawTrafo(patch_shape, do_rescaling=False) + label_transform = ResizeLabelTrafo(patch_shape) + sampler = MinInstanceSampler(min_num_instances=10) + label_dtype = torch.float32 + + train_loader = get_plantseg_loader( + path=data_path, name="root", split="train", patch_shape=(1, *patch_shape), batch_size=2, + download=True, ndim=2, sampler=sampler, raw_transform=raw_transform, label_transform=label_transform, + num_workers=16, shuffle=True, label_dtype=label_dtype, n_samples=5000 # training w. ~25% of the total train-set + ) + val_loader = get_plantseg_loader( + path=data_path, name="root", split="val", patch_shape=(1, *patch_shape), batch_size=1, + download=True, ndim=2, sampler=sampler, raw_transform=raw_transform, label_transform=label_transform, + num_workers=16, shuffle=True, label_dtype=label_dtype + ) + + return train_loader, val_loader + + +def finetune_plantseg_root(args): + """Code for finetuning SAM on PlantSeg (root)""" + # override this (below) if you have some more complex set-up and need to specify the exact gpu + device = "cuda" if torch.cuda.is_available() else "cpu" + + # training settings: + model_type = args.model_type + checkpoint_path = None # override this to start training from a custom checkpoint + patch_shape = (512, 512) # the patch shape for training + n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled + freeze_parts = args.freeze # override this to freeze different parts of the model + + # get the trainable segment anything model + model = sam_training.get_trainable_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + freeze=freeze_parts + ) + model.to(device) + + # let's get the UNETR model for automatic instance segmentation pipeline + unetr = UNETR( + backbone="sam", + encoder=model.sam.image_encoder, + out_channels=3, + use_sam_stats=True, + final_activation="Sigmoid", + use_skip_connection=False, + resize_input=True + ) + unetr.to(device) + + # let's get the parameters for SAM and the decoder from UNETR + joint_model_params = [params for params in model.parameters()] # sam parameters + for name, params in unetr.named_parameters(): # unetr's decoder parameters + if not name.startswith("encoder"): + joint_model_params.append(params) + + # all the stuff we need for training + optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True) + train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) + + # this class creates all the training data for a batch (inputs, prompts and labels) + convert_inputs = sam_training.ConvertToSamInputs( + transform=model.transform, box_distortion_factor=0.025 + ) + + checkpoint_name = f"{args.model_type}/plantseg_root_sam" + + # the trainer which performs the joint training and validation (implemented using "torch_em") + trainer = sam_training.JointSamTrainer( + name=checkpoint_name, + save_root=args.save_root, + train_loader=train_loader, + val_loader=val_loader, + model=model, + optimizer=optimizer, + device=device, + lr_scheduler=scheduler, + logger=sam_training.JointSamLogger, + log_image_interval=100, + mixed_precision=True, + convert_inputs=convert_inputs, + n_objects_per_batch=n_objects_per_batch, + n_sub_iteration=8, + compile_model=False, + mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training + unetr=unetr, + instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True), + instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True) + ) + trainer.fit(args.iterations, save_every_kth_epoch=args.save_every_kth_epoch) + if args.export_path is not None: + checkpoint_path = os.path.join( + "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" + ) + export_custom_sam_model( + checkpoint_path=checkpoint_path, + model_type=model_type, + save_path=args.export_path, + ) + + +def main(): + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the PlantSeg dataset.") + parser.add_argument( + "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/plantseg/", + help="The filepath to the PlantSeg (root) data. If the data does not exist yet it will be downloaded." + ) + parser.add_argument( + "--model_type", "-m", default="vit_b", + help="The model type to use for fine-tuning. Either vit_b, vit_l or vit_h." + ) + parser.add_argument( + "--save_root", "-s", + help="Where to save the checkpoint and logs. By default they will be saved where this script is run." + ) + parser.add_argument( + "--iterations", type=int, default=int(25e4), + help="For how many iterations should the model be trained? By default 250k." + ) + parser.add_argument( + "--export_path", "-e", + help="Where to export the finetuned model to. The exported model can be used in the annotation tools." + ) + parser.add_argument( + "--freeze", type=str, nargs="+", default=None, + help="Which parts of the model to freeze for finetuning." + ) + parser.add_argument( + "--save_every_kth_epoch", type=int, default=None, + help="To save every kth epoch while fine-tuning. Expects an integer value." + ) + args = parser.parse_args() + finetune_plantseg_root(args) + + +if __name__ == "__main__": + main() diff --git a/finetuning/specialists/training/light_microscopy/tissuenet_finetuning.py b/finetuning/specialists/training/light_microscopy/tissuenet_finetuning.py new file mode 100644 index 00000000..027da672 --- /dev/null +++ b/finetuning/specialists/training/light_microscopy/tissuenet_finetuning.py @@ -0,0 +1,166 @@ +import os +import argparse + +import torch + +from torch_em.model import UNETR +from torch_em.data import MinInstanceSampler +from torch_em.loss import DiceBasedDistanceLoss +from torch_em.data.datasets import get_tissuenet_loader + +import micro_sam.training as sam_training +from micro_sam.util import export_custom_sam_model +from micro_sam.training.util import ResizeRawTrafo, ResizeLabelTrafo + + +def get_dataloaders(patch_shape, data_path): + """This returns the tissuenet data loaders implemented in torch_em: + https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/tissuenet.py + It will automatically download the tissuenet data. + + Note: to replace this with another data loader you need to return a torch data loader + that retuns `x, y` tensors, where `x` is the image data and `y` are the labels. + The labels have to be in a label mask instance segmentation format. + I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. + Important: the ID 0 is reseved for background, and the IDs must be consecutive + """ + raw_transform = ResizeRawTrafo(patch_shape) + label_transform = ResizeLabelTrafo(patch_shape) + sampler = MinInstanceSampler() + label_dtype = torch.float32 + + train_loader = get_tissuenet_loader( + path=data_path, split="train", patch_shape=patch_shape, batch_size=2, raw_channel="rgb", + label_channel="cell", download=True, label_dtype=label_dtype, raw_transform=raw_transform, + label_transform=label_transform, sampler=sampler, num_workers=16, shuffle=True + ) + val_loader = get_tissuenet_loader( + path=data_path, split="val", patch_shape=patch_shape, batch_size=1, raw_channel="rgb", + label_channel="cell", download=True, label_dtype=label_dtype, raw_transform=raw_transform, + label_transform=label_transform, sampler=sampler, num_workers=16, shuffle=True, n_samples=1000 + ) + + return train_loader, val_loader + + +def finetune_tissuenet(args): + """Code for finetuning SAM on TissueNet""" + # override this (below) if you have some more complex set-up and need to specify the exact gpu + device = "cuda" if torch.cuda.is_available() else "cpu" + + # training settings: + model_type = args.model_type + checkpoint_path = None # override this to start training from a custom checkpoint + patch_shape = (512, 512) # the patch shape for training + n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled + freeze_parts = args.freeze # override this to freeze different parts of the model + + # get the trainable segment anything model + model = sam_training.get_trainable_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + freeze=freeze_parts + ) + model.to(device) + + # let's get the UNETR model for automatic instance segmentation pipeline + unetr = UNETR( + backbone="sam", + encoder=model.sam.image_encoder, + out_channels=3, + use_sam_stats=True, + final_activation="Sigmoid", + use_skip_connection=False, + resize_input=True + ) + unetr.to(device) + + # let's get the parameters for SAM and the decoder from UNETR + joint_model_params = [params for params in model.parameters()] # sam parameters + for name, params in unetr.named_parameters(): # unetr's decoder parameters + if not name.startswith("encoder"): + joint_model_params.append(params) + + # all the stuff we need for training + optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True) + train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) + + # this class creates all the training data for a batch (inputs, prompts and labels) + convert_inputs = sam_training.ConvertToSamInputs( + transform=model.transform, box_distortion_factor=0.025 + ) + + checkpoint_name = f"{args.model_type}/tissuenet_sam" + + # the trainer which performs the joint training and validation (implemented using "torch_em") + trainer = sam_training.JointSamTrainer( + name=checkpoint_name, + save_root=args.save_root, + train_loader=train_loader, + val_loader=val_loader, + model=model, + optimizer=optimizer, + device=device, + lr_scheduler=scheduler, + logger=sam_training.JointSamLogger, + log_image_interval=100, + mixed_precision=True, + convert_inputs=convert_inputs, + n_objects_per_batch=n_objects_per_batch, + n_sub_iteration=8, + compile_model=False, + mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training + unetr=unetr, + instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True), + instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True) + ) + trainer.fit(args.iterations, save_every_kth_epoch=args.save_every_kth_epoch) + if args.export_path is not None: + checkpoint_path = os.path.join( + "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" + ) + export_custom_sam_model( + checkpoint_path=checkpoint_path, + model_type=model_type, + save_path=args.export_path, + ) + + +def main(): + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the TissueNet dataset.") + parser.add_argument( + "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/tissuenet/", + help="The filepath to the TissueNet data. If the data does not exist yet it will be downloaded." + ) + parser.add_argument( + "--model_type", "-m", default="vit_b", + help="The model type to use for fine-tuning. Either vit_b, vit_l or vit_h." + ) + parser.add_argument( + "--save_root", "-s", + help="Where to save the checkpoint and logs. By default they will be saved where this script is run." + ) + parser.add_argument( + "--iterations", type=int, default=int(25e4), + help="For how many iterations should the model be trained? By default 250k." + ) + parser.add_argument( + "--export_path", "-e", + help="Where to export the finetuned model to. The exported model can be used in the annotation tools." + ) + parser.add_argument( + "--freeze", type=str, nargs="+", default=None, + help="Which parts of the model to freeze for finetuning." + ) + parser.add_argument( + "--save_every_kth_epoch", type=int, default=None, + help="To save every kth epoch while fine-tuning. Expects an integer value." + ) + args = parser.parse_args() + finetune_tissuenet(args) + + +if __name__ == "__main__": + main() diff --git a/micro_sam/evaluation/__init__.py b/micro_sam/evaluation/__init__.py index a63d8597..199ccc97 100644 --- a/micro_sam/evaluation/__init__.py +++ b/micro_sam/evaluation/__init__.py @@ -1,14 +1,15 @@ """Functionality for evaluating Segment Anything models on microscopy data. """ -from .automatic_mask_generation import ( - run_amg_inference, - run_amg_grid_search, - run_amg_grid_search_and_inference, +from .instance_segmentation import ( + run_instance_segmentation_inference, + run_instance_segmentation_grid_search, + run_instance_segmentation_grid_search_and_inference, ) from .evaluation import run_evaluation from .inference import ( get_predictor, + run_inference_with_iterative_prompting, run_inference_with_prompts, precompute_all_embeddings, precompute_all_prompts, diff --git a/micro_sam/evaluation/automatic_mask_generation.py b/micro_sam/evaluation/automatic_mask_generation.py deleted file mode 100644 index 46b12ef5..00000000 --- a/micro_sam/evaluation/automatic_mask_generation.py +++ /dev/null @@ -1,272 +0,0 @@ -"""Inference and evaluation for the automatic instance segmentation functionality. -""" - -import os -from glob import glob -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union - -import imageio.v3 as imageio -import numpy as np -import pandas as pd - -from elf.evaluation import mean_segmentation_accuracy -from segment_anything import SamPredictor -from tqdm import tqdm - -from .. import instance_segmentation -from .. import util - - -def _get_range_of_search_values(input_vals, step): - if isinstance(input_vals, list): - search_range = np.arange(input_vals[0], input_vals[1] + step, step) - search_range = [round(e, 2) for e in search_range] - else: - search_range = [input_vals] - return search_range - - -def _grid_search( - amg, gt, image_name, iou_thresh_values, stability_score_values, result_path, amg_generate_kwargs, verbose, -): - net_list = [] - gs_combinations = [(r1, r2) for r1 in iou_thresh_values for r2 in stability_score_values] - - for iou_thresh, stability_thresh in tqdm(gs_combinations, disable=not verbose): - masks = amg.generate( - pred_iou_thresh=iou_thresh, stability_score_thresh=stability_thresh, **amg_generate_kwargs - ) - instance_labels = instance_segmentation.mask_data_to_segmentation( - masks, gt.shape, with_background=True, - min_object_size=amg_generate_kwargs.get("min_mask_region_area", 0), - ) - m_sas, sas = mean_segmentation_accuracy(instance_labels, gt, return_accuracies=True) # type: ignore - - result_dict = { - "image_name": image_name, - "pred_iou_thresh": iou_thresh, - "stability_score_thresh": stability_thresh, - "mSA": m_sas, - "SA50": sas[0], - "SA75": sas[5] - } - tmp_df = pd.DataFrame([result_dict]) - net_list.append(tmp_df) - - img_gs_df = pd.concat(net_list) - img_gs_df.to_csv(result_path, index=False) - - -# ideally we would generalize the parameters that GS runs over -def run_amg_grid_search( - predictor: SamPredictor, - image_paths: List[Union[str, os.PathLike]], - gt_paths: List[Union[str, os.PathLike]], - embedding_dir: Union[str, os.PathLike], - result_dir: Union[str, os.PathLike], - iou_thresh_values: Optional[List[float]] = None, - stability_score_values: Optional[List[float]] = None, - amg_kwargs: Optional[Dict[str, Any]] = None, - amg_generate_kwargs: Optional[Dict[str, Any]] = None, - AMG: instance_segmentation.AMGBase = instance_segmentation.AutomaticMaskGenerator, - verbose_gs: bool = False, -) -> None: - """Run grid search for automatic mask generation. - - The grid search goes over the two most important parameters: - - `pred_iou_thresh`, the threshold for keeping objects according to the IoU predicted by the model - - `stability_score_thresh`, the theshold for keepong objects according to their stability - - Args: - predictor: The segment anything predictor. - image_paths: The input images for the grid search. - gt_paths: The ground-truth segmentation for the grid search. - embedding_dir: Folder to cache the image embeddings. - result_dir: Folder to cache the evaluation results per image. - iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch. - By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. - stability_score_values: The values for `stability_score_thresh` used in the gridsearch. - By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. - amg_kwargs: The keyword arguments for the automatic mask generator class. - amg_generate_kwargs: The keyword arguments for the `generate` method of the mask generator. - This must not contain `pred_iou_thresh` or `stability_score_thresh`. - AMG: The automatic mask generator. By default `micro_sam.instance_segmentation.AutomaticMaskGenerator`. - verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. - """ - assert len(image_paths) == len(gt_paths) - amg_kwargs = {} if amg_kwargs is None else amg_kwargs - amg_generate_kwargs = {} if amg_generate_kwargs is None else amg_generate_kwargs - if "pred_iou_thresh" in amg_generate_kwargs or "stability_score_thresh" in amg_generate_kwargs: - raise ValueError("The threshold parameters are optimized in the grid-search. You must not pass them as kwargs.") - - if iou_thresh_values is None: - iou_thresh_values = _get_range_of_search_values([0.6, 0.9], step=0.025) - if stability_score_values is None: - stability_score_values = _get_range_of_search_values([0.6, 0.95], step=0.025) - - os.makedirs(result_dir, exist_ok=True) - amg = AMG(predictor, **amg_kwargs) - - for image_path, gt_path in tqdm( - zip(image_paths, gt_paths), desc="Run grid search for AMG", total=len(image_paths) - ): - image_name = Path(image_path).stem - result_path = os.path.join(result_dir, f"{image_name}.csv") - - # We skip images for which the grid search was done already. - if os.path.exists(result_path): - continue - - assert os.path.exists(image_path), image_path - assert os.path.exists(gt_path), gt_path - - image = imageio.imread(image_path) - gt = imageio.imread(gt_path) - - embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") - image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path, ndim=2) - amg.initialize(image, image_embeddings) - - _grid_search( - amg, gt, image_name, - iou_thresh_values, stability_score_values, - result_path, amg_generate_kwargs, verbose=verbose_gs, - ) - - -def run_amg_inference( - predictor: SamPredictor, - image_paths: List[Union[str, os.PathLike]], - embedding_dir: Union[str, os.PathLike], - prediction_dir: Union[str, os.PathLike], - amg_kwargs: Optional[Dict[str, Any]] = None, - amg_generate_kwargs: Optional[Dict[str, Any]] = None, - AMG: instance_segmentation.AMGBase = instance_segmentation.AutomaticMaskGenerator, -) -> None: - """Run inference for automatic mask generation. - - Args: - predictor: The segment anything predictor. - image_paths: The input images. - embedding_dir: Folder to cache the image embeddings. - prediction_dir: Folder to save the predictions. - amg_kwargs: The keyword arguments for the automatic mask generator class. - amg_generate_kwargs: The keyword arguments for the `generate` method of the mask generator. - This must not contain `pred_iou_thresh` or `stability_score_thresh`. - AMG: The automatic mask generator. By default `micro_sam.instance_segmentation.AutomaticMaskGenerator`. - """ - amg_kwargs = {} if amg_kwargs is None else amg_kwargs - amg_generate_kwargs = {} if amg_generate_kwargs is None else amg_generate_kwargs - - amg = AMG(predictor, **amg_kwargs) - - for image_path in tqdm(image_paths, desc="Run inference for automatic mask generation"): - image_name = os.path.basename(image_path) - - # We skip the images that already have been segmented. - prediction_path = os.path.join(prediction_dir, image_name) - if os.path.exists(prediction_path): - continue - - assert os.path.exists(image_path), image_path - image = imageio.imread(image_path) - - embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") - image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path, ndim=2) - - amg.initialize(image, image_embeddings) - masks = amg.generate(**amg_generate_kwargs) - instances = instance_segmentation.mask_data_to_segmentation( - masks, image.shape, with_background=True, min_object_size=amg_generate_kwargs.get("min_mask_region_area", 0) - ) - - # It's important to compress here, otherwise the predictions would take up a lot of space. - imageio.imwrite(prediction_path, instances, compression=5) - - -def evaluate_amg_grid_search(result_dir: Union[str, os.PathLike], criterion: str = "mSA") -> Tuple[float, float, float]: - """Evaluate gridsearch results. - - Args: - result_dir: The folder with the gridsearch results. - criterion: The metric to use for determining the best parameters. - - Returns: - - The best value for `pred_iou_thresh`. - - The best value for `stability_score_thresh`. - - The evaluation score for the best setting. - """ - - # load all the grid search results - gs_files = glob(os.path.join(result_dir, "*.csv")) - gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files]) - - # contain only the relevant columns and group by the gridsearch columns - gs_col1 = "pred_iou_thresh" - gs_col2 = "stability_score_thresh" - gs_result = gs_result[[gs_col1, gs_col2, criterion]] - - # compute the mean over the grouped columns - grouped_result = gs_result.groupby([gs_col1, gs_col2]).mean() - - # find the best grouped result and return the corresponding thresholds - best_score = grouped_result.max().values[0] - best_result = grouped_result.idxmax() - best_iou_thresh, best_stability_score = best_result.values[0] - return best_iou_thresh, best_stability_score, best_score - - -def run_amg_grid_search_and_inference( - predictor: SamPredictor, - val_image_paths: List[Union[str, os.PathLike]], - val_gt_paths: List[Union[str, os.PathLike]], - test_image_paths: List[Union[str, os.PathLike]], - embedding_dir: Union[str, os.PathLike], - prediction_dir: Union[str, os.PathLike], - result_dir: Union[str, os.PathLike], - iou_thresh_values: Optional[List[float]] = None, - stability_score_values: Optional[List[float]] = None, - amg_kwargs: Optional[Dict[str, Any]] = None, - amg_generate_kwargs: Optional[Dict[str, Any]] = None, - AMG: instance_segmentation.AMGBase = instance_segmentation.AutomaticMaskGenerator, - verbose_gs: bool = True, -) -> None: - """Run grid search and inference for automatic mask generation. - - Args: - predictor: The segment anything predictor. - val_image_paths: The input images for the grid search. - val_gt_paths: The ground-truth segmentation for the grid search. - test_image_paths: The input images for inference. - embedding_dir: Folder to cache the image embeddings. - prediction_dir: Folder to save the predictions. - result_dir: Folder to cache the evaluation results per image. - iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch. - By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. - stability_score_values: The values for `stability_score_thresh` used in the gridsearch. - By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. - amg_kwargs: The keyword arguments for the automatic mask generator class. - amg_generate_kwargs: The keyword arguments for the `generate` method of the mask generator. - This must not contain `pred_iou_thresh` or `stability_score_thresh`. - AMG: The automatic mask generator. By default `micro_sam.instance_segmentation.AutomaticMaskGenerator`. - verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. - """ - run_amg_grid_search( - predictor, val_image_paths, val_gt_paths, embedding_dir, result_dir, - iou_thresh_values=iou_thresh_values, stability_score_values=stability_score_values, - amg_kwargs=amg_kwargs, amg_generate_kwargs=amg_generate_kwargs, AMG=AMG, verbose_gs=verbose_gs, - ) - - amg_generate_kwargs = {} if amg_generate_kwargs is None else amg_generate_kwargs - best_iou_thresh, best_stability_score, best_msa = evaluate_amg_grid_search(result_dir) - print( - "Best grid-search result:", best_msa, - f"@ iou_thresh = {best_iou_thresh}, stability_score = {best_stability_score}" - ) - amg_generate_kwargs["pred_iou_thresh"] = best_iou_thresh - amg_generate_kwargs["stability_score_thresh"] = best_stability_score - - run_amg_inference( - predictor, test_image_paths, embedding_dir, prediction_dir, amg_kwargs, amg_generate_kwargs, AMG - ) diff --git a/micro_sam/evaluation/inference.py b/micro_sam/evaluation/inference.py index bfa1b160..4568f641 100644 --- a/micro_sam/evaluation/inference.py +++ b/micro_sam/evaluation/inference.py @@ -137,7 +137,7 @@ def _run_inference_with_prompts_for_image( def get_predictor( checkpoint_path: Union[str, os.PathLike], model_type: str, - device: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, return_state: bool = False, is_custom_model: Optional[bool] = None, ) -> SamPredictor: @@ -146,12 +146,13 @@ def get_predictor( Args: checkpoint_path: The checkpoint filepath. model_type: The type of the model, either vit_h, vit_b or vit_l. + device: The device to use. return_state: Whether to return the complete state of the checkpoint in addtion to the predictor. is_custom_model: Whether this is a custom model or not. Returns: The segment anything predictor. """ - device = util._get_device(device) + device = util.get_device(device) # By default we check if the model follows the torch_em checkpint naming scheme to check whether it is a # custom model or not. This can be over-ridden by passing True or False for is_custom_model. @@ -386,9 +387,8 @@ def run_inference_with_prompts( def _save_segmentation(masks, prediction_path): # masks to segmentation masks = masks.cpu().numpy().squeeze().astype("bool") - shape = masks.shape[-2:] masks = [{"segmentation": mask, "area": mask.sum()} for mask in masks] - segmentation = mask_data_to_segmentation(masks, shape, with_background=True) + segmentation = mask_data_to_segmentation(masks, with_background=True) imageio.imwrite(prediction_path, segmentation, compression=5) diff --git a/micro_sam/evaluation/instance_segmentation.py b/micro_sam/evaluation/instance_segmentation.py new file mode 100644 index 00000000..a07c49dc --- /dev/null +++ b/micro_sam/evaluation/instance_segmentation.py @@ -0,0 +1,364 @@ +"""Inference and evaluation for the automatic instance segmentation functionality. +""" + +import os +from glob import glob +from itertools import product +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import imageio.v3 as imageio +import numpy as np +import pandas as pd + +from elf.evaluation import mean_segmentation_accuracy +from elf.io import open_file +from tqdm import tqdm + +from ..instance_segmentation import AMGBase, InstanceSegmentationWithDecoder, mask_data_to_segmentation +from .. import util + + +def _get_range_of_search_values(input_vals, step): + if isinstance(input_vals, list): + search_range = np.arange(input_vals[0], input_vals[1] + step, step) + search_range = [round(e, 2) for e in search_range] + else: + search_range = [input_vals] + return search_range + + +def default_grid_search_values_amg( + iou_thresh_values: Optional[List[float]] = None, + stability_score_values: Optional[List[float]] = None, +) -> Dict[str, List[float]]: + """Default grid-search parameter for AMG-based instance segmentation. + + Return grid search values for the two most important parameters: + - `pred_iou_thresh`, the threshold for keeping objects according to the IoU predicted by the model. + - `stability_score_thresh`, the theshold for keepong objects according to their stability. + + Args: + iou_thresh_values: The values for `pred_iou_thresh` used in the gridsearch. + By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. + stability_score_values: The values for `stability_score_thresh` used in the gridsearch. + By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. + + Returns: + The values for grid search. + """ + if iou_thresh_values is None: + iou_thresh_values = _get_range_of_search_values([0.6, 0.9], step=0.025) + if stability_score_values is None: + stability_score_values = _get_range_of_search_values([0.6, 0.95], step=0.025) + return { + "pred_iou_thresh": iou_thresh_values, + "stability_score_thresh": stability_score_values, + } + + +def default_grid_search_values_instance_segmentation_with_decoder( + center_distance_threshold_values: Optional[List[float]] = None, + boundary_distance_threshold_values: Optional[List[float]] = None, + distance_smoothing_values: Optional[List[float]] = None, + min_size_values: Optional[List[float]] = None, +) -> Dict[str, List[float]]: + """Default grid-search parameter for decoder-based instance segmentation. + + Args: + center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch. + By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used. + boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch. + By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used. + distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch. + By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used. + min_size_values: The values for `min_size` used in the gridsearch. + By default the values 25, 50, 75, 100 and 200 are used. + + Returns: + The values for grid search. + """ + if center_distance_threshold_values is None: + center_distance_threshold_values = _get_range_of_search_values( + [0.3, 0.7], step=0.1 + ) + if boundary_distance_threshold_values is None: + boundary_distance_threshold_values = _get_range_of_search_values( + [0.3, 0.7], step=0.1 + ) + if distance_smoothing_values is None: + distance_smoothing_values = _get_range_of_search_values( + [1.0, 2.0], step=0.2 + ) + if min_size_values is None: + min_size_values = [50, 100, 200] + return { + "center_distance_threshold": center_distance_threshold_values, + "boundary_distance_threshold": boundary_distance_threshold_values, + "distance_smoothing": distance_smoothing_values, + "min_size": min_size_values, + } + + +def _grid_search_iteration( + segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], + gs_combinations: List[Dict], + gt: np.ndarray, + image_name: str, + fixed_generate_kwargs: Dict[str, Any], + result_path: Optional[Union[str, os.PathLike]], + verbose: bool = False, +) -> pd.DataFrame: + net_list = [] + for gs_kwargs in tqdm(gs_combinations, disable=not verbose): + generate_kwargs = gs_kwargs | fixed_generate_kwargs + masks = segmenter.generate(**generate_kwargs) + + min_object_size = generate_kwargs.get("min_mask_region_area", 0) + if len(masks) == 0: + instance_labels = np.zeros(gt.shape, dtype="uint32") + else: + instance_labels = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size) + m_sas, sas = mean_segmentation_accuracy(instance_labels, gt, return_accuracies=True) # type: ignore + + result_dict = {"image_name": image_name, "mSA": m_sas, "SA50": sas[0], "SA75": sas[5]} + result_dict.update(gs_kwargs) + tmp_df = pd.DataFrame([result_dict]) + net_list.append(tmp_df) + + img_gs_df = pd.concat(net_list) + img_gs_df.to_csv(result_path, index=False) + + return img_gs_df + + +def _load_image(path, key, roi): + if key is None: + im = imageio.imread(path) + if roi is not None: + im = im[roi] + return im + with open_file(path, "r") as f: + im = f[key][:] if roi is None else f[key][roi] + return im + + +def run_instance_segmentation_grid_search( + segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], + grid_search_values: Dict[str, List], + image_paths: List[Union[str, os.PathLike]], + gt_paths: List[Union[str, os.PathLike]], + result_dir: Union[str, os.PathLike], + embedding_dir: Optional[Union[str, os.PathLike]], + fixed_generate_kwargs: Optional[Dict[str, Any]] = None, + verbose_gs: bool = False, + image_key: Optional[str] = None, + gt_key: Optional[str] = None, + rois: Optional[Tuple[slice, ...]] = None, +) -> None: + """Run grid search for automatic mask generation. + + The parameters and their respective value ranges for the grid search are specified via the + 'grid_search_values' argument. For example, to run a grid search over the parameters 'pred_iou_thresh' + and 'stability_score_thresh', you can pass the following: + ``` + grid_search_values = { + "pred_iou_thresh": [0.6, 0.7, 0.8, 0.9], + "stability_score_thresh": [0.6, 0.7, 0.8, 0.9], + } + ``` + All combinations of the parameters will be checked. + + You can use the functions `default_grid_search_values_instance_segmentation_with_decoder` + or `default_grid_search_values_amg` to get the default grid search parameters for the two + respective instance segmentation methods. + + Args: + segmenter: The class implementing the instance segmentation functionality. + grid_search_values: The grid search values for parameters of the `generate` function. + image_paths: The input images for the grid search. + gt_paths: The ground-truth segmentation for the grid search. + result_dir: Folder to cache the evaluation results per image. + embedding_dir: Folder to cache the image embeddings. + fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. + verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. + image_key: Key for loading the image data from a more complex file format like HDF5. + If not given a simple image format like tif is assumed. + gt_key: Key for loading the ground-truth data from a more complex file format like HDF5. + If not given a simple image format like tif is assumed. + rois: Region of interests to resetrict the evaluation to. + """ + assert len(image_paths) == len(gt_paths) + fixed_generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs + + duplicate_params = [gs_param for gs_param in grid_search_values.keys() if gs_param in fixed_generate_kwargs] + if duplicate_params: + raise ValueError( + "You may not pass duplicate parameters in 'grid_search_values' and 'fixed_generate_kwargs'." + f"The parameters {duplicate_params} are duplicated." + ) + + # Compute all combinations of grid search values. + gs_combinations = product(*grid_search_values.values()) + # Map each combination back to a valid kwarg input. + gs_combinations = [ + {k: v for k, v in zip(grid_search_values.keys(), vals)} for vals in gs_combinations + ] + + os.makedirs(result_dir, exist_ok=True) + predictor = getattr(segmenter, "_predictor", None) + + for i, (image_path, gt_path) in tqdm( + enumerate(zip(image_paths, gt_paths)), desc="Run instance segmentation grid-search", total=len(image_paths) + ): + image_name = Path(image_path).stem + result_path = os.path.join(result_dir, f"{image_name}.csv") + + # We skip images for which the grid search was done already. + if os.path.exists(result_path): + continue + + assert os.path.exists(image_path), image_path + assert os.path.exists(gt_path), gt_path + + image = _load_image(image_path, image_key, roi=None if rois is None else rois[i]) + gt = _load_image(gt_path, gt_key, roi=None if rois is None else rois[i]) + + if embedding_dir is None: + segmenter.initialize(image) + else: + assert predictor is not None + embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") + image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path, ndim=2) + segmenter.initialize(image, image_embeddings) + + _grid_search_iteration( + segmenter, gs_combinations, gt, image_name, + fixed_generate_kwargs=fixed_generate_kwargs, result_path=result_path, verbose=verbose_gs, + ) + + +def run_instance_segmentation_inference( + segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], + image_paths: List[Union[str, os.PathLike]], + embedding_dir: Union[str, os.PathLike], + prediction_dir: Union[str, os.PathLike], + generate_kwargs: Optional[Dict[str, Any]] = None, +) -> None: + """Run inference for automatic mask generation. + + Args: + segmenter: The class implementing the instance segmentation functionality. + image_paths: The input images. + embedding_dir: Folder to cache the image embeddings. + prediction_dir: Folder to save the predictions. + generate_kwargs: The keyword arguments for the `generate` method of the segmenter. + """ + + generate_kwargs = {} if generate_kwargs is None else generate_kwargs + predictor = segmenter._predictor + min_object_size = generate_kwargs.get("min_mask_region_area", 0) + + for image_path in tqdm(image_paths, desc="Run inference for automatic mask generation"): + image_name = os.path.basename(image_path) + + # We skip the images that already have been segmented. + prediction_path = os.path.join(prediction_dir, image_name) + if os.path.exists(prediction_path): + continue + + assert os.path.exists(image_path), image_path + image = imageio.imread(image_path) + + embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") + image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path, ndim=2) + + segmenter.initialize(image, image_embeddings) + masks = segmenter.generate(**generate_kwargs) + instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size) + + # It's important to compress here, otherwise the predictions would take up a lot of space. + imageio.imwrite(prediction_path, instances, compression=5) + + +def evaluate_instance_segmentation_grid_search( + result_dir: Union[str, os.PathLike], + grid_search_parameters: List[str], + criterion: str = "mSA" +) -> Tuple[Dict[str, Any], float]: + """Evaluate gridsearch results. + + Args: + result_dir: The folder with the gridsearch results. + grid_search_parameters: The names for the gridsearch parameters. + criterion: The metric to use for determining the best parameters. + + Returns: + The best parameter setting. + The evaluation score for the best setting. + """ + + # Load all the grid search results. + gs_files = glob(os.path.join(result_dir, "*.csv")) + gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files]) + + # Retrieve only the relevant columns and group by the gridsearch columns. + gs_result = gs_result[grid_search_parameters + [criterion]].reset_index() + + # Compute the mean over the grouped columns. + grouped_result = gs_result.groupby(grid_search_parameters).mean().reset_index() + + # Find the best score and corresponding parameters. + best_score, best_idx = grouped_result[criterion].max(), grouped_result[criterion].idxmax() + best_params = grouped_result.iloc[best_idx] + assert np.isclose(best_params[criterion], best_score) + best_kwargs = {k: v for k, v in zip(grid_search_parameters, best_params)} + + return best_kwargs, best_score + + +def run_instance_segmentation_grid_search_and_inference( + segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], + grid_search_values: Dict[str, List], + val_image_paths: List[Union[str, os.PathLike]], + val_gt_paths: List[Union[str, os.PathLike]], + test_image_paths: List[Union[str, os.PathLike]], + embedding_dir: Union[str, os.PathLike], + prediction_dir: Union[str, os.PathLike], + result_dir: Union[str, os.PathLike], + fixed_generate_kwargs: Optional[Dict[str, Any]] = None, + verbose_gs: bool = True, +) -> None: + """Run grid search and inference for automatic mask generation. + + Please refer to the documentation of `run_instance_segmentation_grid_search` + for details on how to specify the grid search parameters. + + Args: + segmenter: The class implementing the instance segmentation functionality. + grid_search_values: The grid search values for parameters of the `generate` function. + val_image_paths: The input images for the grid search. + val_gt_paths: The ground-truth segmentation for the grid search. + test_image_paths: The input images for inference. + embedding_dir: Folder to cache the image embeddings. + prediction_dir: Folder to save the predictions. + result_dir: Folder to cache the evaluation results per image. + fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. + verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. + """ + run_instance_segmentation_grid_search( + segmenter, grid_search_values, val_image_paths, val_gt_paths, + result_dir=result_dir, embedding_dir=embedding_dir, + fixed_generate_kwargs=fixed_generate_kwargs, verbose_gs=verbose_gs, + ) + + best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys())) + best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items()) + print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str) + + generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs + generate_kwargs.update(best_kwargs) + + run_instance_segmentation_inference( + segmenter, test_image_paths, embedding_dir, prediction_dir, generate_kwargs + ) diff --git a/micro_sam/evaluation/livecell.py b/micro_sam/evaluation/livecell.py index 58a95e07..18b0e810 100644 --- a/micro_sam/evaluation/livecell.py +++ b/micro_sam/evaluation/livecell.py @@ -15,8 +15,8 @@ from segment_anything import SamPredictor from tqdm import tqdm -from ..instance_segmentation import AutomaticMaskGenerator, _EmbeddingMaskGenerator -from . import automatic_mask_generation, inference, evaluation +from ..instance_segmentation import AutomaticMaskGenerator, load_instance_segmentation_with_decoder_from_checkpoint +from . import instance_segmentation, inference, evaluation from .experiments import default_experiment_settings, full_experiment_settings CELL_TYPES = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"] @@ -29,7 +29,7 @@ def _get_livecell_paths(input_folder, split="test", n_val_per_cell_type=None): assert split in ["val", "test"] - assert os.path.exists(input_folder), "Please download the LIVECell Dataset" + assert os.path.exists(input_folder), f"Data not found at {input_folder}. Please download the LIVECell Dataset" if split == "test": @@ -147,8 +147,7 @@ def run_livecell_amg( stability_score_values: Optional[List[float]] = None, verbose_gs: bool = False, n_val_per_cell_type: int = 25, - use_mws: bool = False, -) -> None: +) -> str: """Run automatic mask generation grid-search and inference for livecell. Args: @@ -162,17 +161,16 @@ def run_livecell_amg( By default values in the range from 0.6 to 0.9 with a stepsize of 0.025 will be used. verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. n_val_per_cell_type: The number of validation images per cell type. - use_mws: Whether to use the mutex watershed based automatic mask generator approach. + + Returns: + The path where the predicted images are stored. """ embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved os.makedirs(embedding_folder, exist_ok=True) - if use_mws: - amg_prefix = "amg_mws" - AMG = _EmbeddingMaskGenerator - else: - amg_prefix = "amg" - AMG = AutomaticMaskGenerator + predictor = inference.get_predictor(checkpoint, model_type) + amg = AutomaticMaskGenerator(predictor) + amg_prefix = "amg" # where the predictions are saved prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference") @@ -185,13 +183,68 @@ def run_livecell_amg( val_image_paths, val_gt_paths = _get_livecell_paths(input_folder, "val", n_val_per_cell_type=n_val_per_cell_type) test_image_paths, _ = _get_livecell_paths(input_folder, "test") - predictor = inference.get_predictor(checkpoint, model_type) - automatic_mask_generation.run_amg_grid_search_and_inference( - predictor, val_image_paths, val_gt_paths, test_image_paths, + grid_search_values = instance_segmentation.default_grid_search_values_amg( + iou_thresh_values=iou_thresh_values, + stability_score_values=stability_score_values, + ) + + instance_segmentation.run_instance_segmentation_grid_search_and_inference( + amg, grid_search_values, + val_image_paths, val_gt_paths, test_image_paths, embedding_folder, prediction_folder, gs_result_folder, - iou_thresh_values=iou_thresh_values, stability_score_values=stability_score_values, - AMG=AMG, verbose_gs=verbose_gs, ) + return prediction_folder + + +def run_livecell_instance_segmentation_with_decoder( + checkpoint: Union[str, os.PathLike], + input_folder: Union[str, os.PathLike], + model_type: str, + experiment_folder: Union[str, os.PathLike], + verbose_gs: bool = False, + n_val_per_cell_type: int = 25, +) -> str: + """Run automatic mask generation grid-search and inference for livecell. + + Args: + checkpoint: The segment anything model checkpoint. + input_folder: The folder with the livecell data. + model_type: The type of the segmenta anything model. + experiment_folder: The folder where to save all data associated with the experiment. + verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. + n_val_per_cell_type: The number of validation images per cell type. + + Returns: + The path where the predicted images are stored. + """ + embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved + os.makedirs(embedding_folder, exist_ok=True) + + segmenter = load_instance_segmentation_with_decoder_from_checkpoint( + checkpoint, model_type, + ) + seg_prefix = "instance_segmentation_with_decoder" + + # where the predictions are saved + prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference") + os.makedirs(prediction_folder, exist_ok=True) + + # where the grid-search results are saved + gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search") + os.makedirs(gs_result_folder, exist_ok=True) + + val_image_paths, val_gt_paths = _get_livecell_paths(input_folder, "val", n_val_per_cell_type=n_val_per_cell_type) + test_image_paths, _ = _get_livecell_paths(input_folder, "test") + + grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder() + + instance_segmentation.run_instance_segmentation_grid_search_and_inference( + segmenter, grid_search_values, + val_image_paths, val_gt_paths, test_image_paths, + embedding_dir=embedding_folder, prediction_dir=prediction_folder, + result_dir=gs_result_folder, + ) + return prediction_folder def _run_multiple_prompt_settings(args, prompt_settings): diff --git a/micro_sam/evaluation/model_comparison.py b/micro_sam/evaluation/model_comparison.py index 702f944f..6df9e8f4 100644 --- a/micro_sam/evaluation/model_comparison.py +++ b/micro_sam/evaluation/model_comparison.py @@ -109,7 +109,7 @@ def generate_data_for_model_comparison( output_folder: The folder where the samples will be saved. model_type1: The first model to use for comparison. The value needs to be a valid model_type for `micro_sam.util.get_sam_model`. - model_type1: The second model to use for comparison. + model_type2: The second model to use for comparison. The value needs to be a valid model_type for `micro_sam.util.get_sam_model`. n_samples: The number of samples to draw from the dataloader. """ diff --git a/micro_sam/inference.py b/micro_sam/inference.py index da7e7275..3be14a5f 100644 --- a/micro_sam/inference.py +++ b/micro_sam/inference.py @@ -121,7 +121,6 @@ def batched_inference( # then we need to select the most likely mask (according to the predicted IOU) here. if reduce_multimasking and multimasking: _, max_index = batch_ious.max(axis=1) - # How can this be vectorized??? batch_masks = torch.cat([batch_masks[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) batch_ious = torch.cat([batch_ious[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) @@ -144,6 +143,6 @@ def batched_inference( ] if return_instance_segmentation: - masks = mask_data_to_segmentation(masks, image_shape, with_background=False, min_object_size=0) + masks = mask_data_to_segmentation(masks, with_background=False, min_object_size=0) return masks diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index 3abb1f33..b1c301c9 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -4,34 +4,34 @@ https://computational-cell-analytics.github.io/micro-sam/micro_sam.html """ -import multiprocessing as mp +import os +import pickle import warnings from abc import ABC -from concurrent import futures +from collections import OrderedDict from copy import deepcopy -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch import segment_anything.utils.amg as amg_utils import vigra -from elf.segmentation import embeddings as embed -from elf.segmentation.stitching import stitch_segmentation from nifty.tools import blocking - from segment_anything.predictor import SamPredictor -from skimage.transform import resize +from skimage.measure import regionprops from torchvision.ops.boxes import batched_nms, box_area +from torch_em.model import UNETR +from torch_em.util.segmentation import watershed_from_center_and_boundary_distances + try: from napari.utils import progress as tqdm except ImportError: from tqdm import tqdm from . import util -from .prompt_based_segmentation import segment_from_mask from ._vendored import batched_mask_to_box, mask_to_rle_pytorch # @@ -50,7 +50,6 @@ def __getitem__(self, index): def mask_data_to_segmentation( masks: List[Dict[str, Any]], - shape: Tuple[int, ...], with_background: bool, min_object_size: int = 0, max_object_size: Optional[int] = None, @@ -60,7 +59,6 @@ def mask_data_to_segmentation( Args: masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator. Only supports output_mode=binary_mask. - shape: The image shape. with_background: Whether the segmentation has background. If yes this function assures that the largest object in the output will be mapped to zero (the background value). min_object_size: The minimal size of an object in pixels. @@ -70,7 +68,9 @@ def mask_data_to_segmentation( """ masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) - segmentation = np.zeros(shape[:2], dtype="uint32") + # we could also get the shape from the crop box + shape = next(iter(masks))["segmentation"].shape + segmentation = np.zeros(shape, dtype="uint32") def require_numpy(mask): return mask.cpu().numpy() if torch.is_tensor(mask) else mask @@ -681,303 +681,162 @@ def get_amg( # -# Experimental embedding based instance segmentation functionality +# Instance segmentation functionality based on fine-tuned decoder # -class _EmbeddingMaskGenerator(AMGBase): - """Generates an instance segmentation without prompts, using an initial segmentations derived from image embeddings. +class DecoderAdapter(torch.nn.Module): + """Adapter to contain the UNETR decoder in a single module. - Uses an intial segmentation derived from the image embeddings via the Mutex Watershed, - an affinity based sementation method. - The computationally expensive steps of the mask generation are decoupled from cheaper post-processing operations, - to enable faster grid search and interactively changing the post-processing. - - Use this class as follows: - ```python - amg = EmbeddingMaskGenerator(predictor) - amg.initialize(image) # Initialize the masks, this takes care of all expensive computations. - masks = amg.generate(pred_iou_thresh=0.8) # Generate the masks. This is fast and enables testing parameters - ``` - - Args: - predictor: The segment anything predictor. - offsets: Offset values for the affinities computed from image embeddings that are used - for the mutex watershed. - min_initial_size: Minimal size of initial segments. - distance_type: The distance function used to turn embeddings into affinities. - bias: Value to bias the initial segmentation towards over-segmentation. - use_box: Whether to use boxes derived from the initial segments as prompts. - use_mask: Whether to use the initial segments as prompts. - use_points: Whether to use points derived from the initial segments as prompts. - box_extension: Factor for extending the bounding box prompts, given in the relative box size. - stability_score_offset: The amount to shift the cutoff when calculating the stability score. + To apply the decoder on top of pre-computed embeddings for + the segmentation functionality. + See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py """ - default_offsets = [[-1, 0], [0, -1], [-3, 0], [0, -3], [-9, 0], [0, -9]] - - def __init__( - self, - predictor: SamPredictor, - offsets: Optional[List[List[int]]] = None, - min_initial_size: int = 0, - distance_type: str = "l2", - bias: float = 0.0, - use_box: bool = True, - use_mask: bool = True, - use_points: bool = False, - box_extension: float = 0.05, - stability_score_offset: float = 1.0, - ): + def __init__(self, unetr): super().__init__() - self._predictor = predictor - self._offsets = self.default_offsets if offsets is None else offsets - self._min_initial_size = min_initial_size - self._distance_type = distance_type - self._bias = bias - self._use_box = use_box - self._use_mask = use_mask - self._use_points = use_points - self._box_extension = box_extension - self._stability_score_offset = stability_score_offset + self.base = unetr.base + self.out_conv = unetr.out_conv + self.deconv_out = unetr.deconv_out + self.decoder_head = unetr.decoder_head + self.final_activation = unetr.final_activation + self.postprocess_masks = unetr.postprocess_masks - # additional state that is set 'initialize' - self._initial_segmentation = None - - def _compute_initial_segmentation(self): - - embeddings = self._predictor.get_image_embedding().squeeze().cpu().numpy() - assert embeddings.shape == (256, 64, 64), f"{embeddings.shape}" - - initial_segmentation = embed.segment_embeddings_mws( - embeddings, distance_type=self._distance_type, offsets=self._offsets, bias=self._bias, - ).astype("uint32") - assert initial_segmentation.shape == (64, 64), f"{initial_segmentation.shape}" - - # filter out small initial objects - if self._min_initial_size > 0: - seg_ids, sizes = np.unique(initial_segmentation, return_counts=True) - initial_segmentation[np.isin(initial_segmentation, seg_ids[sizes < self._min_initial_size])] = 0 - - # resize to 256 x 256, which is the mask input expected by SAM - initial_segmentation = resize( - initial_segmentation, (256, 256), order=0, preserve_range=True, anti_aliasing=False - ).astype(initial_segmentation.dtype) - - return initial_segmentation - - def _compute_mask_data(self, initial_segmentation, crop_box, original_size, verbose): - seg_ids = np.unique(initial_segmentation) - if seg_ids[0] == 0: - seg_ids = seg_ids[1:] - - mask_data = amg_utils.MaskData() - # TODO batch this to be more efficient on GPUs - for seg_id in tqdm(seg_ids, disable=not verbose, desc="Compute masks from initial segmentation"): - mask = initial_segmentation == seg_id - masks, iou_preds, _ = segment_from_mask( - self._predictor, mask, original_size=original_size, - multimask_output=False, return_logits=True, return_all=True, - use_box=self._use_box, use_mask=self._use_mask, use_points=self._use_points, - box_extension=self._box_extension, - ) - # bring masks and iou_preds to a format compatible with _to_mask_data - masks, iou_preds = torch.from_numpy(masks[None]), torch.from_numpy(iou_preds[None]) - data = self._to_mask_data(masks, iou_preds, crop_box, original_size) - del masks - mask_data.cat(data) + self.decoder = unetr.decoder + self.deconv1 = unetr.deconv1 + self.deconv2 = unetr.deconv2 + self.deconv3 = unetr.deconv3 + self.deconv4 = unetr.deconv4 - return mask_data + def forward(self, input_, input_shape, original_shape): + z12 = input_ - @torch.no_grad() - def initialize( - self, - image: np.ndarray, - image_embeddings: Optional[util.ImageEmbeddings] = None, - i: Optional[int] = None, - verbose: bool = False - ) -> None: - """Initialize image embeddings and masks for an image. + z9 = self.deconv1(z12) + z6 = self.deconv2(z9) + z3 = self.deconv3(z6) + z0 = self.deconv4(z3) - Args: - image: The input image, volume or timeseries. - image_embeddings: Optional precomputed image embeddings. - See `util.precompute_image_embeddings` for details. - i: Index for the image data. Required if 'image' has three spatial dimensions - or a time dimension and two spatial dimensions. - verbose: Whether to print computation progress. - """ - original_size = image.shape[:2] - self._original_size = original_size + updated_from_encoder = [z9, z6, z3] - # the crop box is always the full image - crop_box = [0, 0, original_size[1], original_size[0]] - self._crop_boxes = [crop_box] + x = self.base(z12) + x = self.decoder(x, encoder_inputs=updated_from_encoder) + x = self.deconv_out(x) - if image_embeddings is None: - image_embeddings = util.precompute_image_embeddings(self._predictor, image,) - util.set_precomputed(self._predictor, image_embeddings, i=i) + x = torch.cat([x, z0], dim=1) + x = self.decoder_head(x) - # compute the initial segmentation via embedding based MWS and then refine the masks - # with the segment anything model - initial_segmentation = self._compute_initial_segmentation() - mask_data = self._compute_mask_data(initial_segmentation, crop_box, original_size, verbose) - # to be compatible with the file format of the super class we have to wrap the mask data in a list - crop_list = [mask_data] + x = self.out_conv(x) + if self.final_activation is not None: + x = self.final_activation(x) - # set the initialized data - self._is_initialized = True - self._initial_segmentation = initial_segmentation - self._crop_list = crop_list + x = self.postprocess_masks(x, input_shape, original_shape) + return x - @torch.no_grad() - def generate( - self, - pred_iou_thresh: float = 0.88, - stability_score_thresh: float = 0.95, - box_nms_thresh: float = 0.7, - min_mask_region_area: int = 0, - output_mode: str = "binary_mask", - ) -> List[Dict[str, Any]]: - """Generate instance segmentation for the currently initialized image. - - Args: - pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. - stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask - under changes to the cutoff used to binarize the model prediction. - box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. - min_mask_region_area: Minimal size for the predicted masks. - output_mode: The form masks are returned in. - - Returns: - The instance segmentation masks. - """ - if not self.is_initialized: - raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") - - data = self._postprocess_batch( - data=deepcopy(self.crop_list[0]), crop_box=self.crop_boxes[0], - original_size=self.original_size, - pred_iou_thresh=pred_iou_thresh, - stability_score_thresh=stability_score_thresh, - box_nms_thresh=box_nms_thresh - ) - - data.to_numpy() - masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, box_nms_thresh, output_mode) - return masks - def _resize_segmentation(self, segmentation, shape): - longest_size = max(shape) - longest_shape = (longest_size, longest_size) - resized_segmentation = resize( - segmentation, longest_shape, order=0, preserve_range=True, anti_aliasing=False - ).astype(segmentation.dtype) - crop = tuple(slice(0, sh) for sh in shape) - resized_segmentation = resized_segmentation[crop] - return resized_segmentation - - def get_initial_segmentation(self) -> np.ndarray: - """Get the initial instance segmentation. - - Returns: - The initial instance segmentation. - """ - if not self.is_initialized: - raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") - return self._resize_segmentation(self._initial_segmentation, self.original_size) - - def get_state(self) -> Dict[str, Any]: - """Get the initialized state of the mask generator. - - Returns: - State of the mask generator. - """ - state = super().get_state() - state["initial_segmentation"] = self._initial_segmentation - return state - - def set_state(self, state: Dict[str, Any]) -> None: - """Set the state of the mask generator. - - Args: - state: The state of the mask generator, e.g. from serialized state. - """ - self._initial_segmentation = state["initial_segmentation"] - super().set_state(state) +def load_instance_segmentation_with_decoder_from_checkpoint( + checkpoint: Union[os.PathLike, str], + model_type: str, + device: Optional[Union[str, torch.device]] = None +): + """Load `InstanceSegmentationWithDecoder` from a `training.JointSamTrainer` checkpoint. + Args: + checkpoint: The path to the checkpoint. + model_type: The type of the model, i.e. which image encoder type is used. + device: The device to use (cpu or cuda). -class _TiledEmbeddingMaskGenerator(_EmbeddingMaskGenerator): - """Generates an instance segmentation without prompts, using an initial segmentations derived from image embeddings. + Returns: + InstanceSegmentationWithDecoder + """ + device = util.get_device(device) + + # over-ride the unpickler with our custom one + custom_pickle = pickle + custom_pickle.Unpickler = util._CustomUnpickler + + state = torch.load(checkpoint, map_location=device, pickle_module=custom_pickle) + + # Get the predictor. + model_state = state["model_state"] + sam_prefix = "sam." + model_state = OrderedDict( + [(k[len(sam_prefix):] if k.startswith(sam_prefix) else k, v) for k, v in model_state.items()] + ) + + sam = util.sam_model_registry[model_type]() + sam.to(device) + sam.load_state_dict(model_state) + predictor = SamPredictor(sam) + predictor.model_type = model_type + + # Get the decoder. + # NOTE: we hard-code the UNETR settings for now. + # Eventually we may need to finds a way to be more flexible. + unetr = UNETR( + backbone="sam", + encoder=predictor.model.image_encoder, + out_channels=3, + use_sam_stats=True, + final_activation="Sigmoid", + use_skip_connection=False, + resize_input=True, + ) + + encoder_state = [] + encoder_prefix = "image_" + encoder_state = OrderedDict( + (k[len(encoder_prefix):], v) for k, v in model_state.items() if k.startswith(encoder_prefix) + ) + + decoder_state = state["decoder_state"] + unetr_state = OrderedDict(list(encoder_state.items()) + list(decoder_state.items())) + unetr.load_state_dict(unetr_state) + unetr.to(device) + + decoder = DecoderAdapter(unetr) + + # Instantiate the segmenter. + segmenter = InstanceSegmentationWithDecoder(predictor, decoder) + return segmenter + + +class InstanceSegmentationWithDecoder: + """Generates an instance segmentation without prompts, using a decoder. + + Implements the same interface as `AutomaticMaskGenerator`. - Implements the same logic as `EmbeddingMaskGenerator`, but for tiled image embeddings. + Use this class as follows: + ```python + segmenter = InstanceSegmentationWithDecoder(predictor, decoder) + segmenter.initialize(image) # Predict the image embeddings and decoder outputs. + masks = segmenter.generate(center_distance_threshold=0.75) # Generate the instance segmentation. + ``` Args: predictor: The segment anything predictor. - n_threads: The number of threads used for parallelize operations over the tiles. - with_background: Whether to run segmentation with background. - **kwargs: Keywoard arguments for `EmbeddingMaskGenerator`. + decoder: The decoder to predict intermediate representations + for instance segmentation. """ def __init__( self, predictor: SamPredictor, - n_threads: int = mp.cpu_count(), - with_background: bool = True, - **kwargs - ): - super().__init__(predictor=predictor, **kwargs) - self._n_threads = n_threads - self._with_background = with_background - - # additional state for 'initialize' - self._tile_shape = None - self._halo = None - - # state for saving the stitched initial segmentation - # (this is quite complex, so we save it to only compute once) - self._stitched_initial_segmentation = None - - def _compute_initial_segmentations(self, image_embeddings, i, n_tiles, verbose): - features = image_embeddings["features"] - - def segment_tile(tile_id): - tile_features = features[tile_id] - tile_image_embeddings = { - "features": tile_features, - "input_size": tile_features.attrs["input_size"], - "original_size": tile_features.attrs["original_size"] - } - util.set_precomputed(self._predictor, tile_image_embeddings, i) - return self._compute_initial_segmentation() - - with futures.ThreadPoolExecutor(self._n_threads) as tp: - initial_segmentations = list(tqdm( - tp.map(segment_tile, range(n_tiles)), disable=not verbose, total=n_tiles, - desc="Tile-based initial segmentation" - )) - - return initial_segmentations + decoder: torch.nn.Module, + ) -> None: + self._predictor = predictor + self._decoder = decoder - def _compute_mask_data_tiled(self, image_embeddings, i, initial_segmentations, n_tiles, verbose): - features = image_embeddings["features"] + # The decoder outputs. + self._foreground = None + self._center_distances = None + self._boundary_distances = None - mask_data = [] - for tile_id in tqdm(range(n_tiles), disable=not verbose, total=n_tiles, desc="Tile-based mask computation"): - tile_features = features[tile_id] - this_tile_shape = tile_features.attrs["original_size"] - tile_image_embeddings = { - "features": tile_features, - "input_size": tile_features.attrs["input_size"], - "original_size": this_tile_shape - } - util.set_precomputed(self._predictor, tile_image_embeddings, i) - this_crop_box = [0, 0, this_tile_shape[1], this_tile_shape[0]] - tile_data = self._compute_mask_data( - initial_segmentations[tile_id], this_crop_box, this_tile_shape, verbose=False - ) - mask_data.append(tile_data) + self._is_initialized = False - return mask_data + @property + def is_initialized(self): + """Whether the mask generator has already been initialized. + """ + return self._is_initialized @torch.no_grad() def initialize( @@ -985,149 +844,116 @@ def initialize( image: np.ndarray, image_embeddings: Optional[util.ImageEmbeddings] = None, i: Optional[int] = None, - tile_shape: Optional[Tuple[int, int]] = None, - halo: Optional[Tuple[int, int]] = None, - verbose: bool = False, - embedding_save_path: Optional[str] = None, ) -> None: - """Initialize image embeddings and masks for an image. + """Initialize image embeddings and decoder predictions for an image. Args: image: The input image, volume or timeseries. image_embeddings: Optional precomputed image embeddings. See `util.precompute_image_embeddings` for details. - i: Index for the image data. Required if 'image' has three spatial dimensions + i: Index for the image data. Required if `image` has three spatial dimensions or a time dimension and two spatial dimensions. - tile_shape: The tile shape for embedding prediction. - halo: The overlap of between tiles. - verbose: Whether to print computation progress. - embedding_save_path: Where to save the image embeddings. """ - original_size = image.shape[:2] - image_embeddings, tile_shape, halo = _compute_tiled_embeddings( - self._predictor, image, image_embeddings, embedding_save_path, tile_shape, halo - ) + if image_embeddings is None: + image_embeddings = util.precompute_image_embeddings(self._predictor, image) - tiling = blocking([0, 0], original_size, tile_shape) - n_tiles = tiling.numberOfBlocks - initial_segmentations = self._compute_initial_segmentations(image_embeddings, i, n_tiles, verbose) - mask_data = self._compute_mask_data_tiled(image_embeddings, i, initial_segmentations, n_tiles, verbose) + # This could be made more versatile to also support other decoder inputs, + # e.g. the UNETR with skip connections. + if isinstance(image_embeddings["features"], torch.Tensor): + embeddings = image_embeddings["features"].to(self._predictor.device) + else: + embeddings = torch.from_numpy(image_embeddings["features"]).to(self._predictor.device) + + input_shape = tuple(image_embeddings["input_size"]) + original_shape = tuple(image_embeddings["original_size"]) + output = self._decoder( + embeddings, input_shape, original_shape + ).cpu().numpy().squeeze(0) + + assert output.shape[0] == 3, f"{output.shape}" + + self._foreground = output[0] + self._center_distances = output[1] + self._boundary_distances = output[2] - # set the initialized data self._is_initialized = True - self._tile_shape = tile_shape - self._halo = halo - self._initial_segmentation = initial_segmentations - self._crop_list = mask_data - self._original_size = original_size - # the crop box is always the full local tile - tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)] - self._crop_boxes = [ - [0, 0, tile.end[1] - tile.begin[1], tile.end[0] - tile.begin[0]] for tile in tiles + def _to_masks(self, segmentation, output_mode): + if output_mode != "binary_mask": + raise NotImplementedError + + props = regionprops(segmentation) + ndim = segmentation.ndim + assert ndim in (2, 3) + + shape = segmentation.shape + if ndim == 2: + crop_box = [0, shape[1], 0, shape[0]] + else: + crop_box = [0, shape[2], 0, shape[1], 0, shape[0]] + + # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h] + def to_bbox_2d(bbox): + y0, x0 = bbox[0], bbox[1] + w = bbox[3] - x0 + h = bbox[2] - y0 + return [x0, w, y0, h] + + def to_bbox_3d(bbox): + z0, y0, x0 = bbox[0], bbox[1], bbox[2] + w = bbox[5] - x0 + h = bbox[4] - y0 + d = bbox[3] - y0 + return [x0, w, y0, h, z0, d] + + to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d + masks = [ + { + "segmentation": segmentation == prop.label, + "area": prop.area, + "bbox": to_bbox(prop.bbox), + "crop_box": crop_box, + "seg_id": prop.label, + } for prop in props ] + return masks - @torch.no_grad() + # TODO find good default values (empirically) def generate( self, - pred_iou_thresh: float = 0.88, - stability_score_thresh: float = 0.95, - box_nms_thresh: float = 0.7, - min_mask_region_area: int = 0, - verbose: bool = False - ) -> np.ndarray: + center_distance_threshold: float = 0.5, + boundary_distance_threshold: float = 0.5, + foreground_threshold: float = 0.5, + distance_smoothing: float = 1.6, + min_size: int = 0, + output_mode: Optional[str] = "binary_mask", + ) -> List[Dict[str, Any]]: """Generate instance segmentation for the currently initialized image. Args: - pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. - stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask - under changes to the cutoff used to binarize the model prediction. - box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. - min_mask_region_area: Minimal size for the predicted masks. - verbose: Whether to print progress of the computation. + center_distance_threshold: Center distance predictions below this value will be + used to find seeds (intersected with thresholded boundary distance predictions). + boundary_distance_threshold: Boundary distance predictions below this value will be + used to find seeds (intersected with thresholded center distance predictions). + foreground_threshold: Foreground predictions above this value will be used as foreground mask. + distance_smoothing: Sigma value for smoothing the distance predictions. + min_size: Minimal object size in the segmentation result. + output_mode: The form masks are returned in. Pass None to directly return the instance segmentation. Returns: The instance segmentation masks. """ if not self.is_initialized: - raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") - tiling = blocking([0, 0], self.original_size, self._tile_shape) - - def segment_tile(_, tile_id): - tile = tiling.getBlockWithHalo(tile_id, list(self._halo)).outerBlock - mask_data = deepcopy(self._crop_list[tile_id]) - crop_box = self.crop_boxes[tile_id] - this_tile_shape = tuple(end - beg for beg, end in zip(tile.begin, tile.end)) - mask_data = self._postprocess_batch( - data=mask_data, crop_box=crop_box, original_size=this_tile_shape, - pred_iou_thresh=pred_iou_thresh, - stability_score_thresh=stability_score_thresh, - box_nms_thresh=box_nms_thresh, - ) - mask_data.to_numpy() - mask_data = self._postprocess_masks( - mask_data, 0, box_nms_thresh, box_nms_thresh, output_mode="binary_mask" - ) - mask_data = mask_data_to_segmentation(mask_data, this_tile_shape, with_background=self._with_background) - return mask_data - - input_ = _FakeInput(self.original_size) - segmentation = stitch_segmentation( - input_, segment_tile, self._tile_shape, self._halo, with_background=self._with_background, verbose=verbose + raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.") + + segmentation = watershed_from_center_and_boundary_distances( + self._center_distances, self._boundary_distances, self._foreground, + center_distance_threshold=center_distance_threshold, + boundary_distance_threshold=boundary_distance_threshold, + foreground_threshold=foreground_threshold, + distance_smoothing=distance_smoothing, + min_size=min_size, ) - - if min_mask_region_area > 0: - seg_ids, sizes = np.unique(segmentation, return_counts=True) - segmentation[np.isin(segmentation, seg_ids[sizes < min_mask_region_area])] = 0 - + if output_mode is not None: + segmentation = self._to_masks(segmentation, output_mode) return segmentation - - def get_initial_segmentation(self) -> np.ndarray: - """Get the initial instance segmentation. - - Returns: - The initial instance segmentation. - """ - if not self.is_initialized: - raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.") - - if self._stitched_initial_segmentation is not None: - return self._stitched_initial_segmentation - - tiling = blocking([0, 0], self.original_size, self._tile_shape) - - def segment_tile(_, tile_id): - tile = tiling.getBlockWithHalo(tile_id, list(self._halo)).outerBlock - this_tile_shape = tuple(end - beg for beg, end in zip(tile.begin, tile.end)) - return self._resize_segmentation(self._initial_segmentation[tile_id], this_tile_shape) - - input_ = _FakeInput(self.original_size) - initial_segmentation = stitch_segmentation( - input_, segment_tile, - self._tile_shape, self._halo, - with_background=self._with_background, verbose=False - ) - - self._stitched_initial_segmentation = initial_segmentation - return initial_segmentation - - def get_state(self) -> Dict[str, Any]: - """Get the initialized state of the mask generator. - - Returns: - State of the mask generator. - """ - state = super().get_state() - state["tile_shape"] = self._tile_shape - state["halo"] = self._halo - return state - - def set_state(self, state: Dict[str, Any]) -> None: - """Set the state of the mask generator. - - Args: - state: The state of the mask generator, e.g. from serialized state. - """ - self._tile_shape = state["tile_shape"] - self._halo = state["halo"] - super().set_state(state) diff --git a/micro_sam/multi_dimensional_segmentation.py b/micro_sam/multi_dimensional_segmentation.py index 49020ef1..c448ba4e 100644 --- a/micro_sam/multi_dimensional_segmentation.py +++ b/micro_sam/multi_dimensional_segmentation.py @@ -192,7 +192,7 @@ def segment_3d_from_slice( seg_z = amg.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh) seg_z = mask_data_to_segmentation( - seg_z, shape=raw.shape[1:], with_background=True, + seg_z, with_background=True, min_object_size=min_object_size_z, max_object_size=max_object_size_z, ) diff --git a/micro_sam/napari.yaml b/micro_sam/napari.yaml index e8bd09c7..b39db502 100644 --- a/micro_sam/napari.yaml +++ b/micro_sam/napari.yaml @@ -26,6 +26,9 @@ contributions: - id: micro-sam.embedding_widget python_name: micro_sam.sam_annotator._widgets:embedding_widget title: Embedding widget + - id: micro-sam.cachedir_widget + python_name: micro_sam.sam_annotator._widgets:cachedir_widget + title: Set cache directory sample_data: - command: micro-sam.sample_data_image_series display_name: Image series example data @@ -51,3 +54,5 @@ contributions: widgets: - command: micro-sam.embedding_widget display_name: Embedding widget + - command: micro-sam.cachedir_widget + display_name: Set cache directory diff --git a/micro_sam/precompute_state.py b/micro_sam/precompute_state.py index 66e6edf5..6f20dd7a 100644 --- a/micro_sam/precompute_state.py +++ b/micro_sam/precompute_state.py @@ -158,7 +158,7 @@ def main(): parser = argparse.ArgumentParser(description="Compute the embeddings for an image.") parser.add_argument("-i", "--input_path", required=True) parser.add_argument("-o", "--output_path", required=True) - parser.add_argument("-m", "--model_type", default="vit_h") + parser.add_argument("-m", "--model_type", default=util._DEFAULT_MODEL) parser.add_argument("-c", "--checkpoint_path", default=None) parser.add_argument("-k", "--key") parser.add_argument( diff --git a/micro_sam/prompt_generators.py b/micro_sam/prompt_generators.py index 429020cd..fae4b681 100644 --- a/micro_sam/prompt_generators.py +++ b/micro_sam/prompt_generators.py @@ -204,7 +204,6 @@ def _sample_points(self, segmentation, bbox_coordinates, center_coordinates): return all_coords, all_labels - # TODO make compatible with exact same input shape def __call__( self, segmentation: torch.Tensor, @@ -220,7 +219,7 @@ def __call__( """Generate the prompts for one object in the segmentation. Args: - segmentation: Instance segmentation masks . + The groundtruth segmentation. Expects a float tensor of shape NUM_OBJECTS x 1 x H x W. bbox_coordinates: The precomputed bounding boxes of particular object in the segmentation. center_coordinates: The precomputed center coordinates of particular object in the segmentation. If passed, these coordinates will be used as the first positive point prompt. diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 887aae8a..b048f788 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -1,7 +1,6 @@ -from enum import Enum import os from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Literal from magicgui import magic_factory, widgets from napari.qt.threading import thread_worker @@ -13,29 +12,26 @@ ImageEmbeddings, get_sam_model, precompute_image_embeddings, - _MODEL_URLS, + models, _DEFAULT_MODEL, _available_devices, + get_cache_directory, ) if TYPE_CHECKING: import napari -Model = Enum("Model", _MODEL_URLS) -available_devices_list = ["auto"] + _available_devices() - @magic_factory( pbar={'visible': False, 'max': 0, 'value': 0, 'label': 'working...'}, call_button="Compute image embeddings", - device = {"choices": available_devices_list}, save_path={"mode": "d"}, # choose a directory ) def embedding_widget( pbar: widgets.ProgressBar, image: "napari.layers.Image", - model: Model = Model.__getitem__(_DEFAULT_MODEL), - device = "auto", + model: Literal[tuple(models().urls.keys())] = _DEFAULT_MODEL, + device: Literal[tuple(["auto"] + _available_devices())] = "auto", save_path: Optional[Path] = None, # where embeddings for this image are cached (optional) optional_custom_weights: Optional[Path] = None, # A filepath or URL to custom model weights. ) -> ImageEmbeddings: @@ -52,8 +48,9 @@ def embedding_widget( @thread_worker(connect={'started': pbar.show, 'finished': pbar.hide}) def _compute_image_embedding(state, image_data, save_path, ndim=None, - device="auto", model=Model.__getitem__(_DEFAULT_MODEL), - optional_custom_weights=None): + device="auto", model=_DEFAULT_MODEL, + optional_custom_weights=None, + ): # Make sure save directory exists and is an empty directory if save_path is not None: os.makedirs(save_path, exist_ok=True) @@ -69,16 +66,31 @@ def _compute_image_embedding(state, image_data, save_path, ndim=None, "The user selected 'save_path' is not a zarr array " f"or empty directory: {save_path}" ) + # Initialize the model - state.predictor = get_sam_model(device=device, model_type=model.name, - checkpoint_path=optional_custom_weights) + state.predictor = get_sam_model(device=device, model_type=model, checkpoint_path=optional_custom_weights) # Compute the image embeddings state.image_embeddings = precompute_image_embeddings( - predictor = state.predictor, - input_ = image_data, - save_path = str(save_path), + predictor=state.predictor, + input_=image_data, + save_path=save_path, ndim=ndim, ) return state # returns napari._qt.qthreading.FunctionWorker - return _compute_image_embedding(state, image.data, save_path, ndim=ndim, device=device, model=model, optional_custom_weights=optional_custom_weights) + return _compute_image_embedding( + state, image.data, save_path, ndim=ndim, device=device, model=model, + optional_custom_weights=optional_custom_weights + ) + + +@magic_factory( + call_button="Update settings", + cache_directory={"mode": "d"}, # choose a directory +) +def settings_widget( + cache_directory: Optional[Path] = get_cache_directory(), +): + """Update micro-sam settings.""" + os.environ["MICROSAM_CACHEDIR"] = str(cache_directory) + print(f"micro-sam cache directory set to: {cache_directory}") diff --git a/micro_sam/sam_annotator/annotator_2d.py b/micro_sam/sam_annotator/annotator_2d.py index 9c8e002e..4f8f5357 100644 --- a/micro_sam/sam_annotator/annotator_2d.py +++ b/micro_sam/sam_annotator/annotator_2d.py @@ -81,7 +81,7 @@ def _autosegment_widget( seg = state.amg.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh) seg = instance_segmentation.mask_data_to_segmentation( - seg, shape, with_background=with_background, min_object_size=min_object_size + seg, with_background=with_background, min_object_size=min_object_size ) assert isinstance(seg, np.ndarray) diff --git a/micro_sam/sam_annotator/annotator_3d.py b/micro_sam/sam_annotator/annotator_3d.py index 0bf8cb3c..176d779a 100644 --- a/micro_sam/sam_annotator/annotator_3d.py +++ b/micro_sam/sam_annotator/annotator_3d.py @@ -181,7 +181,7 @@ def _autosegment_widget( seg = state.amg.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh) seg = instance_segmentation.mask_data_to_segmentation( - seg, shape, with_background=with_background, min_object_size=min_object_size + seg, with_background=with_background, min_object_size=min_object_size ) assert isinstance(seg, np.ndarray) diff --git a/micro_sam/sam_annotator/annotator_tracking.py b/micro_sam/sam_annotator/annotator_tracking.py index f0683696..6edbc79f 100644 --- a/micro_sam/sam_annotator/annotator_tracking.py +++ b/micro_sam/sam_annotator/annotator_tracking.py @@ -369,7 +369,7 @@ def _commit_tracking_widget(v: Viewer, layer: str = "current_track") -> None: vutil.clear_annotations(v, clear_segmentations=False) -@magicgui(call_button="Clear Annotations [Shfit + C]") +@magicgui(call_button="Clear Annotations [Shift + C]") def _clear_widget_tracking(v: Viewer) -> None: _reset_tracking_state() vutil.clear_annotations(v) diff --git a/micro_sam/sam_annotator/util.py b/micro_sam/sam_annotator/util.py index 992d2284..c2f39e7a 100644 --- a/micro_sam/sam_annotator/util.py +++ b/micro_sam/sam_annotator/util.py @@ -33,7 +33,7 @@ def clear_annotations(v: napari.Viewer, clear_segmentations=True) -> None: v.layers["current_track"].refresh() -@magicgui(call_button="Clear Annotations [Shfit + C]") +@magicgui(call_button="Clear Annotations [Shift + C]") def _clear_widget(v: napari.Viewer) -> None: clear_annotations(v) diff --git a/micro_sam/sample_data.py b/micro_sam/sample_data.py index 8cdde024..407bc965 100644 --- a/micro_sam/sample_data.py +++ b/micro_sam/sample_data.py @@ -24,6 +24,8 @@ from skimage.measure import label from skimage.transform import resize +from .util import get_cache_directory + def fetch_image_series_example_data(save_directory: Union[str, os.PathLike]) -> str: """Download the sample images for the image series annotator. @@ -66,9 +68,8 @@ def sample_data_image_series(): # Check the documentation for more information about the # add_image_kwargs # https://napari.org/stable/api/napari.Viewer.html#napari.Viewer.add_image - _CACHE_DIR = os.environ.get('MICROSAM_CACHEDIR') or pooch.os_cache('micro_sam') - default_base_data_dir = os.path.join(_CACHE_DIR, 'sample_data') - data_directory = fetch_image_series_example_data(default_base_data_dir) + base_data_directory = os.path.join(get_cache_directory(), 'sample_data') + data_directory = fetch_image_series_example_data(base_data_directory) fnames = os.listdir(data_directory) full_filenames = [os.path.join(data_directory, f) for f in fnames] full_filenames.sort() @@ -108,9 +109,8 @@ def sample_data_wholeslide(): # Check the documentation for more information about the # add_image_kwargs # https://napari.org/stable/api/napari.Viewer.html#napari.Viewer.add_image - _CACHE_DIR = os.environ.get('MICROSAM_CACHEDIR') or pooch.os_cache('micro_sam') - default_base_data_dir = os.path.join(_CACHE_DIR, 'sample_data') - filename = fetch_wholeslide_example_data(default_base_data_dir) + base_data_directory = os.path.join(get_cache_directory(), 'sample_data') + filename = fetch_wholeslide_example_data(base_data_directory) data = imageio.imread(filename) add_image_kwargs = {"name": "wholeslide"} return [(data, add_image_kwargs)] @@ -147,9 +147,8 @@ def sample_data_livecell(): # Check the documentation for more information about the # add_image_kwargs # https://napari.org/stable/api/napari.Viewer.html#napari.Viewer.add_image - _CACHE_DIR = os.environ.get('MICROSAM_CACHEDIR') or pooch.os_cache('micro_sam') - default_base_data_dir = os.path.join(_CACHE_DIR, 'sample_data') - filename = fetch_livecell_example_data(default_base_data_dir) + base_data_directory = os.path.join(get_cache_directory(), 'sample_data') + filename = fetch_livecell_example_data(base_data_directory) data = imageio.imread(filename) add_image_kwargs = {"name": "livecell"} return [(data, add_image_kwargs)] @@ -186,9 +185,8 @@ def sample_data_hela_2d(): # Check the documentation for more information about the # add_image_kwargs # https://napari.org/stable/api/napari.Viewer.html#napari.Viewer.add_image - _CACHE_DIR = os.environ.get('MICROSAM_CACHEDIR') or pooch.os_cache('micro_sam') - default_base_data_dir = os.path.join(_CACHE_DIR, 'sample_data') - filename = fetch_hela_2d_example_data(default_base_data_dir) + base_data_directory = os.path.join(get_cache_directory(), 'sample_data') + filename = fetch_hela_2d_example_data(base_data_directory) data = imageio.imread(filename) add_image_kwargs = {"name": "hela_2d"} return [(data, add_image_kwargs)] @@ -230,9 +228,8 @@ def sample_data_3d(): # Check the documentation for more information about the # add_image_kwargs # https://napari.org/stable/api/napari.Viewer.html#napari.Viewer.add_image - _CACHE_DIR = os.environ.get('MICROSAM_CACHEDIR') or pooch.os_cache('micro_sam') - default_base_data_dir = os.path.join(_CACHE_DIR, 'sample_data') - data_directory = fetch_3d_example_data(default_base_data_dir) + base_data_directory = os.path.join(get_cache_directory(), 'sample_data') + data_directory = fetch_3d_example_data(base_data_directory) fnames = os.listdir(data_directory) full_filenames = [os.path.join(data_directory, f) for f in fnames] full_filenames.sort() @@ -281,9 +278,8 @@ def sample_data_tracking(): # Check the documentation for more information about the # add_image_kwargs # https://napari.org/stable/api/napari.Viewer.html#napari.Viewer.add_image - _CACHE_DIR = os.environ.get('MICROSAM_CACHEDIR') or pooch.os_cache('micro_sam') - default_base_data_dir = os.path.join(_CACHE_DIR, 'sample_data') - data_directory = fetch_tracking_example_data(default_base_data_dir) + base_data_directory = os.path.join(get_cache_directory(), 'sample_data') + data_directory = fetch_tracking_example_data(base_data_directory) fnames = os.listdir(data_directory) full_filenames = [os.path.join(data_directory, f) for f in fnames] full_filenames.sort() @@ -328,9 +324,8 @@ def sample_data_segmentation(): # Check the documentation for more information about the # add_image_kwargs # https://napari.org/stable/api/napari.Viewer.html#napari.Viewer.add_image - _CACHE_DIR = os.environ.get('MICROSAM_CACHEDIR') or pooch.os_cache('micro_sam') - default_base_data_dir = os.path.join(_CACHE_DIR, 'sample_data') - data_directory = fetch_tracking_segmentation_data(default_base_data_dir) + base_data_directory = os.path.join(get_cache_directory(), 'sample_data') + data_directory = fetch_tracking_segmentation_data(base_data_directory) fnames = os.listdir(data_directory) full_filenames = [os.path.join(data_directory, f) for f in fnames] full_filenames.sort() @@ -344,7 +339,7 @@ def synthetic_data(shape, seed=None): ndim = len(shape) assert ndim in (2, 3) image_shape = shape if ndim == 2 else shape[1:] - image = binary_blobs(length=image_shape[0], blob_size_fraction=0.05, volume_fraction=0.15, seed=seed) + image = binary_blobs(length=image_shape[0], blob_size_fraction=0.05, volume_fraction=0.15, rng=seed) if image_shape[1] != image_shape[0]: image = resize(image, image_shape, order=0, anti_aliasing=False, preserve_range=True).astype(image.dtype) @@ -353,6 +348,7 @@ def synthetic_data(shape, seed=None): image = np.stack([image] * nz) segmentation = label(image) + image = image.astype("uint8") * 255 return image, segmentation diff --git a/micro_sam/training/__init__.py b/micro_sam/training/__init__.py index 225b6568..a50db8af 100644 --- a/micro_sam/training/__init__.py +++ b/micro_sam/training/__init__.py @@ -2,4 +2,5 @@ """ from .sam_trainer import SamTrainer, SamLogger -from .util import ConvertToSamInputs, get_trainable_sam_model +from .util import ConvertToSamInputs, get_trainable_sam_model, identity +from .joint_sam_trainer import JointSamTrainer, JointSamLogger diff --git a/micro_sam/training/joint_sam_trainer.py b/micro_sam/training/joint_sam_trainer.py new file mode 100644 index 00000000..e5fed239 --- /dev/null +++ b/micro_sam/training/joint_sam_trainer.py @@ -0,0 +1,200 @@ +import os +import time +import numpy as np +from collections import OrderedDict + +import torch +from torchvision.utils import make_grid + +from .sam_trainer import SamTrainer + +from torch_em.trainer.logger_base import TorchEmLogger +from torch_em.trainer.tensorboard_logger import normalize_im + + +class JointSamTrainer(SamTrainer): + def __init__( + self, + unetr: torch.nn.Module, + instance_loss: torch.nn.Module, + instance_metric: torch.nn.Module, + **kwargs + ): + super().__init__(**kwargs) + self.unetr = unetr + self.instance_loss = instance_loss + self.instance_metric = instance_metric + + def save_checkpoint(self, name, best_metric, **extra_save_dict): + current_unetr_state = self.unetr.state_dict() + decoder_state = [] + for k, v in current_unetr_state.items(): + if not k.startswith("encoder"): + decoder_state.append((k, v)) + decoder_state = OrderedDict(decoder_state) + + super().save_checkpoint(name, best_metric, decoder_state=decoder_state, **extra_save_dict) + + def load_checkpoint(self, checkpoint="best"): + save_dict = super().load_checkpoint(checkpoint) + + # let's get the image encoder params from sam + sam_state = save_dict["model_state"] + encoder_state = [] + prune_prefix = "sam.image_" + for k, v in sam_state.items(): + if k.startswith(prune_prefix): + encoder_state.append((k[len(prune_prefix):], v)) + encoder_state = OrderedDict(encoder_state) + + # let's get the decoder params from unetr + decoder_state = save_dict["decoder_state"] + + # now let's merge the two to get the params for the unetr + unetr_state = OrderedDict(list(encoder_state.items()) + list(decoder_state.items())) + + self.unetr.load_state_dict(unetr_state) + self.unetr.to(self.device) + return save_dict + + def _instance_iteration(self, x, y, metric_for_val=False): + outputs = self.unetr(x.to(self.device)) + loss = self.instance_loss(outputs, y.to(self.device)) + if metric_for_val: + metric = self.instance_metric(outputs, y.to(self.device)) + return loss, metric + else: + return loss + + def _train_epoch_impl(self, progress, forward_context, backprop): + self.model.train() + + input_check_done = False + + n_iter = 0 + t_per_iter = time.time() + for x, y in self.train_loader: + labels_instances = y[:, 0, ...].unsqueeze(1) + labels_for_unetr = y[:, 1:, ...] + + input_check_done = self._check_input_normalization(x, input_check_done) + + self.optimizer.zero_grad() + + with forward_context(): + # 1. train for the interactive segmentation + (loss, mask_loss, iou_regression_loss, model_iou, + sampled_binary_y) = self._interactive_train_iteration(x, labels_instances) + + backprop(loss) + + self.optimizer.zero_grad() + + with forward_context(): + # 2. train for the automatic instance segmentation + unetr_loss = self._instance_iteration(x, labels_for_unetr) + + backprop(unetr_loss) + + if self.logger is not None: + lr = [pm["lr"] for pm in self.optimizer.param_groups][0] + samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None + self.logger.log_train( + self._iteration, loss, lr, x, labels_instances, samples, + mask_loss, iou_regression_loss, model_iou, unetr_loss + ) + + self._iteration += 1 + n_iter += 1 + if self._iteration >= self.max_iteration: + break + progress.update(1) + + t_per_iter = (time.time() - t_per_iter) / n_iter + return t_per_iter + + def _validate_impl(self, forward_context): + self.model.eval() + + input_check_done = False + + val_iteration = 0 + metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0 + + with torch.no_grad(): + for x, y in self.val_loader: + labels_instances = y[:, 0, ...].unsqueeze(1) + labels_for_unetr = y[:, 1:, ...] + + input_check_done = self._check_input_normalization(x, input_check_done) + + with forward_context(): + # 1. validate for the interactive segmentation + (loss, mask_loss, iou_regression_loss, model_iou, + sampled_binary_y, metric) = self._interactive_val_iteration(x, labels_instances, val_iteration) + + with forward_context(): + # 2. validate for the automatic instance segmentation + unetr_loss, unetr_metric = self._instance_iteration(x, labels_for_unetr, metric_for_val=True) + + loss_val += loss.item() + metric_val += metric.item() + (unetr_metric.item() / 3) + model_iou_val += model_iou.item() + val_iteration += 1 + + loss_val /= len(self.val_loader) + metric_val /= len(self.val_loader) + model_iou_val /= len(self.val_loader) + + if self.logger is not None: + self.logger.log_validation( + self._iteration, metric_val, loss_val, x, labels_instances, sampled_binary_y, + mask_loss, iou_regression_loss, model_iou_val, unetr_loss + ) + + return metric_val + + +class JointSamLogger(TorchEmLogger): + """@private""" + def __init__(self, trainer, save_root, **unused_kwargs): + super().__init__(trainer, save_root) + self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ + os.path.join(save_root, "logs", trainer.name) + os.makedirs(self.log_dir, exist_ok=True) + + self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) + self.log_image_interval = trainer.log_image_interval + + def add_image(self, x, y, samples, name, step): + selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] + + image = normalize_im(x[selection].cpu()) + + self.tb.add_image(tag=f"{name}/input", img_tensor=image, global_step=step) + self.tb.add_image(tag=f"{name}/target", img_tensor=y[selection], global_step=step) + sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4) + self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step) + + def log_train( + self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss + ): + self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) + self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step) + self.tb.add_scalar(tag="train/iou_loss", scalar_value=iou_regression_loss, global_step=step) + self.tb.add_scalar(tag="train/model_iou", scalar_value=model_iou, global_step=step) + self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step) + self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) + if step % self.log_image_interval == 0: + self.add_image(x, y, samples, "train", step) + + def log_validation( + self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss + ): + self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) + self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step) + self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step) + self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step) + self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step) + self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) + self.add_image(x, y, samples, "validation", step) diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index 1c0a290c..c251e749 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -1,6 +1,7 @@ import os import time import random +import warnings from typing import Optional import numpy as np @@ -28,7 +29,6 @@ class SamTrainer(torch_em.trainer.DefaultTrainer): n_objects_per_batch: If not given, we compute the loss for all objects in a sample. Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled. mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU. - sigmoid: The activation function for normalizing the model output. prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`) **kwargs: The keyword arguments of the DefaultTrainer super class. @@ -40,15 +40,16 @@ def __init__( n_sub_iteration: int, n_objects_per_batch: Optional[int] = None, mse_loss: torch.nn.Module = torch.nn.MSELoss(), - _sigmoid: torch.nn.Module = torch.nn.Sigmoid(), prompt_generator: PromptGeneratorBase = IterativePromptGenerator(), mask_prob: float = 0.5, **kwargs ): - super().__init__(**kwargs) + # We have to use the Dice Loss with reduce channel set to None. + # Hence we hard-code it here to avoid issues by passsing wrong options for the loss. + dice_loss = torch_em.loss.DiceLoss(reduce_channel=None) + super().__init__(loss=dice_loss, metric=dice_loss, **kwargs) self.convert_inputs = convert_inputs self.mse_loss = mse_loss - self._sigmoid = _sigmoid self.n_objects_per_batch = n_objects_per_batch self.n_sub_iteration = n_sub_iteration self.prompt_generator = prompt_generator @@ -107,133 +108,120 @@ def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): return n_pos, n_neg, get_boxes, multimask_output - def _get_dice(self, input_, target): - """Using the default "DiceLoss" called by the trainer from "torch_em" - """ - dice_loss = self.loss(input_, target) - return dice_loss - - def _get_iou(self, pred, true, eps=1e-7): - """Getting the IoU score for the predicted and true labels + def _compute_iou(self, pred, true, eps=1e-7): + """Compute the IoU score between the prediction and target. """ pred_mask = pred > 0.5 # binarizing the output predictions - overlap = pred_mask.logical_and(true).sum() - union = pred_mask.logical_or(true).sum() + overlap = pred_mask.logical_and(true).sum(dim=(1, 2, 3)) + union = pred_mask.logical_or(true).sum(dim=(1, 2, 3)) iou = overlap / (union + eps) return iou - def _get_net_loss(self, batched_outputs, y, sampled_ids): - """What do we do here? two **separate** things - 1. compute the mask loss: loss between the predicted and ground-truth masks - for this we just use the dice of the prediction vs. the gt (binary) mask - 2. compute the mask for the "IOU Regression Head": so we want the iou output from the decoder to - match the actual IOU between predicted and (binary) ground-truth mask. And we use L2Loss / MSE for this. + def _compute_loss(self, batched_outputs, y_one_hot): + """Compute the loss for one iteration. The loss is made up of two components: + - The mask loss: dice score between the predicted masks and targets. + - The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target. """ - masks = [m["masks"] for m in batched_outputs] - predicted_iou_values = [m["iou_predictions"] for m in batched_outputs] - with torch.no_grad(): - mean_model_iou = torch.mean(torch.stack([p.mean() for p in predicted_iou_values])) - - mask_loss = 0.0 # this is the loss term for 1. - iou_regression_loss = 0.0 # this is the loss term for 2. - - # outer loop is over the batch (different image/patch predictions) - for m_, y_, ids_, predicted_iou_ in zip(masks, y, sampled_ids, predicted_iou_values): - per_object_dice_scores, per_object_iou_scores = [], [] - - # inner loop is over the channels, this corresponds to the different predicted objects - for i, (predicted_obj, predicted_iou) in enumerate(zip(m_, predicted_iou_)): - predicted_obj = self._sigmoid(predicted_obj).to(self.device) - true_obj = (y_ == ids_[i]).to(self.device) - - # this is computing the LOSS for 1.) - _dice_score = min([self._get_dice(p[None], true_obj) for p in predicted_obj]) - per_object_dice_scores.append(_dice_score) - - # now we need to compute the loss for 2.) - with torch.no_grad(): - true_iou = torch.stack([self._get_iou(p[None], true_obj) for p in predicted_obj]) - _iou_score = self.mse_loss(true_iou, predicted_iou) - per_object_iou_scores.append(_iou_score) - - mask_loss = mask_loss + torch.mean(torch.stack(per_object_dice_scores)) - iou_regression_loss = iou_regression_loss + torch.mean(torch.stack(per_object_iou_scores)) + mask_loss, iou_regression_loss = 0.0, 0.0 + + # Loop over the batch. + for batch_output, targets in zip(batched_outputs, y_one_hot): + + predicted_objects = torch.sigmoid(batch_output["masks"]) + # Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop). + # We swap the axes that go into the dice loss so that the object axis + # corresponds to the channel axes. This ensures that the dice is computed + # independetly per channel. We do not reduce the channel axis in the dice, + # so that we can take the minimum (best score) of the 1/3 predicted masks per object. + dice_scores = torch.stack([ + self.loss(predicted_objects[:, i:i+1].swapaxes(0, 1), targets.swapaxes(0, 1)) + for i in range(predicted_objects.shape[1]) + ]) + dice_scores, _ = torch.min(dice_scores, dim=0) + + # Compute the actual IOU between the predicted and true objects. + # The outer loop is for the 1 or 3 predicted masks per true object. + with torch.no_grad(): + true_iou = torch.stack([ + self._compute_iou(predicted_objects[:, i:i+1], targets) for i in range(predicted_objects.shape[1]) + ]) + # Compute the L2 loss between true and predicted IOU. We need to swap the axes so that + # the object axis is back in the first dimension. + iou_score = self.mse_loss(true_iou.swapaxes(0, 1), batch_output["iou_predictions"]) + + mask_loss = mask_loss + torch.mean(dice_scores) + iou_regression_loss = iou_regression_loss + iou_score loss = mask_loss + iou_regression_loss - return loss, mask_loss, iou_regression_loss, mean_model_iou - - def _postprocess_outputs(self, masks): - """ "masks" look like -> (B, 1, X, Y) - where, B is the number of objects, (X, Y) is the input image shape - """ - instance_labels = [] - for m in masks: - instance_list = [self._sigmoid(_val) for _val in m.squeeze(1)] - instance_label = torch.stack(instance_list, dim=0).sum(dim=0).clip(0, 1) - instance_labels.append(instance_label) - instance_labels = torch.stack(instance_labels).unsqueeze(1) - return instance_labels - - def _get_val_metric(self, batched_outputs, sampled_binary_y): - """ Tracking the validation metric based on the DiceLoss - """ - masks = [m["masks"] for m in batched_outputs] - pred_labels = self._postprocess_outputs(masks) - - # we do the condition below to adapt w.r.t. the multimask output to select the "objectively" best response - if pred_labels.dim() == 5: - metric = min([self.metric(pred_labels[:, :, i, :, :], sampled_binary_y.to(self.device)) - for i in range(pred_labels.shape[2])]) - else: - metric = self.metric(pred_labels, sampled_binary_y.to(self.device)) - - return metric + return loss, mask_loss, iou_regression_loss # - # Update Masks Iteratively while Training + # Functionality for iterative prompting loss # - def _update_masks(self, batched_inputs, y, sampled_binary_y, sampled_ids, num_subiter, multimask_output): - # estimating the image inputs to make the computations faster for the decoder - input_images = torch.stack([self.model.preprocess(x=x["image"].to(self.device)) for x in batched_inputs], dim=0) - image_embeddings = self.model.image_embeddings_oft(input_images) + + def _get_best_masks(self, batched_outputs, batched_iou_predictions): + # Batched mask and logit (low-res mask) predictions. + masks = torch.stack([m["masks"] for m in batched_outputs]) + logits = torch.stack([m["low_res_masks"] for m in batched_outputs]) + + # Determine the best IOU across the multi-object prediction axis + # and turn this into a mask we can use for indexing. + # See https://stackoverflow.com/questions/72628000/pytorch-indexing-by-argmax + # for details on the indexing logic. + best_iou_idx = torch.argmax(batched_iou_predictions, dim=2, keepdim=True) + best_iou_idx = torch.zeros_like(batched_iou_predictions).scatter(2, best_iou_idx, value=1).bool() + + # Index the mask and logits with the best iou indices. + # Note that we squash the first two axes (batch x objects) into one when indexing. + # That's why we need to reshape bax into (batch x objects) using a view. + # We also keep the multi object axis as a singleton, that's why the view has (batch_size, n_objects, 1, ...) + batch_size, n_objects = masks.shape[:2] + h, w = masks.shape[-2:] + masks = masks[best_iou_idx].view(batch_size, n_objects, 1, h, w) + + h, w = logits.shape[-2:] + logits = logits[best_iou_idx].view(batch_size, n_objects, 1, h, w) + + # Binarize the mask. Note that the mask here also contains logits, so we use 0.0 + # as threshold instead of using 0.5. (Hence we don't need to apply a sigmoid) + masks = (masks > 0.0).float() + return masks, logits + + def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multimask_output): + """Compute the loss for several (sub-)iterations of iterative prompting. + In each iterations the prompts are updated based on the previous predictions. + """ + image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0 - # this loop takes care of the idea of sub-iterations, i.e. the number of times we iterate over each batch for i in range(0, num_subiter): - # we do multimasking only in the first sub-iteration as we then pass single prompt - # after the first sub-iteration, we don't do multimasking because we get multiple prompts + # We do multimasking only in the first sub-iteration as we then pass single prompt + # after the first sub-iteration, we don't do multimasking because we get multiple prompts. batched_outputs = self.model(batched_inputs, - multimask_output=multimask_output if i == 0 else False, - image_embeddings=image_embeddings) + image_embeddings=image_embeddings, + multimask_output=multimask_output if i == 0 else False) + + # Compute loss for tis sub-iteration. + net_loss, net_mask_loss, net_iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot) + + # Compute the mean IOU predicted by the model. We keep track of this in the logger. + batched_iou_predictions = torch.stack([m["iou_predictions"] for m in batched_outputs]) + with torch.no_grad(): + net_mean_model_iou = torch.mean(batched_iou_predictions) - # we want to average the loss and then backprop over the net sub-iterations - net_loss, net_mask_loss, net_iou_regression_loss, net_mean_model_iou = self._get_net_loss(batched_outputs, - y, sampled_ids) loss += net_loss mask_loss += net_mask_loss iou_regression_loss += net_iou_regression_loss mean_model_iou += net_mean_model_iou - masks, logits_masks = [], [] - # the loop below gets us the masks and logits from the batch-level outputs - for m in batched_outputs: - mask, l_mask = [], [] - for _m, _l, _iou in zip(m["masks"], m["low_res_masks"], m["iou_predictions"]): - best_iou_idx = torch.argmax(_iou) - best_mask, best_logits = _m[best_iou_idx][None], _l[best_iou_idx][None] - mask.append(self._sigmoid(best_mask)) - l_mask.append(best_logits) - - mask, l_mask = torch.stack(mask), torch.stack(l_mask) - masks.append(mask) - logits_masks.append(l_mask) - - masks, logits_masks = torch.stack(masks), torch.stack(logits_masks) - masks = (masks > 0.5).to(torch.float32) - - self._get_updated_points_per_mask_per_subiter(masks, sampled_binary_y, batched_inputs, logits_masks) + # Determine the next prompts based on current predictions. + with torch.no_grad(): + # Get the mask and logit predictions corresponding to the predicted object + # (per actual object) with the best IOU. + masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions) + batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits) loss = loss / num_subiter mask_loss = mask_loss / num_subiter @@ -242,12 +230,18 @@ def _update_masks(self, batched_inputs, y, sampled_binary_y, sampled_ids, num_su return loss, mask_loss, iou_regression_loss, mean_model_iou - def _get_updated_points_per_mask_per_subiter(self, masks, sampled_binary_y, batched_inputs, logits_masks): + def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks): # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs") - for x1, x2, _inp, logits in zip(masks, sampled_binary_y, batched_inputs, logits_masks): + for x1, x2, _inp, logits in zip(masks, y_one_hot, batched_inputs, logits_masks): # here, we get each object in the pairs and do the point choices per-object net_coords, net_labels, _, _ = self.prompt_generator(x2, x1) + # convert the point coordinates to the expected resolution for iterative prompting + # NOTE: + # - "only" need to transform the point prompts from the iterative prompting + # - the `logits` are the low res masks (256, 256), hence do not need the transform + net_coords = self.model.transform.apply_coords_torch(net_coords, y_one_hot.shape[-2:]) + updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \ if "point_coords" in _inp.keys() else net_coords updated_point_labels = torch.cat([_inp["point_labels"], net_labels], dim=1) \ @@ -264,59 +258,85 @@ def _get_updated_points_per_mask_per_subiter(self, masks, sampled_binary_y, batc else: # remove previously existing mask inputs to avoid using them in next sub-iteration _inp.pop("mask_inputs", None) + return batched_inputs + # # Training Loop # - def _update_samples_for_gt_instances(self, y, n_samples): - num_instances_gt = torch.amax(y, dim=(1, 2, 3)) - num_instances_gt = num_instances_gt.numpy().astype(int) - n_samples = min(num_instances_gt) if n_samples > min(num_instances_gt) else n_samples - return n_samples + def _preprocess_batch(self, batched_inputs, y, sampled_ids): + """Compute one hot target (one mask per channel) for the sampled ids + and restrict the number of sampled objects to the minimal number in the batch. + """ + assert len(y) == len(sampled_ids) + + # Get the minimal number of objects in this batch. + # The number of objects in a patch might be < n_objects_per_batch. + # This is why we need to restrict it here to ensure the same + # number of objects across the batch. + n_objects = min(len(ids) for ids in sampled_ids) + + y = y.to(self.device) + # Compute the one hot targets for the seg-id. + y_one_hot = torch.stack([ + torch.stack([target == seg_id for seg_id in ids[:n_objects]]) + for target, ids in zip(y, sampled_ids) + ]).float() + + # Also restrict the prompts to the number of objects. + batched_inputs = [ + {k: (v[:n_objects] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()} + for inp in batched_inputs + ] + return batched_inputs, y_one_hot + + def _interactive_train_iteration(self, x, y): + n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration) + + batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch) + batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids) + + loss, mask_loss, iou_regression_loss, model_iou = self._compute_iterative_loss( + batched_inputs, y_one_hot, + num_subiter=self.n_sub_iteration, multimask_output=multimask_output + ) + return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot + + def _check_input_normalization(self, x, input_check_done): + # The expected data range of the SAM model is 8bit (0-255). + # It can easily happen that data is normalized beforehand in training. + # For some reasons we don't fully understand this still works, but it + # should still be avoided and is very detrimental in some settings + # (e.g. when freezing the image encoder) + # We check once per epoch if the data seems to be normalized already and + # raise a warning if this is the case. + if not input_check_done: + data_min, data_max = x.min(), x.max() + if (data_min < 0) or (data_max < 1): + warnings.warn( + "It looks like you are normalizing the training data." + "The SAM model takes care of normalization, so it is better to not do this." + "We recommend to remove data normalization and input data in the range [0, 255]." + ) + input_check_done = True + + return input_check_done def _train_epoch_impl(self, progress, forward_context, backprop): self.model.train() + input_check_done = False + n_iter = 0 t_per_iter = time.time() for x, y in self.train_loader: + input_check_done = self._check_input_normalization(x, input_check_done) self.optimizer.zero_grad() with forward_context(): - n_samples = self._update_samples_for_gt_instances(y, self.n_objects_per_batch) - - n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration) - - batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, n_samples) - - assert len(y) == len(sampled_ids) - sampled_binary_y = [] - for i in range(len(y)): - _sampled = [torch.isin(y[i], torch.tensor(idx)) for idx in sampled_ids[i]] - sampled_binary_y.append(_sampled) - - # the steps below are done for one reason in a gist: - # to handle images where there aren't enough instances as expected - # (e.g. where one image has only one instance) - obj_lengths = [len(s) for s in sampled_binary_y] - sampled_binary_y = [s[:min(obj_lengths)] for s in sampled_binary_y] - sampled_binary_y = [torch.stack(s).to(torch.float32) for s in sampled_binary_y] - sampled_binary_y = torch.stack(sampled_binary_y) - - # gist for below - while we find the mismatch, we need to update the batched inputs - # else it would still generate masks using mismatching prompts, and it doesn't help us - # with the subiterations again. hence we clip the number of input points as well - f_objs = sampled_binary_y.shape[1] - batched_inputs = [ - {k: (v[:f_objs] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()} - for inp in batched_inputs - ] - - loss, mask_loss, iou_regression_loss, model_iou = self._update_masks(batched_inputs, y, - sampled_binary_y, sampled_ids, - num_subiter=self.n_sub_iteration, - multimask_output=multimask_output) + (loss, mask_loss, iou_regression_loss, model_iou, + sampled_binary_y) = self._interactive_train_iteration(x, y) backprop(loss) @@ -335,33 +355,42 @@ def _train_epoch_impl(self, progress, forward_context, backprop): t_per_iter = (time.time() - t_per_iter) / n_iter return t_per_iter + def _interactive_val_iteration(self, x, y, val_iteration): + n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices_for_val(val_iteration) + + batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch) + batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids) + + image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) + + batched_outputs = self.model( + batched_inputs, + image_embeddings=image_embeddings, + multimask_output=multimask_output, + ) + + loss, mask_loss, iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot) + # We use the dice loss over the masks as metric. + metric = mask_loss + model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs])) + + return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric + def _validate_impl(self, forward_context): self.model.eval() + input_check_done = False + val_iteration = 0 metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0 with torch.no_grad(): for x, y in self.val_loader: - with forward_context(): - n_samples = self._update_samples_for_gt_instances(y, self.n_objects_per_batch) - - (n_pos, n_neg, - get_boxes, multimask_output) = self._get_prompt_and_multimasking_choices_for_val(val_iteration) + input_check_done = self._check_input_normalization(x, input_check_done) - batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, n_samples) - - batched_outputs = self.model(batched_inputs, multimask_output=multimask_output) - - assert len(y) == len(sampled_ids) - sampled_binary_y = torch.stack( - [torch.isin(y[i], torch.tensor(sampled_ids[i])) for i in range(len(y))] - ).to(torch.float32) - - loss, mask_loss, iou_regression_loss, model_iou = self._get_net_loss(batched_outputs, - y, sampled_ids) - - metric = self._get_val_metric(batched_outputs, sampled_binary_y) + with forward_context(): + (loss, mask_loss, iou_regression_loss, model_iou, + sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration) loss_val += loss.item() metric_val += metric.item() diff --git a/micro_sam/training/trainable_sam.py b/micro_sam/training/trainable_sam.py index 99728a1b..81c7d8ca 100644 --- a/micro_sam/training/trainable_sam.py +++ b/micro_sam/training/trainable_sam.py @@ -1,10 +1,11 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Tuple, Union import torch from torch import nn from torch.nn import functional as F from segment_anything.modeling import Sam +from segment_anything.utils.transforms import ResizeLongestSide # simple wrapper around SAM in order to keep things trainable @@ -23,52 +24,62 @@ def __init__( super().__init__() self.sam = sam self.device = device + self.transform = ResizeLongestSide(sam.image_encoder.img_size) - def preprocess(self, x: torch.Tensor) -> torch.Tensor: - """Normalize pixel values and pad to a square input. + def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]: + """Resize, normalize pixel values and pad to a square input. Args: x: The input tensor. Returns: - The normalized and padded tensor. + The resized, normalized and padded tensor. + The shape of the image after resizing. """ + + # Resize longest side to match the image encoder. + x = self.transform.apply_image_torch(x) + input_size = x.shape[-2:] + # Normalize colors - x = (x - self.sam.pixel_mean) / self.sam.pixel_std + x = (x - self.sam.pixel_mean.unsqueeze(0)) / self.sam.pixel_std.unsqueeze(0) # Pad h, w = x.shape[-2:] padh = self.sam.image_encoder.img_size - h padw = self.sam.image_encoder.img_size - w x = F.pad(x, (0, padw, 0, padh)) - return x - - def image_embeddings_oft(self, input_images): - """@private""" + return x, input_size + + def image_embeddings_oft(self, batched_inputs): + # Compute the input images. + input_images, input_size = self.preprocess( + torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.device) + ) + # Update the input size for each input in the batch. + for i in range(len(batched_inputs)): + batched_inputs[i]["input_size"] = input_size + # Compute the image embeddings. image_embeddings = self.sam.image_encoder(input_images) - return image_embeddings + return image_embeddings, batched_inputs # batched inputs follow the same syntax as the input to sam.forward def forward( self, batched_inputs: List[Dict[str, Any]], + image_embeddings: torch.Tensor, multimask_output: bool = False, - image_embeddings: Optional[torch.Tensor] = None, ) -> List[Dict[str, Any]]: """Forward pass. Args: batched_inputs: The batched input images and prompts. - multimask_output: Whether to predict mutiple or just a single mask. image_embeddings: The precompute image embeddings. If not passed then they will be computed. + multimask_output: Whether to predict mutiple or just a single mask. Returns: The predicted segmentation masks and iou values. """ - input_images = torch.stack([self.preprocess(x=x["image"].to(self.device)) for x in batched_inputs], dim=0) - if image_embeddings is None: - image_embeddings = self.sam.image_encoder(input_images) - outputs = [] for image_record, curr_embedding in zip(batched_inputs, image_embeddings): if "point_coords" in image_record: @@ -102,7 +113,7 @@ def forward( masks = self.sam.postprocess_masks( low_res_masks, - input_size=image_record["image"].shape[-2:], + input_size=image_record["input_size"], original_size=image_record["original_size"], ) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 52096462..e9839956 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -1,33 +1,52 @@ import os +from math import ceil, floor from typing import List, Optional, Union import numpy as np +import torch + +from segment_anything.utils.transforms import ResizeLongestSide from ..prompt_generators import PointAndBoxPromptGenerator -from ..util import get_centers_and_bounding_boxes, get_sam_model, segmentation_to_one_hot, _get_device +from ..util import get_centers_and_bounding_boxes, get_sam_model, segmentation_to_one_hot, get_device from .trainable_sam import TrainableSAM +from torch_em.transform.label import PerObjectDistanceTransform +from torch_em.transform.raw import normalize_percentile, normalize + + +def identity(x): + """Identity transformation. + + This is a helper function to skip data normalization when finetuning SAM. + Data normalization is performed within the model and should thus be skipped as + a preprocessing step in training. + """ + return x + def get_trainable_sam_model( model_type: str = "vit_h", - device: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, checkpoint_path: Optional[Union[str, os.PathLike]] = None, freeze: Optional[List[str]] = None, ) -> TrainableSAM: """Get the trainable sam model. Args: - model_type: The type of the segment anything model. + model_type: The segment anything model that should be finetuned. + The weights of this model will be used for initialization, unless a + custom weight file is passed via `checkpoint_path`. + device: The device to use for training. checkpoint_path: Path to a custom checkpoint from which to load the model weights. freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated. - device: The device to use for training. Returns: The trainable segment anything model. """ # set the device here so that the correct one is passed to TrainableSAM below - device = _get_device(device) + device = get_device(device) _, sam = get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path, return_sam=True) # freeze components of the model if freeze was passed @@ -55,21 +74,35 @@ class ConvertToSamInputs: """Convert outputs of data loader to the expected batched inputs of the SegmentAnything model. Args: + transform: The transformation to resize the prompts. Should be the same transform used in the + model to resize the inputs. If `None` the prompts will not be resized. dilation_strength: The dilation factor. It determines a "safety" border from which prompts are not sampled to avoid ambiguous prompts due to imprecise groundtruth masks. box_distortion_factor: Factor for distorting the box annotations derived from the groundtruth masks. - Not yet implemented. """ def __init__( self, + transform: Optional[ResizeLongestSide], dilation_strength: int = 10, box_distortion_factor: Optional[float] = None, ) -> None: self.dilation_strength = dilation_strength - # TODO implement the box distortion logic - if box_distortion_factor is not None: - raise NotImplementedError + self.transform = identity if transform is None else transform + self.box_distortion_factor = box_distortion_factor + + def _distort_boxes(self, bbox_coordinates, shape): + distorted_boxes = [] + for bbox in bbox_coordinates: + # The bounding box is parametrized by y0, x0, y1, x1. + y0, x0, y1, x1 = bbox + ly, lx = y1 - y0, x1 - x0 + y0 = int(round(max(0, y0 - np.random.uniform(0, self.box_distortion_factor) * ly))) + y1 = int(round(min(shape[0], y1 + np.random.uniform(0, self.box_distortion_factor) * ly))) + x0 = int(round(max(0, x0 - np.random.uniform(0, self.box_distortion_factor) * lx))) + x1 = int(round(min(shape[1], x1 + np.random.uniform(0, self.box_distortion_factor) * lx))) + distorted_boxes.append([y0, x0, y1, x1]) + return distorted_boxes def _get_prompt_lists(self, gt, n_samples, prompt_generator): """Returns a list of "expected" prompts subjected to the random input attributes for prompting.""" @@ -87,6 +120,8 @@ def _get_prompt_lists(self, gt, n_samples, prompt_generator): # only keep the bounding boxes for sampled cell ids bbox_coordinates = [bbox_coordinates[sampled_id] for sampled_id in sampled_cell_ids] + if self.box_distortion_factor is not None: + bbox_coordinates = self._distort_boxes(bbox_coordinates, shape=gt.shape[-2:]) # convert the gt to the one-hot-encoded masks for the sampled cell ids object_masks = segmentation_to_one_hot(gt, None if n_samples is None else sampled_cell_ids) @@ -98,7 +133,6 @@ def _get_prompt_lists(self, gt, n_samples, prompt_generator): def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None): """Convert the outputs of dataloader and prompt settings to the batch format expected by SAM. """ - # condition to see if we get point prompts, then we (ofc) use point-prompting # else we don't use point prompting if n_pos == 0 and n_neg == 0: @@ -134,11 +168,69 @@ def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None): batched_input = {"image": image, "original_size": image.shape[1:]} if get_boxes: - batched_input["boxes"] = box_prompts + batched_input["boxes"] = self.transform.apply_boxes_torch( + box_prompts, original_size=gt.shape[-2:] + ) if self.transform is not None else box_prompts if get_points: - batched_input["point_coords"] = point_prompts + batched_input["point_coords"] = self.transform.apply_coords_torch( + point_prompts, original_size=gt.shape[-2:] + ) if self.transform is not None else point_prompts batched_input["point_labels"] = point_label_prompts batched_inputs.append(batched_input) return batched_inputs, batched_sampled_cell_ids_list + + +# +# Raw and Label Transformations for the Generalist and Specialist finetuning +# + + +class ResizeRawTrafo: + def __init__(self, desired_shape, do_rescaling=True, padding="constant"): + self.desired_shape = desired_shape + self.padding = padding + self.do_rescaling = do_rescaling + + def __call__(self, raw): + if self.do_rescaling: + raw = normalize_percentile(raw, axis=(1, 2)) + raw = np.mean(raw, axis=0) + raw = normalize(raw) + raw = raw * 255 + + tmp_ddim = (self.desired_shape[0] - raw.shape[0], self.desired_shape[1] - raw.shape[1]) + ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) + raw = np.pad( + raw, + pad_width=((ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), + mode=self.padding + ) + assert raw.shape == self.desired_shape + return raw + + +class ResizeLabelTrafo: + def __init__(self, desired_shape, padding="constant", min_size=0): + self.desired_shape = desired_shape + self.padding = padding + self.min_size = min_size + + def __call__(self, labels): + distance_trafo = PerObjectDistanceTransform( + distances=True, boundary_distances=True, directed_distances=False, + foreground=True, instances=True, min_size=self.min_size + ) + labels = distance_trafo(labels) + + # choosing H and W from labels (4, H, W), from above dist trafo outputs + tmp_ddim = (self.desired_shape[0] - labels.shape[1], self.desired_shape[0] - labels.shape[2]) + ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) + labels = np.pad( + labels, + pad_width=((0, 0), (ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), + mode=self.padding + ) + assert labels.shape[1:] == self.desired_shape, labels.shape + return labels diff --git a/micro_sam/util.py b/micro_sam/util.py index 28460479..1d0a49f1 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -4,16 +4,15 @@ import hashlib import os +from pathlib import Path import pickle import warnings from collections import OrderedDict -from shutil import copyfileobj from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union import imageio.v3 as imageio import numpy as np import pooch -import requests import torch import vigra import zarr @@ -30,106 +29,127 @@ from segment_anything import sam_model_registry, SamPredictor VIT_T_SUPPORT = False +try: + import xxhash + HAS_XXH128 = hasattr(xxhash, 'xxh128') +except ImportError: + HAS_XXH128 = False + try: from napari.utils import progress as tqdm except ImportError: from tqdm import tqdm -_MODEL_URLS = { - # the default segment anything models - "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", - "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", - "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", - # the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM - "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", - # first version of finetuned models on zenodo - "vit_h_lm": "https://zenodo.org/record/8250299/files/vit_h_lm.pth?download=1", - "vit_b_lm": "https://zenodo.org/record/8250281/files/vit_b_lm.pth?download=1", - "vit_h_em": "https://zenodo.org/record/8250291/files/vit_h_em.pth?download=1", - "vit_b_em": "https://zenodo.org/record/8250260/files/vit_b_em.pth?download=1", -} -_CACHE_DIR = os.environ.get('MICROSAM_CACHEDIR') or pooch.os_cache('micro_sam') -_CHECKPOINT_FOLDER = os.path.join(_CACHE_DIR, 'models') -_CHECKSUMS = { - # the default segment anything models - "vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e", - "vit_l": "3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622", - "vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912", - # the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM - "vit_t": "6dbb90523a35330fedd7f1d3dfc66f995213d81b29a5ca8108dbcdd4e37d6c2f", - # first version of finetuned models on zenodo - "vit_h_lm": "9a65ee0cddc05a98d60469a12a058859c89dc3ea3ba39fed9b90d786253fbf26", - "vit_b_lm": "5a59cc4064092d54cd4d92cd967e39168f3760905431e868e474d60fe5464ecd", - "vit_h_em": "ae3798a0646c8df1d4db147998a2d37e402ff57d3aa4e571792fbb911d8a979c", - "vit_b_em": "c04a714a4e14a110f0eec055a65f7409d54e6bf733164d2933a0ce556f7d6f81", -} -# this is required so that the downloaded file is not called 'download' -_DOWNLOAD_NAMES = { - "vit_t": "vit_t_mobile_sam.pth", - "vit_h_lm": "vit_h_lm.pth", - "vit_b_lm": "vit_b_lm.pth", - "vit_h_em": "vit_h_em.pth", - "vit_b_em": "vit_b_em.pth", -} + # this is the default model used in micro_sam # currently set to the default vit_h _DEFAULT_MODEL = "vit_h" +# The valid model types. Each type corresponds to the architecture of the +# vision transformer used within SAM. +_MODEL_TYPES = ("vit_h", "vit_b", "vit_l", "vit_t") + # TODO define the proper type for image embeddings ImageEmbeddings = Dict[str, Any] """@private""" +def get_cache_directory() -> None: + """Get micro-sam cache directory location. + + Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory. + """ + default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam")) + cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory)) + return cache_directory + + # # Functionality for model download and export # -def _download(url, path, model_type): - with requests.get(url, stream=True, verify=True) as r: - if r.status_code != 200: - r.raise_for_status() - raise RuntimeError(f"Request to {url} returned status code {r.status_code}") - file_size = int(r.headers.get("Content-Length", 0)) - desc = f"Download {url} to {path}" - if file_size == 0: - desc += " (unknown file size)" - with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f: - copyfileobj(r_raw, f) - - # validate the checksum - expected_checksum = _CHECKSUMS[model_type] - if expected_checksum is None: - return - with open(path, "rb") as f: - file_ = f.read() - checksum = hashlib.sha256(file_).hexdigest() - if checksum != expected_checksum: - raise RuntimeError( - "The checksum of the download does not match the expected checksum." - f"Expected: {expected_checksum}, got: {checksum}" - ) - print("Download successful and checksums agree.") +def microsam_cachedir() -> None: + """Return the micro-sam cache directory. + Returns the top level cache directory for micro-sam models and sample data. -def _get_checkpoint(model_type, checkpoint_path=None): - if checkpoint_path is None: - checkpoint_url = _MODEL_URLS[model_type] - checkpoint_name = _DOWNLOAD_NAMES.get(model_type, checkpoint_url.split("/")[-1]) - checkpoint_path = os.path.join(_CHECKPOINT_FOLDER, checkpoint_name) + Every time this function is called, we check for any user updates made to + the MICROSAM_CACHEDIR os environment variable since the last time. + """ + cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam") + return cache_directory - # download the checkpoint if necessary - if not os.path.exists(checkpoint_path): - os.makedirs(_CHECKPOINT_FOLDER, exist_ok=True) - _download(checkpoint_url, checkpoint_path, model_type) - elif not os.path.exists(checkpoint_path): - raise ValueError(f"The checkpoint path {checkpoint_path} that was passed does not exist.") - return checkpoint_path +def models(): + """Return the segmentation models registry. + + We recreate the model registry every time this function is called, + so any user changes to the default micro-sam cache directory location + are respected. + """ + + # Provide hashes in both xxh128 (fast, but not cryptographically secure), + # and as sha256 (as a fallback) to validate if the file has been correctly + # downloaded. + # https://github.com/computational-cell-analytics/micro-sam/issues/283 + # To generate the xxh128 hash + # + # xxh128sum filename + # + # You may need to install xxhash with conda or your system package manager. + registry_sha256 = { + # the default segment anything models + "vit_h": "sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e", + "vit_l": "sha256:3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622", + "vit_b": "sha256:ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912", + # the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM + "vit_t": "sha256:6dbb90523a35330fedd7f1d3dfc66f995213d81b29a5ca8108dbcdd4e37d6c2f", + # first version of finetuned models on zenodo + "vit_b_lm": "sha256:e8f5feb1ad837a7507935409c7f83f7c8af11c6e39cfe3df03f8d3bd4a358449", + "vit_b_em_organelles": "sha256:8fabbe38a427a0c91bbe6518a5c0f103f36b73e6ee6c86fbacd32b4fc66294b4", + "vit_b_em_boundaries": "sha256:d87348b2adef30ab427fb787d458643300eb30624a0e808bf36af21764705f4f", + } + registry_xxh128 = { + # the default segment anything models + "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2", + "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a", + "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b", + # the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM + "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0", + # first version of finetuned models on zenodo + "vit_b_lm": "xxh128:6b061eb8684d9d5f55545330d6dce50d", + "vit_b_em_organelles": "xxh128:3919c2b761beba7d3f4ece342c9f5369", + "vit_b_em_boundaries": "xxh128:3099fe6339f5be91ca84db889db1909f", + } + + models = pooch.create( + path=os.path.join(microsam_cachedir(), "models"), + base_url="", + registry=registry_xxh128 if HAS_XXH128 else registry_sha256, + # Now specify custom URLs for some of the files in the registry. + urls={ + # the default segment anything models + "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", + "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", + "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", + # the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM + "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", + # first version of finetuned models on zenodo + "vit_b_lm": "https://zenodo.org/records/10524791/files/vit_b_lm.pth?download=1", + "vit_b_em_organelles": "https://zenodo.org/records/10524828/files/vit_b_em_organelles.pth?download=1", + "vit_b_em_boundaries": "https://zenodo.org/records/10524894/files/vit_b_em_boundaries.pth?download=1", + }, + ) + return models def _get_default_device(): + # check that we're in CI and use the CPU if we are + # otherwise the tests may run out of memory on MAC if MPS is used. + if os.getenv("GITHUB_ACTIONS") == "true": + return "cpu" # Use cuda enabled gpu if it's available. if torch.cuda.is_available(): device = "cuda" @@ -144,17 +164,29 @@ def _get_default_device(): return device -def _get_device(device=None): +def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: + """Get the torch device. + + If no device is passed the default device for your system is used. + Else it will be checked if the device you have passed is supported. + + Args: + device: The input device. + + Returns: + The device. + """ if device is None or device == "auto": device = _get_default_device() else: - if device.lower() == "cuda": + device_type = device if isinstance(device, str) else device.type + if device_type.lower() == "cuda": if not torch.cuda.is_available(): raise RuntimeError("PyTorch CUDA backend is not available.") - elif device.lower() == "mps": + elif device_type.lower() == "mps": if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") - elif device.lower() == "cpu": + elif device_type.lower() == "cpu": pass # cpu is always available else: raise RuntimeError(f"Unsupported device: {device}\n" @@ -166,7 +198,7 @@ def _available_devices(): available_devices = [] for i in ["cuda", "mps", "cpu"]: try: - device = _get_device(i) + device = get_device(i) except RuntimeError: pass else: @@ -176,15 +208,22 @@ def _available_devices(): def get_sam_model( model_type: str = _DEFAULT_MODEL, - device: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, checkpoint_path: Optional[Union[str, os.PathLike]] = None, return_sam: bool = False, ) -> SamPredictor: r"""Get the SegmentAnything Predictor. - This function will download the required model checkpoint or load it from file if it - was already downloaded. - This location can be changed by setting the environment variable: MICROSAM_CACHEDIR. + This function will download the required model or load it from the cached weight file. + This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. + The name of the requested model can be set via `model_type`. + See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models + for an overview of the available models + + Alternatively this function can also load a model from weights stored in a local filepath. + The corresponding file path is given via `checkpoint_path`. In this case `model_type` + must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for + a SAM model with vit_b encoder. By default the models are downloaded to a folder named 'micro_sam/models' inside your default cache directory, eg: @@ -195,31 +234,53 @@ def get_sam_model( https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html Args: - device: The device for the model. If none is given will use GPU if available. model_type: The SegmentAnything model to use. Will use the standard vit_h model by default. - checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. + To get a list of all available model names you can call `get_model_names`. + device: The device for the model. If none is given will use GPU if available. + checkpoint_path: The path to a file with weights that should be used instead of using the + weights corresponding to `model_type`. If given, `model_type` must match the architecture + corresponding to the weight file. E.g. if you use weights for SAM with vit_b encoder + then `model_type` must be given as "vit_b". return_sam: Return the sam model object as well as the predictor. Returns: The segment anything predictor. """ - checkpoint = _get_checkpoint(model_type, checkpoint_path) - device = _get_device(device) + device = get_device(device) - # Our custom model types have a suffix "_...". This suffix needs to be stripped + # We support passing a local filepath to a checkpoint. + # In this case we do not download any weights but just use the local weight file, + # as it is, without copying it over anywhere or checking it's hashes. + + # checkpoint_path has not been passed, we download a known model and derive the correct + # URL from the model_type. If the model_type is invalid pooch will raise an error. + if checkpoint_path is None: + model_registry = models() + checkpoint = model_registry.fetch(model_type) + # checkpoint_path has been passed, we use it instead of downloading a model. + else: + # Check if the file exists and raise an error otherwise. + # We can't check any hashes here, and we don't check if the file is actually a valid weight file. + # (If it isn't the model creation will fail below.) + if not os.path.exists(checkpoint_path): + raise ValueError(f"Checkpoint at {checkpoint_path} could not be found.") + checkpoint = checkpoint_path + + # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped # before calling sam_model_registry. - model_type_ = model_type[:5] - assert model_type_ in ("vit_h", "vit_b", "vit_l", "vit_t") - if model_type == "vit_t" and not VIT_T_SUPPORT: + abbreviated_model_type = model_type[:5] + if abbreviated_model_type not in _MODEL_TYPES: + raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}") + if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT: raise RuntimeError( "mobile_sam is required for the vit-tiny." "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'" ) - sam = sam_model_registry[model_type_](checkpoint=checkpoint) + sam = sam_model_registry[abbreviated_model_type](checkpoint=checkpoint) sam.to(device=device) predictor = SamPredictor(sam) - predictor.model_type = model_type + predictor.model_type = abbreviated_model_type if return_sam: return predictor, sam return predictor @@ -241,7 +302,7 @@ def find_class(self, module, name): def get_custom_sam_model( checkpoint_path: Union[str, os.PathLike], model_type: str = "vit_h", - device: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, return_sam: bool = False, return_state: bool = False, ) -> SamPredictor: @@ -252,8 +313,8 @@ def get_custom_sam_model( Args: checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. + model_type: The SegmentAnything model_type for the given checkpoint. device: The device for the model. If none is given will use GPU if available. - model_type: The SegmentAnything model to use. return_sam: Return the sam model object as well as the predictor. return_state: Return the full state of the checkpoint in addition to the predictor. @@ -266,7 +327,7 @@ def get_custom_sam_model( custom_pickle = pickle custom_pickle.Unpickler = _CustomUnpickler - device = _get_device(device) + device = get_device(device) sam = sam_model_registry[model_type]() # load the model state, ignoring any attributes that can't be found by pickle @@ -302,7 +363,7 @@ def export_custom_sam_model( Args: checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. - model_type: The SegmentAnything model type to use (vit_h, vit_b or vit_l). + model_type: The SegmentAnything model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t). save_path: Where to save the exported model. """ _, state = get_custom_sam_model( @@ -317,7 +378,9 @@ def export_custom_sam_model( def get_model_names() -> Iterable: - return _MODEL_URLS.keys() + model_registry = models() + model_names = model_registry.registry.keys() + return model_names # @@ -574,6 +637,7 @@ def precompute_image_embeddings( assert save_path is not None, "Tiled prediction is only supported when the embeddings are saved to file." if save_path is not None: + save_path = str(save_path) data_signature = _compute_data_signature(input_) f = zarr.open(save_path, "a") diff --git a/pyproject.toml b/pyproject.toml index 1ce35bd9..38c4ba01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools>=42.0.0", "wheel"] build-backend = "setuptools.build_meta" [tool.pytest.ini_options] -addopts = "-v --durations=10 --cov=micro_sam --cov-report xml:coverage.xml" +addopts = "-v --durations=10 --cov=micro_sam --cov-report xml:coverage.xml --cov-report term-missing" markers = [ "slow: marks tests as slow (deselect with '-m \"not slow\"')", "gui: marks GUI tests (deselect with '-m \"not gui\"')", diff --git a/requirements-dev.txt b/requirements-dev.txt index b9591257..69419a61 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,5 @@ coverage line-profiler -memray napari[all] pdoc pytest diff --git a/scripts/.gitignore b/scripts/.gitignore new file mode 100644 index 00000000..56241314 --- /dev/null +++ b/scripts/.gitignore @@ -0,0 +1,2 @@ +new_models/ +exported_models/ diff --git a/scripts/export_models_for_upload.py b/scripts/export_models_for_upload.py new file mode 100644 index 00000000..7d097171 --- /dev/null +++ b/scripts/export_models_for_upload.py @@ -0,0 +1,62 @@ +"""Helper scripts to export models for upload to zenodo. +""" + +import hashlib +import os +import warnings +from glob import glob + +import xxhash +from micro_sam.util import export_custom_sam_model + +BUF_SIZE = 65536 # lets read stuff in 64kb chunks! + + +def export_model(model_path, model_type, export_name): + output_folder = "./exported_models" + os.makedirs(output_folder, exist_ok=True) + + output_path = os.path.join(output_folder, export_name) + if os.path.exists(output_path): + print("The model", export_name, "has already been exported.") + return + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + export_custom_sam_model( + checkpoint_path=model_path, + model_type=model_type, + save_path=output_path, + ) + + print("Exported", export_name) + + sha_checksum = hashlib.sha256() + xxh_checksum = xxhash.xxh128() + + with open(output_path, "rb") as f: + while True: + data = f.read(BUF_SIZE) + if not data: + break + sha_checksum.update(data) + xxh_checksum.update(data) + + print("sha256:", f"sha256:{sha_checksum.hexdigest()}") + print("xxh128:", f"xxh128:{xxh_checksum.hexdigest()}") + + +def export_all_models(): + models = glob(os.path.join("./new_models/*.pt")) + model_type = "vit_b" + for model_path in models: + export_name = os.path.basename(model_path).replace(".pt", ".pth") + export_model(model_path, model_type, export_name) + + +def main(): + export_all_models() + + +if __name__ == "__main__": + main() diff --git a/scripts/plotting/results/benchmarking/mitonet/lucchi/results/mitonet.csv b/scripts/plotting/results/benchmarking/mitonet/lucchi/results/mitonet.csv new file mode 100644 index 00000000..8c6e6aa6 --- /dev/null +++ b/scripts/plotting/results/benchmarking/mitonet/lucchi/results/mitonet.csv @@ -0,0 +1,2 @@ +msa,sa50,sa75 +0.3821131154700797,0.5276421448676556,0.43452946854553287 diff --git a/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_b/results/amg.csv b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_b/results/amg.csv new file mode 100644 index 00000000..87015b49 --- /dev/null +++ b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_b/results/amg.csv @@ -0,0 +1,2 @@ +msa,sa50,sa75 +0.5342480164898857,0.7001000850738243,0.6656686582248956 diff --git a/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_b/results/instance_segmentation_with_decoder.csv b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_b/results/instance_segmentation_with_decoder.csv new file mode 100644 index 00000000..bf1c19cf --- /dev/null +++ b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_b/results/instance_segmentation_with_decoder.csv @@ -0,0 +1,2 @@ +msa,sa50,sa75 +0.5330027656122025,0.7676402990129684,0.652398702013908 diff --git a/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_b/results/iterative_prompts_start_box.csv b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_b/results/iterative_prompts_start_box.csv new file mode 100644 index 00000000..183f69fb --- /dev/null +++ b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_b/results/iterative_prompts_start_box.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.8261502647422483,0.9962019688699532,0.9587797253125023 +1,0.852095180104265,0.9962019688699532,0.9763074723164177 +2,0.8703469818868097,0.997010049678034,0.9826575126481513 +3,0.8826046761933478,0.997010049678034,0.9884891776893024 +4,0.8952191354356615,0.9983448260128102,0.9921916416185192 +5,0.9042341719712977,0.9972337149016992,0.9933638364270776 +6,0.9142569676393435,0.9977650017966224,0.9950974168563101 +7,0.9232511662751378,0.99802736569535,0.9952354606307174 diff --git a/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_b/results/iterative_prompts_start_point.csv b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_b/results/iterative_prompts_start_point.csv new file mode 100644 index 00000000..7f5dbbf7 --- /dev/null +++ b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_b/results/iterative_prompts_start_point.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.6303760953286199,0.9283613770431891,0.7259490496282851 +1,0.7271479131894507,0.9693551556041675,0.8578322992044398 +2,0.7863883129355541,0.9882171776917185,0.9131861817400104 +3,0.8244429165101651,0.992558379257984,0.9423396807990476 +4,0.85076125965045,0.9947186037502244,0.9606351620065199 +5,0.8695850640176421,0.9947186037502244,0.9736614111043315 +6,0.8858293116918424,0.9956510046826252,0.985177437165812 +7,0.9006762354353417,0.9956510046826252,0.9900819880387424 diff --git a/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_h/results/amg.csv b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_h/results/amg.csv new file mode 100644 index 00000000..e4d9d026 --- /dev/null +++ b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_h/results/amg.csv @@ -0,0 +1,2 @@ +msa,sa50,sa75 +0.509165707266242,0.6785237716539984,0.6391965791007382 diff --git a/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_h/results/instance_segmentation_with_decoder.csv b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_h/results/instance_segmentation_with_decoder.csv new file mode 100644 index 00000000..654c38e4 --- /dev/null +++ b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_h/results/instance_segmentation_with_decoder.csv @@ -0,0 +1,2 @@ +msa,sa50,sa75 +0.5148520320549186,0.7408021810940641,0.6386985464822527 diff --git a/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_h/results/iterative_prompts_start_box.csv b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_h/results/iterative_prompts_start_box.csv new file mode 100644 index 00000000..26f5157b --- /dev/null +++ b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_h/results/iterative_prompts_start_box.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.824577437255196,0.9954761795078002,0.9688739818357288 +1,0.852045249081702,0.9966827507143714,0.9819713915667733 +2,0.8737607670348032,0.9966827507143714,0.9877099894688827 +3,0.889568049537866,0.9972166612482817,0.9915291332880267 +4,0.9024569057881563,0.9980680320996526,0.993835749231006 +5,0.9129546791199118,0.9977650017966224,0.9938528028844233 +6,0.9255270629001708,0.9974908315224521,0.9959251749567957 +7,0.9367785372616947,0.99804179570978,0.9959251749567957 diff --git a/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_h/results/iterative_prompts_start_point.csv b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_h/results/iterative_prompts_start_point.csv new file mode 100644 index 00000000..e51e1b5c --- /dev/null +++ b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/with_cem/vit_h/results/iterative_prompts_start_point.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.579290764506784,0.9145667130527934,0.663061446516525 +1,0.711223164027884,0.9659873292230272,0.8385259072208169 +2,0.7725345959611514,0.9837360973608358,0.9038605610816479 +3,0.8151091853864214,0.9905447363352502,0.9396064560897709 +4,0.8479293384018557,0.9941676395628963,0.967699874922724 +5,0.8715108493515953,0.9947186037502244,0.9793321501149914 +6,0.8867285062731892,0.9955266845583052,0.9836012016091068 +7,0.9017271574781673,0.9955266845583052,0.9907187634499888 diff --git a/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_b/results/amg.csv b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_b/results/amg.csv new file mode 100644 index 00000000..5d246200 --- /dev/null +++ b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_b/results/amg.csv @@ -0,0 +1,2 @@ +msa,sa50,sa75 +0.5038740816911844,0.6916530979897862,0.6234874074658332 diff --git a/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_b/results/instance_segmentation_with_decoder.csv b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_b/results/instance_segmentation_with_decoder.csv new file mode 100644 index 00000000..e6365f61 --- /dev/null +++ b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_b/results/instance_segmentation_with_decoder.csv @@ -0,0 +1,2 @@ +msa,sa50,sa75 +0.4970742496165649,0.7452420673814841,0.5980313648084746 diff --git a/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_b/results/iterative_prompts_start_box.csv b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_b/results/iterative_prompts_start_box.csv new file mode 100644 index 00000000..96b38b5d --- /dev/null +++ b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_b/results/iterative_prompts_start_box.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.7980817680383172,0.9947186037502244,0.9511948325535252 +1,0.8339658206318005,0.9947186037502244,0.9746828171767048 +2,0.8581413134258775,0.9954761795078002,0.9887249158934203 +3,0.8718715601088602,0.996408580440201,0.9910260790300316 +4,0.883643615729137,0.996408580440201,0.9941915945275629 +5,0.8921355237737988,0.9959251749567957,0.9939148006144054 +6,0.9011553566265291,0.9972337149016992,0.9953742107694677 +7,0.907984560092733,0.9967503094182937,0.9952669442985649 diff --git a/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_b/results/iterative_prompts_start_point.csv b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_b/results/iterative_prompts_start_point.csv new file mode 100644 index 00000000..2656146b --- /dev/null +++ b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_b/results/iterative_prompts_start_point.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.573705332703672,0.8940105665715254,0.6452623262562155 +1,0.6761231181084724,0.9460567668446008,0.8092124017832902 +2,0.7395975839063142,0.9784934142924697,0.8710757887908471 +3,0.7855403291139634,0.9880705332390376,0.9157942376800113 +4,0.8174039010579577,0.9930026718075396,0.9412149795683246 +5,0.8408203072436587,0.9941414031730238,0.9566195083303835 +6,0.8572693308653753,0.9941414031730238,0.9665051576643006 +7,0.8727408840069822,0.9941414031730238,0.9765567275646329 diff --git a/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_h/results/amg.csv b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_h/results/amg.csv new file mode 100644 index 00000000..dab776c4 --- /dev/null +++ b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_h/results/amg.csv @@ -0,0 +1,2 @@ +msa,sa50,sa75 +0.4927872984735681,0.6655641793973945,0.6033959119446686 diff --git a/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_h/results/instance_segmentation_with_decoder.csv b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_h/results/instance_segmentation_with_decoder.csv new file mode 100644 index 00000000..edd6102b --- /dev/null +++ b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_h/results/instance_segmentation_with_decoder.csv @@ -0,0 +1,2 @@ +msa,sa50,sa75 +0.5237698595821274,0.7618132683170433,0.6272788566821614 diff --git a/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_h/results/iterative_prompts_start_box.csv b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_h/results/iterative_prompts_start_box.csv new file mode 100644 index 00000000..700b219e --- /dev/null +++ b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_h/results/iterative_prompts_start_box.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.8129605074293222,0.9949026116615051,0.9558793410451654 +1,0.8406403666254398,0.9960045400361606,0.9692579652540616 +2,0.8612491060128588,0.996278710310331,0.9830309709860364 +3,0.8734588571576237,0.9960045400361606,0.9882613298527589 +4,0.8834336828745714,0.9982149558829402,0.9875464167465418 +5,0.8936631411249588,0.9976377553057396,0.988737963105552 +6,0.9024688981911836,0.9979119255799098,0.990799706076386 +7,0.9109432840207509,0.997360961392582,0.9909966691191988 diff --git a/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_h/results/iterative_prompts_start_point.csv b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_h/results/iterative_prompts_start_point.csv new file mode 100644 index 00000000..edb5386f --- /dev/null +++ b/scripts/plotting/results/generalists/em/lucchi/mito_nuc_em_generalist_sam/without_cem/vit_h/results/iterative_prompts_start_point.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.6027445912937732,0.9172335935036897,0.6880826178900978 +1,0.6845654240118713,0.9453158488202603,0.8056415909945425 +2,0.7495587044027366,0.9766632940394245,0.8720955095220875 +3,0.7951689558412258,0.9869460582410448,0.9167149034160009 +4,0.8259307828926412,0.9912543586363016,0.9426454354528487 +5,0.8515779928568458,0.9947186037502244,0.9622181346365808 +6,0.8701314103835502,0.9950721391037598,0.9743748905472923 +7,0.8828525577428453,0.9950721391037598,0.9812539447601063 diff --git a/scripts/plotting/results/generalists/em/snemi/boundaries_em_generalist_sam/vit_b/results/instance_segmentation_with_decoder.csv b/scripts/plotting/results/generalists/em/snemi/boundaries_em_generalist_sam/vit_b/results/instance_segmentation_with_decoder.csv new file mode 100644 index 00000000..8e2ac860 --- /dev/null +++ b/scripts/plotting/results/generalists/em/snemi/boundaries_em_generalist_sam/vit_b/results/instance_segmentation_with_decoder.csv @@ -0,0 +1,2 @@ +msa,sa50,sa75 +0.024888275974812867,0.051031050062742556,0.022553745294398388 diff --git a/scripts/plotting/results/generalists/em/snemi/boundaries_em_generalist_sam/vit_b/results/iterative_prompts_start_box.csv b/scripts/plotting/results/generalists/em/snemi/boundaries_em_generalist_sam/vit_b/results/iterative_prompts_start_box.csv new file mode 100644 index 00000000..6a4fc0e5 --- /dev/null +++ b/scripts/plotting/results/generalists/em/snemi/boundaries_em_generalist_sam/vit_b/results/iterative_prompts_start_box.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.5931420782859859,0.9586730989120074,0.6636199936782969 +1,0.6549473706989701,0.9785305571354556,0.7666714340362685 +2,0.7037443928609264,0.9908780393225136,0.8448609845020343 +3,0.7385094585595114,0.9942437007521207,0.8896665643790644 +4,0.7630809301178176,0.996636490268866,0.9181397156788542 +5,0.781958651691727,0.9972987051861594,0.9360760514132309 +6,0.796064513295252,0.9981849872388289,0.9458111143061931 +7,0.8066201886751785,0.9975602764380568,0.9525798737233245 diff --git a/scripts/plotting/results/generalists/em/snemi/boundaries_em_generalist_sam/vit_b/results/iterative_prompts_start_point.csv b/scripts/plotting/results/generalists/em/snemi/boundaries_em_generalist_sam/vit_b/results/iterative_prompts_start_point.csv new file mode 100644 index 00000000..3141dc73 --- /dev/null +++ b/scripts/plotting/results/generalists/em/snemi/boundaries_em_generalist_sam/vit_b/results/iterative_prompts_start_point.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.06781802442209095,0.11035597152164174,0.07099810180687918 +1,0.1250209468809168,0.23051804437947307,0.11938731473367281 +2,0.1873409520013621,0.3700318401256433,0.16774486568194114 +3,0.2678832611487937,0.5346674127976223,0.236210185330558 +4,0.3497785688924579,0.6768315749301804,0.31548831501894625 +5,0.4211657474155045,0.783422439378113,0.397179966286553 +6,0.47556584243485683,0.8457153833801702,0.46163834475853127 +7,0.5181203367816799,0.8844379210426101,0.5171786345056514 diff --git a/scripts/plotting/results/generalists/em/snemi/boundaries_em_generalist_sam/vit_h/results/instance_segmentation_with_decoder.csv b/scripts/plotting/results/generalists/em/snemi/boundaries_em_generalist_sam/vit_h/results/instance_segmentation_with_decoder.csv new file mode 100644 index 00000000..971289ce --- /dev/null +++ b/scripts/plotting/results/generalists/em/snemi/boundaries_em_generalist_sam/vit_h/results/instance_segmentation_with_decoder.csv @@ -0,0 +1,2 @@ +msa,sa50,sa75 +0.03807335230559052,0.06823153347278738,0.03713053569983708 diff --git a/scripts/plotting/results/generalists/em/snemi/boundaries_em_generalist_sam/vit_h/results/iterative_prompts_start_box.csv b/scripts/plotting/results/generalists/em/snemi/boundaries_em_generalist_sam/vit_h/results/iterative_prompts_start_box.csv new file mode 100644 index 00000000..b5282927 --- /dev/null +++ b/scripts/plotting/results/generalists/em/snemi/boundaries_em_generalist_sam/vit_h/results/iterative_prompts_start_box.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.5409942594514658,0.9480112735657675,0.5578834545303268 +1,0.6022788767036837,0.9710142666740434,0.6690455671965096 +2,0.6509178684028675,0.9867755480702958,0.7595002481769721 +3,0.6891328677716513,0.9916777832075946,0.8259490475202 +4,0.7203093913729243,0.9945756205136492,0.8744511809230406 +5,0.7439657897385473,0.9972418498393842,0.9041722346654785 +6,0.7631092955876742,0.9974767453381869,0.9241645108314994 +7,0.7780195500590473,0.9980865929335853,0.937149576717324 diff --git a/scripts/plotting/results/generalists/em/snemi/boundaries_em_generalist_sam/vit_h/results/iterative_prompts_start_point.csv b/scripts/plotting/results/generalists/em/snemi/boundaries_em_generalist_sam/vit_h/results/iterative_prompts_start_point.csv new file mode 100644 index 00000000..60e3735b --- /dev/null +++ b/scripts/plotting/results/generalists/em/snemi/boundaries_em_generalist_sam/vit_h/results/iterative_prompts_start_point.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.07824609442508737,0.12611560440139125,0.08289226170781663 +1,0.13838063293336428,0.286456374652807,0.12136206294278082 +2,0.2133371840004478,0.4786130644513008,0.16966891976499596 +3,0.31598163094902676,0.6866874093357884,0.25002414635257064 +4,0.41259266496637415,0.8292136996522352,0.3507414958464562 +5,0.4903522188586341,0.9089347376985522,0.45684714233791074 +6,0.551551992713992,0.9492121591061679,0.5559470681920452 +7,0.5989469174458618,0.9718895803685446,0.6365100318648349 diff --git a/scripts/plotting/results/generalists/lm/livecell/vit_b/amg.csv b/scripts/plotting/results/generalists/lm/livecell/vit_b/amg.csv new file mode 100644 index 00000000..41aab08b --- /dev/null +++ b/scripts/plotting/results/generalists/lm/livecell/vit_b/amg.csv @@ -0,0 +1,2 @@ +msa,sa50,sa75 +0.2917001542286596,0.4567462735547909,0.32728600531874597 diff --git a/scripts/plotting/results/generalists/lm/livecell/vit_b/instance_segmentation_with_decoder.csv b/scripts/plotting/results/generalists/lm/livecell/vit_b/instance_segmentation_with_decoder.csv new file mode 100644 index 00000000..cedd40f5 --- /dev/null +++ b/scripts/plotting/results/generalists/lm/livecell/vit_b/instance_segmentation_with_decoder.csv @@ -0,0 +1,2 @@ +msa,sa50,sa75 +0.3473917092212614,0.5942895027345566,0.37384866020157653 diff --git a/scripts/plotting/results/generalists/lm/livecell/vit_b/iterative_prompts_start_box.csv b/scripts/plotting/results/generalists/lm/livecell/vit_b/iterative_prompts_start_box.csv new file mode 100644 index 00000000..132f5279 --- /dev/null +++ b/scripts/plotting/results/generalists/lm/livecell/vit_b/iterative_prompts_start_box.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.6330945228754941,0.9256385601978663,0.7283535456074498 +1,0.6788491708752833,0.9335667451743224,0.7948643289105897 +2,0.7118508958066381,0.9371373857064424,0.8358012406038333 +3,0.7367814333510339,0.9390256219928819,0.860267942997836 +4,0.7557968398827686,0.9398949128830458,0.8750424539582609 +5,0.7710724376385231,0.9407130872380144,0.8845256155291473 +6,0.7836134915519982,0.9411613826723576,0.8911077952612095 +7,0.7937366373436504,0.9415861729807385,0.8951474111639266 diff --git a/scripts/plotting/results/generalists/lm/livecell/vit_b/iterative_prompts_start_point.csv b/scripts/plotting/results/generalists/lm/livecell/vit_b/iterative_prompts_start_point.csv new file mode 100644 index 00000000..9b88f3cd --- /dev/null +++ b/scripts/plotting/results/generalists/lm/livecell/vit_b/iterative_prompts_start_point.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.4022141232291086,0.6839510431966765,0.4261450668119283 +1,0.4985659372951761,0.8161517451114053,0.5335983207155439 +2,0.5713240278477968,0.884423390806401,0.6252716738290416 +3,0.6284546879077642,0.9159000186474198,0.7040252654188905 +4,0.6718298761613183,0.930047558216046,0.7668840282755824 +5,0.7051150282241244,0.9356652770287635,0.8133598371993279 +6,0.7306451606066683,0.9384565789692653,0.8444631305527277 +7,0.750684675551144,0.9399579557052311,0.8644661599045912 diff --git a/scripts/plotting/results/generalists/lm/livecell/vit_h/amg.csv b/scripts/plotting/results/generalists/lm/livecell/vit_h/amg.csv new file mode 100644 index 00000000..bede4d34 --- /dev/null +++ b/scripts/plotting/results/generalists/lm/livecell/vit_h/amg.csv @@ -0,0 +1,2 @@ +msa,sa50,sa75 +0.314884954067805,0.488494700415521,0.3543373658931705 diff --git a/scripts/plotting/results/generalists/lm/livecell/vit_h/instance_segmentation_with_decoder.csv b/scripts/plotting/results/generalists/lm/livecell/vit_h/instance_segmentation_with_decoder.csv new file mode 100644 index 00000000..f823cec6 --- /dev/null +++ b/scripts/plotting/results/generalists/lm/livecell/vit_h/instance_segmentation_with_decoder.csv @@ -0,0 +1,2 @@ +msa,sa50,sa75 +0.36490571170392583,0.6166296447253928,0.39543062334954765 diff --git a/scripts/plotting/results/generalists/lm/livecell/vit_h/iterative_prompts_start_box.csv b/scripts/plotting/results/generalists/lm/livecell/vit_h/iterative_prompts_start_box.csv new file mode 100644 index 00000000..19a580da --- /dev/null +++ b/scripts/plotting/results/generalists/lm/livecell/vit_h/iterative_prompts_start_box.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.6391394987296813,0.9281843442585326,0.7380614045754574 +1,0.6832147373965357,0.9349383820683547,0.8025178396395575 +2,0.7146777330809361,0.9380643307602837,0.8420174341570058 +3,0.7381074376590195,0.9392540039898175,0.8641000494046639 +4,0.7561390805397994,0.9402659533358054,0.8773023816004493 +5,0.7703624716905995,0.9410450049588325,0.8858143730938148 +6,0.782213921723365,0.9416409138272931,0.8911462561917289 +7,0.7919951074798645,0.9420446879483949,0.8949113619323624 diff --git a/scripts/plotting/results/generalists/lm/livecell/vit_h/iterative_prompts_start_point.csv b/scripts/plotting/results/generalists/lm/livecell/vit_h/iterative_prompts_start_point.csv new file mode 100644 index 00000000..c3118c15 --- /dev/null +++ b/scripts/plotting/results/generalists/lm/livecell/vit_h/iterative_prompts_start_point.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.4208523420517224,0.7031741040772173,0.4512130744074879 +1,0.5197568673122674,0.8263446227261932,0.5646676671687164 +2,0.5898242907343642,0.889084819013037,0.6523007677376478 +3,0.6420830853457365,0.9171838231626205,0.7242280965307658 +4,0.6823784514828716,0.9303533860529599,0.7808488256519112 +5,0.7129630117238562,0.9357644494593024,0.8211979813504835 +6,0.7360397733510478,0.9384511032451075,0.8479774224042684 +7,0.7542274391598381,0.9400282894022737,0.8660802487555516 diff --git a/scripts/plotting/results/specialists/lm/livecell/vit_b/amg.csv b/scripts/plotting/results/specialists/lm/livecell/vit_b/amg.csv new file mode 100644 index 00000000..e6bbf541 --- /dev/null +++ b/scripts/plotting/results/specialists/lm/livecell/vit_b/amg.csv @@ -0,0 +1,2 @@ +msa,sa50,sa75 +0.340614750788787,0.5331129035333244,0.37988833753941736 diff --git a/scripts/plotting/results/specialists/lm/livecell/vit_b/instance_segmentation_with_decoder.csv b/scripts/plotting/results/specialists/lm/livecell/vit_b/instance_segmentation_with_decoder.csv new file mode 100644 index 00000000..0c78fea3 --- /dev/null +++ b/scripts/plotting/results/specialists/lm/livecell/vit_b/instance_segmentation_with_decoder.csv @@ -0,0 +1,2 @@ +msa,sa50,sa75 +0.4297564320374104,0.7055948450521679,0.4637000029385975 diff --git a/scripts/plotting/results/specialists/lm/livecell/vit_b/iterative_prompts_start_box.csv b/scripts/plotting/results/specialists/lm/livecell/vit_b/iterative_prompts_start_box.csv new file mode 100644 index 00000000..87ddcedd --- /dev/null +++ b/scripts/plotting/results/specialists/lm/livecell/vit_b/iterative_prompts_start_box.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.652875926432221,0.9304299982477234,0.7591236446644317 +1,0.6984752531829158,0.9365733483634476,0.8229925071247954 +2,0.7311799529861006,0.9391771584032967,0.8572039489804477 +3,0.7558905234460626,0.9404128080617948,0.8764408343116248 +4,0.7752863277083895,0.9411568379018722,0.8882348000627075 +5,0.790628509849978,0.9418886153239848,0.8951282269682678 +6,0.8030784764300517,0.9422612806288869,0.8996625962311235 +7,0.8130713632540388,0.942660067245244,0.9026155844461543 diff --git a/scripts/plotting/results/specialists/lm/livecell/vit_b/iterative_prompts_start_point.csv b/scripts/plotting/results/specialists/lm/livecell/vit_b/iterative_prompts_start_point.csv new file mode 100644 index 00000000..af3e1f35 --- /dev/null +++ b/scripts/plotting/results/specialists/lm/livecell/vit_b/iterative_prompts_start_point.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.4543124095927029,0.760923434410071,0.4822689837132261 +1,0.5414847398959654,0.8445012464352787,0.592391588514044 +2,0.6109414799207289,0.8946091092905274,0.6844871868934906 +3,0.6645384573918424,0.9206429119147603,0.757255131227229 +4,0.7052299652255886,0.9323508830610393,0.8082713030504648 +5,0.7358773038050391,0.9374804918714502,0.8442197603222338 +6,0.7596728534865324,0.9396986574132883,0.8673887225139528 +7,0.7779175541488949,0.9409888858644881,0.8814549756778675 diff --git a/scripts/plotting/results/specialists/lm/livecell/vit_h/amg.csv b/scripts/plotting/results/specialists/lm/livecell/vit_h/amg.csv new file mode 100644 index 00000000..4b5c09e2 --- /dev/null +++ b/scripts/plotting/results/specialists/lm/livecell/vit_h/amg.csv @@ -0,0 +1,2 @@ +msa,sa50,sa75 +0.35557008098665815,0.5467237623199277,0.39988779872687774 diff --git a/scripts/plotting/results/specialists/lm/livecell/vit_h/instance_segmentation_with_decoder.csv b/scripts/plotting/results/specialists/lm/livecell/vit_h/instance_segmentation_with_decoder.csv new file mode 100644 index 00000000..5ae1ff20 --- /dev/null +++ b/scripts/plotting/results/specialists/lm/livecell/vit_h/instance_segmentation_with_decoder.csv @@ -0,0 +1,2 @@ +msa,sa50,sa75 +0.43562168677003077,0.7101481071041281,0.4724010945117092 diff --git a/scripts/plotting/results/specialists/lm/livecell/vit_h/iterative_prompts_start_box.csv b/scripts/plotting/results/specialists/lm/livecell/vit_h/iterative_prompts_start_box.csv new file mode 100644 index 00000000..bb86bbef --- /dev/null +++ b/scripts/plotting/results/specialists/lm/livecell/vit_h/iterative_prompts_start_box.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.6477525607253328,0.9293435667901185,0.750961836817713 +1,0.693735216118805,0.9361974633095106,0.8151997912690174 +2,0.7268293734241834,0.9391599561244868,0.8522759597490676 +3,0.7517640480899525,0.9404351258631368,0.8732371412634055 +4,0.7712040187354087,0.9411665845807622,0.8852444277313106 +5,0.7867717345396241,0.9419076676393694,0.8933981817940133 +6,0.7995391516875582,0.9423489221390896,0.8984899738970564 +7,0.8098483564344346,0.94266370908642,0.9015961746688059 diff --git a/scripts/plotting/results/specialists/lm/livecell/vit_h/iterative_prompts_start_point.csv b/scripts/plotting/results/specialists/lm/livecell/vit_h/iterative_prompts_start_point.csv new file mode 100644 index 00000000..0f58be15 --- /dev/null +++ b/scripts/plotting/results/specialists/lm/livecell/vit_h/iterative_prompts_start_point.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.4580984475415567,0.7598746477541353,0.4899020214302746 +1,0.5457063544178609,0.8461680968138724,0.599651363565181 +2,0.6121136128616247,0.8947849316276938,0.6834677340135189 +3,0.6645826634245824,0.920422044391624,0.7535365889116948 +4,0.7042760038786391,0.9315880494266998,0.8041226927832791 +5,0.7343532791824768,0.9371007503346068,0.8407208009826476 +6,0.7570746329156439,0.9393347094549,0.8635389590850535 +7,0.7751650648450723,0.940732127865808,0.8781410717034517 diff --git a/scripts/plotting/results/vanilla/em/lucchi/vit_b/results/iterative_prompts_start_box.csv b/scripts/plotting/results/vanilla/em/lucchi/vit_b/results/iterative_prompts_start_box.csv new file mode 100644 index 00000000..fdcdf97d --- /dev/null +++ b/scripts/plotting/results/vanilla/em/lucchi/vit_b/results/iterative_prompts_start_box.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.7656302827765239,0.993561578956836,0.8853878921687266 +1,0.7601664912151728,0.9892744808066215,0.864789404942776 +2,0.7700276527037726,0.987682050015106,0.8748885907964866 +3,0.7741639003269413,0.9819913029053864,0.8705519066132903 +4,0.7808159139857336,0.9816377675518511,0.8778927543973372 +5,0.7833099302105511,0.9766198678699199,0.8792434187445083 +6,0.7822769759173137,0.9749789183687665,0.8721603769728383 +7,0.7774603112594476,0.9699668365212364,0.867588977493792 diff --git a/scripts/plotting/results/vanilla/em/lucchi/vit_b/results/iterative_prompts_start_point.csv b/scripts/plotting/results/vanilla/em/lucchi/vit_b/results/iterative_prompts_start_point.csv new file mode 100644 index 00000000..7288bda1 --- /dev/null +++ b/scripts/plotting/results/vanilla/em/lucchi/vit_b/results/iterative_prompts_start_point.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.5614468119917114,0.7934874066395642,0.5897764822693498 +1,0.5304145760150958,0.7594411610116851,0.5529174510399478 +2,0.5792585071710316,0.8041979313062088,0.6133951152054842 +3,0.6088889447438407,0.8286735873299749,0.662534955279396 +4,0.6302277525919898,0.8371599541144362,0.6935104344552925 +5,0.6451672740243729,0.8503599801471695,0.7143444911827651 +6,0.6569204477538413,0.8558594681750781,0.7280308830455173 +7,0.6596478614239633,0.8597133606048465,0.7317881364345408 diff --git a/scripts/plotting/results/vanilla/em/lucchi/vit_h/results/iterative_prompts_start_box.csv b/scripts/plotting/results/vanilla/em/lucchi/vit_h/results/iterative_prompts_start_box.csv new file mode 100644 index 00000000..b6467f21 --- /dev/null +++ b/scripts/plotting/results/vanilla/em/lucchi/vit_h/results/iterative_prompts_start_box.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.7771430156360238,0.9850042952216866,0.8711583809932459 +1,0.7957591495701103,0.9803550283042813,0.8795331360546182 +2,0.8220071257442851,0.986003910431308,0.8926682626790804 +3,0.8273556273285566,0.9854230481990245,0.9016556294371695 +4,0.8233235442835847,0.98219206530401,0.9035548442320063 +5,0.8157973757350253,0.9804358414586207,0.9034793904009755 +6,0.8089901480523957,0.9779320228994113,0.8997203531725002 +7,0.794904414456581,0.96953345319527,0.8892688144964402 diff --git a/scripts/plotting/results/vanilla/em/lucchi/vit_h/results/iterative_prompts_start_point.csv b/scripts/plotting/results/vanilla/em/lucchi/vit_h/results/iterative_prompts_start_point.csv new file mode 100644 index 00000000..40ba47be --- /dev/null +++ b/scripts/plotting/results/vanilla/em/lucchi/vit_h/results/iterative_prompts_start_point.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.5843784451123525,0.742402008633925,0.6384405961176579 +1,0.6649083935672512,0.8332149192173244,0.7306345499691015 +2,0.7062086978525708,0.8527215979985333,0.7760802376795052 +3,0.7289397369368157,0.8739891641650662,0.8005215570111466 +4,0.7410475895451568,0.8859541913795841,0.818111536730514 +5,0.7448246908647259,0.8895190269734052,0.831093864926383 +6,0.7416472747297068,0.8946441304816755,0.8355423462847196 +7,0.7365541782895114,0.8982808677745392,0.8351754824047912 diff --git a/scripts/plotting/results/vanilla/em/snemi/vit_b/results/iterative_prompts_start_box.csv b/scripts/plotting/results/vanilla/em/snemi/vit_b/results/iterative_prompts_start_box.csv new file mode 100644 index 00000000..381bcd0d --- /dev/null +++ b/scripts/plotting/results/vanilla/em/snemi/vit_b/results/iterative_prompts_start_box.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.5864531179382734,0.961215082332739,0.6315977728026171 +1,0.5859245927147814,0.9550607558657476,0.633053672714021 +2,0.581013572758431,0.9482646873886577,0.6184782018815123 +3,0.561764431159048,0.9298520727615479,0.5843703332632373 +4,0.5372600430393688,0.9100601059716407,0.5452805568138696 +5,0.508726274083199,0.8869450993186412,0.5031270959695873 +6,0.47828570241338664,0.8590922318740937,0.46086986362257554 +7,0.4483151824973858,0.8270499507056996,0.42318453680999174 diff --git a/scripts/plotting/results/vanilla/em/snemi/vit_b/results/iterative_prompts_start_point.csv b/scripts/plotting/results/vanilla/em/snemi/vit_b/results/iterative_prompts_start_point.csv new file mode 100644 index 00000000..0bb8b945 --- /dev/null +++ b/scripts/plotting/results/vanilla/em/snemi/vit_b/results/iterative_prompts_start_point.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.23428142380703573,0.4029030042276065,0.24485476664568334 +1,0.2566894396613229,0.46312318215257475,0.25623965227463963 +2,0.2827140167709703,0.507595906218061,0.2814989197316172 +3,0.28207634576580254,0.5045447443703317,0.27769232111142 +4,0.2743372439688219,0.4988233977420705,0.2669684317341389 +5,0.26466101577718254,0.4861838679365709,0.2540447807950442 +6,0.2533408811391572,0.4741947545770861,0.23962698495816775 +7,0.2388504674085298,0.45801335857766245,0.22221753568504488 diff --git a/scripts/plotting/results/vanilla/em/snemi/vit_h/results/iterative_prompts_start_box.csv b/scripts/plotting/results/vanilla/em/snemi/vit_h/results/iterative_prompts_start_box.csv new file mode 100644 index 00000000..e4c78d5f --- /dev/null +++ b/scripts/plotting/results/vanilla/em/snemi/vit_h/results/iterative_prompts_start_box.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.5740633260635087,0.9381900175416302,0.6177918824194981 +1,0.5815737243575148,0.9348995430029294,0.6317414901939573 +2,0.6039575936148656,0.9416230929347721,0.6693148929796194 +3,0.602826617875445,0.9332930636404655,0.6709534337793398 +4,0.5798425679483199,0.9169664178891335,0.6407822560491386 +5,0.5428981903992874,0.8928101011483424,0.5865322350090014 +6,0.4949592283773362,0.8548262594301036,0.5158624616146897 +7,0.4433686362994315,0.8112194229954725,0.438408149917948 diff --git a/scripts/plotting/results/vanilla/em/snemi/vit_h/results/iterative_prompts_start_point.csv b/scripts/plotting/results/vanilla/em/snemi/vit_h/results/iterative_prompts_start_point.csv new file mode 100644 index 00000000..25f25419 --- /dev/null +++ b/scripts/plotting/results/vanilla/em/snemi/vit_h/results/iterative_prompts_start_point.csv @@ -0,0 +1,9 @@ +,msa,sa50,sa75 +0,0.2582243505054961,0.4177350148049307,0.2804356996596167 +1,0.30148648200153916,0.5095384577681294,0.32026679289196813 +2,0.2987767610770364,0.5058682272437036,0.31175713098469426 +3,0.26949819502929817,0.4637636937166766,0.2775752441653023 +4,0.24873420884399566,0.4381133196857306,0.2542456326555962 +5,0.23192409879880055,0.4249288860872832,0.23205356252901832 +6,0.2141601299586851,0.4115564887586721,0.20501413668674823 +7,0.19815111268082503,0.4012047095598585,0.17986626653341994 diff --git a/test/test_gui.py b/test/test_gui.py index e82f5610..2fad5319 100644 --- a/test/test_gui.py +++ b/test/test_gui.py @@ -1,3 +1,5 @@ +import platform + import numpy as np import pytest @@ -20,6 +22,7 @@ def _check_layer_initialization(viewer): @pytest.mark.gui +@pytest.mark.skipif(platform.system() == "Windows", reason="Gui test is not working on windows.") def test_annotator_2d(make_napari_viewer_proxy, tmp_path): """Integration test for annotator_2d widget with automatic mask generation. diff --git a/test/test_instance_segmentation.py b/test/test_instance_segmentation.py index 4838ec08..4f3d6ba7 100644 --- a/test/test_instance_segmentation.py +++ b/test/test_instance_segmentation.py @@ -39,7 +39,7 @@ def write_object(center, radius): @staticmethod def _get_model(image, model_type): - predictor = util.get_sam_model(model_type=model_type) + predictor = util.get_sam_model(model_type=model_type, device=util.get_device(None)) image_embeddings = util.precompute_image_embeddings(predictor, image) return predictor, image_embeddings @@ -71,12 +71,12 @@ def test_automatic_mask_generator(self): amg.initialize(image, image_embeddings=image_embeddings, verbose=False) predicted = amg.generate() - predicted = mask_data_to_segmentation(predicted, image.shape, with_background=True) + predicted = mask_data_to_segmentation(predicted, with_background=True) self.assertGreater(matching(predicted, mask, threshold=0.75)["segmentation_accuracy"], 0.99) # check that regenerating the segmentation works predicted2 = amg.generate() - predicted2 = mask_data_to_segmentation(predicted2, image.shape, with_background=True) + predicted2 = mask_data_to_segmentation(predicted2, with_background=True) self.assertTrue(np.array_equal(predicted, predicted2)) # check that serializing and reserializing the state works @@ -84,14 +84,14 @@ def test_automatic_mask_generator(self): amg = AutomaticMaskGenerator(predictor, points_per_side=10, points_per_batch=16) amg.set_state(state) predicted3 = amg.generate() - predicted3 = mask_data_to_segmentation(predicted3, image.shape, with_background=True) + predicted3 = mask_data_to_segmentation(predicted3, with_background=True) self.assertTrue(np.array_equal(predicted, predicted3)) def test_tiled_automatic_mask_generator(self): from micro_sam.instance_segmentation import TiledAutomaticMaskGenerator, mask_data_to_segmentation # Release all unoccupied cached memory, tiling requires a lot of memory - device = util._get_device(None) + device = util.get_device(None) if device == "cuda": import torch.cuda torch.cuda.empty_cache() @@ -107,11 +107,11 @@ def test_tiled_automatic_mask_generator(self): amg = TiledAutomaticMaskGenerator(predictor, points_per_side=8) amg.initialize(image, image_embeddings=image_embeddings, verbose=False) predicted = amg.generate(pred_iou_thresh=pred_iou_thresh) - predicted = mask_data_to_segmentation(predicted, image.shape, with_background=True) + predicted = mask_data_to_segmentation(predicted, with_background=True) self.assertGreater(matching(predicted, mask, threshold=0.75)["segmentation_accuracy"], 0.99) predicted2 = amg.generate(pred_iou_thresh=pred_iou_thresh) - predicted2 = mask_data_to_segmentation(predicted2, image.shape, with_background=True) + predicted2 = mask_data_to_segmentation(predicted2, with_background=True) self.assertTrue(np.array_equal(predicted, predicted2)) # check that serializing and reserializing the state works @@ -119,73 +119,7 @@ def test_tiled_automatic_mask_generator(self): amg = TiledAutomaticMaskGenerator(predictor) amg.set_state(state) predicted3 = amg.generate(pred_iou_thresh=pred_iou_thresh) - predicted3 = mask_data_to_segmentation(predicted3, image.shape, with_background=True) - self.assertTrue(np.array_equal(predicted, predicted3)) - - @unittest.skip("Experimental functionality") - def test_embedding_mask_generator(self): - from micro_sam.instance_segmentation import _EmbeddingMaskGenerator, mask_data_to_segmentation - - mask, image = self.mask, self.image - predictor, image_embeddings = self.predictor, self.image_embeddings - pred_iou_thresh, stability_score_thresh = 0.95, 0.75 - - amg = _EmbeddingMaskGenerator(predictor) - amg.initialize(image, image_embeddings=image_embeddings, verbose=False) - predicted = amg.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh) - predicted = mask_data_to_segmentation(predicted, image.shape, with_background=True) - - self.assertGreater(matching(predicted, mask, threshold=0.75)["segmentation_accuracy"], 0.99) - - initial_seg = amg.get_initial_segmentation() - self.assertEqual(initial_seg.shape, image.shape) - - # check that regenerating the segmentation works - predicted2 = amg.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh) - predicted2 = mask_data_to_segmentation(predicted2, image.shape, with_background=True) - self.assertTrue(np.array_equal(predicted, predicted2)) - - # check that serializing and reserializing the state works - state = amg.get_state() - amg = _EmbeddingMaskGenerator(predictor) - amg.set_state(state) - predicted3 = amg.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh) - predicted3 = mask_data_to_segmentation(predicted3, image.shape, with_background=True) - self.assertTrue(np.array_equal(predicted, predicted3)) - - @unittest.skip("Experimental functionality") - def test_tiled_embedding_mask_generator(self): - from micro_sam.instance_segmentation import _TiledEmbeddingMaskGenerator - - # Release all unoccupied cached memory, tiling requires a lot of memory - device = util._get_device(None) - if device == "cuda": - import torch.cuda - torch.cuda.empty_cache() - elif device == "mps": - import torch.mps - torch.mps.empty_cache() - - mask, image = self.large_mask, self.large_image - predictor, image_embeddings = self.predictor, self.tiled_embeddings - pred_iou_thresh, stability_score_thresh = 0.90, 0.60 - - amg = _TiledEmbeddingMaskGenerator(predictor, box_extension=0.1) - amg.initialize(image, image_embeddings=image_embeddings) - predicted = amg.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh) - initial_seg = amg.get_initial_segmentation() - - self.assertGreater(matching(predicted, mask, threshold=0.75)["segmentation_accuracy"], 0.99) - self.assertEqual(initial_seg.shape, image.shape) - - predicted2 = amg.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh) - self.assertTrue(np.array_equal(predicted, predicted2)) - - # check that serializing and reserializing the state works - state = amg.get_state() - amg = _TiledEmbeddingMaskGenerator(predictor) - amg.set_state(state) - predicted3 = amg.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh) + predicted3 = mask_data_to_segmentation(predicted3, with_background=True) self.assertTrue(np.array_equal(predicted, predicted3)) diff --git a/test/test_prompt_based_segmentation.py b/test/test_prompt_based_segmentation.py index 5db7d406..6bde257f 100644 --- a/test/test_prompt_based_segmentation.py +++ b/test/test_prompt_based_segmentation.py @@ -20,7 +20,7 @@ def _get_input(shape=(256, 256)): @staticmethod def _get_model(image, model_type): - predictor = util.get_sam_model(model_type=model_type) + predictor = util.get_sam_model(model_type=model_type, device=util.get_device(None)) image_embeddings = util.precompute_image_embeddings(predictor, image) util.set_precomputed(predictor, image_embeddings) return predictor diff --git a/test/test_sam_annotator/test_widgets.py b/test/test_sam_annotator/test_widgets.py index dd4adb6f..dc5e26f1 100644 --- a/test/test_sam_annotator/test_widgets.py +++ b/test/test_sam_annotator/test_widgets.py @@ -7,7 +7,7 @@ import zarr from micro_sam.sam_annotator._state import AnnotatorState -from micro_sam.sam_annotator._widgets import embedding_widget, Model +from micro_sam.sam_annotator._widgets import embedding_widget from micro_sam.util import _compute_data_signature @@ -22,7 +22,7 @@ def test_embedding_widget(make_napari_viewer, tmp_path): layer = viewer.open_sample('napari', 'camera')[0] my_widget = embedding_widget() # run image embedding widget - worker = my_widget(image=layer, model=Model.vit_t, device="cpu", save_path=tmp_path) + worker = my_widget(image=layer, model="vit_t", device="cpu", save_path=tmp_path) worker.await_workers() # blocks until thread worker is finished the embedding # Check in-memory state - predictor assert isinstance(AnnotatorState().predictor, (SamPredictor, MobileSamPredictor)) diff --git a/test/test_training.py b/test/test_training.py index 60a59903..647d5ff6 100644 --- a/test/test_training.py +++ b/test/test_training.py @@ -9,7 +9,7 @@ import torch_em from micro_sam.sample_data import synthetic_data -from micro_sam.util import VIT_T_SUPPORT +from micro_sam.util import VIT_T_SUPPORT, get_custom_sam_model, SamPredictor @unittest.skipUnless(VIT_T_SUPPORT, "Integration test is only run with vit_t support, otherwise it takes too long.") @@ -56,6 +56,8 @@ def tearDown(self): pass def _get_dataloader(self, split, patch_shape, batch_size): + import micro_sam.training as sam_training + # Create the synthetic training data and get the corresponding folders. image_root = os.path.join(self.tmp_folder, "synthetic-data", "images", split) label_root = os.path.join(self.tmp_folder, "synthetic-data", "labels", split) @@ -67,6 +69,7 @@ def _get_dataloader(self, split, patch_shape, batch_size): patch_shape=patch_shape, batch_size=batch_size, label_transform=torch_em.transform.label.connected_components, shuffle=True, num_workers=2, ndim=2, is_seg_dataset=False, + raw_transform=sam_training.identity, ) return loader @@ -74,7 +77,7 @@ def _train_model(self, model_type, device): import micro_sam.training as sam_training batch_size = 1 - n_sub_iteration = 4 + n_sub_iteration = 3 patch_shape = (512, 512) n_objects_per_batch = 2 @@ -83,7 +86,7 @@ def _train_model(self, model_type, device): val_loader = self._get_dataloader("val", patch_shape, batch_size) model = sam_training.get_trainable_sam_model(model_type=model_type, device=device) - convert_inputs = sam_training.ConvertToSamInputs() + convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.05) optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.9, patience=10, verbose=True @@ -94,8 +97,6 @@ def _train_model(self, model_type, device): train_loader=train_loader, val_loader=val_loader, model=model, - loss=torch_em.loss.DiceLoss(), - metric=torch_em.loss.DiceLoss(), optimizer=optimizer, lr_scheduler=scheduler, device=device, @@ -133,6 +134,9 @@ def _run_inference_and_check_results( inference_function(predictor, image_paths, label_paths, embedding_dir, prediction_dir) pred_paths = sorted(glob(os.path.join(prediction_dir, "*.tif"))) + if len(pred_paths) == 0: # we need to go to subfolder for iterative inference + pred_paths = sorted(glob(os.path.join(prediction_dir, "iteration02", "*.tif"))) + self.assertEqual(len(pred_paths), len(label_paths)) eval_res = evaluation.run_evaluation(label_paths, pred_paths, verbose=False) result = eval_res["sa50"].values.item() @@ -150,6 +154,10 @@ def test_training(self): checkpoint_path = os.path.join(self.tmp_folder, "checkpoints", "test", "best.pt") self.assertTrue(os.path.exists(checkpoint_path)) + # Check that the model can be loaded from a custom checkpoint. + predictor = get_custom_sam_model(checkpoint_path, model_type=model_type, device=device) + self.assertTrue(isinstance(predictor, SamPredictor)) + # Export the model. export_path = os.path.join(self.tmp_folder, "exported_model.pth") self._export_model(checkpoint_path, export_path, model_type) @@ -157,7 +165,7 @@ def test_training(self): # Check the model with inference with a single point prompt. prediction_dir = os.path.join(self.tmp_folder, "predictions-points") - normal_inference = partial( + point_inference = partial( evaluation.run_inference_with_prompts, use_points=True, use_boxes=False, n_positives=1, n_negatives=0, @@ -165,12 +173,12 @@ def test_training(self): ) self._run_inference_and_check_results( export_path, model_type, prediction_dir=prediction_dir, - inference_function=normal_inference, expected_sa=0.9 + inference_function=point_inference, expected_sa=0.9 ) # Check the model with inference with a box point prompt. prediction_dir = os.path.join(self.tmp_folder, "predictions-boxes") - normal_inference = partial( + box_inference = partial( evaluation.run_inference_with_prompts, use_points=False, use_boxes=True, n_positives=1, n_negatives=0, @@ -178,11 +186,20 @@ def test_training(self): ) self._run_inference_and_check_results( export_path, model_type, prediction_dir=prediction_dir, - inference_function=normal_inference, expected_sa=0.95, + inference_function=box_inference, expected_sa=0.95, ) - # Check the model with interactive inference - # TODO + # Check the model with interactive inference. + prediction_dir = os.path.join(self.tmp_folder, "predictions-iterative") + iterative_inference = partial( + evaluation.run_inference_with_iterative_prompting, + start_with_box_prompt=False, + n_iterations=3, + ) + self._run_inference_and_check_results( + export_path, model_type, prediction_dir=prediction_dir, + inference_function=iterative_inference, expected_sa=0.95, + ) if __name__ == "__main__": diff --git a/test/test_util.py b/test/test_util.py index 7383977a..c0d4e5ca 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -3,13 +3,16 @@ from shutil import rmtree import numpy as np +import torch import zarr from skimage.data import binary_blobs from skimage.measure import label +from micro_sam.util import VIT_T_SUPPORT, SamPredictor, get_cache_directory class TestUtil(unittest.TestCase): + model_type = "vit_t" if VIT_T_SUPPORT else "vit_b" tmp_folder = "tmp-files" def setUp(self): @@ -18,6 +21,24 @@ def setUp(self): def tearDown(self): rmtree(self.tmp_folder) + def test_get_sam_model(self): + from micro_sam.util import get_sam_model + + def check_predictor(predictor): + self.assertTrue(isinstance(predictor, SamPredictor)) + self.assertEqual(predictor.model_type, self.model_type) + + # check predictor with download + predictor = get_sam_model(model_type=self.model_type) + check_predictor(predictor) + + # check predictor with checkpoint path (using the cached model) + checkpoint_path = os.path.join( + get_cache_directory(), "models", "vit_t" if VIT_T_SUPPORT else "vit_b" + ) + predictor = get_sam_model(model_type=self.model_type, checkpoint_path=checkpoint_path) + check_predictor(predictor) + def test_compute_iou(self): from micro_sam.util import compute_iou @@ -63,6 +84,21 @@ def test_segmentation_to_one_hot(self): self.assertTrue(np.allclose(mask, expected_mask)) + def test_get_device(self): + from micro_sam.util import get_device + + # check that device without argument works + get_device() + + # check passing device as string + device = get_device("cpu") + self.assertEqual(device, "cpu") + + # check passing device as torch.device works + device = get_device(torch.device("cpu")) + self.assertTrue(isinstance(device, torch.device)) + self.assertEqual(device.type, "cpu") + if __name__ == "__main__": unittest.main()