Skip to content

Commit

Permalink
Have sed.resample_sed actaully resample rather than interpolate.
Browse files Browse the repository at this point in the history
  • Loading branch information
yoachim committed Oct 21, 2024
1 parent cd1ee95 commit 0974542
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 34 deletions.
1 change: 1 addition & 0 deletions rubin_sim/phot_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .physical_parameters import *
from .sed import *
from .signaltonoise import *
from .spectral_resampling import *
16 changes: 0 additions & 16 deletions rubin_sim/phot_utils/bandpass.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,6 @@ def __init__(self, wavelen=None, sb=None, sampling_warning=0.2):
self.set_bandpass(wavelen, sb)
return

def _check_wavelength_sampling(self):
"""Check that the wavelength sampling is above some threshold."""
if self.wavelen is not None:
dif = np.diff(self.wavelen)
if np.max(dif) > self.sampling_warning:
warnings.warn(
"Wavelength sampling of %.1f nm is > %.1f nm" % (np.max(dif), self.sampling_warning)
+ ", this may not work well"
" with a Sed object. Consider resampling with resample_bandpass method."
)

def set_bandpass(self, wavelen, sb):
"""
Populate bandpass data with wavelen/sb arrays.
Expand Down Expand Up @@ -114,7 +103,6 @@ def set_bandpass(self, wavelen, sb):
self.phi = None
self.sb = np.copy(sb)
self.bandpassname = "FromArrays"
self._check_wavelength_sampling()

def imsim_bandpass(self, imsimwavelen=500.0, wavelen_min=300, wavelen_max=1150, wavelen_step=0.1):
"""
Expand All @@ -134,7 +122,6 @@ def imsim_bandpass(self, imsimwavelen=500.0, wavelen_min=300, wavelen_max=1150,
self.sb = np.zeros(len(self.wavelen), dtype="float")
self.sb[abs(self.wavelen - imsimwavelen) < wavelen_step / 2.0] = 1.0
self.bandpassname = "IMSIM"
self._check_wavelength_sampling()

def read_throughput(self, filename):
"""
Expand Down Expand Up @@ -194,7 +181,6 @@ def read_throughput(self, filename):
p = self.wavelen.argsort()
self.wavelen = self.wavelen[p]
self.sb = self.sb[p]
self._check_wavelength_sampling()

def read_throughput_list(
self,
Expand Down Expand Up @@ -252,7 +238,6 @@ def read_throughput_list(
# Multiply self by new sb values.
self.sb = self.sb * tempbandpass.sb
self.bandpassname = "".join(component_list)
self._check_wavelength_sampling()

def get_bandpass(self):
wavelen = np.copy(self.wavelen)
Expand Down Expand Up @@ -324,7 +309,6 @@ def resample_bandpass(
self.wavelen = wavelen_grid
self.sb = sb_grid
return
self._check_wavelength_sampling()
return wavelen_grid, sb_grid

# more complicated bandpass functions
Expand Down
29 changes: 13 additions & 16 deletions rubin_sim/phot_utils/sed.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@
import warnings

import numpy
import scipy.interpolate as interpolate
from rubin_scheduler.data import get_data_dir

from .physical_parameters import PhysicalParameters
from .spectral_resampling import spectres

_global_lsst_sed_cache = None

Expand Down Expand Up @@ -476,6 +476,7 @@ def set_sed(self, wavelen, flambda=None, fnu=None, name="FromArray"):
raise ValueError("(No Flambda) - Fnu must be numpy array of same length as Wavelen.")
# Convert fnu to flambda.
self.wavelen, self.flambda = self.fnu_toflambda(wavelen, fnu)
self.fnu = fnu
self.name = name
return

Expand Down Expand Up @@ -747,6 +748,7 @@ def resample_sed(
wavelen_max=None,
wavelen_step=None,
force=False,
fill=numpy.nan,
):
"""
Resample flux onto grid defined by min/max/step OR
Expand Down Expand Up @@ -795,13 +797,8 @@ def resample_sed(
+ " (%.2f to %.2f)" % (wavelen_grid.min(), wavelen_grid.max())
+ "and sed %s (%.2f to %.2f)" % (self.name, wavelen.min(), wavelen.max())
)
# Do the interpolation of wavelen/flux onto grid.
# (type/len failures will die here).
if wavelen[0] > wavelen_grid[0] or wavelen[-1] < wavelen_grid[-1]:
f = interpolate.interp1d(wavelen, flux, bounds_error=False, fill_value=numpy.nan)
flux_grid = f(wavelen_grid)
else:
flux_grid = numpy.interp(wavelen_grid, wavelen, flux)
# rebin the spectra. Fill with NaNs if there's non-overlap regions.
flux_grid = spectres(wavelen_grid, wavelen, flux, fill=fill, verbose=False)

# Update self values if necessary.
if update_self:
Expand Down Expand Up @@ -1245,7 +1242,7 @@ def mag_from_flux(self, flux):

return -2.5 * numpy.log10(flux) - self.zp

def calc_ergs(self, bandpass):
def calc_ergs(self, bandpass, fill=numpy.nan):
r"""
Integrate the SED over a bandpass directly. If self.flambda
is in ergs/s/cm^2/nm and bandpass.sb is the unitless probability
Expand All @@ -1272,7 +1269,7 @@ def calc_ergs(self, bandpass):
The flux of the current SED through the bandpass in ergs/s/cm^2
"""
wavelen, flambda = self.resample_sed(
wavelen=self.wavelen, flux=self.flambda, wavelen_match=bandpass.wavelen
wavelen=self.wavelen, flux=self.flambda, wavelen_match=bandpass.wavelen, fill=fill
)

dlambda = wavelen[1] - wavelen[0]
Expand All @@ -1281,7 +1278,7 @@ def calc_ergs(self, bandpass):
energy = (0.5 * (flambda[1:] * bandpass.sb[1:] + flambda[:-1] * bandpass.sb[:-1]) * dlambda).sum()
return energy

def calc_flux(self, bandpass, wavelen=None, fnu=None):
def calc_flux(self, bandpass, wavelen=None, fnu=None, fill=numpy.nan):
"""
Integrate the specific flux density of the object over the normalized
response curve of a bandpass, giving a flux in Janskys
Expand Down Expand Up @@ -1316,15 +1313,15 @@ def calc_flux(self, bandpass, wavelen=None, fnu=None):
wavelen = self.wavelen
fnu = self.fnu
# Go on with magnitude calculation.
wavelen, fnu = self.resample_sed(wavelen, fnu, wavelen_match=bandpass.wavelen)
wavelen, fnu = self.resample_sed(wavelen, fnu, wavelen_match=bandpass.wavelen, fill=fill)
# Calculate bandpass phi value if required.
if bandpass.phi is None:
bandpass.sb_tophi()
# Calculate flux in bandpass and return this value.
flux = numpy.trapz(fnu * bandpass.phi, x=wavelen)
return flux

def calc_mag(self, bandpass, wavelen=None, fnu=None):
def calc_mag(self, bandpass, wavelen=None, fnu=None, fill=numpy.nan):
"""
Calculate the AB magnitude of an object using the normalized system
response (phi from Section 4.1 of the LSST design document LSE-180).
Expand All @@ -1334,13 +1331,13 @@ def calc_mag(self, bandpass, wavelen=None, fnu=None):
wavelen/fnu pair to be on the same grid as bandpass;
(but only temporary values of these are used).
"""
flux = self.calc_flux(bandpass, wavelen=wavelen, fnu=fnu)
flux = self.calc_flux(bandpass, wavelen=wavelen, fnu=fnu, fill=fill)
if flux < 1e-300:
raise ValueError("This SED has no flux within this bandpass.")
mag = self.mag_from_flux(flux)
return mag

def calc_flux_norm(self, magmatch, bandpass, wavelen=None, fnu=None):
def calc_flux_norm(self, magmatch, bandpass, wavelen=None, fnu=None, fill=numpy.nan):
"""
Calculate the fluxNorm (SED normalization value for a given mag)
for a sed.
Expand All @@ -1359,7 +1356,7 @@ def calc_flux_norm(self, magmatch, bandpass, wavelen=None, fnu=None):
# (fluxnorm * SED(f_nu) * PHI = mag - 8.9 (AB zeropoint).
# FluxNorm * SED => correct magnitudes for this object.
# Calculate fluxnorm.
curmag = self.calc_mag(bandpass, wavelen, fnu)
curmag = self.calc_mag(bandpass, wavelen, fnu, fill=fill)
if curmag == self.badval:
return self.badval
dmag = magmatch - curmag
Expand Down
158 changes: 158 additions & 0 deletions rubin_sim/phot_utils/spectral_resampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
__all__ = ("spectres",)

import warnings

import numpy as np

# Taken from https://github.com/ACCarnall/SpectRes
# (which is under GPL 3), cite https://arxiv.org/abs/1705.05165


def make_bins(wavs):
"""Given a series of wavelength points, find the edges and widths
of corresponding wavelength bins."""
edges = np.zeros(wavs.shape[0] + 1)
widths = np.zeros(wavs.shape[0])
edges[0] = wavs[0] - (wavs[1] - wavs[0]) / 2
widths[-1] = wavs[-1] - wavs[-2]
edges[-1] = wavs[-1] + (wavs[-1] - wavs[-2]) / 2
edges[1:-1] = (wavs[1:] + wavs[:-1]) / 2
widths[:-1] = edges[1:-1] - edges[:-2]

return edges, widths


def spectres(new_wavs, spec_wavs, spec_fluxes, spec_errs=None, fill=0, verbose=True):
"""
Function for resampling spectra (and optionally associated
uncertainties) onto a new wavelength basis.
Parameters
----------
new_wavs : numpy.ndarray
Array containing the new wavelength sampling desired for the
spectrum or spectra.
spec_wavs : numpy.ndarray
1D array containing the current wavelength sampling of the
spectrum or spectra.
spec_fluxes : numpy.ndarray
Array containing spectral fluxes at the wavelengths specified in
spec_wavs, last dimension must correspond to the shape of
spec_wavs. Extra dimensions before this may be used to include
multiple spectra.
spec_errs : numpy.ndarray (optional)
Array of the same shape as spec_fluxes containing uncertainties
associated with each spectral flux value.
fill : float (optional)
Where new_wavs extends outside the wavelength range in spec_wavs
this value will be used as a filler in new_fluxes and new_errs.
verbose : bool (optional)
Setting verbose to False will suppress the default warning about
new_wavs extending outside spec_wavs and "fill" being used.
Returns
-------
new_fluxes : numpy.ndarray
Array of resampled flux values, last dimension is the same
length as new_wavs, other dimensions are the same as
spec_fluxes.
new_errs : numpy.ndarray
Array of uncertainties associated with fluxes in new_fluxes.
Only returned if spec_errs was specified.
"""

# Rename the input variables for clarity within the function.
old_wavs = spec_wavs
old_fluxes = spec_fluxes
old_errs = spec_errs

# Make arrays of edge positions and widths for the old and new bins

old_edges, old_widths = make_bins(old_wavs)
new_edges, new_widths = make_bins(new_wavs)

# Generate output arrays to be populated
new_fluxes = np.zeros(old_fluxes[..., 0].shape + new_wavs.shape)

if old_errs is not None:
if old_errs.shape != old_fluxes.shape:
raise ValueError("If specified, spec_errs must be the same shape " "as spec_fluxes.")
else:
new_errs = np.copy(new_fluxes)

start = 0
stop = 0

# Calculate new flux and uncertainty values, looping over new bins
for j in range(new_wavs.shape[0]):

# Add filler values if new_wavs extends outside of spec_wavs
if (new_edges[j] < old_edges[0]) or (new_edges[j + 1] > old_edges[-1]):
new_fluxes[..., j] = fill

if spec_errs is not None:
new_errs[..., j] = fill

if (j == 0 or j == new_wavs.shape[0] - 1) and verbose:
warnings.warn(
"Spectres: new_wavs contains values outside the range "
"in spec_wavs, new_fluxes and new_errs will be filled "
"with the value set in the 'fill' keyword argument "
"(by default 0).",
category=RuntimeWarning,
)
continue

# Find first old bin which is partially covered by the new bin
while old_edges[start + 1] <= new_edges[j]:
start += 1

# Find last old bin which is partially covered by the new bin
while old_edges[stop + 1] < new_edges[j + 1]:
stop += 1

# If new bin is fully inside an old bin start and stop are equal
if stop == start:
new_fluxes[..., j] = old_fluxes[..., start]
if old_errs is not None:
new_errs[..., j] = old_errs[..., start]

# Otherwise multiply the first and last old bin widths by P_ij
else:
start_factor = (old_edges[start + 1] - new_edges[j]) / (old_edges[start + 1] - old_edges[start])

end_factor = (new_edges[j + 1] - old_edges[stop]) / (old_edges[stop + 1] - old_edges[stop])

old_widths[start] *= start_factor
old_widths[stop] *= end_factor

# Populate new_fluxes spectrum and uncertainty arrays
f_widths = old_widths[start : stop + 1] * old_fluxes[..., start : stop + 1]
new_fluxes[..., j] = np.sum(f_widths, axis=-1)
new_fluxes[..., j] /= np.sum(old_widths[start : stop + 1])

if old_errs is not None:
e_wid = old_widths[start : stop + 1] * old_errs[..., start : stop + 1]

new_errs[..., j] = np.sqrt(np.sum(e_wid**2, axis=-1))
new_errs[..., j] /= np.sum(old_widths[start : stop + 1])

# Put back the old bin widths to their initial values
old_widths[start] /= start_factor
old_widths[stop] /= end_factor

# If errors were supplied return both new_fluxes and new_errs.
if old_errs is not None:
return new_fluxes, new_errs

# Otherwise just return the new_fluxes spectrum array
else:
return new_fluxes
35 changes: 33 additions & 2 deletions tests/phot_utils/test_sed.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,36 @@ def test_sed_bandpass_match(self):
np.testing.assert_equal(testsed.flambda[0], np.NaN)
np.testing.assert_equal(testsed.flambda[49], np.NaN)

def test_rebin(self):
"""Test that rebinning an SED does not change integrated flux
much.
"""
sed = Sed()
sed.set_flat_sed(wavelen_step=0.01)

# Make a line feature.
sigma = 0.05
fnu = sed.fnu - sed.fnu.max() * np.exp(-((sed.wavelen - 365.2) ** 2) / sigma**2)

sed.set_sed(sed.wavelen, fnu=fnu)
wave_fine = np.arange(350, 380 + 0.01, 0.01)
bp_fine = Bandpass(wavelen=wave_fine, sb=np.ones(wave_fine.size))

wave_rough = np.arange(350, 380 + 0.5, 0.5)
bp_rough = Bandpass(wavelen=wave_rough, sb=np.ones(wave_rough.size))

# Flux computed with a fine sampled bandpass
# should match lower resolution bandpass
flux_fine = sed.calc_flux(bp_fine)
flux_rough = sed.calc_flux(bp_rough)

assert np.isclose(flux_fine, flux_rough, rtol=1e-5)

# Check magnitudes as well.
mag_fine = sed.calc_mag(bp_fine)
mag_rough = sed.calc_mag(bp_rough)
assert np.isclose(mag_fine, mag_rough, rtol=1e-3)

def test_sed_mag_errors(self):
"""Test error handling at mag and adu calculation levels of sed."""
sedwavelen = np.arange(self.wmin + 50, self.wmax, 1)
Expand Down Expand Up @@ -279,7 +309,7 @@ def test_calc_ergs(self):
# Now test it on a bandpass with throughput=0.25 and an wavelength
# array that is not the same as the SED

wavelen_arr = np.arange(10.0, 100000.0, 146.0) # in nm
wavelen_arr = np.arange(5.0, 100000.0, 146.0) # in nm
bp = Bandpass(wavelen=wavelen_arr, sb=0.25 * np.ones(len(wavelen_arr)))

wavelen_arr = np.arange(5.0, 200000.0, 17.0)
Expand All @@ -306,12 +336,13 @@ def test_calc_ergs(self):
bb_flambda = np.pi * np.power(10.0, log10_bb_factor + log10_bose_factor - 7.0)

sed = Sed(wavelen=wavelen_arr, flambda=bb_flambda)
ergs = sed.calc_ergs(bp)
ergs = sed.calc_ergs(bp, fill=0)

log10_ergs = np.log10(stefan_boltzmann_sigma) + 4.0 * np.log10(temp)
ergs_truth = np.power(10.0, log10_ergs)

msg = "\ntemp: %e\nergs: %e\nergs_truth: %e" % (temp, ergs, ergs_truth)

self.assertAlmostEqual(ergs / ergs_truth, 0.25, 3, msg=msg)

def test_mags_vs_flux(self):
Expand Down

0 comments on commit 0974542

Please sign in to comment.