From 6fd7e2b4a5dd0bc4e0d9dc5555dbc4140d6c88aa Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Wed, 4 Dec 2024 09:42:24 +0000 Subject: [PATCH] More test fixes --- gwcs/api.py | 2 ++ gwcs/tests/test_wcs.py | 22 ++++++++++++---------- gwcs/wcs.py | 4 ++-- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/gwcs/api.py b/gwcs/api.py index 472357dd..0c32482e 100644 --- a/gwcs/api.py +++ b/gwcs/api.py @@ -9,6 +9,8 @@ from astropy.modeling import separable import astropy.units as u +from gwcs import utils + __all__ = ["GWCSAPIMixin"] diff --git a/gwcs/tests/test_wcs.py b/gwcs/tests/test_wcs.py index 06bf4335..bfd5f115 100644 --- a/gwcs/tests/test_wcs.py +++ b/gwcs/tests/test_wcs.py @@ -23,7 +23,7 @@ from gwcs import utils from gwcs.wcstools import (wcs_from_fiducial, grid_from_bounding_box, wcs_from_points) from gwcs.utils import CoordinateFrameError -from gwcs.utils import _gwcs_from_hst_fits_wcs +from gwcs.tests.utils import _gwcs_from_hst_fits_wcs from gwcs.tests import data from gwcs.examples import gwcs_2d_bad_bounding_box_order @@ -467,7 +467,8 @@ def test_bounding_box_eval(): Tests evaluation with and without respecting the bounding_box. """ trans3 = models.Shift(10) & models.Scale(2) & models.Shift(-1) - pipeline = [('detector', trans3), ('sky', None)] + pipeline = [(cf.CoordinateFrame(naxes=1, axes_type=("PIXEL",), axes_order=(0,), name='detector'), trans3), + (cf.CoordinateFrame(naxes=1, axes_type=("SPATIAL",), axes_order=(0,), name='sky'), None)] w = wcs.WCS(pipeline) w.bounding_box = ((-1, 10), (6, 15), (4.3, 6.9)) @@ -608,11 +609,13 @@ def setup_class(self): tan = models.Pix2Sky_TAN(name='tangent_projection') sky_cs = cf.CelestialFrame(reference_frame=coord.ICRS(), name='sky') det = cf.Frame2D(name='detector') + focal = cf.Frame2D(name='focal') wcs_forward = wcslin | tan | n2c - pipeline = [wcs.Step('detector', distortion), - wcs.Step('focal', wcs_forward), - wcs.Step(sky_cs, None) - ] + pipeline = [ + wcs.Step(det, distortion), + wcs.Step(focal, wcs_forward), + wcs.Step(sky_cs, None) + ] self.wcs = wcs.WCS(input_frame=det, output_frame=sky_cs, @@ -659,7 +662,7 @@ def test_inverse(self): def test_back_coordinates(self): sky_coord = self.wcs(1, 2, with_units=True) - res = self.wcs.transform('sky', 'focal', sky_coord, with_units=True) + res = self.wcs.transform('sky', 'focal', sky_coord, with_units=False) assert_allclose(res, self.wcs.get_transform('detector', 'focal')(1, 2)) def test_units(self): @@ -1553,9 +1556,8 @@ def test_high_level_objects_in_pipeline_forward(gwcs_with_pipeline_celestial): *input_pixel, with_units=True ) - assert len(intermediate_world_with_units) == 1 - assert isinstance(intermediate_world_with_units[0], coord.SkyCoord) - sc = intermediate_world_with_units[0] + assert isinstance(intermediate_world_with_units, coord.SkyCoord) + sc = intermediate_world_with_units assert u.allclose(sc.ra, 20*u.arcsec) assert u.allclose(sc.dec, 15*u.deg) diff --git a/gwcs/wcs.py b/gwcs/wcs.py index 8c99b753..62fcb01d 100644 --- a/gwcs/wcs.py +++ b/gwcs/wcs.py @@ -1153,8 +1153,8 @@ def transform(self, from_frame, to_frame, *args, **kwargs): to_frame = self._get_frame_by_name(to_frame) with_units = kwargs.pop("with_units", False) - if with_units and backward: - args = high_level_objects_to_values(*args, low_level_wcs=self) + if backward and utils.is_high_level(*args, low_level_wcs=from_frame): + args = high_level_objects_to_values(*args, low_level_wcs=from_frame) results = self._call_forward(*args, from_frame=from_frame, to_frame=to_frame, **kwargs)