Skip to content

Commit

Permalink
Aded load_features argument to DataHandler so this can be use…
Browse files Browse the repository at this point in the history
…d to control which features are used in derivations, instead of using ``interp_kwargs``. Removed ``include_single_levels`` parsing from ``interp_kwargs``
  • Loading branch information
bnb32 committed Jan 11, 2025
1 parent efcacfa commit c03ad14
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 21 deletions.
17 changes: 13 additions & 4 deletions sup3r/preprocessing/data_handlers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
self,
file_paths,
features='all',
load_features='all',
res_kwargs: Optional[dict] = None,
chunks: Union[str, Dict[str, int]] = 'auto',
target: Optional[tuple] = None,
Expand All @@ -69,9 +70,16 @@ def __init__(
file_paths : str | list | pathlib.Path
file_paths input to LoaderClass
features : list | str
Features to load and / or derive. If 'all' then all available raw
features will be loaded. Specify explicit feature names for
derivations.
Features to derive. If 'all' then all available raw features will
just be loaded. Specify explicit feature names for derivations.
load_features : list | str
Features to load and make available for derivations. If 'all' then
all available raw features will be loaded and made available for
derivations. This can be used to restrict features used for
derivations. For example, to derive 'temperature_100m' from only
temperature isobars, from data that includes single level values as
well (like temperature_2m), don't include 'temperature_2m' in the
``load_features`` list.
res_kwargs : dict
Additional keyword arguments passed through to the ``BaseLoader``.
BaseLoader is usually xr.open_mfdataset for NETCDF files and
Expand Down Expand Up @@ -146,12 +154,13 @@ def __init__(
)

just_coords = not features
raster_feats = 'all' if any(missing_features) else []
raster_feats = load_features if any(missing_features) else []
self.rasterizer = self.loader = self.cache = None

if any(cached_features):
self.cache = Loader(
file_paths=cached_files,
features=load_features,
res_kwargs=res_kwargs,
chunks=chunks,
BaseLoader=BaseLoader,
Expand Down
17 changes: 5 additions & 12 deletions sup3r/preprocessing/derivers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,10 @@ def __init__(
that is not found in the :class:`Rasterizer` data it will look for
a method to derive the feature in the registry.
interp_kwargs : dict | None
Dictionary of kwargs for level interpolation. Can include "method",
"run_level_check", and "include_single_levels" keys. Method
specifies how to perform height interpolation. e.g. Deriving
u_20m from u_10m and u_100m. Options are "linear" and "log".
``include_single_levels = True`` will include single level
variables in addition to pressure level variables in the
interpolation. e.g. the 3D array of ``temperature_2m`` along
with the 4D array of ``temperature``. ``See
Dictionary of kwargs for level interpolation. Can include "method"
and "run_level_check". "method" specifies how to perform height
interpolation. e.g. Deriving u_20m from u_10m and u_100m. Options
are "linear" and "log". See
:py:meth:`sup3r.preprocessing.derivers.Deriver.do_level_interpolation`
""" # pylint: disable=line-too-long
if FeatureRegistry is not None:
Expand Down Expand Up @@ -324,10 +320,7 @@ def do_level_interpolation(
) -> xr.DataArray:
"""Interpolate over height or pressure to derive the given feature."""
ml_var, ml_levs = self.get_multi_level_data(feature)
if interp_kwargs.get('include_single_levels', False):
sl_var, sl_levs = self.get_single_level_data(feature)
else:
sl_var, sl_levs = None, None
sl_var, sl_levs = self.get_single_level_data(feature)

fstruct = parse_feature(feature)
attrs = {}
Expand Down
3 changes: 1 addition & 2 deletions tests/data_handlers/test_dh_nc_cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ def test_reload_cache():
features=features,
target=target,
shape=(20, 20),
cache_kwargs=cache_kwargs,
interp_kwargs={'include_single_levels': True}
cache_kwargs=cache_kwargs
)

# reload from cache
Expand Down
40 changes: 37 additions & 3 deletions tests/derivers/test_height_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,41 @@ def test_plevel_height_interp_nc_with_cache():
)


def test_plevel_height_interp_with_filtered_load_features():
"""Test that filtering load features can be used to control the features
used in the derivations."""

with TemporaryDirectory() as td:
orog_file = os.path.join(td, 'orog.nc')
make_fake_nc_file(orog_file, shape=(10, 10, 20), features=['orog'])
sfc_file = os.path.join(td, 'u_10m.nc')
make_fake_nc_file(sfc_file, shape=(10, 10, 20), features=['u_10m'])
level_file = os.path.join(td, 'wind_levs.nc')
make_fake_nc_file(
level_file, shape=(10, 10, 20, 3), features=['zg', 'u']
)
derive_features = ['u_20m']
dh_filt = DataHandler(
[orog_file, sfc_file, level_file],
features=derive_features,
load_features=['topography', 'zg', 'u'],
)
dh_no_filt = DataHandler(
[orog_file, sfc_file, level_file],
features=derive_features,
)
dh = DataHandler(
[orog_file, level_file],
features=derive_features,
)
assert np.array_equal(
dh_filt.data['u_20m'].data, dh.data['u_20m'].data
)
assert not np.array_equal(
dh_filt.data['u_20m'].data, dh_no_filt.data['u_20m'].data
)


def test_only_interp_method():
"""Test that interp method alone returns the right values"""
hgt = np.zeros((10, 10, 5, 3))
Expand Down Expand Up @@ -156,7 +191,7 @@ def test_single_levels_height_interp_nc(shape=(10, 10), target=(37.25, -107)):
transform = Deriver(
no_transform.data,
derive_features,
interp_kwargs={'method': 'linear', 'include_single_levels': True},
interp_kwargs={'method': 'linear'},
)

h10 = np.zeros(transform.shape[:3], dtype=np.float32)[..., None]
Expand Down Expand Up @@ -200,7 +235,6 @@ def test_plevel_height_interp_with_single_lev_data_nc(
transform = Deriver(
no_transform.data,
derive_features,
interp_kwargs={'include_single_levels': True},
)

hgt_array = (
Expand Down Expand Up @@ -241,7 +275,7 @@ def test_log_interp(shape=(10, 10), target=(37.25, -107)):
transform = Deriver(
no_transform.data,
derive_features,
interp_kwargs={'method': 'log', 'include_single_levels': True},
interp_kwargs={'method': 'log'},
)

hgt_array = (
Expand Down

0 comments on commit c03ad14

Please sign in to comment.