Skip to content

Commit

Permalink
Merge branch 'inference' into cuda_graph
Browse files Browse the repository at this point in the history
  • Loading branch information
jiazhihao authored Dec 31, 2023
2 parents 21c00bd + 4957b7c commit 032dc7a
Show file tree
Hide file tree
Showing 23 changed files with 1,727 additions and 719 deletions.
29 changes: 28 additions & 1 deletion include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class BatchConfig {
// across workers
static int const MAX_NUM_REQUESTS = 64;
static int const MAX_NUM_TOKENS = 1024;
static int const MAX_SPEC_TREE_TOKEN_NUM = 64;

// Set by update
int num_tokens;
Expand All @@ -68,13 +69,34 @@ class BatchConfig {
int first_token_offset_in_batch;
int num_tokens_in_batch;
int max_sequence_length;

// request id in batch config:
int batch_config_request_id;
RequestGuid request_guid;
};
struct PerTokenInfo {
int abs_depth_in_request;
int request_index;
TokenId token_id;
};

struct BitMask {
unsigned long long mask[MAX_SPEC_TREE_TOKEN_NUM] = {0};

// how many tokens before the tree, every sub requests need this part of
// cache
int non_tree_cache_size;

// current tree size
int tree_size;

int this_layer_size;

// input length-> prompt/root
int prompt_size;
};

BitMask causalMask[MAX_NUM_REQUESTS];
PerRequestInfo requestsInfo[MAX_NUM_REQUESTS];
PerTokenInfo tokensInfo[MAX_NUM_TOKENS];

Expand Down Expand Up @@ -126,9 +148,12 @@ class BeamSearchBatchConfig : public BatchConfig {

size_t beam_width;
size_t target_iterations;
inline static int const MAX_BEAM_WIDTH = 1;
inline static int const MAX_BEAM_WIDTH = 3;
inline static int const MAX_BEAM_DEPTH = 8;

// maximum tree branches for a request
inline static int const MAX_SPECULATIVE_TREE_BRANCHES = 3;

int model_id;

struct BeamSearchPerRequestInfo {
Expand All @@ -139,6 +164,7 @@ class BeamSearchBatchConfig : public BatchConfig {
BatchConfig::TokenId tokens[BeamSearchBatchConfig::MAX_BEAM_WIDTH];
float probs[BeamSearchBatchConfig::MAX_BEAM_WIDTH];
int parent_id[BeamSearchBatchConfig::MAX_BEAM_WIDTH];
int sub_request_num;
};

struct BeamSearchPerTokenInfo {
Expand All @@ -147,6 +173,7 @@ class BeamSearchBatchConfig : public BatchConfig {

BeamSearchPerRequestInfo beamRequestsInfo[MAX_NUM_REQUESTS];
BeamSearchPerTokenInfo beamTokenInfo[MAX_NUM_TOKENS * MAX_BEAM_WIDTH];

// why is this == MAX_NUM_REQUESTS * MAX_BEAM_WIDTH?
int sub_requests[MAX_NUM_REQUESTS * MAX_BEAM_WIDTH];

Expand Down
11 changes: 11 additions & 0 deletions include/flexflow/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#ifndef _FLEXFLOW_CONFIG_H_
#define _FLEXFLOW_CONFIG_H_
#include "ffconst.h"
#include "flexflow/batch_config.h"
#include "legion.h"
#include <cstring>
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
Expand Down Expand Up @@ -75,6 +76,15 @@ struct FFHandler {
#endif
void *workSpace;
size_t workSpaceSize;
void *batch_config_metadata;

// request info + token info + topolopgy mask info
size_t batch_config_metadata_size =
sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) +
sizeof(BeamSearchBatchConfig::beamTokenInfo) +
sizeof(BeamSearchBatchConfig::beamRequestsInfo) +
sizeof(BatchConfig::causalMask) +
sizeof(TreeVerifyBatchConfig::committed_tokens);
void *offload_reserve_space;
size_t offload_reserve_space_size;
DataType quantization_type;
Expand Down Expand Up @@ -132,6 +142,7 @@ class FFConfig {
size_t workSpaceSize;
Legion::Context lg_ctx;
Legion::Runtime *lg_hlr;
Legion::IndexSpaceT<1> all_gpu_task_is;
// Legion::FieldSpace field_space;
bool syntheticInput, profiling, perform_fusion;
bool inference_debugging;
Expand Down
1 change: 1 addition & 0 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ enum TaskIDs {
// InferenceManager & RequestManager
RM_LOAD_TOKENS_TASK_ID,
RM_LOAD_POSITION_TASK_ID,
RM_LOAD_BATCH_CONFIG_TASK_ID,
RM_PREPARE_NEXT_BATCH_TASK_ID,
RM_PREPARE_NEXT_BATCH_INIT_TASK_ID,
RM_PREPARE_NEXT_BATCH_BEAM_TASK_ID,
Expand Down
1 change: 1 addition & 0 deletions include/flexflow/ops/spec_inc_multihead_self_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class SpecIncMultiHeadSelfAttentionMeta : public IncMultiHeadSelfAttentionMeta {
Realm::RegionInstance beam_search_reserve_inst;
BeamSearchBatchConfig::BeamSearchPerTokenInfo *beam_token_infos;
BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos;
BatchConfig::BitMask *causalMask;
};

}; // namespace FlexFlow
Expand Down
1 change: 1 addition & 0 deletions include/flexflow/ops/tree_inc_multihead_self_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class TreeIncMultiHeadSelfAttentionMeta : public IncMultiHeadSelfAttentionMeta {
int num_active_tokens;
Realm::RegionInstance committed_token_reserve_inst;
TreeVerifyBatchConfig::CommittedTokensInfo *committed_token_infos;
BatchConfig::BitMask *causalMask;
};

}; // namespace FlexFlow
Expand Down
33 changes: 29 additions & 4 deletions include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,13 @@ class InferenceManager {
Legion::FutureMap
inference(FFModel *model, int index, BatchConfigFuture const &bc);
void load_input_tokens_from_batch_config(BatchConfigFuture const &bc,
ParallelTensor const input);
ParallelTensor const input,
FFHandler *handlers);
void load_positions(BatchConfigFuture const &bc,
ParallelTensor position_input,
int offset);
void load_inference_metadata_batch_config(BatchConfigFuture const &bc,
FFHandler *handlers);

public:
FFConfig ff_config;
Expand Down Expand Up @@ -72,9 +75,10 @@ struct Request {
struct BeamTree {
struct treeLayer {
BeamSearchBatchConfig::TokenId
tokens[BeamSearchBatchConfig::MAX_BEAM_WIDTH];
tokens[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];
int parent_ids[BeamSearchBatchConfig::MAX_BEAM_WIDTH];
float probs[BeamSearchBatchConfig::MAX_BEAM_WIDTH];
float probs[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];
int nodes_num_this_layer = 0;
};
treeLayer treeLayers[BeamSearchBatchConfig::MAX_BEAM_DEPTH + 1];
};
Expand All @@ -100,13 +104,24 @@ class RequestManager {
void set_max_tokens_per_batch(int max_num_tokens);
int get_max_tokens_per_batch();
void set_max_sequence_length(int max_seq_length);
void push_spec_infer_tree_width(int tree_width);
int get_max_sequence_length();
int register_ssm_model(FFModel *model);
void register_tokenizer(ModelType model_type,
int bos_token_id,
int eos_token_id,
std::string const &path);
void register_output_filepath(std::string const &);
void initBitMask(BatchConfig::BitMask &bitmask, int initLength);
void appendBitMask(BatchConfig::BitMask &bitmask,
int newNodes,
int preBeamSize,
int old_sub_num,
BeamTree const tree,
int currentDepth);
void updateBitMask(BatchConfig::BitMask &bitmask,
int initLength,
int non_tree_size);

FFModel *get_model(int model_id);

Expand Down Expand Up @@ -148,6 +163,7 @@ class RequestManager {
void store_beam_metadata(BeamSearchBatchConfig const &old_bc,
BeamInferenceResult const &result);
void update_beam_metadata(BeamSearchBatchConfig &new_bc,
BeamSearchBatchConfig const &old_bc,
BeamTree &tree,
int request_index);

Expand Down Expand Up @@ -181,6 +197,11 @@ class RequestManager {
Legion::Context ctx,
Legion::Runtime *runtime);

static void
load_batch_config_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static BatchConfig prepare_next_batch_task(
Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Expand Down Expand Up @@ -210,6 +231,9 @@ class RequestManager {
int max_requests_per_batch;
int max_tokens_per_batch;
int max_sequence_length;

// tree width in each speculative step, if not specified 1
std::vector<int> spec_infer_tree_width;
// private fields
std::unique_ptr<Tokenizer> tokenizer_;
bool verbose;
Expand Down Expand Up @@ -243,7 +267,8 @@ class RequestManager {

private:
struct ProfileInfo {
int decoding_steps;
int llm_decoding_steps;
int ssm_decoding_steps;
double start_time, finish_time;
};
std::unordered_map<RequestGuid, ProfileInfo> profiling_requests;
Expand Down
4 changes: 3 additions & 1 deletion inference/models/llama.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ void LLAMA::create_llama_model(FFModel &ff,
if (mode == BEAM_SEARCH_MODE) {
Tensor softmax = ff.softmax(dense, -1);
// output = ff.beam_top_k(softmax, llama_config.max_beam_width, false);
output = ff.argmax(softmax, /*beam_Search*/ true);
// output = ff.argmax(softmax, /*beam_Search*/ true);
output = ff.beam_top_k(softmax, llama_config.max_beam_width, false);
// output = ff.top_k(softmax, )
} else {
// Tensor softmax = ff.softmax(dense, -1);
if (generation_config.do_sample) {
Expand Down
3 changes: 3 additions & 0 deletions inference/spec_infer/spec_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ void FlexFlow::top_level_task(Task const *task,
model_metadata.llm_tokenizer_path);
rm->register_output_filepath(file_paths.output_file_path);

// first decoding step: 3 results
rm->push_spec_infer_tree_width(3);

// Create LLM model
FFModel tree_model(ffconfig, ffconfig.cpu_offload);
if (model_metadata.llm_model_type == ModelType::LLAMA) {
Expand Down
2 changes: 1 addition & 1 deletion src/ops/argmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,6 @@ BeamInferenceResult
GenericTensorAccessorW parent = helperGetGenericTensorAccessorWO(
DT_INT32, regions[2], task->regions[2], FID_DATA, ctx, runtime);
ArgMax::forward_kernel_wrapper(m, input, indices, parent, batch_size);

BeamInferenceResult ir;
download_tensor<BatchConfig::TokenId>(
indices.get_int32_ptr(), ir.token_ids, batch_size);
Expand Down Expand Up @@ -398,6 +397,7 @@ InferenceResult
ArgMax::save_inference_tensors_to_file(
m, shard_id, bc, {}, {}, {input, indices});
}

download_tensor<BatchConfig::TokenId>(
indices.get_int32_ptr(), ir.token_ids, batch_size);
return ir;
Expand Down
2 changes: 1 addition & 1 deletion src/ops/beam_topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ BeamInferenceResult
GenericTensorAccessorW value = helperGetGenericTensorAccessorWO(
DT_FLOAT, regions[2], task->regions[2], FID_DATA, ctx, runtime);
GenericTensorAccessorW parent = helperGetGenericTensorAccessorWO(
DT_FLOAT, regions[3], task->regions[3], FID_DATA, ctx, runtime);
DT_INT32, regions[3], task->regions[3], FID_DATA, ctx, runtime);

Domain input_domain = runtime->get_index_space_domain(
ctx, task->regions[0].region.get_index_space());
Expand Down
65 changes: 37 additions & 28 deletions src/ops/beam_topk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -556,21 +556,20 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m,
int beam_size = bc->beamRequestsInfo[i].beam_size;

// initial request
log_beam_topk.debug() << "sub_requests: " << i << ", " << sub_requests[i]
<< "\n";
assert(sub_requests[i] > 0);
// process sub requests
for (int j = 0; j < sub_requests[i]; j++) {
parent_ids[req_index * BeamSearchBatchConfig::MAX_BEAM_WIDTH + j] = j;
// beam_slots[i].parent_id[j];
acc_probs[req_index * BeamSearchBatchConfig::MAX_BEAM_WIDTH + j] =
bc->beamRequestsInfo[i].probs[j];
log_beam_topk.debug()
<< "probbbb req: " << i
<< ", sub req probability : " << bc->beamRequestsInfo[i].probs[j]
<< ", sub request id " << j << ", parent id "
<< bc->beamRequestsInfo[i].parent_id[j] << ", data inddd"
<< req_index * BeamSearchBatchConfig::MAX_BEAM_WIDTH + j << "\n";
// std::cout << "probbbb req: " << i << ", sub req probability : "
// << bc->beamRequestsInfo[i].probs[j] << ", sub request id " <<
// j
// << ", parent id " << bc->beamRequestsInfo[i].parent_id[j]
// << ", data inddd"
// << req_index * BeamSearchBatchConfig::MAX_BEAM_WIDTH + j
// << "\n";
}

// process tokens
Expand All @@ -584,6 +583,7 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m,

max_heap_size = std::max(max_heap_size, beam_size * sub_requests[i]);
max_beam_width = std::max(max_beam_width, beam_size);

req_index += 1;
block_start_index += (sub_requests[i] - 1) * num_new_tokens * length;
}
Expand Down Expand Up @@ -613,28 +613,37 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m,
assert(num_shards >= (size_t)max_heap_size);
num_shards = max_heap_size;

checkCUDA(cudaMemcpy(m->parent_ids,
parent_ids,
sizeof(int) * max_total_requests,
cudaMemcpyHostToDevice));
checkCUDA(cudaMemcpy(m->acc_probs,
acc_probs,
sizeof(DT) * max_total_requests,
cudaMemcpyHostToDevice));
checkCUDA(cudaMemcpy(m->block_start_index,
beam_block_start_index.data(),
sizeof(int) * beam_num_blocks,
cudaMemcpyHostToDevice));
checkCUDA(cudaMemcpy(m->request_id,
request_id.data(),
sizeof(int) * beam_num_blocks,
cudaMemcpyHostToDevice));
checkCUDA(cudaMemcpy(m->tokens_per_request,
tokens_per_request.data(),
sizeof(int) * beam_num_blocks,
cudaMemcpyHostToDevice));
checkCUDA(cudaMemcpyAsync(m->parent_ids,
parent_ids,
sizeof(int) * max_total_requests,
cudaMemcpyHostToDevice,
stream));
checkCUDA(cudaMemcpyAsync(m->acc_probs,
acc_probs,
sizeof(DT) * max_total_requests,
cudaMemcpyHostToDevice,
stream));
// trick, set acc_probs to 0;
checkCUDA(cudaMemsetAsync(
m->acc_probs, 1.0, max_total_requests * sizeof(DT), stream));
checkCUDA(cudaMemcpyAsync(m->block_start_index,
beam_block_start_index.data(),
sizeof(int) * beam_num_blocks,
cudaMemcpyHostToDevice,
stream));
checkCUDA(cudaMemcpyAsync(m->request_id,
request_id.data(),
sizeof(int) * beam_num_blocks,
cudaMemcpyHostToDevice,
stream));
checkCUDA(cudaMemcpyAsync(m->tokens_per_request,
tokens_per_request.data(),
sizeof(int) * beam_num_blocks,
cudaMemcpyHostToDevice,
stream));
// int depth =
// bc->beamRequestsInfo[bc->tokensInfo[0].request_index].current_depth;
beam_num_blocks = bc->num_active_tokens();
beam_topk_forward_kernel<<<beam_num_blocks, num_shards, 0, stream>>>(
input_ptr,
shared_memory_size,
Expand Down
Loading

0 comments on commit 032dc7a

Please sign in to comment.