Skip to content

Commit

Permalink
shorten
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Dec 31, 2023
1 parent db62540 commit 702cb60
Showing 1 changed file with 10 additions and 20 deletions.
30 changes: 10 additions & 20 deletions include/fdeep/layers/multi_head_attention_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,9 @@ class multi_head_attention_layer : public layer
const std::size_t index, const std::string& name)
{
assertion(index <= 2, "Invalid dense layer index.");

const std::size_t index_factor = use_bias ? 2 : 1;

tensor weights = weights_and_biases[index_factor * index];

const std::size_t units = weights.shape().depth_;

tensor biases = use_bias ?
weights_and_biases[index_factor * index + 1] :
tensor(tensor_shape(num_heads, units), 0);
Expand All @@ -51,32 +47,26 @@ class multi_head_attention_layer : public layer
const auto biases_per_head = tensor_to_tensors_width_slices(biases);
assertion(weights_per_head.size() == num_heads, "Invalid weights for number of heads.");
assertion(biases_per_head.size() == num_heads, "Invalid biases for number of heads.");
const std::vector<dense_layer> dense_layers =
fplus::transform(
[&](const std::pair<std::size_t, std::pair<tensor, tensor>>& n_and_w_with_b)
{
return dense_layer(
name + "_" + std::to_string(n_and_w_with_b.first),
units,
*n_and_w_with_b.second.first.as_vector(),
*n_and_w_with_b.second.second.as_vector());
},
fplus::enumerate(fplus::zip(weights_per_head, biases_per_head)));
return dense_layers;
return fplus::transform(
[&](const std::pair<std::size_t, std::pair<tensor, tensor>>& n_and_w_with_b)
{
return dense_layer(
name + "_" + std::to_string(n_and_w_with_b.first),
units,
*n_and_w_with_b.second.first.as_vector(),
*n_and_w_with_b.second.second.as_vector());
},
fplus::enumerate(fplus::zip(weights_per_head, biases_per_head)));
}
dense_layer create_output_dense_layer(
const tensors& weights_and_biases, bool use_bias, const std::string& name)
{
const std::size_t index_factor = use_bias ? 2 : 1;

tensor weights = weights_and_biases[index_factor * 3];

const std::size_t units = weights.shape().depth_;

tensor biases = use_bias ?
weights_and_biases[index_factor * 3 + 1] :
tensor(tensor_shape(units), 0);

return dense_layer(name + "_output", units, *weights.as_vector(), *biases.as_vector());
}
tensors extract_biases(const tensors& saved_weights, bool use_bias)
Expand Down

0 comments on commit 702cb60

Please sign in to comment.