Skip to content

Commit

Permalink
Merge branch 'tickets/DM-43831'
Browse files Browse the repository at this point in the history
  • Loading branch information
taranu committed Jun 28, 2024
2 parents 8b7d742 + 7fdefcd commit de2a100
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 91 deletions.
208 changes: 129 additions & 79 deletions python/lsst/pipe/tasks/diff_matched_tract_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,22 @@
import astropy.units as u
from dataclasses import dataclass
from decimal import Decimal
from deprecated.sphinx import deprecated
from enum import Enum
import numpy as np
import pandas as pd
from scipy.stats import iqr
from smatch.matcher import sphdist
from types import SimpleNamespace
from typing import Sequence


def is_sequence_set(x: Sequence):
return len(x) == len(set(x))


@deprecated(reason="This method is no longer being used and will be removed after v28.",
version="v28.0", category=FutureWarning)
def is_percentile(x: str):
return 0 <= Decimal(x) <= 100

Expand Down Expand Up @@ -174,25 +178,34 @@ class DiffMatchedTractCatalogConfig(
)
column_ref_extended = pexConfig.Field[str](
default='is_pointsource',
deprecated='This field is no longer being used and will be removed after v28.',
doc='The boolean reference table column specifying if the target is extended',
)
column_ref_extended_inverted = pexConfig.Field[bool](
default=True,
deprecated='This field is no longer being used and will be removed after v28.',
doc='Whether column_ref_extended specifies if the object is compact, not extended',
)
column_target_extended = pexConfig.Field[str](
default='refExtendedness',
deprecated='This field is no longer being used and will be removed after v28.',
doc='The target table column estimating the extendedness of the object (0 <= x <= 1)',
)
compute_stats = pexConfig.Field[bool](
default=False,
deprecated='This field is no longer being used and will be removed after v28.',
doc='Whether to compute matched difference statistics',
)
include_unmatched = pexConfig.Field[bool](
default=False,
doc="Whether to include unmatched rows in the matched table",
doc='Whether to include unmatched rows in the matched table',
)

@property
def columns_in_ref(self) -> list[str]:
columns_all = [self.coord_format.column_ref_coord1, self.coord_format.column_ref_coord2,
self.column_ref_extended]
columns_all = [self.coord_format.column_ref_coord1, self.coord_format.column_ref_coord2]
if self.compute_stats:
columns_all.append(self.column_ref_extended)
for column_lists in (
(
self.columns_ref_copy,
Expand All @@ -206,8 +219,9 @@ def columns_in_ref(self) -> list[str]:

@property
def columns_in_target(self) -> list[str]:
columns_all = [self.coord_format.column_target_coord1, self.coord_format.column_target_coord2,
self.column_target_extended]
columns_all = [self.coord_format.column_target_coord1, self.coord_format.column_target_coord2]
if self.compute_stats:
columns_all.append(self.column_target_extended)
if self.coord_format.coords_ref_to_convert is not None:
columns_all.extend(col for col in self.coord_format.coords_ref_to_convert.values()
if col not in columns_all)
Expand Down Expand Up @@ -266,40 +280,47 @@ def columns_in_target(self) -> list[str]:
doc="Configuration for coordinate conversion",
)
extendedness_cut = pexConfig.Field[float](
dtype=float,
deprecated="This field is no longer being used and will be removed after v28.",
default=0.5,
doc='Minimum extendedness for a measured source to be considered extended',
)
mag_num_bins = pexConfig.Field[int](
deprecated="This field is no longer being used and will be removed after v28.",
doc='Number of magnitude bins',
default=15,
)
mag_brightest_ref = pexConfig.Field[float](
deprecated="This field is no longer being used and will be removed after v28.",
doc='Brightest magnitude cutoff for binning',
default=15,
)
mag_ceiling_target = pexConfig.Field[float](
deprecated="This field is no longer being used and will be removed after v28.",
doc='Ceiling (maximum/faint) magnitude for target sources',
default=None,
optional=True,
)
mag_faintest_ref = pexConfig.Field[float](
deprecated="This field is no longer being used and will be removed after v28.",
doc='Faintest magnitude cutoff for binning',
default=30,
)
mag_zeropoint_ref = pexConfig.Field[float](
deprecated="This field is no longer being used and will be removed after v28.",
doc='Magnitude zeropoint for reference sources',
default=31.4,
)
mag_zeropoint_target = pexConfig.Field[float](
deprecated="This field is no longer being used and will be removed after v28.",
doc='Magnitude zeropoint for target sources',
default=31.4,
)
percentiles = pexConfig.ListField[str](
deprecated="This field is no longer being used and will be removed after v28.",
doc='Percentiles to compute for diff/chi values',
# -2, -1, +1, +2 sigma percentiles for normal distribution
default=('2.275', '15.866', '84.134', '97.725'),
itemCheck=is_percentile,
itemCheck=lambda x: 0 <= Decimal(x) <= 100,
listCheck=is_sequence_set,
)
refcat_sharding_type = pexConfig.ChoiceField[str](
Expand Down Expand Up @@ -333,23 +354,29 @@ def validate(self):
raise ValueError("\n".join(errors))


@deprecated(reason="This class is no longer being used and will be removed after v28.",
version="v28.0", category=FutureWarning)
@dataclass(frozen=True)
class MeasurementTypeInfo:
doc: str
name: str


@deprecated(reason="This class is no longer being used and will be removed after v28.",
version="v28.0", category=FutureWarning)
class MeasurementType(Enum):
DIFF = MeasurementTypeInfo(
DIFF = SimpleNamespace(
doc="difference (measured - reference)",
name="diff",
)
CHI = MeasurementTypeInfo(
CHI = SimpleNamespace(
doc="scaled difference (measured - reference)/error",
name="chi",
)


@deprecated(reason="This class is no longer being used and will be removed after v28.",
version="v28.0", category=FutureWarning)
class Statistic(metaclass=ABCMeta):
"""A statistic that can be applied to a set of values.
"""
Expand Down Expand Up @@ -380,6 +407,8 @@ def value(self, values):
raise NotImplementedError('Subclasses must implement this method')


@deprecated(reason="This class is no longer being used and will be removed after v28.",
version="v28.0", category=FutureWarning)
class Median(Statistic):
"""The median of a set of values."""
@classmethod
Expand All @@ -394,6 +423,8 @@ def value(self, values):
return np.median(values)


@deprecated(reason="This class is no longer being used and will be removed after v28.",
version="v28.0", category=FutureWarning)
class SigmaIQR(Statistic):
"""The re-scaled interquartile range (sigma equivalent)."""
@classmethod
Expand All @@ -408,6 +439,8 @@ def value(self, values):
return iqr(values, scale='normal')


@deprecated(reason="This class is no longer being used and will be removed after v28.",
version="v28.0", category=FutureWarning)
class SigmaMAD(Statistic):
"""The re-scaled median absolute deviation (sigma equivalent)."""
@classmethod
Expand All @@ -422,6 +455,8 @@ def value(self, values):
return mad_std(values)


@deprecated(reason="This class is no longer being used and will be removed after v28.",
version="v28.0", category=FutureWarning)
@dataclass(frozen=True)
class Percentile(Statistic):
"""An arbitrary percentile.
Expand All @@ -447,14 +482,20 @@ def __post_init__(self):
raise ValueError(f'percentile={self.percentile} not >=0 and <= 100')


@deprecated(reason="This method is no longer being used and will be removed after v28.",
version="v28.0", category=FutureWarning)
def _get_stat_name(*args):
return '_'.join(args)


@deprecated(reason="This method is no longer being used and will be removed after v28.",
version="v28.0", category=FutureWarning)
def _get_column_name(band, *args):
return f"{band}_{_get_stat_name(*args)}"


@deprecated(reason="This method is no longer being used and will be removed after v28.",
version="v28.0", category=FutureWarning)
def compute_stats(values_ref, values_target, errors_target, row, stats, suffixes, prefix, skip_diff=False):
"""Compute statistics on differences and store results in a row.
Expand Down Expand Up @@ -508,24 +549,32 @@ def compute_stats(values_ref, values_target, errors_target, row, stats, suffixes
return row


@deprecated(reason="This class is no longer being used and will be removed after v28.",
version="v28.0", category=FutureWarning)
@dataclass(frozen=True)
class SourceTypeInfo:
is_extended: bool | None
label: str


@deprecated(reason="This class is no longer being used and will be removed after v28.",
version="v28.0", category=FutureWarning)
class SourceType(Enum):
ALL = SourceTypeInfo(is_extended=None, label='all')
RESOLVED = SourceTypeInfo(is_extended=True, label='resolved')
UNRESOLVED = SourceTypeInfo(is_extended=False, label='unresolved')
ALL = SimpleNamespace(is_extended=None, label='all')
RESOLVED = SimpleNamespace(is_extended=True, label='resolved')
UNRESOLVED = SimpleNamespace(is_extended=False, label='unresolved')


@deprecated(reason="This class is no longer being used and will be removed after v28.",
version="v28.0", category=FutureWarning)
class MatchType(Enum):
ALL = 'all'
MATCH_RIGHT = 'match_right'
MATCH_WRONG = 'match_wrong'


@deprecated(reason="This method is no longer being used and will be removed after v28.",
version="v28.0", category=FutureWarning)
def _get_columns(bands_columns: dict, suffixes: dict, suffixes_flux: dict, suffixes_mag: dict,
stats: dict, target: ComparableCatalog, column_dist: str):
"""Get column names for a table of difference statistics.
Expand Down Expand Up @@ -768,75 +817,76 @@ def run(
for column_flux in columns_convert.values():
cat_convert[column_flux] = u.ABmag.to(u.nJy, cat_convert[column_flux])

# TODO: Deprecate all matched difference output in DM-43831 (per RFC-1008)

# Slightly smelly hack for when a column (like distance) is already relative to truth
column_dummy = 'dummy'
cat_ref[column_dummy] = np.zeros_like(ref.coord1)

# Add a boolean column for whether a match is classified correctly
# TODO: remove the assumption of a boolean column
extended_ref = cat_ref[config.column_ref_extended] == (not config.column_ref_extended_inverted)

extended_target = cat_target[config.column_target_extended].values >= config.extendedness_cut

# Define difference/chi columns and statistics thereof
suffixes = {MeasurementType.DIFF: 'diff', MeasurementType.CHI: 'chi'}
# Skip diff for fluxes - covered by mags
suffixes_flux = {MeasurementType.CHI: suffixes[MeasurementType.CHI]}
# Skip chi for magnitudes, which have strange errors
suffixes_mag = {MeasurementType.DIFF: suffixes[MeasurementType.DIFF]}
stats = {stat.name_short(): stat() for stat in (Median, SigmaIQR, SigmaMAD)}

for percentile in self.config.percentiles:
stat = Percentile(percentile=float(Decimal(percentile)))
stats[stat.name_short()] = stat

# Get dict of column names
columns, n_models = _get_columns(
bands_columns=config.columns_flux,
suffixes=suffixes,
suffixes_flux=suffixes_flux,
suffixes_mag=suffixes_mag,
stats=stats,
target=target,
column_dist=column_dist,
)

# Setup numpy table
n_bins = config.mag_num_bins
data = np.zeros((n_bins,), dtype=[(key, value) for key, value in columns.items()])
data['bin'] = np.arange(n_bins)

# Setup bins
bins_mag = np.linspace(start=config.mag_brightest_ref, stop=config.mag_faintest_ref,
num=n_bins + 1)
data['mag_min'] = bins_mag[:-1]
data['mag_max'] = bins_mag[1:]
bins_mag = tuple((bins_mag[idx], bins_mag[idx + 1]) for idx in range(n_bins))

# Define temporary columns for intermediate storage
column_mag_temp = 'mag_temp'
column_color_temp = 'color_temp'
column_color_err_temp = 'colorErr_temp'
flux_err_frac_prev = [None]*n_models
mag_prev = [None]*n_models

columns_target = {
target.column_coord1: (
ref.column_coord1, target.column_coord1, coord1_target_err, False,
),
target.column_coord2: (
ref.column_coord2, target.column_coord2, coord2_target_err, False,
),
column_dist: (column_dummy, column_dist, column_dist_err, False),
}

# Cheat a little and do the first band last so that the color is
# based on the last band
data = None
band_fluxes = [(band, config_flux) for (band, config_flux) in config.columns_flux.items()]
n_bands = len(band_fluxes)
if n_bands > 0:

# TODO: Deprecated by RFC-1017 and to be removed in DM-44988
if self.config.compute_stats and (n_bands > 0):
# Slightly smelly hack for when a column (like distance) is already relative to truth
column_dummy = 'dummy'
cat_ref[column_dummy] = np.zeros_like(ref.coord1)

# Add a boolean column for whether a match is classified correctly
# TODO: remove the assumption of a boolean column
extended_ref = cat_ref[config.column_ref_extended] == (not config.column_ref_extended_inverted)

extended_target = cat_target[config.column_target_extended].values >= config.extendedness_cut

# Define difference/chi columns and statistics thereof
suffixes = {MeasurementType.DIFF: 'diff', MeasurementType.CHI: 'chi'}
# Skip diff for fluxes - covered by mags
suffixes_flux = {MeasurementType.CHI: suffixes[MeasurementType.CHI]}
# Skip chi for magnitudes, which have strange errors
suffixes_mag = {MeasurementType.DIFF: suffixes[MeasurementType.DIFF]}
stats = {stat.name_short(): stat() for stat in (Median, SigmaIQR, SigmaMAD)}

for percentile in self.config.percentiles:
stat = Percentile(percentile=float(Decimal(percentile)))
stats[stat.name_short()] = stat

# Get dict of column names
columns, n_models = _get_columns(
bands_columns=config.columns_flux,
suffixes=suffixes,
suffixes_flux=suffixes_flux,
suffixes_mag=suffixes_mag,
stats=stats,
target=target,
column_dist=column_dist,
)

# Setup numpy table
n_bins = config.mag_num_bins
data = np.zeros((n_bins,), dtype=[(key, value) for key, value in columns.items()])
data['bin'] = np.arange(n_bins)

# Setup bins
bins_mag = np.linspace(start=config.mag_brightest_ref, stop=config.mag_faintest_ref,
num=n_bins + 1)
data['mag_min'] = bins_mag[:-1]
data['mag_max'] = bins_mag[1:]
bins_mag = tuple((bins_mag[idx], bins_mag[idx + 1]) for idx in range(n_bins))

# Define temporary columns for intermediate storage
column_mag_temp = 'mag_temp'
column_color_temp = 'color_temp'
column_color_err_temp = 'colorErr_temp'
flux_err_frac_prev = [None]*n_models
mag_prev = [None]*n_models

columns_target = {
target.column_coord1: (
ref.column_coord1, target.column_coord1, coord1_target_err, False,
),
target.column_coord2: (
ref.column_coord2, target.column_coord2, coord2_target_err, False,
),
column_dist: (column_dummy, column_dist, column_dist_err, False),
}

# Cheat a little and do the first band last so that the color is
# based on the last band
band_fluxes.append(band_fluxes[0])
flux_err_frac_first = None
mag_first = None
Expand Down
Loading

0 comments on commit de2a100

Please sign in to comment.