From 2a70bd602a8ab69d2fc1db36cef33499be3acdb9 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 22 Nov 2024 08:08:01 -0700 Subject: [PATCH 01/11] 1d_nc_loading --- sup3r/bias/bias_transforms.py | 5 ++- sup3r/preprocessing/cachers/base.py | 11 ++++--- sup3r/preprocessing/loaders/base.py | 31 ++++++++++++++++++- sup3r/preprocessing/loaders/h5.py | 26 ++++++++-------- sup3r/preprocessing/loaders/nc.py | 48 ++++++++++++++++++++++------- 5 files changed, 88 insertions(+), 33 deletions(-) diff --git a/sup3r/bias/bias_transforms.py b/sup3r/bias/bias_transforms.py index 6229cf010..29bc536e8 100644 --- a/sup3r/bias/bias_transforms.py +++ b/sup3r/bias/bias_transforms.py @@ -791,10 +791,9 @@ def local_qdm_bc( if out_range is not None: data_unbiased = np.maximum(data_unbiased, np.min(out_range)) data_unbiased = np.minimum(data_unbiased, np.max(out_range)) - - if da.isnan(data_unbiased).any(): + if not da.isfinite(data_unbiased).all(): msg = ( - 'QDM bias correction resulted in NaN values! If this is a ' + 'QDM bias correction resulted in NaN / inf values! If this is a ' 'relative QDM, you may try setting ``delta_denom_min`` or ' '``delta_denom_zero``' ) diff --git a/sup3r/preprocessing/cachers/base.py b/sup3r/preprocessing/cachers/base.py index 2a19fc3c7..be380b673 100644 --- a/sup3r/preprocessing/cachers/base.py +++ b/sup3r/preprocessing/cachers/base.py @@ -328,13 +328,16 @@ def write_h5( coord_names = [ crd for crd in data.coords if crd in Dimension.coords_4d() ] + + if Dimension.TIME in data: + data[Dimension.TIME] = data[Dimension.TIME].astype(int) + for dset in [*coord_names, *features]: data_var, chunksizes = cls.get_chunksizes(dset, data, chunks) + data_var = data_var.data - if dset == Dimension.TIME: - data_var = da.asarray(data_var.astype(int).data) - else: - data_var = data_var.data + if not isinstance(data_var, da.core.Array): + data_var = da.asarray(data_var) dset_name = dset if dset == Dimension.TIME: diff --git a/sup3r/preprocessing/loaders/base.py b/sup3r/preprocessing/loaders/base.py index bdad1ff9a..ee9994eea 100644 --- a/sup3r/preprocessing/loaders/base.py +++ b/sup3r/preprocessing/loaders/base.py @@ -5,12 +5,13 @@ import logging from abc import ABC, abstractmethod from datetime import datetime as dt +from functools import cached_property from typing import Callable import numpy as np from sup3r.preprocessing.base import Container -from sup3r.preprocessing.names import FEATURE_NAMES +from sup3r.preprocessing.names import FEATURE_NAMES, Dimension from sup3r.preprocessing.utilities import ( expand_paths, log_args, @@ -153,3 +154,31 @@ def _load(self): ------- xr.Dataset """ + + @cached_property + @abstractmethod + def _lat_lon_shape(self): + """Get shape of lat lon grid only.""" + + @cached_property + @abstractmethod + def _is_flattened(self): + """Check if data is flattened or not""" + + @cached_property + def _lat_lon_dims(self): + """Get dim names for lat lon grid. Either + ``Dimension.FLATTENED_SPATIAL`` or ``(Dimension.SOUTH_NORTH, + Dimension.WEST_EAST)``""" + if self._is_flattened: + return (Dimension.FLATTENED_SPATIAL,) + return Dimension.dims_2d() + + @property + def _time_independent(self): + return 'time_index' not in self.res and 'time' not in self.res + + def _is_spatial_dset(self, data): + """Check if given data is spatial only. We compare against the size of + the spatial domain.""" + return len(data.shape) == 1 and len(data) == self._lat_lon_shape[0] diff --git a/sup3r/preprocessing/loaders/h5.py b/sup3r/preprocessing/loaders/h5.py index bbfaa2ec8..cd00be965 100644 --- a/sup3r/preprocessing/loaders/h5.py +++ b/sup3r/preprocessing/loaders/h5.py @@ -6,6 +6,7 @@ """ import logging +from functools import cached_property from typing import Dict, Tuple from warnings import warn @@ -32,26 +33,23 @@ class LoaderH5(BaseLoader): BASE_LOADER = MultiFileWindX - @property - def _time_independent(self): - return 'time_index' not in self.res - @property def _time_steps(self): return ( len(self.res['time_index']) if not self._time_independent else None ) - def _meta_shape(self): + @cached_property + def _lat_lon_shape(self): """Get shape of spatial domain only.""" if 'latitude' in self.res.h5: return self.res.h5['latitude'].shape return self.res.h5['meta']['latitude'].shape - def _is_spatial_dset(self, data): - """Check if given data is spatial only. We compare against the size of - the meta.""" - return len(data.shape) == 1 and len(data) == self._meta_shape()[0] + @cached_property + def _is_flattened(self): + """Check if dims include a single spatial dimension.""" + return self._lat_lon_shape == 1 def _res_shape(self): """Get shape of H5 file. @@ -61,9 +59,9 @@ def _res_shape(self): Flattened files are 2D but we have 3D H5 files available through caching and bias correction factor calculations.""" return ( - self._meta_shape() + self._lat_lon_shape if self._time_independent - else (self._time_steps, *self._meta_shape()) + else (self._time_steps, *self._lat_lon_shape) ) def _get_coords(self, dims): @@ -74,7 +72,7 @@ def _get_coords(self, dims): coord_base = ( self.res.h5 if 'latitude' in self.res.h5 else self.res.h5['meta'] ) - coord_dims = dims[-len(self._meta_shape()) :] + coord_dims = dims[-len(self._lat_lon_shape) :] chunks = self._parse_chunks(coord_dims) lats = da.asarray( coord_base['latitude'], dtype=np.float32, chunks=chunks @@ -152,7 +150,7 @@ def _check_for_elevation(self, data_vars, dims, chunks): elevation to data_vars if it is.""" flattened_with_elevation = ( - len(self._meta_shape()) == 1 + len(self._lat_lon_shape) == 1 and hasattr(self.res, 'meta') and 'elevation' in self.res.meta ) @@ -191,7 +189,7 @@ def _get_data_vars(self, dims): def _get_dims(self): """Get tuple of named dims for dataset.""" - if len(self._meta_shape()) == 2: + if len(self._lat_lon_shape) == 2: dims = Dimension.dims_2d() else: dims = (Dimension.FLATTENED_SPATIAL,) diff --git a/sup3r/preprocessing/loaders/nc.py b/sup3r/preprocessing/loaders/nc.py index 205e47d02..8711d8c05 100644 --- a/sup3r/preprocessing/loaders/nc.py +++ b/sup3r/preprocessing/loaders/nc.py @@ -3,6 +3,7 @@ classes.""" import logging +from functools import cached_property from warnings import warn import dask.array as da @@ -31,7 +32,7 @@ def BASE_LOADER(self, file_paths, **kwargs): def _enforce_descending_lats(self, dset): """Make sure latitudes are in descending order so that min lat / lon is at ``lat_lon[-1, 0]``.""" - invert_lats = ( + invert_lats = not self._is_flattened(self.res) and ( dset[Dimension.LATITUDE][-1, 0] > dset[Dimension.LATITUDE][0, 0] ) if invert_lats: @@ -63,8 +64,31 @@ def _enforce_descending_levels(self, dset): dset.update({Dimension.PRESSURE_LEVEL: new_press}) return dset - @staticmethod - def get_coords(res): + @cached_property + def _lat_lon_shape(self): + """Get shape of lat lon grid only.""" + space_key = ( + Dimension.LATITUDE + if Dimension.LATITUDE in self.res + else Dimension.SOUTH_NORTH + ) + return self.res[space_key].shape + + @cached_property + def _is_flattened(self): + """Check if dims include a single spatial dimension.""" + crd_names = ( + Dimension.coords_2d() + if Dimension.LATITUDE in self.res + else Dimension.dims_2d() + ) + check = ( + self._lat_lon_shape == 1 + and self.res[crd_names[0]].dims == self.res[crd_names[1]].dims + ) + return check + + def get_coords(self, res): """Get coordinate dictionary to use in ``xr.Dataset().assign_coords()``.""" lats = res[Dimension.LATITUDE].data.astype(np.float32) @@ -76,12 +100,14 @@ def get_coords(res): if lons.ndim == 3: lons = lons.squeeze() - if len(lats.shape) == 1: + if len(lats.shape) == 1 and not self._is_flattened: lons, lats = da.meshgrid(lons, lats) - lats = ((Dimension.SOUTH_NORTH, Dimension.WEST_EAST), lats) - lons = ((Dimension.SOUTH_NORTH, Dimension.WEST_EAST), lons) - coords = {Dimension.LATITUDE: lats, Dimension.LONGITUDE: lons} + dim_names = self._lat_lon_dims + coords = { + Dimension.LATITUDE: (dim_names, lats), + Dimension.LONGITUDE: (dim_names, lons), + } if Dimension.TIME in res: if Dimension.TIME in res.indexes: @@ -95,16 +121,16 @@ def get_coords(res): coords[Dimension.TIME] = times return coords - @staticmethod - def get_dims(res): + def get_dims(self, res): """Get dimension name map using our standard mappping and the names used for coordinate dimensions.""" rename_dims = {k: v for k, v in DIM_NAMES.items() if k in res.dims} lat_dims = res[Dimension.LATITUDE].dims lon_dims = res[Dimension.LONGITUDE].dims if len(lat_dims) == 1 and len(lon_dims) == 1: - rename_dims[lat_dims[0]] = Dimension.SOUTH_NORTH - rename_dims[lon_dims[0]] = Dimension.WEST_EAST + dim_names = self._lat_lon_dims + rename_dims[lat_dims[0]] = dim_names[0] + rename_dims[lon_dims[0]] = dim_names[-1] else: msg = ( 'Latitude and Longitude dimension names are different. ' From 0d7e512f83bedda5f53f86dad9f6390ea8cbeb9b Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 24 Nov 2024 07:33:37 -0700 Subject: [PATCH 02/11] test for 1d netcdf loading. --- sup3r/preprocessing/loaders/__init__.py | 2 +- sup3r/preprocessing/loaders/base.py | 14 +++---- sup3r/preprocessing/loaders/h5.py | 44 +++++++++++---------- sup3r/preprocessing/loaders/nc.py | 33 ++++++---------- sup3r/preprocessing/rasterizers/extended.py | 2 +- tests/loaders/test_file_loading.py | 35 +++++++++++++++- 6 files changed, 78 insertions(+), 52 deletions(-) diff --git a/sup3r/preprocessing/loaders/__init__.py b/sup3r/preprocessing/loaders/__init__.py index 4208399e2..1ffdf7b59 100644 --- a/sup3r/preprocessing/loaders/__init__.py +++ b/sup3r/preprocessing/loaders/__init__.py @@ -28,4 +28,4 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, trace): - self.res.close() + self._res.close() diff --git a/sup3r/preprocessing/loaders/base.py b/sup3r/preprocessing/loaders/base.py index ee9994eea..187f78d5c 100644 --- a/sup3r/preprocessing/loaders/base.py +++ b/sup3r/preprocessing/loaders/base.py @@ -75,7 +75,7 @@ def __init__( self.file_paths = file_paths self.chunks = chunks BASE_LOADER = BaseLoader or self.BASE_LOADER - self.res = BASE_LOADER(self.file_paths, **self.res_kwargs) + self._res = BASE_LOADER(self.file_paths, **self.res_kwargs) data = lower_names(self._load()) data = self._add_attrs(data) data = standardize_values(data) @@ -84,8 +84,8 @@ def __init__( features = list(data.dims) if features == [] else features self.data = data[features] if features != 'all' else data - if 'meta' in self.res: - self.data.meta = self.res.meta + if 'meta' in self._res: + self.data.meta = self._res.meta if self.chunks is None: logger.info(f'Pre-loading data into memory for: {features}') @@ -107,8 +107,8 @@ def _parse_chunks(self, dims, feature=None): def _add_attrs(self, data): """Add meta data to dataset.""" attrs = {'source_files': self.file_paths} - attrs['global_attrs'] = getattr(self.res, 'global_attrs', []) - attrs.update(getattr(self.res, 'attrs', {})) + attrs['global_attrs'] = getattr(self._res, 'global_attrs', []) + attrs.update(getattr(self._res, 'attrs', {})) attrs['date_modified'] = attrs.get( 'date_modified', dt.utcnow().isoformat() ) @@ -119,7 +119,7 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, trace): - self.res.close() + self._res.close() @property def file_paths(self): @@ -176,7 +176,7 @@ def _lat_lon_dims(self): @property def _time_independent(self): - return 'time_index' not in self.res and 'time' not in self.res + return 'time_index' not in self._res and 'time' not in self._res def _is_spatial_dset(self, data): """Check if given data is spatial only. We compare against the size of diff --git a/sup3r/preprocessing/loaders/h5.py b/sup3r/preprocessing/loaders/h5.py index cd00be965..560a93682 100644 --- a/sup3r/preprocessing/loaders/h5.py +++ b/sup3r/preprocessing/loaders/h5.py @@ -36,20 +36,22 @@ class LoaderH5(BaseLoader): @property def _time_steps(self): return ( - len(self.res['time_index']) if not self._time_independent else None + len(self._res['time_index']) + if not self._time_independent + else None ) @cached_property def _lat_lon_shape(self): """Get shape of spatial domain only.""" - if 'latitude' in self.res.h5: - return self.res.h5['latitude'].shape - return self.res.h5['meta']['latitude'].shape + if 'latitude' in self._res.h5: + return self._res.h5['latitude'].shape + return self._res.h5['meta']['latitude'].shape @cached_property def _is_flattened(self): """Check if dims include a single spatial dimension.""" - return self._lat_lon_shape == 1 + return len(self._lat_lon_shape) == 1 def _res_shape(self): """Get shape of H5 file. @@ -68,9 +70,11 @@ def _get_coords(self, dims): """Get coords dict for xr.Dataset construction.""" coords: Dict[str, Tuple] = {} if not self._time_independent: - coords[Dimension.TIME] = self.res['time_index'] + coords[Dimension.TIME] = self._res['time_index'] coord_base = ( - self.res.h5 if 'latitude' in self.res.h5 else self.res.h5['meta'] + self._res.h5 + if 'latitude' in self._res.h5 + else self._res.h5['meta'] ) coord_dims = dims[-len(self._lat_lon_shape) :] chunks = self._parse_chunks(coord_dims) @@ -92,14 +96,14 @@ def _get_dset_tuple(self, dset, dims, chunks): spatiotemporal, 3D spatiotemporal, 4D spatiotemporal (with presssure levels), etc """ - # if self.res includes time-dependent and time-independent variables + # if self._res includes time-dependent and time-independent variables # and chunks is 3-tuple we only use the spatial chunk for # time-indepdent variables dset_chunks = chunks - if len(chunks) == 3 and len(self.res.h5[dset].shape) == 2: - dset_chunks = chunks[-len(self.res.h5[dset].shape)] + if len(chunks) == 3 and len(self._res.h5[dset].shape) == 2: + dset_chunks = chunks[-len(self._res.h5[dset].shape)] arr = da.asarray( - self.res.h5[dset], dtype=np.float32, chunks=dset_chunks + self._res.h5[dset], dtype=np.float32, chunks=dset_chunks ) arr /= self.scale_factor(dset) if len(arr.shape) == 4: @@ -128,7 +132,7 @@ def _get_dset_tuple(self, dset, dims, chunks): arr_dims = dims[-len(arr.shape) :] elif len(arr.shape) == 1: msg = ( - f'Received 1D feature "{dset}" with shape that does not ' + f'Received 1D feature "{dset}" with shape that does not equal ' 'the length of the meta nor the time_index.' ) is_ts = not self._time_independent and len(arr) == self._time_steps @@ -136,7 +140,7 @@ def _get_dset_tuple(self, dset, dims, chunks): arr_dims = (Dimension.TIME,) else: arr_dims = dims[: len(arr.shape)] - return (arr_dims, arr, dict(self.res.h5[dset].attrs)) + return (arr_dims, arr, dict(self._res.h5[dset].attrs)) def _parse_chunks(self, dims, feature=None): """Get chunks for given dimensions from ``self.chunks``.""" @@ -151,14 +155,14 @@ def _check_for_elevation(self, data_vars, dims, chunks): flattened_with_elevation = ( len(self._lat_lon_shape) == 1 - and hasattr(self.res, 'meta') - and 'elevation' in self.res.meta + and hasattr(self._res, 'meta') + and 'elevation' in self._res.meta ) if flattened_with_elevation: - elev = self.res.meta['elevation'].values.astype(np.float32) + elev = self._res.meta['elevation'].values.astype(np.float32) elev = da.asarray(elev) if not self._time_independent: - t_steps = len(self.res['time_index']) + t_steps = len(self._res['time_index']) elev = da.repeat(elev[None, ...], t_steps, axis=0) elev = elev.rechunk(chunks) data_vars['elevation'] = (dims, elev) @@ -173,7 +177,7 @@ def _get_data_vars(self, dims): data_vars, dims=dims, chunks=chunks ) - feats = set(self.res.h5.datasets) + feats = set(self._res.h5.datasets) exclude = { 'meta', 'time_index', @@ -214,8 +218,8 @@ def _load(self) -> xr.Dataset: def scale_factor(self, feature): """Get scale factor for given feature. Data is stored in scaled form to reduce memory.""" - feat = feature if feature in self.res.datasets else feature.lower() - feat = self.res.h5[feat] + feat = feature if feature in self._res.datasets else feature.lower() + feat = self._res.h5[feat] return np.float32( 1.0 if not hasattr(feat, 'attrs') diff --git a/sup3r/preprocessing/loaders/nc.py b/sup3r/preprocessing/loaders/nc.py index 8711d8c05..2336eb9db 100644 --- a/sup3r/preprocessing/loaders/nc.py +++ b/sup3r/preprocessing/loaders/nc.py @@ -32,7 +32,7 @@ def BASE_LOADER(self, file_paths, **kwargs): def _enforce_descending_lats(self, dset): """Make sure latitudes are in descending order so that min lat / lon is at ``lat_lon[-1, 0]``.""" - invert_lats = not self._is_flattened(self.res) and ( + invert_lats = not self._is_flattened and ( dset[Dimension.LATITUDE][-1, 0] > dset[Dimension.LATITUDE][0, 0] ) if invert_lats: @@ -67,28 +67,19 @@ def _enforce_descending_levels(self, dset): @cached_property def _lat_lon_shape(self): """Get shape of lat lon grid only.""" - space_key = ( - Dimension.LATITUDE - if Dimension.LATITUDE in self.res - else Dimension.SOUTH_NORTH - ) - return self.res[space_key].shape + return self._res[Dimension.LATITUDE].shape @cached_property def _is_flattened(self): """Check if dims include a single spatial dimension.""" - crd_names = ( - Dimension.coords_2d() - if Dimension.LATITUDE in self.res - else Dimension.dims_2d() - ) check = ( - self._lat_lon_shape == 1 - and self.res[crd_names[0]].dims == self.res[crd_names[1]].dims + len(self._lat_lon_shape) == 1 + and self._res[Dimension.LATITUDE].dims + == self._res[Dimension.LONGITUDE].dims ) return check - def get_coords(self, res): + def _get_coords(self, res): """Get coordinate dictionary to use in ``xr.Dataset().assign_coords()``.""" lats = res[Dimension.LATITUDE].data.astype(np.float32) @@ -121,7 +112,7 @@ def get_coords(self, res): coords[Dimension.TIME] = times return coords - def get_dims(self, res): + def _get_dims(self, res): """Get dimension name map using our standard mappping and the names used for coordinate dimensions.""" rename_dims = {k: v for k, v in DIM_NAMES.items() if k in res.dims} @@ -159,19 +150,19 @@ def _rechunk_dsets(self, res): def _load(self): """Load netcdf ``xarray.Dataset()``.""" - res = lower_names(self.res) + res = lower_names(self._res) rename_coords = { k: v for k, v in COORD_NAMES.items() if k in res and v not in res } - res = res.rename(rename_coords) + self._res = res.rename(rename_coords) - if not all(coord in res for coord in Dimension.coords_2d()): + if not all(coord in self._res for coord in Dimension.coords_2d()): err = 'Could not find valid coordinates in given files: %s' logger.error(err, self.file_paths) raise OSError(err % (self.file_paths)) - res = res.swap_dims(self.get_dims(res)) - res = res.assign_coords(self.get_coords(res)) + res = self._res.swap_dims(self._get_dims(self._res)) + res = res.assign_coords(self._get_coords(res)) res = self._enforce_descending_lats(res) res = self._rechunk_dsets(res) return self._enforce_descending_levels(res).astype(np.float32) diff --git a/sup3r/preprocessing/rasterizers/extended.py b/sup3r/preprocessing/rasterizers/extended.py index d62874c0f..8aca23b88 100644 --- a/sup3r/preprocessing/rasterizers/extended.py +++ b/sup3r/preprocessing/rasterizers/extended.py @@ -175,7 +175,7 @@ def _get_flat_data_raster_index(self): assert ( self._target is not None and self._grid_shape is not None ), msg - raster_index = self.loader.res.get_raster_index( + raster_index = self.loader._res.get_raster_index( self._target, self._grid_shape, max_delta=self.max_delta ) else: diff --git a/tests/loaders/test_file_loading.py b/tests/loaders/test_file_loading.py index 72d16b7d9..59bdbf340 100644 --- a/tests/loaders/test_file_loading.py +++ b/tests/loaders/test_file_loading.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd import pytest +import xarray as xr from rex import Resource from sup3r.preprocessing import Dimension, Loader, LoaderH5, LoaderNC @@ -51,7 +52,7 @@ def test_dim_ordering(): Dimension.TIME, Dimension.PRESSURE_LEVEL, 'nbnd', - Dimension.VARIABLE + Dimension.VARIABLE, ) @@ -178,6 +179,36 @@ def test_load_era5(fp): ) +def test_load_flattened_nc(): + """Test simple netcdf file loading when nc data is spatially flattened.""" + with TemporaryDirectory() as td: + temp_file = os.path.join(td, 'test.nc') + coords = { + 'time': np.array(range(5)), + 'latitude': ('space_dummy', np.array(range(100))), + 'longitude': ('space_dummy', np.array(range(100))), + } + data_vars = { + 'u_100m': (('time', 'space_dummy'), np.zeros((5, 100))), + 'v_100m': (('time', 'space_dummy'), np.zeros((5, 100))), + } + nc = xr.Dataset(coords=coords, data_vars=data_vars) + nc.to_netcdf(temp_file) + chunks = {'time': 5, 'space': 5} + loader = LoaderNC(temp_file, chunks=chunks) + assert loader.shape == (100, 5, 2) + assert 'space' in loader['latitude'].dims + assert 'space' in loader['longitude'].dims + assert all( + loader[f].data.chunksize == tuple(chunks.values()) + for f in loader.features + ) + + gen_loader = Loader(temp_file, chunks=chunks) + + assert np.array_equal(loader.as_array(), gen_loader.as_array()) + + def test_load_nc(): """Test simple netcdf file loading. Make sure general loader matches nc specific loader""" @@ -224,7 +255,7 @@ def test_load_h5(): assert np.array_equal(loader.as_array(), gen_loader.as_array()) loader_attrs = {f: loader[f].attrs for f in feats} resource_attrs = Resource(pytest.FP_WTK).attrs - assert np.array_equal(loader.meta, loader.res.meta) + assert np.array_equal(loader.meta, loader._res.meta) matching_feats = set(Resource(pytest.FP_WTK).datasets).intersection(feats) assert all(loader_attrs[f] == resource_attrs[f] for f in matching_feats) From 8db2f34d7ce772b61c32a41962d6b8cf6ee47b36 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 24 Nov 2024 09:21:31 -0700 Subject: [PATCH 03/11] last padded slice should be at minimum equal in shape to the forward pass chunk shape. For some grid sizes and forward pass chunk shapes not forcing this minimum shape can result in a chunk that is too small for the generator. Fixed slice test, which was not checking all slices. --- sup3r/pipeline/slicer.py | 35 ++++++++++++++++--------- tests/forward_pass/test_forward_pass.py | 33 ++++++++++++++++++----- 2 files changed, 50 insertions(+), 18 deletions(-) diff --git a/sup3r/pipeline/slicer.py b/sup3r/pipeline/slicer.py index a502c546e..08f03ea20 100644 --- a/sup3r/pipeline/slicer.py +++ b/sup3r/pipeline/slicer.py @@ -175,11 +175,12 @@ def t_lr_pad_slices(self): """ if self._t_lr_pad_slices is None: self._t_lr_pad_slices = self.get_padded_slices( - self.t_lr_slices, - self.time_steps, - 1, - self.temporal_pad, - self.time_slice.step, + slices=self.t_lr_slices, + shape=self.time_steps, + enhancement=1, + padding=self.temporal_pad, + step=self.time_slice.step, + min_size=self.chunk_shape[-1], ) return self._t_lr_pad_slices @@ -344,10 +345,11 @@ def s1_lr_pad_slices(self): spatial dimension""" if self._s1_lr_pad_slices is None: self._s1_lr_pad_slices = self.get_padded_slices( - self.s1_lr_slices, - self.coarse_shape[0], - 1, + slices=self.s1_lr_slices, + shape=self.coarse_shape[0], + enhancement=1, padding=self.spatial_pad, + min_size=self.chunk_shape[0], ) return self._s1_lr_pad_slices @@ -357,10 +359,11 @@ def s2_lr_pad_slices(self): spatial dimension""" if self._s2_lr_pad_slices is None: self._s2_lr_pad_slices = self.get_padded_slices( - self.s2_lr_slices, - self.coarse_shape[1], - 1, + slices=self.s2_lr_slices, + shape=self.coarse_shape[1], + enhancement=1, padding=self.spatial_pad, + min_size=self.chunk_shape[1], ) return self._s2_lr_pad_slices @@ -461,7 +464,9 @@ def n_chunks(self): return self.n_spatial_chunks * self.n_time_chunks @staticmethod - def get_padded_slices(slices, shape, enhancement, padding, step=None): + def get_padded_slices( + slices, shape, enhancement, padding, min_size, step=None + ): """Get padded slices with the specified padding size, max shape, enhancement, and step size @@ -481,6 +486,10 @@ def get_padded_slices(slices, shape, enhancement, padding, step=None): dimension and the spatial_pad is 10 this is 10. It will be multiplied by the enhancement factor if the slices are to be used to index an enhanced dimension. + min_size : int + Minimum size of a slice. This is usually the forward pass chunk + shape. A padded slice (the size of data passed to the generator) + should not be smaller than the forward pass chunk shape step : int | None Step size for slices. e.g. If these slices are indexing a temporal dimension and time_slice.step = 3 then step=3. @@ -496,6 +505,8 @@ def get_padded_slices(slices, shape, enhancement, padding, step=None): for _, s in enumerate(slices): start = np.max([0, s.start * enhancement - pad]) end = np.min([enhancement * shape, s.stop * enhancement + pad]) + if end - start < min_size: + start = end - min_size pad_slices.append(slice(start, end, step)) return pad_slices diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index 61171b80a..a76e3313b 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -1,4 +1,5 @@ """pytests for forward pass module""" + import json import os import tempfile @@ -681,18 +682,38 @@ def test_slicing_no_pad(input_files): ) fwp = ForwardPass(strategy) - for i in range(len(strategy.node_chunks)): + for i in strategy.node_chunks[0]: chunk = fwp.get_input_chunk(i) s_idx, t_idx = strategy.get_chunk_indices(i) - s_slices = strategy.lr_pad_slices[s_idx] + s_slices = strategy.lr_slices[s_idx] + s_pad_slices = strategy.lr_pad_slices[s_idx] + s_crop_slices = strategy.fwp_slicer.s_lr_crop_slices[s_idx] + t_crop_slice = strategy.fwp_slicer.t_lr_crop_slices[t_idx] + lr_pad_data_slice = ( + s_pad_slices[0], + s_pad_slices[1], + fwp.strategy.ti_pad_slices[t_idx], + ) + lr_crop_data_slice = ( + s_crop_slices[0], + s_crop_slices[1], + t_crop_slice, + ) lr_data_slice = ( s_slices[0], s_slices[1], - fwp.strategy.ti_pad_slices[t_idx], + fwp.strategy.ti_slices[t_idx], ) - truth = handler.data[lr_data_slice] - assert np.allclose(chunk.input_data, truth) + assert handler.data[lr_pad_data_slice].shape[:-1] == (3, 2, 4) + assert chunk.input_data.shape[:-1] == (3, 2, 4) + assert np.allclose( + chunk.input_data, handler.data[lr_pad_data_slice] + ) + assert np.allclose( + chunk.input_data[lr_crop_data_slice], + handler.data[lr_data_slice], + ) def test_slicing_pad(input_files): @@ -752,7 +773,7 @@ def test_slicing_pad(input_files): assert chunk_lookup[0, 1, 1] == n_s1 * n_s2 + 1 fwp = ForwardPass(strategy) - for i in range(len(strategy.node_chunks)): + for i in strategy.node_chunks[0]: chunk = fwp.get_input_chunk(i, mode='constant') s_idx, t_idx = strategy.get_chunk_indices(i) s_slices = strategy.lr_pad_slices[s_idx] From fe91a348202630407c587be294c5e03e57291fa7 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 24 Nov 2024 09:44:25 -0700 Subject: [PATCH 04/11] min_size should not be larger than grid size --- sup3r/pipeline/slicer.py | 16 +++++++--------- tests/forward_pass/test_forward_pass.py | 2 +- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/sup3r/pipeline/slicer.py b/sup3r/pipeline/slicer.py index 08f03ea20..d6eb369c7 100644 --- a/sup3r/pipeline/slicer.py +++ b/sup3r/pipeline/slicer.py @@ -179,8 +179,7 @@ def t_lr_pad_slices(self): shape=self.time_steps, enhancement=1, padding=self.temporal_pad, - step=self.time_slice.step, - min_size=self.chunk_shape[-1], + step=self.time_slice.step ) return self._t_lr_pad_slices @@ -465,7 +464,7 @@ def n_chunks(self): @staticmethod def get_padded_slices( - slices, shape, enhancement, padding, min_size, step=None + slices, shape, enhancement, padding, min_size=None, step=None ): """Get padded slices with the specified padding size, max shape, enhancement, and step size @@ -487,9 +486,8 @@ def get_padded_slices( multiplied by the enhancement factor if the slices are to be used to index an enhanced dimension. min_size : int - Minimum size of a slice. This is usually the forward pass chunk - shape. A padded slice (the size of data passed to the generator) - should not be smaller than the forward pass chunk shape + Minimum size of a slice. This is usually at least 4. Padding layers + in the generator model typpically require a minimum shape of 4. step : int | None Step size for slices. e.g. If these slices are indexing a temporal dimension and time_slice.step = 3 then step=3. @@ -503,10 +501,10 @@ def get_padded_slices( pad = step * padding * enhancement pad_slices = [] for _, s in enumerate(slices): - start = np.max([0, s.start * enhancement - pad]) end = np.min([enhancement * shape, s.stop * enhancement + pad]) - if end - start < min_size: - start = end - min_size + start = np.max([0, s.start * enhancement - pad]) + if min_size is not None and end - start < min_size: + start = np.max([0, end - min_size]) pad_slices.append(slice(start, end, step)) return pad_slices diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index a76e3313b..6c0b3f003 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -750,7 +750,7 @@ def test_slicing_pad(input_files): input_files, model_kwargs={'model_dir': st_out_dir}, model_class='Sup3rGan', - fwp_chunk_shape=(2, 1, 4), + fwp_chunk_shape=(4, 1, 4), input_handler_kwargs=input_handler_kwargs, spatial_pad=2, temporal_pad=2, From 2a4f3b03fa950e596df0deac7dabe87d660cfc23 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 24 Nov 2024 14:46:25 -0700 Subject: [PATCH 05/11] some slicer clean up --- sup3r/pipeline/slicer.py | 61 ++++++++++++++----------- tests/forward_pass/test_forward_pass.py | 6 +-- tests/pipeline/test_pipeline.py | 2 +- 3 files changed, 38 insertions(+), 31 deletions(-) diff --git a/sup3r/pipeline/slicer.py b/sup3r/pipeline/slicer.py index d6eb369c7..aa3842732 100644 --- a/sup3r/pipeline/slicer.py +++ b/sup3r/pipeline/slicer.py @@ -1,5 +1,6 @@ """Slicer class for chunking forward pass input""" +import itertools as it import logging from dataclasses import dataclass from typing import Union @@ -133,11 +134,9 @@ def s_lr_slices(self): going through the generator """ if self._s_lr_slices is None: - self._s_lr_slices = [ - (s1, s2) - for s1 in self.s1_lr_slices - for s2 in self.s2_lr_slices - ] + self._s_lr_slices = list( + it.product(self.s1_lr_slices, self.s2_lr_slices) + ) return self._s_lr_slices @property @@ -154,11 +153,9 @@ def s_lr_pad_slices(self): padded data volume passed through the generator """ if self._s_lr_pad_slices is None: - self._s_lr_pad_slices = [ - (s1, s2) - for s1 in self.s1_lr_pad_slices - for s2 in self.s2_lr_pad_slices - ] + self._s_lr_pad_slices = list( + it.product(self.s1_lr_pad_slices, self.s2_lr_pad_slices) + ) return self._s_lr_pad_slices @property @@ -179,7 +176,7 @@ def t_lr_pad_slices(self): shape=self.time_steps, enhancement=1, padding=self.temporal_pad, - step=self.time_slice.step + step=self.time_slice.step, ) return self._t_lr_pad_slices @@ -250,12 +247,9 @@ def s_hr_slices(self): domain corresponding to data_handler.data[lr_slice] """ if self._s_hr_slices is None: - self._s_hr_slices = [] - self._s_hr_slices = [ - (s1, s2) - for s1 in self.s1_hr_slices - for s2 in self.s2_hr_slices - ] + self._s_hr_slices = list( + it.product(self.s1_hr_slices, self.s2_hr_slices) + ) return self._s_hr_slices @property @@ -276,9 +270,9 @@ def s_lr_crop_slices(self): s2_crop_slices = self.get_cropped_slices( self.s2_lr_slices, self.s2_lr_pad_slices, 1 ) - self._s_lr_crop_slices = [ - (s1, s2) for s1 in s1_crop_slices for s2 in s2_crop_slices - ] + self._s_lr_crop_slices = list( + it.product(s1_crop_slices, s2_crop_slices) + ) return self._s_lr_crop_slices @property @@ -308,11 +302,24 @@ def s_hr_crop_slices(self): for _ in range(len(self.s2_lr_slices)) ] - self._s_hr_crop_slices = [ - (s1, s2) - for s1 in s1_hr_crop_slices - for s2 in s2_hr_crop_slices - ] + if self.spatial_pad == 0: + s1_end_slice = self.get_cropped_slices( + self.s1_lr_slices[-1:], + self.s1_lr_pad_slices[-1:], + self.s_enhance, + ) + s2_end_slice = self.get_cropped_slices( + self.s2_lr_slices[-1:], + self.s2_lr_pad_slices[-1:], + self.s_enhance, + ) + + s1_hr_crop_slices[-1] = slice(s1_end_slice[0].start, None) + s2_hr_crop_slices[-1] = slice(s2_end_slice[0].start, None) + + self._s_hr_crop_slices = list( + it.product(s1_hr_crop_slices, s2_hr_crop_slices) + ) return self._s_hr_crop_slices @property @@ -348,7 +355,7 @@ def s1_lr_pad_slices(self): shape=self.coarse_shape[0], enhancement=1, padding=self.spatial_pad, - min_size=self.chunk_shape[0], + min_size=self.chunk_shape[0] ) return self._s1_lr_pad_slices @@ -503,7 +510,7 @@ def get_padded_slices( for _, s in enumerate(slices): end = np.min([enhancement * shape, s.stop * enhancement + pad]) start = np.max([0, s.start * enhancement - pad]) - if min_size is not None and end - start < min_size: + if min_size is not None and end - start < min_size and pad == 0: start = np.max([0, end - min_size]) pad_slices.append(slice(start, end, step)) return pad_slices diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index 6c0b3f003..9c64d048d 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -705,8 +705,8 @@ def test_slicing_no_pad(input_files): fwp.strategy.ti_slices[t_idx], ) - assert handler.data[lr_pad_data_slice].shape[:-1] == (3, 2, 4) - assert chunk.input_data.shape[:-1] == (3, 2, 4) + assert handler.data[lr_pad_data_slice].shape[:-2] == (3, 2) + assert chunk.input_data.shape[:-2] == (3, 2) assert np.allclose( chunk.input_data, handler.data[lr_pad_data_slice] ) @@ -750,7 +750,7 @@ def test_slicing_pad(input_files): input_files, model_kwargs={'model_dir': st_out_dir}, model_class='Sup3rGan', - fwp_chunk_shape=(4, 1, 4), + fwp_chunk_shape=(2, 1, 4), input_handler_kwargs=input_handler_kwargs, spatial_pad=2, temporal_pad=2, diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index d3f58e40d..d05225bb2 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -29,7 +29,7 @@ def input_files(tmpdir_factory): """Dummy netcdf input files for :class:`ForwardPass`""" input_file = str(tmpdir_factory.mktemp('data').join('fwp_input.nc')) - make_fake_nc_file(input_file, shape=(100, 100, 80), features=FEATURES) + make_fake_nc_file(input_file, shape=(109, 261, 80), features=FEATURES) return input_file From 3284f97392bad211a29a7d3c9d535f36e3dafbc0 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 25 Nov 2024 07:12:30 -0700 Subject: [PATCH 06/11] fixes for edge case where spatial slices are too small for the generator padding layer requirements. --- sup3r/pipeline/slicer.py | 176 +++++++++++++++++------- tests/forward_pass/test_forward_pass.py | 102 +++++++++++++- 2 files changed, 227 insertions(+), 51 deletions(-) diff --git a/sup3r/pipeline/slicer.py b/sup3r/pipeline/slicer.py index aa3842732..1cf323c2f 100644 --- a/sup3r/pipeline/slicer.py +++ b/sup3r/pipeline/slicer.py @@ -4,6 +4,7 @@ import logging from dataclasses import dataclass from typing import Union +from warnings import warn import numpy as np @@ -75,6 +76,8 @@ def __post_init__(self): self._s2_lr_slices = None self._s1_lr_pad_slices = None self._s2_lr_pad_slices = None + self._s1_hr_crop_slices = None + self._s2_hr_crop_slices = None self._s_lr_slices = None self._s_lr_pad_slices = None self._s_lr_crop_slices = None @@ -235,6 +238,32 @@ def s2_hr_slices(self): """Get high res spatial slices for second spatial dimension""" return self.get_hr_slices(self.s2_lr_slices, self.s_enhance) + @property + def s1_hr_crop_slices(self): + """Get high res cropped slices for first spatial dimension""" + + if self._s1_hr_crop_slices is None: + self._s1_hr_crop_slices = self.get_hr_cropped_slices( + unpadded_slices=self.s1_lr_slices, + padded_slices=self.s1_lr_pad_slices, + enhancement=self.s_enhance, + padding=self.spatial_pad, + ) + return self._s1_hr_crop_slices + + @property + def s2_hr_crop_slices(self): + """Get high res cropped slices for first spatial dimension""" + + if self._s2_hr_crop_slices is None: + self._s2_hr_crop_slices = self.get_hr_cropped_slices( + unpadded_slices=self.s2_lr_slices, + padded_slices=self.s2_lr_pad_slices, + enhancement=self.s_enhance, + padding=self.spatial_pad, + ) + return self._s2_hr_crop_slices + @property def s_hr_slices(self): """Get high res slices for indexing full generator output array @@ -285,40 +314,9 @@ def s_hr_crop_slices(self): List of high res cropped slices. Each entry in this list has a slice for each spatial dimension. """ - hr_crop_start = None - hr_crop_stop = None - if self.spatial_pad > 0: - hr_crop_start = self.s_enhance * self.spatial_pad - hr_crop_stop = -hr_crop_start - if self._s_hr_crop_slices is None: - self._s_hr_crop_slices = [] - s1_hr_crop_slices = [ - slice(hr_crop_start, hr_crop_stop) - for _ in range(len(self.s1_lr_slices)) - ] - s2_hr_crop_slices = [ - slice(hr_crop_start, hr_crop_stop) - for _ in range(len(self.s2_lr_slices)) - ] - - if self.spatial_pad == 0: - s1_end_slice = self.get_cropped_slices( - self.s1_lr_slices[-1:], - self.s1_lr_pad_slices[-1:], - self.s_enhance, - ) - s2_end_slice = self.get_cropped_slices( - self.s2_lr_slices[-1:], - self.s2_lr_pad_slices[-1:], - self.s_enhance, - ) - - s1_hr_crop_slices[-1] = slice(s1_end_slice[0].start, None) - s2_hr_crop_slices[-1] = slice(s2_end_slice[0].start, None) - self._s_hr_crop_slices = list( - it.product(s1_hr_crop_slices, s2_hr_crop_slices) + it.product(self.s1_hr_crop_slices, self.s2_hr_crop_slices) ) return self._s_hr_crop_slices @@ -345,6 +343,52 @@ def hr_crop_slices(self): self._hr_crop_slices.append(node_slices) return self._hr_crop_slices + def check_boundary_slice(self, slices, dim): + """Check boundary slice for minimum shape. + + When spatial padding is used data is always padded to have at least 2 * + spatial_pad + 1 elements. When spatial padding is not used it's + possible for the forward pass chunk shape to divide the grid size such + that the last slice does not meet the minimum number of elements. + (Padding layers in the generator typically require a minimum shape of + 4). So, when spatial padding is not used so we add extra padding to + meet the minimum shape requirement, otherwise we raise an error if the + minimum shape is not met.""" + + end_slice = slices[-1] + err_msg = ( + 'The final spatial slice for dimension #%s is too small (%s). ' + 'Adjust the forward pass chunk shape (%s) and / or spatial ' + 'padding (%s) so that 2 * spatial_pad + ' + 'modulo(grid_shape, fwp_chunk_shape) > 3' + ) + warn_msg = ( + 'The final spatial slice for dimension #%s is too small (%s). ' + 'The start of this slice will be reduced to try to meet the ' + 'minimum slice length.' + ) + + if end_slice.stop - end_slice.start < 4: + if self.spatial_pad == 0: + logger.warning(warn_msg, dim + 1, end_slice) + warn(warn_msg % (dim + 1, end_slice)) + new_start = np.max([0, end_slice.stop - self.chunk_shape[dim]]) + end_slice = slice(new_start, end_slice.stop, end_slice.step) + slices[-1] = end_slice + if 2 * self.spatial_pad + (end_slice.stop - end_slice.start) < 4: + logger.error( + err_msg, + dim + 1, + end_slice, + self.chunk_shape, + self.spatial_pad, + ) + raise ValueError( + err_msg + % (dim + 1, end_slice, self.chunk_shape, self.spatial_pad) + ) + return slices + @property def s1_lr_pad_slices(self): """List of low resolution spatial slices with padding for first @@ -355,7 +399,9 @@ def s1_lr_pad_slices(self): shape=self.coarse_shape[0], enhancement=1, padding=self.spatial_pad, - min_size=self.chunk_shape[0] + ) + self._s1_lr_pad_slices = self.check_boundary_slice( + slices=self._s1_lr_pad_slices, dim=0 ) return self._s1_lr_pad_slices @@ -369,7 +415,9 @@ def s2_lr_pad_slices(self): shape=self.coarse_shape[1], enhancement=1, padding=self.spatial_pad, - min_size=self.chunk_shape[1], + ) + self._s2_lr_pad_slices = self.check_boundary_slice( + slices=self._s2_lr_pad_slices, dim=1 ) return self._s2_lr_pad_slices @@ -378,20 +426,18 @@ def s1_lr_slices(self): """List of low resolution spatial slices for first spatial dimension considering padding on all sides of the spatial raster.""" ind = slice(0, self.coarse_shape[0]) - slices = get_chunk_slices( + return get_chunk_slices( self.coarse_shape[0], self.chunk_shape[0], index_slice=ind ) - return slices @property def s2_lr_slices(self): """List of low resolution spatial slices for second spatial dimension considering padding on all sides of the spatial raster.""" ind = slice(0, self.coarse_shape[1]) - slices = get_chunk_slices( + return get_chunk_slices( self.coarse_shape[1], self.chunk_shape[1], index_slice=ind ) - return slices @property def t_lr_slices(self): @@ -401,10 +447,9 @@ def t_lr_slices(self): n_chunks = int(np.ceil(n_chunks)) ti_slices = self.dummy_time_index[self.time_slice] ti_slices = np.array_split(ti_slices, n_chunks) - ti_slices = [ + return [ slice(c[0], c[-1] + 1, self.time_slice.step) for c in ti_slices ] - return ti_slices @staticmethod def get_hr_slices(slices, enhancement, step=None): @@ -470,12 +515,18 @@ def n_chunks(self): return self.n_spatial_chunks * self.n_time_chunks @staticmethod - def get_padded_slices( - slices, shape, enhancement, padding, min_size=None, step=None - ): + def get_padded_slices(slices, shape, enhancement, padding, step=None): """Get padded slices with the specified padding size, max shape, enhancement, and step size + Note + ---- + It's possible to get a boundary slice that is too small for generator + input (padding layers typically need at least 4 elements) if the + forward pass chunk shape does not evenly divide the grid shape. We add + extra padding in the low res slices to account for this with + ``min_size`` argument. + Parameters ---------- slices : list @@ -492,9 +543,6 @@ def get_padded_slices( dimension and the spatial_pad is 10 this is 10. It will be multiplied by the enhancement factor if the slices are to be used to index an enhanced dimension. - min_size : int - Minimum size of a slice. This is usually at least 4. Padding layers - in the generator model typpically require a minimum shape of 4. step : int | None Step size for slices. e.g. If these slices are indexing a temporal dimension and time_slice.step = 3 then step=3. @@ -510,8 +558,6 @@ def get_padded_slices( for _, s in enumerate(slices): end = np.min([enhancement * shape, s.stop * enhancement + pad]) start = np.max([0, s.start * enhancement - pad]) - if min_size is not None and end - start < min_size and pad == 0: - start = np.max([0, end - min_size]) pad_slices.append(slice(start, end, step)) return pad_slices @@ -548,3 +594,37 @@ def get_cropped_slices(unpadded_slices, padded_slices, enhancement): stop = None cropped_slices.append(slice(start, stop)) return cropped_slices + + @classmethod + def get_hr_cropped_slices( + cls, unpadded_slices, padded_slices, padding, enhancement + ): + """Get high res cropped slices + + Note + ---- + It's possible to get a boundary slice that is too small for generator + input (padding layers typically need at least 4 elements) if the + forward pass chunk shape does not evenly divide the grid shape. We add + extra padding in the low res slices to account for this (with + :meth:`check_boundary_slice`) and need to adjust the high res cropped + slices accordingly. + """ + + hr_crop_start = None + hr_crop_stop = None + + if padding > 0: + hr_crop_start = enhancement * padding + hr_crop_stop = -hr_crop_start + + slices = [slice(hr_crop_start, hr_crop_stop)] * len(unpadded_slices) + + if padding == 0: + end_slice = cls.get_cropped_slices( + unpadded_slices[-1:], + padded_slices[-1:], + enhancement, + ) + slices[-1] = slice(end_slice[0].start, None) + return slices diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index 9c64d048d..8774c1759 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -669,11 +669,25 @@ def test_slicing_no_pad(input_files): 'shape': shape, 'time_slice': time_slice, } + + # raises error because fwp_chunk_shape is too small + with pytest.raises(ValueError): + strategy = ForwardPassStrategy( + input_files, + model_kwargs={'model_dir': st_out_dir}, + model_class='Sup3rGan', + fwp_chunk_shape=(3, 2, 4), + spatial_pad=0, + temporal_pad=0, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + max_nodes=1, + ) strategy = ForwardPassStrategy( input_files, model_kwargs={'model_dir': st_out_dir}, model_class='Sup3rGan', - fwp_chunk_shape=(3, 2, 4), + fwp_chunk_shape=(4, 4, 4), spatial_pad=0, temporal_pad=0, input_handler_kwargs=input_handler_kwargs, @@ -705,8 +719,90 @@ def test_slicing_no_pad(input_files): fwp.strategy.ti_slices[t_idx], ) - assert handler.data[lr_pad_data_slice].shape[:-2] == (3, 2) - assert chunk.input_data.shape[:-2] == (3, 2) + assert handler.data[lr_pad_data_slice].shape[:-2] == (4, 4) + assert chunk.input_data.shape[:-2] == (4, 4) + assert np.allclose( + chunk.input_data, handler.data[lr_pad_data_slice] + ) + assert np.allclose( + chunk.input_data[lr_crop_data_slice], + handler.data[lr_data_slice], + ) + + +def test_slicing_auto_boundary_pad(input_files): + """Test that automatic boundary padding is applied when the fwp chunk shape + and grid size result in a slice that is too small for the generator.""" + + Sup3rGan.seed() + s_enhance = 3 + t_enhance = 4 + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + features = ['u_100m', 'v_100m'] + st_model.meta['lr_features'] = features + st_model.meta['hr_out_features'] = features + st_model.meta['s_enhance'] = s_enhance + st_model.meta['t_enhance'] = t_enhance + _ = st_model.generate(np.ones((4, 10, 10, 6, 2))) + + with tempfile.TemporaryDirectory() as td: + out_files = os.path.join(td, 'out_{file_id}.h5') + st_out_dir = os.path.join(td, 'st_gan') + st_model.save(st_out_dir) + + handler = DataHandler( + input_files, features, target=target, shape=shape + ) + + input_handler_kwargs = { + 'target': target, + 'shape': shape, + 'time_slice': time_slice, + } + + # raises warning because modulo(shape, fwp_chunk_shape) = 1 for the + # spatial dimensions. The slices of length 1 are then padded to 7 + with pytest.warns(match='too small'): + strategy = ForwardPassStrategy( + input_files, + model_kwargs={'model_dir': st_out_dir}, + model_class='Sup3rGan', + fwp_chunk_shape=(7, 7, 4), + spatial_pad=0, + temporal_pad=0, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + max_nodes=1, + ) + + fwp = ForwardPass(strategy) + for i in strategy.node_chunks[0]: + chunk = fwp.get_input_chunk(i) + s_idx, t_idx = strategy.get_chunk_indices(i) + s_slices = strategy.lr_slices[s_idx] + s_pad_slices = strategy.lr_pad_slices[s_idx] + s_crop_slices = strategy.fwp_slicer.s_lr_crop_slices[s_idx] + t_crop_slice = strategy.fwp_slicer.t_lr_crop_slices[t_idx] + lr_pad_data_slice = ( + s_pad_slices[0], + s_pad_slices[1], + fwp.strategy.ti_pad_slices[t_idx], + ) + lr_crop_data_slice = ( + s_crop_slices[0], + s_crop_slices[1], + t_crop_slice, + ) + lr_data_slice = ( + s_slices[0], + s_slices[1], + fwp.strategy.ti_slices[t_idx], + ) + + assert handler.data[lr_pad_data_slice].shape[:-2] == (7, 7) + assert chunk.input_data.shape[:-2] == (7, 7) assert np.allclose( chunk.input_data, handler.data[lr_pad_data_slice] ) From 317382375d375561b9d665fc535edd6dfbed256d Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 25 Nov 2024 11:01:02 -0700 Subject: [PATCH 07/11] Added boundary padding when spatial_pad > 0 also. --- sup3r/pipeline/slicer.py | 140 +++++++++++------------- sup3r/pipeline/strategy.py | 49 +++++++-- tests/forward_pass/test_forward_pass.py | 64 +++++------ 3 files changed, 132 insertions(+), 121 deletions(-) diff --git a/sup3r/pipeline/slicer.py b/sup3r/pipeline/slicer.py index 1cf323c2f..97996eb50 100644 --- a/sup3r/pipeline/slicer.py +++ b/sup3r/pipeline/slicer.py @@ -245,10 +245,15 @@ def s1_hr_crop_slices(self): if self._s1_hr_crop_slices is None: self._s1_hr_crop_slices = self.get_hr_cropped_slices( unpadded_slices=self.s1_lr_slices, - padded_slices=self.s1_lr_pad_slices, enhancement=self.s_enhance, padding=self.spatial_pad, ) + + self._s1_hr_crop_slices = self.check_boundary_slice( + unpadded_slices=self.s1_lr_slices, + cropped_slices=self._s1_hr_crop_slices, + dim=0, + ) return self._s1_hr_crop_slices @property @@ -258,10 +263,14 @@ def s2_hr_crop_slices(self): if self._s2_hr_crop_slices is None: self._s2_hr_crop_slices = self.get_hr_cropped_slices( unpadded_slices=self.s2_lr_slices, - padded_slices=self.s2_lr_pad_slices, enhancement=self.s_enhance, padding=self.spatial_pad, ) + self._s2_hr_crop_slices = self.check_boundary_slice( + unpadded_slices=self.s2_lr_slices, + cropped_slices=self._s2_hr_crop_slices, + dim=1, + ) return self._s2_hr_crop_slices @property @@ -296,9 +305,20 @@ def s_lr_crop_slices(self): s1_crop_slices = self.get_cropped_slices( self.s1_lr_slices, self.s1_lr_pad_slices, 1 ) + + s1_crop_slices = self.check_boundary_slice( + unpadded_slices=self.s1_lr_slices, + cropped_slices=s1_crop_slices, + dim=0, + ) s2_crop_slices = self.get_cropped_slices( self.s2_lr_slices, self.s2_lr_pad_slices, 1 ) + s2_crop_slices = self.check_boundary_slice( + unpadded_slices=self.s2_lr_slices, + cropped_slices=s2_crop_slices, + dim=1, + ) self._s_lr_crop_slices = list( it.product(s1_crop_slices, s2_crop_slices) ) @@ -343,52 +363,6 @@ def hr_crop_slices(self): self._hr_crop_slices.append(node_slices) return self._hr_crop_slices - def check_boundary_slice(self, slices, dim): - """Check boundary slice for minimum shape. - - When spatial padding is used data is always padded to have at least 2 * - spatial_pad + 1 elements. When spatial padding is not used it's - possible for the forward pass chunk shape to divide the grid size such - that the last slice does not meet the minimum number of elements. - (Padding layers in the generator typically require a minimum shape of - 4). So, when spatial padding is not used so we add extra padding to - meet the minimum shape requirement, otherwise we raise an error if the - minimum shape is not met.""" - - end_slice = slices[-1] - err_msg = ( - 'The final spatial slice for dimension #%s is too small (%s). ' - 'Adjust the forward pass chunk shape (%s) and / or spatial ' - 'padding (%s) so that 2 * spatial_pad + ' - 'modulo(grid_shape, fwp_chunk_shape) > 3' - ) - warn_msg = ( - 'The final spatial slice for dimension #%s is too small (%s). ' - 'The start of this slice will be reduced to try to meet the ' - 'minimum slice length.' - ) - - if end_slice.stop - end_slice.start < 4: - if self.spatial_pad == 0: - logger.warning(warn_msg, dim + 1, end_slice) - warn(warn_msg % (dim + 1, end_slice)) - new_start = np.max([0, end_slice.stop - self.chunk_shape[dim]]) - end_slice = slice(new_start, end_slice.stop, end_slice.step) - slices[-1] = end_slice - if 2 * self.spatial_pad + (end_slice.stop - end_slice.start) < 4: - logger.error( - err_msg, - dim + 1, - end_slice, - self.chunk_shape, - self.spatial_pad, - ) - raise ValueError( - err_msg - % (dim + 1, end_slice, self.chunk_shape, self.spatial_pad) - ) - return slices - @property def s1_lr_pad_slices(self): """List of low resolution spatial slices with padding for first @@ -400,9 +374,6 @@ def s1_lr_pad_slices(self): enhancement=1, padding=self.spatial_pad, ) - self._s1_lr_pad_slices = self.check_boundary_slice( - slices=self._s1_lr_pad_slices, dim=0 - ) return self._s1_lr_pad_slices @property @@ -416,9 +387,6 @@ def s2_lr_pad_slices(self): enhancement=1, padding=self.spatial_pad, ) - self._s2_lr_pad_slices = self.check_boundary_slice( - slices=self._s2_lr_pad_slices, dim=1 - ) return self._s2_lr_pad_slices @property @@ -561,6 +529,42 @@ def get_padded_slices(slices, shape, enhancement, padding, step=None): pad_slices.append(slice(start, end, step)) return pad_slices + def check_boundary_slice(self, unpadded_slices, cropped_slices, dim): + """Check cropped slice at the right boundary for minimum shape. + + It is possible for the forward pass chunk shape to divide the grid size + such that the last slice (right boundary) does not meet the minimum + number of elements. (Padding layers in the generator typically require + a minimum shape of 4). When this minimum shape is not met we apply + extra padding in ``ForwardPassStrategy._get_pad_width``. Cropped slices + have to be adjusted to account for this here.""" + + warn_msg = ( + 'The final spatial slice for dimension #%s is too small ' + '(slice=slice(%s, %s), padding=%s). The start of this slice will ' + 'be reduced to try to meet the minimum slice length.' + ) + + lr_slice_start = unpadded_slices[-1].start or 0 + lr_slice_stop = unpadded_slices[-1].stop or self.coarse_shape[dim] + + # last slice adjustment + if 2 * self.spatial_pad + (lr_slice_stop - lr_slice_start) < 4: + logger.warning( + warn_msg, + dim + 1, + lr_slice_start, + lr_slice_stop, + self.spatial_pad, + ) + warn( + warn_msg + % (dim + 1, lr_slice_start, lr_slice_stop, self.spatial_pad) + ) + cropped_slices[-1] = slice(2 * self.s_enhance, -2 * self.s_enhance) + + return cropped_slices + @staticmethod def get_cropped_slices(unpadded_slices, padded_slices, enhancement): """Get cropped slices to cut off padded output @@ -593,23 +597,12 @@ def get_cropped_slices(unpadded_slices, padded_slices, enhancement): if stop is not None and stop >= 0: stop = None cropped_slices.append(slice(start, stop)) + return cropped_slices @classmethod - def get_hr_cropped_slices( - cls, unpadded_slices, padded_slices, padding, enhancement - ): - """Get high res cropped slices - - Note - ---- - It's possible to get a boundary slice that is too small for generator - input (padding layers typically need at least 4 elements) if the - forward pass chunk shape does not evenly divide the grid shape. We add - extra padding in the low res slices to account for this (with - :meth:`check_boundary_slice`) and need to adjust the high res cropped - slices accordingly. - """ + def get_hr_cropped_slices(cls, unpadded_slices, padding, enhancement): + """Get high res cropped slices""" hr_crop_start = None hr_crop_stop = None @@ -618,13 +611,4 @@ def get_hr_cropped_slices( hr_crop_start = enhancement * padding hr_crop_stop = -hr_crop_start - slices = [slice(hr_crop_start, hr_crop_stop)] * len(unpadded_slices) - - if padding == 0: - end_slice = cls.get_cropped_slices( - unpadded_slices[-1:], - padded_slices[-1:], - enhancement, - ) - slices[-1] = slice(end_slice[0].start, None) - return slices + return [slice(hr_crop_start, hr_crop_stop)] * len(unpadded_slices) diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index 4aa9e4d32..3935108a0 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -6,7 +6,6 @@ import os import pathlib import pprint -import warnings from dataclasses import dataclass from functools import cached_property from typing import Dict, Optional, Tuple, Union @@ -336,6 +335,18 @@ def preflight(self): out = self.fwp_slicer.get_time_slices() self.ti_slices, self.ti_pad_slices = out + fwp_s1_steps = self.fwp_chunk_shape[0] + 2 * self.spatial_pad + fwp_s2_steps = self.fwp_chunk_shape[1] + 2 * self.spatial_pad + msg = ( + 'The padding layers in the generator typically require at least 4 ' + 'elements per spatial dimension. The padded chunk shape (%s, %s) ' + 'is smaller than this.' + ) + + if fwp_s1_steps < 4 or fwp_s2_steps < 4: + logger.warning(msg, fwp_s1_steps, fwp_s2_steps) + warn(msg % (fwp_s1_steps, fwp_s2_steps)) + fwp_tsteps = self.fwp_chunk_shape[2] + 2 * self.temporal_pad tsteps = len(self.input_handler.time_index[self.time_slice]) msg = ( @@ -345,7 +356,7 @@ def preflight(self): ) if fwp_tsteps > tsteps: logger.warning(msg) - warnings.warn(msg) + warn(msg) out = self.fwp_slicer.get_spatial_slices() self.lr_slices, self.lr_pad_slices, self.hr_slices = out @@ -400,7 +411,7 @@ def out_files(self): return out_file_list @staticmethod - def _get_pad_width(window, max_steps, max_pad): + def _get_pad_width(window, max_steps, max_pad, check_boundary=False): """ Parameters ---------- @@ -410,16 +421,30 @@ def _get_pad_width(window, max_steps, max_pad): Maximum number of steps available. Padding cannot extend past this max_pad : int Maximum amount of padding to apply. + check_bounary : bool + Whether to check the final slice for minimum size requirement Returns ------- tuple Tuple of pad width for the given window. """ - start = window.start or 0 - stop = window.stop or max_steps - start = int(np.maximum(0, (max_pad - start))) - stop = int(np.maximum(0, max_pad + stop - max_steps)) + win_start = window.start or 0 + win_stop = window.stop or max_steps + start = int(np.maximum(0, (max_pad - win_start))) + stop = int(np.maximum(0, max_pad + win_stop - max_steps)) + + # We add minimum padding to the last slice if the padded window is + # too small for the generator. This can happen if 2 * spatial_pad + + # modulo(grid_size, fwp_chunk_shape) < 4 + if ( + check_boundary + and win_stop == max_steps + and (win_stop - win_start) < 4 + ): + stop = np.max([2, max_pad]) + start = np.max([2, max_pad]) + return (start, stop) def get_pad_width(self, chunk_index): @@ -438,10 +463,16 @@ def get_pad_width(self, chunk_index): return ( self._get_pad_width( - lr_slice[0], self.input_handler.grid_shape[0], self.spatial_pad + lr_slice[0], + self.input_handler.grid_shape[0], + self.spatial_pad, + check_boundary=True, ), self._get_pad_width( - lr_slice[1], self.input_handler.grid_shape[1], self.spatial_pad + lr_slice[1], + self.input_handler.grid_shape[1], + self.spatial_pad, + check_boundary=True, ), self._get_pad_width( ti_slice, len(self.input_handler.time_index), self.temporal_pad diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index 8774c1759..5574ff8ca 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -670,8 +670,8 @@ def test_slicing_no_pad(input_files): 'time_slice': time_slice, } - # raises error because fwp_chunk_shape is too small - with pytest.raises(ValueError): + # raises warning because fwp_chunk_shape is too small + with pytest.warns(match='at least 4'): strategy = ForwardPassStrategy( input_files, model_kwargs={'model_dir': st_out_dir}, @@ -719,8 +719,10 @@ def test_slicing_no_pad(input_files): fwp.strategy.ti_slices[t_idx], ) - assert handler.data[lr_pad_data_slice].shape[:-2] == (4, 4) - assert chunk.input_data.shape[:-2] == (4, 4) + assert handler.data[lr_pad_data_slice].shape[:-2][0] > 3 + assert handler.data[lr_pad_data_slice].shape[:-2][1] > 3 + assert chunk.input_data.shape[:-2][0] > 3 + assert chunk.input_data.shape[:-2][1] > 3 assert np.allclose( chunk.input_data, handler.data[lr_pad_data_slice] ) @@ -730,7 +732,8 @@ def test_slicing_no_pad(input_files): ) -def test_slicing_auto_boundary_pad(input_files): +@pytest.mark.parametrize('spatial_pad', [0, 1]) +def test_slicing_auto_boundary_pad(input_files, spatial_pad): """Test that automatic boundary padding is applied when the fwp chunk shape and grid size result in a slice that is too small for the generator.""" @@ -762,34 +765,26 @@ def test_slicing_auto_boundary_pad(input_files): 'time_slice': time_slice, } - # raises warning because modulo(shape, fwp_chunk_shape) = 1 for the - # spatial dimensions. The slices of length 1 are then padded to 7 - with pytest.warns(match='too small'): - strategy = ForwardPassStrategy( - input_files, - model_kwargs={'model_dir': st_out_dir}, - model_class='Sup3rGan', - fwp_chunk_shape=(7, 7, 4), - spatial_pad=0, - temporal_pad=0, - input_handler_kwargs=input_handler_kwargs, - out_pattern=out_files, - max_nodes=1, - ) + strategy = ForwardPassStrategy( + input_files, + model_kwargs={'model_dir': st_out_dir}, + model_class='Sup3rGan', + fwp_chunk_shape=(7, 7, 4), + spatial_pad=spatial_pad, + temporal_pad=0, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + max_nodes=1, + ) fwp = ForwardPass(strategy) for i in strategy.node_chunks[0]: chunk = fwp.get_input_chunk(i) s_idx, t_idx = strategy.get_chunk_indices(i) + pad_width = strategy.get_pad_width(i) s_slices = strategy.lr_slices[s_idx] - s_pad_slices = strategy.lr_pad_slices[s_idx] s_crop_slices = strategy.fwp_slicer.s_lr_crop_slices[s_idx] t_crop_slice = strategy.fwp_slicer.t_lr_crop_slices[t_idx] - lr_pad_data_slice = ( - s_pad_slices[0], - s_pad_slices[1], - fwp.strategy.ti_pad_slices[t_idx], - ) lr_crop_data_slice = ( s_crop_slices[0], s_crop_slices[1], @@ -801,15 +796,16 @@ def test_slicing_auto_boundary_pad(input_files): fwp.strategy.ti_slices[t_idx], ) - assert handler.data[lr_pad_data_slice].shape[:-2] == (7, 7) - assert chunk.input_data.shape[:-2] == (7, 7) - assert np.allclose( - chunk.input_data, handler.data[lr_pad_data_slice] - ) - assert np.allclose( - chunk.input_data[lr_crop_data_slice], - handler.data[lr_data_slice], - ) + assert chunk.input_data.shape[:-2][0] > 3 + assert chunk.input_data.shape[:-2][1] > 3 + input_data = chunk.input_data.copy() + if spatial_pad > 0: + slices = [ + slice(pw[0] or None, -pw[1] or None) for pw in pad_width + ] + input_data = input_data[slices[0], slices[1]] + hdata = handler.data[lr_data_slice] + assert np.allclose(input_data[lr_crop_data_slice], hdata) def test_slicing_pad(input_files): From 3456315ede296a44743f5faf298543334cf0f61a Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 25 Nov 2024 12:13:01 -0700 Subject: [PATCH 08/11] pad_width fix. check padded width for min shape, not unpadded width. --- sup3r/pipeline/strategy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index 3935108a0..a6ff0f64d 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -440,7 +440,7 @@ def _get_pad_width(window, max_steps, max_pad, check_boundary=False): if ( check_boundary and win_stop == max_steps - and (win_stop - win_start) < 4 + and (start + stop + win_stop - win_start) < 4 ): stop = np.max([2, max_pad]) start = np.max([2, max_pad]) From 11acb626282a0434f81c0db973ae2991832f8545 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 25 Nov 2024 14:56:53 -0700 Subject: [PATCH 09/11] Moved pad_width methods from Strategy to Slicer. Choices made for slicing values (hr_cropped_slices, for example) depend a lot on the pad width implementations so these should be grouped. --- sup3r/pipeline/slicer.py | 88 ++++++++++++++++++++++++- sup3r/pipeline/strategy.py | 77 ++-------------------- tests/forward_pass/test_forward_pass.py | 19 ++++-- 3 files changed, 104 insertions(+), 80 deletions(-) diff --git a/sup3r/pipeline/slicer.py b/sup3r/pipeline/slicer.py index 97996eb50..e0412d2bc 100644 --- a/sup3r/pipeline/slicer.py +++ b/sup3r/pipeline/slicer.py @@ -72,6 +72,7 @@ def __post_init__(self): self.time_slice = _parse_time_slice(self.time_slice) self._chunk_lookup = None + self._extra_padding = None self._s1_lr_slices = None self._s2_lr_slices = None self._s1_lr_pad_slices = None @@ -536,7 +537,7 @@ def check_boundary_slice(self, unpadded_slices, cropped_slices, dim): such that the last slice (right boundary) does not meet the minimum number of elements. (Padding layers in the generator typically require a minimum shape of 4). When this minimum shape is not met we apply - extra padding in ``ForwardPassStrategy._get_pad_width``. Cropped slices + extra padding in :meth:`self._get_pad_width`. Cropped slices have to be adjusted to account for this here.""" warn_msg = ( @@ -612,3 +613,88 @@ def get_hr_cropped_slices(cls, unpadded_slices, padding, enhancement): hr_crop_stop = -hr_crop_start return [slice(hr_crop_start, hr_crop_stop)] * len(unpadded_slices) + + @staticmethod + def _get_pad_width(window, max_steps, max_pad, check_boundary=False): + """ + Parameters + ---------- + window : slice + Slice with start and stop of window to pad. + max_steps : int + Maximum number of steps available. Padding cannot extend past this + max_pad : int + Maximum amount of padding to apply. + check_bounary : bool + Whether to check the final slice for minimum size requirement + + Returns + ------- + tuple + Tuple of pad width for the given window. + """ + win_start = window.start or 0 + win_stop = window.stop or max_steps + start = int(np.maximum(0, (max_pad - win_start))) + stop = int(np.maximum(0, max_pad + win_stop - max_steps)) + + # We add minimum padding to the last slice if the padded window is + # too small for the generator. This can happen if 2 * spatial_pad + + # modulo(grid_size, fwp_chunk_shape) < 4 + if ( + check_boundary + and win_stop == max_steps + and (2 * max_pad + win_stop - win_start) < 4 + ): + stop = np.max([2, max_pad]) + start = np.max([2, max_pad]) + + return (start, stop) + + def get_chunk_indices(self, chunk_index): + """Get (spatial, temporal) indices for the given chunk index""" + return ( + chunk_index % self.n_spatial_chunks, + chunk_index // self.n_spatial_chunks, + ) + + def get_pad_width(self, chunk_index): + """Get extra padding for the current spatiotemporal chunk + + Returns + ------- + padding : tuple + Tuple of tuples with padding width for spatial and temporal + dimensions. Each tuple includes the start and end of padding for + that dimension. Ordering is spatial_1, spatial_2, temporal. + """ + s_chunk_idx, t_chunk_idx = self.get_chunk_indices(chunk_index) + ti_slice = self.t_lr_slices[t_chunk_idx] + lr_slice = self.s_lr_slices[s_chunk_idx] + + return ( + self._get_pad_width( + lr_slice[0], + self.coarse_shape[0], + self.spatial_pad, + check_boundary=True, + ), + self._get_pad_width( + lr_slice[1], + self.coarse_shape[1], + self.spatial_pad, + check_boundary=True, + ), + self._get_pad_width( + ti_slice, len(self.dummy_time_index), self.temporal_pad + ), + ) + + @property + def extra_padding(self): + """Get list of pad widths for each chunk index""" + if self._extra_padding is None: + self._extra_padding = [ + self.get_pad_width(idx) for idx in range(self.n_chunks) + ] + return self._extra_padding diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index a6ff0f64d..9a488e6a1 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -410,75 +410,6 @@ def out_files(self): ] return out_file_list - @staticmethod - def _get_pad_width(window, max_steps, max_pad, check_boundary=False): - """ - Parameters - ---------- - window : slice - Slice with start and stop of window to pad. - max_steps : int - Maximum number of steps available. Padding cannot extend past this - max_pad : int - Maximum amount of padding to apply. - check_bounary : bool - Whether to check the final slice for minimum size requirement - - Returns - ------- - tuple - Tuple of pad width for the given window. - """ - win_start = window.start or 0 - win_stop = window.stop or max_steps - start = int(np.maximum(0, (max_pad - win_start))) - stop = int(np.maximum(0, max_pad + win_stop - max_steps)) - - # We add minimum padding to the last slice if the padded window is - # too small for the generator. This can happen if 2 * spatial_pad + - # modulo(grid_size, fwp_chunk_shape) < 4 - if ( - check_boundary - and win_stop == max_steps - and (start + stop + win_stop - win_start) < 4 - ): - stop = np.max([2, max_pad]) - start = np.max([2, max_pad]) - - return (start, stop) - - def get_pad_width(self, chunk_index): - """Get padding for the current spatiotemporal chunk - - Returns - ------- - padding : tuple - Tuple of tuples with padding width for spatial and temporal - dimensions. Each tuple includes the start and end of padding for - that dimension. Ordering is spatial_1, spatial_2, temporal. - """ - s_chunk_idx, t_chunk_idx = self.get_chunk_indices(chunk_index) - ti_slice = self.ti_slices[t_chunk_idx] - lr_slice = self.lr_slices[s_chunk_idx] - - return ( - self._get_pad_width( - lr_slice[0], - self.input_handler.grid_shape[0], - self.spatial_pad, - check_boundary=True, - ), - self._get_pad_width( - lr_slice[1], - self.input_handler.grid_shape[1], - self.spatial_pad, - check_boundary=True, - ), - self._get_pad_width( - ti_slice, len(self.input_handler.time_index), self.temporal_pad - ), - ) - def prep_chunk_data(self, chunk_index=0): """Get low res input data and exo data for given chunk index and bias correct low res data if requested. @@ -530,7 +461,9 @@ def init_chunk(self, chunk_index=0): with that data and other chunk specific attributes. """ - s_chunk_idx, t_chunk_idx = self.get_chunk_indices(chunk_index) + s_chunk_idx, t_chunk_idx = self.fwp_slicer.get_chunk_indices( + chunk_index + ) msg = ( f'Requested forward pass on chunk_index={chunk_index} > ' @@ -579,7 +512,7 @@ def init_chunk(self, chunk_index=0): ), gids=self.gids[hr_slice[:2]], out_file=self.out_files[chunk_index], - pad_width=self.get_pad_width(chunk_index), + pad_width=self.fwp_slicer.extra_padding[chunk_index], index=chunk_index, ) @@ -687,7 +620,7 @@ def chunk_masked(self, chunk_idx, log=True): """Check if the region for this chunk is masked. This is used to skip running the forward pass for region with just ocean, for example.""" - s_chunk_idx, _ = self.get_chunk_indices(chunk_idx) + s_chunk_idx, _ = self.fwp_slicer.get_chunk_indices(chunk_idx) mask_check = self.fwp_mask[s_chunk_idx] if mask_check and log: logger.info( diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index 5574ff8ca..4d25f3fb2 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -479,7 +479,9 @@ def test_fwp_chunking(input_files): output_workers=strat.output_workers, meta=fwp.meta, ) - s_chunk_idx, t_chunk_idx = fwp.strategy.get_chunk_indices(i) + s_chunk_idx, t_chunk_idx = ( + fwp.strategy.fwp_slicer.get_chunk_indices(i) + ) ti_slice = fwp.strategy.ti_slices[t_chunk_idx] hr_slice = fwp.strategy.hr_slices[s_chunk_idx] @@ -698,7 +700,7 @@ def test_slicing_no_pad(input_files): fwp = ForwardPass(strategy) for i in strategy.node_chunks[0]: chunk = fwp.get_input_chunk(i) - s_idx, t_idx = strategy.get_chunk_indices(i) + s_idx, t_idx = strategy.fwp_slicer.get_chunk_indices(i) s_slices = strategy.lr_slices[s_idx] s_pad_slices = strategy.lr_pad_slices[s_idx] s_crop_slices = strategy.fwp_slicer.s_lr_crop_slices[s_idx] @@ -780,8 +782,8 @@ def test_slicing_auto_boundary_pad(input_files, spatial_pad): fwp = ForwardPass(strategy) for i in strategy.node_chunks[0]: chunk = fwp.get_input_chunk(i) - s_idx, t_idx = strategy.get_chunk_indices(i) - pad_width = strategy.get_pad_width(i) + s_idx, t_idx = strategy.fwp_slicer.get_chunk_indices(i) + pad_width = strategy.fwp_slicer.get_pad_width(i) s_slices = strategy.lr_slices[s_idx] s_crop_slices = strategy.fwp_slicer.s_lr_crop_slices[s_idx] t_crop_slice = strategy.fwp_slicer.t_lr_crop_slices[t_idx] @@ -796,8 +798,8 @@ def test_slicing_auto_boundary_pad(input_files, spatial_pad): fwp.strategy.ti_slices[t_idx], ) - assert chunk.input_data.shape[:-2][0] > 3 - assert chunk.input_data.shape[:-2][1] > 3 + assert chunk.input_data.shape[0] > 3 + assert chunk.input_data.shape[1] > 3 input_data = chunk.input_data.copy() if spatial_pad > 0: slices = [ @@ -867,7 +869,7 @@ def test_slicing_pad(input_files): fwp = ForwardPass(strategy) for i in strategy.node_chunks[0]: chunk = fwp.get_input_chunk(i, mode='constant') - s_idx, t_idx = strategy.get_chunk_indices(i) + s_idx, t_idx = strategy.fwp_slicer.get_chunk_indices(i) s_slices = strategy.lr_pad_slices[s_idx] lr_data_slice = ( s_slices[0], @@ -904,6 +906,9 @@ def test_slicing_pad(input_files): (0, 0), ) + assert chunk.input_data.shape[0] > 3 + assert chunk.input_data.shape[1] > 3 + truth = handler.data[lr_data_slice] padded_truth = np.pad(truth, pad_width, mode='constant') From 18070f95536f0d3222b44fa02ebdd37b5204d830 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 26 Nov 2024 08:01:33 -0700 Subject: [PATCH 10/11] doc string additions --- sup3r/pipeline/slicer.py | 53 +++++++++---------------- tests/forward_pass/test_forward_pass.py | 8 +++- 2 files changed, 26 insertions(+), 35 deletions(-) diff --git a/sup3r/pipeline/slicer.py b/sup3r/pipeline/slicer.py index e0412d2bc..332d62a31 100644 --- a/sup3r/pipeline/slicer.py +++ b/sup3r/pipeline/slicer.py @@ -244,11 +244,12 @@ def s1_hr_crop_slices(self): """Get high res cropped slices for first spatial dimension""" if self._s1_hr_crop_slices is None: - self._s1_hr_crop_slices = self.get_hr_cropped_slices( - unpadded_slices=self.s1_lr_slices, - enhancement=self.s_enhance, - padding=self.spatial_pad, - ) + hr_crop_start = self.s_enhance * self.spatial_pad or None + hr_crop_stop = None if self.spatial_pad == 0 else -hr_crop_start + + self._s1_hr_crop_slices = [ + slice(hr_crop_start, hr_crop_stop) + ] * len(self.s1_lr_slices) self._s1_hr_crop_slices = self.check_boundary_slice( unpadded_slices=self.s1_lr_slices, @@ -262,11 +263,13 @@ def s2_hr_crop_slices(self): """Get high res cropped slices for first spatial dimension""" if self._s2_hr_crop_slices is None: - self._s2_hr_crop_slices = self.get_hr_cropped_slices( - unpadded_slices=self.s2_lr_slices, - enhancement=self.s_enhance, - padding=self.spatial_pad, - ) + hr_crop_start = self.s_enhance * self.spatial_pad or None + hr_crop_stop = None if self.spatial_pad == 0 else -hr_crop_start + + self._s2_hr_crop_slices = [ + slice(hr_crop_start, hr_crop_stop) + ] * len(self.s2_lr_slices) + self._s2_hr_crop_slices = self.check_boundary_slice( unpadded_slices=self.s2_lr_slices, cropped_slices=self._s2_hr_crop_slices, @@ -488,14 +491,6 @@ def get_padded_slices(slices, shape, enhancement, padding, step=None): """Get padded slices with the specified padding size, max shape, enhancement, and step size - Note - ---- - It's possible to get a boundary slice that is too small for generator - input (padding layers typically need at least 4 elements) if the - forward pass chunk shape does not evenly divide the grid shape. We add - extra padding in the low res slices to account for this with - ``min_size`` argument. - Parameters ---------- slices : list @@ -536,9 +531,12 @@ def check_boundary_slice(self, unpadded_slices, cropped_slices, dim): It is possible for the forward pass chunk shape to divide the grid size such that the last slice (right boundary) does not meet the minimum number of elements. (Padding layers in the generator typically require - a minimum shape of 4). When this minimum shape is not met we apply - extra padding in :meth:`self._get_pad_width`. Cropped slices - have to be adjusted to account for this here.""" + a minimum shape of 4). e.g. ``grid_size = (8, 8)`` with + ``fwp_chunk_shape = (7, 7, ...)`` results in unpadded slices with just + one element. If the padding is 0 or 1 these padded slices have length + less than 4. When this minimum shape is not met we apply extra padding + in :meth:`self._get_pad_width`. Cropped slices have to be adjusted to + account for this here.""" warn_msg = ( 'The final spatial slice for dimension #%s is too small ' @@ -601,19 +599,6 @@ def get_cropped_slices(unpadded_slices, padded_slices, enhancement): return cropped_slices - @classmethod - def get_hr_cropped_slices(cls, unpadded_slices, padding, enhancement): - """Get high res cropped slices""" - - hr_crop_start = None - hr_crop_stop = None - - if padding > 0: - hr_crop_start = enhancement * padding - hr_crop_stop = -hr_crop_start - - return [slice(hr_crop_start, hr_crop_stop)] * len(unpadded_slices) - @staticmethod def _get_pad_width(window, max_steps, max_pad, check_boundary=False): """ diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index 4d25f3fb2..69d6cfda7 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -737,7 +737,13 @@ def test_slicing_no_pad(input_files): @pytest.mark.parametrize('spatial_pad', [0, 1]) def test_slicing_auto_boundary_pad(input_files, spatial_pad): """Test that automatic boundary padding is applied when the fwp chunk shape - and grid size result in a slice that is too small for the generator.""" + and grid size result in a slice that is too small for the generator. + + Here the fwp chunk shape is (7, 7, 4) and the grid size is (8, 8) so with + no spatial padding this results in some chunk slices that have length 1. + With spatial padding equal to 1 some slices have length 3. In each of these + case we need to pad the slices so the input to the generator has at least 4 + elements.""" Sup3rGan.seed() s_enhance = 3 From 164ae4405a6c29e17f44f2df46ac4506f255cf10 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 4 Dec 2024 09:13:51 -0700 Subject: [PATCH 11/11] earlier error catch on chunks which didn't get all variables written. --- sup3r/postprocessing/collectors/h5.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/sup3r/postprocessing/collectors/h5.py b/sup3r/postprocessing/collectors/h5.py index ebd5d9c70..72ff97f13 100644 --- a/sup3r/postprocessing/collectors/h5.py +++ b/sup3r/postprocessing/collectors/h5.py @@ -146,12 +146,25 @@ def get_data( with RexOutputs(file_path, unscale=False, mode='r') as f: f_ti = f.time_index f_meta = f.meta + + if feature not in f.attrs: + e = ( + 'Trying to collect dataset "{}" from {} but cannot find ' + 'in available attrbutes: {}'.format( + feature, file_path, f.attrs + ) + ) + logger.error(e) + raise KeyError(e) + source_scale_factor = f.attrs[feature].get('scale_factor', 1) if feature not in f.dsets: e = ( - 'Trying to collect dataset "{}" but cannot find in ' - 'available: {}'.format(feature, f.dsets) + 'Trying to collect dataset "{}" from {} but cannot find ' + 'in available features: {}'.format( + feature, file_path, f.dsets + ) ) logger.error(e) raise KeyError(e)