From d08f296cbe3dbdb8a3ea7a01ad9f68794adb869d Mon Sep 17 00:00:00 2001 From: Ivan Ivanov Date: Fri, 22 Mar 2024 11:17:29 -0700 Subject: [PATCH] fix bug finding focus in stack with only one slice (#162) * fix bug finding focus in stack with only one slice * refactor for clarify * formatting * print -> warnings.warn * test single-slice case * fix test bugs --------- Co-authored-by: Talon Chandler --- tests/test_focus_estimator.py | 15 ++++++++-- waveorder/focus.py | 54 +++++++++++++++++++++++------------ 2 files changed, 48 insertions(+), 21 deletions(-) diff --git a/tests/test_focus_estimator.py b/tests/test_focus_estimator.py index 560dc580..c3cca07b 100644 --- a/tests/test_focus_estimator.py +++ b/tests/test_focus_estimator.py @@ -39,12 +39,21 @@ def test_focus_estimator(tmp_path): plot_path = tmp_path.joinpath("test.pdf") data3D = np.random.random((11, 256, 256)) slice = focus.focus_from_transverse_band( - data3D, ps, lambda_ill, NA_det, plot_path=str(plot_path) + data3D, NA_det, lambda_ill, ps, plot_path=str(plot_path) ) assert slice >= 0 assert slice <= data3D.shape[0] assert plot_path.exists() + # Check single slice + slice = focus.focus_from_transverse_band( + np.random.random((1, 10, 10)), + NA_det, + lambda_ill, + ps, + ) + assert slice == 0 + def test_focus_estimator_snr(tmp_path): ps = 6.5 / 100 @@ -66,9 +75,9 @@ def test_focus_estimator_snr(tmp_path): plot_path = tmp_path / f"test-{snr}.pdf" slice = focus.focus_from_transverse_band( data, - ps, - lambda_ill, NA_det, + lambda_ill, + ps, plot_path=plot_path, threshold_FWHM=5, ) diff --git a/waveorder/focus.py b/waveorder/focus.py index 3c259c19..709be235 100644 --- a/waveorder/focus.py +++ b/waveorder/focus.py @@ -3,6 +3,7 @@ from waveorder import util import matplotlib.pyplot as plt import numpy as np +import warnings def focus_from_transverse_band( @@ -60,10 +61,19 @@ def focus_from_transverse_band( >>> slice = focus_from_transverse_band(zyx_array, NA_det=0.55, lambda_ill=0.532, pixel_size=6.5/20) >>> in_focus_data = data[slice,:,:] """ - minmaxfunc = _check_focus_inputs( - zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions, mode + minmaxfunc = _mode_to_minmaxfunc(mode) + + _check_focus_inputs( + zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions ) + # Check for single slice + if zyx_array.shape[0] == 1: + warnings.warn( + "The dataset only contained a single slice. Returning trivial slice index = 0." + ) + return 0 + # Calculate coordinates _, Y, X = zyx_array.shape _, _, fxx, fyy = util.gen_coordinate((Y, X), pixel_size) @@ -94,25 +104,35 @@ def focus_from_transverse_band( # Plot if plot_path is not None: _plot_focus_metric( - plot_path, midband_sum, peak_index, in_focus_index, peak_results, threshold_FWHM + plot_path, + midband_sum, + peak_index, + in_focus_index, + peak_results, + threshold_FWHM, ) return in_focus_index +def _mode_to_minmaxfunc(mode): + if mode == "min": + minmaxfunc = np.argmin + elif mode == "max": + minmaxfunc = np.argmax + else: + raise ValueError("mode must be either `min` or `max`") + return minmaxfunc + + def _check_focus_inputs( - zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions, mode + zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions ): N = len(zyx_array.shape) if N != 3: raise ValueError( f"{N}D array supplied. `focus_from_transverse_band` only accepts 3D arrays." ) - if zyx_array.shape[0] == 1: - print( - "WARNING: The dataset only contained a single slice. Returning trivial slice index = 0." - ) - return 0 if NA_det < 0: raise ValueError("NA must be > 0") @@ -121,7 +141,7 @@ def _check_focus_inputs( if pixel_size < 0: raise ValueError("pixel_size must be > 0") if not 0.4 < lambda_ill / pixel_size < 10: - print( + warnings.warn( f"WARNING: lambda_ill/pixel_size = {lambda_ill/pixel_size}." f"Did you use the same units?" f"Did you enter the pixel size in (demagnified) object-space units?" @@ -134,17 +154,15 @@ def _check_focus_inputs( raise ValueError("midband_fractions[0] must be between 0 and 1") if not (0 <= midband_fractions[1] <= 1): raise ValueError("midband_fractions[1] must be between 0 and 1") - if mode == "min": - minmaxfunc = np.argmin - elif mode == "max": - minmaxfunc = np.argmax - else: - raise ValueError("mode must be either `min` or `max`") - return minmaxfunc def _plot_focus_metric( - plot_path, midband_sum, peak_index, in_focus_index, peak_results, threshold_FWHM + plot_path, + midband_sum, + peak_index, + in_focus_index, + peak_results, + threshold_FWHM, ): _, ax = plt.subplots(1, 1, figsize=(4, 4)) ax.plot(midband_sum, "-k")