From a02bcfe04b8ae3ad9bbc3ff4fb55a649986f7d7c Mon Sep 17 00:00:00 2001 From: Arun Kannawadi Date: Thu, 11 Jan 2024 21:17:09 -0800 Subject: [PATCH] Add a CloughTocher2DInterpolateTask --- python/lsst/pipe/tasks/interpImage.py | 201 +++++++++++++++++++++++++- 1 file changed, 200 insertions(+), 1 deletion(-) diff --git a/python/lsst/pipe/tasks/interpImage.py b/python/lsst/pipe/tasks/interpImage.py index 5691163637..1b0885dc4f 100644 --- a/python/lsst/pipe/tasks/interpImage.py +++ b/python/lsst/pipe/tasks/interpImage.py @@ -19,9 +19,18 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -__all__ = ["InterpImageConfig", "InterpImageTask"] +__all__ = ( + "CloughTocher2DInterpolateConfig", + "CloughTocher2DInterpolateTask", + "InterpImageConfig", + "InterpImageTask", +) + from contextlib import contextmanager +from itertools import product +from typing import Iterable + import lsst.pex.config as pexConfig import lsst.geom import lsst.afw.image as afwImage @@ -30,6 +39,7 @@ import lsst.meas.algorithms as measAlg import lsst.pipe.base as pipeBase from lsst.utils.timer import timeMethod +from scipy.interpolate import CloughTocher2DInterpolator class InterpImageConfig(pexConfig.Config): @@ -262,3 +272,192 @@ def interpolateImage(self, maskedImage, psf, defectList, fallbackValue): with self.transposeContext(maskedImage, defectList) as (image, defects): measAlg.interpolateOverDefects(image, psf, defects, fallbackValue, self.config.useFallbackValueAtEdge) + + +class CloughTocher2DInterpolateConfig(pexConfig.Config): + """Config for CloughTocher2DInterpolateTask.""" + + badMaskPlanes = pexConfig.ListField[str]( + doc="List of mask planes to interpolate over.", + default=["BAD", "SAT", "CR"], + ) + fillValue = pexConfig.Field[float]( + doc="Constant value to fill outside of the convex hull of the good " + "pixels. A long (longer than twice the ``interpLength``) streak of " + "bad pixels at an edge will be set to this value.", + default=0.0, + ) + interpLength = pexConfig.Field[int]( + doc="Maximum number of pixels away from a bad pixel to include in " + "building the interpolant. Must be greater than or equal to 1.", + default=4, + check=lambda x: x >= 1, + ) + + +class CloughTocher2DInterpolateTask(pipeBase.Task): + """Interpolated over bad pixels using CloughTocher2DInterpolator. + + Pixels with mask bits set to any of those listed ``badMaskPlanes`` config + are considered bad and are interpolated over. All good (non-bad) pixels + within ``interpLength`` pixels of a bad pixel in either direction are used + to construct the interpolant. An extended streak of bad pixels at an edge, + longer than ``interpLength``, is set to `fillValue`` specified in config. + """ + + ConfigClass = CloughTocher2DInterpolateConfig + _DefaultName = "cloughTocher2DInterpolate" + + def run(self, maskedImage, badpix: set[tuple[int, int]] | None = None, goodpix: dict | None = None): + """Interpolate over bad pixels in a masked image. + + This modifies the ``image`` attribute of the ``maskedImage`` in place. + This method returns, and accepts, the coordinates of the bad pixels + that were interpolated over, and the coordinates and values of the + good pixels that were used to construct the interpolant. This avoids + having to search for the bad and the good pixels repeatedly when the + mask plane is shared among many images, as would be the case with + noise realizations. + + Parameters + ---------- + maskedImage : `~lsst.afw.image.MaskedImage` + Image on which to perform interpolation (and modify in-place). + badpix: `set` [`tuple` [`int`, `int`]], optional + The coordinates of the bad pixels to interpolate over. + If None, then the coordinates of the bad pixels are determined by + an exhaustive search over the image. + goodpix: `dict` [`tuple` [`int`, `int`], `float`], optional + A mapping whose keys are the coordinates of the good pixels around + ``badpix`` that must be included when constructing the + interpolant. If ``badpix`` is provided, then the pixels in + ``goodpix`` are used as to construct the interpolatant. If not, + any additional good pixels around internally determined ``badpix`` + are added to ``goodpix`` and used to construct the interpolant. In + all cases, the values are populated from the image plane of the + ``maskedImage`` (provided values will be ignored. + + Returns + ------- + badpix: `set` [`tuple` [`int`, `int`]] + The coordinates of the bad pixels that were interpolated over. + goodpix: `dict` [`tuple` [`int`, `int`], `float`] + Mapping of the coordinates of the good pixels around ``badpix`` + to their values that were included when constructing the + interpolant. + + Raises + ------ + RuntimeError + If a pixel passed in as ``goodpix`` is found to be bad as specified by + ``maskPlanes``. + ValueError + If an input ``badpix`` is not found to be bad as specified by + ``maskPlanes``. + """ + max_window_extent = lsst.geom.Extent2I( + 2 * self.config.interpLength + 1, 2 * self.config.interpLength + 1 + ) + # Even if badpix and/or goodpix is provided, make sure to update + # the values of goodpix. + badpix, goodpix = find_good_pixels_around_bad_pixels( + maskedImage, + self.config.badMaskPlanes, + max_window_extent=max_window_extent, + badpix=badpix, + goodpix=goodpix, + ) + + # Construct the interpolant. + interpolator = CloughTocher2DInterpolator( + list(goodpix.keys()), + list(goodpix.values()), + fill_value=self.config.fillValue, + ) + + # Fill in the bad pixels. + for x, y in badpix: + maskedImage.image[x, y] = interpolator((x, y)) + + return badpix, goodpix + + +def find_good_pixels_around_bad_pixels( + image: afwImage.MaskedImage, + maskPlanes: Iterable[str], + *, + max_window_extent: lsst.geom.Extent2I, + badpix: set | None = None, + goodpix: dict | None = None, +): + """Find the location of bad pixels, and neighboring good pixels. + + Parameters + ---------- + image : `~lsst.afw.image.MaskedImage` + Image from which to find the bad and the good pixels. + maskPlanes : `list` [`str`] + List of mask planes to consider as bad pixels. + max_window_extent : `lsst.geom.Extent2I` + Maximum extent of the window around a bad pixel to consider when + looking for good pixels. + badpix : `list` [`tuple` [`int`, `int`]], optional + A known list of bad pixels. If provided, the function does not look for + any additional bad pixels, but it verifies that the provided + coordinates correspond to bad pixels. If an input``badpix`` is not + found to be bad as specified by ``maskPlanes``, an exception is raised. + goodpix : `dict` [`tuple` [`int`, `int`], `float`], optional + A known mapping of the coordinates of good pixels to their values, to + which any newly found good pixels locations will be added, and the + values (even for existing items) will be updated. + + Returns + ------- + badpix : `list` [`tuple` [`int`, `int`]] + The coordinates of the bad pixels. If ``badpix`` was provided as an + input argument, the returned quantity is the same as the input. + goodpix : `dict` [`tuple` [`int`, `int`], `float`] + Updated mapping of the coordinates of good pixels to their values. + + Raises + ------ + RuntimeError + If a pixel passed in as ``goodpix`` is found to be bad as specified by + ``maskPlanes``. + ValueError + If an input ``badpix`` is not found to be bad as specified by + ``maskPlanes``. + """ + + bbox = image.getBBox() + if badpix is None: + iterator = product(range(bbox.minX, bbox.maxX + 1), range(bbox.minY, bbox.maxY + 1)) + badpix = set() + else: + iterator = badpix + + if goodpix is None: + goodpix = {} + + for x, y in iterator: + if image.mask[x, y] & afwImage.Mask.getPlaneBitMask(maskPlanes): + if (x, y) in goodpix: + raise RuntimeError(f"Pixel ({x}, {y}) is bad as specified by maskPlanes {maskPlanes} but " + "passed in as goodpix") + badpix.add((x, y)) + window = lsst.geom.Box2I.makeCenteredBox( + center=lsst.geom.Point2D(x, y), # center has to be a Point2D instance. + size=max_window_extent, + ) + # Restrict to the bounding box of the image. + window.clip(bbox) + + for xx, yy in product(range(window.minX, window.maxX + 1), range(window.minY, window.maxY + 1)): + if not (image.mask[xx, yy] & afwImage.Mask.getPlaneBitMask(maskPlanes)): + goodpix[(xx, yy)] = image.image[xx, yy] + elif (x, y) in badpix: + # If (x, y) is in badpix, but did not get flagged as bad, + # raise an exception. + raise ValueError(f"Pixel ({x}, {y}) is not bad as specified by maskPlanes {maskPlanes}") + + return badpix, goodpix