Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use CUDA graph to fuse kernel launches #1251

Closed
wants to merge 19 commits into from
4 changes: 4 additions & 0 deletions include/flexflow/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ struct FFHandler {
#endif
void *workSpace;
size_t workSpaceSize;
void *cublasWorkSpace;
size_t cublasWorkSpaceSize;
void *batch_config_metadata;

// request info + token info + topolopgy mask info
Expand All @@ -97,6 +99,7 @@ struct FFHandler {

struct FFInitInfo {
size_t workSpaceSize;
size_t cublasWorkSpaceSize;
size_t offload_reserve_space_size;
DataType quantization_type;
bool allowTensorOpMathConversion;
Expand Down Expand Up @@ -141,6 +144,7 @@ class FFConfig {
float device_mem; // The device (GPU) memory threshold; given by -ll:fsize
float learningRate, weightDecay;
size_t workSpaceSize;
size_t cublasWorkSpaceSize;
Legion::Context lg_ctx;
Legion::Runtime *lg_hlr;
Legion::IndexSpaceT<1> all_gpu_task_is;
Expand Down
13 changes: 12 additions & 1 deletion include/flexflow/ops/fused.h
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
#ifndef _FLEXFLOW_FUSED_H_
#define _FLEXFLOW_FUSED_H_

#include "flexflow/batch_config.h"
#include "flexflow/model.h"

namespace FlexFlow {

class FusedOp;
class FusedOpMeta {
public:
FusedOpMeta(void) {}
FusedOpMeta(void) {
graph_collections.reserve(BatchConfig::MAX_NUM_REQUESTS *
BatchConfig::MAX_NUM_TOKENS * 2);
}
OpMeta *meta[MAX_NUM_FUSED_OPERATORS];
FusedOp *fused_op;
int numOperators;
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
std::unordered_map<std::tuple<int, int, bool>, cudaGraphExec_t>
graph_collections;
#else
std::unordered_map<std::tuple<int, int, bool>, hipGraphExec_t>
graph_collections;
#endif
};

class FusedOp : public Op {
Expand Down
6 changes: 6 additions & 0 deletions src/ops/aggregate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ void Aggregate::forward_kernel_wrapper(AggregateMeta const *m,
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDA(cublasSetWorkspace(m->handle.blas,
m->handle.cublasWorkSpace,
m->handle.cublasWorkSpaceSize));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));

cudaEvent_t t_start, t_end;
Expand Down Expand Up @@ -266,6 +269,9 @@ void Aggregate::backward_kernel_wrapper(AggregateMeta const *m,
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDA(cublasSetWorkspace(m->handle.blas,
m->handle.cublasWorkSpace,
m->handle.cublasWorkSpaceSize));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));

cudaEvent_t t_start, t_end;
Expand Down
6 changes: 6 additions & 0 deletions src/ops/aggregate_spec.cu
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,9 @@ void AggregateSpec::forward_kernel_wrapper(AggregateSpecMeta const *m,
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDA(cublasSetWorkspace(m->handle.blas,
m->handle.cublasWorkSpace,
m->handle.cublasWorkSpaceSize));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));

// call forward kernel
Expand Down Expand Up @@ -261,6 +264,9 @@ void AggregateSpec::backward_kernel_wrapper(AggregateSpecMeta const *m,
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDA(cublasSetWorkspace(m->handle.blas,
m->handle.cublasWorkSpace,
m->handle.cublasWorkSpaceSize));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));

// call backward kernel
Expand Down
3 changes: 3 additions & 0 deletions src/ops/cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ void Cache::cache_forward(Task const *task,
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDA(cublasSetWorkspace(m->handle.blas,
m->handle.cublasWorkSpace,
m->handle.cublasWorkSpaceSize));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));

cudaMemcpy(output_ptr,
Expand Down
3 changes: 3 additions & 0 deletions src/ops/experts.cu
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ void experts_forward_GemmBatched_kernel(ExpertsMeta const *m,
ffStream_t stream) {

checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDA(cublasSetWorkspace(m->handle.blas,
m->handle.cublasWorkSpace,
m->handle.cublasWorkSpaceSize));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));

float alpha = 1.0f, beta = 0.0f;
Expand Down
34 changes: 32 additions & 2 deletions src/ops/fused.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ __host__ void
Context ctx,
Runtime *runtime) {
// const FusedOp* fused = (FusedOp*) task->args;
FusedOpMeta const *metas = *((FusedOpMeta **)task->local_args);
FusedOpMeta *metas = *((FusedOpMeta **)task->local_args);
FusedOp const *fused = metas->fused_op;
BatchConfig const *bc = BatchConfig::from_future(task->futures[0]);
if (bc->num_tokens == 0) {
Expand Down Expand Up @@ -587,6 +587,11 @@ __host__ void
checkCUDA(get_legion_stream(&stream));
}

// create new hip graph
hipGraph_t graph;
hipGraphExec_t instance;
hipStreamBeginCapture(stream, hipStreamCaptureModeThreadLocal);

int ioff = 0, woff = 0, ooff = 0;
for (int op = 0; op < fused->numOperators; op++) {
GenericTensorAccessorR my_input_accessor[MAX_NUM_INPUTS];
Expand Down Expand Up @@ -1061,6 +1066,31 @@ __host__ void
woff += fused->op_num_weights[op];
ooff += fused->op_num_outputs[op];
}
hipStreamEndCapture(stream, &graph);
std::tuple<int, int, bool> graph_params =
std::make_tuple(bc->num_active_requests(),
bc->num_active_tokens(),
bc->num_generation_tokens > 0);
// check if graph exists
if (metas->graph_collections.find(graph_params) !=
metas->graph_collections.end()) {
instance = metas->graph_collections[graph_params];
hipGraphExecUpdateResult updateResult;
hipGraphNode_t errorNode;
hipGraphExecUpdate(instance, graph, &errorNode, &updateResult);
if (updateResult != hipGraphExecUpdateSuccess) {
hipGraphExecDestroy(instance);
hipGraphInstantiate(&instance, graph, NULL, NULL, 0);
}
} else {
hipGraphInstantiate(&instance, graph, NULL, NULL, 0);
}
metas->graph_collections[graph_params] = instance;
assert(metas->graph_collections.find(graph_params) !=
metas->graph_collections.end());
hipGraphDestroy(graph);
hipGraphLaunch(instance, stream);

// for (int i = 0; i < fused->numOutputs; i++)
// print_tensor<float>(output_ptr[i], output_domain[i].get_volume(),
// "[Fused:forward:output]");
Expand All @@ -1079,7 +1109,7 @@ __host__ void FusedOp::backward_task(Task const *task,
Context ctx,
Runtime *runtime) {
// const FusedOp* fused = (FusedOp*) task->args;
FusedOpMeta const *metas = *((FusedOpMeta **)task->local_args);
FusedOpMeta *metas = *((FusedOpMeta **)task->local_args);
FusedOp const *fused = metas->fused_op;

assert(metas->numOperators == fused->numOperators);
Expand Down
40 changes: 40 additions & 0 deletions src/ops/fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* limitations under the License.
*/

#include "cuda.h"
#include "flexflow/accessor.h"
#include "flexflow/model.h"
#include "flexflow/ops/add_bias_residual_layer_norm.h"
Expand Down Expand Up @@ -608,6 +609,14 @@ __host__ void
}
}

cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));

// create new cuda graph
cudaGraph_t graph;
cudaGraphExec_t instance;
cudaStreamBeginCapture(stream, cudaStreamCaptureModeThreadLocal);

int ioff = 0, woff = 0, ooff = 0;
for (int op = 0; op < fused->numOperators; op++) {
// Domain my_id[MAX_NUM_INPUTS];
Expand Down Expand Up @@ -1132,6 +1141,37 @@ __host__ void
// for (int i = 0; i < fused->numOutputs; i++)
// print_tensor<float>(output_ptr[i], output_domain[i].get_volume(),
// "[Fused:forward:output]");

cudaStreamEndCapture(stream, &graph);
std::tuple<int, int, bool> graph_params =
std::make_tuple(bc->num_active_requests(),
bc->num_active_tokens(),
bc->num_generation_tokens > 0);
// check if graph exists
if (metas->graph_collections.find(graph_params) !=
metas->graph_collections.end()) {
instance = metas->graph_collections[graph_params];
#if defined(CUDA_VERSION) && (CUDA_VERSION < 12000)
cudaGraphExecUpdateResult updateResult;
cudaGraphNode_t errorNode;
cudaGraphExecUpdate(instance, graph, &errorNode, &updateResult);
bool update_failed = (updateResult != cudaGraphExecUpdateSuccess);
#else
cudaError_t update_result = cudaGraphExecUpdate(instance, graph, NULL);
bool update_failed = (update_result != cudaSuccess);
#endif
if (update_failed) {
cudaGraphExecDestroy(instance);
cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);
}
} else {
cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);
}
metas->graph_collections[graph_params] = instance;
assert(metas->graph_collections.find(graph_params) !=
metas->graph_collections.end());
cudaGraphDestroy(graph);
cudaGraphLaunch(instance, stream);
}

/*
Expand Down
10 changes: 5 additions & 5 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -492,9 +492,6 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m,
DT *output_ptr,
DT const *bias_ptr,
cudaStream_t stream) {

checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
assert(m->qSize == m->vSize && m->qSize == m->kSize);
cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]);
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
Expand Down Expand Up @@ -807,6 +804,11 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta const *m,
DT *output_ptr,
DT const *bias_ptr,
cudaStream_t stream) {
checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDA(cublasSetWorkspace(m->handle.blas,
m->handle.cublasWorkSpace,
m->handle.cublasWorkSpaceSize));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));

if (m->offload && m->biasSize > 0) {
cudaMemcpyAsync(
Expand Down Expand Up @@ -900,8 +902,6 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m,
DT const *bias_ptr,
DT const *weight_ptr,
cudaStream_t stream) {
checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]);
cudnnDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]);
assert(data_type_size(m->output_type[0]) == sizeof(DT));
Expand Down
6 changes: 6 additions & 0 deletions src/ops/kernels/batch_matmul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ void forward_kernel(BatchMatmulMeta const *meta,
int b_seq_length_dim,
int seq_length) {
checkCUDA(cublasSetStream(meta->handle.blas, stream));
checkCUDA(cublasSetWorkspace(meta->handle.blas,
meta->handle.cublasWorkSpace,
meta->handle.cublasWorkSpaceSize));
checkCUDNN(cudnnSetStream(meta->handle.dnn, stream));

// int a_stride = n * k;
Expand Down Expand Up @@ -213,6 +216,9 @@ void backward_kernel(BatchMatmulMeta const *meta,
int batch,
cudaStream_t stream) {
checkCUDA(cublasSetStream(meta->handle.blas, stream));
checkCUDA(cublasSetWorkspace(meta->handle.blas,
meta->handle.cublasWorkSpace,
meta->handle.cublasWorkSpaceSize));
checkCUDNN(cudnnSetStream(meta->handle.dnn, stream));

int a_stride = n * k;
Expand Down
6 changes: 6 additions & 0 deletions src/ops/kernels/element_binary_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,9 @@ void forward_kernel(ElementBinaryMeta const *m,
DT *out_ptr,
cudaStream_t stream) {
checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDA(cublasSetWorkspace(m->handle.blas,
m->handle.cublasWorkSpace,
m->handle.cublasWorkSpaceSize));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
float alpha1 = 1.0f, alpha2 = 1.0f, beta = 0.0f;
switch (m->op_type) {
Expand Down Expand Up @@ -419,6 +422,9 @@ void backward_kernel(ElementBinaryMeta const *m,
float *in2_grad_ptr,
cudaStream_t stream) {
checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDA(cublasSetWorkspace(m->handle.blas,
m->handle.cublasWorkSpace,
m->handle.cublasWorkSpaceSize));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));

if (m->op_type == OP_EW_ADD || m->op_type == OP_EW_SUB) {
Expand Down
6 changes: 6 additions & 0 deletions src/ops/kernels/linear_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,9 @@ void forward_kernel(LinearMeta const *m,
}
}
checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDA(cublasSetWorkspace(m->handle.blas,
m->handle.cublasWorkSpace,
m->handle.cublasWorkSpaceSize));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
DT alpha = 1.0f, beta = 0.0f;
cudaDataType_t input_type = ff_to_cuda_datatype(m->input_type[0]);
Expand Down Expand Up @@ -421,6 +424,9 @@ void backward_kernel(LinearMeta const *m,
int batch_size,
ffStream_t stream) {
checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDA(cublasSetWorkspace(m->handle.blas,
m->handle.cublasWorkSpace,
m->handle.cublasWorkSpaceSize));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));

DT alpha = 1.0f;
Expand Down
9 changes: 6 additions & 3 deletions src/ops/spec_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,6 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m,
DT const *bias_ptr,
DT const *weight_ptr,
cudaStream_t stream) {
checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]);
cudnnDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]);
assert(data_type_size(m->output_type[0]) == sizeof(DT));
Expand Down Expand Up @@ -707,8 +705,13 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m,
DT *output_ptr,
DT const *bias_ptr,
cudaStream_t stream) {
// phase 1: Implement kernel to compute KQV for input tokens
checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDA(cublasSetWorkspace(m->handle.blas,
m->handle.cublasWorkSpace,
m->handle.cublasWorkSpaceSize));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));

// phase 1: Implement kernel to compute KQV for input tokens
compute_qkv_kernel(m,
bc,
shard_id,
Expand Down
7 changes: 5 additions & 2 deletions src/ops/tree_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,6 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m,
DT const *bias_ptr,
DT const *weight_ptr,
cudaStream_t stream) {
checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]);
cudnnDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]);
assert(data_type_size(m->output_type[0]) == sizeof(DT));
Expand Down Expand Up @@ -877,6 +875,11 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m,
DT *output_ptr,
DT const *bias_ptr,
cudaStream_t stream) {
checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDA(cublasSetWorkspace(m->handle.blas,
m->handle.cublasWorkSpace,
m->handle.cublasWorkSpaceSize));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
// additional processing for weight uploading
if (m->handle.offload_reserve_space != nullptr) {
// Note that we update weight_ptr and bias_ptr when uploading weight and
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1507,6 +1507,7 @@ FFRuntime::FFRuntime(FFConfig &config) {
// info.myRank = rank++;
// info.allRanks = config.workersPerNode * config.numNodes;
info.workSpaceSize = config.workSpaceSize;
info.cublasWorkSpaceSize = config.cublasWorkSpaceSize;
info.offload_reserve_space_size =
config.cpu_offload ? config.offload_reserve_space_size : 0;
info.quantization_type = config.quantization_type;
Expand Down Expand Up @@ -4019,6 +4020,7 @@ struct DefaultConfig {
constexpr static float learningRate = 0.01f;
constexpr static float weightDecay = 0.0001f;
const static size_t workSpaceSize = (size_t)128 * 1024 * 1024; // 128 MB
const static size_t cublasWorkSpaceSize = (size_t)4096 * 8 * 1024;
const static int numNodes = 1;
const static int workersPerNode = 0;
const static int cpusPerNode = 0;
Expand Down Expand Up @@ -4054,6 +4056,7 @@ FFConfig::FFConfig() {
learningRate = DefaultConfig::learningRate;
weightDecay = DefaultConfig::weightDecay;
workSpaceSize = DefaultConfig::workSpaceSize;
cublasWorkSpaceSize = DefaultConfig::cublasWorkSpaceSize;
numNodes = DefaultConfig::numNodes;
cpusPerNode = DefaultConfig::cpusPerNode;
workersPerNode = DefaultConfig::workersPerNode;
Expand Down
Loading
Loading