diff --git a/include/fdeep/layers/multi_head_attention_layer.hpp b/include/fdeep/layers/multi_head_attention_layer.hpp index ef0167b3..aa530d67 100644 --- a/include/fdeep/layers/multi_head_attention_layer.hpp +++ b/include/fdeep/layers/multi_head_attention_layer.hpp @@ -37,7 +37,7 @@ class multi_head_attention_layer : public layer { const std::size_t index_factor = use_bias ? 2 : 1; const tensor weights = weights_and_biases[index_factor * index]; - const std::size_t n = weights.shape().width_ * weights.shape().depth_; + const std::size_t n = weights.shape().depth_; const tensor biases = use_bias ? weights_and_biases[index_factor * index + 1] : tensor(tensor_shape(n), 0); @@ -54,18 +54,19 @@ class multi_head_attention_layer : public layer const tensor query_raw = input[0]; const tensor value_raw = input[1]; const tensor key_raw = input.size() > 2 ? input[2] : value_raw; - const tensor query = query_dense_.apply({query_raw}).front(); - const tensor value = value_dense_.apply({value_raw}).front(); - const tensor key = key_dense_.apply({key_raw}).front(); assertion( - query.shape().rank() == 2 && - value.shape().rank() == 2 && - key.shape().rank() == 2 && - query.shape().depth_ == value.shape().depth_ && - query.shape().depth_ == key.shape().depth_ && - value.shape().width_ == key.shape().width_, + query_raw.shape().rank() == 2 && + value_raw.shape().rank() == 2 && + key_raw.shape().rank() == 2 && + query_raw.shape().depth_ == value_raw.shape().depth_ && + query_raw.shape().depth_ == key_raw.shape().depth_ && + value_raw.shape().width_ == key_raw.shape().width_, "Invalid shapes; need a query tensor of shape (B, T, dim) and a value/key tensor of shape (B, S, dim)." ); + const tensor query = query_dense_.apply({query_raw}).front(); + const tensor value = value_dense_.apply({value_raw}).front(); + const tensor key = key_dense_.apply({key_raw}).front(); + // https://towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853 // https://dmol.pub/dl/attention.html#multi-head-attention-block // https://github.com/keras-team/keras/blob/v2.14.0/keras/layers/attention/multi_head_attention.py diff --git a/keras_export/generate_test_models.py b/keras_export/generate_test_models.py index 809168f6..1b6b4c23 100644 --- a/keras_export/generate_test_models.py +++ b/keras_export/generate_test_models.py @@ -443,18 +443,24 @@ def get_test_model_exhaustive(): num_heads=1, key_dim=1, value_dim=None, use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50])) outputs.append(MultiHeadAttention( - num_heads=3, key_dim=1, value_dim=None, + num_heads=1, key_dim=2, value_dim=None, use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50])) + outputs.append(MultiHeadAttention( + num_heads=1, key_dim=2, value_dim=None, + use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50])) + outputs.append(MultiHeadAttention( + num_heads=1, key_dim=1, value_dim=2, + use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50])) + outputs.append(MultiHeadAttention( + num_heads=1, key_dim=1, value_dim=2, + use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50])) # todo: re-enable - #outputs.append(MultiHeadAttention( - # num_heads=1, key_dim=2, value_dim=None, + # outputs.append(MultiHeadAttention( + # num_heads=3, key_dim=1, value_dim=None, # use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50])) #outputs.append(MultiHeadAttention( - # num_heads=1, key_dim=1, value_dim=2, - # use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50])) - outputs.append(MultiHeadAttention( - num_heads=1, key_dim=1, value_dim=None, - use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50], inputs[51])) + # num_heads=1, key_dim=1, value_dim=None, + # use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50], inputs[51])) shared_conv = Conv2D(1, (1, 1), padding='valid', name='shared_conv', activation='relu')