Skip to content

Commit

Permalink
vsort/vs_onnxruntime.cpp: disable ort cuda stream sync
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Apr 20, 2024
1 parent 1a0eb7c commit 187249d
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions vsort/vs_onnxruntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ using namespace std::chrono_literals;
#define NOMINMAX

#include <onnxruntime_c_api.h>
#include <onnxruntime_run_options_config_keys.h>

#ifdef ENABLE_CUDA
#include <cuda_runtime.h>
Expand Down Expand Up @@ -571,9 +572,23 @@ static const VSFrameRef *VS_CC vsOrtGetFrame(
return nullptr;
};

OrtRunOptions * run_options {};

#ifdef ENABLE_CUDA
if (d->backend == Backend::CUDA) {
checkCUDAError(cudaSetDevice(d->device_id));

#if ORT_API_VERSION >= 16
checkError(ortapi->CreateRunOptions(&run_options));
if (run_options == nullptr) {
return set_error("create run_options failed");
}
checkError(ortapi->AddRunConfigEntry(
run_options,
kOrtRunOptionsConfigDisableSynchronizeExecutionProviders,
"1"
));
#endif // ORT_API_VERSION >= 16
}
#endif // ENABLE_CUDA

Expand Down Expand Up @@ -650,17 +665,17 @@ static const VSFrameRef *VS_CC vsOrtGetFrame(
// note that this applies only to stream capture from the ort library
// this fails when another plugin also uses global-mode stream capture
std::lock_guard _ { capture_lock };
checkError(ortapi->RunWithBinding(resource.session, nullptr, resource.binding));
checkError(ortapi->RunWithBinding(resource.session, run_options, resource.binding));

// onnxruntime replays the graph itself in CUDAExecutionProvider::OnRunEnd
} else
#endif // ENABLE_CUDA
if (d->backend == Backend::CPU || d->backend == Backend::CUDA) {
checkError(ortapi->RunWithBinding(resource.session, nullptr, resource.binding));
checkError(ortapi->RunWithBinding(resource.session, run_options, resource.binding));
} else {
checkError(ortapi->Run(
resource.session,
nullptr,
run_options,
&resource.input_name,
&resource.input_tensor,
1,
Expand Down Expand Up @@ -741,6 +756,10 @@ static const VSFrameRef *VS_CC vsOrtGetFrame(
y = std::min(y + step_h, src_height - src_tile_h);
}

if (run_options) {
ortapi->ReleaseRunOptions(run_options);
}

d->release(ticket);

for (const auto & frame : src_frames) {
Expand Down

0 comments on commit 187249d

Please sign in to comment.