Skip to content

Commit

Permalink
added depth to time spatiotemporal expansion option with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
grantbuster committed Mar 1, 2023
1 parent 4297a7a commit 8fdbbb9
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 10 deletions.
50 changes: 41 additions & 9 deletions phygnn/layers/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ class SpatioTemporalExpansion(tf.keras.layers.Layer):
"""

def __init__(self, spatial_mult=1, temporal_mult=1,
temporal_method='nearest'):
temporal_method='nearest', t_roll=0):
"""
Parameters
----------
Expand All @@ -343,12 +343,23 @@ def __init__(self, spatial_mult=1, temporal_mult=1,
if the input layer has shape (123, 5, 5, 24, 2) with multiplier=2
the output shape will be (123, 5, 5, 48, 2).
temporal_method : str
Interpolation method for tf.image.resize().
Interpolation method for tf.image.resize(). Can also be
"depth_to_time" for an operation similar to tf.nn.depth_to_space
where the feature axis is unpacked into the temporal axis.
t_roll : int
Option to roll the temporal axis after expanding. When using
temporal_method="depth_to_time", the default (t_roll=0) will
add temporal steps after the input steps such that if input
temporal shape is 3 and the temporal_mult is 24x, the output will
have the original timesteps at idt=0,24,48 but if t_roll=12, the
output will have the original timesteps at idt=12,36,60
"""

super().__init__()
self._spatial_mult = int(spatial_mult)
self._temporal_mult = int(temporal_mult)
self._temporal_meth = temporal_method
self._t_roll = t_roll

@staticmethod
def _check_shape(input_shape):
Expand Down Expand Up @@ -377,14 +388,35 @@ def build(self, input_shape):

def _temporal_expand(self, x):
"""Expand the temporal dimension (axis=3) of a 5D tensor"""
temp_expand_shape = tf.stack(
[x.shape[2], x.shape[3] * self._temporal_mult])
out = []
for x_unstack in tf.unstack(x, axis=1):
out.append(tf.image.resize(x_unstack, temp_expand_shape,
method=self._temporal_meth))

return tf.stack(out, axis=1)
if self._temporal_meth == 'depth_to_time':
check_shape = x.shape[-1] % self._temporal_mult
if check_shape != 0:
msg = ('Temporal expansion of factor {} is being attempted on '
'input tensor of shape {}, but the last dimension of '
'the input tensor ({}) must be divisible by the '
'temporal factor ({}).'
.format(self._temporal_mult, x.shape, x.shape[-1],
self._temporal_mult))
logger.error(msg)
raise RuntimeError(msg)

shape = (x.shape[0], x.shape[1], x.shape[2],
x.shape[3] * self._temporal_mult,
x.shape[4] // self._temporal_mult)
out = tf.reshape(x, shape)
out = tf.roll(out, self._t_roll, axis=3)

else:
temp_expand_shape = tf.stack([x.shape[2],
x.shape[3] * self._temporal_mult])
out = []
for x_unstack in tf.unstack(x, axis=1):
out.append(tf.image.resize(x_unstack, temp_expand_shape,
method=self._temporal_meth))
out = tf.stack(out, axis=1)

return out

def _spatial_expand(self, x):
"""Expand the two spatial dimensions (axis=1,2) of a 5D tensor using
Expand Down
2 changes: 1 addition & 1 deletion phygnn/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
"""Physics Guided Neural Network version."""

__version__ = '0.0.22'
__version__ = '0.0.23'
39 changes: 39 additions & 0 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,45 @@ def test_st_expansion(t_mult, s_mult):
assert y.shape[4] == x.shape[4] / (s_mult**2)


@pytest.mark.parametrize(
('t_mult', 's_mult', 't_roll'),
((2, 1, 0),
(2, 1, 1),
(1, 2, 0),
(2, 2, 0),
(2, 2, 1),
(5, 3, 0),
(5, 1, 0),
(5, 1, 2),
(5, 1, 3),
(5, 2, 3),
(24, 1, 12)))
def test_temporal_depth_to_time(t_mult, s_mult, t_roll):
"""Test the spatiotemporal expansion layer."""
layer = SpatioTemporalExpansion(spatial_mult=s_mult, temporal_mult=t_mult,
temporal_method='depth_to_time',
t_roll=t_roll)
n_filters = 2 * s_mult**2 * t_mult
shape = (1, 4, 4, 3, n_filters)
n = np.product(shape)
x = np.arange(n).reshape((shape))
y = layer(x)
assert y.shape[0] == x.shape[0]
assert y.shape[1] == s_mult * x.shape[1]
assert y.shape[2] == s_mult * x.shape[2]
assert y.shape[3] == t_mult * x.shape[3]
assert y.shape[4] == x.shape[4] / (t_mult * s_mult**2)
if s_mult == 1:
for idy in range(y.shape[3]):
idx = np.maximum(0, idy - t_roll) // t_mult
even = ((idy - t_roll) % t_mult) == 0
x1, y1 = x[0, :, :, idx, 0], y[0, :, :, idy, 0]
if even:
assert np.allclose(x1, y1)
else:
assert not np.allclose(x1, y1)


def test_st_expansion_new_shape():
"""Test that the spatiotemporal expansion layer can expand multiple shapes
and is not bound to the shape it was built on (bug found on 3/16/2022.)"""
Expand Down

0 comments on commit 8fdbbb9

Please sign in to comment.