Skip to content

Commit

Permalink
General refactoring, using generators, update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
teutoburg committed Aug 29, 2023
1 parent b6563ea commit 2bfeab6
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 129 deletions.
205 changes: 114 additions & 91 deletions scopesim/optics/fov_manager_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import logging
from copy import deepcopy
from itertools import product
from more_itertools import pairwise

import numpy as np
from astropy import units as u
Expand All @@ -9,6 +12,8 @@
from ..effects.effects_utils import get_all_effects
from ..utils import check_keys

# TODO: Where are all these functions used??


def get_3d_shifts(effects, **kwargs):
"""
Expand Down Expand Up @@ -41,9 +46,11 @@ def get_3d_shifts(effects, **kwargs):
shifts = [eff.fov_grid(which="shifts", **kwargs) for eff in effects]

old_bin_edges = [shift[0] for shift in shifts if len(shift[0]) >= 2]
# TODO: could this use combine_wavesets?
new_bin_edges = np.unique(np.sort(np.concatenate(old_bin_edges),
kind="stable"))

# TODO: could this be zeros_like?
x_shifts = np.zeros(len(new_bin_edges))
y_shifts = np.zeros(len(new_bin_edges))
# .. todo:: replace the 1e-7 with a variable in !SIM
Expand Down Expand Up @@ -99,33 +106,35 @@ def get_imaging_waveset(effects_list, **kwargs):

wave_bin_edges = [filt.fov_grid(which="waveset", **kwargs)
for filt in filters]
if len(wave_bin_edges) > 0:
kwargs["wave_min"] = np.max([w[0].value for w in wave_bin_edges])
kwargs["wave_max"] = np.min([w[1].value for w in wave_bin_edges])
if wave_bin_edges:
kwargs["wave_min"] = max(wave[0].value for wave in wave_bin_edges)
kwargs["wave_max"] = min(wave[1].value for wave in wave_bin_edges)
# Bit confusing...
wave_bin_edges = [[kwargs["wave_min"], kwargs["wave_max"]]]

if kwargs["wave_min"] > kwargs["wave_max"]:
raise ValueError(f"Filter wavelength ranges do not overlap: {wave_bin_edges}")
raise ValueError("Filter wavelength ranges do not overlap: "
f"{wave_bin_edges}.")

# ..todo: add in Atmospheric dispersion and ADC here
for effect_class in [efs.PSF]:
effects = get_all_effects(effects_list, effect_class)
for eff in effects:
for eff in get_all_effects(effects_list, effect_class):
waveset = eff.fov_grid(which="waveset", **kwargs)
if waveset is not None:
wave_bin_edges += [waveset]
wave_bin_edges.append(waveset)

wave_bin_edges = combine_wavesets(wave_bin_edges)
wave_bin_edges = combine_wavesets(*wave_bin_edges)

if len(wave_bin_edges) == 0:
if not wave_bin_edges:
# This is already set at the top, why again here?
wave_bin_edges = [kwargs["wave_min"], kwargs["wave_max"]]

return wave_bin_edges


def get_imaging_headers(effects, **kwargs):
"""
Returns a list of Header objects for each of the FieldOfVIew objects
Return a generator of Header objects for each of the FieldOfVIew objects.
Parameters
----------
Expand All @@ -135,7 +144,7 @@ def get_imaging_headers(effects, **kwargs):
Returns
-------
hdrs : list of Header objects
hdrs : generator of Header objects
Notes
-----
Expand Down Expand Up @@ -163,33 +172,39 @@ def get_imaging_headers(effects, **kwargs):
aperture_effects = get_all_effects(effects, (efs.ApertureMask,
efs.SlitWheel,
efs.ApertureList))
if len(aperture_effects) == 0:
if not aperture_effects:
detector_arrays = get_all_effects(effects, efs.DetectorList)
if len(detector_arrays) > 0:
aperture_effects += [detarr.fov_grid(which="edges",
pixel_scale=pixel_scale)
for detarr in detector_arrays]
else:
raise ValueError("No ApertureMask or DetectorList was provided. At "
"least one must be passed to make an ImagePlane: "
f"{effects}")

if not detector_arrays:
raise ValueError("No ApertureMask or DetectorList was provided. "
"At least one must be passed to make an "
f"ImagePlane: {effects}.")
aperture_effects.extend(
detarr.fov_grid(which="edges", pixel_scale=pixel_scale)
for detarr in detector_arrays)

# FIXME: all of this is a bit inconsistent; fov_grid(which="edges" is
# called afterwards, but when looking in detector_arrays, the same
# is called immediately; does that even work? is this all tested??
# get aperture headers from fov_grid()
# - for-loop catches mutliple headers from ApertureList.fov_grid()
hdrs = []
for ap_eff in aperture_effects:
# ..todo:: add this functionality to ApertureList effect
hdr = ap_eff.fov_grid(which="edges", pixel_scale=pixel_scale)
hdrs += hdr if isinstance(hdr, (list, tuple)) else [hdr]
def _get_hdrs(ap_effs):
for ap_eff in ap_effs:
# ..todo:: add this functionality to ApertureList effect
hdr = ap_eff.fov_grid(which="edges", pixel_scale=pixel_scale)
if isinstance(hdr, (list, tuple)):
yield from hdr
else:
yield hdr
headers = _get_hdrs(aperture_effects)

# check size of aperture in pixels - split if necessary
sky_hdrs = []
for hdr in hdrs:
if hdr["NAXIS1"] * hdr["NAXIS2"] > kwargs["max_segment_size"]:
sky_hdrs += imp_utils.split_header(hdr, kwargs["chunk_size"])
else:
sky_hdrs += [hdr]

def _get_sky_hdrs(hdrs):
for hdr in hdrs:
if hdr["NAXIS1"] * hdr["NAXIS2"] > kwargs["max_segment_size"]:
yield from imp_utils.split_header(hdr, kwargs["chunk_size"])
else:
yield hdr
sky_hdrs = _get_sky_hdrs(headers)
# ..todo:: Deal with the case that two or more ApertureMasks overlap

# map the on-sky apertures directly to the image plane using plate_scale
Expand All @@ -208,13 +223,12 @@ def get_imaging_headers(effects, **kwargs):

dethdr = imp_utils.header_from_list_of_xy(x_det, y_det, pixel_size, "D")
skyhdr.update(dethdr)

return sky_hdrs
yield skyhdr


def get_imaging_fovs(headers, waveset, shifts, **kwargs):
"""
Returns a list of ``FieldOfView`` objects
Return a generator of ``FieldOfView`` objects.
Parameters
----------
Expand All @@ -224,54 +238,50 @@ def get_imaging_fovs(headers, waveset, shifts, **kwargs):
waveset : list of floats
[um] N+1 wavelengths for N spectral layers
shifts : list of tuples
shifts : list of tuples (or actually arrays?)
[deg] x,y shifts w.r.t to the optical axis plane. N shifts for N
spectral layers
Returns
-------
fovs : list of FieldOfView objects
fovs : generator of ``FieldOfView`` objects
"""

shift_waves = shifts["wavelengths"] # in [um]
# Ensure array for later indexing
shift_waves = np.array(shifts["wavelengths"]) # in [um]
shift_dx = shifts["x_shifts"] # in [deg]
shift_dy = shifts["y_shifts"]

# combine the wavelength bins from 1D spectral effects and 3D shift effects
if len(shifts["wavelengths"]) > 0:
mask = (shift_waves > np.min(waveset)) * (shift_waves < np.max(waveset))
waveset = combine_wavesets([waveset, shift_waves[mask]])

counter = 0
fovs = []
if shift_waves.size:
mask = (shift_waves > min(waveset)) * (shift_waves < max(waveset))
waveset = combine_wavesets(waveset, shift_waves[mask])

print(f"Preparing {(len(waveset)-1)*len(headers)} FieldOfViews", flush=True)
# Actually evaluating the generators here is only necessary for the log msg
waveset = list(waveset)
headers = list(headers)
logging.info("Preparing %d FieldOfViews", (len(waveset) - 1) * len(headers))

for ii in range(len(waveset) - 1):
for hdr in headers:
# add any pre-instrument shifts to the FOV sky coords
wave_mid = 0.5 * (waveset[ii] + waveset[ii+1])
x_shift = np.interp(wave_mid, shift_waves, shift_dx)
y_shift = np.interp(wave_mid, shift_waves, shift_dy)

fov_hdr = deepcopy(hdr)
fov_hdr["CRVAL1"] += x_shift # headers are in [deg]
fov_hdr["CRVAL2"] += y_shift
combos = product(pairwise(waveset), headers)
for fov_id, ((wave_min, wave_max), hdr) in enumerate(combos):
# add any pre-instrument shifts to the FOV sky coords
wave_mid = 0.5 * (wave_min + wave_max)
x_shift = np.interp(wave_mid, shift_waves, shift_dx)
y_shift = np.interp(wave_mid, shift_waves, shift_dy)

# define the wavelength range for the FOV
waverange = [waveset[ii], waveset[ii + 1]]
fov_hdr = deepcopy(hdr)
fov_hdr["CRVAL1"] += x_shift # headers are in [deg]
fov_hdr["CRVAL2"] += y_shift

# Make the FOV
fov = FieldOfView(fov_hdr, waverange, id=counter, **kwargs)
fovs += [fov]
counter += 1
# define the wavelength range for the FOV
waverange = [wave_min, wave_max]

return fovs
# Make the FOV
yield FieldOfView(fov_hdr, waverange, id=fov_id, **kwargs)


def get_spectroscopy_headers(effects, **kwargs):

"""Return generator of Header objects."""
required_keys = ["pixel_scale", "plate_scale",
"wave_min", "wave_max"]
check_keys(kwargs, required_keys, action="error")
Expand All @@ -285,15 +295,15 @@ def get_spectroscopy_headers(effects, **kwargs):
efs.SlitWheel,
efs.ApertureMask))

if len(surface_list_effects) > 0:
if surface_list_effects:
waves = surface_list_effects[0].fov_grid(which="waveset")
if len(waves) == 2:
kwargs["wave_min"] = np.max([waves[0].value, kwargs["wave_min"]])
kwargs["wave_max"] = np.min([waves[1].value, kwargs["wave_max"]])

if len(detector_list_effects) > 0:
if detector_list_effects:
implane_hdr = detector_list_effects[0].image_plane_header
elif len(spec_trace_effects) > 0:
elif spec_trace_effects:
implane_hdr = spec_trace_effects[0].image_plane_header
else:
raise ValueError("Missing a way to determine the image plane size")
Expand All @@ -304,6 +314,7 @@ def get_spectroscopy_headers(effects, **kwargs):
f"{spec_trace_effects}")
spec_trace = spec_trace_effects[0]

# TODO: The following is WET with the code in get_imaging_headers
sky_hdrs = []
for ap_eff in aperture_effects:
# if ApertureList, a list of ApertureMask headers is returned
Expand All @@ -320,21 +331,23 @@ def get_spectroscopy_headers(effects, **kwargs):
plate_scale=kwargs["plate_scale"]
)
for sky_hdr in sky_hdrs]
fov_headers = [hdr for hdr_list in fov_headers for hdr in hdr_list]
# ..todo: check that each header is not larger than chunk_size

return fov_headers
for hdr_list in fov_headers:
for hdr in hdr_list:
yield hdr
# TODO: check that each header is not larger than chunk_size
# that's already done in get_imaging_headers, isn't it?


def get_spectroscopy_fovs(headers, shifts, effects=None, **kwargs):
"""Return a generator of ``FieldOfView`` objects."""
if effects is None:
effects = []

shift_waves = shifts["wavelengths"] # in [um]
shift_dx = shifts["x_shifts"] # in [deg]
shift_dy = shifts["y_shifts"]

print(f"Preparing {len(headers)} FieldOfViews", flush=True)
logging.info("Preparing %d FieldOfViews", len(headers))

apertures = get_all_effects(effects, (efs.ApertureList, efs.ApertureMask))
masks = [ap.fov_grid(which="masks") for ap in apertures]
Expand All @@ -345,8 +358,7 @@ def get_spectroscopy_fovs(headers, shifts, effects=None, **kwargs):
elif isinstance(mask, np.ndarray):
mask_dict[len(mask_dict)] = mask

fovs = []
for ii, hdr in enumerate(headers):
for fov_id, hdr in enumerate(headers):
# add any pre-instrument shifts to the FOV sky coords
wave_mid = hdr["WAVE_MID"]
x_shift = np.interp(wave_mid, shift_waves, shift_dx)
Expand All @@ -357,42 +369,53 @@ def get_spectroscopy_fovs(headers, shifts, effects=None, **kwargs):
fov_hdr["CRVAL2"] += y_shift

# Make the FOV
fov = FieldOfView(fov_hdr, waverange=[hdr["WAVE_MIN"], hdr["WAVE_MAX"]],
**kwargs)
waverange = [hdr["WAVE_MIN"], hdr["WAVE_MAX"]]
fov = FieldOfView(fov_hdr, waverange=waverange, **kwargs)
fov.meta["distortion"]["rotation"] = hdr["ROTANGD"]
fov.meta["distortion"]["shear"] = hdr["SKEWANGD"]
fov.meta["conserve_image"] = hdr["IMG_CONS"]
fov.meta["fov_id"] = ii
# TODO: In the other function, the id is set via the contructor.
# What's the difference?
fov.meta["fov_id"] = fov_id
fov.meta["aperture_id"] = hdr["APERTURE"]

# .. todo: get these masks working
# there needs to be fov_grid(which="mask") in ApertureList/Mask
# fov.mask = mask_dict[hdr["APERTURE"]]
fovs += [fov]
yield fov

return fovs


def combine_wavesets(waveset_list):
# FIXME: This functions doesn't seem to be covered by any separate unit test.
def combine_wavesets(*wavesets):
"""
Joins and sorts several sets of wavelengths into a single 1D array
Join and sorts several sets of wavelengths into a single 1D array.
Parameters
----------
waveset_list : list
wavesets : one or more iterables
A group of wavelength arrays or lists
Returns
-------
wave_set : np.ndarray
Combined set of wavelengths
Note
----
This assumes that all wavesets are given in the same unit!
"""
wave_set = []
for wbe in waveset_list: # wbe = waveset bin edges
wbe = wbe.value if isinstance(wbe, u.Quantity) else wbe
wave_set += list(wbe)
# ..todo:: set variable in !SIM.computing for rounding to the 7th decimal
wave_set = np.unique(np.round(np.sort(wave_set, kind="stable"), 7))

# TODO: set variable in !SIM.computing for rounding to the 7th decimal
decimals = 7

def _get_waves(waves):
for wave in waves:
if isinstance(wave, u.Quantity):
round_wave = wave.round(decimals).value
else:
round_wave = np.round(wave, decimals)
yield from round_wave

# NOTE: This function previously used np.sort(wave_set, kind="stable").
# If any issues occur with the buitin sorted, go back to that!
wave_set = sorted(set(_get_waves(wavesets)))
return wave_set
Loading

0 comments on commit 2bfeab6

Please sign in to comment.