diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 7080cbf05b..cff5550c85 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -241,8 +241,13 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) cudaDataType_t compute_type = cublas_data_type; #else - // TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance + // For best performance, set the default cublas compute type to + // CUBLAS_COMPUTE_16F for half precision and to + // CUBLAS_COMPUTE_32F_FAST_16F for full precision cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; + if (m->output_type[0] == DT_FLOAT) { + compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + } #endif // Compute (W^T)x matmul: einsum(ijkl,im->jmkl) // Weights: qSize x qProjSize x 3 x num_q_heads @@ -511,8 +516,13 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m, #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) cudaDataType_t compute_type = cublas_data_type; #else - // TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance + // For best performance, set the default cublas compute type to + // CUBLAS_COMPUTE_16F for half precision and to + // CUBLAS_COMPUTE_32F_FAST_16F for full precision cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; + if (m->output_type[0] == DT_FLOAT) { + compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + } #endif // int num_requests = bc->num_active_requests(); int num_tokens = bc->num_active_tokens(); diff --git a/src/ops/kernels/linear_kernels.cu b/src/ops/kernels/linear_kernels.cu index d8a9b5aa16..9373c2fb2f 100644 --- a/src/ops/kernels/linear_kernels.cu +++ b/src/ops/kernels/linear_kernels.cu @@ -314,8 +314,13 @@ void forward_kernel(LinearMeta const *m, #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) cudaDataType_t compute_type = cublas_data_type; #else - // TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance + // For best performance, set the default cublas compute type to + // CUBLAS_COMPUTE_16F for half precision and to + // CUBLAS_COMPUTE_32F_FAST_16F for full precision cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; + if (m->output_type[0] == DT_FLOAT) { + compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + } #endif checkCUDA(cublasGemmEx(m->handle.blas, CUBLAS_OP_T, @@ -404,8 +409,13 @@ void backward_kernel(LinearMeta const *m, #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) cudaDataType_t compute_type = cublas_data_type; #else - // TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance + // For best performance, set the default cublas compute type to + // CUBLAS_COMPUTE_16F for half precision and to + // CUBLAS_COMPUTE_32F_FAST_16F for full precision cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; + if (m->output_type[0] == DT_FLOAT) { + compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + } #endif int output_size = out_dim * batch_size; if (m->activation == AC_MODE_RELU) { diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 681c7a0f72..52e083889e 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -218,8 +218,13 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) cudaDataType_t compute_type = cublas_data_type; #else - // TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance + // For best performance, set the default cublas compute type to + // CUBLAS_COMPUTE_16F for half precision and to + // CUBLAS_COMPUTE_32F_FAST_16F for full precision cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; + if (m->output_type[0] == DT_FLOAT) { + compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + } #endif // int num_requests = bc->num_active_requests(); int num_tokens = bc->num_active_tokens(); diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index 758a93bbf7..0aa50f605c 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -161,8 +161,13 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m, #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) cudaDataType_t compute_type = cublas_data_type; #else - // TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance + // For best performance, set the default cublas compute type to + // CUBLAS_COMPUTE_16F for half precision and to + // CUBLAS_COMPUTE_32F_FAST_16F for full precision cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; + if (m->output_type[0] == DT_FLOAT) { + compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + } #endif // int num_requests = bc->num_active_requests(); int processed_tokens_in_batch = 0;