diff --git a/src/leibnetz/leibnet.py b/src/leibnetz/leibnet.py index 38aad90..f9a1d3b 100644 --- a/src/leibnetz/leibnet.py +++ b/src/leibnetz/leibnet.py @@ -313,12 +313,12 @@ def is_valid_input_shape(self, input_key, input_shape): == 0 ).all() - def step_up_size(self, steps: int = 1): + def step_up_size(self, steps: int = 1, step_size: int = 1): for n in range(steps): target_arrays = {} for name, metadata in self.output_shapes.items(): target_arrays[name] = tuple( - (tuple(s + 1 for s in metadata["shape"]), metadata["scale"]) + (tuple(s + step_size for s in metadata["shape"]), metadata["scale"]) ) self.compute_shapes(target_arrays, set=True)