Skip to content

Commit

Permalink
min padding depends on the .paddings attribute of the ``FlexibleP…
Browse files Browse the repository at this point in the history
…adding`` layers in the generator model. Generalized current min padding to use these values.
  • Loading branch information
bnb32 committed Jan 18, 2025
1 parent c01df1b commit 095063e
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
6 changes: 5 additions & 1 deletion sup3r/pipeline/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ def s_lr_crop_slices(self):
s1_crop_slices = self.check_boundary_slice(
unpadded_slices=self.s1_lr_slices,
cropped_slices=s1_crop_slices,
enhancement=self.s_enhance,
padding=self.spatial_pad,
dim=0,
)
s2_crop_slices = self.get_cropped_slices(
Expand All @@ -334,6 +336,8 @@ def s_lr_crop_slices(self):
s2_crop_slices = self.check_boundary_slice(
unpadded_slices=self.s2_lr_slices,
cropped_slices=s2_crop_slices,
enhancement=self.s_enhance,
padding=self.spatial_pad,
dim=1,
)
self._s_lr_crop_slices = list(
Expand Down Expand Up @@ -653,7 +657,7 @@ def _get_pad_width(
if (
check_boundary
and win_stop == max_steps
and (2 * max_pad + win_stop - win_start) < min_width
and (2 * max_pad + win_stop - win_start) <= min_width
):
half_width = min_width // 2 + 1
stop = np.max([half_width, max_pad])
Expand Down
13 changes: 7 additions & 6 deletions sup3r/pipeline/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def __post_init__(self):
self.input_handler_kwargs.get('time_slice', slice(None))
)
self.fwp_chunk_shape = self._get_fwp_chunk_shape()

self.fwp_slicer = ForwardPassSlicer(
coarse_shape=self.input_handler.grid_shape,
time_steps=len(self.input_handler.time_index),
Expand All @@ -231,7 +230,7 @@ def __post_init__(self):
t_enhance=self.t_enhance,
spatial_pad=self.spatial_pad,
temporal_pad=self.temporal_pad,
min_width=self.get_min_pad_width(model)
min_width=self.get_min_pad_width(model),
)
self.n_chunks = self.fwp_slicer.n_chunks

Expand Down Expand Up @@ -261,10 +260,12 @@ def get_min_pad_width(self, model):
pad_width = (1, 1, 1)
for layer in model._gen.layers:
if hasattr(layer, 'paddings'):
pad_width = np.max(layer.paddings, axis=1)[1:-1]
if len(pad_width) < 3:
pad_width = (*pad_width, 1)
break
new_pw = np.max(layer.paddings, axis=1)[1:-1]
if len(new_pw) < 3:
new_pw = (*new_pw, 1)
pad_width = [
np.max((new_pw[i], pad_width[i])) for i in range(3)
]
return pad_width

@property
Expand Down
4 changes: 2 additions & 2 deletions tests/forward_pass/test_forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,8 +802,8 @@ def test_slicing_auto_boundary_pad(input_files, spatial_pad):
fwp.strategy.ti_slices[t_idx],
)

assert chunk.input_data.shape[0] > 3
assert chunk.input_data.shape[1] > 3
assert chunk.input_data.shape[0] > strategy.fwp_slicer.min_width[0]
assert chunk.input_data.shape[1] > strategy.fwp_slicer.min_width[1]
input_data = chunk.input_data.copy()
if spatial_pad > 0:
slices = [
Expand Down

0 comments on commit 095063e

Please sign in to comment.