Skip to content

Commit

Permalink
Merge branch 'main' into verify-norm
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Jul 16, 2024
2 parents be32576 + 3b79b91 commit 1e09cb0
Show file tree
Hide file tree
Showing 52 changed files with 2,458 additions and 120 deletions.
9 changes: 7 additions & 2 deletions experiments/unet-segmentation/dsb/train_boundaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@ def train_boundaries(args):

patch_shape = (1, 256, 256)
train_loader = get_dsb_loader(
args.input, patch_shape, split="train",
args.input, patch_shape=patch_shape, split="train",
download=True, boundaries=True, batch_size=args.batch_size
)

# Uncomment this for checking the loader.
# from torch_em.util.debug import check_loader
# check_loader(train_loader, 4)

val_loader = get_dsb_loader(
args.input, patch_shape, split="test",
args.input, patch_shape=patch_shape, split="test",
boundaries=True, batch_size=args.batch_size
)
loss = torch_em.loss.DiceLoss()
Expand Down
23 changes: 23 additions & 0 deletions scripts/datasets/medical/check_acdc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from torch_em.util.debug import check_loader
from torch_em.data.datasets.medical import get_acdc_loader
from torch_em.data import MinInstanceSampler


ROOT = "/media/anwai/ANWAI/data/acdc"


def check_acdc():
loader = get_acdc_loader(
path=ROOT,
patch_shape=(4, 256, 256),
batch_size=2,
split="train",
download=True,
sampler=MinInstanceSampler(min_num_instances=4),
)

check_loader(loader, 8)


if __name__ == "__main__":
check_acdc()
24 changes: 24 additions & 0 deletions scripts/datasets/medical/check_amos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from torch_em.util.debug import check_loader
from torch_em.data import MinInstanceSampler
from torch_em.data.datasets.medical import get_amos_loader

ROOT = "/media/anwai/ANWAI/data/amos"


def check_amos():
loader = get_amos_loader(
path=ROOT,
split="train",
patch_shape=(1, 512, 512),
modality="mri",
ndim=2,
batch_size=2,
download=True,
sampler=MinInstanceSampler(min_num_instances=3),
resize_inputs=False,
)
check_loader(loader, 8)


if __name__ == "__main__":
check_amos()
24 changes: 24 additions & 0 deletions scripts/datasets/medical/check_cbis_ddsm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from torch_em.data import MinInstanceSampler
from torch_em.util.debug import check_loader
from torch_em.data.datasets.medical import get_cbis_ddsm_loader


ROOT = "/media/anwai/ANWAI/data/cbis_ddsm"


def check_cbis_ddsm():
loader = get_cbis_ddsm_loader(
path=ROOT,
patch_shape=(512, 512),
batch_size=2,
split="Train",
task=None,
tumour_type=None,
resize_inputs=True,
sampler=MinInstanceSampler()
)
check_loader(loader, 8)


if __name__ == "__main__":
check_cbis_ddsm()
21 changes: 21 additions & 0 deletions scripts/datasets/medical/check_cholecseg8k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from torch_em.util.debug import check_loader
from torch_em.data.datasets.medical import get_cholecseg8k_loader


ROOT = "/media/anwai/ANWAI/data/cholecseg8k"


def get_cholecseg8k():
loader = get_cholecseg8k_loader(
path=ROOT,
patch_shape=(512, 512),
batch_size=2,
split="train",
resize_inputs=True,
download=False,
)
check_loader(loader, 8)


if __name__ == "__main__":
get_cholecseg8k()
23 changes: 23 additions & 0 deletions scripts/datasets/medical/check_covid19_seg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from torch_em.util.debug import check_loader
from torch_em.data import MinInstanceSampler
from torch_em.data.datasets.medical import get_covid19_seg_loader


ROOT = "/media/anwai/ANWAI/data/covid19_seg"


def check_covid19_seg():
loader = get_covid19_seg_loader(
path=ROOT,
patch_shape=(32, 512, 512),
batch_size=2,
task="lung",
download=True,
sampler=MinInstanceSampler(),
)

check_loader(loader, 8)


if __name__ == "__main__":
check_covid19_seg()
21 changes: 21 additions & 0 deletions scripts/datasets/medical/check_dca1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from torch_em.util.debug import check_loader
from torch_em.data.datasets.medical import get_dca1_loader


ROOT = "/media/anwai/ANWAI/data/dca1"


def check_dca1():
loader = get_dca1_loader(
path=ROOT,
patch_shape=(512, 512),
batch_size=2,
split="test",
resize_inputs=True,
download=False,
)
check_loader(loader, 8)


if __name__ == "__main__":
check_dca1()
23 changes: 23 additions & 0 deletions scripts/datasets/medical/check_duke_liver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from torch_em.util.debug import check_loader
from torch_em.data.datasets.medical import get_duke_liver_loader


ROOT = "/media/anwai/ANWAI/data/duke_liver"


def check_duke_liver():
from micro_sam.training import identity
loader = get_duke_liver_loader(
path=ROOT,
patch_shape=(32, 512, 512),
batch_size=2,
split="train",
download=False,
raw_transform=identity,

)
check_loader(loader, 8)


if __name__ == "__main__":
check_duke_liver()
20 changes: 20 additions & 0 deletions scripts/datasets/medical/check_han_seg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from torch_em.util.debug import check_loader
from torch_em.data.datasets.medical import get_han_seg_loader


ROOT = "/media/anwai/ANWAI/data/han-seg/"


def check_han_seg():
loader = get_han_seg_loader(
path=ROOT,
patch_shape=(32, 512, 512),
batch_size=2,
download=False,
)

check_loader(loader, 8)


if __name__ == "__main__":
check_han_seg()
22 changes: 22 additions & 0 deletions scripts/datasets/medical/check_isic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from torch_em.util.debug import check_loader
from torch_em.data.datasets.medical import get_isic_loader


ROOT = "/scratch/share/cidas/cca/data/isic"


def check_isic():
loader = get_isic_loader(
path=ROOT,
patch_shape=(700, 700),
batch_size=2,
split="test",
download=True,
resize_inputs=True,
)

check_loader(loader, 8, plt=True, save_path="./isic.png")


if __name__ == "__main__":
check_isic()
21 changes: 21 additions & 0 deletions scripts/datasets/medical/check_m2caiseg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from torch_em.util.debug import check_loader
from torch_em.data.datasets.medical import get_m2caiseg_loader


ROOT = "/media/anwai/ANWAI/data/m2caiseg"


def check_m2caiseg():
loader = get_m2caiseg_loader(
path=ROOT,
split="train",
patch_shape=(512, 512),
batch_size=2,
resize_inputs=True,
download=True,
)
check_loader(loader, 8)


if __name__ == "__main__":
check_m2caiseg()
5 changes: 3 additions & 2 deletions scripts/datasets/medical/check_oimhs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@
from torch_em.data.datasets.medical import get_oimhs_loader


ROOT = "/media/anwai/ANWAI/data/oimhs"
ROOT = "/scratch/share/cidas/cca/data/oimhs"


def check_oimhs():
loader = get_oimhs_loader(
path=ROOT,
patch_shape=(512, 512),
batch_size=2,
split="test",
download=False,
resize_inputs=True,
)

check_loader(loader, 8)
check_loader(loader, 8, plt=True, save_path="./oimhs.png")


if __name__ == "__main__":
Expand Down
19 changes: 1 addition & 18 deletions scripts/datasets/medical/check_osic_pulmofib.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,11 @@ def check_osic_pulmofib():
patch_shape=(1, 512, 512),
batch_size=2,
resize_inputs=False,
download=False,
download=True,
)

check_loader(loader, 8)


def visualize_data():
import os
from glob import glob

import nrrd
import napari

all_volume_paths = sorted(glob(os.path.join(ROOT, "nrrd_heart", "*", "*")))
for vol_path in all_volume_paths:
vol, header = nrrd.read(vol_path)

v = napari.Viewer()
v.add_image(vol.transpose(2, 0, 1))
napari.run()


if __name__ == "__main__":
# visualize_data()
check_osic_pulmofib()
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch_em.data.datasets.medical import get_papila_loader


ROOT = "/media/anwai/ANWAI/data/papila"
ROOT = "/scratch/share/cidas/cca/data/papila"


def check_papila():
Expand All @@ -16,7 +16,7 @@ def check_papila():
download=True,
)

check_loader(loader, 8)
check_loader(loader, 8, plt=True, save_path="./papila.png")


if __name__ == "__main__":
Expand Down
20 changes: 20 additions & 0 deletions scripts/datasets/medical/check_piccolo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from torch_em.util.debug import check_loader
from torch_em.data.datasets.medical import get_piccolo_loader


ROOT = "/media/anwai/ANWAI/data/piccolo"


def check_piccolo():
loader = get_piccolo_loader(
path=ROOT,
patch_shape=(512, 512),
batch_size=2,
split="train",
resize_inputs=True,
)
check_loader(loader, 8)


if __name__ == "__main__":
check_piccolo()
4 changes: 1 addition & 3 deletions scripts/datasets/medical/check_sega.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
def check_sega():
loader = get_sega_loader(
path=ROOT,
patch_shape=(1, 512, 512),
patch_shape=(32, 512, 512),
batch_size=2,
ndim=2,
data_choice="KiTS",
resize_inputs=True,
download=True,
sampler=MinInstanceSampler(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch_em.data.datasets.medical import get_siim_acr_loader


ROOT = "/media/anwai/ANWAI/data/siim_acr"
ROOT = "/scratch/share/cidas/cca/data/siim_acr"


def check_siim_acr():
Expand All @@ -13,10 +13,10 @@ def check_siim_acr():
patch_shape=(512, 512),
batch_size=2,
download=True,
resize_inputs=False,
resize_inputs=True,
sampler=MinInstanceSampler()
)
check_loader(loader, 8)
check_loader(loader, 8, plt=True, save_path="./siim_acr.png")


if __name__ == "__main__":
Expand Down
20 changes: 20 additions & 0 deletions scripts/datasets/medical/check_spider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from torch_em.util.debug import check_loader
from torch_em.data import MinInstanceSampler
from torch_em.data.datasets.medical import get_spider_loader


ROOT = "/media/anwai/ANWAI/data/spider"


def check_spider():
loader = get_spider_loader(
path=ROOT,
patch_shape=(1, 512, 512),
batch_size=2,
sampler=MinInstanceSampler()
)

check_loader(loader, 8)


check_spider()
Loading

0 comments on commit 1e09cb0

Please sign in to comment.