From 4957b7c7d4c73a6fca94ea40f140319b50b49e9a Mon Sep 17 00:00:00 2001 From: Xinhao Cheng <99570243+xinhaoc@users.noreply.github.com> Date: Sat, 30 Dec 2023 23:24:37 -0500 Subject: [PATCH] Specinfer - new kernel (#1252) * init * fix speculative * fix speculative * bitmap+tree verify * fix. * fix * multi batch * copy metadata once * fix some corner cases * Replicate load_token tasks so that it can be fused with other compute tasks; this eliminates Replicate and enables a larger fused op * more fix. * clean up * . * load batchconfig * clean * hip * hip --------- Co-authored-by: Zhihao Jia --- include/flexflow/batch_config.h | 29 +- include/flexflow/config.h | 11 + include/flexflow/model.h | 1 + .../ops/spec_inc_multihead_self_attention.h | 1 + .../ops/tree_inc_multihead_self_attention.h | 1 + include/flexflow/request_manager.h | 33 +- inference/models/llama.cc | 4 +- inference/spec_infer/spec_infer.cc | 3 + src/ops/argmax.cc | 2 +- src/ops/beam_topk.cc | 2 +- src/ops/beam_topk.cu | 65 +- src/ops/embedding.cc | 18 +- src/ops/inc_multihead_self_attention.cu | 81 +- src/ops/spec_inc_multihead_self_attention.cc | 12 +- src/ops/spec_inc_multihead_self_attention.cu | 964 +++++++++++------- src/ops/tree_inc_multihead_self_attention.cu | 232 +++-- src/runtime/inference_manager.cc | 56 +- src/runtime/model.cc | 48 +- src/runtime/model.cpp | 48 + src/runtime/model.cu | 25 + src/runtime/request_manager.cc | 639 +++++++++--- src/runtime/request_manager.cpp | 85 ++ src/runtime/request_manager.cu | 86 ++ 23 files changed, 1727 insertions(+), 719 deletions(-) diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index e2903c4d11..13904aaa46 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -56,6 +56,7 @@ class BatchConfig { // across workers static int const MAX_NUM_REQUESTS = 64; static int const MAX_NUM_TOKENS = 1024; + static int const MAX_SPEC_TREE_TOKEN_NUM = 64; // Set by update int num_tokens; @@ -68,6 +69,9 @@ class BatchConfig { int first_token_offset_in_batch; int num_tokens_in_batch; int max_sequence_length; + + // request id in batch config: + int batch_config_request_id; RequestGuid request_guid; }; struct PerTokenInfo { @@ -75,6 +79,24 @@ class BatchConfig { int request_index; TokenId token_id; }; + + struct BitMask { + unsigned long long mask[MAX_SPEC_TREE_TOKEN_NUM] = {0}; + + // how many tokens before the tree, every sub requests need this part of + // cache + int non_tree_cache_size; + + // current tree size + int tree_size; + + int this_layer_size; + + // input length-> prompt/root + int prompt_size; + }; + + BitMask causalMask[MAX_NUM_REQUESTS]; PerRequestInfo requestsInfo[MAX_NUM_REQUESTS]; PerTokenInfo tokensInfo[MAX_NUM_TOKENS]; @@ -126,9 +148,12 @@ class BeamSearchBatchConfig : public BatchConfig { size_t beam_width; size_t target_iterations; - inline static int const MAX_BEAM_WIDTH = 1; + inline static int const MAX_BEAM_WIDTH = 3; inline static int const MAX_BEAM_DEPTH = 8; + // maximum tree branches for a request + inline static int const MAX_SPECULATIVE_TREE_BRANCHES = 3; + int model_id; struct BeamSearchPerRequestInfo { @@ -139,6 +164,7 @@ class BeamSearchBatchConfig : public BatchConfig { BatchConfig::TokenId tokens[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; float probs[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; int parent_id[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; + int sub_request_num; }; struct BeamSearchPerTokenInfo { @@ -147,6 +173,7 @@ class BeamSearchBatchConfig : public BatchConfig { BeamSearchPerRequestInfo beamRequestsInfo[MAX_NUM_REQUESTS]; BeamSearchPerTokenInfo beamTokenInfo[MAX_NUM_TOKENS * MAX_BEAM_WIDTH]; + // why is this == MAX_NUM_REQUESTS * MAX_BEAM_WIDTH? int sub_requests[MAX_NUM_REQUESTS * MAX_BEAM_WIDTH]; diff --git a/include/flexflow/config.h b/include/flexflow/config.h index c2af6d707c..e1480264cc 100644 --- a/include/flexflow/config.h +++ b/include/flexflow/config.h @@ -16,6 +16,7 @@ #ifndef _FLEXFLOW_CONFIG_H_ #define _FLEXFLOW_CONFIG_H_ #include "ffconst.h" +#include "flexflow/batch_config.h" #include "legion.h" #include #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) @@ -75,6 +76,15 @@ struct FFHandler { #endif void *workSpace; size_t workSpaceSize; + void *batch_config_metadata; + + // request info + token info + topolopgy mask info + size_t batch_config_metadata_size = + sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + + sizeof(BeamSearchBatchConfig::beamTokenInfo) + + sizeof(BeamSearchBatchConfig::beamRequestsInfo) + + sizeof(BatchConfig::causalMask) + + sizeof(TreeVerifyBatchConfig::committed_tokens); void *offload_reserve_space; size_t offload_reserve_space_size; DataType quantization_type; @@ -132,6 +142,7 @@ class FFConfig { size_t workSpaceSize; Legion::Context lg_ctx; Legion::Runtime *lg_hlr; + Legion::IndexSpaceT<1> all_gpu_task_is; // Legion::FieldSpace field_space; bool syntheticInput, profiling, perform_fusion; bool inference_debugging; diff --git a/include/flexflow/model.h b/include/flexflow/model.h index d8402ba622..16df99ab1a 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -240,6 +240,7 @@ enum TaskIDs { // InferenceManager & RequestManager RM_LOAD_TOKENS_TASK_ID, RM_LOAD_POSITION_TASK_ID, + RM_LOAD_BATCH_CONFIG_TASK_ID, RM_PREPARE_NEXT_BATCH_TASK_ID, RM_PREPARE_NEXT_BATCH_INIT_TASK_ID, RM_PREPARE_NEXT_BATCH_BEAM_TASK_ID, diff --git a/include/flexflow/ops/spec_inc_multihead_self_attention.h b/include/flexflow/ops/spec_inc_multihead_self_attention.h index 56bb2bd80d..a306f7985a 100644 --- a/include/flexflow/ops/spec_inc_multihead_self_attention.h +++ b/include/flexflow/ops/spec_inc_multihead_self_attention.h @@ -142,6 +142,7 @@ class SpecIncMultiHeadSelfAttentionMeta : public IncMultiHeadSelfAttentionMeta { Realm::RegionInstance beam_search_reserve_inst; BeamSearchBatchConfig::BeamSearchPerTokenInfo *beam_token_infos; BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos; + BatchConfig::BitMask *causalMask; }; }; // namespace FlexFlow diff --git a/include/flexflow/ops/tree_inc_multihead_self_attention.h b/include/flexflow/ops/tree_inc_multihead_self_attention.h index 6e2da19ce9..d160da4a72 100644 --- a/include/flexflow/ops/tree_inc_multihead_self_attention.h +++ b/include/flexflow/ops/tree_inc_multihead_self_attention.h @@ -147,6 +147,7 @@ class TreeIncMultiHeadSelfAttentionMeta : public IncMultiHeadSelfAttentionMeta { int num_active_tokens; Realm::RegionInstance committed_token_reserve_inst; TreeVerifyBatchConfig::CommittedTokensInfo *committed_token_infos; + BatchConfig::BitMask *causalMask; }; }; // namespace FlexFlow diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index baf6844801..1c4b0b2a2f 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -38,10 +38,13 @@ class InferenceManager { Legion::FutureMap inference(FFModel *model, int index, BatchConfigFuture const &bc); void load_input_tokens_from_batch_config(BatchConfigFuture const &bc, - ParallelTensor const input); + ParallelTensor const input, + FFHandler *handlers); void load_positions(BatchConfigFuture const &bc, ParallelTensor position_input, int offset); + void load_inference_metadata_batch_config(BatchConfigFuture const &bc, + FFHandler *handlers); public: FFConfig ff_config; @@ -72,9 +75,10 @@ struct Request { struct BeamTree { struct treeLayer { BeamSearchBatchConfig::TokenId - tokens[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; + tokens[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES]; int parent_ids[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; - float probs[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; + float probs[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES]; + int nodes_num_this_layer = 0; }; treeLayer treeLayers[BeamSearchBatchConfig::MAX_BEAM_DEPTH + 1]; }; @@ -100,6 +104,7 @@ class RequestManager { void set_max_tokens_per_batch(int max_num_tokens); int get_max_tokens_per_batch(); void set_max_sequence_length(int max_seq_length); + void push_spec_infer_tree_width(int tree_width); int get_max_sequence_length(); int register_ssm_model(FFModel *model); void register_tokenizer(ModelType model_type, @@ -107,6 +112,16 @@ class RequestManager { int eos_token_id, std::string const &path); void register_output_filepath(std::string const &); + void initBitMask(BatchConfig::BitMask &bitmask, int initLength); + void appendBitMask(BatchConfig::BitMask &bitmask, + int newNodes, + int preBeamSize, + int old_sub_num, + BeamTree const tree, + int currentDepth); + void updateBitMask(BatchConfig::BitMask &bitmask, + int initLength, + int non_tree_size); FFModel *get_model(int model_id); @@ -148,6 +163,7 @@ class RequestManager { void store_beam_metadata(BeamSearchBatchConfig const &old_bc, BeamInferenceResult const &result); void update_beam_metadata(BeamSearchBatchConfig &new_bc, + BeamSearchBatchConfig const &old_bc, BeamTree &tree, int request_index); @@ -181,6 +197,11 @@ class RequestManager { Legion::Context ctx, Legion::Runtime *runtime); + static void + load_batch_config_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); static BatchConfig prepare_next_batch_task( Legion::Task const *task, std::vector const ®ions, @@ -210,6 +231,9 @@ class RequestManager { int max_requests_per_batch; int max_tokens_per_batch; int max_sequence_length; + + // tree width in each speculative step, if not specified 1 + std::vector spec_infer_tree_width; // private fields std::unique_ptr tokenizer_; bool verbose; @@ -243,7 +267,8 @@ class RequestManager { private: struct ProfileInfo { - int decoding_steps; + int llm_decoding_steps; + int ssm_decoding_steps; double start_time, finish_time; }; std::unordered_map profiling_requests; diff --git a/inference/models/llama.cc b/inference/models/llama.cc index b8fe70526d..10001ee916 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -246,7 +246,9 @@ void LLAMA::create_llama_model(FFModel &ff, if (mode == BEAM_SEARCH_MODE) { Tensor softmax = ff.softmax(dense, -1); // output = ff.beam_top_k(softmax, llama_config.max_beam_width, false); - output = ff.argmax(softmax, /*beam_Search*/ true); + // output = ff.argmax(softmax, /*beam_Search*/ true); + output = ff.beam_top_k(softmax, llama_config.max_beam_width, false); + // output = ff.top_k(softmax, ) } else { // Tensor softmax = ff.softmax(dense, -1); if (generation_config.do_sample) { diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index 8b0eb926d9..b369a13c1d 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -302,6 +302,9 @@ void FlexFlow::top_level_task(Task const *task, model_metadata.llm_tokenizer_path); rm->register_output_filepath(file_paths.output_file_path); + // first decoding step: 3 results + rm->push_spec_infer_tree_width(3); + // Create LLM model FFModel tree_model(ffconfig, ffconfig.cpu_offload); if (model_metadata.llm_model_type == ModelType::LLAMA) { diff --git a/src/ops/argmax.cc b/src/ops/argmax.cc index f336c843e8..dc7e4ea3b3 100644 --- a/src/ops/argmax.cc +++ b/src/ops/argmax.cc @@ -352,7 +352,6 @@ BeamInferenceResult GenericTensorAccessorW parent = helperGetGenericTensorAccessorWO( DT_INT32, regions[2], task->regions[2], FID_DATA, ctx, runtime); ArgMax::forward_kernel_wrapper(m, input, indices, parent, batch_size); - BeamInferenceResult ir; download_tensor( indices.get_int32_ptr(), ir.token_ids, batch_size); @@ -398,6 +397,7 @@ InferenceResult ArgMax::save_inference_tensors_to_file( m, shard_id, bc, {}, {}, {input, indices}); } + download_tensor( indices.get_int32_ptr(), ir.token_ids, batch_size); return ir; diff --git a/src/ops/beam_topk.cc b/src/ops/beam_topk.cc index 2883428254..18d0ec1587 100644 --- a/src/ops/beam_topk.cc +++ b/src/ops/beam_topk.cc @@ -366,7 +366,7 @@ BeamInferenceResult GenericTensorAccessorW value = helperGetGenericTensorAccessorWO( DT_FLOAT, regions[2], task->regions[2], FID_DATA, ctx, runtime); GenericTensorAccessorW parent = helperGetGenericTensorAccessorWO( - DT_FLOAT, regions[3], task->regions[3], FID_DATA, ctx, runtime); + DT_INT32, regions[3], task->regions[3], FID_DATA, ctx, runtime); Domain input_domain = runtime->get_index_space_domain( ctx, task->regions[0].region.get_index_space()); diff --git a/src/ops/beam_topk.cu b/src/ops/beam_topk.cu index 72ab7862a6..a958786be3 100644 --- a/src/ops/beam_topk.cu +++ b/src/ops/beam_topk.cu @@ -556,8 +556,6 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m, int beam_size = bc->beamRequestsInfo[i].beam_size; // initial request - log_beam_topk.debug() << "sub_requests: " << i << ", " << sub_requests[i] - << "\n"; assert(sub_requests[i] > 0); // process sub requests for (int j = 0; j < sub_requests[i]; j++) { @@ -565,12 +563,13 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m, // beam_slots[i].parent_id[j]; acc_probs[req_index * BeamSearchBatchConfig::MAX_BEAM_WIDTH + j] = bc->beamRequestsInfo[i].probs[j]; - log_beam_topk.debug() - << "probbbb req: " << i - << ", sub req probability : " << bc->beamRequestsInfo[i].probs[j] - << ", sub request id " << j << ", parent id " - << bc->beamRequestsInfo[i].parent_id[j] << ", data inddd" - << req_index * BeamSearchBatchConfig::MAX_BEAM_WIDTH + j << "\n"; + // std::cout << "probbbb req: " << i << ", sub req probability : " + // << bc->beamRequestsInfo[i].probs[j] << ", sub request id " << + // j + // << ", parent id " << bc->beamRequestsInfo[i].parent_id[j] + // << ", data inddd" + // << req_index * BeamSearchBatchConfig::MAX_BEAM_WIDTH + j + // << "\n"; } // process tokens @@ -584,6 +583,7 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m, max_heap_size = std::max(max_heap_size, beam_size * sub_requests[i]); max_beam_width = std::max(max_beam_width, beam_size); + req_index += 1; block_start_index += (sub_requests[i] - 1) * num_new_tokens * length; } @@ -613,28 +613,37 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m, assert(num_shards >= (size_t)max_heap_size); num_shards = max_heap_size; - checkCUDA(cudaMemcpy(m->parent_ids, - parent_ids, - sizeof(int) * max_total_requests, - cudaMemcpyHostToDevice)); - checkCUDA(cudaMemcpy(m->acc_probs, - acc_probs, - sizeof(DT) * max_total_requests, - cudaMemcpyHostToDevice)); - checkCUDA(cudaMemcpy(m->block_start_index, - beam_block_start_index.data(), - sizeof(int) * beam_num_blocks, - cudaMemcpyHostToDevice)); - checkCUDA(cudaMemcpy(m->request_id, - request_id.data(), - sizeof(int) * beam_num_blocks, - cudaMemcpyHostToDevice)); - checkCUDA(cudaMemcpy(m->tokens_per_request, - tokens_per_request.data(), - sizeof(int) * beam_num_blocks, - cudaMemcpyHostToDevice)); + checkCUDA(cudaMemcpyAsync(m->parent_ids, + parent_ids, + sizeof(int) * max_total_requests, + cudaMemcpyHostToDevice, + stream)); + checkCUDA(cudaMemcpyAsync(m->acc_probs, + acc_probs, + sizeof(DT) * max_total_requests, + cudaMemcpyHostToDevice, + stream)); + // trick, set acc_probs to 0; + checkCUDA(cudaMemsetAsync( + m->acc_probs, 1.0, max_total_requests * sizeof(DT), stream)); + checkCUDA(cudaMemcpyAsync(m->block_start_index, + beam_block_start_index.data(), + sizeof(int) * beam_num_blocks, + cudaMemcpyHostToDevice, + stream)); + checkCUDA(cudaMemcpyAsync(m->request_id, + request_id.data(), + sizeof(int) * beam_num_blocks, + cudaMemcpyHostToDevice, + stream)); + checkCUDA(cudaMemcpyAsync(m->tokens_per_request, + tokens_per_request.data(), + sizeof(int) * beam_num_blocks, + cudaMemcpyHostToDevice, + stream)); // int depth = // bc->beamRequestsInfo[bc->tokensInfo[0].request_index].current_depth; + beam_num_blocks = bc->num_active_tokens(); beam_topk_forward_kernel<<>>( input_ptr, shared_memory_size, diff --git a/src/ops/embedding.cc b/src/ops/embedding.cc index 007e799fe0..76236e65ff 100644 --- a/src/ops/embedding.cc +++ b/src/ops/embedding.cc @@ -155,11 +155,8 @@ int Embedding::output_size(ParallelDim output_dims[MAX_TENSOR_DIM]) { output_dims[OUT_CHANNELS].size = this->out_channels; output_dims[OUT_CHANNELS].degree = 1; output_dims[OUT_CHANNELS].parallel_idx = -1; - // Currently do not support parallelizing over the replica dim - output_dims[num_dims - 1].size = 1; - output_dims[num_dims - 1].degree = 1; - output_dims[num_dims - 1].parallel_idx = -1; - output_dims[num_dims - 1].is_replica_dim = true; + // Copy replica dim + output_dims[num_dims - 1] = input->dims[input->num_dims - 1]; return num_dims; } else { int num_dims = input->num_dims; @@ -170,11 +167,8 @@ int Embedding::output_size(ParallelDim output_dims[MAX_TENSOR_DIM]) { output_dims[OUT_CHANNELS].size = this->out_channels; output_dims[OUT_CHANNELS].degree = 1; output_dims[OUT_CHANNELS].parallel_idx = -1; - // Currently do not support parallelizing over the replica dim - output_dims[num_dims - 1].size = 1; - output_dims[num_dims - 1].degree = 1; - output_dims[num_dims - 1].parallel_idx = -1; - output_dims[num_dims - 1].is_replica_dim = true; + // Copy replica dim + output_dims[num_dims - 1] = input->dims[input->num_dims - 1]; return num_dims; } // const int REPLICA = this->output_vocab_size_replica_dim(); @@ -189,13 +183,13 @@ int Embedding::weight_size(ParallelDim weight_dims[MAX_TENSOR_DIM]) { weight_dims[Weight::VOCAB_SIZE].size = this->num_entries; weight_dims[Weight::VOCAB_SIZE].degree = 1; weight_dims[Weight::VOCAB_SIZE].parallel_idx = -1; - for (int i = 2; i < input->num_dims; i++) { + for (int i = 2; i < input->num_dims + 1; i++) { weight_dims[i].size = input->dims[i - 1].degree; weight_dims[i].degree = weight_dims[i].size; weight_dims[i].parallel_idx = input->dims[i - 1].parallel_idx; weight_dims[i].is_replica_dim = true; } - return input->num_dims; + return input->num_dims + 1; } void Embedding::register_output_mappings() { diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 695f4b13b9..da70e23f87 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -82,6 +82,9 @@ __global__ void compute_attention_kernel_generation_kernel( // request idx int const request_idx = blockIdx.y; + int const batch_config_request_id = + request_infos[request_idx].batch_config_request_id; + int const beam_request_idx = is_beam ? request_idx / max_beam_width : request_idx; int const beam_sub_request_idx = is_beam ? request_idx % max_beam_width : 0; @@ -89,8 +92,8 @@ __global__ void compute_attention_kernel_generation_kernel( int const first_step = 0; int const tlength = - request_infos[beam_request_idx].first_token_depth_in_request + - request_infos[beam_request_idx].num_tokens_in_batch; + request_infos[batch_config_request_id].first_token_depth_in_request + + request_infos[batch_config_request_id].num_tokens_in_batch; // shared memory objects extern __shared__ char smem_[]; @@ -103,7 +106,8 @@ __global__ void compute_attention_kernel_generation_kernel( // first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum __shared__ float red_smem[WARPS_PER_BLOCK * 2]; - const DT *q_ptr = query + beam_request_idx * hidden_size * QKV_WEIGHT_NUM + + const DT *q_ptr = query + + batch_config_request_id * hidden_size * QKV_WEIGHT_NUM + head_idx * per_head_size; __shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD]; // DT const *q_ptr = @@ -139,7 +143,7 @@ __global__ void compute_attention_kernel_generation_kernel( DT const *k_cache_batch = key_cache + - (beam_request_idx * max_beam_width + beam_sub_request_idx) * + (batch_config_request_id * max_beam_width + beam_sub_request_idx) * max_seq_length * hidden_size + ki; @@ -245,7 +249,7 @@ __global__ void compute_attention_kernel_generation_kernel( // The base pointer for the value in the cache buffer. DT const *v_cache_batch = value_cache + - (beam_request_idx * max_beam_width + beam_sub_request_idx) * + (batch_config_request_id * max_beam_width + beam_sub_request_idx) * max_seq_length * hidden_size + vi; @@ -825,19 +829,6 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta const *m, bias_ptr = static_cast
(m->bias_ptr); } - // todo Xinhao copy how many requests if requests are not continous? - cudaMemcpyAsync(m->token_infos, - &(bc->tokensInfo), - bc->num_active_tokens() * sizeof(BatchConfig::PerTokenInfo), - cudaMemcpyHostToDevice, - stream); - cudaMemcpyAsync(m->request_infos, - &(bc->requestsInfo), - bc->max_requests_per_batch() * - sizeof(BatchConfig::PerRequestInfo), - cudaMemcpyHostToDevice, - stream); - // phase 1: Implement kernel to compute KQV for input tokens compute_qkv_kernel(m, bc, @@ -1364,8 +1355,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( vProjSize * num_q_heads); size_t key_cache_size = 0, value_cache_size = 0; switch (infer_mode) { - case INC_DECODING_MODE: - case TREE_VERIFY_MODE: { + case INC_DECODING_MODE: { key_cache_size = num_q_heads * kProjSize * BatchConfig::max_requests_per_batch() * BatchConfig::max_sequence_length(); @@ -1374,22 +1364,24 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( BatchConfig::max_sequence_length(); break; } - case BEAM_SEARCH_MODE: { + case BEAM_SEARCH_MODE: + case TREE_VERIFY_MODE: { + // a K-ary tree max node is (k^n - 1) / 2 key_cache_size = num_q_heads * kProjSize * BeamSearchBatchConfig::max_requests_per_batch() * - BatchConfig::max_sequence_length() * - BeamSearchBatchConfig::MAX_BEAM_WIDTH; + (BatchConfig::max_sequence_length() + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); value_cache_size = num_q_heads * vProjSize * BeamSearchBatchConfig::max_requests_per_batch() * - BatchConfig::max_sequence_length() * - BeamSearchBatchConfig::MAX_BEAM_WIDTH; + (BatchConfig::max_sequence_length() + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); break; } default: assert(false && "Unkown inference mode"); } size_t requestinfo_size = BatchConfig::max_requests_per_batch(); - size_t tokeninfo_size = max_tokens_per_batch; + // size_t tokeninfo_size = max_tokens_per_batch; size_t qk_prod_size = max_tokens_per_batch * BatchConfig::max_sequence_length() * num_q_heads; size_t attn_heads_size = max_tokens_per_batch * num_q_heads * vProjSize; @@ -1400,11 +1392,8 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( (qkv_max_proj_size + key_cache_size + value_cache_size + 2 * qk_prod_size + attn_heads_size) * size_of_dt + - tokeninfo_size * sizeof(BatchConfig::PerTokenInfo) + - complex_size * sizeof(cuFloatComplex) + - requestinfo_size * - sizeof(BatchConfig::PerRequestInfo); // more components will - // be added here later + complex_size * sizeof(cuFloatComplex); // more components will + // be added here later if (offload) { // assert that we have enough reserved work space left size_t totalSharedSize = @@ -1447,10 +1436,16 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( valueCache = gpu_mem_allocator.allocate_instance_untyped(value_cache_size * size_of_dt); + token_infos = + static_cast(handler.batch_config_metadata); + request_infos = reinterpret_cast( + reinterpret_cast(handler.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo)); + if (offload) { - token_infos = - gpu_mem_allocator.allocate_reserved( - tokeninfo_size); + // token_infos = + // gpu_mem_allocator.allocate_reserved( + // tokeninfo_size); // offset += sizeof(BatchConfig::PerTokenInfo) * tokeninfo_size; qk_prods = gpu_mem_allocator.allocate_reserved_untyped(qk_prod_size * size_of_dt); @@ -1464,13 +1459,13 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( complex_input = gpu_mem_allocator.allocate_reserved(complex_size); // offset += complex_size * sizeof(cuFloatComplex); - request_infos = - gpu_mem_allocator.allocate_reserved( - requestinfo_size); + // request_infos = + // gpu_mem_allocator.allocate_reserved( + // requestinfo_size); } else { - token_infos = - gpu_mem_allocator.allocate_instance( - tokeninfo_size); + // token_infos = + // gpu_mem_allocator.allocate_instance( + // tokeninfo_size); qk_prods = gpu_mem_allocator.allocate_instance_untyped(qk_prod_size * size_of_dt); qk_prods_softmax = gpu_mem_allocator.allocate_instance_untyped( @@ -1479,9 +1474,9 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( size_of_dt); complex_input = gpu_mem_allocator.allocate_instance(complex_size); - request_infos = - gpu_mem_allocator.allocate_instance( - requestinfo_size); + // request_infos = + // gpu_mem_allocator.allocate_instance( + // requestinfo_size); } // allocate more size for quantization data diff --git a/src/ops/spec_inc_multihead_self_attention.cc b/src/ops/spec_inc_multihead_self_attention.cc index eb6fd721e6..5d234df822 100644 --- a/src/ops/spec_inc_multihead_self_attention.cc +++ b/src/ops/spec_inc_multihead_self_attention.cc @@ -53,7 +53,7 @@ bool SpecIncMultiHeadSelfAttentionParams::is_valid( } Tensor - FFModel::spec_inc_multihead_self_attention(const Tensor input, + FFModel::spec_inc_multihead_self_attention(Tensor const input, int embed_dim, int num_heads, int kdim, @@ -91,7 +91,7 @@ Tensor } Tensor - FFModel::spec_inc_multiquery_self_attention(const Tensor input, + FFModel::spec_inc_multiquery_self_attention(Tensor const input, int embed_dim, int num_q_heads, int num_kv_heads, @@ -257,7 +257,7 @@ Op *SpecIncMultiHeadSelfAttention::create_operator_from_layer( SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( FFModel &model, LayerID const &_layer_guid, - const ParallelTensor _input, + ParallelTensor const _input, int _embed_dim, int _num_q_heads, int _num_kv_heads, @@ -358,8 +358,8 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( FFModel &model, - const ParallelTensor _input, - const ParallelTensor _weight, + ParallelTensor const _input, + ParallelTensor const _weight, int _embed_dim, int _num_q_heads, int _num_kv_heads, @@ -465,7 +465,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( FFModel &model, SpecIncMultiHeadSelfAttention const &other, - const ParallelTensor input, + ParallelTensor const input, bool allocate_weights) : SpecIncMultiHeadSelfAttention(model, other.layer_guid, diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 562dee4d93..88dd3f92e4 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -23,16 +23,286 @@ namespace FlexFlow { +#define WARP_SIZE 32 + // declare Legion names using Legion::coord_t; using Legion::Memory; using namespace Kernels::IncMultiHeadAttention; namespace Kernels { -namespace SpecIncMultiHeadAttention { +namespace SpecIncMultiHeadSelfAttention { + +template +__global__ void compute_spec_inc_attention_kernel_generation_kernel( + DT const *query, + DT const *key_cache, + DT const *value_cache, + DT *output_ptr, + float const scale, + int const max_seq_length, + int per_head_size, + int hidden_size, + BatchConfig::PerRequestInfo *request_infos, + BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos, + BatchConfig::BitMask *causalMask) { + + // q, k + using Q_vec = typename VEC_K::Type; + using K_vec = typename VEC_K::Type; + using V_vec = typename VEC_V
::Type; + using Out_sum = typename Vec_fp32_::Type; + + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + + constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(DT); + constexpr int K_ELTS_PER_THREAD = Dh / THREADS_PER_KEY; + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + // constexpr int QK_ELTS_IN_16B = 16 / sizeof(DT); + + // thread id + int const tidx = threadIdx.x; + // head id + int const head_idx = blockIdx.x; + // nth request idx + int const request_idx = blockIdx.y; + + // request id in batch config + int const batch_config_request_id = + request_infos[request_idx].batch_config_request_id; + + // request_idx = re + + BatchConfig::BitMask bitmask = causalMask[batch_config_request_id]; + + int const first_step = 0; + + // int const tlength = + // request_infos[batch_config_request_id].first_token_depth_in_request + + // request_infos[batch_config_request_id].num_tokens_in_batch; + + int const totalCacheSize = bitmask.non_tree_cache_size + bitmask.tree_size; + + int first_token_idx = 0; + for (int r = 0; r < request_idx; r++) { + first_token_idx += causalMask[r].this_layer_size; + } + + int const tree_branch_num = + beam_request_infos[batch_config_request_id].sub_request_num; + + // shared memory objects + extern __shared__ char smem_[]; + + float *qk_smem = reinterpret_cast(smem_); + float *out_smem = reinterpret_cast(smem_); + + float qk_max = -FLT_MAX; + + // first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum + __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + + const DT *q_ptr = query + first_token_idx * hidden_size * QKV_WEIGHT_NUM + + head_idx * per_head_size; + __shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD]; + + // the start offset of the element eg. (0, 1, 2, 3) * K_VEC_SIZE + int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; + int ki_o = tidx % THREADS_PER_KEY; + // the first key's offset for this thread + // ko = 0, 0, 0, 0, 1, 1, 1, 1, .... + int ko = tidx / THREADS_PER_KEY; + // load q tensor + Q_vec q_vec[K_VECS_PER_THREAD]; + + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; + // The number of keys per warp. + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + + DT const *k_cache_batch = + key_cache + batch_config_request_id * max_seq_length * hidden_size + ki; + + int ti_end = + div_up(totalCacheSize - first_step, K_PER_WARP) * K_PER_WARP + first_step; + + for (int qi = 0; qi < tree_branch_num; qi += 1) { +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + q_vecs[ki_o][ii] = *reinterpret_cast( + q_ptr + (hidden_size * QKV_WEIGHT_NUM * qi) + ki + + ii * THREADS_PER_KEY * K_VEC_SIZE); + } + + int const query_token = bitmask.tree_size - tree_branch_num + qi; + + __syncthreads(); + for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { + K_vec k[K_VECS_PER_THREAD]; + int const ti_circ = ti % max_seq_length; + + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * THREADS_PER_KEY * K_VEC_SIZE; + if (ti < totalCacheSize) { + + k[ii] = *reinterpret_cast( + k_cache_batch + ti_circ * hidden_size + head_idx * per_head_size + + jj); + } + } + float qk = scale * Qk_dot::dot(q_vecs[ki_o], k); + + if (ti < totalCacheSize && tidx % THREADS_PER_KEY == 0) { + // todo add alobi here + // bool const mask = ti_circ >= totalCacheSize; + bool const mask = (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & + (1 << query_token)))); + + // if (blockIdx.y == 0 && blockIdx.x == 0 && !mask) { + // printf("spec inc attn qkqkqk %d, %.10f, %d\n", ti, qk, qi); + // } + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + qk_smem[ti - first_step] = mask ? 0.f : qk; + } + } + + __syncthreads(); + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Decompose the thread index into warp and lane. + int const warp = tidx / WARP_SIZE; + int const lane = tidx % WARP_SIZE; + + // The warp leader writes the max to shared memory. + if (lane == 0) { + red_smem[warp] = qk_max; + } + + // Make sure the products are in shared memory. + __syncthreads(); + + // The warps finalize the reduction. + qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Broadcast to all the threads in the warp. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { + // printf("spec inc attn first token qk_max %.10f\n", qk_max); + // } + + float exp_sum = 0.f; + for (int ti = first_step + tidx; ti < totalCacheSize; + ti += THREADS_PER_BLOCK) { + bool const mask = (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & + (1 << query_token)))); + float logit = mask ? 0.0f : __expf(qk_smem[ti - first_step] - qk_max); + exp_sum += logit; + qk_smem[ti - first_step] = mask ? 0.0f : logit; + } + + // Compute the sum. + exp_sum = block_sum(&red_smem[WARPS_PER_BLOCK], exp_sum); + + // softmax + float inv_sum = __fdividef(1.f, exp_sum + 1.e-6); + for (int ti = first_step + tidx; ti < totalCacheSize; + ti += THREADS_PER_BLOCK) { + qk_smem[ti - first_step] *= inv_sum; + } + + __syncthreads(); + + // value projection + constexpr int V_VEC_SIZE = 16 / sizeof(DT); + // A vector of V elements for the current timestep. + // using V_vec_k = typename V_vec_k_::Type; + // using V_vec_acum = typename V_vec_acum_fp32_::Type; + + // The value computed by this thread. + int vo = tidx / THREADS_PER_VALUE; + // The hidden dimensions computed by this particular thread. + int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + + Out_sum out; + zero(out); + + // The base pointer for the value in the cache buffer. + DT const *v_cache_batch = + value_cache + batch_config_request_id * max_seq_length * hidden_size + + vi; + + if (Dh == Dh_MAX || vi < Dh) { + for (int ti = first_step + vo; ti < totalCacheSize; ti += V_PER_ITER) { + // Load the values from the cache. + int const ti_circ = ti % max_seq_length; + V_vec v = *reinterpret_cast( + v_cache_batch + ti_circ * hidden_size + head_idx * per_head_size); + + bool const mask = (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & + (1 << query_token)))); + float logit = mask ? 0.0f : qk_smem[ti - first_step]; + out = FlexFlow::fma(logit, cast_to_float(v), out); + } + } + + // // Make sure we can start writing to shared memory. + __syncthreads(); + + // Run the final reduction amongst the different groups computing different + // partial outputs. + if (Dh == Dh_MAX || vi < Dh) { +#pragma unroll + for (int active_groups = V_PER_ITER; active_groups >= 2; + active_groups /= 2) { + + // The midpoint in the number of active groups. + int midpoint = active_groups / 2; + + // The upper part of active threads store to shared memory. + if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { + *reinterpret_cast(out_smem + (vo - midpoint) * Dh + vi) = + out; + } + __syncthreads(); + + // The bottom warps update their values. + if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { + out = add(*reinterpret_cast(out_smem + vo * Dh + vi), + out); + } + __syncthreads(); + } + } + + // Output the final values. + if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { + convert_from_float(*reinterpret_cast( + output_ptr + (first_token_idx + qi) * hidden_size + + head_idx * per_head_size + vi), + out); + } + } +} template -__global__ void spec_store_kv_cache( +__global__ void spec_inc_store_kv_cache( DT const *devQKVProjArray, DT *kCache_ptr, DT *vCache_ptr, @@ -40,16 +310,16 @@ __global__ void spec_store_kv_cache( BatchConfig::PerRequestInfo *requestInfo, BeamSearchBatchConfig::BeamSearchPerTokenInfo *beamTokenInfos, BeamSearchBatchConfig::BeamSearchPerRequestInfo *beamRequestInfos, + BatchConfig::BitMask *causalMask, int qProjSize, int kProjSize, int vProjSize, int num_tokens, int max_seq_len, - int max_beam_width, bool is_root, int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size * 2) { - int token_idx = i / (hidden_size * KV_WEIGHT_NUM); + CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { + int token_idx = i / (hidden_size); int offset = i % hidden_size; size_t val_idx = @@ -58,100 +328,30 @@ __global__ void spec_store_kv_cache( DT kVal = devQKVProjArray[val_idx]; DT vVal = devQKVProjArray[val_idx + hidden_size]; - // above no need to be changed - // int const req_id = id_map[token_idx].request_index; - // int const tok_id = id_map[token_idx].token_position; - // int const sub_req_id = id_map[token_idx].sub_request_index; - // int const parent_id = id_map[token_idx].parent_id; - // int const beam_depth = id_map[token_idx].beam_depth; - // int const beam_width = id_map[token_idx].beam_width; - int const req_id = tokenInfos[token_idx].request_index; - int const tok_id = tokenInfos[token_idx].abs_depth_in_request; - int const sub_req_id = beamTokenInfos[token_idx].sub_request_index; - int const parent_id = beamRequestInfos[req_id].parent_id[sub_req_id]; - int const beam_depth = beamRequestInfos[req_id].current_depth; - int const beam_width = beamRequestInfos[req_id].beam_size; - - kCache_ptr[(req_id * max_beam_width + sub_req_id) * - (hidden_size * max_seq_len) + - tok_id * hidden_size + offset] = kVal; - vCache_ptr[(req_id * max_beam_width + sub_req_id) * - (hidden_size * max_seq_len) + - tok_id * hidden_size + offset] = vVal; - - // replica in the root iteration - if (beam_depth == 1) { - for (int i = 1; i < beam_width; i++) { - kCache_ptr[(req_id * max_beam_width + i) * (hidden_size * max_seq_len) + - tok_id * hidden_size + offset] = kVal; - vCache_ptr[(req_id * max_beam_width + i) * (hidden_size * max_seq_len) + - tok_id * hidden_size + offset] = vVal; - } - } + // int const tok_id = tokenInfos[token_idx].abs_depth_in_request; - // if (head_idx == 0 && beam_depth == 0 && token_idx == 8 && k_cache) { - // // printf("token idx %d\n", token_idx); - // printf("data idx: %d, tok_id %d, new_token_cache_idx %d, parent_id %d, - // " - // "sub_req_id %d, num_tokens %d, kProjSize %d, num_kv_heads %d, - // val " - // "%f, beam_width %d\n", - // data_idx, - // tok_id, - // new_token_cache_idx, - // parent_id, - // sub_req_id, - // num_tokens, - // kProjSize, - // num_kv_heads, - // val, - // beam_width); - // } + int const request_token_offset = + requestInfo[req_id].first_token_offset_in_batch; - // naive cache stealing - if (sub_req_id != parent_id) { - if (offset == 0 && tok_id == 0) { - printf("cache stealing!, depth %d req_id %d sub_req_id %d, parentid " - "%d, tok_id %d\n", - beam_depth, - req_id, - sub_req_id, - parent_id, - tok_id); - } + BatchConfig::BitMask bitmask = causalMask[req_id]; - for (int depth = 0; depth < beam_depth; depth++) { - int steal_token_idx = tok_id - beam_depth + depth; - int steal_from_idx = (req_id * max_beam_width + parent_id) * - (hidden_size * max_seq_len) + - steal_token_idx * hidden_size + offset; - int steal_to_idx = (req_id * max_beam_width + sub_req_id) * - (hidden_size * max_seq_len) + - steal_token_idx * hidden_size + offset; - kCache_ptr[steal_to_idx] = kCache_ptr[steal_from_idx]; - vCache_ptr[steal_to_idx] = vCache_ptr[steal_from_idx]; - - // if(data_idx == 0 && head_idx == 0 && k_cache && req_id == 1){ - // printf("cache stealing kernel!, steal_token_idx %d\n", - // steal_token_idx); - // } - } - } + // int const tree_branch_num = beamRequestInfos[req_id].sub_request_num; + + // int const query_token = bitmask.non_tree_cache_size + bitmask.tree_size - + // tree_branch_num + sub_req_id + tok_id; + // bitmask.tree_size - tree_branch_num + sub_req_id; + + // if prompt token -> token id + // if tree token: + int const cache_idx = bitmask.non_tree_cache_size + bitmask.tree_size - + bitmask.this_layer_size + token_idx - + request_token_offset; - // parallel cache stealing not yet implemented - // logic shld be - // launch spec_store_kv_cache with parallelism * current depth - // from the i here, get depth index - // if depth index not the current one, check if we need to steal - // steal if needed - - // cache stealing theory - // identify which sub request does this token come from - // for initial token, 0 - // for other, may 0,0,1/ 0,1,2/ 1,1,1 to get which cache to be reuse and - // which to be delete copy beam_size bunch of blocks when sub_req_id == - // parent_id : like 0 -> 0, 1->1, 2->2, do nothing, just append the new k/v + kCache_ptr[req_id * (hidden_size * max_seq_len) + (cache_idx)*hidden_size + + offset] = kVal; + vCache_ptr[req_id * (hidden_size * max_seq_len) + (cache_idx)*hidden_size + + offset] = vVal; } } @@ -161,28 +361,79 @@ void update_kv_cache_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, cudaStream_t stream) { int num_tokens = bc->num_active_tokens(); int curr_depth = bc->beamRequestsInfo[0].current_depth; - // printf("curr depth: %d\n", curr_depth); - // assert(curr_depth < 3); if (num_tokens > 0) { int parallelism = m->hidden_size * KV_WEIGHT_NUM * num_tokens; - spec_store_kv_cache<<>>(static_cast
(m->devQKVProjArray), - static_cast
(m->keyCache), - static_cast
(m->valueCache), - m->token_infos, - m->request_infos, - m->beam_token_infos, - m->beam_request_infos, - m->qProjSize, - m->kProjSize, - m->vProjSize, - num_tokens, - BatchConfig::max_sequence_length(), - BeamSearchBatchConfig::MAX_BEAM_WIDTH, - /*root*/ curr_depth == 0, - m->hidden_size); + spec_inc_store_kv_cache<<>>( + static_cast
(m->devQKVProjArray), + static_cast
(m->keyCache), + static_cast
(m->valueCache), + m->token_infos, + m->request_infos, + m->beam_token_infos, + m->beam_request_infos, + m->causalMask, + m->qProjSize, + m->kProjSize, + m->vProjSize, + num_tokens, + BatchConfig::max_sequence_length() + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, + /*root*/ curr_depth == 0, + m->hidden_size); + } +} + +#define LAUNCH_SPEC_INC_ATTENTION_SCORE_KERNEL( \ + DT, Dh, Dh_MAX, THDS_PER_KEY, THREADS_PER_VALUE, THDS_PER_BLOCK, stream) \ + smem_sz = smem_size_in_bytes
(m->qProjSize, \ + BatchConfig::max_sequence_length() + \ + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, \ + THREADS_PER_VALUE, \ + THDS_PER_BLOCK); \ + compute_spec_inc_attention_kernel_generation_kernel \ + <<>>( \ + static_cast
(m->devQKVProjArray), \ + static_cast
(m->keyCache), \ + static_cast
(m->valueCache), \ + output_ptr, \ + scale, \ + BatchConfig::max_sequence_length() + \ + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, \ + m->qProjSize, \ + m->hidden_size, \ + m->request_infos, \ + m->beam_request_infos, \ + m->causalMask) + +template +void compute_spec_inc_attention_kernel_generation( + SpecIncMultiHeadSelfAttentionMeta const *m, + BeamSearchBatchConfig const *bc, + DT *output_ptr, + cudaStream_t stream) { + // one block == one head per request + dim3 grid(m->num_q_heads, bc->num_active_requests()); + int const per_head_size = m->qProjSize; + float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f; + size_t smem_sz; + if (per_head_size == 64) { + constexpr int THREADS_PER_VALUE_64 = threads_per_value_t::value; + LAUNCH_SPEC_INC_ATTENTION_SCORE_KERNEL( + DT, 64, 64, 4, THREADS_PER_VALUE_64, 128, stream); + } else if (per_head_size == 128) { + constexpr int THREADS_PER_VALUE_128 = threads_per_value_t::value; + LAUNCH_SPEC_INC_ATTENTION_SCORE_KERNEL( + DT, 128, 128, 4, THREADS_PER_VALUE_128, 128, stream); + } else { + assert(false && "a unsupported head size"); } } @@ -236,199 +487,208 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, int q_block_size = m->qProjSize; int kt_block_size = m->kProjSize; - int kt_req_block_size = - kt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); + int kt_req_block_size = kt_block_size * m->num_q_heads * + (BatchConfig::max_sequence_length() + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); int vt_block_size = m->vProjSize; - int vt_req_block_size = - vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); + int vt_req_block_size = vt_block_size * m->num_q_heads * + (BatchConfig::max_sequence_length() + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); assert(m->qProjSize == m->kProjSize); for (int i = 0; i < bc->max_requests_per_batch(); i++) { if (bc->request_completed[i]) { continue; } - for (int sub_req_id = 0; sub_req_id < bc->sub_requests[i]; sub_req_id++) { - // int num_new_tokens = bc->num_processing_tokens[i]; - // int total_tokens = bc->token_last_available_idx[i] + 1; + // else if (tokens_previous_requests < bc->num_generation_tokens) { + // tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch; + // continue; + // } - int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; - int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + - bc->requestsInfo[i].num_tokens_in_batch; + // all requests in prompt phase should only have one sub requests; + assert(bc->sub_requests[i] == 1); + // int num_new_tokens = bc->num_processing_tokens[i]; + // int total_tokens = bc->token_last_available_idx[i] + 1; - if (num_new_tokens <= 0) { - continue; - } + int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; + int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + + bc->requestsInfo[i].num_tokens_in_batch; - // Compute (QK^T/sqrt(d_k)) - int m_ = num_new_tokens; - int n = total_tokens; - int k = m->qProjSize; - int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads, - ldc = m_; - int strideA = q_block_size; - int strideB = kt_block_size; - int strideC = num_new_tokens * total_tokens; - - // a flag of using this scaling alpha - DT alpha = 1.0f, beta = 0.0f; - if (*m->qk_prod_scaling) { - alpha = static_cast
(1.0f / sqrt(m->kProjSize)); - } - // To get A, skip over Q entries from previous requests (same head) - DT const *A = static_cast
(m->devQKVProjArray) + - bc->requestsInfo[i].first_token_offset_in_batch * - m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM; - // To get B, skip over K entries from previous requests (all heads + - // padding) - DT const *B = static_cast
(m->keyCache) + - (i * bc->MAX_BEAM_WIDTH + sub_req_id) * kt_req_block_size; - - // if (i == 0 && sub_req_id == 0 && - // bc->beam_slots.at(0).current_depth == 1) { - // int offset = (float *)B - m->keyCache; - // printf("key cache offset %d\n", kt_req_block_size); - // } - // To get C, skip over QK^T products from previous requests - DT *C = static_cast
(m->qk_prods) + - m->num_q_heads * tokens_prev_requests_squares; - checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_N, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - // add alibi position bias to qk production - // add alibi position bias to qk production - if (*m->position_bias) { - size_t parallelism = m->num_q_heads * total_tokens * num_new_tokens; - apply_position_bias_qkprd<<>>(C, - num_new_tokens, - total_tokens, - m->num_q_heads, - m->global_num_q_heads, - shard_id); - } - // Fill all elements above diagonal in qk prods with -inf to force - // causal attention. - assert(num_new_tokens <= total_tokens); - if (num_new_tokens > 1) { - size_t parallelism = m->num_q_heads * num_new_tokens * total_tokens; - spec_fill_entries_above_diagonal<<>>( - C, - num_new_tokens, - total_tokens, - m->num_q_heads, - static_cast
(-INFINITY)); - } - // Compute Softmax(QK^T/sqrt(d_k)) - // Before modifying the parameters below, make sure to read the following - // description of the CUDNN_TENSOR_NCHW tensor layout, from - // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnTensorFormat_t: - // This tensor format specifies that the data is laid out in the following - // order: batch size, feature maps, rows, columns. The strides are - // implicitly defined in such a way that the data are contiguous in memory - // with no padding between images, feature maps, rows, and columns; the - // columns are the inner dimension and the images are the outermost - // dimension. - int n_param = m->num_q_heads; - int c_param = total_tokens; - int h_param = 1; - int w_param = num_new_tokens; - checkCUDNN(cudnnSetTensor4dDescriptor(m->qk_tensor, - CUDNN_TENSOR_NCHW, - cudnn_data_type, - n_param, - c_param, - h_param, - w_param)); - float softmax_alpha = 1.0f, softmax_beta = 0.0f; - DT *C_softmax = static_cast
(m->qk_prods_softmax) + - m->num_q_heads * tokens_prev_requests_squares; - // The softmax operation below is executed according to the - // CUDNN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The - // softmax operation is computed per spatial location (H,W) per image (N) - // across dimension C. - checkCUDNN(cudnnSoftmaxForward(m->handle.dnn, - CUDNN_SOFTMAX_ACCURATE, - CUDNN_SOFTMAX_MODE_CHANNEL, - &softmax_alpha, - m->qk_tensor, - C, - &softmax_beta, - m->qk_tensor, - C_softmax)); - // Matmul softmax(QK^T/sqrt(d_k)) by V - alpha = 1.0f, beta = 0.0f; - m_ = m->vProjSize; - n = num_new_tokens; - k = total_tokens; - lda = m_ * m->num_q_heads, ldb = n, ldc = m_ * m->num_q_heads; - strideA = vt_block_size; - strideB = num_new_tokens * total_tokens; - strideC = m->vProjSize; - // To get A, skip over V^T entries from previous requests (all heads + - // padding) - A = static_cast
(m->valueCache) + - (i * bc->MAX_BEAM_WIDTH + sub_req_id) * vt_req_block_size; - // To get B, skip over softmax(QK^T/sqrt(d_k)) entries from previous - // requests (all heads) - B = C_softmax; - // To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous - // requests - C = static_cast
(m->attn_heads) + - (tokens_previous_requests + bc->num_generation_tokens) * - m->num_q_heads * m->vProjSize; - checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, - CUBLAS_OP_N, - CUBLAS_OP_T, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - tokens_previous_requests += num_new_tokens; - tokens_prev_requests_squares += num_new_tokens * total_tokens; + if (num_new_tokens <= 0) { + continue; + } + + // Compute (QK^T/sqrt(d_k)) + int m_ = num_new_tokens; + int n = total_tokens; + int k = m->qProjSize; + int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads, + ldc = m_; + int strideA = q_block_size; + int strideB = kt_block_size; + int strideC = num_new_tokens * total_tokens; + + // a flag of using this scaling alpha + DT alpha = 1.0f, beta = 0.0f; + if (*m->qk_prod_scaling) { + alpha = static_cast
(1.0f / sqrt(m->kProjSize)); + } + // To get A, skip over Q entries from previous requests (same head) + DT const *A = static_cast
(m->devQKVProjArray) + + bc->requestsInfo[i].first_token_offset_in_batch * + m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM; + // To get B, skip over K entries from previous requests (all heads + + // padding) + + // print_tensor((float*)A, 32, "A"); + DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; + + // if (i == 0 && sub_req_id == 0 && + // bc->beam_slots.at(0).current_depth == 1) { + // int offset = (float *)B - m->keyCache; + // printf("key cache offset %d\n", kt_req_block_size); + // } + // To get C, skip over QK^T products from previous requests + DT *C = static_cast
(m->qk_prods) + + m->num_q_heads * tokens_prev_requests_squares; + checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, + CUBLAS_OP_T, + CUBLAS_OP_N, + m_, + n, + k, + &alpha, + A, + cublas_data_type, + lda, + strideA, + B, + cublas_data_type, + ldb, + strideB, + &beta, + C, + cublas_data_type, + ldc, + strideC, + m->num_q_heads, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + // print_tensor((float*)C, 32, "C"); + // add alibi position bias to qk production + // add alibi position bias to qk production + if (*m->position_bias) { + size_t parallelism = m->num_q_heads * total_tokens * num_new_tokens; + apply_position_bias_qkprd<<>>(C, + num_new_tokens, + total_tokens, + m->num_q_heads, + m->global_num_q_heads, + shard_id); } + // Fill all elements above diagonal in qk prods with -inf to force + // causal attention. + assert(num_new_tokens <= total_tokens); + if (num_new_tokens > 1) { + size_t parallelism = m->num_q_heads * num_new_tokens * total_tokens; + spec_fill_entries_above_diagonal<<>>(C, + num_new_tokens, + total_tokens, + m->num_q_heads, + static_cast
(-INFINITY)); + } + // Compute Softmax(QK^T/sqrt(d_k)) + // Before modifying the parameters below, make sure to read the following + // description of the CUDNN_TENSOR_NCHW tensor layout, from + // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnTensorFormat_t: + // This tensor format specifies that the data is laid out in the following + // order: batch size, feature maps, rows, columns. The strides are + // implicitly defined in such a way that the data are contiguous in memory + // with no padding between images, feature maps, rows, and columns; the + // columns are the inner dimension and the images are the outermost + // dimension. + int n_param = m->num_q_heads; + int c_param = total_tokens; + int h_param = 1; + int w_param = num_new_tokens; + checkCUDNN(cudnnSetTensor4dDescriptor(m->qk_tensor, + CUDNN_TENSOR_NCHW, + cudnn_data_type, + n_param, + c_param, + h_param, + w_param)); + float softmax_alpha = 1.0f, softmax_beta = 0.0f; + DT *C_softmax = static_cast
(m->qk_prods_softmax) + + m->num_q_heads * tokens_prev_requests_squares; + // The softmax operation below is executed according to the + // CUDNN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The + // softmax operation is computed per spatial location (H,W) per image (N) + // across dimension C. + checkCUDNN(cudnnSoftmaxForward(m->handle.dnn, + CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_CHANNEL, + &softmax_alpha, + m->qk_tensor, + C, + &softmax_beta, + m->qk_tensor, + C_softmax)); + // Matmul softmax(QK^T/sqrt(d_k)) by V + alpha = 1.0f, beta = 0.0f; + m_ = m->vProjSize; + n = num_new_tokens; + k = total_tokens; + lda = m_ * m->num_q_heads, ldb = n, ldc = m_ * m->num_q_heads; + strideA = vt_block_size; + strideB = num_new_tokens * total_tokens; + strideC = m->vProjSize; + // To get A, skip over V^T entries from previous requests (all heads + + // padding) + A = static_cast
(m->valueCache) + i * vt_req_block_size; + // To get B, skip over softmax(QK^T/sqrt(d_k)) entries from previous + // requests (all heads) + B = C_softmax; + // To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous + // requests + + // print_tensor((float*)C_softmax, 32, "C_softmax"); + C = static_cast
(m->attn_heads) + + (tokens_previous_requests + bc->num_generation_tokens) * + m->num_q_heads * m->vProjSize; + checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, + CUBLAS_OP_N, + CUBLAS_OP_T, + m_, + n, + k, + &alpha, + A, + cublas_data_type, + lda, + strideA, + B, + cublas_data_type, + ldb, + strideB, + &beta, + C, + cublas_data_type, + ldc, + strideC, + m->num_q_heads, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + tokens_previous_requests += num_new_tokens; + tokens_prev_requests_squares += num_new_tokens * total_tokens; } // assert(tokens_previous_requests == num_tokens); @@ -443,31 +703,8 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, DT *output_ptr, DT const *bias_ptr, cudaStream_t stream) { - // here because we need postion info in infernece 1 - cudaMemcpyAsync(m->token_infos, - &(bc->tokensInfo), - bc->num_active_tokens() * sizeof(BatchConfig::PerTokenInfo), - cudaMemcpyHostToDevice, - stream); - cudaMemcpyAsync(m->request_infos, - &(bc->requestsInfo), - bc->max_requests_per_batch() * - sizeof(BatchConfig::PerRequestInfo), - cudaMemcpyHostToDevice, - stream); - cudaMemcpyAsync(m->beam_token_infos, - &(bc->beamTokenInfo), - bc->num_active_tokens() * bc->MAX_BEAM_WIDTH * - sizeof(BeamSearchBatchConfig::BeamSearchPerTokenInfo), - cudaMemcpyHostToDevice, - stream); - cudaMemcpyAsync(m->beam_request_infos, - &(bc->beamRequestsInfo), - bc->max_requests_per_batch() * - sizeof(BeamSearchBatchConfig::BeamSearchPerRequestInfo), - cudaMemcpyHostToDevice, - stream); // phase 1: Implement kernel to compute KQV for input tokens + compute_qkv_kernel(m, bc, shard_id, @@ -479,7 +716,7 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, // phase 2: Update key/val cache update_kv_cache_kernel
(m, bc, stream); if (bc->num_generation_tokens > 0) { - compute_attention_kernel_generation
( + compute_spec_inc_attention_kernel_generation
( m, bc, static_cast
(m->attn_heads), stream); } // phase 3: Compute attention score @@ -488,16 +725,14 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, compute_attention_kernel_prompt( m, bc, shard_id, output_ptr, bias_ptr, weight_ptr, stream); } - // compute output production and bias together for all tokens - int num_tokens = - bc->num_active_tokens() * BeamSearchBatchConfig::MAX_BEAM_WIDTH; + int num_tokens = bc->num_active_tokens(); compute_o_prod_bias( m, bc, shard_id, output_ptr, weight_ptr, bias_ptr, num_tokens, stream); } -} // namespace SpecIncMultiHeadAttention +} // namespace SpecIncMultiHeadSelfAttention } // namespace Kernels /*static*/ @@ -529,25 +764,27 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( if (input.data_type == DT_HALF) { half const *bias_ptr = use_bias ? bias.get_half_ptr() : static_cast(nullptr); - Kernels::SpecIncMultiHeadAttention::inference_kernel(m, - bc, - shard_id, - input.get_half_ptr(), - weight.get_half_ptr(), - output.get_half_ptr(), - bias_ptr, - stream); + Kernels::SpecIncMultiHeadSelfAttention::inference_kernel( + m, + bc, + shard_id, + input.get_half_ptr(), + weight.get_half_ptr(), + output.get_half_ptr(), + bias_ptr, + stream); } else if (input.data_type == DT_FLOAT) { float const *bias_ptr = use_bias ? bias.get_float_ptr() : static_cast(nullptr); - Kernels::SpecIncMultiHeadAttention::inference_kernel(m, - bc, - shard_id, - input.get_float_ptr(), - weight.get_float_ptr(), - output.get_float_ptr(), - bias_ptr, - stream); + Kernels::SpecIncMultiHeadSelfAttention::inference_kernel( + m, + bc, + shard_id, + input.get_float_ptr(), + weight.get_float_ptr(), + output.get_float_ptr(), + bias_ptr, + stream); } else { assert(false && "Unspported data type"); } @@ -606,38 +843,23 @@ SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta( // allocate memory for the seqArray and reserve space { - int max_tokens_per_batch = BatchConfig::max_tokens_per_batch(); - size_t beam_tokeninfo_size = - max_tokens_per_batch * BeamSearchBatchConfig::MAX_BEAM_WIDTH; - size_t requestinfo_size = BeamSearchBatchConfig::max_requests_per_batch(); - size_t beam_requestinfo_size = - BeamSearchBatchConfig::max_requests_per_batch(); - size_t total_size = - beam_tokeninfo_size * - sizeof(BeamSearchBatchConfig::BeamSearchPerTokenInfo) + - beam_requestinfo_size * - sizeof(BeamSearchBatchConfig:: - BeamSearchPerRequestInfo); // more components will - // be added here later - - // We always directly allocate memory for small speculative models - gpu_mem_allocator.create_legion_instance(beam_search_reserve_inst, - total_size); beam_token_infos = - gpu_mem_allocator - .allocate_instance( - beam_tokeninfo_size); - // offset += beam_tokeninfo_size * - // sizeof(BeamSearchBatchConfig::BeamSearchPerTokenInfo); + reinterpret_cast( + reinterpret_cast(handler.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo) + + sizeof(BatchConfig::requestsInfo)); + beam_request_infos = - gpu_mem_allocator - .allocate_instance( - beam_requestinfo_size); - // offset += beam_requestinfo_size * - // sizeof(BeamSearchBatchConfig::BeamSearchPerRequestInfo); - // assert(offset == total_size); - assert(gpu_mem_allocator.instance_total_size == - gpu_mem_allocator.instance_allocated_size); + reinterpret_cast( + reinterpret_cast(handler.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo) + + sizeof(BatchConfig::requestsInfo) + + sizeof(BeamSearchBatchConfig::beamTokenInfo)); + causalMask = reinterpret_cast( + reinterpret_cast(handler.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + + sizeof(BeamSearchBatchConfig::beamTokenInfo) + + sizeof(BeamSearchBatchConfig::beamRequestsInfo)); } cudaStreamSynchronize(stream); diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index bc7d1017b7..b4af80976f 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -53,6 +53,7 @@ __global__ void compute_attention_kernel_fused_kernel( BatchConfig::PerRequestInfo *request_infos, int num_heads, int num_requests, + BatchConfig::BitMask *causalMask, int qk_smem_sz) { // q, k @@ -75,17 +76,28 @@ __global__ void compute_attention_kernel_fused_kernel( // request idx int const request_idx = blockIdx.y; + int const batch_config_request_id = + request_infos[request_idx].batch_config_request_id; + int const first_step = 0; - int const tlength = request_infos[request_idx].first_token_depth_in_request + - request_infos[request_idx].num_tokens_in_batch; - int const qlength = request_infos[request_idx].num_tokens_in_batch; + int const tlength = + request_infos[batch_config_request_id].first_token_depth_in_request + + request_infos[batch_config_request_id].num_tokens_in_batch; + int const qlength = + request_infos[batch_config_request_id].num_tokens_in_batch; + + BatchConfig::BitMask bitmask = causalMask[batch_config_request_id]; int first_token_idx = 0; for (int r = 0; r < request_idx; r++) { - first_token_idx += request_infos[request_idx].num_tokens_in_batch; + first_token_idx += request_infos[r].num_tokens_in_batch; } + // if(tidx == 0 && head_idx == 0){ + // printf("tree req: %d, %d\n", request_idx, first_token_idx); + // } + // shared memory objects extern __shared__ char smem_[]; @@ -115,7 +127,7 @@ __global__ void compute_attention_kernel_fused_kernel( constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; DT const *k_cache_batch = - key_cache + request_idx * max_seq_length * hidden_size + ki; + key_cache + batch_config_request_id * max_seq_length * hidden_size + ki; int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; @@ -126,11 +138,19 @@ __global__ void compute_attention_kernel_fused_kernel( q_vecs[ki_o][ii] = *reinterpret_cast( q_ptr + (hidden_size * QKV_WEIGHT_NUM * qi) + ki + ii * THREADS_PER_KEY * K_VEC_SIZE); + + // if (head_idx == 0 && qi == 1 && tidx == 0) { + // printf("laod q %d, %d %.10f\n", + // request_idx, + // qi,q_vecs[ki_o][ii].x); + // } } + __syncthreads(); for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { K_vec k[K_VECS_PER_THREAD]; int const ti_circ = ti % max_seq_length; + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { int jj = ii * THREADS_PER_KEY * K_VEC_SIZE; if (ti < tlength) { @@ -142,22 +162,28 @@ __global__ void compute_attention_kernel_fused_kernel( float qk = scale * Qk_dot::dot(q_vecs[ki_o], k); if (ti < tlength && tidx % THREADS_PER_KEY == 0) { - bool const mask = ti_circ >= tlength; - if (mask) { - assert(false); - } + bool const mask = + (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << qi)))); - int pos = ti * qlength + qi; - if (((pos / qlength) % tlength) > (pos % qlength + tlength - qlength)) { - qk = -FLT_MAX; - } qk_max = mask ? qk_max : fmaxf(qk_max, qk); - qk_smem[pos] = mask ? 0.f : qk; + // if (head_idx == 0 && qi == 0 && !mask) { + // printf("tree attn qkqkqkqk request id %d, %d %.10f, %.10f, %.10f\n + // ", + // request_idx, + // ti, + // qk, + // q_vecs[ki_o][0].x, + // k[0].x); + // } + qk_smem[ti - first_step] = mask ? 0.0f : qk; } } + __syncthreads(); +#pragma unroll for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); } @@ -176,7 +202,7 @@ __global__ void compute_attention_kernel_fused_kernel( // The warps finalize the reduction. qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; - +#pragma unroll for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); } @@ -184,12 +210,18 @@ __global__ void compute_attention_kernel_fused_kernel( // Broadcast to all the threads in the warp. qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - float exp_sum = 0.f; + // if (head_idx == 0 && qi == 9 && tidx == 0) { + // printf("tree attn first token qk_max %f\n", qk_max); + // } + float exp_sum = 0.f; for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { - float logit = __expf(qk_smem[ti * qlength + qi] - qk_max); + bool const mask = + (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << qi)))); + float logit = mask ? 0.0f : __expf(qk_smem[ti - first_step] - qk_max); exp_sum += logit; - qk_smem[ti * qlength + qi] = logit; + qk_smem[ti - first_step] = mask ? 0.0f : logit; } // Compute the sum. @@ -197,43 +229,51 @@ __global__ void compute_attention_kernel_fused_kernel( // softmax float inv_sum = __fdividef(1.f, exp_sum + 1.e-6); - for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { - qk_smem[ti * qlength + qi] *= inv_sum; + qk_smem[ti - first_step] *= inv_sum; } __syncthreads(); - } - // value projection - constexpr int V_VEC_SIZE = 16 / sizeof(DT); - // The value computed by this thread. - int vo = tidx / THREADS_PER_VALUE; - // The hidden dimensions computed by this particular thread. - int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; - constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + // value projection + constexpr int V_VEC_SIZE = 16 / sizeof(DT); + // A vector of V elements for the current timestep. + // using V_vec_k = typename V_vec_k_::Type; + // using V_vec_acum = typename V_vec_acum_fp32_::Type; - Out_sum out; - // The base pointer for the value in the cache buffer. - DT const *v_cache_batch = - value_cache + request_idx * max_seq_length * hidden_size + vi; + // The value computed by this thread. + int vo = tidx / THREADS_PER_VALUE; + // The hidden dimensions computed by this particular thread. + int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; - for (int qi = 0; qi < qlength; qi++) { + Out_sum out; zero(out); - __syncthreads(); + + // The base pointer for the value in the cache buffer. + DT const *v_cache_batch = + value_cache + batch_config_request_id * max_seq_length * hidden_size + + vi; + if (Dh == Dh_MAX || vi < Dh) { for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { // Load the values from the cache. int const ti_circ = ti % max_seq_length; - + // int const real_cache_idx = topology.real_token_pos[sub_req_idx][ti]; V_vec v = *reinterpret_cast( v_cache_batch + ti_circ * hidden_size + head_idx * per_head_size); - float logit = qk_smem[ti * qlength + qi]; - out = FlexFlow::fma(logit, cast_to_float(v), out); + + if (ti < tlength) { + bool const mask = + (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << qi)))); + float logit = mask ? 0.0f : qk_smem[ti - first_step]; + out = FlexFlow::fma(logit, cast_to_float(v), out); + } } } - // Make sure we can start writing to shared memory. + // // Make sure we can start writing to shared memory. __syncthreads(); // Run the final reduction amongst the different groups computing different @@ -268,6 +308,17 @@ __global__ void compute_attention_kernel_fused_kernel( output_ptr + (first_token_idx + qi) * hidden_size + head_idx * per_head_size + vi), out); + // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 1) { + // printf("tree attn final value, %.9f, %.9f, %.9f, %.9f, %d, %d\n", + // out.x, + // out.y, + // out.z, + // out.w, + // vi, + // (first_token_idx + qi) * hidden_size + head_idx * + // per_head_size + + // vi); + // } } } } @@ -286,9 +337,9 @@ __global__ void commit_tokens_kernel( int max_seq_len, int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens_to_commit * hidden_size * 2) { + CUDA_KERNEL_LOOP(i, num_tokens_to_commit * hidden_size) { - int token_pos = i / (hidden_size * KV_WEIGHT_NUM); + int token_pos = i / (hidden_size); int token_idx_in_last_batch = committedTokenInfos[token_pos].token_index; int offset = i % hidden_size; assert(token_idx_in_last_batch < num_active_tokens_in_last_batch); @@ -329,7 +380,8 @@ void commit_tokens(TreeIncMultiHeadSelfAttentionMeta const *m, m->vProjSize, num_tokens_to_commit, m->num_active_tokens, // number of active tokens in previous batch - BatchConfig::max_sequence_length(), + BatchConfig::max_sequence_length() + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, m->hidden_size); } } @@ -348,9 +400,9 @@ __global__ void update_tree_branch_kv_cache( int total_tokens_in_batch, int max_seq_len, int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens_in_branch * hidden_size * 2) { + CUDA_KERNEL_LOOP(i, num_tokens_in_branch * hidden_size) { - int token_idx = i / (hidden_size * KV_WEIGHT_NUM); + int token_idx = i / (hidden_size); int offset = i % hidden_size; token_idx += processed_tokens_in_batch; // get index in the whole batch @@ -375,6 +427,7 @@ __global__ void update_tree_branch_kv_cache_fused( DT *kCache_ptr, DT *vCache_ptr, TreeVerifyBatchConfig::PerTokenInfo const *tokenInfos, + BatchConfig::PerRequestInfo *request_infos, int qProjSize, int kProjSize, int vProjSize, @@ -392,10 +445,25 @@ __global__ void update_tree_branch_kv_cache_fused( DT vVal = devQKVProjArray[val_idx + hidden_size]; int const req_id = tokenInfos[token_idx].request_index; - int const tok_id = tokenInfos[token_idx].abs_depth_in_request; - kCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + + // int const tok_id = tokenInfos[token_idx].abs_depth_in_request; + + int const request_token_offset = + request_infos[req_id].first_token_offset_in_batch; + int const first_token_depth = + request_infos[req_id].first_token_depth_in_request; + + // if(i % hidden_size == 0){ + // printf("update token request id: %d, %d, %d real id %d, value%.10f\n", + // req_id, token_idx, request_token_offset,(token_idx + first_token_depth + // - request_token_offset), kVal); + // } + kCache_ptr[req_id * (hidden_size * max_seq_len) + + (token_idx + first_token_depth - request_token_offset) * + hidden_size + offset] = kVal; - vCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + + vCache_ptr[req_id * (hidden_size * max_seq_len) + + (token_idx + first_token_depth - request_token_offset) * + hidden_size + offset] = vVal; } } @@ -448,10 +516,12 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m, int q_block_size = m->qProjSize; int kt_block_size = m->kProjSize; int kt_req_block_size = - kt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); + kt_block_size * m->num_q_heads * BatchConfig::max_sequence_length() + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM; int vt_block_size = m->vProjSize; int vt_req_block_size = - vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); + vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length() + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM; assert(m->qProjSize == m->kProjSize); for (int i = 0; i < bc->max_requests_per_batch(); i++) { @@ -472,9 +542,6 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m, num_new_tokens++; } - std::cout << "num_new_tokens: " << num_new_tokens << "\n"; - assert(false); - int total_tokens_in_request = bc->tokensInfo[j].abs_depth_in_request + 1; assert(num_new_tokens >= 1 && total_tokens_in_request >= num_new_tokens); { @@ -716,7 +783,8 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m, #define LAUNCH_TREE_VERIFY_ATTENTION_SCORE_KERNEL( \ DT, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \ smem_size_in_bytes_tree
(m->qProjSize, \ - BatchConfig::max_sequence_length(), \ + BatchConfig::max_sequence_length() + \ + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, \ THDS_PER_VALUE, \ THDS_PER_BLOCK, \ bc, \ @@ -733,17 +801,19 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m, static_cast
(m->valueCache), \ output_ptr, \ scale, \ - BatchConfig::max_sequence_length(), \ + BatchConfig::max_sequence_length() + \ + BatchConfig::BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, \ BatchConfig::max_tokens_per_batch(), \ m->qProjSize, \ m->hidden_size, \ m->request_infos, \ m->num_q_heads, \ bc->num_active_requests(), \ + m->causalMask, \ smem_sz[0]) template -void compute_attention_kernel_fused(IncMultiHeadSelfAttentionMeta const *m, +void compute_attention_kernel_fused(TreeIncMultiHeadSelfAttentionMeta const *m, TreeVerifyBatchConfig const *bc, DT *output_ptr, cudaStream_t stream) { @@ -760,11 +830,12 @@ void compute_attention_kernel_fused(IncMultiHeadSelfAttentionMeta const *m, static_cast
(m->keyCache), static_cast
(m->valueCache), m->token_infos, + m->request_infos, m->qProjSize, m->kProjSize, m->vProjSize, num_new_tokens, - BatchConfig::max_sequence_length(), + BatchConfig::max_sequence_length() + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, m->hidden_size); dim3 grid(m->num_q_heads, bc->num_active_requests()); @@ -816,12 +887,20 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, // Note that m->num_active_tokens stores the number of active // tokens in the previous batch, which is needed for committing // keys/values to the key-value cache + // std::cout << "tokens to be committed: " << bc->num_tokens_to_commit << + // "\n"; + cudaMemcpyAsync(m->committed_token_infos, &(bc->committed_tokens), bc->num_tokens_to_commit * sizeof(TreeVerifyBatchConfig::CommittedTokensInfo), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(m->causalMask, + &(bc->causalMask), + bc->num_active_requests() * sizeof(BatchConfig::BitMask), + cudaMemcpyHostToDevice, + stream); commit_tokens
(m, bc, stream); // After commit we update m->num_active_tokens to be the number of active @@ -834,18 +913,6 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, m->bias_ptr, bias_ptr, m->biasSize, cudaMemcpyHostToDevice, stream); bias_ptr = static_cast
(m->bias_ptr); } - cudaMemcpyAsync(m->token_infos, - &(bc->tokensInfo), - bc->num_active_tokens() * - sizeof(TreeVerifyBatchConfig::PerTokenInfo), - cudaMemcpyHostToDevice, - stream); - cudaMemcpyAsync(m->request_infos, - &(bc->requestsInfo), - bc->max_requests_per_batch() * - sizeof(BatchConfig::PerRequestInfo), - cudaMemcpyHostToDevice, - stream); // phase 1: Implement kernel to compute KQV for input tokens compute_qkv_kernel(m, bc, @@ -991,27 +1058,16 @@ TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( // allocate memory for the seqArray and reserve space { - int max_tokens_per_batch = BatchConfig::max_tokens_per_batch(); - size_t committed_tokeninfo_size = max_tokens_per_batch; - size_t total_size = committed_tokeninfo_size * - sizeof(TreeVerifyBatchConfig::CommittedTokensInfo); - if (offload) { - // assert that we have enough reserved work space left - assert(gpu_mem_allocator.reserved_total_size - - gpu_mem_allocator.reserved_allocated_size >= - total_size); - committed_token_infos = - gpu_mem_allocator - .allocate_reserved( - committed_tokeninfo_size); - } else { - gpu_mem_allocator.create_legion_instance(committed_token_reserve_inst, - total_size); - committed_token_infos = - gpu_mem_allocator - .allocate_instance( - committed_tokeninfo_size); - } + + causalMask = reinterpret_cast( + reinterpret_cast(handler.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo)); + committed_token_infos = + reinterpret_cast( + reinterpret_cast(handler.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo) + + sizeof(BatchConfig::requestsInfo) + + sizeof(BatchConfig::causalMask)); } cudaStreamSynchronize(stream); diff --git a/src/runtime/inference_manager.cc b/src/runtime/inference_manager.cc index eb045e8159..8af0ed8978 100644 --- a/src/runtime/inference_manager.cc +++ b/src/runtime/inference_manager.cc @@ -318,7 +318,8 @@ FutureMap InferenceManager::inference(FFModel *model, found_input_operator = true; assert(op->numOutputs == 1); ParallelTensor pt = tensor_buffer[op->outputs[0]][batch_index]; - load_input_tokens_from_batch_config(bc, pt); + load_input_tokens_from_batch_config(bc, pt, model->handlers); + load_inference_metadata_batch_config(bc, model->handlers); } } @@ -348,11 +349,34 @@ FutureMap InferenceManager::inference(FFModel *model, }; void InferenceManager::load_input_tokens_from_batch_config( - BatchConfigFuture const &bc, ParallelTensor const input) { + BatchConfigFuture const &bc, + ParallelTensor const input, + FFHandler *handlers) { Context ctx = ff_config.lg_ctx; Runtime *runtime = ff_config.lg_hlr; size_t machine_view_hash = input->machine_view.hash(); ArgumentMap argmap; + Domain domain = runtime->get_index_space_domain(ctx, input->parallel_is); + + switch (domain.get_dim()) { +#define DIMFUNC(DIM) \ + case DIM: { \ + Rect rect = domain; \ + MachineView view = input->machine_view; \ + int idx = 0; \ + for (PointInRectIterator it(rect); it(); it++) { \ + argmap.set_point(*it, \ + TaskArgument(&handlers[view.get_device_id(*it)], \ + sizeof(FFHandler))); \ + } \ + break; \ + } + LEGION_FOREACH_N(DIMFUNC) +#undef DIMFUNC + default: + assert(false); + } + IndexLauncher launcher(RM_LOAD_TOKENS_TASK_ID, input->parallel_is, TaskArgument(nullptr, 0), @@ -368,6 +392,34 @@ void InferenceManager::load_input_tokens_from_batch_config( runtime->execute_index_space(ctx, launcher); } +void InferenceManager::load_inference_metadata_batch_config( + BatchConfigFuture const &bc, FFHandler *handlers) { + Context ctx = ff_config.lg_ctx; + Runtime *runtime = ff_config.lg_hlr; + ArgumentMap argmap; + + Domain domain = + runtime->get_index_space_domain(ctx, ff_config.all_gpu_task_is); + Rect<1> task_rect = domain; + + int idx = 0; + for (PointInRectIterator<1> it(task_rect); it(); it++) { + FFHandler handler = handlers[idx++]; + argmap.set_point(*it, TaskArgument(&handler, sizeof(FFHandler))); + } + + IndexLauncher launcher(RM_LOAD_BATCH_CONFIG_TASK_ID, + ff_config.all_gpu_task_is, + TaskArgument(nullptr, 0), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + FFConfig::DataParallelism_GPU); + launcher.add_future(bc); + runtime->execute_index_space(ctx, launcher); +} + void InferenceManager::load_positions(BatchConfigFuture const &bc, ParallelTensor position_input, int offset) { diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 92f0cff472..37605c44a4 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -1499,10 +1499,8 @@ FFRuntime::FFRuntime(FFConfig &config) { Context ctx = config.lg_ctx; ArgumentMap argmap; - Rect<1> task_rect(Point<1>(0), - Point<1>(config.workersPerNode * config.numNodes - 1)); - IndexSpaceT<1> task_is = runtime->create_index_space(ctx, task_rect); - + Domain domain = runtime->get_index_space_domain(ctx, config.all_gpu_task_is); + Rect<1> task_rect = domain; // int rank = 0; for (PointInRectIterator<1> it(task_rect); it(); it++) { FFInitInfo info; @@ -1518,7 +1516,7 @@ FFRuntime::FFRuntime(FFConfig &config) { // Init CUDA library on each worker IndexLauncher initLauncher(FF_INIT_TASK_ID, - task_is, + config.all_gpu_task_is, TaskArgument(NULL, 0), argmap, Predicate::TRUE_PRED, @@ -2993,6 +2991,12 @@ Op *FFModel::create_operator_from_layer( dims[num_dims].degree = 1; dims[num_dims].parallel_idx = -1; dims[num_dims].is_replica_dim = true; + if (config.computationMode == COMP_MODE_INFERENCE && + config.tensor_parallelism_degree > 1) { + dims[num_dims].size *= config.tensor_parallelism_degree; + dims[num_dims].degree *= config.tensor_parallelism_degree; + dims[num_dims].parallel_idx = 0; + } // create_parallel_tensor adds an NoOp into operators ParallelTensor pt = create_parallel_tensor_legion_ordering(num_dims + 1, @@ -3002,6 +3006,7 @@ Op *FFModel::create_operator_from_layer( 0, true /*gradients*/, tensor->tensor_guid); + assert(pt->get_shape().is_valid()); // assert that this tensor hasn't been mapped before assert(tensor->parallel_tensor == nullptr); tensor->parallel_tensor = pt; @@ -3260,12 +3265,12 @@ void FFModel::create_operators_from_layers() { if (config.computationMode == COMP_MODE_INFERENCE && config.tensor_parallelism_degree > 1 && l->op_type == OP_EMBEDDING) { assert(op->numOutputs == 1); - Replicate *repl = new Replicate(*this, - op->outputs[0], - op->outputs[0]->num_dims - 1, - config.tensor_parallelism_degree); - operators.push_back(repl); - op = repl; + // Replicate *repl = new Replicate(*this, + // op->outputs[0], + // op->outputs[0]->num_dims - 1, + // config.tensor_parallelism_degree); + // operators.push_back(repl); + // op = repl; } else if (config.computationMode == COMP_MODE_INFERENCE && config.tensor_parallelism_degree > 1 && (l->op_type == OP_INC_MULTIHEAD_SELF_ATTENTION || @@ -4076,6 +4081,10 @@ FFConfig::FFConfig() { Runtime *runtime = Runtime::get_runtime(); lg_hlr = runtime; lg_ctx = Runtime::get_context(); + Rect<1> task_rect(Point<1>(0), Point<1>(workersPerNode * numNodes - 1)); + // Create an index space for tasks running on all GPUs + all_gpu_task_is = runtime->create_index_space(lg_ctx, task_rect); + // field_space = runtime->create_field_space(lg_ctx); } @@ -4337,6 +4346,23 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar); } } + // RequestManager load metadata + { + TaskVariantRegistrar registrar(RM_LOAD_BATCH_CONFIG_TASK_ID, + "RequestManager Load meta data"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant( + registrar, "RequestManager Load metadata Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant( + registrar); + } + } // RequestManager prepare_next_batch { TaskVariantRegistrar registrar(RM_PREPARE_NEXT_BATCH_TASK_ID, diff --git a/src/runtime/model.cpp b/src/runtime/model.cpp index 6c482426eb..ad2b781567 100644 --- a/src/runtime/model.cpp +++ b/src/runtime/model.cpp @@ -131,6 +131,54 @@ FFHandler .wait(); handle.workSpace = workspaceInst.pointer_untyped(0, sizeof(char)); } + if (handle.offload_reserve_space_size > 0) { + // allocate memory for offload reserve space + Memory gpu_mem = Machine::MemoryQuery(Machine::get_machine()) + .only_kind(Memory::GPU_FB_MEM) + .best_affinity_to(task->target_proc) + .first(); + Realm::Rect<1, coord_t> bounds( + Realm::Point<1, coord_t>(0), + Realm::Point<1, coord_t>(handle.offload_reserve_space_size - 1)); + std::vector field_sizes; + field_sizes.push_back(sizeof(char)); + Realm::RegionInstance workspaceInst; + Realm::RegionInstance::create_instance(workspaceInst, + gpu_mem, + bounds, + field_sizes, + 0, + Realm::ProfilingRequestSet()) + .wait(); + handle.offload_reserve_space = + workspaceInst.pointer_untyped(0, sizeof(char)); + } else { + handle.offload_reserve_space = nullptr; + } + if (handle.batch_config_metadata_size > 0) { + // allocate memory for offload reserve space + Memory gpu_mem = Machine::MemoryQuery(Machine::get_machine()) + .only_kind(Memory::GPU_FB_MEM) + .best_affinity_to(task->target_proc) + .first(); + Realm::Rect<1, coord_t> bounds( + Realm::Point<1, coord_t>(0), + Realm::Point<1, coord_t>(handle.batch_config_metadata_size - 1)); + std::vector field_sizes; + field_sizes.push_back(sizeof(char)); + Realm::RegionInstance workspaceInst; + Realm::RegionInstance::create_instance(workspaceInst, + gpu_mem, + bounds, + field_sizes, + 0, + Realm::ProfilingRequestSet()) + .wait(); + handle.batch_config_metadata = + workspaceInst.pointer_untyped(0, sizeof(char)); + } else { + handle.batch_config_metadata = nullptr; + } // checkCUDA(hipMalloc(&handle.workSpace, handle.workSpaceSize)); #ifdef FF_USE_NCCL handle.ncclComm = NULL; diff --git a/src/runtime/model.cu b/src/runtime/model.cu index 17401a0f14..c885b29db2 100644 --- a/src/runtime/model.cu +++ b/src/runtime/model.cu @@ -151,6 +151,31 @@ FFHandler } else { handle.offload_reserve_space = nullptr; } + if (handle.batch_config_metadata_size > 0) { + // allocate memory for offload reserve space + Memory gpu_mem = Machine::MemoryQuery(Machine::get_machine()) + .only_kind(Memory::GPU_FB_MEM) + .best_affinity_to(task->target_proc) + .first(); + Realm::Rect<1, coord_t> bounds( + Realm::Point<1, coord_t>(0), + Realm::Point<1, coord_t>(handle.batch_config_metadata_size - 1)); + std::vector field_sizes; + field_sizes.push_back(sizeof(char)); + Realm::RegionInstance workspaceInst; + Realm::RegionInstance::create_instance(workspaceInst, + gpu_mem, + bounds, + field_sizes, + 0, + Realm::ProfilingRequestSet()) + .wait(); + handle.batch_config_metadata = + workspaceInst.pointer_untyped(0, sizeof(char)); + } else { + handle.batch_config_metadata = nullptr; + } + // checkCUDA(cudaMalloc(&handle.workSpace, handle.workSpaceSize)); #ifdef FF_USE_NCCL diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 7c37f3391e..89d4ddaed4 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -16,6 +16,7 @@ #include "flexflow/request_manager.h" #include "flexflow/parallel_ops/parallel_op.h" // #include "flexflow/tokenizers.h" +#include #include #include #include @@ -106,6 +107,11 @@ int RequestManager::get_max_sequence_length() { return max_sequence_length; } +void RequestManager::push_spec_infer_tree_width(int tree_width) { + assert(tree_width <= BeamSearchBatchConfig::MAX_BEAM_WIDTH); + spec_infer_tree_width.emplace_back(tree_width); +} + void RequestManager::register_tokenizer(ModelType type, int bos_token_id, int eos_token_id, @@ -358,6 +364,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, } } int num_generation_tokens = 0; + int num_active_req = -1; // Step 2: prepare the next batch for existing requests BatchConfig new_bc; @@ -406,13 +413,14 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, total_request_run_time += profile_info.finish_time - profile_info.start_time; profiling_requests[request.guid] = profile_info; - log_req_mgr.print("[Profile] guid(%zu) decoding_steps(%d) start(%.1lf) " - "finish(%.1lf) latency(%.1lf)", - request.guid, - profile_info.decoding_steps, - profile_info.start_time, - profile_info.finish_time, - profile_info.finish_time - profile_info.start_time); + log_req_mgr.print( + "[Profile] guid(%zu) llm_decoding_steps(%d) start(%.1lf) " + "finish(%.1lf) latency(%.1lf)", + request.guid, + profile_info.llm_decoding_steps, + profile_info.start_time, + profile_info.finish_time, + profile_info.finish_time - profile_info.start_time); // Write output to file if needed: if (!output_filepath.empty()) { std::ofstream outputFile(output_filepath, std::ios::app); @@ -420,8 +428,8 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, outputFile << "end-to-end latency: " << std::fixed << std::setprecision(3) << total_request_run_time << std::endl; - outputFile << "num decoding steps: " << profile_info.decoding_steps - << std::endl; + outputFile << "num decoding steps: " + << profile_info.llm_decoding_steps << std::endl; outputFile << "token IDs: "; for (int i = 0; i < request.tokens.size(); i++) { outputFile << request.tokens[i]; @@ -447,6 +455,8 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, old_bc.requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = old_bc.requestsInfo[i].max_sequence_length; + num_active_req++; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; if (new_bc.requestsInfo[i].first_token_depth_in_request + 1 == request.tokens.size()) { // Incremental phase @@ -469,7 +479,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, } // Update profiling profiling_requests[new_bc.requestsInfo[i].request_guid] - .decoding_steps++; + .llm_decoding_steps++; } } } @@ -483,6 +493,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, Request new_request = pending_request_queue.front(); pending_request_queue.pop(); // all_requests[new_request.guid] = new_request; + new_bc.requestsInfo[i].first_token_depth_in_request = 0; new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens; new_bc.requestsInfo[i].request_guid = new_request.guid; @@ -492,9 +503,11 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, new_bc.requestsInfo[i].max_sequence_length = new_request.max_sequence_length; new_bc.request_completed[i] = false; + num_active_req++; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; // add profile_info for the new request ProfileInfo profile_info; - profile_info.decoding_steps = 1; + profile_info.llm_decoding_steps = 1; profile_info.start_time = Realm::Clock::current_time_in_microseconds(); profiling_requests[new_request.guid] = profile_info; for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { @@ -567,6 +580,7 @@ BeamSearchBatchConfig int result_index = 0; int num_generation_tokens = 0; + int num_active_req = -1; for (int i = 0; i < BatchConfig::max_requests_per_batch(); i++) { if (old_bc.request_completed[i]) { @@ -602,6 +616,8 @@ BeamSearchBatchConfig committed_tokens[guid].emplace_back(abs_depth, result_index); } else if (abs_depth >= root_abs_depth) { tree_outputs.emplace_back(token_id, abs_depth + 1); + // std::cout << "committred tokens push: " << abs_depth + // << " ,result index: " << result_index << "\n"; committed_tokens[guid].emplace_back(abs_depth, result_index); if (verbose) { @@ -612,22 +628,23 @@ BeamSearchBatchConfig tree_outputs.back().second, token_id); } - std::cout << "Index within old batch: " << result_index << std::endl; - printf(" Input: [%d] %d ---> [%d] %d \n", - abs_depth, - old_bc.tokensInfo[result_index].token_id, - tree_outputs.back().second, - token_id); + // std::cout << "Index within old batch: " << result_index << std::endl; + // printf(" Input: [%d] %d ---> [%d] %d \n", + // abs_depth, + // old_bc.tokensInfo[result_index].token_id, + // tree_outputs.back().second, + // token_id); } result_index++; } if (request.status == Request::RUNNING) { + std::vector> verified_tokens = traverse_verify_tree(guid, dfs_tree_inputs.at(guid), tree_outputs); + log_req_mgr.print("Number of Verified Tokens = %zu", verified_tokens.size()); - // check if the request is finished if (verified_tokens.size() + request.tokens.size() >= request.max_sequence_length) { @@ -664,16 +681,18 @@ BeamSearchBatchConfig // Log profiling info ProfileInfo profile_info = profiling_requests[request.guid]; profile_info.finish_time = Realm::Clock::current_time_in_microseconds(); + profile_info.ssm_decoding_steps = 0; total_request_run_time += profile_info.finish_time - profile_info.start_time; profiling_requests[request.guid] = profile_info; - log_req_mgr.print("[Profile] guid(%zu) decoding_steps(%d) start(%.1lf) " - "finish(%.1lf) latency(%.1lf)", - request.guid, - profile_info.decoding_steps, - profile_info.start_time, - profile_info.finish_time, - profile_info.finish_time - profile_info.start_time); + log_req_mgr.print( + "[Profile] guid(%zu) llm_decoding_steps(%d) start(%.1lf) " + "finish(%.1lf) latency(%.1lf)", + request.guid, + profile_info.llm_decoding_steps, + profile_info.start_time, + profile_info.finish_time, + profile_info.finish_time - profile_info.start_time); // Write output to file if needed: if (!output_filepath.empty()) { @@ -682,8 +701,8 @@ BeamSearchBatchConfig outputFile << "end-to-end latency: " << std::fixed << std::setprecision(3) << total_request_run_time << std::endl; - outputFile << "num decoding steps: " << profile_info.decoding_steps - << std::endl; + outputFile << "num decoding steps: " + << profile_info.llm_decoding_steps << std::endl; outputFile << "token IDs: "; for (int i = 0; i < request.tokens.size(); i++) { outputFile << request.tokens[i]; @@ -709,6 +728,7 @@ BeamSearchBatchConfig new_bc.request_completed[i] = false; new_bc.request_running[i] = true; + num_active_req++; // Normal Request Info new_bc.requestsInfo[i].first_token_depth_in_request = @@ -719,6 +739,7 @@ BeamSearchBatchConfig new_bc.requestsInfo[i].max_sequence_length = old_bc.requestsInfo[i].max_sequence_length; new_bc.requestsInfo[i].num_tokens_in_batch = verified_tokens.size(); + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; // TODO: Beam Request Info, missing from VerifyTreeBatchConfig int new_max_depth = @@ -726,8 +747,14 @@ BeamSearchBatchConfig new_bc.requestsInfo[i].first_token_depth_in_request - verified_tokens.size(); new_bc.beamRequestsInfo[i].current_depth = 1; + + profiling_requests[request.guid].ssm_decoding_steps = 0; + + int ssm_decoding_steps = 0; new_bc.beamRequestsInfo[i].beam_size = - BeamSearchBatchConfig::MAX_BEAM_WIDTH; + spec_infer_tree_width.size() > ssm_decoding_steps + ? spec_infer_tree_width[ssm_decoding_steps] + : 1; new_bc.beamRequestsInfo[i].max_depth = std::min(new_max_depth, BeamSearchBatchConfig::MAX_BEAM_DEPTH); for (int j = 0; j < BeamSearchBatchConfig::MAX_BEAM_WIDTH; j++) { @@ -735,8 +762,14 @@ BeamSearchBatchConfig new_bc.beamRequestsInfo[i].probs[j] = 1; } + new_bc.beamRequestsInfo[i].sub_request_num = 1; + new_bc.sub_requests[i] = 1; + updateBitMask(new_bc.causalMask[i], + verified_tokens.size(), + request.tokens.size()); + // Token Info for (int j = 0; j < verified_tokens.size(); j++) { auto token = verified_tokens.at(j); @@ -758,6 +791,7 @@ BeamSearchBatchConfig break; } } + std::string output = this->tokenizer_->Decode(request.tokens); // Unlike Huggingface, the sentencepiece C++ library automatically // removes the BOS token @@ -767,9 +801,11 @@ BeamSearchBatchConfig } log_req_mgr.print("Output: %s", output.c_str()); } + } else if (request.status == Request::PENDING) { new_bc.request_completed[i] = false; new_bc.request_running[i] = false; + num_active_req++; std::cout << "ssm_cache_size: " << request.ssm_cache_size << ", " << "initial_len: " << request.initial_len << std::endl; @@ -783,17 +819,24 @@ BeamSearchBatchConfig new_bc.requestsInfo[i].max_sequence_length = old_bc.requestsInfo[i].max_sequence_length; new_bc.requestsInfo[i].num_tokens_in_batch = 0; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; // TODO: Beam Request Info, missing from VerifyTreeBatchConfig new_bc.beamRequestsInfo[i].current_depth = 1; + int ssm_decoding_steps = + profiling_requests[request.guid].ssm_decoding_steps; new_bc.beamRequestsInfo[i].beam_size = - BeamSearchBatchConfig::MAX_BEAM_WIDTH; + spec_infer_tree_width.size() > ssm_decoding_steps + ? spec_infer_tree_width[ssm_decoding_steps] + : 1; new_bc.beamRequestsInfo[i].max_depth = 0; for (int j = 0; j < BeamSearchBatchConfig::MAX_BEAM_WIDTH; j++) { new_bc.beamRequestsInfo[i].parent_id[j] = 0; new_bc.beamRequestsInfo[i].probs[j] = 1; } + new_bc.beamRequestsInfo[i].sub_request_num = 1; + new_bc.sub_requests[i] = 1; // Token Info @@ -818,6 +861,7 @@ BeamSearchBatchConfig Request new_request = pending_request_queue.front(); pending_request_queue.pop(); // all_requests[new_request.guid] = new_request; + num_active_req++; new_bc.requestsInfo[i].first_token_depth_in_request = 0; new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens; new_bc.requestsInfo[i].request_guid = new_request.guid; @@ -826,15 +870,21 @@ BeamSearchBatchConfig (int)new_request.tokens.size()); new_bc.requestsInfo[i].max_sequence_length = new_request.max_sequence_length; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; // add profile_info for the new request ProfileInfo profile_info; - profile_info.decoding_steps = 0; + profile_info.llm_decoding_steps = 0; + profile_info.ssm_decoding_steps = 0; profile_info.start_time = Realm::Clock::current_time_in_microseconds(); profiling_requests[new_request.guid] = profile_info; // init the beam search metadata per request + int ssm_decoding_steps = profile_info.ssm_decoding_steps; + new_bc.beamRequestsInfo[i].beam_size = - BeamSearchBatchConfig::MAX_BEAM_WIDTH; + spec_infer_tree_width.size() > ssm_decoding_steps + ? spec_infer_tree_width[ssm_decoding_steps] + : 1; new_bc.beamRequestsInfo[i].current_depth = 1; new_bc.beamRequestsInfo[i].max_depth = std::min(BeamSearchBatchConfig::MAX_BEAM_DEPTH, @@ -846,6 +896,11 @@ BeamSearchBatchConfig } new_bc.request_completed[i] = false; + + new_bc.beamRequestsInfo[i].sub_request_num = 1; + printf("sub request num == 1, %d \n", + new_bc.beamRequestsInfo[i].beam_size); + new_bc.sub_requests[i] = 1; for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { @@ -862,6 +917,9 @@ BeamSearchBatchConfig new_bc.num_tokens++; } + initBitMask(new_bc.causalMask[i], + new_bc.requestsInfo[i].num_tokens_in_batch); + // if (new_bc.requestsInfo[i].num_tokens_in_batch < // new_request.initial_len) { // all_requests[new_request.guid].status = Request::PENDING; @@ -949,6 +1007,8 @@ BeamSearchBatchConfig } std::cout << "Current Beam Depth: " << old_bc.beamRequestsInfo[0].current_depth << "\n"; + std::cout << "Current sub request num: " + << old_bc.beamRequestsInfo[0].sub_request_num << "\n"; } // Step 1: Store result to the beam tree struct store_beam_metadata(old_bc, result); @@ -960,10 +1020,12 @@ BeamSearchBatchConfig int num_generation_tokens = 0; // Add incremental tokens to the batch + int num_active_req = -1; for (int i = 0; i < BatchConfig::max_requests_per_batch(); i++) { if (old_bc.request_completed[i] || !old_bc.request_running[i]) { continue; } + num_active_req++; // Comment out this assertion since num_tokens_in_batch can be // zero when beam search has reached required sequence length // assert(old_bc.requestsInfo[i].num_tokens_in_batch > 0); @@ -973,29 +1035,6 @@ BeamSearchBatchConfig // assert(processed_tokens < request.tokens.size()); log_req_mgr.debug() << "processed_tokens: " << processed_tokens << "\n"; - // if (processed_tokens > - // old_bc.beamRequestsInfo[i].max_depth + request.tokens.size() && - // request.status == Request::RUNNING - // // || ir.results[t] == 0 TODO: replace this with - // ) { - // // log_req_mgr.print("[Done] guid(%zu) with spec_tree_depth(%d)", - // // old_bc.requestsInfo[i].request_guid, - // // old_bc.beamRequestsInfo[i].max_depth); - // // // new_bc.request_completed[i] = true; - // // new_bc.request_completed[i] = false; - // // new_bc.requestsInfo[i].first_token_depth_in_request = - // processed_tokens; - // // new_bc.requestsInfo[i].request_guid = - // // old_bc.requestsInfo[i].request_guid; - // // new_bc.requestsInfo[i].max_sequence_length = - // // old_bc.requestsInfo[i].max_sequence_length; - // // new_bc.beamRequestsInfo[i].current_depth = - // // old_bc.beamRequestsInfo[i].current_depth; - // // new_bc.request_running[i] = false; - // std::cout << "beam search end:" << request.status << i << ", " - // << new_bc.requestsInfo[i].num_tokens_in_batch << "\n"; - // } - // else { log_req_mgr.debug() << "num tokens: " << old_bc.num_tokens << ", " << new_bc.num_tokens; @@ -1005,25 +1044,42 @@ BeamSearchBatchConfig new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = old_bc.requestsInfo[i].max_sequence_length; - + profiling_requests[request.guid].ssm_decoding_steps += 1; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; // update the beam search metadata // how many sub request in current request // why is sub_requests has max_requests_per_batch() * MAX_BEAM_WIDTH // entries? - new_bc.sub_requests[i] = old_bc.beamRequestsInfo[i].beam_size; - // update the parentid, accumalated_probs, depth, and token_ids + int ssm_decoding_steps = + profiling_requests[request.guid].ssm_decoding_steps; + new_bc.beamRequestsInfo[i].beam_size = - old_bc.beamRequestsInfo[i].beam_size; + spec_infer_tree_width.size() > ssm_decoding_steps + ? spec_infer_tree_width[ssm_decoding_steps] + : 1; + new_bc.beamRequestsInfo[i].max_depth = old_bc.beamRequestsInfo[i].max_depth; + new_bc.sub_requests[i] = + old_bc.sub_requests[i] * new_bc.beamRequestsInfo[i].beam_size; + new_bc.beamRequestsInfo[i].sub_request_num = + old_bc.beamRequestsInfo[i].sub_request_num * + old_bc.beamRequestsInfo[i].beam_size; + + assert(new_bc.beamRequestsInfo[i].sub_request_num <= + BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES && + "exceed maximum nodes per layer"); + if (request.status == Request::RUNNING) { new_bc.beamRequestsInfo[i].current_depth = old_bc.beamRequestsInfo[i].current_depth + 1; new_bc.request_running[i] = true; // do the slot exchange to minimize the cache exchange in kernel. - update_beam_metadata(new_bc, request.beam_trees.at(old_bc.model_id), i); + update_beam_metadata( + new_bc, old_bc, request.beam_trees.at(old_bc.model_id), i); + } else { assert(false && "Request should not be pending in beam search phase"); } @@ -1035,6 +1091,7 @@ BeamSearchBatchConfig request.tokens.size()) { // Incremental phase if (request.status == Request::RUNNING) { + // todo this is replaced by this_layer_size, but should check it new_bc.requestsInfo[i].num_tokens_in_batch = 1; } else { assert(false && "Request should be done"); @@ -1057,9 +1114,22 @@ BeamSearchBatchConfig } // register more tokens due to the beam width + + // copy metadata + memcpy(&new_bc.causalMask[i], + &old_bc.causalMask[i], + sizeof(BatchConfig::BitMask)); + BeamTree tree = request.beam_trees[old_bc.model_id]; + appendBitMask(new_bc.causalMask[i], + new_bc.beamRequestsInfo[i].sub_request_num, + old_bc.beamRequestsInfo[i].beam_size, + old_bc.beamRequestsInfo[i].sub_request_num, + tree, + old_bc.beamRequestsInfo[i].current_depth); + // assert(false); for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j; - for (int k = 0; k < new_bc.sub_requests[i]; k++) { + for (int k = 0; k < new_bc.beamRequestsInfo[i].sub_request_num; k++) { new_bc.tokensInfo[new_bc.num_tokens].request_index = i; new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth; @@ -1069,6 +1139,8 @@ BeamSearchBatchConfig new_bc.beamTokenInfo[new_bc.num_tokens].sub_request_index = k; new_bc.num_tokens++; + + num_generation_tokens++; } } } @@ -1079,6 +1151,7 @@ BeamSearchBatchConfig if (old_bc.request_completed[i] || old_bc.request_running[i]) { continue; } + num_active_req++; // Comment out this assertion since num_tokens_in_batch can be // zero when beam search has reached required sequence length // assert(old_bc.requestsInfo[i].num_tokens_in_batch > 0); @@ -1098,18 +1171,34 @@ BeamSearchBatchConfig new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = old_bc.requestsInfo[i].max_sequence_length; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; // update the beam search metadata // how many sub request in current request // why is sub_requests has max_requests_per_batch() * MAX_BEAM_WIDTH // entries? - new_bc.sub_requests[i] = old_bc.beamRequestsInfo[i].beam_size; + int ssm_decoding_steps = + profiling_requests[request.guid].ssm_decoding_steps; - // update the parentid, accumalated_probs, depth, and token_ids new_bc.beamRequestsInfo[i].beam_size = - old_bc.beamRequestsInfo[i].beam_size; + spec_infer_tree_width.size() > ssm_decoding_steps + ? spec_infer_tree_width[ssm_decoding_steps] + : 1; + printf("beam size: %d, %d\n", + new_bc.beamRequestsInfo[i].beam_size, + ssm_decoding_steps); new_bc.beamRequestsInfo[i].max_depth = old_bc.beamRequestsInfo[i].max_depth; + new_bc.sub_requests[i] = + old_bc.sub_requests[i] * new_bc.beamRequestsInfo[i].beam_size; + new_bc.beamRequestsInfo[i].sub_request_num = + old_bc.beamRequestsInfo[i].sub_request_num; + + assert(new_bc.beamRequestsInfo[i].sub_request_num <= + BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES && + "exceed maximum nodes per layer"); + + // update the parentid, accumalated_probs, depth, and token_ids if (request.status == Request::PENDING) { // if the request is pending, we need to update the beam search @@ -1121,6 +1210,10 @@ BeamSearchBatchConfig assert(false && "Request should be pending"); } + memcpy(&new_bc.causalMask[i], + &old_bc.causalMask[i], + sizeof(BatchConfig::BitMask)); + if (new_bc.requestsInfo[i].first_token_depth_in_request >= request.tokens.size()) { // request is done @@ -1133,6 +1226,13 @@ BeamSearchBatchConfig (int)request.tokens.size() - new_bc.requestsInfo[i].first_token_depth_in_request); request.ssm_cache_size += new_bc.requestsInfo[i].num_tokens_in_batch; + BeamTree tree = request.beam_trees[old_bc.model_id]; + appendBitMask(new_bc.causalMask[i], + new_bc.beamRequestsInfo[i].sub_request_num, + old_bc.beamRequestsInfo[i].beam_size, + old_bc.beamRequestsInfo[i].sub_request_num, + tree, + old_bc.beamRequestsInfo[i].current_depth); } if (verbose) { @@ -1152,7 +1252,7 @@ BeamSearchBatchConfig // register more tokens due to the beam width for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j; - for (int k = 0; k < new_bc.sub_requests[i]; k++) { + for (int k = 0; k < new_bc.beamRequestsInfo[i].sub_request_num; k++) { new_bc.tokensInfo[new_bc.num_tokens].request_index = i; new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth; @@ -1229,21 +1329,20 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( max_prompt_load_size -= 1; } } - + int num_active_req = -1; for (int i = 0; i < TreeVerifyBatchConfig::max_requests_per_batch(); i++) { if (old_batches.at(0).request_completed[i]) { continue; } + num_active_req++; size_t guid = old_batches.at(0).requestsInfo[i].request_guid; Request &request = all_requests[guid]; // Profiling - profiling_requests[request.guid].decoding_steps += 1; + profiling_requests[request.guid].llm_decoding_steps += 1; if (request.status == Request::RUNNING) { new_bc.request_running[i] = true; - std::cout << "[Verify] Request " << request.guid << " is running" - << std::endl; // Get the dfs tree std::vector>> @@ -1274,31 +1373,44 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( old_batches.at(0).requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = old_batches.at(0).requestsInfo[i].max_sequence_length; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; + + // copy bitmask to verify batchconfig + memcpy(&(new_bc.causalMask[i]), + &(old_batches.at(0).causalMask[i]), + sizeof(BatchConfig::BitMask)); // TODO: Check this new_bc.requestsInfo[i].num_tokens_in_batch = 0; new_bc.request_completed[i] = false; + // std::cout << "dfs_tree_inputs: " << dfs_tree_inputs.size() << ", " + // << new_bc.causalMask[i].tree_size << ", " + // << new_bc.causalMask[i].non_tree_cache_size << "\n"; + // std::cout << "mask: " << std::bitset<64>(new_bc.causalMask[i].mask[0]) + // << "\n"; + // Committed Tokens if (committed_tokens.find(guid) != committed_tokens.end()) { - for (int j = 0; j < dfs_tree_inputs.size(); j++) { - if (j < committed_tokens.at(guid).size()) { - auto committed_token = committed_tokens.at(guid).at(j); - new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_index = - committed_token.second; - new_bc.committed_tokens[new_bc.num_tokens_to_commit].request_index = - i; - new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_depth = - committed_token.first; - if (verbose) { - std::cout << new_bc.num_tokens_to_commit - << "- committed_token.token_depth: " - << committed_token.first - << ", token_index: " << committed_token.second - << std::endl; - } - new_bc.num_tokens_to_commit++; - request.llm_cache_size++; + for (int j = 0; j < committed_tokens.at(guid).size(); j++) { + // if (j < committed_tokens.at(guid).size()) { + + auto committed_token = committed_tokens.at(guid).at(j); + new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_index = + committed_token.second; + new_bc.committed_tokens[new_bc.num_tokens_to_commit].request_index = + i; + new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_depth = + committed_token.first; + if (verbose) { + std::cout << new_bc.num_tokens_to_commit + << "- committed_token.token_depth: " + << committed_token.first + << ", token_index: " << committed_token.second + << std::endl; } + new_bc.num_tokens_to_commit++; + request.llm_cache_size++; + // } } } if (verbose) { @@ -1324,6 +1436,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.requestsInfo[i].first_token_depth_in_request = request.tokens.size() - 1; + bool cutLayer = false; // Add Tokens from the DFS Tree to the next batch for (int j = 1; j < dfs_tree_inputs.size(); j++) { auto token = dfs_tree_inputs.at(j); @@ -1340,11 +1453,27 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.num_tokens++; new_bc.requestsInfo[i].num_tokens_in_batch++; - if (new_bc.num_tokens == get_max_tokens_per_batch() - 1) { + if (new_bc.num_tokens == get_max_tokens_per_batch() && + (j != dfs_tree_inputs.size() - 1)) { + cutLayer = true; break; } } + // delete the last incomplete layer + if (cutLayer) { + int total_tokens = new_bc.num_tokens; + for (int j = total_tokens - 1; j >= 1; j--) { + new_bc.num_tokens--; + new_bc.requestsInfo[i].num_tokens_in_batch--; + // std::cout << "cut: " << j << "\n"; + if (new_bc.tokensInfo[j].abs_depth_in_request != + new_bc.tokensInfo[j - 1].abs_depth_in_request) { + break; + } + } + } + } else if (request.status == Request::PENDING) { new_bc.request_running[i] = false; if (verbose) { @@ -1374,6 +1503,10 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( << new_bc.num_tokens_to_commit << std::endl; } + memcpy(&(new_bc.causalMask[i]), + &(old_batches.at(0).causalMask[i]), + sizeof(BatchConfig::BitMask)); + // Normal Request Info new_bc.requestsInfo[i].first_token_depth_in_request = request.llm_cache_size; @@ -1382,6 +1515,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( old_batches.at(0).requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = old_batches.at(0).requestsInfo[i].max_sequence_length; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; new_bc.request_completed[i] = false; @@ -1395,6 +1529,9 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( << std::endl; if (request.llm_cache_size < request.initial_len) { + // std::cout << "Initialization (prompt) phase: " + // << new_bc.requestsInfo[i].num_tokens_in_batch << ", " + // << old_batches.at(0).beamRequestsInfo[i].beam_size << "\n"; // Initialization (prompt) phase for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { new_bc.tokensInfo[new_bc.num_tokens].request_index = i; @@ -1402,7 +1539,6 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( request.tokens[request.llm_cache_size + j]; new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = request.llm_cache_size + j; - new_bc.num_tokens++; } @@ -1428,6 +1564,8 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( } } else { // launch the request into running phase after loading all prompt if (get_max_tokens_per_batch() - new_bc.num_tokens > 0) { + // std::cout << "Initialization running phase: " + // << new_bc.requestsInfo[i].num_tokens_in_batch << "\n"; request.status = Request::RUNNING; new_bc.request_running[i] = true; @@ -1476,26 +1614,41 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, old_bc.requestsInfo[old_bc.tokensInfo[i].request_index].request_guid != guid) { + // std::cout << "i is: " << i << "old guid" << guid << " new guid" + // << old_bc.requestsInfo[old_bc.tokensInfo[i].request_index] + // .request_guid + // << "\n"; + int index = old_bc.tokensInfo[i - 1].request_index; int beam_size = old_bc.beamRequestsInfo[index].beam_size; + + // int leaf_node_num = old_bc.sub_requests[index]; + int leaf_node_num = + old_bc.beamRequestsInfo[index].sub_request_num * beam_size; int depth = old_bc.beamRequestsInfo[index].current_depth; // Each token yields (beam_width) results - int beam_width = old_bc.beamRequestsInfo[index].beam_size; + // int beam_width = old_bc.beamRequestsInfo[index].beam_size; // Count tokens sent to model in this request to find the final token's // index result_index += (old_bc.tokensInfo[i - 1].abs_depth_in_request - start_depth) * - beam_width; + beam_size; if (verbose) { std::cout << "i = " << i << ", result index = " << result_index - << ", value: " << result.token_ids[result_index] << "\n"; + << ", value: " << result.token_ids[result_index] + << ", leaf node num: " << leaf_node_num << ", depth" << depth + << ", beam size: " << beam_size << "\n"; } Request &request = all_requests[old_bc.requestsInfo[index].request_guid]; + if (old_bc.requestsInfo[index].num_tokens_in_batch == 0) { + continue; + } + if (depth == 1) { // store the last input into the tree; if (verbose) { @@ -1507,14 +1660,20 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, request.tokens.back(); request.beam_trees.at(old_bc.model_id).treeLayers[0].probs[0] = 1; request.beam_trees.at(old_bc.model_id).treeLayers[0].parent_ids[0] = -1; + request.beam_trees.at(old_bc.model_id) + .treeLayers[0] + .nodes_num_this_layer = 1; if (verbose) { std::cout << "Store the previous last token to the tree root: " << request.tokens.back() << "\n"; } } + request.beam_trees.at(old_bc.model_id) + .treeLayers[depth] + .nodes_num_this_layer = leaf_node_num; + for (int beam_id = 0; beam_id < leaf_node_num; beam_id++) { - for (int beam_id = 0; beam_id < beam_width; beam_id++) { request.beam_trees.at(old_bc.model_id) .treeLayers[depth] .tokens[beam_id] = result.token_ids[result_index]; @@ -1534,10 +1693,10 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, } result_index += 1; } - // update the guid and start_depth for current request if (i < old_bc.num_tokens) { - guid = old_bc.requestsInfo[index].request_guid; + int new_req_idx = old_bc.tokensInfo[i].request_index; + guid = old_bc.requestsInfo[new_req_idx].request_guid; start_depth = old_bc.tokensInfo[i].abs_depth_in_request; } } @@ -1546,6 +1705,7 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, // for updating the beam search metadata in requests in incremental phase void RequestManager::update_beam_metadata(BeamSearchBatchConfig &new_bc, + BeamSearchBatchConfig const &old_bc, BeamTree &tree, int request_index) { @@ -1556,6 +1716,9 @@ void RequestManager::update_beam_metadata(BeamSearchBatchConfig &new_bc, int depth = new_bc.beamRequestsInfo[request_index].current_depth - 1; int beam_size = new_bc.beamRequestsInfo[request_index].beam_size; + // int leaf_node_num = old_bc.sub_requests[request_index]; + int leaf_node_num = new_bc.beamRequestsInfo[request_index].sub_request_num; + if (new_bc.beamRequestsInfo[request_index].current_depth == 1) { // TODO: check if this is correct // for (int j = 0; j < beam_size; j++) { @@ -1568,48 +1731,15 @@ void RequestManager::update_beam_metadata(BeamSearchBatchConfig &new_bc, // Do nothing // assert(false); } else { - std::set parents; - std::set childs; - // cache stealing - for (int j = 0; j < beam_size; j++) { - int parent_id = tree.treeLayers[depth].parent_ids[j]; - if (childs.find(parent_id) == childs.end()) { - // copy beam slot - new_bc.beamRequestsInfo[request_index].parent_id[parent_id] = - tree.treeLayers[depth].parent_ids[j]; - new_bc.beamRequestsInfo[request_index].probs[parent_id] = - tree.treeLayers[depth].probs[j]; - new_bc.beamRequestsInfo[request_index].tokens[parent_id] = - tree.treeLayers[depth].tokens[j]; - parents.emplace(j); - childs.emplace(parent_id); - } - } - if (parents.size() < beam_size) { - for (int j = 0; j < beam_size; j++) { - if (parents.find(j) == parents.end()) { - // this slot has not been assigned - // find the smallest not assigned child and put in - if (verbose) { - std::cout << "request_index" << request_index - << ", miss slot: " << j << "\n"; - } - for (int k = 0; k < beam_size; k++) { - if (childs.find(k) == childs.end()) { - // parent -> j to child k; - new_bc.beamRequestsInfo[request_index].parent_id[k] = - tree.treeLayers[depth].parent_ids[j]; - new_bc.beamRequestsInfo[request_index].probs[k] = - tree.treeLayers[depth].probs[j]; - new_bc.beamRequestsInfo[request_index].tokens[k] = - tree.treeLayers[depth].tokens[j]; - parents.emplace(j); - childs.emplace(k); - break; - } - } - } - } + for (int j = 0; j < leaf_node_num; j++) { + new_bc.beamRequestsInfo[request_index].parent_id[j] = + tree.treeLayers[depth].parent_ids[j]; + new_bc.beamRequestsInfo[request_index].probs[j] = + tree.treeLayers[depth].probs[j]; + new_bc.beamRequestsInfo[request_index].tokens[j] = + tree.treeLayers[depth].tokens[j]; + // std::cout << "token: " << j << ": " + // << new_bc.beamRequestsInfo[request_index].tokens[j] << "\n"; } } if (verbose) { @@ -1625,6 +1755,139 @@ void RequestManager::update_beam_metadata(BeamSearchBatchConfig &new_bc, } } +// bit mask related function + +// prompt phase, init task +void RequestManager::initBitMask(BatchConfig::BitMask &bitmask, + int initLength) { + assert(initLength <= BatchConfig::MAX_SPEC_TREE_TOKEN_NUM && + "do not support tree size > 64"); + // eg. 4 tokens: t1: 0000000..1111, t2: 0000000..1110, t3: 0000000..1100, t4: + // 0000000..1000 + bitmask.non_tree_cache_size = 0; + bitmask.tree_size = initLength; + + bitmask.prompt_size = initLength; + bitmask.this_layer_size = initLength; + for (int i = 0; i < bitmask.prompt_size; i++) { + for (int j = i; j < bitmask.prompt_size; j++) { + bitmask.mask[i] |= (1 << j); + } + } + // std::cout << "see bit mask" << bitmask.prompt_size << "\n"; + // std::cout << "see bit mask" << std::bitset<64>(bitmask.mask[0]) << "\n"; + // std::cout << "see bit mask" << std::bitset<64>(bitmask.mask[1]) << "\n"; + // std::cout << "see bit mask" << std::bitset<64>(bitmask.mask[2]) << "\n"; +} + +// prepare next init +void RequestManager::updateBitMask(BatchConfig::BitMask &bitmask, + int initLength, + int non_tree_size) { + // assert(initLength == 1); + // eg. 4 tokens: t1: 0000000..1111, t2: 0000000..1110, t3: 0000000..1100, t4: + // 0000000..1000 + assert(initLength <= BatchConfig::MAX_SPEC_TREE_TOKEN_NUM && + "do not support tree size > 64"); + assert(initLength >= 1 && "verified token num should >= 1"); + + // std::cout << "non tree size: " << non_tree_size << ", " + // << bitmask.non_tree_cache_size << "\n"; + + bitmask.non_tree_cache_size = non_tree_size + initLength - 1; + bitmask.tree_size = 1; + bitmask.this_layer_size = initLength; + // std::cout << "non_tree_size: " << non_tree_size << "\n"; + bitmask.prompt_size = 1; + for (int i = 0; i < bitmask.prompt_size; i++) { + for (int j = i; j < bitmask.prompt_size; j++) { + bitmask.mask[i] |= (1 << j); + } + } + + // std::cout << "see bit mask update" << bitmask.prompt_size << "\n"; + // std::cout << "see bit mask update" << std::bitset<64>(bitmask.mask[0]) + // << "\n"; +} + +// prepare next beam, append layers to the tree +void RequestManager::appendBitMask(BatchConfig::BitMask &bitmask, + int newNodes, + int preBeamSize, + int old_sub_num, + BeamTree const tree, + int currentDepth) { + int pre_tree_size = bitmask.tree_size; + bitmask.tree_size += newNodes; + bitmask.this_layer_size = newNodes; + assert(bitmask.tree_size <= BatchConfig::MAX_SPEC_TREE_TOKEN_NUM && + "do not support tree size > 64"); + // preBeamSize: replicate num + + // add relationship with input/prompt + for (int i = 0; i < bitmask.prompt_size; i++) { + for (int j = pre_tree_size; j < bitmask.tree_size; j++) { + bitmask.mask[i] |= (1 << j); + // std::cout << "see bit mask append: " << i << ", to" << j + // << std::bitset<64>(bitmask.mask[i]) << "\n"; + } + } + + // std::cout << "bitmask.tree_size: " << bitmask.tree_size << ", " + // << pre_tree_size << ", " << bitmask.prompt_size << ", " + // << preBeamSize << "\n"; + + // int num_groups = newNodes / preBeamSize; + // int group_size = newNodes / num_groups; + // add relations to branch + // requests in same groups share same relations, except the last token. + + // set middle layers + // skip the root prompt/tokens + int token_idx = bitmask.prompt_size; + int new_nodes_start_idx = pre_tree_size; + // std::cout << "new nodes start " << new_nodes_start_idx << "\n"; + for (int i = 1; i < currentDepth; i++) { + new_nodes_start_idx = pre_tree_size; + int nodes_this_layer = tree.treeLayers[i].nodes_num_this_layer; + // std::cout << "tree layer: " << i << " nodes:" << nodes_this_layer + // << "group size: " << newNodes / nodes_this_layer << "\n"; + for (int j = 0; j < nodes_this_layer; j++) { + int group_size = newNodes / nodes_this_layer; + for (int k = 0; k < group_size; k++) { + bitmask.mask[token_idx] |= (1 << new_nodes_start_idx); + new_nodes_start_idx += 1; + } + token_idx += 1; + } + } + + // std::cout << "token idx: " << token_idx << ", " << pre_tree_size << ", " + // << new_nodes_start_idx << ", " << newNodes + // << "current depth: " << currentDepth << "\n"; + // std::cout << "new nodes end " << new_nodes_start_idx << "\n"; + + // std::cout << "tree size: " << bitmask.tree_size << "\n"; + assert(token_idx == pre_tree_size); + assert(currentDepth <= 1 || new_nodes_start_idx == bitmask.tree_size); + + // assert(currentDepth <= 2); + // set last layer, all tokens are only relevant to it self; + for (int i = token_idx; i < bitmask.tree_size; i++) { + bitmask.mask[i] |= (1 << i); + // std::cout << "set rel: " << i << "to: " << i << "\n"; + } + + // if(bitmask.non_tree_cache_size == 19 && bitmask.tree_size > 2){ + // assert(false); + // } + + // std::cout << "see bit mask append" << bitmask.prompt_size << "\n"; + // std::cout << "see bit mask append" << bitmask.non_tree_cache_size << "\n"; + // std::cout << "see bit mask append" << std::bitset<64>(bitmask.mask[0]) + // << "\n"; +} + bool PreOrder( BeamTree const &tree, int max_depth, @@ -1740,12 +2003,43 @@ std::vector> // In this case the inputSeriedTree ends with padding 0s assert(inputSerializedTree.size() >= outputSerializedTree.size()); + int *treeLayers = new int[inputSerializedTree.size()]; + int node_num = 1; + int layer_num = 0; + for (int token_id = 0; token_id < inputSerializedTree.size(); token_id++) { + if (token_id == (inputSerializedTree.size() - 1) || + inputSerializedTree.at(token_id + 1).second != + inputSerializedTree.at(token_id).second) { + treeLayers[layer_num] = node_num; + layer_num += 1; + node_num = 1; + } else { + node_num++; + } + } + + // to avoid branch switch when same tokens in input tree. + // todo, only checked for N->1->1->1 cases + + bool findFirst = false; + layer_num = -1; + int first_layer_slot = 0; + int first_layer_slot_total = 0; + int processed_whole_layer_tokens = 0; + for (int i = 0; i < outputSerializedTree.size(); i++) { auto input = inputSerializedTree.at(i); auto output = outputSerializedTree.at(i); + if (i == 0 || inputSerializedTree.at(i - 1).second != + inputSerializedTree.at(i).second) { + layer_num += 1; + processed_whole_layer_tokens += i == 0 ? 0 : treeLayers[layer_num - 1]; + } + if (i == 0) { verifiedTree.push_back(output); + new_committed_tokens.push_back(std::make_pair( input.second, committed_tokens.at(guid).at(i).second)); // > if (input.first == verifiedTree.back().first && input.second == verifiedTree.back().second) { - verifiedTree.push_back(output); - new_committed_tokens.push_back(std::make_pair( - input.second, - committed_tokens.at(guid).at(i).second)); // + if (findFirst) { + // must in this branch. + int layer_slot = i - processed_whole_layer_tokens; + int layer_slot_total = treeLayers[layer_num]; + if ((first_layer_slot == layer_slot)) { + verifiedTree.push_back(output); + new_committed_tokens.push_back(std::make_pair( + input.second, committed_tokens.at(guid).at(i).second)); + // at this point, you'll not go other branches + // std::cout << "verify tree push back: " << output.first + // << ", tree size is: " << verifiedTree.size() + // << ", ??: " << input.first << ", " << input.second << + // "\n"; + + } else { + printf("not correct slot\n"); + } + } else { + verifiedTree.push_back(output); + first_layer_slot = i - processed_whole_layer_tokens; + first_layer_slot_total = treeLayers[layer_num]; + findFirst = true; + new_committed_tokens.push_back(std::make_pair( + input.second, + committed_tokens.at(guid).at(i).second)); // + // at this point, you'll not go other branches + // std::cout << "verify tree push back: " << output.first + // << ", tree size is: " << verifiedTree.size() + // << ", ??: " << input.first << ", " << input.second << "\n"; + } + assert(committed_tokens.at(guid).at(i).first == input.second); } } @@ -1804,6 +2125,8 @@ std::vector> << old_bc.beamRequestsInfo[request_index].current_depth << "\n"; std::cout << "[Traverse Beam Tree] beam_width: " << old_bc.beamRequestsInfo[request_index].beam_size << "\n"; + std::cout << "[Traverse Beam Tree] start index: " + << first_token_depth_in_request << "\n"; } auto guid = old_bc.requestsInfo[request_index].request_guid; @@ -1811,18 +2134,30 @@ std::vector> // std::cout << "request.beam_trees.size(): " << request.beam_trees.size() // << std::endl; BeamTree tree = request.beam_trees.at(old_bc.model_id); - // std::cout << "\n\n"; + // std::cout << "print beam tree: " + // << "\n"; + std::vector> serializedTree; + for (int i = 0; i <= old_bc.beamRequestsInfo[request_index].max_depth; i++) { + // std::cout << "tree layer: " << i + // << ", num_nodes: " << tree.treeLayers[i].nodes_num_this_layer + // << "\n"; + // push tokens into tree + for (int j = 0; j < tree.treeLayers[i].nodes_num_this_layer; j++) { + // std::cout << "token: " << tree.treeLayers[i].tokens[j] << "\n"; + serializedTree.push_back(std::make_pair(tree.treeLayers[i].tokens[j], i)); + } + } // token, index // todo make this one global for different stages - std::vector> serializedTree; - PreOrder(tree, - old_bc.beamRequestsInfo[request_index].max_depth, - 0, - old_bc.beamRequestsInfo[request_index].beam_size, - 0, - serializedTree, - verbose); + + // PreOrder(tree, + // old_bc.beamRequestsInfo[request_index].max_depth, + // 0, + // old_bc.beamRequestsInfo[request_index].beam_size, + // 0, + // serializedTree, + // verbose); // print it if (verbose) { @@ -1857,6 +2192,10 @@ std::vector> input_trees, int root_depth, RequestGuid guid) { + assert(input_trees.size() == 1 && "currently using one ssm"); + dfs_tree_inputs[guid] = input_trees.at(0); + return input_trees.at(0); + std::vector> merged_tree; std::unordered_map> childrens; diff --git a/src/runtime/request_manager.cpp b/src/runtime/request_manager.cpp index 1e756606f8..fadbf80d6d 100644 --- a/src/runtime/request_manager.cpp +++ b/src/runtime/request_manager.cpp @@ -58,6 +58,91 @@ void RequestManager::load_tokens_task( stream)); } +void RequestManager::load_batch_config_task( + Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + assert(regions.size() == 0); + assert(task->regions.size() == 0); + hipStream_t stream; + checkCUDA(get_legion_stream(&stream)); + + // BatchConfig const batch_config = *((BatchConfig *)task->args); + BatchConfig const *batch_config = BatchConfig::from_future(task->futures[0]); + + // copy meta data to workSpace + FFHandler handle = *((FFHandler const *)task->local_args); + size_t total_copy_size = 0; + checkCUDA(hipMemcpyAsync(handle.batch_config_metadata, + &(batch_config->tokensInfo), + sizeof(BatchConfig::tokensInfo), + hipMemcpyHostToDevice, + stream)); + total_copy_size += sizeof(BatchConfig::tokensInfo); + + checkCUDA(hipMemcpyAsync(static_cast(handle.batch_config_metadata) + + total_copy_size, + &(batch_config->requestsInfo), + sizeof(BatchConfig::requestsInfo), + hipMemcpyHostToDevice, + stream)); + total_copy_size += sizeof(BatchConfig::requestsInfo); + + // load speculative metadata + if (batch_config->get_mode() == BEAM_SEARCH_MODE) { + BeamSearchBatchConfig const *beam_batch_config = + static_cast(batch_config); + + checkCUDA(hipMemcpyAsync(static_cast(handle.batch_config_metadata) + + total_copy_size, + &(beam_batch_config->beamTokenInfo), + sizeof(BeamSearchBatchConfig::beamTokenInfo), + hipMemcpyHostToDevice, + stream)); + + total_copy_size += sizeof(BeamSearchBatchConfig::beamTokenInfo); + + checkCUDA(hipMemcpyAsync(static_cast(handle.batch_config_metadata) + + total_copy_size, + &(beam_batch_config->beamRequestsInfo), + sizeof(BeamSearchBatchConfig::beamRequestsInfo), + hipMemcpyHostToDevice, + stream)); + total_copy_size += sizeof(BeamSearchBatchConfig::beamRequestsInfo); + + checkCUDA(hipMemcpyAsync(static_cast(handle.batch_config_metadata) + + total_copy_size, + &(beam_batch_config->causalMask), + sizeof(BatchConfig::causalMask), + hipMemcpyHostToDevice, + stream)); + + total_copy_size += sizeof(BatchConfig::causalMask); + } else if (batch_config->get_mode() == TREE_VERIFY_MODE) { + TreeVerifyBatchConfig const *tree_batch_config = + static_cast(batch_config); + + checkCUDA(hipMemcpyAsync(static_cast(handle.batch_config_metadata) + + total_copy_size, + &(tree_batch_config->causalMask), + sizeof(BatchConfig::causalMask), + hipMemcpyHostToDevice, + stream)); + total_copy_size += sizeof(BatchConfig::causalMask); + checkCUDA(hipMemcpyAsync(static_cast(handle.batch_config_metadata) + + total_copy_size, + &(tree_batch_config->committed_tokens), + sizeof(TreeVerifyBatchConfig::committed_tokens), + hipMemcpyHostToDevice, + stream)); + total_copy_size += sizeof(TreeVerifyBatchConfig::committed_tokens); + } + + // add a size check + assert(total_copy_size <= handle.batch_config_metadata_size); +} + void RequestManager::load_positions_task( Task const *task, std::vector const ®ions, diff --git a/src/runtime/request_manager.cu b/src/runtime/request_manager.cu index cd3e03fff6..51c52c3026 100644 --- a/src/runtime/request_manager.cu +++ b/src/runtime/request_manager.cu @@ -30,6 +30,7 @@ void RequestManager::load_tokens_task( // BatchConfig const batch_config = *((BatchConfig *)task->args); BatchConfig const *batch_config = BatchConfig::from_future(task->futures[0]); + BatchConfig::TokenId dram_copy[BatchConfig::MAX_NUM_TOKENS]; // Extreme long prompts are not supported, only load up to @@ -57,6 +58,91 @@ void RequestManager::load_tokens_task( stream)); } +void RequestManager::load_batch_config_task( + Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + assert(regions.size() == 0); + assert(task->regions.size() == 0); + cudaStream_t stream; + checkCUDA(get_legion_stream(&stream)); + + // BatchConfig const batch_config = *((BatchConfig *)task->args); + BatchConfig const *batch_config = BatchConfig::from_future(task->futures[0]); + + // copy meta data to workSpace + FFHandler handle = *((FFHandler const *)task->local_args); + size_t total_copy_size = 0; + checkCUDA(cudaMemcpyAsync(handle.batch_config_metadata, + &(batch_config->tokensInfo), + sizeof(BatchConfig::tokensInfo), + cudaMemcpyHostToDevice, + stream)); + total_copy_size += sizeof(BatchConfig::tokensInfo); + + checkCUDA(cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + + total_copy_size, + &(batch_config->requestsInfo), + sizeof(BatchConfig::requestsInfo), + cudaMemcpyHostToDevice, + stream)); + total_copy_size += sizeof(BatchConfig::requestsInfo); + + // load speculative metadata + if (batch_config->get_mode() == BEAM_SEARCH_MODE) { + BeamSearchBatchConfig const *beam_batch_config = + static_cast(batch_config); + + checkCUDA(cudaMemcpyAsync( + static_cast(handle.batch_config_metadata) + total_copy_size, + &(beam_batch_config->beamTokenInfo), + sizeof(BeamSearchBatchConfig::beamTokenInfo), + cudaMemcpyHostToDevice, + stream)); + + total_copy_size += sizeof(BeamSearchBatchConfig::beamTokenInfo); + + checkCUDA(cudaMemcpyAsync( + static_cast(handle.batch_config_metadata) + total_copy_size, + &(beam_batch_config->beamRequestsInfo), + sizeof(BeamSearchBatchConfig::beamRequestsInfo), + cudaMemcpyHostToDevice, + stream)); + total_copy_size += sizeof(BeamSearchBatchConfig::beamRequestsInfo); + + checkCUDA(cudaMemcpyAsync( + static_cast(handle.batch_config_metadata) + total_copy_size, + &(beam_batch_config->causalMask), + sizeof(BatchConfig::causalMask), + cudaMemcpyHostToDevice, + stream)); + + total_copy_size += sizeof(BatchConfig::causalMask); + } else if (batch_config->get_mode() == TREE_VERIFY_MODE) { + TreeVerifyBatchConfig const *tree_batch_config = + static_cast(batch_config); + + checkCUDA(cudaMemcpyAsync( + static_cast(handle.batch_config_metadata) + total_copy_size, + &(tree_batch_config->causalMask), + sizeof(BatchConfig::causalMask), + cudaMemcpyHostToDevice, + stream)); + total_copy_size += sizeof(BatchConfig::causalMask); + checkCUDA(cudaMemcpyAsync( + static_cast(handle.batch_config_metadata) + total_copy_size, + &(tree_batch_config->committed_tokens), + sizeof(TreeVerifyBatchConfig::committed_tokens), + cudaMemcpyHostToDevice, + stream)); + total_copy_size += sizeof(TreeVerifyBatchConfig::committed_tokens); + } + + // add a size check + assert(total_copy_size <= handle.batch_config_metadata_size); +} + void RequestManager::load_positions_task( Task const *task, std::vector const ®ions,