Skip to content

Commit

Permalink
Merge branch 'inference' into legion_workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro authored Oct 22, 2023
2 parents 5559374 + caf5d61 commit 69cb57e
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 138 deletions.
85 changes: 43 additions & 42 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -650,19 +650,19 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m,
C_softmax));
// Matmul softmax(QK^T/sqrt(d_k)) by V
alpha = 1.0f, beta = 0.0f;
m_ = num_new_tokens;
n = m->vProjSize;
m_ = m->vProjSize;
n = num_new_tokens;
k = total_tokens;
lda = m_, ldb = n * m->num_q_heads, ldc = m_;
strideA = num_new_tokens * total_tokens;
strideB = vt_block_size;
strideC = num_new_tokens * m->vProjSize;
// To get A, skip over softmax(QK^T/sqrt(d_k)) entries from previous
// requests (all heads)
A = C_softmax;
// To get B, skip over V^T entries from previous requests (all heads +
lda = m_ * m->num_q_heads, ldb = n, ldc = m_ * m->num_q_heads;
strideA = vt_block_size;
strideB = num_new_tokens * total_tokens;
strideC = m->vProjSize;
// To get A, skip over V^T entries from previous requests (all heads +
// padding)
B = static_cast<DT *>(m->valueCache) + i * vt_req_block_size;
A = static_cast<DT *>(m->valueCache) + i * vt_req_block_size;
// To get B, skip over softmax(QK^T/sqrt(d_k)) entries from previous
// requests (all heads)
B = C_softmax;
// To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous
// requests
C = static_cast<DT *>(m->attn_heads) +
Expand Down Expand Up @@ -690,40 +690,41 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m,
m->num_q_heads,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Project to output, save result directly on output tensor
alpha = 1.0f, beta = 0.0f;
m_ = m->oProjSize;
k = m->vProjSize * m->num_q_heads;
n = num_new_tokens;
lda = k, ldb = n, ldc = m_;
A = weight_ptr + m->qSize * (m->qProjSize * m->num_q_heads +
m->kProjSize * m->num_q_heads +
m->vProjSize * m->num_q_heads);
B = C;
C = static_cast<DT *>(output_ptr) + tokens_previous_requests * m->oProjSize;
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_T,
CUBLAS_OP_T,
m_,
n,
k,
&alpha,
A,
cublas_data_type,
lda,
B,
cublas_data_type,
ldb,
&beta,
C,
cublas_data_type,
ldc,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
tokens_previous_requests += num_new_tokens;
}
// Project to output, save result directly on output tensor
DT alpha = 1.0f, beta = 0.0f;
int m_ = m->oProjSize;
int k = m->vProjSize * m->num_q_heads;
int n = bc->num_active_tokens();
int lda = k, ldb = k, ldc = m_;
DT const *A = weight_ptr + m->qSize * (m->qProjSize * m->num_q_heads +
m->kProjSize * m->num_q_heads +
m->vProjSize * m->num_q_heads);
DT const *B = static_cast<DT *>(m->attn_heads);
DT *C = static_cast<DT *>(output_ptr);
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_T,
CUBLAS_OP_N,
m_,
n,
k,
&alpha,
A,
cublas_data_type,
lda,
B,
cublas_data_type,
ldb,
&beta,
C,
cublas_data_type,
ldc,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
if (*m->final_bias && shard_id == 0) {
int parallelism = m->oProjSize * num_tokens;
int qkv_weight_size = m->qProjSize * m->global_num_q_heads +
Expand Down
100 changes: 49 additions & 51 deletions src/ops/spec_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m,
#endif
// int num_requests = bc->num_active_requests();
int num_tokens = bc->num_active_tokens();
int tokens_previous_requests = 0;
// int tokens_previous_requests = 0;
int tokens_prev_requests_squares = 0;
// int qkv_block_size =
// (m->qProjSize + m->kProjSize + m->vProjSize) * num_tokens;
Expand All @@ -241,10 +241,7 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m,
if (bc->request_completed[i]) {
continue;
}
assert(tokens_previous_requests ==
bc->requestsInfo[i].first_token_offset_in_batch);
for (int sub_req_id = 0; sub_req_id < bc->sub_requests[i]; sub_req_id++) {
// int num_new_tokens = bc->num_processing_tokens[i];
// int total_tokens = bc->token_last_available_idx[i] + 1;
Expand Down Expand Up @@ -273,8 +270,8 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m,
}
// To get A, skip over Q entries from previous requests (same head)
DT const *A = static_cast<DT *>(m->devQKVProjArray) +
tokens_previous_requests * m->qProjSize * m->num_q_heads *
QKV_WEIGHT_NUM;
bc->requestsInfo[i].first_token_offset_in_batch *
m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM;
// To get B, skip over K entries from previous requests (all heads +
// padding)
DT const *B = static_cast<DT *>(m->keyCache) +
Expand Down Expand Up @@ -380,24 +377,25 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m,
C_softmax));
// Matmul softmax(QK^T/sqrt(d_k)) by V
alpha = 1.0f, beta = 0.0f;
m_ = num_new_tokens;
n = m->vProjSize;
m_ = m->vProjSize;
n = num_new_tokens;
k = total_tokens;
lda = m_, ldb = n * m->num_q_heads, ldc = m_;
strideA = num_new_tokens * total_tokens;
strideB = vt_block_size;
strideC = num_new_tokens * m->vProjSize;
// To get A, skip over softmax(QK^T/sqrt(d_k)) entries from previous
// requests (all heads)
A = C_softmax;
// To get B, skip over V^T entries from previous requests (all heads +
lda = m_ * m->num_q_heads, ldb = n, ldc = m_ * m->num_q_heads;
strideA = vt_block_size;
strideB = num_new_tokens * total_tokens;
strideC = m->vProjSize;
// To get A, skip over V^T entries from previous requests (all heads +
// padding)
B = static_cast<DT *>(m->valueCache) +
A = static_cast<DT *>(m->valueCache) +
(i * bc->MAX_BEAM_WIDTH + sub_req_id) * vt_req_block_size;
// To get B, skip over softmax(QK^T/sqrt(d_k)) entries from previous
// requests (all heads)
B = C_softmax;
// To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous
// requests
C = static_cast<DT *>(m->attn_heads) +
tokens_previous_requests * m->num_q_heads * m->vProjSize;
bc->requestsInfo[i].first_token_offset_in_batch * m->num_q_heads *
m->vProjSize;
checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas,
CUBLAS_OP_N,
CUBLAS_OP_T,
Expand All @@ -422,42 +420,42 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Project to output, save result directly on output tensor
alpha = 1.0f, beta = 0.0f;
m_ = m->oProjSize;
k = m->vProjSize * m->num_q_heads;
n = num_new_tokens;
lda = k, ldb = n, ldc = m_;
A = weight_ptr + m->qSize * (m->qProjSize * m->num_q_heads +
m->kProjSize * m->num_q_heads +
m->vProjSize * m->num_q_heads);
B = C;
C = static_cast<DT *>(output_ptr) +
tokens_previous_requests * m->oProjSize;
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_T,
CUBLAS_OP_T,
m_,
n,
k,
&alpha,
A,
cublas_data_type,
lda,
B,
cublas_data_type,
ldb,
&beta,
C,
cublas_data_type,
ldc,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
tokens_previous_requests += num_new_tokens;
// tokens_previous_requests += num_new_tokens;
tokens_prev_requests_squares += num_new_tokens * total_tokens;
}
}
// Project to output, save result directly on output tensor
DT alpha = 1.0f, beta = 0.0f;
int m_ = m->oProjSize;
int k = m->vProjSize * m->num_q_heads;
int n = bc->num_active_tokens();
int lda = k, ldb = k, ldc = m_;
DT const *A = weight_ptr + m->qSize * (m->qProjSize * m->num_q_heads +
m->kProjSize * m->num_q_heads +
m->vProjSize * m->num_q_heads);
DT const *B = static_cast<DT *>(m->attn_heads);
DT *C = static_cast<DT *>(output_ptr);
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_T,
CUBLAS_OP_N,
m_,
n,
k,
&alpha,
A,
cublas_data_type,
lda,
B,
cublas_data_type,
ldb,
&beta,
C,
cublas_data_type,
ldc,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
if (*m->final_bias && shard_id == 0) {
int parallelism = m->oProjSize * num_tokens;
int qkv_weight_size = m->qProjSize * m->global_num_q_heads +
Expand Down
87 changes: 42 additions & 45 deletions src/ops/tree_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -338,24 +338,23 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m,
C_softmax));
// Matmul softmax(QK^T/sqrt(d_k)) by V
alpha = 1.0f, beta = 0.0f;
m_ = num_new_tokens;
n = m->vProjSize;
m_ = m->vProjSize;
n = num_new_tokens;
k = total_tokens_in_request;
lda = m_, ldb = n * m->num_q_heads, ldc = m_;
strideA = num_new_tokens * total_tokens_in_request;
strideB = vt_block_size;
strideC = num_new_tokens * m->vProjSize;
// To get A, skip over softmax(QK^T/sqrt(d_k)) entries from previous
// requests (all heads)
A = C_softmax;
// To get B, skip over V^T entries from previous requests (all heads +
lda = m_ * m->num_q_heads, ldb = n, ldc = m_ * m->num_q_heads;
strideA = vt_block_size;
strideB = num_new_tokens * total_tokens_in_request;
strideC = m->vProjSize;
// To get A, skip over V^T entries from previous requests (all heads +
// padding)
B = static_cast<DT *>(m->valueCache) + i * vt_req_block_size;
A = static_cast<DT *>(m->valueCache) + i * vt_req_block_size;
// To get B, skip over softmax(QK^T/sqrt(d_k)) entries from previous
// requests (all heads)
B = C_softmax;
// To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous
// requests
C = static_cast<DT *>(m->attn_heads) +
processed_tokens_in_batch * m->num_q_heads * m->vProjSize;
checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas,
CUBLAS_OP_N,
CUBLAS_OP_T,
Expand All @@ -379,45 +378,43 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m,
m->num_q_heads,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Project to output, save result directly on output tensor
alpha = 1.0f, beta = 0.0f;
m_ = m->oProjSize;
k = m->vProjSize * m->num_q_heads;
n = num_new_tokens;
lda = k, ldb = n, ldc = m_;
A = weight_ptr + m->qSize * (m->qProjSize * m->num_q_heads +
m->kProjSize * m->num_q_heads +
m->vProjSize * m->num_q_heads);
B = C;
C = static_cast<DT *>(output_ptr) +
processed_tokens_in_batch * m->oProjSize;
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_T,
CUBLAS_OP_T,
m_,
n,
k,
&alpha,
A,
cublas_data_type,
lda,
B,
cublas_data_type,
ldb,
&beta,
C,
cublas_data_type,
ldc,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
processed_tokens_in_batch += num_new_tokens;
}
// Before moving to the next request
// check that we have finished all tokens of the request
assert(last_token_idx_of_the_request + 1 == processed_tokens_in_batch);
}
// Project to output, save result directly on output tensor
DT alpha = 1.0f, beta = 0.0f;
int m_ = m->oProjSize;
int k = m->vProjSize * m->num_q_heads;
int n = processed_tokens_in_batch;
int lda = k, ldb = k, ldc = m_;
DT const *A = weight_ptr + m->qSize * (m->qProjSize * m->num_q_heads +
m->kProjSize * m->num_q_heads +
m->vProjSize * m->num_q_heads);
DT const *B = static_cast<DT *>(m->attn_heads);
DT *C = static_cast<DT *>(output_ptr);
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_T,
CUBLAS_OP_N,
m_,
n,
k,
&alpha,
A,
cublas_data_type,
lda,
B,
cublas_data_type,
ldb,
&beta,
C,
cublas_data_type,
ldc,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
if (*m->final_bias && shard_id == 0) {
int parallelism = m->oProjSize * processed_tokens_in_batch;
int qkv_weight_size = m->qProjSize * m->global_num_q_heads +
Expand Down

0 comments on commit 69cb57e

Please sign in to comment.