Skip to content

Commit

Permalink
Update for scarlet lite changes
Browse files Browse the repository at this point in the history
  • Loading branch information
fred3m committed Dec 24, 2024
1 parent 1dcb35c commit d18972c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 5 deletions.
27 changes: 23 additions & 4 deletions python/lsst/pipe/tasks/deblendCoaddSourcesPipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ class DeblendCoaddSourcesMultiConnections(PipelineTaskConnections,
multiple=True,
dimensions=("tract", "patch", "band", "skymap")
)
deconvolvedCoadds = cT.Input(
doc="Deconvolved coadds",
name="deconvolved_{inputCoaddName}_coadd",
storageClass="ExposureF",
multiple=True,
dimensions=("tract", "patch", "band", "skymap")
)
outputSchema = cT.InitOutput(
doc="Output of the schema used in deblending task",
name="{outputCoaddName}Coadd_deblendedFlux_schema",
Expand Down Expand Up @@ -251,14 +258,26 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputRefs = reorderRefs(inputRefs, bandOrder, dataIdKey="band")
inputs = butlerQC.get(inputRefs)
inputs["idFactory"] = self.config.idGenerator.apply(butlerQC.quantum.dataId).make_table_id_factory()
inputs["filters"] = [dRef.dataId["band"] for dRef in inputRefs.coadds]

# Ensure that the coadd bands and deconvolved coadd bands match
bands = [dRef.dataId["band"] for dRef in inputRefs.coadds]
deconvBands = [dRef.dataId["band"] for dRef in inputRefs.deconvolvedCoadds]
if len(bands) != len(deconvBands):
raise RuntimeError("Number of coadd bands and deconvolved coadd bands do not match")

for band, deconvBand in zip(bands, deconvBands):
if band != deconvBand:
raise RuntimeError(f"Bands {band} and {deconvBand} do not match")

inputs["bands"] = [dRef.dataId["band"] for dRef in inputRefs.coadds]
outputs = self.run(**inputs)
butlerQC.put(outputs, outputRefs)

def run(self, coadds, filters, mergedDetections, idFactory):
def run(self, coadds, bands, mergedDetections, deconvolvedCoadds, idFactory):
sources = self._makeSourceCatalog(mergedDetections, idFactory)
multiExposure = afwImage.MultibandExposure.fromExposures(filters, coadds)
catalog, modelData = self.multibandDeblend.run(multiExposure, sources)
multiExposure = afwImage.MultibandExposure.fromExposures(bands, coadds)
mDeconvolved = afwImage.MultibandExposure.fromExposures(bands, deconvolvedCoadds)
catalog, modelData = self.multibandDeblend.run(multiExposure, mDeconvolved, sources)
retStruct = Struct(deblendedCatalog=catalog, scarletModelData=modelData)
return retStruct

Expand Down
10 changes: 9 additions & 1 deletion tests/test_isPrimaryFlag.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from lsst.meas.algorithms import SourceDetectionTask, SkyObjectsTask, SetPrimaryFlagsTask
import lsst.meas.extensions.scarlet as mes
from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask
from lsst.meas.extensions.scarlet.deconvolveExposureTask import DeconvolveExposureTask
from lsst.meas.base import SingleFrameMeasurementTask
from lsst.afw.table import SourceCatalog

Expand Down Expand Up @@ -214,6 +215,10 @@ def testIsScarletPrimaryFlag(self):
skySourcesTask = SkyObjectsTask(name="skySources", config=skyConfig)
schema.addField("merge_peak_sky", type="Flag")

# Initialize the deconvolution task
deconvolveConfig = DeconvolveExposureTask.ConfigClass()
deconvolveTask = DeconvolveExposureTask(config=deconvolveConfig)

# Initialize the deblender task
scarletConfig = ScarletDeblendTask.ConfigClass()
scarletConfig.maxIter = 20
Expand Down Expand Up @@ -243,8 +248,11 @@ def testIsScarletPrimaryFlag(self):
src = catalog.addNew()
src.setFootprint(foot)
src.set("merge_peak_sky", True)
# deconvolve the images
deconvolved = deconvolveTask.run(coadds["test"], catalog).deconvolved
mDeconvolved = afwImage.MultibandExposure.fromExposures(["test"], [deconvolved])
# deblend
catalog, modelData = deblendTask.run(coadds, catalog)
catalog, modelData = deblendTask.run(coadds, mDeconvolved, catalog)
# Attach footprints to the catalog
mes.io.updateCatalogFootprints(
modelData=modelData,
Expand Down

0 comments on commit d18972c

Please sign in to comment.