Skip to content

Commit

Permalink
Modify region definition config in extended_psf
Browse files Browse the repository at this point in the history
  • Loading branch information
bazkiaei authored and leeskelvin committed Oct 24, 2023
1 parent dacf46b commit fb1e298
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 62 deletions.
145 changes: 92 additions & 53 deletions python/lsst/pipe/tasks/extended_psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,61 @@
"StackBrightStarsTask",
"MeasureExtendedPsfConfig",
"MeasureExtendedPsfTask",
"DetectorsInRegion",
]

from dataclasses import dataclass
from typing import List

from lsst.afw.fits import Fits, readMetadata
from lsst.afw.image import ImageF, MaskedImageF, MaskX
from lsst.afw.math import StatisticsControl, statisticsStack, stringToStatisticsProperty
from lsst.daf.base import PropertyList
from lsst.geom import Extent2I
from lsst.pex.config import ChoiceField, Config, ConfigurableField, DictField, Field, ListField
from lsst.pex.config import ChoiceField, Config, ConfigDictField, ConfigurableField, Field, ListField
from lsst.pipe.base import PipelineTaskConfig, PipelineTaskConnections, Struct, Task
from lsst.pipe.tasks.coaddBase import subBBoxIter
from lsst.pipe.base.connectionTypes import Input, Output
from lsst.pipe.tasks.coaddBase import subBBoxIter


def find_region_for_detector(detector_id, detectors_focal_plane_regions):
"""Find the focal plane region that contains a given detector.
Parameters
----------
detector_id : `int`
The detector ID.
detectors_focal_plane_regions :
`dict` [`str`, `lsst.pipe.tasks.extended_psf.DetectorsInRegion`]
A dictionary containing focal plane region names as keys, and the
corresponding detector IDs encoded within the values.
Returns
-------
key: `str`
The name of the region to which the given detector belongs.
Raises
------
KeyError
Raised if the given detector is not included in any focal plane region.
"""
for region_id, detectors_in_region in detectors_focal_plane_regions.items():
if detector_id in detectors_in_region.detectors:
return region_id
raise KeyError(
"Detector %d is not included in any focal plane region.",
detector_id,
)


class DetectorsInRegion(Config):
"""Provides a list of detectors that define a region."""

detectors = ListField[int](
doc="A list containing the detectors IDs.",
default=[],
)


@dataclass
Expand All @@ -54,13 +95,13 @@ class FocalPlaneRegionExtendedPsf:
----------
extended_psf_image : `lsst.afw.image.MaskedImageF`
Image of the extended PSF model.
detector_list : `list` [`int`]
region_detectors : `lsst.pipe.tasks.extended_psf.DetectorsInRegion`
List of detector IDs that define the focal plane region over which this
extended PSF model has been built (and can be used).
"""

extended_psf_image: MaskedImageF
detector_list: List[int]
region_detectors: DetectorsInRegion


class ExtendedPsf:
Expand All @@ -81,8 +122,8 @@ def __init__(self, default_extended_psf=None):
self.focal_plane_regions = {}
self.detectors_focal_plane_regions = {}

def add_regional_extended_psf(self, extended_psf_image, region_name, detector_list):
"""Add a new focal plane region, along wit hits extended PSF, to the
def add_regional_extended_psf(self, extended_psf_image, region_name, region_detectors):
"""Add a new focal plane region, along with its extended PSF, to the
ExtendedPsf instance.
Parameters
Expand All @@ -91,17 +132,17 @@ def add_regional_extended_psf(self, extended_psf_image, region_name, detector_li
Extended PSF model for the region.
region_name : `str`
Name of the focal plane region. Will be converted to all-uppercase.
detector_list : `list` [`int`]
List of IDs for the detectors that define the focal plane region.
region_detectors : `lsst.pipe.tasks.extended_psf.DetectorsInRegion`
List of detector IDs for the detectors that define a region on the
focal plane.
"""
region_name = region_name.upper()
if region_name in self.focal_plane_regions:
raise ValueError(f"Region name {region_name} is already used by this ExtendedPsf instance.")
self.focal_plane_regions[region_name] = FocalPlaneRegionExtendedPsf(
extended_psf_image=extended_psf_image, detector_list=detector_list
extended_psf_image=extended_psf_image, region_detectors=region_detectors
)
for det in detector_list:
self.detectors_focal_plane_regions[det] = region_name
self.detectors_focal_plane_regions[region_name] = region_detectors

def __call__(self, detector=None):
"""Return the appropriate extended PSF.
Expand Down Expand Up @@ -130,7 +171,7 @@ def __call__(self, detector=None):
return self.default_extended_psf
elif not self.focal_plane_regions:
return self.default_extended_psf
return self.get_regional_extended_psf(detector=detector)
return self.get_extended_psf(region_name=detector)

def __len__(self):
"""Returns the number of extended PSF models present in the instance.
Expand All @@ -146,31 +187,36 @@ def __len__(self):
n_regions += 1
return n_regions

def get_regional_extended_psf(self, region_name=None, detector=None):
"""Returns the extended PSF for a focal plane region.
def get_extended_psf(self, region_name):
"""Returns the extended PSF for a focal plane region or detector.
The region can be identified either by name, or through a detector ID.
This method takes either a region name or a detector ID as input. If
the input is a `str` type, it is assumed to be the region name and if
the input is a `int` type it is assumed to be the detector ID.
Parameters
----------
region_name : `str` or `None`, optional
Name of the region for which the extended PSF should be retrieved.
Ignored if ``detector`` is provided. Must be provided if
``detector`` is None.
detector : `int` or `None`, optional
If provided, returns the extended PSF for the focal plane region
that includes this detector.
region_name : `str` or `int`
Name of the region (str) or detector (int) for which the extended
PSF should be retrieved.
Returns
-------
extended_psf_image: `lsst.afw.image.MaskedImageF`
The extended PSF model for the requested region or detector.
Raises
------
ValueError
Raised if neither ``detector`` nor ``regionName`` is provided.
Raised if the input is not in the correct type.
"""
if detector is None:
if region_name is None:
raise ValueError("One of either a regionName or a detector number must be provided.")
if isinstance(region_name, str):
return self.focal_plane_regions[region_name].extended_psf_image
return self.focal_plane_regions[self.detectors_focal_plane_regions[detector]].extended_psf_image
elif isinstance(region_name, int):
region_name = find_region_for_detector(region_name, self.detectors_focal_plane_regions)
return self.focal_plane_regions[region_name].extended_psf_image
else:
raise ValueError("A region name with `str` type or detector number with `int` must be provided")

def write_fits(self, filename):
"""Write this object to a file.
Expand All @@ -187,7 +233,7 @@ def write_fits(self, filename):
metadata["HAS_REGIONS"] = True
metadata["REGION_NAMES"] = list(self.focal_plane_regions.keys())
for region, e_psf_region in self.focal_plane_regions.items():
metadata[region] = e_psf_region.detector_list
metadata[region] = e_psf_region.region_detectors.detectors
else:
metadata["HAS_REGIONS"] = False
fits_primary = Fits(filename, "w")
Expand Down Expand Up @@ -260,8 +306,9 @@ def read_fits(cls, filename):
# Generate extended PSF regions mappings.
for r_name in focal_plane_region_names:
extended_psf_image = MaskedImageF(**extended_psf_parts[r_name])
detector_list = global_metadata.getArray(r_name)
extended_psf.add_regional_extended_psf(extended_psf_image, r_name, detector_list)
region_detectors = DetectorsInRegion()
region_detectors.detectors = global_metadata.getArray(r_name)
extended_psf.add_regional_extended_psf(extended_psf_image, r_name, region_detectors)
# Instantiate ExtendedPsf.
return extended_psf

Expand Down Expand Up @@ -413,12 +460,13 @@ class MeasureExtendedPsfConfig(PipelineTaskConfig, pipelineConnections=MeasureEx
target=StackBrightStarsTask,
doc="Stack selected bright stars",
)
detectors_focal_plane_regions = DictField(
keytype=int,
itemtype=str,
detectors_focal_plane_regions = ConfigDictField(
keytype=str,
itemtype=DetectorsInRegion,
doc=(
"Mapping from detector IDs to focal plane region names. If empty, a constant extended PSF model "
"is built from all selected bright stars."
"Mapping from focal plane region names to detector IDs. "
"If empty, a constant extended PSF model is built from all selected bright stars. "
"It's possible for a single detector to be included in multiple regions if so desired."
),
default={},
)
Expand All @@ -442,15 +490,7 @@ class MeasureExtendedPsfTask(Task):
def __init__(self, initInputs=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.makeSubtask("stack_bright_stars")
self.focal_plane_regions = {
region: [] for region in set(self.config.detectors_focal_plane_regions.values())
}
for det, region in self.config.detectors_focal_plane_regions.items():
self.focal_plane_regions[region].append(det)
# make no assumption on what detector IDs should be, but if we come
# across one where there are processed bright stars, but no
# corresponding focal plane region, make sure we keep track of
# it (eg to raise a warning only once)
self.detectors_focal_plane_regions = self.config.detectors_focal_plane_regions
self.regionless_dets = []

def select_detector_refs(self, ref_list):
Expand All @@ -463,29 +503,28 @@ def select_detector_refs(self, ref_list):
`lsst.daf.butler._deferredDatasetHandle.DeferredDatasetHandle`
List of available bright star stamps data references.
"""
region_ref_list = {region: [] for region in self.focal_plane_regions.keys()}
region_ref_list = {region: [] for region in self.detectors_focal_plane_regions.keys()}
for dataset_handle in ref_list:
det_id = dataset_handle.ref.dataId["detector"]
if det_id in self.regionless_dets:
detector_id = dataset_handle.ref.dataId["detector"]
if detector_id in self.regionless_dets:
continue
try:
region_name = self.config.detectors_focal_plane_regions[det_id]
region_name = find_region_for_detector(detector_id, self.detectors_focal_plane_regions)
except KeyError:
self.log.warning(
"Bright stars were available for detector %d, but it was missing from the %s config "
"field, so they will not be used to build any of the extended PSF models.",
det_id,
detector_id,
"'detectors_focal_plane_regions'",
)
self.regionless_dets.append(det_id)
self.regionless_dets.append(detector_id)
continue
region_ref_list[region_name].append(dataset_handle)
return region_ref_list

def runQuantum(self, butlerQC, inputRefs, outputRefs):
input_data = butlerQC.get(inputRefs)
bss_ref_list = input_data["input_brightStarStamps"]
# Handle default case of a single region with empty detector list
if not self.config.detectors_focal_plane_regions:
self.log.info(
"No detector groups were provided to MeasureExtendedPsfTask; computing a single, "
Expand All @@ -505,7 +544,7 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):
continue
ext_psf = self.stack_bright_stars.run(ref_list, region_name)
output_e_psf.add_regional_extended_psf(
ext_psf, region_name, self.focal_plane_regions[region_name]
ext_psf, region_name, self.detectors_focal_plane_regions[region_name]
)
output = Struct(extended_psf=output_e_psf)
butlerQC.put(output, outputRefs)
25 changes: 16 additions & 9 deletions tests/test_extended_psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,20 @@ def setUp(self):
self.default_e_psf = make_extended_psf(1)[0]
self.constant_e_psf = extended_psf.ExtendedPsf(self.default_e_psf)
self.regions = ["NW", "SW", "E"]
self.region_detectors = [list(range(10)), list(range(10, 20)), list(range(20, 40))]
self.region_detectors = []
for i in range(3):
self.det = extended_psf.DetectorsInRegion()
r0 = 10*i
r1 = 10*(i+1)
self.det.detectors = list(range(r0, r1))
self.region_detectors.append(self.det)
self.regional_e_psfs = make_extended_psf(3)

def tearDown(self):
del self.default_e_psf
del self.regions
del self.region_detectors
del self.det
del self.regional_e_psfs

def test_constant_psf(self):
Expand Down Expand Up @@ -79,20 +86,20 @@ def test_regional_psf_addition(self):
self.assertEqual(len(with_default_e_psf), 3)
# Ensure we recover the correct regional PSF.
for j in range(2):
for det in self.region_detectors[j]:
for det in self.region_detectors[j].detectors:
# Try it by calling the class directly.
reg_psf0, reg_psf1 = starts_empty_e_psf(det), with_default_e_psf(det)
self.assertMaskedImagesAlmostEqual(reg_psf0, self.regional_e_psfs[j])
self.assertMaskedImagesAlmostEqual(reg_psf1, self.regional_e_psfs[j])
# Try it by passing on a detector number to the
# get_regional_extended_psf method.
reg_psf0 = starts_empty_e_psf.get_regional_extended_psf(detector=det)
reg_psf1 = with_default_e_psf.get_regional_extended_psf(detector=det)
# get_extended_psf method.
reg_psf0 = starts_empty_e_psf.get_extended_psf(region_name=det)
reg_psf1 = with_default_e_psf.get_extended_psf(region_name=det)
self.assertMaskedImagesAlmostEqual(reg_psf0, self.regional_e_psfs[j])
self.assertMaskedImagesAlmostEqual(reg_psf1, self.regional_e_psfs[j])
# Try it by passing on a region name.
reg_psf0 = starts_empty_e_psf.get_regional_extended_psf(region_name=self.regions[j])
reg_psf1 = with_default_e_psf.get_regional_extended_psf(region_name=self.regions[j])
reg_psf0 = starts_empty_e_psf.get_extended_psf(region_name=self.regions[j])
reg_psf1 = with_default_e_psf.get_extended_psf(region_name=self.regions[j])
self.assertMaskedImagesAlmostEqual(reg_psf0, self.regional_e_psfs[j])
self.assertMaskedImagesAlmostEqual(reg_psf1, self.regional_e_psfs[j])
# Ensure we recover the original default PSF.
Expand All @@ -118,7 +125,7 @@ def test_IO(self):
self.assertMaskedImagesAlmostEqual(per_region_e_psf0(), read_e_psf0())
# And per-region extended PSFs.
for j in range(3):
for det in self.region_detectors[j]:
for det in self.region_detectors[j].detectors:
reg_psf0, read_reg_psf0 = per_region_e_psf0(det), read_e_psf0(det)
self.assertMaskedImagesAlmostEqual(reg_psf0, read_reg_psf0)
# Test IO with a single per-region extended PSF.
Expand All @@ -130,7 +137,7 @@ def test_IO(self):
read_e_psf1 = extended_psf.ExtendedPsf.readFits(f.name)
self.assertEqual(per_region_e_psf0.detectors_focal_plane_regions,
read_e_psf0.detectors_focal_plane_regions)
for det in self.region_detectors[1]:
for det in self.region_detectors[1].detectors:
reg_psf1, read_reg_psf1 = per_region_e_psf1(det), read_e_psf1(det)
self.assertMaskedImagesAlmostEqual(reg_psf1, read_reg_psf1)

Expand Down

0 comments on commit fb1e298

Please sign in to comment.