Skip to content

Commit

Permalink
adding docstrings and denoise tests
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Feb 21, 2024
1 parent fdd2b68 commit aeb8bc8
Show file tree
Hide file tree
Showing 14 changed files with 1,630 additions and 1,425 deletions.
267 changes: 126 additions & 141 deletions cellpose/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,56 +11,47 @@
import cv2
from scipy.stats import mode
import fastremap
from . import transforms, dynamics, utils, plot, metrics
from . import transforms, dynamics, utils, plot, metrics, resnet_torch

import torch
# from GPUtil import showUtilization as gpu_usage #for gpu memory debugging
from torch import nn
from torch.utils import mkldnn as mkldnn_utils
from . import resnet_torch

TORCH_ENABLED = True

core_logger = logging.getLogger(__name__)
tqdm_out = utils.TqdmToLogger(core_logger, level=logging.INFO)

def use_gpu(gpu_number=0, use_torch=True):
"""
Check if GPU is available for use.
def parse_model_string(pretrained_model):
if isinstance(pretrained_model, list):
model_str = os.path.split(pretrained_model[0])[-1]
else:
model_str = os.path.split(pretrained_model)[-1]
if len(model_str) > 3 and model_str[:4] == "unet":
cp = False
nclasses = max(2, int(model_str[4]))
elif len(model_str) > 7 and model_str[:8] == "cellpose":
cp = True
nclasses = 3
else:
return 3, True, True, False

if "residual" in model_str and "style" in model_str and "concatentation" in model_str:
ostrs = model_str.split("_")[2::2]
residual_on = ostrs[0] == "on"
style_on = ostrs[1] == "on"
concatenation = ostrs[2] == "on"
return nclasses, residual_on, style_on, concatenation
else:
if cp:
return 3, True, True, False
else:
return nclasses, False, False, True
Args:
gpu_number (int): The index of the GPU to be used. Default is 0.
use_torch (bool): Whether to use PyTorch for GPU check. Default is True.
Returns:
bool: True if GPU is available, False otherwise.
def use_gpu(gpu_number=0, use_torch=True):
""" check if gpu works """
Raises:
ValueError: If use_torch is False, as cellpose only runs with PyTorch now.
"""
if use_torch:
return _use_gpu_torch(gpu_number)
else:
raise ValueError("cellpose only runs with pytorch now")
raise ValueError("cellpose only runs with PyTorch now")


def _use_gpu_torch(gpu_number=0):
"""
Checks if CUDA is available and working with PyTorch.
Args:
gpu_number (int): The GPU device number to use (default is 0).
Returns:
bool: True if CUDA is available and working, False otherwise.
"""
try:
device = torch.device("cuda:" + str(gpu_number))
_ = torch.zeros([1, 2, 3]).to(device)
Expand All @@ -72,6 +63,18 @@ def _use_gpu_torch(gpu_number=0):


def assign_device(use_torch=True, gpu=False, device=0):
"""
Assigns the device (CPU or GPU or mps) to be used for computation.
Args:
use_torch (bool, optional): Whether to use torch for GPU detection. Defaults to True.
gpu (bool, optional): Whether to use GPU for computation. Defaults to False.
device (int or str, optional): The device index or name to be used. Defaults to 0.
Returns:
torch.device: The assigned device.
bool: True if GPU is used, False otherwise.
"""
mac = False
cpu = True
if isinstance(device, str):
Expand Down Expand Up @@ -102,11 +105,18 @@ def assign_device(use_torch=True, gpu=False, device=0):


def check_mkl(use_torch=True):
#core_logger.info("Running test snippet to check if MKL-DNN working")
"""
Checks if MKL-DNN is enabled and working.
Args:
use_torch (bool, optional): Whether to use torch. Defaults to True.
Returns:
bool: True if MKL-DNN is enabled, False otherwise.
"""
mkl_enabled = torch.backends.mkldnn.is_available()
if mkl_enabled:
mkl_enabled = True
#core_logger.info("MKL version working - CPU version is sped up.")
else:
core_logger.info(
"WARNING: MKL version on torch not working/installed - CPU version will be slightly slower."
Expand All @@ -117,6 +127,16 @@ def check_mkl(use_torch=True):


def _to_device(x, device):
"""
Converts the input tensor or numpy array to the specified device.
Args:
x (torch.Tensor or numpy.ndarray): The input tensor or numpy array.
device (torch.device): The target device.
Returns:
torch.Tensor: The converted tensor on the specified device.
"""
if not isinstance(x, torch.Tensor):
X = torch.from_numpy(x).float().to(device)
return X
Expand All @@ -125,12 +145,29 @@ def _to_device(x, device):


def _from_device(X):
"""
Converts a PyTorch tensor from the device to a NumPy array on the CPU.
Args:
X (torch.Tensor): The input PyTorch tensor.
Returns:
numpy.ndarray: The converted NumPy array.
"""
x = X.detach().cpu().numpy()
return x


def _forward(net, x):
""" convert imgs to torch and run network model and return numpy """
"""Converts images to torch tensors, runs the network model, and returns numpy arrays.
Args:
net (torch.nn.Module): The network model.
x (numpy.ndarray): The input images.
Returns:
Tuple[numpy.ndarray, numpy.ndarray]: The output predictions (flows and cellprob) and style features.
"""
X = _to_device(x, net.device)
net.eval()
if net.mkldnn:
Expand All @@ -143,42 +180,26 @@ def _forward(net, x):
return y, style


def run_net(net, imgs, augment=False, tile=True, tile_overlap=0.1, bsize=224):
""" run network on image or stack of images
def run_net(net, imgs, batch_size=8, augment=False, tile=True, tile_overlap=0.1,
bsize=224):
"""
Run network on image or stack of images.
(faster if augment is False)
Parameters
--------------
imgs: array [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan]
rsz: float (optional, default 1.0)
resize coefficient(s) for image
augment: bool (optional, default False)
tiles image with overlapping tiles and flips overlapped regions to augment
tile: bool (optional, default True)
tiles image to ensure GPU/CPU memory usage limited (recommended);
cannot be turned off for 3D segmentation
tile_overlap: float (optional, default 0.1)
fraction of overlap of tiles when computing flows
bsize: int (optional, default 224)
size of tiles to use in pixels [bsize x bsize]
Returns
------------------
y: array [Ly x Lx x 3] or [Lz x Ly x Lx x 3]
y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability
style: array [64]
1D array summarizing the style of the image,
if tiled it is averaged over tiles
Args:
imgs (np.ndarray): The input image or stack of images of size [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan].
batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8.
rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0.
augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
tile (bool, optional): Tiles image to ensure GPU/CPU memory usage limited (recommended); cannot be turned off for 3D segmentation. Defaults to True.
tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1.
bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224.
Returns:
y (np.ndarray): output of network, if tiled it is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3].
y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability.
style (np.ndarray): 1D array of size 256 summarizing the style of the image, if tiled it is averaged over tiles.
"""
if imgs.ndim == 4:
# make image Lz x nchan x Ly x Lx for net
Expand All @@ -204,8 +225,8 @@ def run_net(net, imgs, augment=False, tile=True, tile_overlap=0.1, bsize=224):

# run network
if tile or augment or imgs.ndim == 4:
y, style = _run_tiled(net, imgs, augment=augment, bsize=bsize,
tile_overlap=tile_overlap)
y, style = _run_tiled(net, imgs, augment=augment, bsize=bsize,
batch_size=batch_size, tile_overlap=tile_overlap)
else:
imgs = np.expand_dims(imgs, axis=0)
y, style = _forward(net, imgs)
Expand All @@ -221,36 +242,22 @@ def run_net(net, imgs, augment=False, tile=True, tile_overlap=0.1, bsize=224):


def _run_tiled(net, imgi, batch_size=8, augment=False, bsize=224, tile_overlap=0.1):
""" run network in tiles of size [bsize x bsize]
First image is split into overlapping tiles of size [bsize x bsize].
If augment, tiles have 50% overlap and are flipped at overlaps.
The average of the network output over tiles is returned.
Parameters
--------------
imgi: array [nchan x Ly x Lx] or [Lz x nchan x Ly x Lx]
augment: bool (optional, default False)
tiles image with overlapping tiles and flips overlapped regions to augment
bsize: int (optional, default 224)
size of tiles to use in pixels [bsize x bsize]
tile_overlap: float (optional, default 0.1)
fraction of overlap of tiles when computing flows
Returns
------------------
yf: array [3 x Ly x Lx] or [Lz x 3 x Ly x Lx]
yf is averaged over tiles
yf[0] is Y flow; yf[1] is X flow; yf[2] is cell probability
styles: array [64]
1D array summarizing the style of the image, averaged over tiles
"""
Run network on tiles of size [bsize x bsize]
(faster if augment is False)
Args:
imgs (np.ndarray): The input image or stack of images of size [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan].
batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8.
augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1.
bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224.
Returns:
y (np.ndarray): output of network, if tiled it is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3].
y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability.
style (np.ndarray): 1D array of size 256 summarizing the style of the image, if tiled it is averaged over tiles.
"""
nout = net.nout
if imgi.ndim == 4:
Expand Down Expand Up @@ -328,50 +335,28 @@ def _run_tiled(net, imgi, batch_size=8, augment=False, bsize=224, tile_overlap=0
return yf, styles


def run_3D(net, imgs, rsz=1.0, anisotropy=None, augment=False, tile=True,
def run_3D(net, imgs, batch_size=8, rsz=1.0, anisotropy=None, augment=False, tile=True,
tile_overlap=0.1, bsize=224, progress=None):
""" run network on stack of images
"""
Run network on image z-stack.
(faster if augment is False)
Parameters
--------------
imgs: array [Lz x Ly x Lx x nchan]
rsz: float (optional, default 1.0)
resize coefficient(s) for image
anisotropy: float (optional, default None)
for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y)
augment: bool (optional, default False)
tiles image with overlapping tiles and flips overlapped regions to augment
tile: bool (optional, default True)
tiles image to ensure GPU/CPU memory usage limited (recommended);
cannot be turned off for 3D segmentation
tile_overlap: float (optional, default 0.1)
fraction of overlap of tiles when computing flows
bsize: int (optional, default 224)
size of tiles to use in pixels [bsize x bsize]
progress: pyqt progress bar (optional, default None)
to return progress bar status to GUI
Returns
------------------
yf: array [Lz x Ly x Lx x 3]
y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability
style: array [64]
1D array summarizing the style of the image,
if tiled it is averaged over tiles
Args:
imgs (np.ndarray): The input image stack of size [Lz x Ly x Lx x nchan].
batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8.
rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0.
anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None.
augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
tile (bool, optional): Tiles image to ensure GPU/CPU memory usage limited (recommended); cannot be turned off for 3D segmentation. Defaults to True.
tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1.
bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224.
progress (QProgressBar, optional): pyqt progress bar. Defaults to None.
Returns:
y (np.ndarray): output of network, if tiled it is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3].
y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability.
style (np.ndarray): 1D array of size 256 summarizing the style of the image, if tiled it is averaged over tiles.
"""
sstr = ["YX", "ZY", "ZX"]
if anisotropy is not None:
Expand All @@ -390,8 +375,8 @@ def run_3D(net, imgs, rsz=1.0, anisotropy=None, augment=False, tile=True,
# per image
core_logger.info("running %s: %d planes of size (%d, %d)" %
(sstr[p], shape[0], shape[1], shape[2]))
y, style = run_net(net, xsl, augment=augment, tile=tile, bsize=bsize,
tile_overlap=tile_overlap)
y, style = run_net(net, xsl, batch_size=batch_size, augment=augment, tile=tile,
bsize=bsize, tile_overlap=tile_overlap)
y = transforms.resize_image(y, shape[1], shape[2])
yf[p] = y.transpose(ipm[p])
if progress is not None:
Expand Down
Loading

0 comments on commit aeb8bc8

Please sign in to comment.