Skip to content

Commit

Permalink
Update to new finetuned models (#326)
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape authored Jan 17, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 3fda1d9 commit 102c4a4
Showing 10 changed files with 157 additions and 60 deletions.
36 changes: 26 additions & 10 deletions doc/finetuned_models.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Finetuned models

We provide models that were finetuned on microscopy data using `micro_sam.training`. They are hosted on zenodo. We currently offer the following models:
In addition to the original Segment anything models, we provide models that finetuned on microscopy data using the functionality from `micro_sam.training`.
The models are hosted on zenodo. We currently offer the following models:
- `vit_h`: Default Segment Anything model with vit-h backbone.
- `vit_l`: Default Segment Anything model with vit-l backbone.
- `vit_b`: Default Segment Anything model with vit-b backbone.
- `vit_h_lm`: Finetuned Segment Anything model for cells and nuclei in light microscopy data with vit-h backbone.
- `vit_t`: Segment Anything model with vit-tiny backbone. From the [mobile sam publication](https://arxiv.org/abs/2306.14289).
- `vit_b_lm`: Finetuned Segment Anything model for cells and nuclei in light microscopy data with vit-b backbone.
- `vit_h_em`: Finetuned Segment Anything model for neurites and cells in electron microscopy data with vit-h backbone.
- `vit_b_em`: Finetuned Segment Anything model for neurites and cells in electron microscopy data with vit-b backbone.
- `vit_b_em_organelles`: Finetuned Segment Anything model for mitochodria and nuclei in electron microscopy data with vit-b backbone.
- `vit_b_em_boundaries`: Finetuned Segment Anything model for neurites and cells in electron microscopy data with vit-b backbone.

See the two figures below of the improvements through the finetuned model for LM and EM data.

@@ -20,17 +21,32 @@ You can select which of the models is used in the annotation tools by selecting
<img src="https://raw.githubusercontent.com/computational-cell-analytics/micro-sam/master/doc/images/model-type-selector.png" width="256">

To use a specific model in the python library you need to pass the corresponding name as value to the `model_type` parameter exposed by all relevant functions.
See for example the [2d annotator example](https://github.com/computational-cell-analytics/micro-sam/blob/master/examples/annotator_2d.py#L62) where `use_finetuned_model` can be set to `True` to use the `vit_h_lm` model.
See for example the [2d annotator example](https://github.com/computational-cell-analytics/micro-sam/blob/master/examples/annotator_2d.py#L62) where `use_finetuned_model` can be set to `True` to use the `vit_b_lm` model.

Note that we are still working on improving these models and may update them from time to time. All older models will stay available for download on zenodo, see [model sources](#model-sources) below


## Which model should I choose?

As a rule of thumb:
- Use the `_lm` models for segmenting cells or nuclei in light microscopy.
- Use the `_em` models for segmenting cells or neurites in electron microscopy.
- Note that this model does not work well for segmenting mitochondria or other organelles because it is biased towards segmenting the full cell / cellular compartment.
- For other cases use the default models.
- Use the `vit_b_lm` model for segmenting cells or nuclei in light microscopy.
- Use the `vit_b_em_organelles` models for segmenting mitochondria, nuclei or other organelles in electron microscopy.
- Use the `vit_b_em_boundaries` models for segmenting cells or neurites in electron microscopy.
- For other use-cases use one of the default models.

See also the figures above for examples where the finetuned models work better than the vanilla models.
Currently the model `vit_h` is used by default.

We are working on releasing more fine-tuned models, in particular for mitochondria and other organelles in EM.
We are working on further improving these models and adding new models for other biomedical imaging domains.


## Model Sources

Here is an overview of all finetuned models we have released to zenodo so far:
- [vit_b_em_boundaries](https://zenodo.org/records/10524894): for segmenting compartments delineated by boundaries such as cells or neurites in EM.
- [vit_b_em_organelles](https://zenodo.org/records/10524828): for segmenting mitochondria, nuclei or other organelles in EM.
- [vit_b_lm](https://zenodo.org/records/10524791): for segmenting cells and nuclei in LM.
- [vit_h_em](https://zenodo.org/records/8250291): this model is outdated.
- [vit_h_lm](https://zenodo.org/records/8250299): this model is outdated.

Some of these models contain multiple versions.
Binary file modified doc/images/model-type-selector.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 10 additions & 11 deletions examples/annotator_2d.py
Original file line number Diff line number Diff line change
@@ -19,8 +19,8 @@ def livecell_annotator(use_finetuned_model):
image = imageio.imread(example_data)

if use_finetuned_model:
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-livecell-vit_h_lm.zarr")
model_type = "vit_h_lm"
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-livecell-vit_b_lm.zarr")
model_type = "vit_b_lm"
else:
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-livecell.zarr")
model_type = "vit_h"
@@ -35,8 +35,8 @@ def hela_2d_annotator(use_finetuned_model):
image = imageio.imread(example_data)

if use_finetuned_model:
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-hela2d-vit_h_lm.zarr")
model_type = "vit_h_lm"
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-hela2d-vit_b_lm.zarr")
model_type = "vit_b_lm"
else:
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-hela2d.zarr")
model_type = "vit_h"
@@ -54,8 +54,8 @@ def wholeslide_annotator(use_finetuned_model):
image = imageio.imread(example_data)

if use_finetuned_model:
embedding_path = os.path.join(EMBEDDING_CACHE, "whole-slide-embeddings-vit_h_lm.zarr")
model_type = "vit_h_lm"
embedding_path = os.path.join(EMBEDDING_CACHE, "whole-slide-embeddings-vit_b_lm.zarr")
model_type = "vit_b_lm"
else:
embedding_path = os.path.join(EMBEDDING_CACHE, "whole-slide-embeddings.zarr")
model_type = "vit_h"
@@ -64,15 +64,14 @@ def wholeslide_annotator(use_finetuned_model):


def main():
# whether to use the fine-tuned SAM model
# this feature is still experimental!
use_finetuned_model = False
# Whether to use the fine-tuned SAM model for light microscopy data.
use_finetuned_model = True

# 2d annotator for livecell data
# livecell_annotator(use_finetuned_model)
livecell_annotator(use_finetuned_model)

# 2d annotator for cell tracking challenge hela data
hela_2d_annotator(use_finetuned_model)
# hela_2d_annotator(use_finetuned_model)

# 2d annotator for a whole slide image
# wholeslide_annotator(use_finetuned_model)
27 changes: 16 additions & 11 deletions examples/annotator_3d.py
Original file line number Diff line number Diff line change
@@ -10,31 +10,36 @@
os.makedirs(EMBEDDING_CACHE, exist_ok=True)


def em_3d_annotator(use_finetuned_model):
def em_3d_annotator(finetuned_model):
"""Run the 3d annotator for an example EM volume."""
# download the example data
example_data = fetch_3d_example_data(DATA_CACHE)
# load the example data (load the sequence of tif files as 3d volume)
with open_file(example_data) as f:
raw = f["*.png"][:]

if use_finetuned_model:
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-lucchi-vit_h_em.zarr")
model_type = "vit_h_em"
else:
if not finetuned_model:
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-lucchi.zarr")
model_type = "vit_h"
else:
assert finetuned_model in ("organelles", "boundaries")
embedding_path = os.path.join(EMBEDDING_CACHE, f"embeddings-lucchi-vit_b_em_{finetuned_model}.zarr")
model_type = f"vit_b_em_{finetuned_model}"
print(embedding_path)

# start the annotator, cache the embeddings
annotator_3d(raw, embedding_path, model_type=model_type, show_embeddings=False)
annotator_3d(raw, embedding_path, model_type=model_type)


def main():
# whether to use the fine-tuned SAM model
# this feature is still experimental!
use_finetuned_model = False

em_3d_annotator(use_finetuned_model)
# Whether to use the fine-tuned SAM model for mitochondria (organelles) or boundaries.
# valid choices are:
# - None / False (will use the vanilla model)
# - "organelles": will use the model for mitochondria and other organelles
# - "boundaries": will use the model for boundary based structures
finetuned_model = "boundaries"

em_3d_annotator(finetuned_model)


if __name__ == "__main__":
11 changes: 5 additions & 6 deletions examples/annotator_tracking.py
Original file line number Diff line number Diff line change
@@ -20,20 +20,19 @@ def track_ctc_data(use_finetuned_model):
timeseries = f["*.tif"]

if use_finetuned_model:
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-ctc-vit_h_lm.zarr")
model_type = "vit_h_lm"
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-ctc-vit_b_lm.zarr")
model_type = "vit_b_lm"
else:
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-ctc.zarr")
model_type = "vit_h"

# start the annotator with cached embeddings
annotator_tracking(timeseries, embedding_path=embedding_path, show_embeddings=False, model_type=model_type)
annotator_tracking(timeseries, embedding_path=embedding_path, model_type=model_type)


def main():
# whether to use the fine-tuned SAM model
# this feature is still experimental!
use_finetuned_model = False
# Whether to use the fine-tuned SAM model.
use_finetuned_model = True
track_ctc_data(use_finetuned_model)


21 changes: 19 additions & 2 deletions examples/annotator_with_custom_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
import os

import imageio
import h5py
import micro_sam.sam_annotator as annotator

from micro_sam.util import get_sam_model
from micro_sam.util import get_cache_directory
from micro_sam.sample_data import fetch_hela_2d_example_data


DATA_CACHE = os.path.join(get_cache_directory(), "sample_data")


def annotator_2d_with_custom_model():
example_data = fetch_hela_2d_example_data(DATA_CACHE)
image = imageio.imread(example_data)

# TODO add an example for the 2d annotator with a custom model
custom_model = "/home/pape/Downloads/exported_models/vit_b_lm.pth"
predictor = get_sam_model(checkpoint_path=custom_model, model_type="vit_b")
annotator.annotator_2d(image, predictor=predictor)


def annotator_3d_with_custom_model():
@@ -16,7 +32,8 @@ def annotator_3d_with_custom_model():


def main():
annotator_3d_with_custom_model()
annotator_2d_with_custom_model()
# annotator_3d_with_custom_model()


if __name__ == "__main__":
7 changes: 3 additions & 4 deletions examples/image_series_annotator.py
Original file line number Diff line number Diff line change
@@ -14,8 +14,8 @@ def series_annotation(use_finetuned_model):
"""

if use_finetuned_model:
embedding_path = os.path.join(EMBEDDING_CACHE, "series-embeddings-vit_h_lm")
model_type = "vit_h_lm"
embedding_path = os.path.join(EMBEDDING_CACHE, "series-embeddings-vit_b_lm")
model_type = "vit_b_lm"
else:
embedding_path = os.path.join(EMBEDDING_CACHE, "series-embeddings")
model_type = "vit_h"
@@ -29,8 +29,7 @@ def series_annotation(use_finetuned_model):


def main():
# whether to use the fine-tuned SAM model
# this feature is still experimental!
# Whether to use the fine-tuned SAM model.
use_finetuned_model = False
series_annotation(use_finetuned_model)

30 changes: 14 additions & 16 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
@@ -60,24 +60,25 @@ def get_cache_directory() -> None:
Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
"""
default_cache_directory = os.path.expanduser(pooch.os_cache('micro_sam'))
cache_directory = Path(os.environ.get('MICROSAM_CACHEDIR', default_cache_directory))
default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam"))
cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory))
return cache_directory


#
# Functionality for model download and export
#

def microsam_cachedir():

def microsam_cachedir() -> None:
"""Return the micro-sam cache directory.
Returns the top level cache directory for micro-sam models and sample data.
Every time this function is called, we check for any user updates made to
the MICROSAM_CACHEDIR os environment variable since the last time.
"""
cache_directory = os.environ.get('MICROSAM_CACHEDIR') or pooch.os_cache('micro_sam')
cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam")
return cache_directory


@@ -106,10 +107,9 @@ def models():
# the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM
"vit_t": "sha256:6dbb90523a35330fedd7f1d3dfc66f995213d81b29a5ca8108dbcdd4e37d6c2f",
# first version of finetuned models on zenodo
"vit_h_lm": "sha256:9a65ee0cddc05a98d60469a12a058859c89dc3ea3ba39fed9b90d786253fbf26",
"vit_b_lm": "sha256:5a59cc4064092d54cd4d92cd967e39168f3760905431e868e474d60fe5464ecd",
"vit_h_em": "sha256:ae3798a0646c8df1d4db147998a2d37e402ff57d3aa4e571792fbb911d8a979c",
"vit_b_em": "sha256:c04a714a4e14a110f0eec055a65f7409d54e6bf733164d2933a0ce556f7d6f81",
"vit_b_lm": "sha256:e8f5feb1ad837a7507935409c7f83f7c8af11c6e39cfe3df03f8d3bd4a358449",
"vit_b_em_organelles": "sha256:8fabbe38a427a0c91bbe6518a5c0f103f36b73e6ee6c86fbacd32b4fc66294b4",
"vit_b_em_boundaries": "sha256:d87348b2adef30ab427fb787d458643300eb30624a0e808bf36af21764705f4f",
}
registry_xxh128 = {
# the default segment anything models
@@ -119,10 +119,9 @@ def models():
# the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM
"vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0",
# first version of finetuned models on zenodo
"vit_h_lm": "xxh128:e113adac6a0a21514bb2d73de16b921b",
"vit_b_lm": "xxh128:5fc0851abf8a209dcbed4e95634d9e27",
"vit_h_em": "xxh128:64b6eb2d32ac9c5d9b022b1ac57f1cc6",
"vit_b_em": "xxh128:f50d499db5bf54dc9849c3dbd271d5c9",
"vit_b_lm": "xxh128:6b061eb8684d9d5f55545330d6dce50d",
"vit_b_em_organelles": "xxh128:3919c2b761beba7d3f4ece342c9f5369",
"vit_b_em_boundaries": "xxh128:3099fe6339f5be91ca84db889db1909f",
}

models = pooch.create(
@@ -138,10 +137,9 @@ def models():
# the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM
"vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download",
# first version of finetuned models on zenodo
"vit_h_lm": "https://zenodo.org/record/8250299/files/vit_h_lm.pth?download=1",
"vit_b_lm": "https://zenodo.org/record/8250281/files/vit_b_lm.pth?download=1",
"vit_h_em": "https://zenodo.org/record/8250291/files/vit_h_em.pth?download=1",
"vit_b_em": "https://zenodo.org/record/8250260/files/vit_b_em.pth?download=1",
"vit_b_lm": "https://zenodo.org/records/10524791/files/vit_b_lm.pth?download=1",
"vit_b_em_organelles": "https://zenodo.org/records/10524828/files/vit_b_em_organelles.pth?download=1",
"vit_b_em_boundaries": "https://zenodo.org/records/10524894/files/vit_b_em_boundaries.pth?download=1",
},
)
return models
2 changes: 2 additions & 0 deletions scripts/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
new_models/
exported_models/
62 changes: 62 additions & 0 deletions scripts/export_models_for_upload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Helper scripts to export models for upload to zenodo.
"""

import hashlib
import os
import warnings
from glob import glob

import xxhash
from micro_sam.util import export_custom_sam_model

BUF_SIZE = 65536 # lets read stuff in 64kb chunks!


def export_model(model_path, model_type, export_name):
output_folder = "./exported_models"
os.makedirs(output_folder, exist_ok=True)

output_path = os.path.join(output_folder, export_name)
if os.path.exists(output_path):
print("The model", export_name, "has already been exported.")
return

with warnings.catch_warnings():
warnings.simplefilter("ignore")
export_custom_sam_model(
checkpoint_path=model_path,
model_type=model_type,
save_path=output_path,
)

print("Exported", export_name)

sha_checksum = hashlib.sha256()
xxh_checksum = xxhash.xxh128()

with open(output_path, "rb") as f:
while True:
data = f.read(BUF_SIZE)
if not data:
break
sha_checksum.update(data)
xxh_checksum.update(data)

print("sha256:", f"sha256:{sha_checksum.hexdigest()}")
print("xxh128:", f"xxh128:{xxh_checksum.hexdigest()}")


def export_all_models():
models = glob(os.path.join("./new_models/*.pt"))
model_type = "vit_b"
for model_path in models:
export_name = os.path.basename(model_path).replace(".pt", ".pth")
export_model(model_path, model_type, export_name)


def main():
export_all_models()


if __name__ == "__main__":
main()

0 comments on commit 102c4a4

Please sign in to comment.