Skip to content

Commit

Permalink
Merge branch 'inference' into cuda_graph
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro authored Dec 24, 2023
2 parents 64a8058 + 7e7f955 commit 21c00bd
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1515,4 +1515,24 @@ template void Kernels::IncMultiHeadAttention::pre_build_weight_kernel<half>(
GenericTensorAccessorR const weight,
DataType data_type,
cudaStream_t stream);
template void Kernels::IncMultiHeadAttention::compute_o_prod_bias<float>(
IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
int shard_id,
float *output_ptr,
float const *weight_ptr,
float const *bias_ptr,
int num_tokens,
cudaStream_t stream);
template void Kernels::IncMultiHeadAttention::compute_o_prod_bias<half>(
IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
int shard_id,
half *output_ptr,
half const *weight_ptr,
half const *bias_ptr,
int num_tokens,
cudaStream_t stream);
}; // namespace FlexFlow

0 comments on commit 21c00bd

Please sign in to comment.