diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 7da9aa389c..695f4b13b9 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -1515,4 +1515,24 @@ template void Kernels::IncMultiHeadAttention::pre_build_weight_kernel( GenericTensorAccessorR const weight, DataType data_type, cudaStream_t stream); + +template void Kernels::IncMultiHeadAttention::compute_o_prod_bias( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + int shard_id, + float *output_ptr, + float const *weight_ptr, + float const *bias_ptr, + int num_tokens, + cudaStream_t stream); + +template void Kernels::IncMultiHeadAttention::compute_o_prod_bias( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + int shard_id, + half *output_ptr, + half const *weight_ptr, + half const *bias_ptr, + int num_tokens, + cudaStream_t stream); }; // namespace FlexFlow