Skip to content

Commit

Permalink
variable renaming (#1194)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiazhihao authored Oct 17, 2023
1 parent f243b40 commit 4c06a09
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 42 deletions.
2 changes: 1 addition & 1 deletion include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class BatchConfig {
int num_tokens;

struct PerRequestInfo {
int token_start_offset;
int first_token_depth_in_request;
int num_tokens_in_batch;
int max_sequence_length;
RequestGuid request_guid;
Expand Down
2 changes: 1 addition & 1 deletion include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class RequestManager {
std::vector<std::pair<BatchConfig::TokenId, int>>
traverse_beam_tree(BeamSearchBatchConfig const &old_bc,
int request_index,
int token_start_offset);
int first_token_depth_in_request);

// remove guid after put the cached tree in request
std::vector<std::pair<BatchConfig::TokenId, int>> merge_dfs_trees(
Expand Down
2 changes: 1 addition & 1 deletion src/ops/inc_multihead_self_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m,
continue;
}
int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch;
int total_tokens = bc->requestsInfo[i].token_start_offset +
int total_tokens = bc->requestsInfo[i].first_token_depth_in_request +
bc->requestsInfo[i].num_tokens_in_batch;
// bc->token_last_available_idx[i] + 1;
// Compute (QK^T/sqrt(d_k))
Expand Down
2 changes: 1 addition & 1 deletion src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m,
continue;
}
int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch;
int total_tokens = bc->requestsInfo[i].token_start_offset +
int total_tokens = bc->requestsInfo[i].first_token_depth_in_request +
bc->requestsInfo[i].num_tokens_in_batch;
// bc->token_last_available_idx[i] + 1;
// Compute (QK^T/sqrt(d_k))
Expand Down
2 changes: 1 addition & 1 deletion src/ops/spec_inc_multihead_self_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m,
// int total_tokens = bc->token_last_available_idx[i] + 1;

int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch;
int total_tokens = bc->requestsInfo[i].token_start_offset +
int total_tokens = bc->requestsInfo[i].first_token_depth_in_request +
bc->requestsInfo[i].num_tokens_in_batch;
// Compute (QK^T/sqrt(d_k))
int m_ = num_new_tokens;
Expand Down
2 changes: 1 addition & 1 deletion src/ops/spec_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m,
// int total_tokens = bc->token_last_available_idx[i] + 1;
int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch;
int total_tokens = bc->requestsInfo[i].token_start_offset +
int total_tokens = bc->requestsInfo[i].first_token_depth_in_request +
bc->requestsInfo[i].num_tokens_in_batch;
if (num_new_tokens <= 0) {
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/batch_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ using Legion::Memory;

BatchConfig::BatchConfig() : num_tokens(0) {
for (int i = 0; i < MAX_NUM_REQUESTS; i++) {
requestsInfo[i].token_start_offset = 0;
requestsInfo[i].first_token_depth_in_request = 0;
requestsInfo[i].num_tokens_in_batch = 0;
request_completed[i] = true;
}
Expand Down Expand Up @@ -104,8 +104,8 @@ std::ostream &operator<<(std::ostream &os, BatchConfig const &bc) {
for (int i = 0; i < bc.max_requests_per_batch(); i++) {
if (!bc.request_completed[i]) {
os << " Request " << i << ":\n";
os << " Token start offset: " << bc.requestsInfo[i].token_start_offset
<< std::endl;
os << " Token start offset: "
<< bc.requestsInfo[i].first_token_depth_in_request << std::endl;
os << " Number of tokens in batch: "
<< bc.requestsInfo[i].num_tokens_in_batch << std::endl;
os << " GUID: " << bc.requestsInfo[i].request_guid << std::endl;
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/beam_search_batch_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ std::ostream &operator<<(std::ostream &os, BeamSearchBatchConfig const &bc) {
for (int i = 0; i < bc.max_requests_per_batch(); i++) {
if (!bc.request_completed[i]) {
os << " Request " << i << ":\n";
os << " Token start offset: " << bc.requestsInfo[i].token_start_offset
<< std::endl;
os << " Token start offset: "
<< bc.requestsInfo[i].first_token_depth_in_request << std::endl;
os << " Number of tokens in batch: "
<< bc.requestsInfo[i].num_tokens_in_batch << std::endl;
os << " GUID: " << bc.requestsInfo[i].request_guid << std::endl;
Expand Down
66 changes: 37 additions & 29 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
Request new_request = pending_request_queue.front();
pending_request_queue.pop();
// all_requests[new_request.guid] = new_request;
new_bc.requestsInfo[i].token_start_offset = 0;
new_bc.requestsInfo[i].first_token_depth_in_request = 0;
new_bc.requestsInfo[i].request_guid = new_request.guid;
new_bc.requestsInfo[i].num_tokens_in_batch =
std::min(get_max_tokens_per_batch() - new_bc.num_tokens -
Expand All @@ -382,7 +382,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
profile_info.start_time = Realm::Clock::current_time_in_microseconds();
profiling_requests[new_request.guid] = profile_info;
for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) {
int depth = new_bc.requestsInfo[i].token_start_offset + j;
int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j;
new_bc.tokensInfo[new_bc.num_tokens].request_index = i;
new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth;
assert(depth < new_request.tokens.size());
Expand All @@ -397,8 +397,9 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
} else {
assert(old_bc.requestsInfo[i].num_tokens_in_batch > 0);
Request &request = all_requests[old_bc.requestsInfo[i].request_guid];
int processed_tokens = old_bc.requestsInfo[i].token_start_offset +
old_bc.requestsInfo[i].num_tokens_in_batch;
int processed_tokens =
old_bc.requestsInfo[i].first_token_depth_in_request +
old_bc.requestsInfo[i].num_tokens_in_batch;
assert(processed_tokens < request.tokens.size());
bool request_completed = false;
// printf("model_type = %d\n", this->model_type);
Expand Down Expand Up @@ -464,12 +465,12 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,

} else {
new_bc.request_completed[i] = false;
new_bc.requestsInfo[i].token_start_offset = processed_tokens;
new_bc.requestsInfo[i].first_token_depth_in_request = processed_tokens;
new_bc.requestsInfo[i].request_guid =
old_bc.requestsInfo[i].request_guid;
new_bc.requestsInfo[i].max_sequence_length =
old_bc.requestsInfo[i].max_sequence_length;
if (new_bc.requestsInfo[i].token_start_offset + 1 ==
if (new_bc.requestsInfo[i].first_token_depth_in_request + 1 ==
request.tokens.size()) {
// Incremental phase
new_bc.requestsInfo[i].num_tokens_in_batch = 1;
Expand All @@ -478,10 +479,10 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
new_bc.requestsInfo[i].num_tokens_in_batch =
std::min(get_max_tokens_per_batch() - new_bc.num_tokens,
(int)request.tokens.size() -
new_bc.requestsInfo[i].token_start_offset);
new_bc.requestsInfo[i].first_token_depth_in_request);
}
for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) {
int depth = new_bc.requestsInfo[i].token_start_offset + j;
int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j;
new_bc.tokensInfo[new_bc.num_tokens].request_index = i;
new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth;
assert(depth < request.tokens.size());
Expand Down Expand Up @@ -685,7 +686,7 @@ BeamSearchBatchConfig
new_bc.request_running[i] = true;

// Normal Request Info
new_bc.requestsInfo[i].token_start_offset =
new_bc.requestsInfo[i].first_token_depth_in_request =
verified_tokens.front().second;
new_bc.requestsInfo[i].request_guid =
old_bc.requestsInfo[i].request_guid;
Expand All @@ -694,9 +695,10 @@ BeamSearchBatchConfig
new_bc.requestsInfo[i].num_tokens_in_batch = verified_tokens.size();

// TODO: Beam Request Info, missing from VerifyTreeBatchConfig
int new_max_depth = new_bc.requestsInfo[i].max_sequence_length -
new_bc.requestsInfo[i].token_start_offset -
verified_tokens.size();
int new_max_depth =
new_bc.requestsInfo[i].max_sequence_length -
new_bc.requestsInfo[i].first_token_depth_in_request -
verified_tokens.size();
new_bc.beamRequestsInfo[i].current_depth = 1;
new_bc.beamRequestsInfo[i].beam_size =
BeamSearchBatchConfig::MAX_BEAM_WIDTH;
Expand Down Expand Up @@ -742,7 +744,8 @@ BeamSearchBatchConfig
assert(request.ssm_cache_size == request.initial_len);

// Normal Request Info
new_bc.requestsInfo[i].token_start_offset = request.ssm_cache_size;
new_bc.requestsInfo[i].first_token_depth_in_request =
request.ssm_cache_size;
new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid;
new_bc.requestsInfo[i].max_sequence_length =
old_bc.requestsInfo[i].max_sequence_length;
Expand Down Expand Up @@ -776,7 +779,7 @@ BeamSearchBatchConfig
Request new_request = pending_request_queue.front();
pending_request_queue.pop();
// all_requests[new_request.guid] = new_request;
new_bc.requestsInfo[i].token_start_offset = 0;
new_bc.requestsInfo[i].first_token_depth_in_request = 0;
new_bc.requestsInfo[i].request_guid = new_request.guid;
new_bc.requestsInfo[i].num_tokens_in_batch =
std::min(get_max_tokens_per_batch() - new_bc.num_tokens,
Expand Down Expand Up @@ -806,7 +809,7 @@ BeamSearchBatchConfig
new_bc.sub_requests[i] = 1;

for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) {
int depth = new_bc.requestsInfo[i].token_start_offset + j;
int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j;
new_bc.tokensInfo[new_bc.num_tokens].request_index = i;
new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth;
assert(depth < new_request.tokens.size());
Expand Down Expand Up @@ -922,7 +925,7 @@ BeamSearchBatchConfig
// zero when beam search has reached required sequence length
// assert(old_bc.requestsInfo[i].num_tokens_in_batch > 0);
Request &request = all_requests[old_bc.requestsInfo[i].request_guid];
int processed_tokens = old_bc.requestsInfo[i].token_start_offset +
int processed_tokens = old_bc.requestsInfo[i].first_token_depth_in_request +
old_bc.requestsInfo[i].num_tokens_in_batch;

// assert(processed_tokens < request.tokens.size());
Expand All @@ -937,7 +940,8 @@ BeamSearchBatchConfig
// // old_bc.beamRequestsInfo[i].max_depth);
// // // new_bc.request_completed[i] = true;
// // new_bc.request_completed[i] = false;
// // new_bc.requestsInfo[i].token_start_offset = processed_tokens;
// // new_bc.requestsInfo[i].first_token_depth_in_request =
// processed_tokens;
// // new_bc.requestsInfo[i].request_guid =
// // old_bc.requestsInfo[i].request_guid;
// // new_bc.requestsInfo[i].max_sequence_length =
Expand All @@ -953,7 +957,7 @@ BeamSearchBatchConfig
log_req_mgr.debug() << "num tokens: " << old_bc.num_tokens << ", "
<< new_bc.num_tokens;
new_bc.request_completed[i] = false;
new_bc.requestsInfo[i].token_start_offset = processed_tokens;
new_bc.requestsInfo[i].first_token_depth_in_request = processed_tokens;
new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid;
new_bc.requestsInfo[i].max_sequence_length =
old_bc.requestsInfo[i].max_sequence_length;
Expand Down Expand Up @@ -986,7 +990,8 @@ BeamSearchBatchConfig
// do the slot exchange to minimize the cache exchange in kernel.
// update_beam_metadata(new_bc, request.beam_trees.at(old_bc.model_id),
// i);
if (new_bc.requestsInfo[i].token_start_offset >= request.tokens.size()) {
if (new_bc.requestsInfo[i].first_token_depth_in_request >=
request.tokens.size()) {
// Incremental phase
if (request.status == Request::RUNNING) {
new_bc.requestsInfo[i].num_tokens_in_batch = 1;
Expand All @@ -1006,7 +1011,7 @@ BeamSearchBatchConfig
std::min(get_max_tokens_per_batch() - new_bc.num_tokens -
BatchConfig::max_requests_per_batch() + i,
(int)request.tokens.size() -
new_bc.requestsInfo[i].token_start_offset);
new_bc.requestsInfo[i].first_token_depth_in_request);
request.ssm_cache_size += new_bc.requestsInfo[i].num_tokens_in_batch;
if (verbose) {
std::cout << "[ Beam Spec] " << request.guid << std::endl;
Expand All @@ -1027,7 +1032,7 @@ BeamSearchBatchConfig

// register more tokens due to the beam width
for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) {
int depth = new_bc.requestsInfo[i].token_start_offset + j;
int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j;
for (int k = 0; k < new_bc.sub_requests[i]; k++) {
new_bc.tokensInfo[new_bc.num_tokens].request_index = i;
new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth;
Expand Down Expand Up @@ -1151,7 +1156,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify(
}

// Normal Request Info
new_bc.requestsInfo[i].token_start_offset =
new_bc.requestsInfo[i].first_token_depth_in_request =
dfs_tree_inputs.front().second;
new_bc.requestsInfo[i].request_guid =
old_batches.at(0).requestsInfo[i].request_guid;
Expand Down Expand Up @@ -1204,7 +1209,8 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify(
break;
}

new_bc.requestsInfo[i].token_start_offset = request.tokens.size() - 1;
new_bc.requestsInfo[i].first_token_depth_in_request =
request.tokens.size() - 1;

// Add Tokens from the DFS Tree to the next batch
for (int j = 1; j < dfs_tree_inputs.size(); j++) {
Expand Down Expand Up @@ -1257,17 +1263,19 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify(
}

// Normal Request Info
new_bc.requestsInfo[i].token_start_offset = request.llm_cache_size;
new_bc.requestsInfo[i].first_token_depth_in_request =
request.llm_cache_size;
new_bc.requestsInfo[i].request_guid =
old_batches.at(0).requestsInfo[i].request_guid;
new_bc.requestsInfo[i].max_sequence_length =
old_batches.at(0).requestsInfo[i].max_sequence_length;

new_bc.request_completed[i] = false;

new_bc.requestsInfo[i].num_tokens_in_batch = std::min(
max_prompt_load_size,
(int)request.initial_len - new_bc.requestsInfo[i].token_start_offset);
new_bc.requestsInfo[i].num_tokens_in_batch =
std::min(max_prompt_load_size,
(int)request.initial_len -
new_bc.requestsInfo[i].first_token_depth_in_request);
max_prompt_load_size -= new_bc.requestsInfo[i].num_tokens_in_batch;

std::cout << "max_prompt_load_size: " << max_prompt_load_size
Expand Down Expand Up @@ -1673,7 +1681,7 @@ std::vector<std::pair<BatchConfig::TokenId, int>>
std::vector<std::pair<BatchConfig::TokenId, int>>
RequestManager::traverse_beam_tree(BeamSearchBatchConfig const &old_bc,
int request_index,
int token_start_offset) {
int first_token_depth_in_request) {
if (verbose) {
std::cout << "[Traverse Beam Tree] request_index: " << request_index
<< "\n";
Expand Down Expand Up @@ -1709,7 +1717,7 @@ std::vector<std::pair<BatchConfig::TokenId, int>>
<< serializedTree.size() << "\n";
}
for (int k = 0; k < serializedTree.size(); k++) {
serializedTree.at(k).second += token_start_offset;
serializedTree.at(k).second += first_token_depth_in_request;
if (verbose) {
std::cout << "token id: " << serializedTree.at(k).first
<< ", depth: " << serializedTree.at(k).second << "\n";
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/tree_verify_batch_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ std::ostream &operator<<(std::ostream &os, TreeVerifyBatchConfig const &bc) {
for (int i = 0; i < bc.max_requests_per_batch(); i++) {
if (!bc.request_completed[i]) {
os << " Request " << i << ":\n";
os << " Token start offset: " << bc.requestsInfo[i].token_start_offset
<< std::endl;
os << " Token start offset: "
<< bc.requestsInfo[i].first_token_depth_in_request << std::endl;
os << " Number of tokens in batch: "
<< bc.requestsInfo[i].num_tokens_in_batch << std::endl;
os << " GUID: " << bc.requestsInfo[i].request_guid << std::endl;
Expand Down

0 comments on commit 4c06a09

Please sign in to comment.