diff --git a/sup3r/utilities/interpolation.py b/sup3r/utilities/interpolation.py index 0c0903aca..79f093518 100644 --- a/sup3r/utilities/interpolation.py +++ b/sup3r/utilities/interpolation.py @@ -20,9 +20,6 @@ def get_level_masks(cls, lev_array, level): Parameters ---------- - var_array : Union[np.ndarray, da.core.Array] - Array of variable data, for example u-wind in a 4D array of shape - (lat, lon, time, level) lev_array : Union[np.ndarray, da.core.Array] Height or pressure values for the corresponding entries in var_array, in the same shape as var_array. If this is height and @@ -51,7 +48,7 @@ def get_level_masks(cls, lev_array, level): da.arange(lev_array.shape[-1]), lev_array.shape ) mask1 = lev_indices == argmin1 - lev_diff = da.where(mask1, np.inf, lev_diff) + lev_diff = da.abs(da.ma.masked_array(lev_array, mask1) - level) argmin2 = da.argmin(lev_diff, axis=-1, keepdims=True) mask2 = lev_indices == argmin2 return mask1, mask2 @@ -61,7 +58,7 @@ def _lin_interp(cls, lev_samps, var_samps, level): """Linearly interpolate between levels.""" diff = da.map_blocks(lambda x, y: x - y, lev_samps[1], lev_samps[0]) alpha = da.where( - diff < 1e-3, + diff == 0, 0, da.map_blocks(lambda x, y: x / y, (level - lev_samps[0]), diff), ) @@ -109,9 +106,6 @@ def interp_to_level( Parameters ---------- - var_array : xr.DataArray - Array of variable data, for example u-wind in a 4D array of shape - (lat, lon, time, level) lev_array : xr.DataArray Height or pressure values for the corresponding entries in var_array, in the same shape as var_array. If this is height and @@ -119,6 +113,9 @@ def interp_to_level( should be the geopotential height corresponding to every var_array index relative to the surface elevation (subtract the elevation at the surface from the geopotential height) + var_array : xr.DataArray + Array of variable data, for example u-wind in a 4D array of shape + (lat, lon, time, level) level : float level or levels to interpolate to (e.g. final desired hub height above surface elevation) diff --git a/tests/data_handlers/test_dh_nc_cc.py b/tests/data_handlers/test_dh_nc_cc.py index 9dd868023..b952ab83c 100644 --- a/tests/data_handlers/test_dh_nc_cc.py +++ b/tests/data_handlers/test_dh_nc_cc.py @@ -69,6 +69,7 @@ def test_reload_cache(): target=target, shape=(20, 20), cache_kwargs=cache_kwargs, + interp_kwargs={'include_single_levels': True} ) # reload from cache @@ -80,7 +81,9 @@ def test_reload_cache(): cache_kwargs=cache_kwargs, ) assert all(f in cached for f in features) - assert np.array_equal(handler.as_array(), cached.as_array()) + harr = handler.as_array().compute() + carr = cached.as_array().compute() + assert np.array_equal(harr, carr) @pytest.mark.parametrize(