Skip to content

Commit

Permalink
generalized min pad width for padding slices so that this can accomod…
Browse files Browse the repository at this point in the history
…ate models with increased receptive field and larger padding values.
  • Loading branch information
bnb32 committed Jan 18, 2025
1 parent 212a77f commit c01df1b
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 35 deletions.
99 changes: 64 additions & 35 deletions sup3r/pipeline/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import itertools as it
import logging
from dataclasses import dataclass
from typing import Union
from typing import Optional, Union
from warnings import warn

import numpy as np
Expand All @@ -27,9 +27,23 @@ class ForwardPassSlicer:
time_steps : int
Number of time steps for full temporal domain of low res data. This
is used to construct a dummy_time_index from np.arange(time_steps)
s_enhance : int
Spatial enhancement factor
t_enhance : int
Temporal enhancement factor
time_slice : slice | list
Slice to use to extract range from time_index. Can be a ``slice(start,
stop, step)`` or list ``[start, stop, step]``
temporal_pad : int
Size of temporal overlap between coarse chunks passed to forward
passes for subsequent temporal stitching. This overlap will pad
both sides of the fwp_chunk_shape. Note that the first and last
chunks in the temporal dimension will not be padded.
spatial_pad : int
Size of spatial overlap between coarse chunks passed to forward
passes for subsequent spatial stitching. This overlap will pad both
sides of the fwp_chunk_shape. Note that the first and last chunks
in any of the spatial dimension will not be padded.
chunk_shape : tuple
Max shape (spatial_1, spatial_2, temporal) of an unpadded coarse
chunk to use for a forward pass. The number of nodes that the
Expand All @@ -41,20 +55,11 @@ class ForwardPassSlicer:
to the generator can be bigger than this shape. If running in
serial set this equal to the shape of the full spatiotemporal data
volume for best performance.
s_enhance : int
Spatial enhancement factor
t_enhance : int
Temporal enhancement factor
spatial_pad : int
Size of spatial overlap between coarse chunks passed to forward
passes for subsequent spatial stitching. This overlap will pad both
sides of the fwp_chunk_shape. Note that the first and last chunks
in any of the spatial dimension will not be padded.
temporal_pad : int
Size of temporal overlap between coarse chunks passed to forward
passes for subsequent temporal stitching. This overlap will pad
both sides of the fwp_chunk_shape. Note that the first and last
chunks in the temporal dimension will not be padded.
min_width : tuple
Minimum width of padded slices, with each element providing the min
width for the corresponding dimension. e.g. (spatial_1, spatial_2,
temporal). This is used to make sure generator network input meets the
minimum size requirement for padding layers.
"""

coarse_shape: Union[tuple, list]
Expand All @@ -65,6 +70,7 @@ class ForwardPassSlicer:
temporal_pad: int
spatial_pad: int
chunk_shape: Union[tuple, list]
min_width: Optional[Union[tuple, list]] = None

@log_args
def __post_init__(self):
Expand All @@ -88,6 +94,9 @@ def __post_init__(self):
self._s_hr_crop_slices = None
self._t_hr_crop_slices = None
self._hr_crop_slices = None
self.min_width = (
self.chunk_shape if self.min_width is None else self.min_width
)

def get_spatial_slices(self):
"""Get spatial slices for small data chunks that are passed through
Expand Down Expand Up @@ -254,6 +263,8 @@ def s1_hr_crop_slices(self):
self._s1_hr_crop_slices = self.check_boundary_slice(
unpadded_slices=self.s1_lr_slices,
cropped_slices=self._s1_hr_crop_slices,
enhancement=self.s_enhance,
padding=self.spatial_pad,
dim=0,
)
return self._s1_hr_crop_slices
Expand All @@ -273,6 +284,8 @@ def s2_hr_crop_slices(self):
self._s2_hr_crop_slices = self.check_boundary_slice(
unpadded_slices=self.s2_lr_slices,
cropped_slices=self._s2_hr_crop_slices,
enhancement=self.s_enhance,
padding=self.spatial_pad,
dim=1,
)
return self._s2_hr_crop_slices
Expand Down Expand Up @@ -525,21 +538,22 @@ def get_padded_slices(slices, shape, enhancement, padding, step=None):
pad_slices.append(slice(start, end, step))
return pad_slices

def check_boundary_slice(self, unpadded_slices, cropped_slices, dim):
def check_boundary_slice(
self, unpadded_slices, cropped_slices, enhancement, padding, dim
):
"""Check cropped slice at the right boundary for minimum shape.
It is possible for the forward pass chunk shape to divide the grid size
such that the last slice (right boundary) does not meet the minimum
number of elements. (Padding layers in the generator typically require
a minimum shape of 4). e.g. ``grid_size = (8, 8)`` with
``fwp_chunk_shape = (7, 7, ...)`` results in unpadded slices with just
one element. If the padding is 0 or 1 these padded slices have length
less than 4. When this minimum shape is not met we apply extra padding
in :meth:`self._get_pad_width`. Cropped slices have to be adjusted to
number of elements. (Padding layers in the generator require a minimum
shape). e.g. ``grid_size = (8, 8)`` with ``fwp_chunk_shape = (7, 7,
...)`` results in unpadded slices with just one element. When this
minimum shape is not met we apply extra padding in
:meth:`self._get_pad_width`. Cropped slices have to be adjusted to
account for this here."""

warn_msg = (
'The final spatial slice for dimension #%s is too small '
'The final slice for dimension #%s is too small '
'(slice=slice(%s, %s), padding=%s). The start of this slice will '
'be reduced to try to meet the minimum slice length.'
)
Expand All @@ -548,19 +562,22 @@ def check_boundary_slice(self, unpadded_slices, cropped_slices, dim):
lr_slice_stop = unpadded_slices[-1].stop or self.coarse_shape[dim]

# last slice adjustment
if 2 * self.spatial_pad + (lr_slice_stop - lr_slice_start) < 4:
if (
2 * padding + (lr_slice_stop - lr_slice_start)
<= self.min_width[dim]
):
half_width = self.min_width[dim] // 2 + 1
logger.warning(
warn_msg,
dim + 1,
lr_slice_start,
lr_slice_stop,
self.spatial_pad,
padding,
)
warn(
warn_msg
% (dim + 1, lr_slice_start, lr_slice_stop, self.spatial_pad)
warn(warn_msg % (dim + 1, lr_slice_start, lr_slice_stop, padding))
cropped_slices[-1] = slice(
half_width * enhancement, -half_width * enhancement
)
cropped_slices[-1] = slice(2 * self.s_enhance, -2 * self.s_enhance)

return cropped_slices

Expand Down Expand Up @@ -600,7 +617,9 @@ def get_cropped_slices(unpadded_slices, padded_slices, enhancement):
return cropped_slices

@staticmethod
def _get_pad_width(window, max_steps, max_pad, check_boundary=False):
def _get_pad_width(
window, max_steps, max_pad, min_width=None, check_boundary=False
):
"""
Parameters
----------
Expand All @@ -610,6 +629,10 @@ def _get_pad_width(window, max_steps, max_pad, check_boundary=False):
Maximum number of steps available. Padding cannot extend past this
max_pad : int
Maximum amount of padding to apply.
min_width : int | None
Minimum width to enforce. This could be the forward pass chunk
shape or the padding value in the first padding layer of the
generator network. This is only used if ``check_boundary = True``
check_bounary : bool
Whether to check the final slice for minimum size requirement
Expand All @@ -625,14 +648,16 @@ def _get_pad_width(window, max_steps, max_pad, check_boundary=False):

# We add minimum padding to the last slice if the padded window is
# too small for the generator. This can happen if 2 * spatial_pad +
# modulo(grid_size, fwp_chunk_shape) < 4
# modulo(grid_size, fwp_chunk_shape) is less than the padding applied
# in the first padding layer of the generator
if (
check_boundary
and win_stop == max_steps
and (2 * max_pad + win_stop - win_start) < 4
and (2 * max_pad + win_stop - win_start) < min_width
):
stop = np.max([2, max_pad])
start = np.max([2, max_pad])
half_width = min_width // 2 + 1
stop = np.max([half_width, max_pad])
start = np.max([half_width, max_pad])

return (start, stop)

Expand Down Expand Up @@ -662,16 +687,20 @@ def get_pad_width(self, chunk_index):
lr_slice[0],
self.coarse_shape[0],
self.spatial_pad,
self.min_width[0],
check_boundary=True,
),
self._get_pad_width(
lr_slice[1],
self.coarse_shape[1],
self.spatial_pad,
self.min_width[1],
check_boundary=True,
),
self._get_pad_width(
ti_slice, len(self.dummy_time_index), self.temporal_pad
ti_slice,
len(self.dummy_time_index),
self.temporal_pad
),
)

Expand Down
14 changes: 14 additions & 0 deletions sup3r/pipeline/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def __post_init__(self):
t_enhance=self.t_enhance,
spatial_pad=self.spatial_pad,
temporal_pad=self.temporal_pad,
min_width=self.get_min_pad_width(model)
)
self.n_chunks = self.fwp_slicer.n_chunks

Expand All @@ -253,6 +254,19 @@ def __post_init__(self):

self.preflight()

def get_min_pad_width(self, model):
"""Get the padding values applied in the first padding layer of the
model. This is used to determine the minimum width of padded slices
used to chunk the generator input."""
pad_width = (1, 1, 1)
for layer in model._gen.layers:
if hasattr(layer, 'paddings'):
pad_width = np.max(layer.paddings, axis=1)[1:-1]
if len(pad_width) < 3:
pad_width = (*pad_width, 1)
break
return pad_width

@property
def meta(self):
"""Meta data dictionary for the strategy. Used to add info to forward
Expand Down

0 comments on commit c01df1b

Please sign in to comment.