Skip to content

Commit

Permalink
Improve full-visit sky subtraction
Browse files Browse the repository at this point in the history
  • Loading branch information
michitaro committed Mar 13, 2019
1 parent 7ce5656 commit ad0a025
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 49 deletions.
123 changes: 122 additions & 1 deletion python/lsst/pipe/drivers/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import lsst.afw.image as afwImage
import lsst.afw.geom as afwGeom
import lsst.afw.cameraGeom as afwCameraGeom
import lsst.meas.algorithms as measAlg
import lsst.afw.table as afwTable

from lsst.pex.config import Config, Field, ListField, ChoiceField, ConfigField, RangeField
from lsst.pex.config import Config, Field, ListField, ChoiceField, ConfigField, RangeField, ConfigurableField
from lsst.pipe.base import Task


Expand Down Expand Up @@ -466,6 +468,9 @@ class FocalPlaneBackgroundConfig(Config):
"NONE": "No background estimation is to be attempted",
},
)
doSmooth = Field(dtype=bool, default=False, doc="Do smoothing?")
smoothScale = Field(dtype=float, doc="Smoothing scale")
smoothWindowSize = Field(dtype=int, default=15, doc="Window size for smoothing")
binning = Field(dtype=int, default=64, doc="Binning to use for CCD background model (pixels)")


Expand Down Expand Up @@ -719,7 +724,123 @@ def getStatsImage(self):
values /= self._numbers
thresh = self.config.minFrac*self.config.xSize*self.config.ySize
isBad = self._numbers.getArray() < thresh
if self.config.doSmooth:
array = values.getArray()
array[isBad] = numpy.nan
gridSize = min(self.config.xSize, self.config.ySize)
array[:] = NanSafeSmoothing.gaussianSmoothing(array, self.config.smoothWindowSize, self.config.smoothScale / gridSize)
isBad = numpy.isnan(values.array)
interpolateBadPixels(values.getArray(), isBad, self.config.interpolation)
return values


class MaskObjectsConfig(Config):
"""Configuration for MaskObjectsTask"""
nIter = Field(doc="Iteration for masking", dtype=int, default=3)
subtractBackground = ConfigurableField(target=measAlg.SubtractBackgroundTask, doc='Background configuration')
detection = ConfigurableField(target=measAlg.SourceDetectionTask, doc="Detection configuration")
detectSigma = Field(dtype=float, default=5., doc='Detection PSF gaussian sigmas')
doInterpolate = Field(dtype=bool, default=True, doc='Interpolate masked region?')
interpolate = ConfigurableField(target=measAlg.SubtractBackgroundTask, doc='Interpolate configuration')

def setDefaults(self):
self.detection.reEstimateBackground = False
self.detection.doTempLocalBackground = False
self.detection.doTempWideBackground = False
self.detection.thresholdValue = 2.5
# self.detection.thresholdPolarity = "both"
self.subtractBackground.binSize = 1024
self.subtractBackground.useApprox = False
self.interpolate.binSize = 256
self.interpolate.useApprox = False

def validate(self):
assert not self.detection.reEstimateBackground
assert not self.detection.doTempLocalBackground
assert not self.detection.doTempWideBackground


class MaskObjectsTask(Task):
"""MaskObjectsTask
This task makes more exhaustive object mask by iteratively doing detection and background-subtraction.
The purpose of this task is to get true background removing faint tails of large objects.
This is useful to make clean SKY from relatively small number of visits.
We deliberately use the specified 'detectSigma' instead of the PSF,
in order to better pick up the faint wings of objects.
"""
ConfigClass = MaskObjectsConfig

def __init__(self, *args, **kwargs):
super(MaskObjectsTask, self).__init__(*args, **kwargs)
# Disposable schema suppresses warning from SourceDetectionTask.__init__
self.makeSubtask("detection", schema=afwTable.Schema())
self.makeSubtask('interpolate')
self.makeSubtask('subtractBackground')

def run(self, exp):
for i in range(self.config.nIter):
self.log.info("Masking %d/%d", i + 1, self.config.nIter)
bg = self.subtractBackground.run(exp).background
fp = self.detection.detectFootprints(exp, sigma=self.config.detectSigma, clearMask=True)
exp.maskedImage += bg.getImage()

if self.config.doInterpolate:
self.log.info("Interpolating")
smooth = self.interpolate.run(exp).background
exp.maskedImage += smooth.getImage()
mask = exp.maskedImage.mask
detected = mask.array & mask.getPlaneBitMask(['DETECTED']) > 0
exp.maskedImage.image.array[detected] = smooth.getImage().getArray()[detected]


class NanSafeSmoothing:
'''
Smooth image dealing with NaN pixels
'''

@classmethod
def gaussianSmoothing(cls, array, windowSize, sigma):
return cls._safeConvolve2d(array, cls._gaussianKernel(windowSize, sigma))

@classmethod
def _gaussianKernel(cls, windowSize, sigma):
''' Returns 2D gaussian kernel '''
s = sigma
r = windowSize
X, Y = numpy.meshgrid(
numpy.linspace(-r, r, 2 * r + 1),
numpy.linspace(-r, r, 2 * r + 1),
)
kernel = cls._normalDist(X, s) * cls._normalDist(Y, s)
# cut off
kernel[X**2 + Y**2 > (r + 0.5)**2] = 0.
return kernel / kernel.sum()

@staticmethod
def _normalDist(x, s=1., m=0.):
''' Normal Distribution '''
return 1. / (s * numpy.sqrt(2. * numpy.pi)) * numpy.exp(-(x-m)**2/(2*s**2)) / (s * numpy.sqrt(2*numpy.pi))

@staticmethod
def _safeConvolve2d(image, kernel):
''' Convolve 2D safely dealing with NaNs in `image` '''
assert numpy.ndim(image) == 2
assert numpy.ndim(kernel) == 2
assert kernel.shape[0] % 2 == 1 and kernel.shape[1] % 2 == 1
ks = kernel.shape
kl = (ks[0] - 1) // 2, \
(ks[1] - 1) // 2
image2 = numpy.pad(image, ((kl[0], kl[0]), (kl[1], kl[1])), 'constant', constant_values=numpy.nan)
convolved = numpy.empty_like(image)
convolved.fill(numpy.nan)
for yi in range(convolved.shape[0]):
for xi in range(convolved.shape[1]):
patch = image2[yi : yi + ks[0], xi : xi + ks[1]]
c = patch * kernel
ok = numpy.isfinite(c)
if numpy.any(ok):
convolved[yi, xi] = numpy.nansum(c) / kernel[ok].sum()
return convolved

48 changes: 21 additions & 27 deletions python/lsst/pipe/drivers/constructCalibs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from builtins import zip
from builtins import range

from lsst.pex.config import Config, ConfigurableField, Field, ListField, ConfigField
from lsst.pex.config import Config, ConfigurableField, Field, ListField, ConfigField, ConfigurableField
from lsst.pipe.base import Task, Struct, TaskRunner, ArgumentParser
import lsst.daf.base as dafBase
import lsst.afw.math as afwMath
Expand All @@ -26,7 +26,7 @@

from lsst.ctrl.pool.parallel import BatchPoolTask
from lsst.ctrl.pool.pool import Pool, NODE
from lsst.pipe.drivers.background import SkyMeasurementTask, FocalPlaneBackground, FocalPlaneBackgroundConfig
from lsst.pipe.drivers.background import SkyMeasurementTask, FocalPlaneBackground, FocalPlaneBackgroundConfig, MaskObjectsTask
from lsst.pipe.drivers.visualizeVisit import makeCameraImage

from .checksum import checksum
Expand Down Expand Up @@ -1066,7 +1066,6 @@ def setDefaults(self):
CalibConfig.setDefaults(self)
self.detection.reEstimateBackground = False


class FringeTask(CalibTask):
"""Fringe construction task
Expand Down Expand Up @@ -1115,10 +1114,9 @@ def processSingle(self, sensorRef):

class SkyConfig(CalibConfig):
"""Configuration for sky frame construction"""
detection = ConfigurableField(target=measAlg.SourceDetectionTask, doc="Detection configuration")
detectSigma = Field(dtype=float, default=2.0, doc="Detection PSF gaussian sigma")
subtractBackground = ConfigurableField(target=measAlg.SubtractBackgroundTask,
doc="Regular-scale background configuration, for object detection")
maskObjects = ConfigurableField(target=MaskObjectsTask,
doc="Configuration for masking objects aggressively")
largeScaleBackground = ConfigField(dtype=FocalPlaneBackgroundConfig,
doc="Large-scale background configuration")
sky = ConfigurableField(target=SkyMeasurementTask, doc="Sky measurement")
Expand All @@ -1145,8 +1143,7 @@ class SkyTask(CalibTask):

def __init__(self, *args, **kwargs):
CalibTask.__init__(self, *args, **kwargs)
self.makeSubtask("detection")
self.makeSubtask("subtractBackground")
self.makeSubtask("maskObjects")
self.makeSubtask("sky")

def scatterProcess(self, pool, ccdIdLists):
Expand Down Expand Up @@ -1190,6 +1187,12 @@ def scatterProcess(self, pool, ccdIdLists):
backgrounds[visit] = bgModel
scales[visit] = np.median(bgModel.getStatsImage().getArray())

# for debug
if False:
butler = pool._pool._store['process']['butler']
outputDir = butler._initArgs['outputs']['root']
bgModel.getStatsImage().writeFits(f"{outputDir}/bgModel-{'-'.join(map(str, visit))}.fits")

return mapToMatrix(pool, self.process, ccdIdLists, backgrounds=backgrounds, scales=scales)

def measureBackground(self, cache, dataId):
Expand Down Expand Up @@ -1230,27 +1233,18 @@ def processSingleBackground(self, dataRef):
return dataRef.get("postISRCCD")
exposure = CalibTask.processSingle(self, dataRef)

# Detect sources. Requires us to remove the background; we'll restore it later.
bgTemp = self.subtractBackground.run(exposure).background
footprints = self.detection.detectFootprints(exposure, sigma=self.config.detectSigma)
image = exposure.getMaskedImage()
if footprints.background is not None:
image += footprints.background.getImage()

# Mask high pixels
variance = image.getVariance()
noise = np.sqrt(np.median(variance.getArray()))
isHigh = image.getImage().getArray() > self.config.maskThresh*noise
image.getMask().getArray()[isHigh] |= image.getMask().getPlaneBitMask("DETECTED")

# Restore the background: it's what we want!
image += bgTemp.getImage()
self.maskObjects.run(exposure)
mi = exposure.maskedImage

# Set detected/bad pixels to background to ensure they don't corrupt the background
maskVal = image.getMask().getPlaneBitMask(self.config.mask)
isBad = image.getMask().getArray() & maskVal > 0
bgLevel = np.median(image.getImage().getArray()[~isBad])
image.getImage().getArray()[isBad] = bgLevel
if self.config.maskObjects.doInterpolate:
mi.mask.array &= ~mi.mask.getPlaneBitMask(['DETECTED'])
else:
maskVal = mi.mask.getPlaneBitMask(self.config.mask)
isBad = mi.mask.array & maskVal > 0
bgLevel = np.median(mi.image.array[~isBad])
mi.image.array[isBad] = bgLevel

dataRef.put(exposure, "postISRCCD")
return exposure

Expand Down
57 changes: 36 additions & 21 deletions python/lsst/pipe/drivers/skyCorrection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from lsst.pex.config import Config, Field, ConfigurableField, ConfigField
from lsst.ctrl.pool.pool import Pool
from lsst.ctrl.pool.parallel import BatchPoolTask
from lsst.pipe.drivers.background import SkyMeasurementTask, FocalPlaneBackground, FocalPlaneBackgroundConfig
from lsst.pipe.drivers.background import SkyMeasurementTask, FocalPlaneBackground, FocalPlaneBackgroundConfig, MaskObjectsTask
import lsst.pipe.drivers.visualizeVisit as visualizeVisit

DEBUG = False # Debugging outputs?
Expand Down Expand Up @@ -41,21 +41,21 @@ def makeCameraImage(camera, exposures, filename=None, binning=8):
class SkyCorrectionConfig(Config):
"""Configuration for SkyCorrectionTask"""
bgModel = ConfigField(dtype=FocalPlaneBackgroundConfig, doc="Background model")
bgModel2 = ConfigField(dtype=FocalPlaneBackgroundConfig, doc="2nd Background model")
sky = ConfigurableField(target=SkyMeasurementTask, doc="Sky measurement")
detection = ConfigurableField(target=measAlg.SourceDetectionTask, doc="Detection configuration")
doDetection = Field(dtype=bool, default=True, doc="Detect sources (to find good sky)?")
detectSigma = Field(dtype=float, default=5.0, doc="Detection PSF gaussian sigma")
maskObjects = ConfigurableField(target=MaskObjectsTask, doc="Mask Objects")
doMaskObjects = Field(dtype=bool, default=True, doc="Mask objects to find good sky?")
doBgModel = Field(dtype=bool, default=True, doc="Do background model subtraction?")
doSky = Field(dtype=bool, default=True, doc="Do sky frame subtraction?")
binning = Field(dtype=int, default=8, doc="Binning factor for constructing focal-plane images")

def setDefaults(self):
Config.setDefaults(self)
self.detection.reEstimateBackground = False
self.detection.thresholdPolarity = "both"
self.detection.doTempLocalBackground = False
self.detection.thresholdType = "pixel_stdev"
self.detection.thresholdValue = 3.0
self.maskObjects.doInterpolate = False
self.bgModel2.doSmooth = True

def validate(self):
assert not self.maskObjects.doInterpolate

class SkyCorrectionTask(BatchPoolTask):
"""Correct sky over entire focal plane"""
Expand All @@ -64,9 +64,8 @@ class SkyCorrectionTask(BatchPoolTask):

def __init__(self, *args, **kwargs):
BatchPoolTask.__init__(self, *args, **kwargs)
self.makeSubtask("maskObjects")
self.makeSubtask("sky")
# Disposable schema suppresses warning from SourceDetectionTask.__init__
self.makeSubtask("detection", schema=afwTable.Schema())

@classmethod
def _makeArgumentParser(cls, *args, **kwargs):
Expand Down Expand Up @@ -102,7 +101,10 @@ def run(self, expRef):
algorithms. We optionally apply:
1. A large-scale background model.
This step removes very-large-scale sky such as moonlight.
2. A sky frame.
3. A medium-scale background model.
This step removes residual sky (This is smooth on the focal plane).
Only the master node executes this method. The data is held on
the slave nodes, which do all the hard work.
Expand All @@ -112,6 +114,7 @@ def run(self, expRef):
expRef : `lsst.daf.persistence.ButlerDataRef`
Data reference for exposure.
"""

if DEBUG:
extension = "-%(visit)d.fits" % expRef.dataId

Expand Down Expand Up @@ -159,13 +162,31 @@ def run(self, expRef):
calibs = pool.mapToPrevious(self.collectSky, dataIdList)
makeCameraImage(camera, calibs, "sky" + extension)

exposures = self.smoothFocalPlaneSubtraction(camera, pool, dataIdList)

# Persist camera-level image of calexp
image = makeCameraImage(camera, exposures)
expRef.put(image, "calexp_camera")

pool.mapToPrevious(self.write, dataIdList)

def loadImage(self, cache, dataId):
def smoothFocalPlaneSubtraction(self, camera, pool, dataIdList):
'''Do 2nd Focal Plane subtraction
After doSky, we get smooth focal plane image.
(Before doSky, sky pistons remain in HSC-G)
Now make smooth focal plane background and subtract it.
'''
bgModel = FocalPlaneBackground.fromCamera(self.config.bgModel2, camera)
data = [Struct(dataId=dataId, bgModel=bgModel.clone()) for dataId in dataIdList]
bgModelList = pool.mapToPrevious(self.accumulateModel, data)
for ii, bg in enumerate(bgModelList):
self.log.info("Background %d: %d pixels", ii, bg._numbers.array.sum())
bgModel.merge(bg)
exposures = pool.mapToPrevious(self.subtractModel, dataIdList, bgModel)
return exposures

def loadImage(self, cache, dataId, outputDir):
"""Load original image and restore the sky
This method runs on the slave nodes.
Expand All @@ -187,15 +208,6 @@ def loadImage(self, cache, dataId):
bgOld = cache.butler.get("calexpBackground", dataId, immediate=True)
image = cache.exposure.getMaskedImage()

if self.config.doDetection:
# We deliberately use the specified 'detectSigma' instead of the PSF, in order to better pick up
# the faint wings of objects.
results = self.detection.detectFootprints(cache.exposure, doSmooth=True,
sigma=self.config.detectSigma, clearMask=True)
if hasattr(results, "background") and results.background:
# Restore any background that was removed during detection
maskedImage += results.background.getImage()

# We're removing the old background, so change the sense of all its components
for bgData in bgOld:
statsImage = bgData[0].getStatsImage()
Expand All @@ -206,6 +218,9 @@ def loadImage(self, cache, dataId):
for bgData in bgOld:
cache.bgList.append(bgData)

if self.config.doMaskObjects:
self.maskObjects.run(cache.exposure)

return self.collect(cache)

def measureSkyFrame(self, cache, dataId):
Expand Down

0 comments on commit ad0a025

Please sign in to comment.