Skip to content

Commit

Permalink
vsort/vs_onnxruntime.cpp: add tf32 flag to the cuda provider
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Apr 20, 2024
1 parent 61682d2 commit 7d55aa7
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions vsort/vs_onnxruntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,11 @@ static void VS_CC vsOrtCreate(
prefer_nhwc = false;
}
#endif // ORT_API_VERSION >= 17

bool tf32 = !!(vsapi->propGetInt(in, "tf32", 0, &error));
if (error) {
tf32 = false;
}
#endif // ENABLE_CUDA

if (auto err = ortInit(); err.has_value()) {
Expand Down Expand Up @@ -1121,6 +1126,7 @@ static void VS_CC vsOrtCreate(
"enable_cuda_graph",
#if ORT_API_VERSION >= 17
"prefer_nhwc",
"use_tf32",
#endif // ORT_API_VERSION >= 17
};
auto device_id_str = std::to_string(d->device_id);
Expand All @@ -1132,6 +1138,7 @@ static void VS_CC vsOrtCreate(
"0",
#if ORT_API_VERSION >= 17
"0",
"0",
#endif // ORT_API_VERSION >= 17
};
if (!cudnn_benchmark) {
Expand All @@ -1147,6 +1154,9 @@ static void VS_CC vsOrtCreate(
if (prefer_nhwc) {
values[5] = "1";
}
if (tf32) {
values[6] = "1";
}
#endif // ORT_API_VERSION >= 17
checkError(ortapi->UpdateCUDAProviderOptions(cuda_options, keys, values, std::size(keys)));

Expand Down Expand Up @@ -1327,6 +1337,7 @@ VS_EXTERNAL_API(void) VapourSynthPluginInit(
"fp16_blacklist_ops:data[]:opt;"
"prefer_nhwc:int:opt;"
"output_format:int:opt;"
"use_tf32:int:opt;"
, vsOrtCreate,
nullptr,
plugin
Expand Down

0 comments on commit 7d55aa7

Please sign in to comment.