Skip to content

Commit

Permalink
Refactor AEW
Browse files Browse the repository at this point in the history
  • Loading branch information
aemerywatkins committed Dec 16, 2024
1 parent 52660e1 commit f379682
Showing 1 changed file with 20 additions and 170 deletions.
190 changes: 20 additions & 170 deletions python/lsst/pipe/tasks/matchBackgrounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,12 @@

__all__ = ["MatchBackgroundsConnections", "MatchBackgroundsConfig", "MatchBackgroundsTask"]

import lsstDebug
import numpy as np
from lsst.afw.image import LOCAL, PARENT, ExposureF, ImageF, Mask, MaskedImageF
from lsst.afw.image import LOCAL, ImageF, Mask, MaskedImageF
from lsst.afw.math import (
MEAN,
MEANCLIP,
MEANSQUARE,
MEDIAN,
NPOINT,
STDEV,
VARIANCE,
ApproximateControl,
BackgroundControl,
Expand All @@ -43,7 +39,6 @@
stringToStatisticsProperty,
stringToUndersampleStyle,
)
from lsst.geom import Box2D, Box2I, PointI
from lsst.pex.config import ChoiceField, Field, ListField, RangeField
from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct, TaskError
from lsst.pipe.base.connectionTypes import Input, Output
Expand Down Expand Up @@ -137,7 +132,16 @@ class MatchBackgroundsConfig(PipelineTaskConfig, pipelineConnections=MatchBackgr
)
badMaskPlanes = ListField[str](
doc="Names of mask planes to ignore while estimating the background.",
default=["NO_DATA", "DETECTED", "DETECTED_NEGATIVE", "SAT", "BAD", "INTRP", "CR"],
default=[
"NO_DATA",
"DETECTED",
"DETECTED_NEGATIVE",
"SAT",
"BAD",
"INTRP",
"CR",
"NOT_DEBLENDED",
],
itemCheck=lambda x: x in Mask().getMaskPlaneDict(),
)
gridStatistic = ChoiceField(
Expand Down Expand Up @@ -237,9 +241,6 @@ def __init__(self, *args, **kwargs):
super().__init__(**kwargs)
self.statsFlag = stringToStatisticsProperty(self.config.gridStatistic)
self.statsCtrl = StatisticsControl()
# TODO: Check that setting the mask planes here work - these planes
# can vary from exposure to exposure, I think?
# Aaron: I think only the bit values vary, not the names, which this is referencing.
self.statsCtrl.setAndMask(Mask.getPlaneBitMask(self.config.badMaskPlanes))
self.statsCtrl.setNanSafe(True)
self.statsCtrl.setNumSigmaClip(self.config.numSigmaClip)
Expand Down Expand Up @@ -278,7 +279,7 @@ def run(self, warps):
raise TaskError("No exposures to match")

# Define a reference warp; 'warps' is modified in-place to exclude it
refWarp, refInd = self._defineWarps(warps=warps, refWarpVisit=self.config.refWarpVisit)
refWarp, refInd, bkgd = self._defineWarps(warps=warps, refWarpVisit=self.config.refWarpVisit)

# Images must be scaled to a common ZP
# Converting everything to nJy to accomplish this
Expand All @@ -287,29 +288,11 @@ def run(self, warps):

self.log.info("Matching %d Exposures", numExp)

# Creating a null BackgroundList object by fitting a blank image
statsFlag = stringToStatisticsProperty(self.config.gridStatistic)
self.statsCtrl.setNumSigmaClip(self.config.numSigmaClip)
self.statsCtrl.setNumIter(self.config.numIter)

# TODO: refactor below to construct blank bg model
im = refExposure.getMaskedImage()
blankIm = im.clone()
blankIm.image.array *= 0

width = blankIm.getWidth()
height = blankIm.getHeight()
nx = width // self.config.binSize
if width % self.config.binSize != 0:
nx += 1
ny = height // self.config.binSize
if height % self.config.binSize != 0:
ny += 1

bctrl = BackgroundControl(nx, ny, self.statsCtrl, statsFlag)
bctrl.setUndersampleStyle(self.config.undersampleStyle)

bkgd = makeBackground(blankIm, bctrl)
# Blank ref warp background as reference background
bkgdIm = bkgd.getImageF()
bkgdStatsIm = bkgd.getStatsImage()
bkgdIm *= 0
bkgdStatsIm *= 0
blank = BackgroundList(
(
bkgd,
Expand All @@ -325,7 +308,6 @@ def run(self, warps):
backgroundInfoList = []
matchedImageList = []
for exp in warps:
# TODO: simplify what this prints?
self.log.info(
"Matching background of %s to %s",
exp.dataId,
Expand All @@ -347,7 +329,6 @@ def run(self, warps):
toMatchExposure.image /= instFluxToNanojansky # Back to cts
matchedImageList.append(toMatchExposure)

# TODO: more elegant solution than inserting blank model at ref ind?
backgroundInfoList.insert(refInd, blank)
refExposure.image /= instFluxToNanojanskyRef # Back to cts
matchedImageList.insert(refInd, refExposure)
Expand Down Expand Up @@ -377,6 +358,8 @@ def _defineWarps(self, warps, refWarpVisit=None):
Reference warped exposure.
refWarpIndex : `int`
Index of the reference removed from the list of warps.
warpBg : `~lsst.afw.math.BackgroundMI`
Temporary background model, used to make a blank BG for the ref
Notes
-----
Expand Down Expand Up @@ -454,7 +437,7 @@ def _defineWarps(self, warps, refWarpVisit=None):
ind = np.nanargmin(costFunctionVals)
refWarp = warps.pop(ind)
self.log.info("Using best reference visit %d", refWarp.dataId["visit"])
return refWarp, ind
return refWarp, ind, warpBg

def _makeBackground(self, warp: MaskedImageF, binSize) -> tuple[BackgroundMI, BackgroundControl]:
"""Generate a simple binned background masked image for warped data.
Expand Down Expand Up @@ -528,11 +511,6 @@ def matchBackgrounds(self, refExposure, sciExposure):
model : `~lsst.afw.math.BackgroundMI`
Background model of difference image, reference - science
"""
# TODO: this is deprecated
if lsstDebug.Info(__name__).savefits:
refExposure.writeFits(lsstDebug.Info(__name__).figpath + "refExposure.fits")
sciExposure.writeFits(lsstDebug.Info(__name__).figpath + "sciExposure.fits")

# Check Configs for polynomials:
if self.config.usePolynomial:
x, y = sciExposure.getDimensions()
Expand Down Expand Up @@ -622,17 +600,6 @@ def matchBackgrounds(self, refExposure, sciExposure):
resids = bgZ - modelValueArr
rms = np.sqrt(np.mean(resids[~np.isnan(resids)] ** 2))

# TODO: also deprecated; _gridImage() maybe can go?
if lsstDebug.Info(__name__).savefits:
sciExposure.writeFits(lsstDebug.Info(__name__).figpath + "sciMatchedExposure.fits")

if lsstDebug.Info(__name__).savefig:
bbox = Box2D(refExposure.getMaskedImage().getBBox())
try:
self._debugPlot(bgX, bgY, bgZ, bgdZ, bkgdImage, bbox, modelValueArr, resids)
except Exception as e:
self.log.warning("Debug plot not generated: %s", e)

meanVar = makeStatistics(diffMI.getVariance(), diffMI.getMask(), MEANCLIP, self.statsCtrl).getValue()

diffIm = diffMI.getImage()
Expand All @@ -642,7 +609,6 @@ def matchBackgrounds(self, refExposure, sciExposure):

outBkgd = approx if self.config.usePolynomial else bkgd
# Convert this back into counts
# TODO: is there a one-line way to do this?
statsIm = outBkgd.getStatsImage()
statsIm /= instFluxToNanojansky
bkgdIm = outBkgd.getImageF()
Expand All @@ -667,119 +633,3 @@ def matchBackgrounds(self, refExposure, sciExposure):
False,
)
)

def _debugPlot(self, X, Y, Z, dZ, modelImage, bbox, model, resids):
"""
Consider deleting this entirely
Generate a plot showing the background fit and residuals.
It is called when lsstDebug.Info(__name__).savefig = True.
Saves the fig to lsstDebug.Info(__name__).figpath.
Displays on screen if lsstDebug.Info(__name__).display = True.
Parameters
----------
X : `np.ndarray`, (N,)
Array of x positions.
Y : `np.ndarray`, (N,)
Array of y positions.
Z : `np.ndarray`
Array of the grid values that were interpolated.
dZ : `np.ndarray`, (len(Z),)
Array of the error on the grid values.
modelImage : `Unknown`
Image of the model of the fit.
model : `np.ndarray`, (len(Z),)
Array of len(Z) containing the grid values predicted by the model.
resids : `Unknown`
Z - model.
"""
import matplotlib.colors
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

zeroIm = MaskedImageF(Box2I(bbox))
zeroIm += modelImage
x0, y0 = zeroIm.getXY0()
dx, dy = zeroIm.getDimensions()
if len(Z) == 0:
self.log.warning("No grid. Skipping plot generation.")
else:
max, min = np.max(Z), np.min(Z)
norm = matplotlib.colors.normalize(vmax=max, vmin=min)
maxdiff = np.max(np.abs(resids))
diffnorm = matplotlib.colors.normalize(vmax=maxdiff, vmin=-maxdiff)
rms = np.sqrt(np.mean(resids**2))
fig = plt.figure(1, (8, 6))
meanDz = np.mean(dZ)
grid = ImageGrid(
fig,
111,
nrows_ncols=(1, 2),
axes_pad=0.1,
share_all=True,
label_mode="L",
cbar_mode="each",
cbar_size="7%",
cbar_pad="2%",
cbar_location="top",
)
im = grid[0].imshow(
zeroIm.getImage().getArray(), extent=(x0, x0 + dx, y0 + dy, y0), norm=norm, cmap="Spectral"
)
im = grid[0].scatter(
X, Y, c=Z, s=15.0 * meanDz / dZ, edgecolor="none", norm=norm, marker="o", cmap="Spectral"
)
im2 = grid[1].scatter(X, Y, c=resids, edgecolor="none", norm=diffnorm, marker="s", cmap="seismic")
grid.cbar_axes[0].colorbar(im)
grid.cbar_axes[1].colorbar(im2)
grid[0].axis([x0, x0 + dx, y0 + dy, y0])
grid[1].axis([x0, x0 + dx, y0 + dy, y0])
grid[0].set_xlabel("model and grid")
grid[1].set_xlabel("residuals. rms = %0.3f" % (rms))
if lsstDebug.Info(__name__).savefig:
fig.savefig(lsstDebug.Info(__name__).figpath + self.debugDataIdString + ".png")
if lsstDebug.Info(__name__).display:
plt.show()
plt.clf()

def _gridImage(self, maskedImage, binsize, statsFlag):
"""Private method to grid an image for debugging."""
width, height = maskedImage.getDimensions()
x0, y0 = maskedImage.getXY0()
xedges = np.arange(0, width, binsize)
yedges = np.arange(0, height, binsize)
xedges = np.hstack((xedges, width)) # add final edge
yedges = np.hstack((yedges, height)) # add final edge

# Use lists/append to protect against the case where
# a bin has no valid pixels and should not be included in the fit
bgX = []
bgY = []
bgZ = []
bgdZ = []

for ymin, ymax in zip(yedges[0:-1], yedges[1:]):
for xmin, xmax in zip(xedges[0:-1], xedges[1:]):
subBBox = Box2I(
PointI(int(x0 + xmin), int(y0 + ymin)),
PointI(int(x0 + xmax - 1), int(y0 + ymax - 1)),
)
subIm = MaskedImageF(maskedImage, subBBox, PARENT, False)
stats = makeStatistics(
subIm,
MEAN | MEANCLIP | MEDIAN | NPOINT | STDEV,
self.statsCtrl,
)
npoints, _ = stats.getResult(NPOINT)
if npoints >= 2:
stdev, _ = stats.getResult(STDEV)
if stdev < self.config.gridStdevEpsilon:
stdev = self.config.gridStdevEpsilon
bgX.append(0.5 * (x0 + xmin + x0 + xmax))
bgY.append(0.5 * (y0 + ymin + y0 + ymax))
bgdZ.append(stdev / np.sqrt(npoints))
est, _ = stats.getResult(statsFlag)
bgZ.append(est)

return np.array(bgX), np.array(bgY), np.array(bgZ), np.array(bgdZ)

0 comments on commit f379682

Please sign in to comment.