diff --git a/vsort/vs_onnxruntime.cpp b/vsort/vs_onnxruntime.cpp index 290c1d2..0bf0c9d 100644 --- a/vsort/vs_onnxruntime.cpp +++ b/vsort/vs_onnxruntime.cpp @@ -227,7 +227,8 @@ static std::optional checkNodes( [[nodiscard]] static std::optional checkIOInfo( const OrtTypeInfo * info, - bool is_output + bool is_output, + bool flexible_output ) noexcept { const auto set_error = [](const std::string & error_message) { @@ -262,8 +263,8 @@ static std::optional checkIOInfo( if (is_output) { int64_t out_channels = shape[1]; - if (out_channels != 1 && out_channels != 3) { - return "output dimensions must be 1 or 3"; + if (out_channels != 1 && out_channels != 3 && !flexible_output) { + return "output dimensions must be 1 or 3, or enable \"flexible_output\""; } } @@ -273,7 +274,8 @@ static std::optional checkIOInfo( [[nodiscard]] static std::optional checkSession( - const OrtSession * session + const OrtSession * session, + bool flexible_output ) noexcept { const auto set_error = [](const std::string & error_message) { @@ -290,7 +292,7 @@ static std::optional checkSession( OrtTypeInfo * input_type_info; checkError(ortapi->SessionGetInputTypeInfo(session, 0, &input_type_info)); - if (auto err = checkIOInfo(input_type_info, false); err.has_value()) { + if (auto err = checkIOInfo(input_type_info, false, flexible_output); err.has_value()) { return set_error(err.value()); } @@ -306,7 +308,7 @@ static std::optional checkSession( OrtTypeInfo * output_type_info; checkError(ortapi->SessionGetOutputTypeInfo(session, 0, &output_type_info)); - if (auto err = checkIOInfo(output_type_info, true); err.has_value()) { + if (auto err = checkIOInfo(output_type_info, true, flexible_output); err.has_value()) { return set_error(err.value()); } @@ -375,13 +377,14 @@ static void setDimensions( const std::array & output_shape, VSCore * core, const VSAPI * vsapi, - int32_t onnx_output_type + int32_t onnx_output_type, + bool flexible_output ) noexcept { vi->height *= output_shape[2] / input_shape[2]; vi->width *= output_shape[3] / input_shape[3]; - if (output_shape[1] == 1) { + if (output_shape[1] == 1 || flexible_output) { vi->format = vsapi->registerFormat(cmGray, stFloat, 8 * getNumBytes(onnx_output_type), 0, 0, core); } else if (output_shape[1] == 3) { vi->format = vsapi->registerFormat(cmRGB, stFloat, 8 * getNumBytes(onnx_output_type), 0, 0, core); @@ -463,6 +466,8 @@ struct vsOrtData { std::mutex ticket_lock; TicketSemaphore semaphore; + std::string flexible_output_prop; + int acquire() noexcept { semaphore.acquire(); { @@ -535,6 +540,9 @@ static const VSFrameRef *VS_CC vsOrtGetFrame( d->out_vi->format, d->out_vi->width, d->out_vi->height, src_frames.front(), core ); + + std::vector dst_frames; + auto dst_stride = vsapi->getStride(dst_frame, 0); auto dst_bytes = vsapi->getFrameFormat(dst_frame)->bytesPerSample; @@ -564,9 +572,21 @@ static const VSFrameRef *VS_CC vsOrtGetFrame( auto dst_tile_w_bytes = dst_tile_w * dst_bytes; auto dst_tile_bytes = dst_tile_h * dst_tile_w_bytes; auto dst_planes = dst_tile_shape[1]; - uint8_t * dst_ptrs[3] {}; - for (int i = 0; i < dst_planes; ++i) { - dst_ptrs[i] = vsapi->getWritePtr(dst_frame, i); + + std::vector dst_ptrs; + if (d->flexible_output_prop.empty()) { + for (int i = 0; i < dst_planes; ++i) { + dst_ptrs.emplace_back(vsapi->getWritePtr(dst_frame, i)); + } + } else { + for (int i = 0; i < dst_planes; ++i) { + auto frame { vsapi->newVideoFrame( + d->out_vi->format, d->out_vi->width, d->out_vi->height, + src_frames[0], core + )}; + dst_frames.emplace_back(frame); + dst_ptrs.emplace_back(vsapi->getWritePtr(frame, 0)); + } } auto h_scale = dst_tile_h / src_tile_h; @@ -580,6 +600,10 @@ static const VSFrameRef *VS_CC vsOrtGetFrame( d->release(ticket); + for (const auto & frame : dst_frames) { + vsapi->freeFrame(frame); + } + vsapi->freeFrame(dst_frame); for (const auto & frame : src_frames) { @@ -783,6 +807,16 @@ static const VSFrameRef *VS_CC vsOrtGetFrame( vsapi->freeFrame(frame); } + if (!d->flexible_output_prop.empty()) { + auto prop = vsapi->getFramePropsRW(dst_frame); + + for (int i = 0; i < dst_planes; i++) { + auto key { d->flexible_output_prop + std::to_string(i) }; + vsapi->propSetFrame(prop, key.c_str(), dst_frames[i], paReplace); + vsapi->freeFrame(dst_frames[i]); + } + } + return dst_frame; } @@ -991,6 +1025,11 @@ static void VS_CC vsOrtCreate( return set_error("\"output_format\" must be 0 or 1"); } + auto flexible_output_prop = vsapi->propGetData(in, "flexible_output_prop", 0, &error); + if (!error) { + d->flexible_output_prop = flexible_output_prop; + } + std::string_view path_view; std::string path; if (path_is_serialization) { @@ -1215,7 +1254,7 @@ static void VS_CC vsOrtCreate( ortapi->ReleaseSessionOptions(session_options); - if (auto err = checkSession(resource.session); err.has_value()) { + if (auto err = checkSession(resource.session, !d->flexible_output_prop.empty()); err.has_value()) { return set_error(err.value()); } @@ -1307,7 +1346,19 @@ static void VS_CC vsOrtCreate( } if (i == 0) { - setDimensions(d->out_vi, input_shape, output_shape, core, vsapi, onnx_output_type); + setDimensions( + d->out_vi, + input_shape, + output_shape, + core, vsapi, + onnx_output_type, + !d->flexible_output_prop.empty() + ); + + if (!d->flexible_output_prop.empty()) { + auto num_planes = output_shape[1]; + vsapi->propSetInt(out, "num_planes", static_cast(num_planes), paReplace); + } } d->resources.push_back(resource); @@ -1355,6 +1406,7 @@ VS_EXTERNAL_API(void) VapourSynthPluginInit( "prefer_nhwc:int:opt;" "output_format:int:opt;" "tf32:int:opt;" + "flexible_output_prop:data:opt;" , vsOrtCreate, nullptr, plugin