From 8fdbbb92159301eef68b48bf4c0ecdaa5e86a5dd Mon Sep 17 00:00:00 2001 From: grantbuster Date: Wed, 1 Mar 2023 10:15:28 -0700 Subject: [PATCH] added depth to time spatiotemporal expansion option with tests --- phygnn/layers/custom_layers.py | 50 ++++++++++++++++++++++++++++------ phygnn/version.py | 2 +- tests/test_layers.py | 39 ++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 10 deletions(-) diff --git a/phygnn/layers/custom_layers.py b/phygnn/layers/custom_layers.py index 9bcb23a..c7e2f0e 100644 --- a/phygnn/layers/custom_layers.py +++ b/phygnn/layers/custom_layers.py @@ -327,7 +327,7 @@ class SpatioTemporalExpansion(tf.keras.layers.Layer): """ def __init__(self, spatial_mult=1, temporal_mult=1, - temporal_method='nearest'): + temporal_method='nearest', t_roll=0): """ Parameters ---------- @@ -343,12 +343,23 @@ def __init__(self, spatial_mult=1, temporal_mult=1, if the input layer has shape (123, 5, 5, 24, 2) with multiplier=2 the output shape will be (123, 5, 5, 48, 2). temporal_method : str - Interpolation method for tf.image.resize(). + Interpolation method for tf.image.resize(). Can also be + "depth_to_time" for an operation similar to tf.nn.depth_to_space + where the feature axis is unpacked into the temporal axis. + t_roll : int + Option to roll the temporal axis after expanding. When using + temporal_method="depth_to_time", the default (t_roll=0) will + add temporal steps after the input steps such that if input + temporal shape is 3 and the temporal_mult is 24x, the output will + have the original timesteps at idt=0,24,48 but if t_roll=12, the + output will have the original timesteps at idt=12,36,60 """ + super().__init__() self._spatial_mult = int(spatial_mult) self._temporal_mult = int(temporal_mult) self._temporal_meth = temporal_method + self._t_roll = t_roll @staticmethod def _check_shape(input_shape): @@ -377,14 +388,35 @@ def build(self, input_shape): def _temporal_expand(self, x): """Expand the temporal dimension (axis=3) of a 5D tensor""" - temp_expand_shape = tf.stack( - [x.shape[2], x.shape[3] * self._temporal_mult]) - out = [] - for x_unstack in tf.unstack(x, axis=1): - out.append(tf.image.resize(x_unstack, temp_expand_shape, - method=self._temporal_meth)) - return tf.stack(out, axis=1) + if self._temporal_meth == 'depth_to_time': + check_shape = x.shape[-1] % self._temporal_mult + if check_shape != 0: + msg = ('Temporal expansion of factor {} is being attempted on ' + 'input tensor of shape {}, but the last dimension of ' + 'the input tensor ({}) must be divisible by the ' + 'temporal factor ({}).' + .format(self._temporal_mult, x.shape, x.shape[-1], + self._temporal_mult)) + logger.error(msg) + raise RuntimeError(msg) + + shape = (x.shape[0], x.shape[1], x.shape[2], + x.shape[3] * self._temporal_mult, + x.shape[4] // self._temporal_mult) + out = tf.reshape(x, shape) + out = tf.roll(out, self._t_roll, axis=3) + + else: + temp_expand_shape = tf.stack([x.shape[2], + x.shape[3] * self._temporal_mult]) + out = [] + for x_unstack in tf.unstack(x, axis=1): + out.append(tf.image.resize(x_unstack, temp_expand_shape, + method=self._temporal_meth)) + out = tf.stack(out, axis=1) + + return out def _spatial_expand(self, x): """Expand the two spatial dimensions (axis=1,2) of a 5D tensor using diff --git a/phygnn/version.py b/phygnn/version.py index 7d29f8d..a34ad93 100644 --- a/phygnn/version.py +++ b/phygnn/version.py @@ -1,4 +1,4 @@ # -*- coding: utf-8 -*- """Physics Guided Neural Network version.""" -__version__ = '0.0.22' +__version__ = '0.0.23' diff --git a/tests/test_layers.py b/tests/test_layers.py index c7244ef..64a7a31 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -187,6 +187,45 @@ def test_st_expansion(t_mult, s_mult): assert y.shape[4] == x.shape[4] / (s_mult**2) +@pytest.mark.parametrize( + ('t_mult', 's_mult', 't_roll'), + ((2, 1, 0), + (2, 1, 1), + (1, 2, 0), + (2, 2, 0), + (2, 2, 1), + (5, 3, 0), + (5, 1, 0), + (5, 1, 2), + (5, 1, 3), + (5, 2, 3), + (24, 1, 12))) +def test_temporal_depth_to_time(t_mult, s_mult, t_roll): + """Test the spatiotemporal expansion layer.""" + layer = SpatioTemporalExpansion(spatial_mult=s_mult, temporal_mult=t_mult, + temporal_method='depth_to_time', + t_roll=t_roll) + n_filters = 2 * s_mult**2 * t_mult + shape = (1, 4, 4, 3, n_filters) + n = np.product(shape) + x = np.arange(n).reshape((shape)) + y = layer(x) + assert y.shape[0] == x.shape[0] + assert y.shape[1] == s_mult * x.shape[1] + assert y.shape[2] == s_mult * x.shape[2] + assert y.shape[3] == t_mult * x.shape[3] + assert y.shape[4] == x.shape[4] / (t_mult * s_mult**2) + if s_mult == 1: + for idy in range(y.shape[3]): + idx = np.maximum(0, idy - t_roll) // t_mult + even = ((idy - t_roll) % t_mult) == 0 + x1, y1 = x[0, :, :, idx, 0], y[0, :, :, idy, 0] + if even: + assert np.allclose(x1, y1) + else: + assert not np.allclose(x1, y1) + + def test_st_expansion_new_shape(): """Test that the spatiotemporal expansion layer can expand multiple shapes and is not bound to the shape it was built on (bug found on 3/16/2022.)"""