Skip to content

Commit

Permalink
vsort: implement flexible output
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed May 13, 2024
1 parent 325f54f commit 21a842a
Showing 1 changed file with 65 additions and 13 deletions.
78 changes: 65 additions & 13 deletions vsort/vs_onnxruntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ static std::optional<std::string> checkNodes(
[[nodiscard]]
static std::optional<std::string> checkIOInfo(
const OrtTypeInfo * info,
bool is_output
bool is_output,
bool flexible_output
) noexcept {

const auto set_error = [](const std::string & error_message) {
Expand Down Expand Up @@ -262,8 +263,8 @@ static std::optional<std::string> checkIOInfo(

if (is_output) {
int64_t out_channels = shape[1];
if (out_channels != 1 && out_channels != 3) {
return "output dimensions must be 1 or 3";
if (out_channels != 1 && out_channels != 3 && !flexible_output) {
return "output dimensions must be 1 or 3, or enable \"flexible_output\"";
}
}

Expand All @@ -273,7 +274,8 @@ static std::optional<std::string> checkIOInfo(

[[nodiscard]]
static std::optional<std::string> checkSession(
const OrtSession * session
const OrtSession * session,
bool flexible_output
) noexcept {

const auto set_error = [](const std::string & error_message) {
Expand All @@ -290,7 +292,7 @@ static std::optional<std::string> checkSession(
OrtTypeInfo * input_type_info;
checkError(ortapi->SessionGetInputTypeInfo(session, 0, &input_type_info));

if (auto err = checkIOInfo(input_type_info, false); err.has_value()) {
if (auto err = checkIOInfo(input_type_info, false, flexible_output); err.has_value()) {
return set_error(err.value());
}

Expand All @@ -306,7 +308,7 @@ static std::optional<std::string> checkSession(
OrtTypeInfo * output_type_info;
checkError(ortapi->SessionGetOutputTypeInfo(session, 0, &output_type_info));

if (auto err = checkIOInfo(output_type_info, true); err.has_value()) {
if (auto err = checkIOInfo(output_type_info, true, flexible_output); err.has_value()) {
return set_error(err.value());
}

Expand Down Expand Up @@ -375,13 +377,14 @@ static void setDimensions(
const std::array<int64_t, 4> & output_shape,
VSCore * core,
const VSAPI * vsapi,
int32_t onnx_output_type
int32_t onnx_output_type,
bool flexible_output
) noexcept {

vi->height *= output_shape[2] / input_shape[2];
vi->width *= output_shape[3] / input_shape[3];

if (output_shape[1] == 1) {
if (output_shape[1] == 1 || flexible_output) {
vi->format = vsapi->registerFormat(cmGray, stFloat, 8 * getNumBytes(onnx_output_type), 0, 0, core);
} else if (output_shape[1] == 3) {
vi->format = vsapi->registerFormat(cmRGB, stFloat, 8 * getNumBytes(onnx_output_type), 0, 0, core);
Expand Down Expand Up @@ -463,6 +466,8 @@ struct vsOrtData {
std::mutex ticket_lock;
TicketSemaphore semaphore;

std::string flexible_output_prop;

int acquire() noexcept {
semaphore.acquire();
{
Expand Down Expand Up @@ -535,6 +540,9 @@ static const VSFrameRef *VS_CC vsOrtGetFrame(
d->out_vi->format, d->out_vi->width, d->out_vi->height,
src_frames.front(), core
);

std::vector<VSFrameRef *> dst_frames;

auto dst_stride = vsapi->getStride(dst_frame, 0);
auto dst_bytes = vsapi->getFrameFormat(dst_frame)->bytesPerSample;

Expand Down Expand Up @@ -564,9 +572,21 @@ static const VSFrameRef *VS_CC vsOrtGetFrame(
auto dst_tile_w_bytes = dst_tile_w * dst_bytes;
auto dst_tile_bytes = dst_tile_h * dst_tile_w_bytes;
auto dst_planes = dst_tile_shape[1];
uint8_t * dst_ptrs[3] {};
for (int i = 0; i < dst_planes; ++i) {
dst_ptrs[i] = vsapi->getWritePtr(dst_frame, i);

std::vector<uint8_t *> dst_ptrs;
if (d->flexible_output_prop.empty()) {
for (int i = 0; i < dst_planes; ++i) {
dst_ptrs.emplace_back(vsapi->getWritePtr(dst_frame, i));
}
} else {
for (int i = 0; i < dst_planes; ++i) {
auto frame { vsapi->newVideoFrame(
d->out_vi->format, d->out_vi->width, d->out_vi->height,
src_frames[0], core
)};
dst_frames.emplace_back(frame);
dst_ptrs.emplace_back(vsapi->getWritePtr(frame, 0));
}
}

auto h_scale = dst_tile_h / src_tile_h;
Expand All @@ -580,6 +600,10 @@ static const VSFrameRef *VS_CC vsOrtGetFrame(

d->release(ticket);

for (const auto & frame : dst_frames) {
vsapi->freeFrame(frame);
}

vsapi->freeFrame(dst_frame);

for (const auto & frame : src_frames) {
Expand Down Expand Up @@ -783,6 +807,16 @@ static const VSFrameRef *VS_CC vsOrtGetFrame(
vsapi->freeFrame(frame);
}

if (!d->flexible_output_prop.empty()) {
auto prop = vsapi->getFramePropsRW(dst_frame);

for (int i = 0; i < dst_planes; i++) {
auto key { d->flexible_output_prop + std::to_string(i) };
vsapi->propSetFrame(prop, key.c_str(), dst_frames[i], paReplace);
vsapi->freeFrame(dst_frames[i]);
}
}

return dst_frame;
}

Expand Down Expand Up @@ -991,6 +1025,11 @@ static void VS_CC vsOrtCreate(
return set_error("\"output_format\" must be 0 or 1");
}

auto flexible_output_prop = vsapi->propGetData(in, "flexible_output_prop", 0, &error);
if (!error) {
d->flexible_output_prop = flexible_output_prop;
}

std::string_view path_view;
std::string path;
if (path_is_serialization) {
Expand Down Expand Up @@ -1215,7 +1254,7 @@ static void VS_CC vsOrtCreate(

ortapi->ReleaseSessionOptions(session_options);

if (auto err = checkSession(resource.session); err.has_value()) {
if (auto err = checkSession(resource.session, !d->flexible_output_prop.empty()); err.has_value()) {
return set_error(err.value());
}

Expand Down Expand Up @@ -1307,7 +1346,19 @@ static void VS_CC vsOrtCreate(
}

if (i == 0) {
setDimensions(d->out_vi, input_shape, output_shape, core, vsapi, onnx_output_type);
setDimensions(
d->out_vi,
input_shape,
output_shape,
core, vsapi,
onnx_output_type,
!d->flexible_output_prop.empty()
);

if (!d->flexible_output_prop.empty()) {
auto num_planes = output_shape[1];
vsapi->propSetInt(out, "num_planes", static_cast<int>(num_planes), paReplace);
}
}

d->resources.push_back(resource);
Expand Down Expand Up @@ -1355,6 +1406,7 @@ VS_EXTERNAL_API(void) VapourSynthPluginInit(
"prefer_nhwc:int:opt;"
"output_format:int:opt;"
"tf32:int:opt;"
"flexible_output_prop:data:opt;"
, vsOrtCreate,
nullptr,
plugin
Expand Down

0 comments on commit 21a842a

Please sign in to comment.