From 96f046466495bdf093a11ce8863fe7702f1d84ce Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Tue, 20 Jun 2023 15:24:42 +0100 Subject: [PATCH] Rewrite coordinate systems tests for APE 14 This highlights some API changes here. --- gwcs/api.py | 4 +- gwcs/coordinate_frames.py | 64 +++++----- gwcs/tests/test_api.py | 4 +- gwcs/tests/test_coordinate_systems.py | 167 +++++++++++++++----------- 4 files changed, 131 insertions(+), 108 deletions(-) diff --git a/gwcs/api.py b/gwcs/api.py index ca606f4d..ff633143 100644 --- a/gwcs/api.py +++ b/gwcs/api.py @@ -256,11 +256,11 @@ def serialized_classes(self): @property def world_axis_object_classes(self): - return self.output_frame._world_axis_object_classes + return self.output_frame.world_axis_object_classes @property def world_axis_object_components(self): - return self.output_frame._world_axis_object_components + return self.output_frame.world_axis_object_components @property def pixel_axis_names(self): diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index 83647529..02c385bd 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -249,7 +249,7 @@ def axis_physical_types(self): @property @abc.abstractmethod - def _world_axis_object_classes(self): + def world_axis_object_classes(self): """ The APE 14 object classes for this frame. @@ -260,7 +260,7 @@ def _world_axis_object_classes(self): @property @abc.abstractmethod - def _world_axis_object_components(self): + def world_axis_object_components(self): """ The APE 14 object components for this frame. @@ -444,14 +444,14 @@ def axis_physical_types(self): return self._axis_physical_types or self._default_axis_physical_types @property - def _world_axis_object_classes(self): + def world_axis_object_classes(self): return {f"{at}{i}" if i != 0 else at: (u.Quantity, (), {'unit': unit}) for i, (at, unit) in enumerate(zip(self._axes_type, self.unit))} @property - def _world_axis_object_components(self): + def world_axis_object_components(self): return [(f"{at}{i}" if i != 0 else at, 0, 'value') for i, at in enumerate(self._axes_type)] @@ -543,7 +543,7 @@ def _default_axis_physical_types(self): return tuple("custom:{}".format(t) for t in self.axes_names) @property - def _world_axis_object_classes(self): + def world_axis_object_classes(self): return {'celestial': ( coord.SkyCoord, (), @@ -551,7 +551,7 @@ def _world_axis_object_classes(self): 'unit': self.unit})} @property - def _world_axis_object_components(self): + def world_axis_object_components(self): return [('celestial', 0, 'spherical.lon'), ('celestial', 1, 'spherical.lat')] @@ -605,14 +605,14 @@ def _default_axis_physical_types(self): return ("custom:{}".format(self.unit[0].physical_type),) @property - def _world_axis_object_classes(self): + def world_axis_object_classes(self): return {'spectral': ( coord.SpectralCoord, (), {'unit': self.unit[0]})} @property - def _world_axis_object_components(self): + def world_axis_object_components(self): return [('spectral', 0, 'value')] @@ -657,8 +657,19 @@ def __init__(self, reference_frame, unit=None, axes_order=(0,), def _default_axis_physical_types(self): return ("time",) + def _convert_to_time(self, dt, *, unit, **kwargs): + if (not isinstance(dt, time.TimeDelta) and + isinstance(dt, time.Time) or + isinstance(self.reference_frame.value, np.ndarray)): + return time.Time(dt, **kwargs) + + if not hasattr(dt, 'unit'): + dt = dt * unit + + return self.reference_frame + dt + @property - def _world_axis_object_classes(self): + def world_axis_object_classes(self): comp = ( time.Time, (), @@ -668,7 +679,7 @@ def _world_axis_object_classes(self): return {'temporal': comp} @property - def _world_axis_object_components(self): + def world_axis_object_components(self): if isinstance(self.reference_frame.value, np.ndarray): return [('temporal', 0, 'value')] @@ -676,17 +687,6 @@ def offset_from_time_and_reference(time): return (time - self.reference_frame).sec return [('temporal', 0, offset_from_time_and_reference)] - def _convert_to_time(self, dt, *, unit, **kwargs): - if (not isinstance(dt, time.TimeDelta) and - isinstance(dt, time.Time) or - isinstance(self.reference_frame.value, np.ndarray)): - return time.Time(dt, **kwargs) - - if not hasattr(dt, 'unit'): - dt = dt * unit - - return self.reference_frame + dt - class CompositeFrame(CoordinateFrame): """ @@ -723,10 +723,10 @@ def __init__(self, frames, name=None): "axes_order should contain unique numbers, " "got {}.".format(axes_order)) - super(CompositeFrame, self).__init__(naxes, axes_type=axes_type, - axes_order=axes_order, - unit=unit, axes_names=axes_names, - name=name) + super().__init__(naxes, axes_type=axes_type, + axes_order=axes_order, + unit=unit, axes_names=axes_names, + name=name) self._axis_physical_types = tuple(ph_type) @property @@ -743,7 +743,7 @@ def _wao_classes_rename_map(self): for frame in self.frames: # ensure the frame is in the mapper mapper[frame] - for key in frame._world_axis_object_classes.keys(): + for key in frame.world_axis_object_classes.keys(): if key in seen_names: new_key = f"{key}{seen_names.count(key)}" mapper[frame][key] = new_key @@ -755,7 +755,7 @@ def _wao_renamed_components_iter(self): mapper = self._wao_classes_rename_map for frame in self.frames: renamed_components = [] - for comp in frame._world_axis_object_components: + for comp in frame.world_axis_object_components: comp = list(comp) rename = mapper[frame].get(comp[0]) if rename: @@ -767,14 +767,14 @@ def _wao_renamed_components_iter(self): def _wao_renamed_classes_iter(self): mapper = self._wao_classes_rename_map for frame in self.frames: - for key, value in frame._world_axis_object_classes.items(): + for key, value in frame.world_axis_object_classes.items(): rename = mapper[frame].get(key) if rename: key = rename yield key, value @property - def _world_axis_object_components(self): + def world_axis_object_components(self): """ We need to generate the components respecting the axes_order. """ @@ -788,7 +788,7 @@ def _world_axis_object_components(self): return out @property - def _world_axis_object_classes(self): + def world_axis_object_classes(self): return dict(self._wao_renamed_classes_iter) @@ -814,7 +814,7 @@ def _default_axis_physical_types(self): return ("phys.polarization.stokes",) @property - def _world_axis_object_classes(self): + def world_axis_object_classes(self): return {'stokes': ( StokesCoord, (), @@ -822,7 +822,7 @@ def _world_axis_object_classes(self): )} @property - def _world_axis_object_components(self): + def world_axis_object_components(self): return [('stokes', 0, 'value')] diff --git a/gwcs/tests/test_api.py b/gwcs/tests/test_api.py index 65b886f6..0c5b5c4e 100644 --- a/gwcs/tests/test_api.py +++ b/gwcs/tests/test_api.py @@ -491,12 +491,12 @@ def test_composite_many_base_frame(): q_frame_2 = cf.CoordinateFrame(name='distance', axes_order=(1,), naxes=1, axes_type="SPATIAL", unit=(u.m,)) frame = cf.CompositeFrame([q_frame_1, q_frame_2]) - wao_classes = frame._world_axis_object_classes + wao_classes = frame.world_axis_object_classes assert len(wao_classes) == 2 assert not set(wao_classes.keys()).difference({"SPATIAL", "SPATIAL1"}) - wao_components = frame._world_axis_object_components + wao_components = frame.world_axis_object_components assert len(wao_components) == 2 assert not {c[0] for c in wao_components}.difference({"SPATIAL", "SPATIAL1"}) diff --git a/gwcs/tests/test_coordinate_systems.py b/gwcs/tests/test_coordinate_systems.py index 967657f8..88035b10 100644 --- a/gwcs/tests/test_coordinate_systems.py +++ b/gwcs/tests/test_coordinate_systems.py @@ -10,11 +10,12 @@ from astropy.tests.helper import assert_quantity_allclose from astropy.modeling import models as m from astropy.wcs.wcsapi.fitswcs import CTYPE_TO_UCD1 -from astropy.coordinates import StokesCoord +from astropy.coordinates import StokesCoord, SpectralCoord from .. import WCS from .. import coordinate_frames as cf +from astropy.wcs.wcsapi.high_level_api import values_to_high_level_objects, high_level_objects_to_values import astropy astropy_version = astropy.__version__ @@ -33,7 +34,7 @@ focal = cf.Frame2D(name='focal', axes_order=(0, 1), unit=(u.m, u.m)) spec1 = cf.SpectralFrame(name='freq', unit=[u.Hz, ], axes_order=(2, )) -spec2 = cf.SpectralFrame(name='wave', unit=[u.m, ], axes_order=(2, ), axes_names=('lambda', )) +spec2 = cf.SpectralFrame(name='wave', unit=[u.m, ], axes_order=(2, ), axes_names=('lambda',)) spec3 = cf.SpectralFrame(name='energy', unit=[u.J, ], axes_order=(2, )) spec4 = cf.SpectralFrame(name='pixel', unit=[u.pix, ], axes_order=(2, )) spec5 = cf.SpectralFrame(name='speed', unit=[u.m/u.s, ], axes_order=(2, )) @@ -55,6 +56,19 @@ inputs3 = [(xscalar, yscalar, xscalar), (xarr, yarr, xarr)] +@pytest.fixture(autouse=True, scope="module") +def serialized_classes(): + """ + In the rest of this test file we are passing the CoordinateFrame object to + astropy helper functions as if they were a low level WCS object. + + This little patch means that this works. + """ + cf.CoordinateFrame.serialized_classes = False + yield + del cf.CoordinateFrame.serialized_classes + + def test_units(): assert(comp1.unit == (u.deg, u.deg, u.Hz)) assert(comp2.unit == (u.m, u.m, u.m)) @@ -64,19 +78,34 @@ def test_units(): assert(comp.unit == (u.deg, u.deg, u.Hz, u.m)) +# These two functions fake the old methods on CoordinateFrame to reduce the +# amount of refactoring that needed doing in these tests. +def coordinates(*inputs, frame): + results = values_to_high_level_objects(*inputs, low_level_wcs=frame) + if isinstance(results, list) and len(results) == 1: + return results[0] + return results + + +def coordinate_to_quantity(*inputs, frame): + results = high_level_objects_to_values(*inputs, low_level_wcs=frame) + results = [r<