From db62540cba18fcde87ed5465226d46ea39e2e183 Mon Sep 17 00:00:00 2001 From: Dobiasd Date: Sun, 31 Dec 2023 12:18:04 +0100 Subject: [PATCH] add more tests --- keras_export/generate_test_models.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/keras_export/generate_test_models.py b/keras_export/generate_test_models.py index 018a2124..d8eb14c3 100644 --- a/keras_export/generate_test_models.py +++ b/keras_export/generate_test_models.py @@ -463,6 +463,12 @@ def get_test_model_exhaustive(): outputs.append(MultiHeadAttention( num_heads=1, key_dim=1, value_dim=None, use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50], inputs[51])) + outputs.append(MultiHeadAttention( + num_heads=2, key_dim=3, value_dim=5, + use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50], inputs[51])) + outputs.append(MultiHeadAttention( + num_heads=2, key_dim=3, value_dim=5, + use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50], inputs[51])) shared_conv = Conv2D(1, (1, 1), padding='valid', name='shared_conv', activation='relu')