From f439a9c6c10cfd90256c614255f786f070f348e1 Mon Sep 17 00:00:00 2001 From: WolframRhodium Date: Fri, 19 Apr 2024 11:32:37 +0800 Subject: [PATCH] vsort/vs_onnxruntime.cpp: add `output_format` param --- vsort/vs_onnxruntime.cpp | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/vsort/vs_onnxruntime.cpp b/vsort/vs_onnxruntime.cpp index fc484d1..ed86816 100644 --- a/vsort/vs_onnxruntime.cpp +++ b/vsort/vs_onnxruntime.cpp @@ -941,6 +941,14 @@ static void VS_CC vsOrtCreate( use_cuda_graph = false; } + int output_format = int64ToIntS(vsapi->propGetInt(in, "output_format", 0, &error)); + if (error) { + output_format = 1; + } + if (output_format != 0 && output_format != 1) { + return set_error("\"output_format\" must be 0 or 1"); + } + std::string_view path_view; std::string path; if (path_is_serialization) { @@ -988,7 +996,13 @@ static void VS_CC vsOrtCreate( fp16_blacklist_ops.emplace(vsapi->propGetData(in, "fp16_blacklist_ops", i, nullptr)); } } - convert_float_to_float16(onnx_model, false, fp16_blacklist_ops); + convert_float_to_float16( + onnx_model, + false, + fp16_blacklist_ops, + in_vis.front()->format->bytesPerSample == 4, + output_format == 0 + ); } rename(onnx_model); @@ -1292,6 +1306,7 @@ VS_EXTERNAL_API(void) VapourSynthPluginInit( "use_cuda_graph:int:opt;" "fp16_blacklist_ops:data[]:opt;" "prefer_nhwc:int:opt;" + "output_format:int:opt;" , vsOrtCreate, nullptr, plugin