Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Jan 2, 2024
1 parent 80a42f5 commit 6f4b4d9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
4 changes: 3 additions & 1 deletion include/flexflow/ops/fused.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef _FLEXFLOW_FUSED_H_
#define _FLEXFLOW_FUSED_H_

#include "flexflow/batch_config.h"
#include "flexflow/model.h"

namespace FlexFlow {
Expand All @@ -12,7 +13,8 @@ class FusedOpMeta {
OpMeta *meta[MAX_NUM_FUSED_OPERATORS];
FusedOp *fused_op;
int numOperators;
std::map<std::pair<int, int>, cudaGraphExec_t> graph_collections;
cudaGraphExec_t graph_collections[BatchConfig::MAX_NUM_REQUESTS]
[BatchConfig::MAX_NUM_TOKENS] = {nullptr};
};

class FusedOp : public Op {
Expand Down
22 changes: 13 additions & 9 deletions src/ops/fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -608,20 +608,20 @@ __host__ void
}
}

// create cuda graph if not yet available
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
cudaGraph_t graph;
cudaGraphExec_t instance;
// check if graph exists
std::pair<int, int> graph_params = std::make_pair<int, int>(
bc->num_active_tokens(), bc->num_active_requests());
if (metas->graph_collections.find(graph_params) !=
metas->graph_collections.end()) {
cudaGraphExec_t instance = metas->graph_collections[graph_params];
if (metas->graph_collections[bc->num_active_requests()]
[bc->num_active_tokens()] != nullptr) {
cudaGraphExec_t instance =
metas->graph_collections[bc->num_active_requests()]
[bc->num_active_tokens()];
cudaGraphLaunch(instance, stream);
return;
}
// create new cuda graph
cudaGraph_t graph;
cudaGraphExec_t instance;
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);

int ioff = 0, woff = 0, ooff = 0;
Expand Down Expand Up @@ -1149,9 +1149,13 @@ __host__ void
// for (int i = 0; i < fused->numOutputs; i++)
// print_tensor<float>(output_ptr[i], output_domain[i].get_volume(),
// "[Fused:forward:output]");

cudaStreamEndCapture(stream, &graph);
cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);
metas->graph_collections[graph_params] = instance;
metas->graph_collections[bc->num_active_requests()][bc->num_active_tokens()] =
instance;
assert(metas->graph_collections[bc->num_active_requests()]
[bc->num_active_tokens()] != nullptr);
cudaGraphLaunch(instance, stream);
}

Expand Down

0 comments on commit 6f4b4d9

Please sign in to comment.