From 325f54fa5e9c95adbdbbfa949d4067932725070b Mon Sep 17 00:00:00 2001 From: WolframRhodium Date: Mon, 13 May 2024 16:44:06 +0800 Subject: [PATCH] vsov: implement flexible output --- vsov/vs_openvino.cpp | 71 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 58 insertions(+), 13 deletions(-) diff --git a/vsov/vs_openvino.cpp b/vsov/vs_openvino.cpp index ce6dfe3..0e564c1 100644 --- a/vsov/vs_openvino.cpp +++ b/vsov/vs_openvino.cpp @@ -104,7 +104,8 @@ static std::optional checkNodes( [[nodiscard]] static std::optional checkIOInfo( const ov::Output & info, - bool is_output + bool is_output, + bool flexible_output ) { if (info.get_element_type() != ov::element::f32) { @@ -124,8 +125,8 @@ static std::optional checkIOInfo( if (is_output) { auto out_channels = dims[1]; - if (out_channels != 1 && out_channels != 3) { - return "output dimensions must be 1 or 3"; + if (out_channels != 1 && out_channels != 3 && !flexible_output) { + return "output dimensions must be 1 or 3, or enable \"flexible_output\""; } } @@ -135,7 +136,8 @@ static std::optional checkIOInfo( [[nodiscard]] static std::optional checkNetwork( - const std::shared_ptr & network + const std::shared_ptr & network, + bool flexible_output ) { if (auto num_inputs = std::size(network->inputs()); num_inputs != 1) { @@ -143,7 +145,7 @@ static std::optional checkNetwork( } const auto & input_info = network->input(); - if (auto err = checkIOInfo(input_info, false); err.has_value()) { + if (auto err = checkIOInfo(input_info, false, flexible_output); err.has_value()) { return err.value(); } @@ -152,7 +154,7 @@ static std::optional checkNetwork( } const auto & output_info = network->output(); - if (auto err = checkIOInfo(output_info, true); err.has_value()) { + if (auto err = checkIOInfo(output_info, true, flexible_output); err.has_value()) { return err.value(); } @@ -193,7 +195,8 @@ static void setDimensions( std::unique_ptr & vi, const ov::CompiledModel & network, VSCore * core, - const VSAPI * vsapi + const VSAPI * vsapi, + bool flexible_output ) { const auto & in_dims = network.input().get_shape(); @@ -202,7 +205,7 @@ static void setDimensions( vi->height *= out_dims[2] / in_dims[2]; vi->width *= out_dims[3] / in_dims[3]; - if (out_dims[1] == 1) { + if (out_dims[1] == 1 || flexible_output) { vi->format = vsapi->registerFormat(cmGray, stFloat, 32, 0, 0, core); } else if (out_dims[1] == 3) { vi->format = vsapi->registerFormat(cmRGB, stFloat, 32, 0, 0, core); @@ -273,6 +276,8 @@ struct OVData { ov::CompiledModel executable_network; std::unordered_map infer_requests; std::shared_mutex infer_requests_lock; + + std::string flexible_output_prop; }; @@ -344,6 +349,9 @@ static const VSFrameRef *VS_CC vsOvGetFrame( d->out_vi->format, d->out_vi->width, d->out_vi->height, src_frames.front(), core ); + + std::vector dst_frames; + auto dst_stride = vsapi->getStride(dst_frame, 0); auto dst_bytes = vsapi->getFrameFormat(dst_frame)->bytesPerSample; auto dst_tile_shape = getShape(d->executable_network, false); @@ -352,9 +360,21 @@ static const VSFrameRef *VS_CC vsOvGetFrame( auto dst_tile_w_bytes = dst_tile_w * dst_bytes; auto dst_tile_bytes = dst_tile_h * dst_tile_w_bytes; auto dst_planes = dst_tile_shape[1]; - std::array dst_ptrs {}; - for (int i = 0; i < dst_planes; ++i) { - dst_ptrs[i] = vsapi->getWritePtr(dst_frame, i); + + std::vector dst_ptrs; + if (d->flexible_output_prop.empty()) { + for (int i = 0; i < dst_planes; ++i) { + dst_ptrs.emplace_back(vsapi->getWritePtr(dst_frame, i)); + } + } else { + for (int i = 0; i < dst_planes; ++i) { + auto frame { vsapi->newVideoFrame( + d->out_vi->format, d->out_vi->width, d->out_vi->height, + src_frames[0], core + )}; + dst_frames.emplace_back(frame); + dst_ptrs.emplace_back(vsapi->getWritePtr(frame, 0)); + } } auto h_scale = dst_tile_h / src_tile_h; @@ -368,6 +388,10 @@ static const VSFrameRef *VS_CC vsOvGetFrame( vsapi->freeFrame(dst_frame); + for (const auto & frame : dst_frames) { + vsapi->freeFrame(frame); + } + for (const auto & frame : src_frames) { vsapi->freeFrame(frame); } @@ -474,6 +498,16 @@ static const VSFrameRef *VS_CC vsOvGetFrame( vsapi->freeFrame(frame); } + if (!d->flexible_output_prop.empty()) { + auto prop = vsapi->getFramePropsRW(dst_frame); + + for (int i = 0; i < dst_planes; i++) { + auto key { d->flexible_output_prop + std::to_string(i) }; + vsapi->propSetFrame(prop, key.c_str(), dst_frames[i], paReplace); + vsapi->freeFrame(dst_frames[i]); + } + } + return dst_frame; } @@ -615,6 +649,11 @@ static void VS_CC vsOvCreate( path_view = path; } + auto flexible_output_prop = vsapi->propGetData(in, "flexible_output_prop", 0, &error); + if (!error) { + d->flexible_output_prop = flexible_output_prop; + } + auto result = loadONNX(path_view, tile_w, tile_h, path_is_serialization); if (std::holds_alternative(result)) { return set_error(std::get(result)); @@ -657,7 +696,7 @@ static void VS_CC vsOvCreate( return set_error("[Standard exception] ReadNetwork(): "s + e.what()); } - if (auto err = checkNetwork(network); err.has_value()) { + if (auto err = checkNetwork(network, !d->flexible_output_prop.empty()); err.has_value()) { return set_error(err.value()); } @@ -696,13 +735,18 @@ static void VS_CC vsOvCreate( return set_error(err.value()); } - setDimensions(d->out_vi, d->executable_network, core, vsapi); + setDimensions(d->out_vi, d->executable_network, core, vsapi, !d->flexible_output_prop.empty()); VSCoreInfo core_info; vsapi->getCoreInfo2(core, &core_info); d->infer_requests.reserve(core_info.numThreads); } + if (!d->flexible_output_prop.empty()) { + auto num_planes = d->executable_network.output(0).get_shape()[1]; + vsapi->propSetInt(out, "num_planes", static_cast(num_planes), paReplace); + } + vsapi->createFilter( in, out, "Model", vsOvInit, vsOvGetFrame, vsOvFree, @@ -738,6 +782,7 @@ VS_EXTERNAL_API(void) VapourSynthPluginInit( #ifdef ENABLE_VISUALIZATION "dot_path:data:opt;" #endif + "flexible_output_prop:data:opt;" , vsOvCreate, nullptr, plugin