Skip to content

Commit

Permalink
Merge pull request #327 from computational-cell-analytics/dev
Browse files Browse the repository at this point in the history
Changes for new release
  • Loading branch information
constantinpape authored Jan 17, 2024
2 parents e9f7689 + 102c4a4 commit 9aa64a5
Show file tree
Hide file tree
Showing 131 changed files with 3,933 additions and 1,769 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.10"]

steps:
Expand Down
2 changes: 1 addition & 1 deletion development/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def main():
args = parser.parse_args()

model_type = args.model_type
device = util._get_device(args.device)
device = util.get_device(args.device)
print("Running benchmarks for", model_type)
print("with device:", device)

Expand Down
31 changes: 31 additions & 0 deletions development/seg_with_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import imageio.v3 as imageio
import napari

from micro_sam.instance_segmentation import (
load_instance_segmentation_with_decoder_from_checkpoint, mask_data_to_segmentation
)
from micro_sam.util import precompute_image_embeddings

checkpoint = "./for_decoder/best.pt"
segmenter = load_instance_segmentation_with_decoder_from_checkpoint(checkpoint, model_type="vit_b")

image_path = "/home/pape/Work/data/incu_cyte/livecell/images/livecell_train_val_images/A172_Phase_A7_1_02d00h00m_1.tif"
image = imageio.imread(image_path)

embedding_path = "./for_decoder/A172_Phase_A7_1_02d00h00m_1.zarr"
image_embeddings = precompute_image_embeddings(
segmenter._predictor, image, embedding_path,
)
# image_embeddings = None

print("Start segmentation ...")
segmenter.initialize(image, image_embeddings)
masks = segmenter.generate(output_mode="binary_mask")
segmentation = mask_data_to_segmentation(masks, with_background=True)
print("Segmentation done")

v = napari.Viewer()
v.add_image(image)
# v.add_image(segmenter._foreground)
v.add_labels(segmentation)
napari.run()
36 changes: 26 additions & 10 deletions doc/finetuned_models.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Finetuned models

We provide models that were finetuned on microscopy data using `micro_sam.training`. They are hosted on zenodo. We currently offer the following models:
In addition to the original Segment anything models, we provide models that finetuned on microscopy data using the functionality from `micro_sam.training`.
The models are hosted on zenodo. We currently offer the following models:
- `vit_h`: Default Segment Anything model with vit-h backbone.
- `vit_l`: Default Segment Anything model with vit-l backbone.
- `vit_b`: Default Segment Anything model with vit-b backbone.
- `vit_h_lm`: Finetuned Segment Anything model for cells and nuclei in light microscopy data with vit-h backbone.
- `vit_t`: Segment Anything model with vit-tiny backbone. From the [mobile sam publication](https://arxiv.org/abs/2306.14289).
- `vit_b_lm`: Finetuned Segment Anything model for cells and nuclei in light microscopy data with vit-b backbone.
- `vit_h_em`: Finetuned Segment Anything model for neurites and cells in electron microscopy data with vit-h backbone.
- `vit_b_em`: Finetuned Segment Anything model for neurites and cells in electron microscopy data with vit-b backbone.
- `vit_b_em_organelles`: Finetuned Segment Anything model for mitochodria and nuclei in electron microscopy data with vit-b backbone.
- `vit_b_em_boundaries`: Finetuned Segment Anything model for neurites and cells in electron microscopy data with vit-b backbone.

See the two figures below of the improvements through the finetuned model for LM and EM data.

Expand All @@ -20,17 +21,32 @@ You can select which of the models is used in the annotation tools by selecting
<img src="https://raw.githubusercontent.com/computational-cell-analytics/micro-sam/master/doc/images/model-type-selector.png" width="256">

To use a specific model in the python library you need to pass the corresponding name as value to the `model_type` parameter exposed by all relevant functions.
See for example the [2d annotator example](https://github.com/computational-cell-analytics/micro-sam/blob/master/examples/annotator_2d.py#L62) where `use_finetuned_model` can be set to `True` to use the `vit_h_lm` model.
See for example the [2d annotator example](https://github.com/computational-cell-analytics/micro-sam/blob/master/examples/annotator_2d.py#L62) where `use_finetuned_model` can be set to `True` to use the `vit_b_lm` model.

Note that we are still working on improving these models and may update them from time to time. All older models will stay available for download on zenodo, see [model sources](#model-sources) below


## Which model should I choose?

As a rule of thumb:
- Use the `_lm` models for segmenting cells or nuclei in light microscopy.
- Use the `_em` models for segmenting cells or neurites in electron microscopy.
- Note that this model does not work well for segmenting mitochondria or other organelles because it is biased towards segmenting the full cell / cellular compartment.
- For other cases use the default models.
- Use the `vit_b_lm` model for segmenting cells or nuclei in light microscopy.
- Use the `vit_b_em_organelles` models for segmenting mitochondria, nuclei or other organelles in electron microscopy.
- Use the `vit_b_em_boundaries` models for segmenting cells or neurites in electron microscopy.
- For other use-cases use one of the default models.

See also the figures above for examples where the finetuned models work better than the vanilla models.
Currently the model `vit_h` is used by default.

We are working on releasing more fine-tuned models, in particular for mitochondria and other organelles in EM.
We are working on further improving these models and adding new models for other biomedical imaging domains.


## Model Sources

Here is an overview of all finetuned models we have released to zenodo so far:
- [vit_b_em_boundaries](https://zenodo.org/records/10524894): for segmenting compartments delineated by boundaries such as cells or neurites in EM.
- [vit_b_em_organelles](https://zenodo.org/records/10524828): for segmenting mitochondria, nuclei or other organelles in EM.
- [vit_b_lm](https://zenodo.org/records/10524791): for segmenting cells and nuclei in LM.
- [vit_h_em](https://zenodo.org/records/8250291): this model is outdated.
- [vit_h_lm](https://zenodo.org/records/8250299): this model is outdated.

Some of these models contain multiple versions.
Binary file modified doc/images/model-type-selector.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions environment_cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ dependencies:
- napari
- pip
- pooch
- python-xxhash
- python-elf >=0.4.8
- pytorch
- segment-anything
- torchvision
- torch_em >=0.5.1
- torch_em >=0.6.0
- tqdm
- timm
- pip:
- git+https://github.com/ChaoningZhang/MobileSAM.git
# - git+https://github.com/facebookresearch/segment-anything.git
4 changes: 2 additions & 2 deletions environment_gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ dependencies:
- napari
- pip
- pooch
- python-xxhash
- python-elf >=0.4.8
- pytorch
- pytorch-cuda>=11.7 # you may need to update the cuda version to match your system
- segment-anything
- torchvision
- torch_em >=0.5.1
- torch_em >=0.6.0
- tqdm
- timm
- pip:
- git+https://github.com/ChaoningZhang/MobileSAM.git
# - git+https://github.com/facebookresearch/segment-anything.git
40 changes: 23 additions & 17 deletions examples/annotator_2d.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
import os

import imageio.v3 as imageio
from micro_sam.sam_annotator import annotator_2d
from micro_sam.sample_data import fetch_hela_2d_example_data, fetch_livecell_example_data, fetch_wholeslide_example_data
from micro_sam.util import get_cache_directory

DATA_CACHE = os.path.join(get_cache_directory(), "sample_data")
EMBEDDING_CACHE = os.path.join(get_cache_directory(), "embeddings")
os.makedirs(EMBEDDING_CACHE, exist_ok=True)


def livecell_annotator(use_finetuned_model):
"""Run the 2d annotator for an example image from the LiveCELL dataset.
See https://doi.org/10.1038/s41592-021-01249-6 for details on the data.
"""
example_data = fetch_livecell_example_data("./data")
example_data = fetch_livecell_example_data(DATA_CACHE)
image = imageio.imread(example_data)

if use_finetuned_model:
embedding_path = "./embeddings/embeddings-livecell-vit_h_lm.zarr"
model_type = "vit_h_lm"
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-livecell-vit_b_lm.zarr")
model_type = "vit_b_lm"
else:
embedding_path = "./embeddings/embeddings-livecell.zarr"
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-livecell.zarr")
model_type = "vit_h"

annotator_2d(image, embedding_path, show_embeddings=False, model_type=model_type)
Expand All @@ -24,14 +31,14 @@ def livecell_annotator(use_finetuned_model):
def hela_2d_annotator(use_finetuned_model):
"""Run the 2d annotator for an example image form the cell tracking challenge HeLa 2d dataset.
"""
example_data = fetch_hela_2d_example_data("./data")
example_data = fetch_hela_2d_example_data(DATA_CACHE)
image = imageio.imread(example_data)

if use_finetuned_model:
embedding_path = "./embeddings/embeddings-hela2d-vit_h_lm.zarr"
model_type = "vit_h_lm"
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-hela2d-vit_b_lm.zarr")
model_type = "vit_b_lm"
else:
embedding_path = "./embeddings/embeddings-hela2d.zarr"
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-hela2d.zarr")
model_type = "vit_h"

annotator_2d(image, embedding_path, show_embeddings=False, model_type=model_type, precompute_amg_state=True)
Expand All @@ -43,29 +50,28 @@ def wholeslide_annotator(use_finetuned_model):
See https://neurips22-cellseg.grand-challenge.org/ for details on the data.
"""
example_data = fetch_wholeslide_example_data("./data")
example_data = fetch_wholeslide_example_data(DATA_CACHE)
image = imageio.imread(example_data)

if use_finetuned_model:
embedding_path = "./embeddings/whole-slide-embeddings-vit_h_lm.zarr"
model_type = "vit_h_lm"
embedding_path = os.path.join(EMBEDDING_CACHE, "whole-slide-embeddings-vit_b_lm.zarr")
model_type = "vit_b_lm"
else:
embedding_path = "./embeddings/whole-slide-embeddings.zarr"
embedding_path = os.path.join(EMBEDDING_CACHE, "whole-slide-embeddings.zarr")
model_type = "vit_h"

annotator_2d(image, embedding_path, tile_shape=(1024, 1024), halo=(256, 256), model_type=model_type)


def main():
# whether to use the fine-tuned SAM model
# this feature is still experimental!
use_finetuned_model = False
# Whether to use the fine-tuned SAM model for light microscopy data.
use_finetuned_model = True

# 2d annotator for livecell data
# livecell_annotator(use_finetuned_model)
livecell_annotator(use_finetuned_model)

# 2d annotator for cell tracking challenge hela data
hela_2d_annotator(use_finetuned_model)
# hela_2d_annotator(use_finetuned_model)

# 2d annotator for a whole slide image
# wholeslide_annotator(use_finetuned_model)
Expand Down
38 changes: 25 additions & 13 deletions examples/annotator_3d.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,45 @@
import os

from elf.io import open_file
from micro_sam.sam_annotator import annotator_3d
from micro_sam.sample_data import fetch_3d_example_data
from micro_sam.util import get_cache_directory

DATA_CACHE = os.path.join(get_cache_directory(), "sample_data")
EMBEDDING_CACHE = os.path.join(get_cache_directory(), "embeddings")
os.makedirs(EMBEDDING_CACHE, exist_ok=True)


def em_3d_annotator(use_finetuned_model):
def em_3d_annotator(finetuned_model):
"""Run the 3d annotator for an example EM volume."""
# download the example data
example_data = fetch_3d_example_data("./data")
example_data = fetch_3d_example_data(DATA_CACHE)
# load the example data (load the sequence of tif files as 3d volume)
with open_file(example_data) as f:
raw = f["*.png"][:]

if use_finetuned_model:
embedding_path = "./embeddings/embeddings-lucchi-vit_h_em.zarr"
model_type = "vit_h_em"
else:
embedding_path = "./embeddings/embeddings-lucchi.zarr"
if not finetuned_model:
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-lucchi.zarr")
model_type = "vit_h"
else:
assert finetuned_model in ("organelles", "boundaries")
embedding_path = os.path.join(EMBEDDING_CACHE, f"embeddings-lucchi-vit_b_em_{finetuned_model}.zarr")
model_type = f"vit_b_em_{finetuned_model}"
print(embedding_path)

# start the annotator, cache the embeddings
annotator_3d(raw, embedding_path, model_type=model_type, show_embeddings=False)
annotator_3d(raw, embedding_path, model_type=model_type)


def main():
# whether to use the fine-tuned SAM model
# this feature is still experimental!
use_finetuned_model = False

em_3d_annotator(use_finetuned_model)
# Whether to use the fine-tuned SAM model for mitochondria (organelles) or boundaries.
# valid choices are:
# - None / False (will use the vanilla model)
# - "organelles": will use the model for mitochondria and other organelles
# - "boundaries": will use the model for boundary based structures
finetuned_model = "boundaries"

em_3d_annotator(finetuned_model)


if __name__ == "__main__":
Expand Down
22 changes: 14 additions & 8 deletions examples/annotator_tracking.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,38 @@
import os

from elf.io import open_file
from micro_sam.sam_annotator import annotator_tracking
from micro_sam.sample_data import fetch_tracking_example_data
from micro_sam.util import get_cache_directory

DATA_CACHE = os.path.join(get_cache_directory(), "sample_data")
EMBEDDING_CACHE = os.path.join(get_cache_directory(), "embeddings")
os.makedirs(EMBEDDING_CACHE, exist_ok=True)


def track_ctc_data(use_finetuned_model):
"""Run interactive tracking for data from the cell tracking challenge.
"""
# download the example data
example_data = fetch_tracking_example_data("./data")
example_data = fetch_tracking_example_data(DATA_CACHE)
# load the example data (load the sequence of tif files as timeseries)
with open_file(example_data, mode="r") as f:
timeseries = f["*.tif"]

if use_finetuned_model:
embedding_path = "./embeddings/embeddings-ctc-vit_h_lm.zarr"
model_type = "vit_h_lm"
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-ctc-vit_b_lm.zarr")
model_type = "vit_b_lm"
else:
embedding_path = "./embeddings/embeddings-ctc.zarr"
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-ctc.zarr")
model_type = "vit_h"

# start the annotator with cached embeddings
annotator_tracking(timeseries, embedding_path=embedding_path, show_embeddings=False, model_type=model_type)
annotator_tracking(timeseries, embedding_path=embedding_path, model_type=model_type)


def main():
# whether to use the fine-tuned SAM model
# this feature is still experimental!
use_finetuned_model = False
# Whether to use the fine-tuned SAM model.
use_finetuned_model = True
track_ctc_data(use_finetuned_model)


Expand Down
21 changes: 19 additions & 2 deletions examples/annotator_with_custom_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
import os

import imageio
import h5py
import micro_sam.sam_annotator as annotator

from micro_sam.util import get_sam_model
from micro_sam.util import get_cache_directory
from micro_sam.sample_data import fetch_hela_2d_example_data


DATA_CACHE = os.path.join(get_cache_directory(), "sample_data")


def annotator_2d_with_custom_model():
example_data = fetch_hela_2d_example_data(DATA_CACHE)
image = imageio.imread(example_data)

# TODO add an example for the 2d annotator with a custom model
custom_model = "/home/pape/Downloads/exported_models/vit_b_lm.pth"
predictor = get_sam_model(checkpoint_path=custom_model, model_type="vit_b")
annotator.annotator_2d(image, predictor=predictor)


def annotator_3d_with_custom_model():
Expand All @@ -16,7 +32,8 @@ def annotator_3d_with_custom_model():


def main():
annotator_3d_with_custom_model()
annotator_2d_with_custom_model()
# annotator_3d_with_custom_model()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/finetuning/finetune_hela.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_dataloader(split, patch_shape, batch_size):
patch_shape=patch_shape, batch_size=batch_size,
ndim=2, is_seg_dataset=True, rois=roi,
label_transform=torch_em.transform.label.connected_components,
num_workers=8, shuffle=True,
num_workers=8, shuffle=True, raw_transform=sam_training.identity,
)
return loader

Expand Down
Loading

0 comments on commit 9aa64a5

Please sign in to comment.