Skip to content

Commit

Permalink
vsort/vs_onnxruntime.cpp: use custom cuda stream
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Apr 20, 2024
1 parent 84539e5 commit 1a0eb7c
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions vsort/vs_onnxruntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,9 +634,8 @@ static const VSFrameRef *VS_CC vsOrtGetFrame(
));

#if ORT_API_VERSION < 16
// TODO: fix cudnn issue
#endif // ORT_API_VERSION < 16
checkCUDAError(cudaStreamSynchronize(resource.stream));
#endif // ORT_API_VERSION < 16
}
#endif // ENABLE_CUDA

Expand Down Expand Up @@ -1078,6 +1077,8 @@ static void VS_CC vsOrtCreate(
// TODO: other providers
#ifdef ENABLE_CUDA
if (d->backend == Backend::CUDA) {
checkCUDAError(cudaStreamCreateWithFlags(&resource.stream, cudaStreamNonBlocking));

OrtCUDAProviderOptionsV2 * cuda_options;
checkError(ortapi->CreateCUDAProviderOptions(&cuda_options));
#ifdef _MSC_VER
Expand Down Expand Up @@ -1131,11 +1132,11 @@ static void VS_CC vsOrtCreate(
checkError(ortapi->UpdateCUDAProviderOptions(cuda_options, keys, values, std::size(keys)));

#if ORT_API_VERSION >= 16
// checkError(ortapi->UpdateCUDAProviderOptionsWithValue(
// cuda_options,
// "user_compute_stream",
// resource.stream
// ));
checkError(ortapi->UpdateCUDAProviderOptionsWithValue(
cuda_options,
"user_compute_stream",
resource.stream
));
#endif // ORT_API_VERSION >= 16

checkError(ortapi->SessionOptionsAppendExecutionProvider_CUDA_V2(session_options, cuda_options));
Expand Down Expand Up @@ -1178,8 +1179,6 @@ static void VS_CC vsOrtCreate(

#ifdef ENABLE_CUDA
if (d->backend == Backend::CUDA) {
checkCUDAError(cudaStreamCreateWithFlags(&resource.stream, cudaStreamNonBlocking));

resource.input.size = (
input_shape[0] *
input_shape[1] *
Expand Down

0 comments on commit 1a0eb7c

Please sign in to comment.