diff --git a/vsort/vs_onnxruntime.cpp b/vsort/vs_onnxruntime.cpp index fc484d1..3541cee 100644 --- a/vsort/vs_onnxruntime.cpp +++ b/vsort/vs_onnxruntime.cpp @@ -941,6 +941,14 @@ static void VS_CC vsOrtCreate( use_cuda_graph = false; } + int output_format = int64ToIntS(vsapi->propGetInt(in, "output_format", 0, &error)); + if (error) { + output_format = 1; + } + if (output_format != 0 && output_format != 1) { + return set_error("\"output_format\" must be 0 or 1"); + } + std::string_view path_view; std::string path; if (path_is_serialization) { @@ -988,7 +996,13 @@ static void VS_CC vsOrtCreate( fp16_blacklist_ops.emplace(vsapi->propGetData(in, "fp16_blacklist_ops", i, nullptr)); } } - convert_float_to_float16(onnx_model, false, fp16_blacklist_ops); + convert_float_to_float16( + onnx_model, + false, + fp16_blacklist_ops, + in_vis.front()->format->bytesPerSample == 4, + output_format == 0 + ); } rename(onnx_model); @@ -1215,7 +1229,7 @@ static void VS_CC vsOrtCreate( memory_info, resource.output.d_data, resource.output.size, std::data(output_shape), std::size(output_shape), - static_cast(onnx_input_type), + static_cast(onnx_output_type), &resource.output_tensor )); } else @@ -1224,7 +1238,7 @@ static void VS_CC vsOrtCreate( checkError(ortapi->CreateTensorAsOrtValue( cpu_allocator, std::data(output_shape), std::size(output_shape), - static_cast(onnx_input_type), + static_cast(onnx_output_type), &resource.output_tensor )); } @@ -1292,6 +1306,7 @@ VS_EXTERNAL_API(void) VapourSynthPluginInit( "use_cuda_graph:int:opt;" "fp16_blacklist_ops:data[]:opt;" "prefer_nhwc:int:opt;" + "output_format:int:opt;" , vsOrtCreate, nullptr, plugin