Skip to content

Commit

Permalink
skip layer test with convolutions which is typical in sup3r
Browse files Browse the repository at this point in the history
  • Loading branch information
grantbuster committed Jul 8, 2024
1 parent 12ac219 commit d5da2f1
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,36 +101,41 @@ def test_repeat_layers():
def test_skip_connection():
"""Test a functional skip connection"""
hidden_layers = [
{'units': 64, 'activation': 'relu', 'dropout': 0.01},
{'class': 'Conv2D', 'filters': 4, 'kernel_size': 3,
'activation': 'relu', 'padding': 'same'},
{'class': 'SkipConnection', 'name': 'a'},
{'units': 64, 'activation': 'relu', 'dropout': 0.01},
{'class': 'SkipConnection', 'name': 'a'}]
{'class': 'Conv2D', 'filters': 4, 'kernel_size': 3,
'activation': 'relu', 'padding': 'same'},
{'class': 'SkipConnection', 'name': 'a'},
{'class': 'Conv2D', 'filters': 4, 'kernel_size': 3,
'activation': 'relu', 'padding': 'same'},
]
layers = HiddenLayers(hidden_layers)
assert len(layers.layers) == 8
assert len(layers.layers) == 5

skip_layers = [x for x in layers.layers if isinstance(x, SkipConnection)]
assert len(skip_layers) == 2
assert id(skip_layers[0]) == id(skip_layers[1])

x = np.ones((5, 3))
x = np.ones((5, 10, 10, 4))
cache = None
x_input = None

for i, layer in enumerate(layers):
if i == 3: # skip start
if i == 1: # skip start
cache = tf.identity(x)
assert id(cache) != id(x)
elif i == 7: # skip end
elif i == 3: # skip end
x_input = tf.identity(x)
assert id(x_input) != id(x)

x = layer(x)

if i == 3: # skip start
if i == 1: # skip start
assert layer._cache is not None
elif i == 4 or i == 5 or i == 6:
elif i == 2:
assert np.allclose(cache.numpy(), layers[3]._cache.numpy())
elif i == 7: # skip end
elif i == 3: # skip end
assert layer._cache is None
tf.assert_equal(x, tf.add(x_input, cache))

Expand Down

0 comments on commit d5da2f1

Please sign in to comment.