From 187249d59cdb5c6733983adc19fc80b95fd3bcde Mon Sep 17 00:00:00 2001 From: WolframRhodium Date: Sat, 20 Apr 2024 11:21:42 +0800 Subject: [PATCH] vsort/vs_onnxruntime.cpp: disable ort cuda stream sync --- vsort/vs_onnxruntime.cpp | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/vsort/vs_onnxruntime.cpp b/vsort/vs_onnxruntime.cpp index ecd5725..c775c06 100644 --- a/vsort/vs_onnxruntime.cpp +++ b/vsort/vs_onnxruntime.cpp @@ -25,6 +25,7 @@ using namespace std::chrono_literals; #define NOMINMAX #include +#include #ifdef ENABLE_CUDA #include @@ -571,9 +572,23 @@ static const VSFrameRef *VS_CC vsOrtGetFrame( return nullptr; }; + OrtRunOptions * run_options {}; + #ifdef ENABLE_CUDA if (d->backend == Backend::CUDA) { checkCUDAError(cudaSetDevice(d->device_id)); + +#if ORT_API_VERSION >= 16 + checkError(ortapi->CreateRunOptions(&run_options)); + if (run_options == nullptr) { + return set_error("create run_options failed"); + } + checkError(ortapi->AddRunConfigEntry( + run_options, + kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, + "1" + )); +#endif // ORT_API_VERSION >= 16 } #endif // ENABLE_CUDA @@ -650,17 +665,17 @@ static const VSFrameRef *VS_CC vsOrtGetFrame( // note that this applies only to stream capture from the ort library // this fails when another plugin also uses global-mode stream capture std::lock_guard _ { capture_lock }; - checkError(ortapi->RunWithBinding(resource.session, nullptr, resource.binding)); + checkError(ortapi->RunWithBinding(resource.session, run_options, resource.binding)); // onnxruntime replays the graph itself in CUDAExecutionProvider::OnRunEnd } else #endif // ENABLE_CUDA if (d->backend == Backend::CPU || d->backend == Backend::CUDA) { - checkError(ortapi->RunWithBinding(resource.session, nullptr, resource.binding)); + checkError(ortapi->RunWithBinding(resource.session, run_options, resource.binding)); } else { checkError(ortapi->Run( resource.session, - nullptr, + run_options, &resource.input_name, &resource.input_tensor, 1, @@ -741,6 +756,10 @@ static const VSFrameRef *VS_CC vsOrtGetFrame( y = std::min(y + step_h, src_height - src_tile_h); } + if (run_options) { + ortapi->ReleaseRunOptions(run_options); + } + d->release(ticket); for (const auto & frame : src_frames) {