Skip to content

Commit

Permalink
fixes for edge case where spatial slices are too small for the genera…
Browse files Browse the repository at this point in the history
…tor padding layer requirements.
  • Loading branch information
bnb32 committed Nov 25, 2024
1 parent 2a4f3b0 commit 3284f97
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 51 deletions.
176 changes: 128 additions & 48 deletions sup3r/pipeline/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from dataclasses import dataclass
from typing import Union
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -75,6 +76,8 @@ def __post_init__(self):
self._s2_lr_slices = None
self._s1_lr_pad_slices = None
self._s2_lr_pad_slices = None
self._s1_hr_crop_slices = None
self._s2_hr_crop_slices = None
self._s_lr_slices = None
self._s_lr_pad_slices = None
self._s_lr_crop_slices = None
Expand Down Expand Up @@ -235,6 +238,32 @@ def s2_hr_slices(self):
"""Get high res spatial slices for second spatial dimension"""
return self.get_hr_slices(self.s2_lr_slices, self.s_enhance)

@property
def s1_hr_crop_slices(self):
"""Get high res cropped slices for first spatial dimension"""

if self._s1_hr_crop_slices is None:
self._s1_hr_crop_slices = self.get_hr_cropped_slices(
unpadded_slices=self.s1_lr_slices,
padded_slices=self.s1_lr_pad_slices,
enhancement=self.s_enhance,
padding=self.spatial_pad,
)
return self._s1_hr_crop_slices

@property
def s2_hr_crop_slices(self):
"""Get high res cropped slices for first spatial dimension"""

if self._s2_hr_crop_slices is None:
self._s2_hr_crop_slices = self.get_hr_cropped_slices(
unpadded_slices=self.s2_lr_slices,
padded_slices=self.s2_lr_pad_slices,
enhancement=self.s_enhance,
padding=self.spatial_pad,
)
return self._s2_hr_crop_slices

@property
def s_hr_slices(self):
"""Get high res slices for indexing full generator output array
Expand Down Expand Up @@ -285,40 +314,9 @@ def s_hr_crop_slices(self):
List of high res cropped slices. Each entry in this list has a
slice for each spatial dimension.
"""
hr_crop_start = None
hr_crop_stop = None
if self.spatial_pad > 0:
hr_crop_start = self.s_enhance * self.spatial_pad
hr_crop_stop = -hr_crop_start

if self._s_hr_crop_slices is None:
self._s_hr_crop_slices = []
s1_hr_crop_slices = [
slice(hr_crop_start, hr_crop_stop)
for _ in range(len(self.s1_lr_slices))
]
s2_hr_crop_slices = [
slice(hr_crop_start, hr_crop_stop)
for _ in range(len(self.s2_lr_slices))
]

if self.spatial_pad == 0:
s1_end_slice = self.get_cropped_slices(
self.s1_lr_slices[-1:],
self.s1_lr_pad_slices[-1:],
self.s_enhance,
)
s2_end_slice = self.get_cropped_slices(
self.s2_lr_slices[-1:],
self.s2_lr_pad_slices[-1:],
self.s_enhance,
)

s1_hr_crop_slices[-1] = slice(s1_end_slice[0].start, None)
s2_hr_crop_slices[-1] = slice(s2_end_slice[0].start, None)

self._s_hr_crop_slices = list(
it.product(s1_hr_crop_slices, s2_hr_crop_slices)
it.product(self.s1_hr_crop_slices, self.s2_hr_crop_slices)
)
return self._s_hr_crop_slices

Expand All @@ -345,6 +343,52 @@ def hr_crop_slices(self):
self._hr_crop_slices.append(node_slices)
return self._hr_crop_slices

def check_boundary_slice(self, slices, dim):
"""Check boundary slice for minimum shape.
When spatial padding is used data is always padded to have at least 2 *
spatial_pad + 1 elements. When spatial padding is not used it's
possible for the forward pass chunk shape to divide the grid size such
that the last slice does not meet the minimum number of elements.
(Padding layers in the generator typically require a minimum shape of
4). So, when spatial padding is not used so we add extra padding to
meet the minimum shape requirement, otherwise we raise an error if the
minimum shape is not met."""

end_slice = slices[-1]
err_msg = (
'The final spatial slice for dimension #%s is too small (%s). '
'Adjust the forward pass chunk shape (%s) and / or spatial '
'padding (%s) so that 2 * spatial_pad + '
'modulo(grid_shape, fwp_chunk_shape) > 3'
)
warn_msg = (
'The final spatial slice for dimension #%s is too small (%s). '
'The start of this slice will be reduced to try to meet the '
'minimum slice length.'
)

if end_slice.stop - end_slice.start < 4:
if self.spatial_pad == 0:
logger.warning(warn_msg, dim + 1, end_slice)
warn(warn_msg % (dim + 1, end_slice))
new_start = np.max([0, end_slice.stop - self.chunk_shape[dim]])
end_slice = slice(new_start, end_slice.stop, end_slice.step)
slices[-1] = end_slice
if 2 * self.spatial_pad + (end_slice.stop - end_slice.start) < 4:
logger.error(
err_msg,
dim + 1,
end_slice,
self.chunk_shape,
self.spatial_pad,
)
raise ValueError(
err_msg
% (dim + 1, end_slice, self.chunk_shape, self.spatial_pad)
)
return slices

@property
def s1_lr_pad_slices(self):
"""List of low resolution spatial slices with padding for first
Expand All @@ -355,7 +399,9 @@ def s1_lr_pad_slices(self):
shape=self.coarse_shape[0],
enhancement=1,
padding=self.spatial_pad,
min_size=self.chunk_shape[0]
)
self._s1_lr_pad_slices = self.check_boundary_slice(
slices=self._s1_lr_pad_slices, dim=0
)
return self._s1_lr_pad_slices

Expand All @@ -369,7 +415,9 @@ def s2_lr_pad_slices(self):
shape=self.coarse_shape[1],
enhancement=1,
padding=self.spatial_pad,
min_size=self.chunk_shape[1],
)
self._s2_lr_pad_slices = self.check_boundary_slice(
slices=self._s2_lr_pad_slices, dim=1
)
return self._s2_lr_pad_slices

Expand All @@ -378,20 +426,18 @@ def s1_lr_slices(self):
"""List of low resolution spatial slices for first spatial dimension
considering padding on all sides of the spatial raster."""
ind = slice(0, self.coarse_shape[0])
slices = get_chunk_slices(
return get_chunk_slices(
self.coarse_shape[0], self.chunk_shape[0], index_slice=ind
)
return slices

@property
def s2_lr_slices(self):
"""List of low resolution spatial slices for second spatial dimension
considering padding on all sides of the spatial raster."""
ind = slice(0, self.coarse_shape[1])
slices = get_chunk_slices(
return get_chunk_slices(
self.coarse_shape[1], self.chunk_shape[1], index_slice=ind
)
return slices

@property
def t_lr_slices(self):
Expand All @@ -401,10 +447,9 @@ def t_lr_slices(self):
n_chunks = int(np.ceil(n_chunks))
ti_slices = self.dummy_time_index[self.time_slice]
ti_slices = np.array_split(ti_slices, n_chunks)
ti_slices = [
return [
slice(c[0], c[-1] + 1, self.time_slice.step) for c in ti_slices
]
return ti_slices

@staticmethod
def get_hr_slices(slices, enhancement, step=None):
Expand Down Expand Up @@ -470,12 +515,18 @@ def n_chunks(self):
return self.n_spatial_chunks * self.n_time_chunks

@staticmethod
def get_padded_slices(
slices, shape, enhancement, padding, min_size=None, step=None
):
def get_padded_slices(slices, shape, enhancement, padding, step=None):
"""Get padded slices with the specified padding size, max shape,
enhancement, and step size
Note
----
It's possible to get a boundary slice that is too small for generator
input (padding layers typically need at least 4 elements) if the
forward pass chunk shape does not evenly divide the grid shape. We add
extra padding in the low res slices to account for this with
``min_size`` argument.
Parameters
----------
slices : list
Expand All @@ -492,9 +543,6 @@ def get_padded_slices(
dimension and the spatial_pad is 10 this is 10. It will be
multiplied by the enhancement factor if the slices are to be used
to index an enhanced dimension.
min_size : int
Minimum size of a slice. This is usually at least 4. Padding layers
in the generator model typpically require a minimum shape of 4.
step : int | None
Step size for slices. e.g. If these slices are indexing a temporal
dimension and time_slice.step = 3 then step=3.
Expand All @@ -510,8 +558,6 @@ def get_padded_slices(
for _, s in enumerate(slices):
end = np.min([enhancement * shape, s.stop * enhancement + pad])
start = np.max([0, s.start * enhancement - pad])
if min_size is not None and end - start < min_size and pad == 0:
start = np.max([0, end - min_size])
pad_slices.append(slice(start, end, step))
return pad_slices

Expand Down Expand Up @@ -548,3 +594,37 @@ def get_cropped_slices(unpadded_slices, padded_slices, enhancement):
stop = None
cropped_slices.append(slice(start, stop))
return cropped_slices

@classmethod
def get_hr_cropped_slices(
cls, unpadded_slices, padded_slices, padding, enhancement
):
"""Get high res cropped slices
Note
----
It's possible to get a boundary slice that is too small for generator
input (padding layers typically need at least 4 elements) if the
forward pass chunk shape does not evenly divide the grid shape. We add
extra padding in the low res slices to account for this (with
:meth:`check_boundary_slice`) and need to adjust the high res cropped
slices accordingly.
"""

hr_crop_start = None
hr_crop_stop = None

if padding > 0:
hr_crop_start = enhancement * padding
hr_crop_stop = -hr_crop_start

slices = [slice(hr_crop_start, hr_crop_stop)] * len(unpadded_slices)

if padding == 0:
end_slice = cls.get_cropped_slices(
unpadded_slices[-1:],
padded_slices[-1:],
enhancement,
)
slices[-1] = slice(end_slice[0].start, None)
return slices
Loading

0 comments on commit 3284f97

Please sign in to comment.