From dc68b5166f21c3ae891e751cbabd989781868e74 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Mon, 1 Jan 2024 20:43:43 -0500 Subject: [PATCH] fix --- include/flexflow/ops/fused.h | 4 +++- src/ops/fused.cu | 22 +++++++++++++--------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/include/flexflow/ops/fused.h b/include/flexflow/ops/fused.h index 4183481013..0232ae9e94 100644 --- a/include/flexflow/ops/fused.h +++ b/include/flexflow/ops/fused.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_FUSED_H_ #define _FLEXFLOW_FUSED_H_ +#include "flexflow/batch_config.h" #include "flexflow/model.h" namespace FlexFlow { @@ -12,7 +13,8 @@ class FusedOpMeta { OpMeta *meta[MAX_NUM_FUSED_OPERATORS]; FusedOp *fused_op; int numOperators; - std::map, cudaGraphExec_t> graph_collections; + cudaGraphExec_t graph_collections[BatchConfig::MAX_NUM_REQUESTS] + [BatchConfig::MAX_NUM_TOKENS] = {nullptr}; }; class FusedOp : public Op { diff --git a/src/ops/fused.cu b/src/ops/fused.cu index 4c4cc4e340..11e1ad5ee8 100644 --- a/src/ops/fused.cu +++ b/src/ops/fused.cu @@ -607,20 +607,20 @@ __host__ void } } - // create cuda graph if not yet available cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - cudaGraph_t graph; - cudaGraphExec_t instance; // check if graph exists - std::pair graph_params = std::make_pair( - bc->num_active_tokens(), bc->num_active_requests()); - if (metas->graph_collections.find(graph_params) != - metas->graph_collections.end()) { - cudaGraphExec_t instance = metas->graph_collections[graph_params]; + if (metas->graph_collections[bc->num_active_requests()] + [bc->num_active_tokens()] != nullptr) { + cudaGraphExec_t instance = + metas->graph_collections[bc->num_active_requests()] + [bc->num_active_tokens()]; cudaGraphLaunch(instance, stream); return; } + // create new cuda graph + cudaGraph_t graph; + cudaGraphExec_t instance; cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal); int ioff = 0, woff = 0, ooff = 0; @@ -1142,9 +1142,13 @@ __host__ void // for (int i = 0; i < fused->numOutputs; i++) // print_tensor(output_ptr[i], output_domain[i].get_volume(), // "[Fused:forward:output]"); + cudaStreamEndCapture(stream, &graph); cudaGraphInstantiate(&instance, graph, NULL, NULL, 0); - metas->graph_collections[graph_params] = instance; + metas->graph_collections[bc->num_active_requests()][bc->num_active_tokens()] = + instance; + assert(metas->graph_collections[bc->num_active_requests()] + [bc->num_active_tokens()] != nullptr); cudaGraphLaunch(instance, stream); }