diff --git a/python/lsst/pipe/tasks/interpImage.py b/python/lsst/pipe/tasks/interpImage.py index 569116363..9b5274e17 100644 --- a/python/lsst/pipe/tasks/interpImage.py +++ b/python/lsst/pipe/tasks/interpImage.py @@ -19,9 +19,18 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -__all__ = ["InterpImageConfig", "InterpImageTask"] +__all__ = ( + "CloughTocher2DInterpolateConfig", + "CloughTocher2DInterpolateTask", + "InterpImageConfig", + "InterpImageTask", +) + from contextlib import contextmanager +from itertools import product +from typing import Iterable + import lsst.pex.config as pexConfig import lsst.geom import lsst.afw.image as afwImage @@ -30,6 +39,7 @@ import lsst.meas.algorithms as measAlg import lsst.pipe.base as pipeBase from lsst.utils.timer import timeMethod +from scipy.interpolate import CloughTocher2DInterpolator class InterpImageConfig(pexConfig.Config): @@ -262,3 +272,192 @@ def interpolateImage(self, maskedImage, psf, defectList, fallbackValue): with self.transposeContext(maskedImage, defectList) as (image, defects): measAlg.interpolateOverDefects(image, psf, defects, fallbackValue, self.config.useFallbackValueAtEdge) + + +class CloughTocher2DInterpolateConfig(pexConfig.Config): + """Config for CloughTocher2DInterpolateTask.""" + + badMaskPlanes = pexConfig.ListField[str]( + doc="List of mask planes to interpolate over.", + default=["BAD", "SAT", "CR"], + ) + fillValue = pexConfig.Field[float]( + doc="Constant value to fill outside of the convex hull of the good " + "pixels. A long (longer than twice the ``interpLength``) streak of " + "bad pixels at an edge will be set to this value.", + default=0.0, + ) + interpLength = pexConfig.Field[int]( + doc="Maximum number of pixels away from a bad pixel to include in " + "building the interpolant. Must be greater than or equal to 1.", + default=4, + check=lambda x: x >= 1, + ) + + +class CloughTocher2DInterpolateTask(pipeBase.Task): + """Interpolated over bad pixels using CloughTocher2DInterpolator. + + Pixels with mask bits set to any of those listed ``badMaskPlanes`` config + are considered bad and are interpolated over. All good (non-bad) pixels + within ``interpLength`` pixels of a bad pixel in either direction are used + to construct the interpolant. An extended streak of bad pixels at an edge, + longer than ``interpLength``, is set to `fillValue`` specified in config. + """ + + ConfigClass = CloughTocher2DInterpolateConfig + _DefaultName = "cloughTocher2DInterpolate" + + def run(self, maskedImage, badpix: set[tuple[int, int]] | None = None, goodpix: dict | None = None): + """Interpolate over bad pixels in a masked image. + + This modifies the ``image`` attribute of the ``maskedImage`` in place. + This method returns, and accepts, the coordinates of the bad pixels + that were interpolated over, and the coordinates and values of the + good pixels that were used to construct the interpolant. This avoids + having to search for the bad and the good pixels repeatedly when the + mask plane is shared among many images, as would be the case with + noise realizations. + + Parameters + ---------- + maskedImage : `~lsst.afw.image.MaskedImage` + Image on which to perform interpolation (and modify in-place). + badpix: `set` [`tuple` [`int`, `int`]], optional + The coordinates of the bad pixels to interpolate over. + If None, then the coordinates of the bad pixels are determined by + an exhaustive search over the image. + goodpix: `dict` [`tuple` [`int`, `int`], `float`], optional + A mapping whose keys are the coordinates of the good pixels around + ``badpix`` that must be included when constructing the + interpolant. If ``badpix`` is provided, then the pixels in + ``goodpix`` are used as to construct the interpolatant. If not, + any additional good pixels around internally determined ``badpix`` + are added to ``goodpix`` and used to construct the interpolant. In + all cases, the values are populated from the image plane of the + ``maskedImage`` (provided values will be ignored. + + Returns + ------- + badpix: `set` [`tuple` [`int`, `int`]] + The coordinates of the bad pixels that were interpolated over. + goodpix: `dict` [`tuple` [`int`, `int`], `float`] + Mapping of the coordinates of the good pixels around ``badpix`` + to their values that were included when constructing the + interpolant. + + Raises + ------ + RuntimeError + If a pixel passed in as ``goodpix`` is found to be bad as specified by + ``maskPlanes``. + ValueError + If an input ``badpix`` is not found to be bad as specified by + ``maskPlanes``. + """ + max_window_extent = lsst.geom.Extent2I( + 2 * self.config.interpLength + 1, 2 * self.config.interpLength + 1 + ) + # Even if badpix and/or goodpix is provided, make sure to update + # the values of goodpix. + badpix, goodpix = find_good_pixels_around_bad_pixels( + maskedImage, + self.config.badMaskPlanes, + max_window_extent=max_window_extent, + badpix=badpix, + goodpix=goodpix, + ) + + # Construct the interpolant. + interpolator = CloughTocher2DInterpolator( + list(goodpix.keys()), + list(goodpix.values()), + fill_value=self.config.fillValue, + ) + + # Fill in the bad pixels. + for x, y in badpix: + maskedImage.image[x, y] = interpolator((x, y)) + + return badpix, goodpix + + +def find_good_pixels_around_bad_pixels( + image: afwImage.MaskedImage, + maskPlanes: Iterable[str], + *, + max_window_extent: lsst.geom.Extent2I, + badpix: set | None = None, + goodpix: dict | None = None, +): + """Find the location of bad pixels, and neighboring good pixels. + + Parameters + ---------- + image : `~lsst.afw.image.MaskedImage` + Image from which to find the bad and the good pixels. + maskPlanes : `list` [`str`] + List of mask planes to consider as bad pixels. + max_window_extent : `lsst.geom.Extent2I` + Maximum extent of the window around a bad pixel to consider when + looking for good pixels. + badpix : `list` [`tuple` [`int`, `int`]], optional + A known list of bad pixels. If provided, the function does not look for + any additional bad pixels, but it verifies that the provided + coordinates correspond to bad pixels. If an input``badpix`` is not + found to be bad as specified by ``maskPlanes``, an exception is raised. + goodpix : `dict` [`tuple` [`int`, `int`], `float`], optional + A known mapping of the coordinates of good pixels to their values, to + which any newly found good pixels locations will be added, and the + values (even for existing items) will be updated. + + Returns + ------- + badpix : `list` [`tuple` [`int`, `int`]] + The coordinates of the bad pixels. If ``badpix`` was provided as an + input argument, the returned quantity is the same as the input. + goodpix : `dict` [`tuple` [`int`, `int`], `float`] + Updated mapping of the coordinates of good pixels to their values. + + Raises + ------ + RuntimeError + If a pixel passed in as ``goodpix`` is found to be bad as specified by + ``maskPlanes``. + ValueError + If an input ``badpix`` is not found to be bad as specified by + ``maskPlanes``. + """ + + bbox = image.getBBox() + if badpix is None: + iterator = product(range(bbox.minX, bbox.maxX + 1), range(bbox.minY, bbox.maxY + 1)) + badpix = set() + else: + iterator = badpix + + if goodpix is None: + goodpix = {} + + for x, y in iterator: + if image.mask[x, y] & afwImage.Mask.getPlaneBitMask(maskPlanes): + if (x, y) in goodpix: + raise RuntimeError(f"Pixel ({x}, {y}) is bad as specified by maskPlanes {maskPlanes} but " + "passed in as goodpix") + badpix.add((x, y)) + window = lsst.geom.Box2I.makeCenteredBox( + center=lsst.geom.Point2D(x, y), # center has to be a Point2D instance. + size=max_window_extent, + ) + # Restrict to the bounding box of the image. + window.clip(bbox) + + for xx, yy in product(range(window.minX, window.maxX + 1), range(window.minY, window.maxY + 1)): + if not (image.mask[xx, yy] & afwImage.Mask.getPlaneBitMask(maskPlanes)): + goodpix[(xx, yy)] = image.image[xx, yy] + elif (x, y) in badpix: + # If (x, y) is in badpix, but did not get flagged as bad, + # raise an exception. + raise ValueError(f"Pixel ({x}, {y}) is not bad as specified by maskPlanes {maskPlanes}") + + return badpix, goodpix diff --git a/tests/test_import.py b/tests/test_import.py index e45144c8f..88ea64471 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -40,6 +40,7 @@ class PipeTasksImportTestCase(ImportTestCase): "lsst.pipe.tasks": { "assembleCoadd.py", # TODO: Remove in DM-40826 "assembleChi2Coadd.py", # TODO: Remove in DM-40826 + "configurableActions.py", # TODO: Remove in DM-38415 "dcrAssembleCoadd.py", # TODO: Remove in DM-40826 } } diff --git a/tests/test_interpImageTask.py b/tests/test_interpImageTask.py index ecc88c95e..4c229af07 100755 --- a/tests/test_interpImageTask.py +++ b/tests/test_interpImageTask.py @@ -36,7 +36,7 @@ import lsst.afw.image as afwImage import lsst.pex.config as pexConfig import lsst.ip.isr as ipIsr -from lsst.pipe.tasks.interpImage import InterpImageTask +from lsst.pipe.tasks.interpImage import CloughTocher2DInterpolateTask, InterpImageTask try: display @@ -166,6 +166,137 @@ def testTranspose(self): self.assertFloatsEqual(image.image.array, value) +class CloughTocher2DInterpolateTestCase(lsst.utils.tests.TestCase): + """Test the CloughTocher2DInterpolateTask.""" + + def setUp(self): + super().setUp() + + self.maskedimage = afwImage.MaskedImageF(100, 121) + for x in range(100): + for y in range(121): + self.maskedimage[x, y] = (3 * y + x * 5, 0, 1.0) + + # Clone the maskedimage so we can compare it after running the task. + self.reference = self.maskedimage.clone() + + # Set some central pixels as SAT + sliceX, sliceY = slice(30, 35), slice(40, 45) + self.maskedimage.mask[sliceX, sliceY] = afwImage.Mask.getPlaneBitMask("SAT") + self.maskedimage.image[sliceX, sliceY] = np.nan + # Put nans here to make sure interp is done ok + + # Set an entire column as BAD + self.maskedimage.mask[54:55, :] = afwImage.Mask.getPlaneBitMask("BAD") + self.maskedimage.image[54:55, :] = np.nan + + # Set an entire row as BAD + self.maskedimage.mask[:, 110:111] = afwImage.Mask.getPlaneBitMask("BAD") + self.maskedimage.image[:, 110:111] = np.nan + + # Set a diagonal set of pixels as CR + for i in range(74, 78): + self.maskedimage.mask[i, i] = afwImage.Mask.getPlaneBitMask("CR") + self.maskedimage.image[i, i] = np.nan + + # Set one of the edges as EDGE + self.maskedimage.mask[0:1, :] = afwImage.Mask.getPlaneBitMask("EDGE") + self.maskedimage.image[0:1, :] = np.nan + + # Set a smaller streak at the edge + self.maskedimage.mask[25:28, 0:1] = afwImage.Mask.getPlaneBitMask("EDGE") + self.maskedimage.image[25:28, 0:1] = np.nan + + # Update the reference image's mask alone, so we can compare them after + # running the task. + self.reference.mask.array[:, :] = self.maskedimage.mask.array + + # Create a noise image + self.noise = self.maskedimage.clone() + np.random.seed(12345) + self.noise.image.array[:, :] = np.random.normal(size=self.noise.image.array.shape) + + @lsst.utils.tests.methodParameters(n_runs=(1, 2)) + def test_interpolation(self, n_runs: int): + """Test that the interpolation is done correctly. + + Parameters + ---------- + n_runs : `int` + Number of times to run the task. Running the task more than once + should have no effect. + """ + config = CloughTocher2DInterpolateTask.ConfigClass() + config.badMaskPlanes = ( + "BAD", + "SAT", + "CR", + "EDGE", + ) + config.fillValue = 0.5 + task = CloughTocher2DInterpolateTask(config) + for n in range(n_runs): + task.run(self.maskedimage) + + # Assert that the mask and the variance planes remain unchanged. + self.assertImagesEqual(self.maskedimage.variance, self.reference.variance) + self.assertMasksEqual(self.maskedimage.mask, self.reference.mask) + + # Check that the long streak of bad pixels have been replaced with the + # fillValue, but not the short streak. + np.testing.assert_array_equal(self.maskedimage.image[0:1, :].array, config.fillValue) + with self.assertRaises(AssertionError): + np.testing.assert_array_equal(self.maskedimage.image[25:28, 0:1].array, config.fillValue) + + # Check that interpolated pixels are close to the reference (original), + # and that none of them is still NaN. + self.assertTrue(np.isfinite(self.maskedimage.image.array).all()) + self.assertImagesAlmostEqual( + self.maskedimage.image[1:, :], self.reference.image[1:, :], rtol=1e-05, atol=1e-08 + ) + + @lsst.utils.tests.methodParametersProduct(pass_badpix=(True, False), pass_goodpix=(True, False)) + def test_interpolation_with_noise(self, pass_badpix: bool = True, pass_goodpix: bool = True): + """Test that we can reuse the badpix and goodpix. + + Parameters + ---------- + pass_badpix : `bool` + Whether to pass the badpix to the task? + pass_goodpix : `bool` + Whether to pass the goodpix to the task? + """ + + config = CloughTocher2DInterpolateTask.ConfigClass() + config.badMaskPlanes = ( + "BAD", + "SAT", + "CR", + "EDGE", + ) + task = CloughTocher2DInterpolateTask(config) + + badpix, goodpix = task.run(self.noise) + task.run( + self.maskedimage, + badpix=(badpix if pass_badpix else None), + goodpix=(goodpix if pass_goodpix else None), + ) + + # Check that the long streak of bad pixels by the edge have been + # replaced with fillValue, but not the short streak. + np.testing.assert_array_equal(self.maskedimage.image[0:1, :].array, config.fillValue) + with self.assertRaises(AssertionError): + np.testing.assert_array_equal(self.maskedimage.image[25:28, 0:1].array, config.fillValue) + + # Check that interpolated pixels are close to the reference (original), + # and that none of them is still NaN. + self.assertTrue(np.isfinite(self.maskedimage.image.array).all()) + self.assertImagesAlmostEqual( + self.maskedimage.image[1:, :], self.reference.image[1:, :], rtol=1e-05, atol=1e-08 + ) + + def setup_module(module): lsst.utils.tests.init()