Skip to content

Commit

Permalink
vsort/vs_onnxruntime.cpp: use cuda stream in ort
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Apr 18, 2024
1 parent cb9d488 commit 0b3aa2b
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions vsort/vs_onnxruntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -625,10 +625,9 @@ static const VSFrameRef *VS_CC vsOrtGetFrame(
));

#if ORT_API_VERSION < 16
// OrtCUDAProviderOptionsV2 disallows using custom user stream
// and the inference is executed on a private non-blocking stream
checkCUDAError(cudaStreamSynchronize(resource.stream));
// TODO: fix cudnn issue
#endif // ORT_API_VERSION < 16
checkCUDAError(cudaStreamSynchronize(resource.stream));
}
#endif // ENABLE_CUDA

Expand Down Expand Up @@ -908,10 +907,12 @@ static void VS_CC vsOrtCreate(
cudnn_benchmark = true;
}

#if ORT_API_VERSION >= 17
bool prefer_nhwc = !!(vsapi->propGetInt(in, "prefer_nhwc", 0, &error));
if (error) {
prefer_nhwc = false;
}
#endif // ORT_API_VERSION >= 17
#endif // ENABLE_CUDA

if (auto err = ortInit(); err.has_value()) {
Expand Down Expand Up @@ -1066,41 +1067,43 @@ static void VS_CC vsOrtCreate(
"cudnn_conv_use_max_workspace",
"arena_extend_strategy",
"enable_cuda_graph",
#if ORT_API_VERSION >= 16
"do_copy_in_default_stream",
#endif // ORT_API_VERSION >= 16
#if ORT_API_VERSION >= 17
"prefer_nhwc",
#endif // ORT_API_VERSION >= 17
};
auto device_id_str = std::to_string(d->device_id);
const char * values [] {
device_id_str.c_str(),
cudnn_benchmark ? "EXHAUSTIVE" : "HEURISTIC",
"EXHAUSTIVE",
"1",
"kSameAsRequested",
"0",
#if ORT_API_VERSION >= 16
"0",
#endif // ORT_API_VERSION >= 16
#if ORT_API_VERSION >= 17
prefer_nhwc ? "1" : "0",
"0",
#endif // ORT_API_VERSION >= 17
};
if (!cudnn_benchmark) {
values[1] = "HEURISTIC";
}
if (use_cuda_graph) {
values[4] = "1";
resource.require_replay = true;
} else {
resource.require_replay = false;
}
#if ORT_API_VERSION >= 17
if (prefer_nhwc) {
values[5] = "1";
}
#endif // ORT_API_VERSION >= 17
checkError(ortapi->UpdateCUDAProviderOptions(cuda_options, keys, values, std::size(keys)));

#if ORT_API_VERSION >= 16
checkError(ortapi->UpdateCUDAProviderOptionsWithValue(
cuda_options,
"user_compute_stream",
resource.stream
));
// checkError(ortapi->UpdateCUDAProviderOptionsWithValue(
// cuda_options,
// "user_compute_stream",
// resource.stream
// ));
#endif // ORT_API_VERSION >= 16

checkError(ortapi->SessionOptionsAppendExecutionProvider_CUDA_V2(session_options, cuda_options));
Expand Down

0 comments on commit 0b3aa2b

Please sign in to comment.