From fd9055ca32ff710a64dd7fe80785e70a86228907 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 8 Jan 2025 14:01:25 -0700 Subject: [PATCH] test fixes --- sup3r/utilities/interpolation.py | 2 +- tests/data_handlers/test_dh_nc_cc.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sup3r/utilities/interpolation.py b/sup3r/utilities/interpolation.py index 776361db8..79f093518 100644 --- a/sup3r/utilities/interpolation.py +++ b/sup3r/utilities/interpolation.py @@ -48,7 +48,7 @@ def get_level_masks(cls, lev_array, level): da.arange(lev_array.shape[-1]), lev_array.shape ) mask1 = lev_indices == argmin1 - lev_diff = da.where(mask1, np.inf, lev_diff) + lev_diff = da.abs(da.ma.masked_array(lev_array, mask1) - level) argmin2 = da.argmin(lev_diff, axis=-1, keepdims=True) mask2 = lev_indices == argmin2 return mask1, mask2 diff --git a/tests/data_handlers/test_dh_nc_cc.py b/tests/data_handlers/test_dh_nc_cc.py index 9dd868023..b952ab83c 100644 --- a/tests/data_handlers/test_dh_nc_cc.py +++ b/tests/data_handlers/test_dh_nc_cc.py @@ -69,6 +69,7 @@ def test_reload_cache(): target=target, shape=(20, 20), cache_kwargs=cache_kwargs, + interp_kwargs={'include_single_levels': True} ) # reload from cache @@ -80,7 +81,9 @@ def test_reload_cache(): cache_kwargs=cache_kwargs, ) assert all(f in cached for f in features) - assert np.array_equal(handler.as_array(), cached.as_array()) + harr = handler.as_array().compute() + carr = cached.as_array().compute() + assert np.array_equal(harr, carr) @pytest.mark.parametrize(