From d5da2f16cc3d4ecfca215bcd735e191701cee879 Mon Sep 17 00:00:00 2001 From: grantbuster Date: Mon, 8 Jul 2024 15:53:19 -0600 Subject: [PATCH] skip layer test with convolutions which is typical in sup3r --- tests/test_layers.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/tests/test_layers.py b/tests/test_layers.py index 5645a24..28897d0 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -101,36 +101,41 @@ def test_repeat_layers(): def test_skip_connection(): """Test a functional skip connection""" hidden_layers = [ - {'units': 64, 'activation': 'relu', 'dropout': 0.01}, + {'class': 'Conv2D', 'filters': 4, 'kernel_size': 3, + 'activation': 'relu', 'padding': 'same'}, {'class': 'SkipConnection', 'name': 'a'}, - {'units': 64, 'activation': 'relu', 'dropout': 0.01}, - {'class': 'SkipConnection', 'name': 'a'}] + {'class': 'Conv2D', 'filters': 4, 'kernel_size': 3, + 'activation': 'relu', 'padding': 'same'}, + {'class': 'SkipConnection', 'name': 'a'}, + {'class': 'Conv2D', 'filters': 4, 'kernel_size': 3, + 'activation': 'relu', 'padding': 'same'}, + ] layers = HiddenLayers(hidden_layers) - assert len(layers.layers) == 8 + assert len(layers.layers) == 5 skip_layers = [x for x in layers.layers if isinstance(x, SkipConnection)] assert len(skip_layers) == 2 assert id(skip_layers[0]) == id(skip_layers[1]) - x = np.ones((5, 3)) + x = np.ones((5, 10, 10, 4)) cache = None x_input = None for i, layer in enumerate(layers): - if i == 3: # skip start + if i == 1: # skip start cache = tf.identity(x) assert id(cache) != id(x) - elif i == 7: # skip end + elif i == 3: # skip end x_input = tf.identity(x) assert id(x_input) != id(x) x = layer(x) - if i == 3: # skip start + if i == 1: # skip start assert layer._cache is not None - elif i == 4 or i == 5 or i == 6: + elif i == 2: assert np.allclose(cache.numpy(), layers[3]._cache.numpy()) - elif i == 7: # skip end + elif i == 3: # skip end assert layer._cache is None tf.assert_equal(x, tf.add(x_input, cache))