From ad0a0259f67244ac9d528a6072bff2387fe376c6 Mon Sep 17 00:00:00 2001 From: "michitaro.koike" Date: Wed, 13 Feb 2019 06:18:05 +0900 Subject: [PATCH] Improve full-visit sky subtraction https://jira.lsstcorp.org/browse/DM-17426 --- python/lsst/pipe/drivers/background.py | 123 +++++++++++++++++++- python/lsst/pipe/drivers/constructCalibs.py | 48 ++++---- python/lsst/pipe/drivers/skyCorrection.py | 57 +++++---- 3 files changed, 179 insertions(+), 49 deletions(-) diff --git a/python/lsst/pipe/drivers/background.py b/python/lsst/pipe/drivers/background.py index d33e652..61e41bd 100644 --- a/python/lsst/pipe/drivers/background.py +++ b/python/lsst/pipe/drivers/background.py @@ -7,8 +7,10 @@ import lsst.afw.image as afwImage import lsst.afw.geom as afwGeom import lsst.afw.cameraGeom as afwCameraGeom +import lsst.meas.algorithms as measAlg +import lsst.afw.table as afwTable -from lsst.pex.config import Config, Field, ListField, ChoiceField, ConfigField, RangeField +from lsst.pex.config import Config, Field, ListField, ChoiceField, ConfigField, RangeField, ConfigurableField from lsst.pipe.base import Task @@ -466,6 +468,9 @@ class FocalPlaneBackgroundConfig(Config): "NONE": "No background estimation is to be attempted", }, ) + doSmooth = Field(dtype=bool, default=False, doc="Do smoothing?") + smoothScale = Field(dtype=float, doc="Smoothing scale") + smoothWindowSize = Field(dtype=int, default=15, doc="Window size for smoothing") binning = Field(dtype=int, default=64, doc="Binning to use for CCD background model (pixels)") @@ -719,7 +724,123 @@ def getStatsImage(self): values /= self._numbers thresh = self.config.minFrac*self.config.xSize*self.config.ySize isBad = self._numbers.getArray() < thresh + if self.config.doSmooth: + array = values.getArray() + array[isBad] = numpy.nan + gridSize = min(self.config.xSize, self.config.ySize) + array[:] = NanSafeSmoothing.gaussianSmoothing(array, self.config.smoothWindowSize, self.config.smoothScale / gridSize) + isBad = numpy.isnan(values.array) interpolateBadPixels(values.getArray(), isBad, self.config.interpolation) return values +class MaskObjectsConfig(Config): + """Configuration for MaskObjectsTask""" + nIter = Field(doc="Iteration for masking", dtype=int, default=3) + subtractBackground = ConfigurableField(target=measAlg.SubtractBackgroundTask, doc='Background configuration') + detection = ConfigurableField(target=measAlg.SourceDetectionTask, doc="Detection configuration") + detectSigma = Field(dtype=float, default=5., doc='Detection PSF gaussian sigmas') + doInterpolate = Field(dtype=bool, default=True, doc='Interpolate masked region?') + interpolate = ConfigurableField(target=measAlg.SubtractBackgroundTask, doc='Interpolate configuration') + + def setDefaults(self): + self.detection.reEstimateBackground = False + self.detection.doTempLocalBackground = False + self.detection.doTempWideBackground = False + self.detection.thresholdValue = 2.5 + # self.detection.thresholdPolarity = "both" + self.subtractBackground.binSize = 1024 + self.subtractBackground.useApprox = False + self.interpolate.binSize = 256 + self.interpolate.useApprox = False + + def validate(self): + assert not self.detection.reEstimateBackground + assert not self.detection.doTempLocalBackground + assert not self.detection.doTempWideBackground + + +class MaskObjectsTask(Task): + """MaskObjectsTask + + This task makes more exhaustive object mask by iteratively doing detection and background-subtraction. + The purpose of this task is to get true background removing faint tails of large objects. + This is useful to make clean SKY from relatively small number of visits. + + We deliberately use the specified 'detectSigma' instead of the PSF, + in order to better pick up the faint wings of objects. + """ + ConfigClass = MaskObjectsConfig + + def __init__(self, *args, **kwargs): + super(MaskObjectsTask, self).__init__(*args, **kwargs) + # Disposable schema suppresses warning from SourceDetectionTask.__init__ + self.makeSubtask("detection", schema=afwTable.Schema()) + self.makeSubtask('interpolate') + self.makeSubtask('subtractBackground') + + def run(self, exp): + for i in range(self.config.nIter): + self.log.info("Masking %d/%d", i + 1, self.config.nIter) + bg = self.subtractBackground.run(exp).background + fp = self.detection.detectFootprints(exp, sigma=self.config.detectSigma, clearMask=True) + exp.maskedImage += bg.getImage() + + if self.config.doInterpolate: + self.log.info("Interpolating") + smooth = self.interpolate.run(exp).background + exp.maskedImage += smooth.getImage() + mask = exp.maskedImage.mask + detected = mask.array & mask.getPlaneBitMask(['DETECTED']) > 0 + exp.maskedImage.image.array[detected] = smooth.getImage().getArray()[detected] + + +class NanSafeSmoothing: + ''' + Smooth image dealing with NaN pixels + ''' + + @classmethod + def gaussianSmoothing(cls, array, windowSize, sigma): + return cls._safeConvolve2d(array, cls._gaussianKernel(windowSize, sigma)) + + @classmethod + def _gaussianKernel(cls, windowSize, sigma): + ''' Returns 2D gaussian kernel ''' + s = sigma + r = windowSize + X, Y = numpy.meshgrid( + numpy.linspace(-r, r, 2 * r + 1), + numpy.linspace(-r, r, 2 * r + 1), + ) + kernel = cls._normalDist(X, s) * cls._normalDist(Y, s) + # cut off + kernel[X**2 + Y**2 > (r + 0.5)**2] = 0. + return kernel / kernel.sum() + + @staticmethod + def _normalDist(x, s=1., m=0.): + ''' Normal Distribution ''' + return 1. / (s * numpy.sqrt(2. * numpy.pi)) * numpy.exp(-(x-m)**2/(2*s**2)) / (s * numpy.sqrt(2*numpy.pi)) + + @staticmethod + def _safeConvolve2d(image, kernel): + ''' Convolve 2D safely dealing with NaNs in `image` ''' + assert numpy.ndim(image) == 2 + assert numpy.ndim(kernel) == 2 + assert kernel.shape[0] % 2 == 1 and kernel.shape[1] % 2 == 1 + ks = kernel.shape + kl = (ks[0] - 1) // 2, \ + (ks[1] - 1) // 2 + image2 = numpy.pad(image, ((kl[0], kl[0]), (kl[1], kl[1])), 'constant', constant_values=numpy.nan) + convolved = numpy.empty_like(image) + convolved.fill(numpy.nan) + for yi in range(convolved.shape[0]): + for xi in range(convolved.shape[1]): + patch = image2[yi : yi + ks[0], xi : xi + ks[1]] + c = patch * kernel + ok = numpy.isfinite(c) + if numpy.any(ok): + convolved[yi, xi] = numpy.nansum(c) / kernel[ok].sum() + return convolved + diff --git a/python/lsst/pipe/drivers/constructCalibs.py b/python/lsst/pipe/drivers/constructCalibs.py index 1ca05fa..2d659a1 100644 --- a/python/lsst/pipe/drivers/constructCalibs.py +++ b/python/lsst/pipe/drivers/constructCalibs.py @@ -11,7 +11,7 @@ from builtins import zip from builtins import range -from lsst.pex.config import Config, ConfigurableField, Field, ListField, ConfigField +from lsst.pex.config import Config, ConfigurableField, Field, ListField, ConfigField, ConfigurableField from lsst.pipe.base import Task, Struct, TaskRunner, ArgumentParser import lsst.daf.base as dafBase import lsst.afw.math as afwMath @@ -26,7 +26,7 @@ from lsst.ctrl.pool.parallel import BatchPoolTask from lsst.ctrl.pool.pool import Pool, NODE -from lsst.pipe.drivers.background import SkyMeasurementTask, FocalPlaneBackground, FocalPlaneBackgroundConfig +from lsst.pipe.drivers.background import SkyMeasurementTask, FocalPlaneBackground, FocalPlaneBackgroundConfig, MaskObjectsTask from lsst.pipe.drivers.visualizeVisit import makeCameraImage from .checksum import checksum @@ -1066,7 +1066,6 @@ def setDefaults(self): CalibConfig.setDefaults(self) self.detection.reEstimateBackground = False - class FringeTask(CalibTask): """Fringe construction task @@ -1115,10 +1114,9 @@ def processSingle(self, sensorRef): class SkyConfig(CalibConfig): """Configuration for sky frame construction""" - detection = ConfigurableField(target=measAlg.SourceDetectionTask, doc="Detection configuration") detectSigma = Field(dtype=float, default=2.0, doc="Detection PSF gaussian sigma") - subtractBackground = ConfigurableField(target=measAlg.SubtractBackgroundTask, - doc="Regular-scale background configuration, for object detection") + maskObjects = ConfigurableField(target=MaskObjectsTask, + doc="Configuration for masking objects aggressively") largeScaleBackground = ConfigField(dtype=FocalPlaneBackgroundConfig, doc="Large-scale background configuration") sky = ConfigurableField(target=SkyMeasurementTask, doc="Sky measurement") @@ -1145,8 +1143,7 @@ class SkyTask(CalibTask): def __init__(self, *args, **kwargs): CalibTask.__init__(self, *args, **kwargs) - self.makeSubtask("detection") - self.makeSubtask("subtractBackground") + self.makeSubtask("maskObjects") self.makeSubtask("sky") def scatterProcess(self, pool, ccdIdLists): @@ -1190,6 +1187,12 @@ def scatterProcess(self, pool, ccdIdLists): backgrounds[visit] = bgModel scales[visit] = np.median(bgModel.getStatsImage().getArray()) + # for debug + if False: + butler = pool._pool._store['process']['butler'] + outputDir = butler._initArgs['outputs']['root'] + bgModel.getStatsImage().writeFits(f"{outputDir}/bgModel-{'-'.join(map(str, visit))}.fits") + return mapToMatrix(pool, self.process, ccdIdLists, backgrounds=backgrounds, scales=scales) def measureBackground(self, cache, dataId): @@ -1230,27 +1233,18 @@ def processSingleBackground(self, dataRef): return dataRef.get("postISRCCD") exposure = CalibTask.processSingle(self, dataRef) - # Detect sources. Requires us to remove the background; we'll restore it later. - bgTemp = self.subtractBackground.run(exposure).background - footprints = self.detection.detectFootprints(exposure, sigma=self.config.detectSigma) - image = exposure.getMaskedImage() - if footprints.background is not None: - image += footprints.background.getImage() - - # Mask high pixels - variance = image.getVariance() - noise = np.sqrt(np.median(variance.getArray())) - isHigh = image.getImage().getArray() > self.config.maskThresh*noise - image.getMask().getArray()[isHigh] |= image.getMask().getPlaneBitMask("DETECTED") - - # Restore the background: it's what we want! - image += bgTemp.getImage() + self.maskObjects.run(exposure) + mi = exposure.maskedImage # Set detected/bad pixels to background to ensure they don't corrupt the background - maskVal = image.getMask().getPlaneBitMask(self.config.mask) - isBad = image.getMask().getArray() & maskVal > 0 - bgLevel = np.median(image.getImage().getArray()[~isBad]) - image.getImage().getArray()[isBad] = bgLevel + if self.config.maskObjects.doInterpolate: + mi.mask.array &= ~mi.mask.getPlaneBitMask(['DETECTED']) + else: + maskVal = mi.mask.getPlaneBitMask(self.config.mask) + isBad = mi.mask.array & maskVal > 0 + bgLevel = np.median(mi.image.array[~isBad]) + mi.image.array[isBad] = bgLevel + dataRef.put(exposure, "postISRCCD") return exposure diff --git a/python/lsst/pipe/drivers/skyCorrection.py b/python/lsst/pipe/drivers/skyCorrection.py index 86b9f13..62880bc 100644 --- a/python/lsst/pipe/drivers/skyCorrection.py +++ b/python/lsst/pipe/drivers/skyCorrection.py @@ -12,7 +12,7 @@ from lsst.pex.config import Config, Field, ConfigurableField, ConfigField from lsst.ctrl.pool.pool import Pool from lsst.ctrl.pool.parallel import BatchPoolTask -from lsst.pipe.drivers.background import SkyMeasurementTask, FocalPlaneBackground, FocalPlaneBackgroundConfig +from lsst.pipe.drivers.background import SkyMeasurementTask, FocalPlaneBackground, FocalPlaneBackgroundConfig, MaskObjectsTask import lsst.pipe.drivers.visualizeVisit as visualizeVisit DEBUG = False # Debugging outputs? @@ -41,21 +41,21 @@ def makeCameraImage(camera, exposures, filename=None, binning=8): class SkyCorrectionConfig(Config): """Configuration for SkyCorrectionTask""" bgModel = ConfigField(dtype=FocalPlaneBackgroundConfig, doc="Background model") + bgModel2 = ConfigField(dtype=FocalPlaneBackgroundConfig, doc="2nd Background model") sky = ConfigurableField(target=SkyMeasurementTask, doc="Sky measurement") - detection = ConfigurableField(target=measAlg.SourceDetectionTask, doc="Detection configuration") - doDetection = Field(dtype=bool, default=True, doc="Detect sources (to find good sky)?") - detectSigma = Field(dtype=float, default=5.0, doc="Detection PSF gaussian sigma") + maskObjects = ConfigurableField(target=MaskObjectsTask, doc="Mask Objects") + doMaskObjects = Field(dtype=bool, default=True, doc="Mask objects to find good sky?") doBgModel = Field(dtype=bool, default=True, doc="Do background model subtraction?") doSky = Field(dtype=bool, default=True, doc="Do sky frame subtraction?") binning = Field(dtype=int, default=8, doc="Binning factor for constructing focal-plane images") def setDefaults(self): Config.setDefaults(self) - self.detection.reEstimateBackground = False - self.detection.thresholdPolarity = "both" - self.detection.doTempLocalBackground = False - self.detection.thresholdType = "pixel_stdev" - self.detection.thresholdValue = 3.0 + self.maskObjects.doInterpolate = False + self.bgModel2.doSmooth = True + + def validate(self): + assert not self.maskObjects.doInterpolate class SkyCorrectionTask(BatchPoolTask): """Correct sky over entire focal plane""" @@ -64,9 +64,8 @@ class SkyCorrectionTask(BatchPoolTask): def __init__(self, *args, **kwargs): BatchPoolTask.__init__(self, *args, **kwargs) + self.makeSubtask("maskObjects") self.makeSubtask("sky") - # Disposable schema suppresses warning from SourceDetectionTask.__init__ - self.makeSubtask("detection", schema=afwTable.Schema()) @classmethod def _makeArgumentParser(cls, *args, **kwargs): @@ -102,7 +101,10 @@ def run(self, expRef): algorithms. We optionally apply: 1. A large-scale background model. + This step removes very-large-scale sky such as moonlight. 2. A sky frame. + 3. A medium-scale background model. + This step removes residual sky (This is smooth on the focal plane). Only the master node executes this method. The data is held on the slave nodes, which do all the hard work. @@ -112,6 +114,7 @@ def run(self, expRef): expRef : `lsst.daf.persistence.ButlerDataRef` Data reference for exposure. """ + if DEBUG: extension = "-%(visit)d.fits" % expRef.dataId @@ -159,13 +162,31 @@ def run(self, expRef): calibs = pool.mapToPrevious(self.collectSky, dataIdList) makeCameraImage(camera, calibs, "sky" + extension) + exposures = self.smoothFocalPlaneSubtraction(camera, pool, dataIdList) + # Persist camera-level image of calexp image = makeCameraImage(camera, exposures) expRef.put(image, "calexp_camera") pool.mapToPrevious(self.write, dataIdList) - def loadImage(self, cache, dataId): + def smoothFocalPlaneSubtraction(self, camera, pool, dataIdList): + '''Do 2nd Focal Plane subtraction + + After doSky, we get smooth focal plane image. + (Before doSky, sky pistons remain in HSC-G) + Now make smooth focal plane background and subtract it. + ''' + bgModel = FocalPlaneBackground.fromCamera(self.config.bgModel2, camera) + data = [Struct(dataId=dataId, bgModel=bgModel.clone()) for dataId in dataIdList] + bgModelList = pool.mapToPrevious(self.accumulateModel, data) + for ii, bg in enumerate(bgModelList): + self.log.info("Background %d: %d pixels", ii, bg._numbers.array.sum()) + bgModel.merge(bg) + exposures = pool.mapToPrevious(self.subtractModel, dataIdList, bgModel) + return exposures + + def loadImage(self, cache, dataId, outputDir): """Load original image and restore the sky This method runs on the slave nodes. @@ -187,15 +208,6 @@ def loadImage(self, cache, dataId): bgOld = cache.butler.get("calexpBackground", dataId, immediate=True) image = cache.exposure.getMaskedImage() - if self.config.doDetection: - # We deliberately use the specified 'detectSigma' instead of the PSF, in order to better pick up - # the faint wings of objects. - results = self.detection.detectFootprints(cache.exposure, doSmooth=True, - sigma=self.config.detectSigma, clearMask=True) - if hasattr(results, "background") and results.background: - # Restore any background that was removed during detection - maskedImage += results.background.getImage() - # We're removing the old background, so change the sense of all its components for bgData in bgOld: statsImage = bgData[0].getStatsImage() @@ -206,6 +218,9 @@ def loadImage(self, cache, dataId): for bgData in bgOld: cache.bgList.append(bgData) + if self.config.doMaskObjects: + self.maskObjects.run(cache.exposure) + return self.collect(cache) def measureSkyFrame(self, cache, dataId):