From c01df1b09229d13753c261bc816eccf3d866f6b6 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sat, 11 Jan 2025 11:22:34 -0700 Subject: [PATCH] generalized min pad width for padding slices so that this can accomodate models with increased receptive field and larger padding values. --- sup3r/pipeline/slicer.py | 99 ++++++++++++++++++++++++-------------- sup3r/pipeline/strategy.py | 14 ++++++ 2 files changed, 78 insertions(+), 35 deletions(-) diff --git a/sup3r/pipeline/slicer.py b/sup3r/pipeline/slicer.py index 332d62a31..3df7ace7c 100644 --- a/sup3r/pipeline/slicer.py +++ b/sup3r/pipeline/slicer.py @@ -3,7 +3,7 @@ import itertools as it import logging from dataclasses import dataclass -from typing import Union +from typing import Optional, Union from warnings import warn import numpy as np @@ -27,9 +27,23 @@ class ForwardPassSlicer: time_steps : int Number of time steps for full temporal domain of low res data. This is used to construct a dummy_time_index from np.arange(time_steps) + s_enhance : int + Spatial enhancement factor + t_enhance : int + Temporal enhancement factor time_slice : slice | list Slice to use to extract range from time_index. Can be a ``slice(start, stop, step)`` or list ``[start, stop, step]`` + temporal_pad : int + Size of temporal overlap between coarse chunks passed to forward + passes for subsequent temporal stitching. This overlap will pad + both sides of the fwp_chunk_shape. Note that the first and last + chunks in the temporal dimension will not be padded. + spatial_pad : int + Size of spatial overlap between coarse chunks passed to forward + passes for subsequent spatial stitching. This overlap will pad both + sides of the fwp_chunk_shape. Note that the first and last chunks + in any of the spatial dimension will not be padded. chunk_shape : tuple Max shape (spatial_1, spatial_2, temporal) of an unpadded coarse chunk to use for a forward pass. The number of nodes that the @@ -41,20 +55,11 @@ class ForwardPassSlicer: to the generator can be bigger than this shape. If running in serial set this equal to the shape of the full spatiotemporal data volume for best performance. - s_enhance : int - Spatial enhancement factor - t_enhance : int - Temporal enhancement factor - spatial_pad : int - Size of spatial overlap between coarse chunks passed to forward - passes for subsequent spatial stitching. This overlap will pad both - sides of the fwp_chunk_shape. Note that the first and last chunks - in any of the spatial dimension will not be padded. - temporal_pad : int - Size of temporal overlap between coarse chunks passed to forward - passes for subsequent temporal stitching. This overlap will pad - both sides of the fwp_chunk_shape. Note that the first and last - chunks in the temporal dimension will not be padded. + min_width : tuple + Minimum width of padded slices, with each element providing the min + width for the corresponding dimension. e.g. (spatial_1, spatial_2, + temporal). This is used to make sure generator network input meets the + minimum size requirement for padding layers. """ coarse_shape: Union[tuple, list] @@ -65,6 +70,7 @@ class ForwardPassSlicer: temporal_pad: int spatial_pad: int chunk_shape: Union[tuple, list] + min_width: Optional[Union[tuple, list]] = None @log_args def __post_init__(self): @@ -88,6 +94,9 @@ def __post_init__(self): self._s_hr_crop_slices = None self._t_hr_crop_slices = None self._hr_crop_slices = None + self.min_width = ( + self.chunk_shape if self.min_width is None else self.min_width + ) def get_spatial_slices(self): """Get spatial slices for small data chunks that are passed through @@ -254,6 +263,8 @@ def s1_hr_crop_slices(self): self._s1_hr_crop_slices = self.check_boundary_slice( unpadded_slices=self.s1_lr_slices, cropped_slices=self._s1_hr_crop_slices, + enhancement=self.s_enhance, + padding=self.spatial_pad, dim=0, ) return self._s1_hr_crop_slices @@ -273,6 +284,8 @@ def s2_hr_crop_slices(self): self._s2_hr_crop_slices = self.check_boundary_slice( unpadded_slices=self.s2_lr_slices, cropped_slices=self._s2_hr_crop_slices, + enhancement=self.s_enhance, + padding=self.spatial_pad, dim=1, ) return self._s2_hr_crop_slices @@ -525,21 +538,22 @@ def get_padded_slices(slices, shape, enhancement, padding, step=None): pad_slices.append(slice(start, end, step)) return pad_slices - def check_boundary_slice(self, unpadded_slices, cropped_slices, dim): + def check_boundary_slice( + self, unpadded_slices, cropped_slices, enhancement, padding, dim + ): """Check cropped slice at the right boundary for minimum shape. It is possible for the forward pass chunk shape to divide the grid size such that the last slice (right boundary) does not meet the minimum - number of elements. (Padding layers in the generator typically require - a minimum shape of 4). e.g. ``grid_size = (8, 8)`` with - ``fwp_chunk_shape = (7, 7, ...)`` results in unpadded slices with just - one element. If the padding is 0 or 1 these padded slices have length - less than 4. When this minimum shape is not met we apply extra padding - in :meth:`self._get_pad_width`. Cropped slices have to be adjusted to + number of elements. (Padding layers in the generator require a minimum + shape). e.g. ``grid_size = (8, 8)`` with ``fwp_chunk_shape = (7, 7, + ...)`` results in unpadded slices with just one element. When this + minimum shape is not met we apply extra padding in + :meth:`self._get_pad_width`. Cropped slices have to be adjusted to account for this here.""" warn_msg = ( - 'The final spatial slice for dimension #%s is too small ' + 'The final slice for dimension #%s is too small ' '(slice=slice(%s, %s), padding=%s). The start of this slice will ' 'be reduced to try to meet the minimum slice length.' ) @@ -548,19 +562,22 @@ def check_boundary_slice(self, unpadded_slices, cropped_slices, dim): lr_slice_stop = unpadded_slices[-1].stop or self.coarse_shape[dim] # last slice adjustment - if 2 * self.spatial_pad + (lr_slice_stop - lr_slice_start) < 4: + if ( + 2 * padding + (lr_slice_stop - lr_slice_start) + <= self.min_width[dim] + ): + half_width = self.min_width[dim] // 2 + 1 logger.warning( warn_msg, dim + 1, lr_slice_start, lr_slice_stop, - self.spatial_pad, + padding, ) - warn( - warn_msg - % (dim + 1, lr_slice_start, lr_slice_stop, self.spatial_pad) + warn(warn_msg % (dim + 1, lr_slice_start, lr_slice_stop, padding)) + cropped_slices[-1] = slice( + half_width * enhancement, -half_width * enhancement ) - cropped_slices[-1] = slice(2 * self.s_enhance, -2 * self.s_enhance) return cropped_slices @@ -600,7 +617,9 @@ def get_cropped_slices(unpadded_slices, padded_slices, enhancement): return cropped_slices @staticmethod - def _get_pad_width(window, max_steps, max_pad, check_boundary=False): + def _get_pad_width( + window, max_steps, max_pad, min_width=None, check_boundary=False + ): """ Parameters ---------- @@ -610,6 +629,10 @@ def _get_pad_width(window, max_steps, max_pad, check_boundary=False): Maximum number of steps available. Padding cannot extend past this max_pad : int Maximum amount of padding to apply. + min_width : int | None + Minimum width to enforce. This could be the forward pass chunk + shape or the padding value in the first padding layer of the + generator network. This is only used if ``check_boundary = True`` check_bounary : bool Whether to check the final slice for minimum size requirement @@ -625,14 +648,16 @@ def _get_pad_width(window, max_steps, max_pad, check_boundary=False): # We add minimum padding to the last slice if the padded window is # too small for the generator. This can happen if 2 * spatial_pad + - # modulo(grid_size, fwp_chunk_shape) < 4 + # modulo(grid_size, fwp_chunk_shape) is less than the padding applied + # in the first padding layer of the generator if ( check_boundary and win_stop == max_steps - and (2 * max_pad + win_stop - win_start) < 4 + and (2 * max_pad + win_stop - win_start) < min_width ): - stop = np.max([2, max_pad]) - start = np.max([2, max_pad]) + half_width = min_width // 2 + 1 + stop = np.max([half_width, max_pad]) + start = np.max([half_width, max_pad]) return (start, stop) @@ -662,16 +687,20 @@ def get_pad_width(self, chunk_index): lr_slice[0], self.coarse_shape[0], self.spatial_pad, + self.min_width[0], check_boundary=True, ), self._get_pad_width( lr_slice[1], self.coarse_shape[1], self.spatial_pad, + self.min_width[1], check_boundary=True, ), self._get_pad_width( - ti_slice, len(self.dummy_time_index), self.temporal_pad + ti_slice, + len(self.dummy_time_index), + self.temporal_pad ), ) diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index 9a488e6a1..6b01294a3 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -231,6 +231,7 @@ def __post_init__(self): t_enhance=self.t_enhance, spatial_pad=self.spatial_pad, temporal_pad=self.temporal_pad, + min_width=self.get_min_pad_width(model) ) self.n_chunks = self.fwp_slicer.n_chunks @@ -253,6 +254,19 @@ def __post_init__(self): self.preflight() + def get_min_pad_width(self, model): + """Get the padding values applied in the first padding layer of the + model. This is used to determine the minimum width of padded slices + used to chunk the generator input.""" + pad_width = (1, 1, 1) + for layer in model._gen.layers: + if hasattr(layer, 'paddings'): + pad_width = np.max(layer.paddings, axis=1)[1:-1] + if len(pad_width) < 3: + pad_width = (*pad_width, 1) + break + return pad_width + @property def meta(self): """Meta data dictionary for the strategy. Used to add info to forward