From 31bc5befb3573f9dd0cc46db392788fe32c96bcd Mon Sep 17 00:00:00 2001 From: Qin Yu Date: Wed, 24 Jul 2024 01:47:35 +0200 Subject: [PATCH 1/5] Fix and test channel-wise linear normalization 1. Make it actually channel-wise; 2. Before it should take `(old_min, old_max)` of the input image (otherwise it doesn't make sense), but now it takes `(new_min, new_max)` of the expected normalized output. Before this fix, if `lowhigh` is provided, channel-wise linear normalization is performed by: ```python for c in ...: img_norm[...,c] = (img_norm[..., c] - lowhigh[0]) / (lowhigh[1] - lowhigh[0]) ``` which has a redundant `[..., c]` and is not "channel-wise", i.e. it is equivalent to ```python img_norm = (img_norm - lowhigh[0]) / (lowhigh[1] - lowhigh[0]) ``` This parameter, `lowhigh`, is never used in Cellpose itself, however. --- cellpose/transforms.py | 8 ++++++-- tests/test_transforms.py | 3 +++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/cellpose/transforms.py b/cellpose/transforms.py index 4fcacfcd..ad8e8bfd 100644 --- a/cellpose/transforms.py +++ b/cellpose/transforms.py @@ -638,9 +638,13 @@ def normalize_img(img, normalize=True, norm3D=False, invert=False, lowhigh=None, nchan = img_norm.shape[-1] if lowhigh is not None: + new_min, new_max = lowhigh for c in range(nchan): - img_norm[..., - c] = (img_norm[..., c] - lowhigh[0]) / (lowhigh[1] - lowhigh[0]) + c_min = img_norm[..., c].min() + c_max = img_norm[..., c].max() + eps = 1.0e-6 + img_norm[..., c] = (img_norm[..., c] - c_min) / (c_max - c_min + eps) + img_norm[..., c] = img_norm[..., c] * (new_max - new_min) + new_min else: if sharpen_radius > 0 or smooth_radius > 0: img_norm = smooth_sharpen_img(img_norm, sharpen_radius=sharpen_radius, diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 91f40d7e..21ddca6b 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -13,6 +13,9 @@ def test_normalize_img(data_dir): img = io.imread(str(data_dir.joinpath('3D').joinpath('rgb_3D.tif'))) img = img.transpose(0, 2, 3, 1).astype('float32') + img_norm = normalize_img(img, lowhigh=(0, 1)) + assert img_norm.min() >= 0 and img_norm.max() <= 1 + img_norm = normalize_img(img, norm3D=True) assert img_norm.shape == img.shape From 91469d1273ee0b3e3e4f614d0e8a80b0587ceb40 Mon Sep 17 00:00:00 2001 From: Qin Yu Date: Wed, 24 Jul 2024 03:54:16 +0200 Subject: [PATCH 2/5] Improve channel-wise inversion Previously, only the last channel was inverted. Now, channel 0 is inverted, which is typically cells used for segmentation. This update addresses the common use case where the biological structure of interest is in the first channel. Future improvements could allow for user-specified channels. --- cellpose/transforms.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cellpose/transforms.py b/cellpose/transforms.py index ad8e8bfd..d50e7df7 100644 --- a/cellpose/transforms.py +++ b/cellpose/transforms.py @@ -393,7 +393,7 @@ def move_min_dim(img, force=False): Args: img (ndarray): The input image. - force (bool, optional): If True, the minimum dimension will always be moved. + force (bool, optional): If True, the minimum dimension will always be moved. Defaults to False. Returns: @@ -668,7 +668,7 @@ def normalize_img(img, normalize=True, norm3D=False, invert=False, lowhigh=None, c], lower=percentile[0], upper=percentile[1], copy=False) if (tile_norm_blocksize > 0 or normalize) and invert: - img_norm[..., c] = -1 * img_norm[..., c] + 1 + img_norm[..., 0] = -1 * img_norm[..., 0] + 1 elif invert: error_message = "cannot invert image without normalizing" transforms_logger.critical(error_message) @@ -690,7 +690,7 @@ def resize_image(img0, Ly=None, Lx=None, rsz=None, interpolation=cv2.INTER_LINEA Lx (int, optional): Desired width of the resized image. Defaults to None. rsz (float, optional): Resize coefficient(s) for the image. If Ly is None, rsz is used. Defaults to None. interpolation (int, optional): OpenCV interpolation method. Defaults to cv2.INTER_LINEAR. - no_channels (bool, optional): Flag indicating whether to treat the third dimension as a channel. + no_channels (bool, optional): Flag indicating whether to treat the third dimension as a channel. Defaults to False. Returns: @@ -742,8 +742,8 @@ def pad_image_ND(img0, div=16, extra=1, min_size=None): extra (int, optional): Extra padding. Defaults to 1. min_size (tuple, optional): Minimum size of the image. Defaults to None. - Returns: - tuple containing + Returns: + tuple containing - I (ndarray): Padded image. - ysub (ndarray): Y range of pixels in the padded image corresponding to img0. - xsub (ndarray): X range of pixels in the padded image corresponding to img0. From 5eb40b747dc1be5c73cda2c2f24cf0cfbcb0bca8 Mon Sep 17 00:00:00 2001 From: Qin Yu Date: Tue, 10 Sep 2024 03:16:34 +0200 Subject: [PATCH 3/5] refactor `normalize_img()`: channel-wise `lowhigh`, volume-wise `invert` * `lowhigh` could be either a 2-tuple or a `nchan`-tuple of 2-tuple * `invert` inverts all channels Note that `lowhigh` shouldn't be combined with pre-smoothing or -sharpening. --- cellpose/transforms.py | 112 +++++++++++++++++++++++++---------------- 1 file changed, 68 insertions(+), 44 deletions(-) diff --git a/cellpose/transforms.py b/cellpose/transforms.py index 1ce37a5a..46d063e5 100644 --- a/cellpose/transforms.py +++ b/cellpose/transforms.py @@ -598,21 +598,24 @@ def reshape(data, channels=[0, 0], chan_first=False): def normalize_img(img, normalize=True, norm3D=False, invert=False, lowhigh=None, - percentile=None, sharpen_radius=0, smooth_radius=0, + percentile=(1., 99.), sharpen_radius=0, smooth_radius=0, tile_norm_blocksize=0, tile_norm_smooth3D=1, axis=-1): - """Normalize each channel of the image. + """Normalize each channel of the image with optional inversion, smoothing, and sharpening. Args: img (ndarray): The input image. It should have at least 3 dimensions. If it is 4-dimensional, it assumes the first non-channel axis is the Z dimension. normalize (bool, optional): Whether to perform normalization. Defaults to True. - norm3D (bool, optional): Whether to normalize in 3D. Defaults to False. + norm3D (bool, optional): Whether to normalize in 3D. If True, the entire 3D stack will + be normalized per channel. If False, normalization is applied per Z-slice. Defaults to False. invert (bool, optional): Whether to invert the image. Useful if cells are dark instead of bright. Defaults to False. - lowhigh (tuple, optional): The lower and upper bounds for normalization. If provided, it should be a tuple - of two values. Defaults to None. + lowhigh (tuple or ndarray, optional): The lower and upper bounds for normalization. + Can be a tuple of two values (applied to all channels) or an array of shape (nchan, 2) + for per-channel normalization. Incompatible with smoothing and sharpening. + Defaults to None. percentile (tuple, optional): The lower and upper percentiles for normalization. If provided, it should be - a tuple of two values. Each value should be between 0 and 100. Defaults to None. + a tuple of two values. Each value should be between 0 and 100. Defaults to (1.0, 99.0). sharpen_radius (int, optional): The radius for sharpening the image. Defaults to 0. smooth_radius (int, optional): The radius for smoothing the image. Defaults to 0. tile_norm_blocksize (int, optional): The block size for tile-based normalization. Defaults to 0. @@ -633,60 +636,81 @@ def normalize_img(img, normalize=True, norm3D=False, invert=False, lowhigh=None, transforms_logger.critical(error_message) raise ValueError(error_message) - if lowhigh is not None: - assert len(lowhigh) == 2 - assert lowhigh[1] > lowhigh[0] - elif percentile is not None: - assert len(percentile) == 2 - assert percentile[0] >= 0 and percentile[1] > 0 - assert percentile[0] < 100 and percentile[1] <= 100 - assert percentile[1] > percentile[0] - else: - percentile = [1., 99.] - img_norm = img.astype(np.float32) - # move channel axis last - img_norm = np.moveaxis(img_norm, axis, -1) + img_norm = np.moveaxis(img_norm, axis, -1) # Move channel axis to last + nchan = img_norm.shape[-1] + # Validate and handle lowhigh bounds + if lowhigh is not None: + lowhigh = np.array(lowhigh) + if lowhigh.ndim == 1: + lowhigh = np.tile(lowhigh, (nchan, 1)) + elif lowhigh.shape != (nchan, 2): + error_message = "`lowhigh` must have shape (nchan, 2)" + transforms_logger.critical(error_message) + raise ValueError(error_message) + + # Validate percentile + if percentile is None: + percentile = (1.0, 99.0) + elif not (0 <= percentile[0] < percentile[1] <= 100): + error_message = "Invalid percentile range, should be between 0 and 100" + transforms_logger.critical(error_message) + raise ValueError(error_message) + + # Apply normalization based on lowhigh or percentile if lowhigh is not None: - new_min, new_max = lowhigh for c in range(nchan): - c_min = img_norm[..., c].min() - c_max = img_norm[..., c].max() - eps = 1.0e-6 - img_norm[..., c] = (img_norm[..., c] - c_min) / (c_max - c_min + eps) - img_norm[..., c] = img_norm[..., c] * (new_max - new_min) + new_min + lower = lowhigh[c, 0] + upper = lowhigh[c, 1] + img_norm[..., c] = (img_norm[..., c] - lower) / (upper - lower) + else: + # Apply sharpening and smoothing if specified if sharpen_radius > 0 or smooth_radius > 0: - img_norm = smooth_sharpen_img(img_norm, sharpen_radius=sharpen_radius, - smooth_radius=smooth_radius) + img_norm = smooth_sharpen_img( + img_norm, sharpen_radius=sharpen_radius, smooth_radius=smooth_radius + ) + # Apply tile-based normalization or standard normalization if tile_norm_blocksize > 0: - img_norm = normalize99_tile(img_norm, blocksize=tile_norm_blocksize, - lower=percentile[0], upper=percentile[1], - smooth3D=tile_norm_smooth3D, norm3D=norm3D) + img_norm = normalize99_tile( + img_norm, + blocksize=tile_norm_blocksize, + lower=percentile[0], + upper=percentile[1], + smooth3D=tile_norm_smooth3D, + norm3D=norm3D, + ) elif normalize: - if img_norm.ndim == 3 or norm3D: + if img_norm.ndim == 3 or norm3D: # i.e. if YXC, or ZYXC with norm3D=True for c in range(nchan): - img_norm[..., c] = normalize99(img_norm[..., - c], lower=percentile[0], - upper=percentile[1], copy=False) - else: + img_norm[..., c] = normalize99( + img_norm[..., c], + lower=percentile[0], + upper=percentile[1], + copy=False, + ) + else: # i.e. if ZYXC with norm3D=False then per Z-slice for z in range(img_norm.shape[0]): for c in range(nchan): - img_norm[z, :, :, - c] = normalize99(img_norm[z, :, :, - c], lower=percentile[0], - upper=percentile[1], copy=False) - if (tile_norm_blocksize > 0 or normalize) and invert: - img_norm[..., 0] = -1 * img_norm[..., 0] + 1 - elif invert: - error_message = "cannot invert image without normalizing" + img_norm[z, ..., c] = normalize99( + img_norm[z, ..., c], + lower=percentile[0], + upper=percentile[1], + copy=False, + ) + + if invert: + if lowhigh is not None or tile_norm_blocksize > 0 or normalize: + img_norm = 1 - img_norm + else: + error_message = "Cannot invert image without normalization" transforms_logger.critical(error_message) raise ValueError(error_message) - # move channel axis back to original position + # Move channel axis back to the original position img_norm = np.moveaxis(img_norm, -1, axis) return img_norm From 93b050e2f545555381dca7998c7eecd22f85a52a Mon Sep 17 00:00:00 2001 From: Qin Yu Date: Tue, 10 Sep 2024 04:04:29 +0200 Subject: [PATCH 4/5] update tests for `normalize_img()` and format code * Tested the changes made in transform.py * Removed unused import and formatted whitespace --- cellpose/transforms.py | 39 ++++++++++++++---------------- tests/test_transforms.py | 52 ++++++++++++++++++++++++++++++---------- 2 files changed, 58 insertions(+), 33 deletions(-) diff --git a/cellpose/transforms.py b/cellpose/transforms.py index 46d063e5..13b091ec 100644 --- a/cellpose/transforms.py +++ b/cellpose/transforms.py @@ -2,19 +2,17 @@ Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. """ -import numpy as np +import logging import warnings + import cv2 +import numpy as np import torch -from torch.fft import fft2, ifft2, fftshift from scipy.ndimage import gaussian_filter1d - -import logging +from torch.fft import fft2, fftshift, ifft2 transforms_logger = logging.getLogger(__name__) -from . import dynamics, utils - def _taper_mask(ly=224, lx=224, sig=7.5): """ @@ -492,7 +490,7 @@ def convert_image(x, channels, channel_axis=None, z_axis=None, do_3D=False, ncha transforms_logger.warning(f"z_axis not specified, assuming it is dim {z_axis}") transforms_logger.warning(f"if this is actually the channel_axis, use 'model.eval(channel_axis={z_axis}, ...)'") z_axis = 0 - + if z_axis is not None: if x.ndim == 3: x = x[..., np.newaxis] @@ -512,7 +510,7 @@ def convert_image(x, channels, channel_axis=None, z_axis=None, do_3D=False, ncha if channel_axis is None: x = move_min_dim(x) - + if x.ndim > 3: transforms_logger.info( "multi-stack tiff read in as having %d planes %d channels" % @@ -533,7 +531,7 @@ def convert_image(x, channels, channel_axis=None, z_axis=None, do_3D=False, ncha % (nchan, nchan)) x = x[..., :nchan] - #if not do_3D and x.ndim > 3: + # if not do_3D and x.ndim > 3: # transforms_logger.critical("ERROR: cannot process 4D images in 2D mode") # raise ValueError("ERROR: cannot process 4D images in 2D mode") @@ -716,41 +714,40 @@ def normalize_img(img, normalize=True, norm3D=False, invert=False, lowhigh=None, return img_norm def resize_safe(img, Ly, Lx, interpolation=cv2.INTER_LINEAR): - """OpenCV resize function does not support uint32. - + """OpenCV resize function does not support uint32. + This function converts the image to float32 before resizing and then converts it back to uint32. Not safe! References issue: https://github.com/MouseLand/cellpose/issues/937 - + Implications: - * Runtime: Runtime increases by 5x-50x due to type casting. However, with resizing being very efficient, this is not + * Runtime: Runtime increases by 5x-50x due to type casting. However, with resizing being very efficient, this is not a big issue. A 10,000x10,000 image takes 0.47s instead of 0.016s to cast and resize on 32 cores on GPU. * Memory: However, memory usage increases. Not tested by how much. - + Args: img (ndarray): Image of size [Ly x Lx]. Ly (int): Desired height of the resized image. Lx (int): Desired width of the resized image. interpolation (int, optional): OpenCV interpolation method. Defaults to cv2.INTER_LINEAR. - + Returns: ndarray: Resized image of size [Ly x Lx]. - + """ - + # cast image cast = img.dtype == np.uint32 if cast: - # img = img.astype(np.float32) - + # resize img = cv2.resize(img, (Lx, Ly), interpolation=interpolation) - + # cast back if cast: transforms_logger.warning("resizing image from uint32 to float32 and back to uint32") img = img.round().astype(np.uint32) - + return img diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 5572e3ab..bb9d71b9 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,5 +1,7 @@ -from cellpose.transforms import * -from cellpose import io +import numpy as np + +from cellpose.io import imread +from cellpose.transforms import normalize_img, random_rotate_and_resize, resize_image def test_random_rotate_and_resize__default(): @@ -10,12 +12,9 @@ def test_random_rotate_and_resize__default(): def test_normalize_img(data_dir): - img = io.imread(str(data_dir.joinpath('3D').joinpath('rgb_3D.tif'))) + img = imread(str(data_dir.joinpath('3D').joinpath('rgb_3D.tif'))) img = img.transpose(0, 2, 3, 1).astype('float32') - img_norm = normalize_img(img, lowhigh=(0, 1)) - assert img_norm.min() >= 0 and img_norm.max() <= 1 - img_norm = normalize_img(img, norm3D=True) assert img_norm.shape == img.shape @@ -25,21 +24,50 @@ def test_normalize_img(data_dir): img_norm = normalize_img(img, norm3D=False, sharpen_radius=8) assert img_norm.shape == img.shape + +def test_normalize_img_with_lowhigh_and_invert(data_dir): + img = imread(str(data_dir.joinpath('3D').joinpath('rgb_3D.tif'))) + img = img.transpose(0, 2, 3, 1).astype('float32') + + img_norm = normalize_img(img, lowhigh=(img.min(), img.max())) + assert img_norm.min() >= 0 and img_norm.max() <= 1 + + img_norm_channelwise = normalize_img( + img, + lowhigh=( + (img[..., 0].min(), img[..., 0].max()), + (img[..., 1].min(), img[..., 1].max()), + ), + ) + assert img_norm_channelwise.min() >= 0 and img_norm_channelwise.max() <= 1 + + img_norm_channelwise_inverted = normalize_img( + img, + lowhigh=( + (img[..., 0].min(), img[..., 0].max()), + (img[..., 1].min(), img[..., 1].max()), + ), + invert=True, + ) + np.testing.assert_allclose( + img_norm_channelwise, 1 - img_norm_channelwise_inverted, rtol=1e-3 + ) + + def test_resize(data_dir): - img = io.imread(str(data_dir.joinpath('2D').joinpath('rgb_2D_tif.tif'))) - + img = imread(str(data_dir.joinpath('2D').joinpath('rgb_2D_tif.tif'))) + Lx = 100 Ly = 200 - + img8 = resize_image(img.astype("uint8"), Lx=Lx, Ly=Ly) assert img8.shape == (Ly, Lx, 3) assert img8.dtype == np.uint8 - + img16 = resize_image(img.astype("uint16"), Lx=Lx, Ly=Ly) assert img16.shape == (Ly, Lx, 3) assert img16.dtype == np.uint16 - + img32 = resize_image(img.astype("uint32"), Lx=Lx, Ly=Ly) assert img32.shape == (Ly, Lx, 3) assert img32.dtype == np.uint32 - From 4df273146b844686deb72aec675ccde317386fde Mon Sep 17 00:00:00 2001 From: Qin Yu Date: Tue, 10 Sep 2024 12:11:20 +0200 Subject: [PATCH 5/5] Add tests to cover exception handling in `normalize_img()` --- cellpose/transforms.py | 6 +-- tests/test_transforms.py | 85 +++++++++++++++++++++++++++------------- 2 files changed, 61 insertions(+), 30 deletions(-) diff --git a/cellpose/transforms.py b/cellpose/transforms.py index 13b091ec..cbb18b05 100644 --- a/cellpose/transforms.py +++ b/cellpose/transforms.py @@ -642,10 +642,10 @@ def normalize_img(img, normalize=True, norm3D=False, invert=False, lowhigh=None, # Validate and handle lowhigh bounds if lowhigh is not None: lowhigh = np.array(lowhigh) - if lowhigh.ndim == 1: - lowhigh = np.tile(lowhigh, (nchan, 1)) + if lowhigh.shape == (2,): + lowhigh = np.tile(lowhigh, (nchan, 1)) # Expand to per-channel bounds elif lowhigh.shape != (nchan, 2): - error_message = "`lowhigh` must have shape (nchan, 2)" + error_message = "`lowhigh` must have shape (2,) or (nchan, 2)" transforms_logger.critical(error_message) raise ValueError(error_message) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index bb9d71b9..7a56b52e 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,51 +1,61 @@ import numpy as np +import pytest from cellpose.io import imread from cellpose.transforms import normalize_img, random_rotate_and_resize, resize_image +@pytest.fixture +def img_3d(data_dir): + """Fixture to load 3D image data for tests.""" + img = imread(str(data_dir.joinpath('3D').joinpath('rgb_3D.tif'))) + return img.transpose(0, 2, 3, 1).astype('float32') + + +@pytest.fixture +def img_2d(data_dir): + """Fixture to load 2D image data for tests.""" + return imread(str(data_dir.joinpath('2D').joinpath('rgb_2D_tif.tif'))) + + def test_random_rotate_and_resize__default(): nimg = 2 X = [np.random.rand(64, 64) for i in range(nimg)] - random_rotate_and_resize(X) -def test_normalize_img(data_dir): - img = imread(str(data_dir.joinpath('3D').joinpath('rgb_3D.tif'))) - img = img.transpose(0, 2, 3, 1).astype('float32') +def test_normalize_img(img_3d): + img_norm = normalize_img(img_3d, norm3D=True) + assert img_norm.shape == img_3d.shape - img_norm = normalize_img(img, norm3D=True) - assert img_norm.shape == img.shape + img_norm = normalize_img(img_3d, norm3D=True, tile_norm_blocksize=25) + assert img_norm.shape == img_3d.shape - img_norm = normalize_img(img, norm3D=True, tile_norm_blocksize=25) - assert img_norm.shape == img.shape + img_norm = normalize_img(img_3d, norm3D=False, sharpen_radius=8) + assert img_norm.shape == img_3d.shape - img_norm = normalize_img(img, norm3D=False, sharpen_radius=8) - assert img_norm.shape == img.shape +def test_normalize_img_with_lowhigh_and_invert(img_3d): + img_norm = normalize_img(img_3d, lowhigh=(img_3d.min() + 1, img_3d.max() - 1)) + assert img_norm.min() < 0 and img_norm.max() > 1 -def test_normalize_img_with_lowhigh_and_invert(data_dir): - img = imread(str(data_dir.joinpath('3D').joinpath('rgb_3D.tif'))) - img = img.transpose(0, 2, 3, 1).astype('float32') - - img_norm = normalize_img(img, lowhigh=(img.min(), img.max())) - assert img_norm.min() >= 0 and img_norm.max() <= 1 + img_norm = normalize_img(img_3d, lowhigh=(img_3d.min(), img_3d.max())) + assert 0 <= img_norm.min() < img_norm.max() <= 1 img_norm_channelwise = normalize_img( - img, + img_3d, lowhigh=( - (img[..., 0].min(), img[..., 0].max()), - (img[..., 1].min(), img[..., 1].max()), + (img_3d[..., 0].min(), img_3d[..., 0].max()), + (img_3d[..., 1].min(), img_3d[..., 1].max()), ), ) assert img_norm_channelwise.min() >= 0 and img_norm_channelwise.max() <= 1 img_norm_channelwise_inverted = normalize_img( - img, + img_3d, lowhigh=( - (img[..., 0].min(), img[..., 0].max()), - (img[..., 1].min(), img[..., 1].max()), + (img_3d[..., 0].min(), img_3d[..., 0].max()), + (img_3d[..., 1].min(), img_3d[..., 1].max()), ), invert=True, ) @@ -54,20 +64,41 @@ def test_normalize_img_with_lowhigh_and_invert(data_dir): ) -def test_resize(data_dir): - img = imread(str(data_dir.joinpath('2D').joinpath('rgb_2D_tif.tif'))) +def test_normalize_img_exceptions(img_3d): + img_2D = img_3d[0, ..., 0] + with pytest.raises(ValueError): + normalize_img(img_2D) + + with pytest.raises(ValueError): + normalize_img(img_3d, lowhigh=(0, 1, 2)) + + with pytest.raises(ValueError): + normalize_img(img_3d, lowhigh=((0, 1), (0, 1, 2))) + + with pytest.raises(ValueError): + normalize_img(img_3d, lowhigh=((0, 1),) * 4) + + with pytest.raises(ValueError): + normalize_img(img_3d, percentile=(1, 101)) + + with pytest.raises(ValueError): + normalize_img( + img_3d, lowhigh=None, tile_norm_blocksize=0, normalize=False, invert=True + ) + +def test_resize(img_2d): Lx = 100 Ly = 200 - img8 = resize_image(img.astype("uint8"), Lx=Lx, Ly=Ly) + img8 = resize_image(img_2d.astype("uint8"), Lx=Lx, Ly=Ly) assert img8.shape == (Ly, Lx, 3) assert img8.dtype == np.uint8 - img16 = resize_image(img.astype("uint16"), Lx=Lx, Ly=Ly) + img16 = resize_image(img_2d.astype("uint16"), Lx=Lx, Ly=Ly) assert img16.shape == (Ly, Lx, 3) assert img16.dtype == np.uint16 - img32 = resize_image(img.astype("uint32"), Lx=Lx, Ly=Ly) + img32 = resize_image(img_2d.astype("uint32"), Lx=Lx, Ly=Ly) assert img32.shape == (Ly, Lx, 3) assert img32.dtype == np.uint32