From 2bfeab65fbd3a393d45ec9d9ba5213f4dc119e4e Mon Sep 17 00:00:00 2001 From: teutoburg Date: Sat, 26 Aug 2023 00:21:33 +0200 Subject: [PATCH] General refactoring, using generators, update tests --- scopesim/optics/fov_manager_utils.py | 205 ++++++++++-------- .../tests_optics/test_fov_manager_utils.py | 93 ++++---- 2 files changed, 169 insertions(+), 129 deletions(-) diff --git a/scopesim/optics/fov_manager_utils.py b/scopesim/optics/fov_manager_utils.py index c03e6084..f97db5aa 100644 --- a/scopesim/optics/fov_manager_utils.py +++ b/scopesim/optics/fov_manager_utils.py @@ -1,4 +1,7 @@ +import logging from copy import deepcopy +from itertools import product +from more_itertools import pairwise import numpy as np from astropy import units as u @@ -9,6 +12,8 @@ from ..effects.effects_utils import get_all_effects from ..utils import check_keys +# TODO: Where are all these functions used?? + def get_3d_shifts(effects, **kwargs): """ @@ -41,9 +46,11 @@ def get_3d_shifts(effects, **kwargs): shifts = [eff.fov_grid(which="shifts", **kwargs) for eff in effects] old_bin_edges = [shift[0] for shift in shifts if len(shift[0]) >= 2] + # TODO: could this use combine_wavesets? new_bin_edges = np.unique(np.sort(np.concatenate(old_bin_edges), kind="stable")) + # TODO: could this be zeros_like? x_shifts = np.zeros(len(new_bin_edges)) y_shifts = np.zeros(len(new_bin_edges)) # .. todo:: replace the 1e-7 with a variable in !SIM @@ -99,25 +106,27 @@ def get_imaging_waveset(effects_list, **kwargs): wave_bin_edges = [filt.fov_grid(which="waveset", **kwargs) for filt in filters] - if len(wave_bin_edges) > 0: - kwargs["wave_min"] = np.max([w[0].value for w in wave_bin_edges]) - kwargs["wave_max"] = np.min([w[1].value for w in wave_bin_edges]) + if wave_bin_edges: + kwargs["wave_min"] = max(wave[0].value for wave in wave_bin_edges) + kwargs["wave_max"] = min(wave[1].value for wave in wave_bin_edges) + # Bit confusing... wave_bin_edges = [[kwargs["wave_min"], kwargs["wave_max"]]] if kwargs["wave_min"] > kwargs["wave_max"]: - raise ValueError(f"Filter wavelength ranges do not overlap: {wave_bin_edges}") + raise ValueError("Filter wavelength ranges do not overlap: " + f"{wave_bin_edges}.") # ..todo: add in Atmospheric dispersion and ADC here for effect_class in [efs.PSF]: - effects = get_all_effects(effects_list, effect_class) - for eff in effects: + for eff in get_all_effects(effects_list, effect_class): waveset = eff.fov_grid(which="waveset", **kwargs) if waveset is not None: - wave_bin_edges += [waveset] + wave_bin_edges.append(waveset) - wave_bin_edges = combine_wavesets(wave_bin_edges) + wave_bin_edges = combine_wavesets(*wave_bin_edges) - if len(wave_bin_edges) == 0: + if not wave_bin_edges: + # This is already set at the top, why again here? wave_bin_edges = [kwargs["wave_min"], kwargs["wave_max"]] return wave_bin_edges @@ -125,7 +134,7 @@ def get_imaging_waveset(effects_list, **kwargs): def get_imaging_headers(effects, **kwargs): """ - Returns a list of Header objects for each of the FieldOfVIew objects + Return a generator of Header objects for each of the FieldOfVIew objects. Parameters ---------- @@ -135,7 +144,7 @@ def get_imaging_headers(effects, **kwargs): Returns ------- - hdrs : list of Header objects + hdrs : generator of Header objects Notes ----- @@ -163,33 +172,39 @@ def get_imaging_headers(effects, **kwargs): aperture_effects = get_all_effects(effects, (efs.ApertureMask, efs.SlitWheel, efs.ApertureList)) - if len(aperture_effects) == 0: + if not aperture_effects: detector_arrays = get_all_effects(effects, efs.DetectorList) - if len(detector_arrays) > 0: - aperture_effects += [detarr.fov_grid(which="edges", - pixel_scale=pixel_scale) - for detarr in detector_arrays] - else: - raise ValueError("No ApertureMask or DetectorList was provided. At " - "least one must be passed to make an ImagePlane: " - f"{effects}") - + if not detector_arrays: + raise ValueError("No ApertureMask or DetectorList was provided. " + "At least one must be passed to make an " + f"ImagePlane: {effects}.") + aperture_effects.extend( + detarr.fov_grid(which="edges", pixel_scale=pixel_scale) + for detarr in detector_arrays) + + # FIXME: all of this is a bit inconsistent; fov_grid(which="edges" is + # called afterwards, but when looking in detector_arrays, the same + # is called immediately; does that even work? is this all tested?? # get aperture headers from fov_grid() # - for-loop catches mutliple headers from ApertureList.fov_grid() - hdrs = [] - for ap_eff in aperture_effects: - # ..todo:: add this functionality to ApertureList effect - hdr = ap_eff.fov_grid(which="edges", pixel_scale=pixel_scale) - hdrs += hdr if isinstance(hdr, (list, tuple)) else [hdr] + def _get_hdrs(ap_effs): + for ap_eff in ap_effs: + # ..todo:: add this functionality to ApertureList effect + hdr = ap_eff.fov_grid(which="edges", pixel_scale=pixel_scale) + if isinstance(hdr, (list, tuple)): + yield from hdr + else: + yield hdr + headers = _get_hdrs(aperture_effects) # check size of aperture in pixels - split if necessary - sky_hdrs = [] - for hdr in hdrs: - if hdr["NAXIS1"] * hdr["NAXIS2"] > kwargs["max_segment_size"]: - sky_hdrs += imp_utils.split_header(hdr, kwargs["chunk_size"]) - else: - sky_hdrs += [hdr] - + def _get_sky_hdrs(hdrs): + for hdr in hdrs: + if hdr["NAXIS1"] * hdr["NAXIS2"] > kwargs["max_segment_size"]: + yield from imp_utils.split_header(hdr, kwargs["chunk_size"]) + else: + yield hdr + sky_hdrs = _get_sky_hdrs(headers) # ..todo:: Deal with the case that two or more ApertureMasks overlap # map the on-sky apertures directly to the image plane using plate_scale @@ -208,13 +223,12 @@ def get_imaging_headers(effects, **kwargs): dethdr = imp_utils.header_from_list_of_xy(x_det, y_det, pixel_size, "D") skyhdr.update(dethdr) - - return sky_hdrs + yield skyhdr def get_imaging_fovs(headers, waveset, shifts, **kwargs): """ - Returns a list of ``FieldOfView`` objects + Return a generator of ``FieldOfView`` objects. Parameters ---------- @@ -224,54 +238,50 @@ def get_imaging_fovs(headers, waveset, shifts, **kwargs): waveset : list of floats [um] N+1 wavelengths for N spectral layers - shifts : list of tuples + shifts : list of tuples (or actually arrays?) [deg] x,y shifts w.r.t to the optical axis plane. N shifts for N spectral layers Returns ------- - fovs : list of FieldOfView objects + fovs : generator of ``FieldOfView`` objects """ - - shift_waves = shifts["wavelengths"] # in [um] + # Ensure array for later indexing + shift_waves = np.array(shifts["wavelengths"]) # in [um] shift_dx = shifts["x_shifts"] # in [deg] shift_dy = shifts["y_shifts"] # combine the wavelength bins from 1D spectral effects and 3D shift effects - if len(shifts["wavelengths"]) > 0: - mask = (shift_waves > np.min(waveset)) * (shift_waves < np.max(waveset)) - waveset = combine_wavesets([waveset, shift_waves[mask]]) - - counter = 0 - fovs = [] + if shift_waves.size: + mask = (shift_waves > min(waveset)) * (shift_waves < max(waveset)) + waveset = combine_wavesets(waveset, shift_waves[mask]) - print(f"Preparing {(len(waveset)-1)*len(headers)} FieldOfViews", flush=True) + # Actually evaluating the generators here is only necessary for the log msg + waveset = list(waveset) + headers = list(headers) + logging.info("Preparing %d FieldOfViews", (len(waveset) - 1) * len(headers)) - for ii in range(len(waveset) - 1): - for hdr in headers: - # add any pre-instrument shifts to the FOV sky coords - wave_mid = 0.5 * (waveset[ii] + waveset[ii+1]) - x_shift = np.interp(wave_mid, shift_waves, shift_dx) - y_shift = np.interp(wave_mid, shift_waves, shift_dy) - - fov_hdr = deepcopy(hdr) - fov_hdr["CRVAL1"] += x_shift # headers are in [deg] - fov_hdr["CRVAL2"] += y_shift + combos = product(pairwise(waveset), headers) + for fov_id, ((wave_min, wave_max), hdr) in enumerate(combos): + # add any pre-instrument shifts to the FOV sky coords + wave_mid = 0.5 * (wave_min + wave_max) + x_shift = np.interp(wave_mid, shift_waves, shift_dx) + y_shift = np.interp(wave_mid, shift_waves, shift_dy) - # define the wavelength range for the FOV - waverange = [waveset[ii], waveset[ii + 1]] + fov_hdr = deepcopy(hdr) + fov_hdr["CRVAL1"] += x_shift # headers are in [deg] + fov_hdr["CRVAL2"] += y_shift - # Make the FOV - fov = FieldOfView(fov_hdr, waverange, id=counter, **kwargs) - fovs += [fov] - counter += 1 + # define the wavelength range for the FOV + waverange = [wave_min, wave_max] - return fovs + # Make the FOV + yield FieldOfView(fov_hdr, waverange, id=fov_id, **kwargs) def get_spectroscopy_headers(effects, **kwargs): - + """Return generator of Header objects.""" required_keys = ["pixel_scale", "plate_scale", "wave_min", "wave_max"] check_keys(kwargs, required_keys, action="error") @@ -285,15 +295,15 @@ def get_spectroscopy_headers(effects, **kwargs): efs.SlitWheel, efs.ApertureMask)) - if len(surface_list_effects) > 0: + if surface_list_effects: waves = surface_list_effects[0].fov_grid(which="waveset") if len(waves) == 2: kwargs["wave_min"] = np.max([waves[0].value, kwargs["wave_min"]]) kwargs["wave_max"] = np.min([waves[1].value, kwargs["wave_max"]]) - if len(detector_list_effects) > 0: + if detector_list_effects: implane_hdr = detector_list_effects[0].image_plane_header - elif len(spec_trace_effects) > 0: + elif spec_trace_effects: implane_hdr = spec_trace_effects[0].image_plane_header else: raise ValueError("Missing a way to determine the image plane size") @@ -304,6 +314,7 @@ def get_spectroscopy_headers(effects, **kwargs): f"{spec_trace_effects}") spec_trace = spec_trace_effects[0] + # TODO: The following is WET with the code in get_imaging_headers sky_hdrs = [] for ap_eff in aperture_effects: # if ApertureList, a list of ApertureMask headers is returned @@ -320,13 +331,15 @@ def get_spectroscopy_headers(effects, **kwargs): plate_scale=kwargs["plate_scale"] ) for sky_hdr in sky_hdrs] - fov_headers = [hdr for hdr_list in fov_headers for hdr in hdr_list] - # ..todo: check that each header is not larger than chunk_size - - return fov_headers + for hdr_list in fov_headers: + for hdr in hdr_list: + yield hdr + # TODO: check that each header is not larger than chunk_size + # that's already done in get_imaging_headers, isn't it? def get_spectroscopy_fovs(headers, shifts, effects=None, **kwargs): + """Return a generator of ``FieldOfView`` objects.""" if effects is None: effects = [] @@ -334,7 +347,7 @@ def get_spectroscopy_fovs(headers, shifts, effects=None, **kwargs): shift_dx = shifts["x_shifts"] # in [deg] shift_dy = shifts["y_shifts"] - print(f"Preparing {len(headers)} FieldOfViews", flush=True) + logging.info("Preparing %d FieldOfViews", len(headers)) apertures = get_all_effects(effects, (efs.ApertureList, efs.ApertureMask)) masks = [ap.fov_grid(which="masks") for ap in apertures] @@ -345,8 +358,7 @@ def get_spectroscopy_fovs(headers, shifts, effects=None, **kwargs): elif isinstance(mask, np.ndarray): mask_dict[len(mask_dict)] = mask - fovs = [] - for ii, hdr in enumerate(headers): + for fov_id, hdr in enumerate(headers): # add any pre-instrument shifts to the FOV sky coords wave_mid = hdr["WAVE_MID"] x_shift = np.interp(wave_mid, shift_waves, shift_dx) @@ -357,29 +369,30 @@ def get_spectroscopy_fovs(headers, shifts, effects=None, **kwargs): fov_hdr["CRVAL2"] += y_shift # Make the FOV - fov = FieldOfView(fov_hdr, waverange=[hdr["WAVE_MIN"], hdr["WAVE_MAX"]], - **kwargs) + waverange = [hdr["WAVE_MIN"], hdr["WAVE_MAX"]] + fov = FieldOfView(fov_hdr, waverange=waverange, **kwargs) fov.meta["distortion"]["rotation"] = hdr["ROTANGD"] fov.meta["distortion"]["shear"] = hdr["SKEWANGD"] fov.meta["conserve_image"] = hdr["IMG_CONS"] - fov.meta["fov_id"] = ii + # TODO: In the other function, the id is set via the contructor. + # What's the difference? + fov.meta["fov_id"] = fov_id fov.meta["aperture_id"] = hdr["APERTURE"] # .. todo: get these masks working # there needs to be fov_grid(which="mask") in ApertureList/Mask # fov.mask = mask_dict[hdr["APERTURE"]] - fovs += [fov] + yield fov - return fovs - -def combine_wavesets(waveset_list): +# FIXME: This functions doesn't seem to be covered by any separate unit test. +def combine_wavesets(*wavesets): """ - Joins and sorts several sets of wavelengths into a single 1D array + Join and sorts several sets of wavelengths into a single 1D array. Parameters ---------- - waveset_list : list + wavesets : one or more iterables A group of wavelength arrays or lists Returns @@ -387,12 +400,22 @@ def combine_wavesets(waveset_list): wave_set : np.ndarray Combined set of wavelengths + Note + ---- + This assumes that all wavesets are given in the same unit! """ - wave_set = [] - for wbe in waveset_list: # wbe = waveset bin edges - wbe = wbe.value if isinstance(wbe, u.Quantity) else wbe - wave_set += list(wbe) - # ..todo:: set variable in !SIM.computing for rounding to the 7th decimal - wave_set = np.unique(np.round(np.sort(wave_set, kind="stable"), 7)) - + # TODO: set variable in !SIM.computing for rounding to the 7th decimal + decimals = 7 + + def _get_waves(waves): + for wave in waves: + if isinstance(wave, u.Quantity): + round_wave = wave.round(decimals).value + else: + round_wave = np.round(wave, decimals) + yield from round_wave + + # NOTE: This function previously used np.sort(wave_set, kind="stable"). + # If any issues occur with the buitin sorted, go back to that! + wave_set = sorted(set(_get_waves(wavesets))) return wave_set diff --git a/scopesim/tests/tests_optics/test_fov_manager_utils.py b/scopesim/tests/tests_optics/test_fov_manager_utils.py index 1007af7e..63468858 100644 --- a/scopesim/tests/tests_optics/test_fov_manager_utils.py +++ b/scopesim/tests/tests_optics/test_fov_manager_utils.py @@ -13,11 +13,31 @@ from scopesim.tests.mocks.py_objects import trace_list_objects as tlo +@pytest.fixture(scope="function") +def wave_kwargs(): + return {"wave_min": 0.5, "wave_max": 2.5} + + +@pytest.fixture(scope="function") +def th_filt(): + return eo._filter_tophat_curve() + + @pytest.fixture(scope="function") def full_trace_list(): return tlo.make_trace_hdulist() +@pytest.fixture(scope="function") +def spec_hdrs(full_trace_list): + params = {"pixel_scale": 0.1, "plate_scale": 0.1, + "wave_min": 0.7, "wave_max": 2.5} + spt = SpectralTraceList(hdulist=full_trace_list, **params) + apm = apo._basic_aperture() + hdrs = fm_utils.get_spectroscopy_headers(effects=[spt, apm], **params) + return list(hdrs) + + PLOTS = False @@ -73,35 +93,35 @@ def test_combined_shifts_reduced_to_usable_number(self, sub_pix_frac): class TestGetImagingWaveset: - def test_returns_default_wave_range_when_passed_no_effects(self): - kwargs = {"wave_min": 0.5, "wave_max": 2.5} - wave_bin_edges = fm_utils.get_imaging_waveset([], **kwargs) + @pytest.mark.usefixtures("wave_kwargs") + def test_returns_default_wave_range_when_passed_no_effects(self, wave_kwargs): + wave_bin_edges = fm_utils.get_imaging_waveset([], **wave_kwargs) assert len(wave_bin_edges) == 2 - def test_returns_waveset_of_filter(self): - filt = eo._filter_tophat_curve() - kwargs = {"wave_min": 0.5, "wave_max": 2.5} - wave_bin_edges = fm_utils.get_imaging_waveset([filt], **kwargs) + @pytest.mark.usefixtures("wave_kwargs", "th_filt") + def test_returns_waveset_of_filter(self, wave_kwargs, th_filt): + wave_bin_edges = fm_utils.get_imaging_waveset([th_filt], **wave_kwargs) assert len(wave_bin_edges) == 2 - def test_returns_waveset_of_psf(self): + @pytest.mark.usefixtures("wave_kwargs") + def test_returns_waveset_of_psf(self, wave_kwargs): psf = eo._const_psf() - kwargs = {"wave_min": 0.5, "wave_max": 2.5} - wave_bin_edges = fm_utils.get_imaging_waveset([psf], **kwargs) + wave_bin_edges = fm_utils.get_imaging_waveset([psf], **wave_kwargs) assert len(wave_bin_edges) == 4 - def test_returns_waveset_of_psf_and_filter(self): - filt = eo._filter_tophat_curve() + @pytest.mark.usefixtures("wave_kwargs", "th_filt") + def test_returns_waveset_of_psf_and_filter(self, wave_kwargs, th_filt): psf = eo._const_psf() - kwargs = {"wave_min": 0.5, "wave_max": 2.5} - wave_bin_edges = fm_utils.get_imaging_waveset([filt, psf], **kwargs) + wave_bin_edges = fm_utils.get_imaging_waveset([th_filt, psf], + **wave_kwargs) assert len(wave_bin_edges) == 4 - def test_returns_waveset_of_ncpa_psf_inside_filter_edges(self): - filt = eo._filter_tophat_curve() + @pytest.mark.usefixtures("wave_kwargs", "th_filt") + def test_returns_waveset_of_ncpa_psf_inside_filter_edges(self, wave_kwargs, + th_filt): psf = eo._ncpa_psf() - kwargs = {"wave_min": 0.5, "wave_max": 2.5} - wave_bin_edges = fm_utils.get_imaging_waveset([psf, filt], **kwargs) + wave_bin_edges = fm_utils.get_imaging_waveset([psf, th_filt], + **wave_kwargs) assert min(wave_bin_edges) == 1. assert max(wave_bin_edges) == 2. assert len(wave_bin_edges) == 9 @@ -111,7 +131,7 @@ class TestGetImagingHeaders: def test_throws_error_if_not_all_keywords_are_passed(self): apm = eo._img_aperture_mask() with pytest.raises(ValueError): - fm_utils.get_imaging_headers([apm]) + list(fm_utils.get_imaging_headers([apm])) def test_returns_set_of_headers_from_aperture_effects(self): apm = eo._img_aperture_mask(array_dict={"x": [-1.28, 1., 1., -1.28], @@ -129,6 +149,8 @@ def test_returns_set_of_headers_from_detector_list_effect(self): kwargs = {"pixel_scale": 0.004, "plate_scale": 0.26666666666, "max_segment_size": 2048 ** 2, "chunk_size": 1024} hdrs = fm_utils.get_imaging_headers([det], **kwargs) + # Evaluate generator for testing + hdrs = list(hdrs) area_sum = np.sum([hdr["NAXIS1"] * hdr["NAXIS2"] for hdr in hdrs]) assert area_sum == 4096**2 @@ -161,6 +183,8 @@ def test_returns_fov_objects_for_basic_input(self): "max_segment_size": 100 ** 2, "chunk_size": 100} hdrs = fm_utils.get_imaging_headers([apm], **kwargs) + # Evaluate generator for testing + hdrs = list(hdrs) waveset = np.linspace(1, 2, 6) shifts = {"wavelengths": np.array([1, 2]), "x_shifts": np.zeros(2), @@ -168,7 +192,10 @@ def test_returns_fov_objects_for_basic_input(self): fovs = fm_utils.get_imaging_fovs(headers=hdrs, waveset=waveset, shifts=shifts) + # Evaluate generator for testing + fovs = list(fovs) assert len(fovs) == (len(waveset)-1) * len(hdrs) + assert fovs if PLOTS: from scopesim.optics.image_plane_utils import calc_footprint @@ -189,19 +216,13 @@ def test_returns_fov_objects_for_basic_input(self): plt.show() -@pytest.mark.usefixtures("full_trace_list") +@pytest.mark.usefixtures("spec_hdrs") class TestGetSpectroscopyHeaders: - def test_returns_headers(self, full_trace_list): - params = {"pixel_scale": 0.1, "plate_scale": 0.1, - "wave_min": 0.7, "wave_max": 2.5} - spt = SpectralTraceList(hdulist=full_trace_list, **params) - apm = apo._basic_aperture() - - hdrs = fm_utils.get_spectroscopy_headers(effects=[spt, apm], **params) - assert all([isinstance(hdr, PoorMansHeader) for hdr in hdrs]) + def test_returns_headers(self, spec_hdrs): + assert all([isinstance(hdr, PoorMansHeader) for hdr in spec_hdrs]) if PLOTS: - for hdr in hdrs: + for hdr in spec_hdrs: x = np.array([0, hdr["NAXIS1"], hdr["NAXIS1"], 0]) y = np.array([0, 0, hdr["NAXIS2"], hdr["NAXIS2"]]) xw, yw = pix2val(hdr, x, y, "D") @@ -210,19 +231,15 @@ def test_returns_headers(self, full_trace_list): plt.show() -@pytest.mark.usefixtures("full_trace_list") +@pytest.mark.usefixtures("spec_hdrs") class TestGetSpectroscopyFOVs: - def test_returns_fovs(self, full_trace_list): - params = {"pixel_scale": 0.1, "plate_scale": 0.1, - "wave_min": 0.7, "wave_max": 2.5} - spt = SpectralTraceList(hdulist=full_trace_list, **params) - apm = apo._basic_aperture() - + def test_returns_fovs(self, spec_hdrs): shifts = {"wavelengths": np.array([0.7, 2.5]), "x_shifts": np.array([0, 0]), "y_shifts": np.array([0, 1/3600.])} - hdrs = fm_utils.get_spectroscopy_headers(effects=[spt, apm], **params) - fovs = fm_utils.get_spectroscopy_fovs(hdrs, shifts) + fovs = fm_utils.get_spectroscopy_fovs(spec_hdrs, shifts) + # Evaluate generator for testing + fovs = list(fovs) assert all([isinstance(fov, FieldOfView) for fov in fovs])