Skip to content

Commit

Permalink
vsort/vs_onnxruntime.cpp: enables fp16 i/o
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Apr 19, 2024
1 parent 4a4879e commit a65bc0e
Showing 1 changed file with 43 additions and 14 deletions.
57 changes: 43 additions & 14 deletions vsort/vs_onnxruntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,19 @@ static std::variant<std::string, std::array<int64_t, 4>> getShape(
return std::get<std::array<int64_t, 4>>(maybe_shape);
}

static size_t getNumBytes(int32_t type) {
using namespace ONNX_NAMESPACE;

switch (type) {
case TensorProto::FLOAT:
return 4;
case TensorProto::FLOAT16:
return 2;
default:
return 0;
}
}


static int numPlanes(
const std::vector<const VSVideoInfo *> & vis
Expand All @@ -183,8 +196,12 @@ static std::optional<std::string> checkNodes(
) noexcept {

for (const auto & vi : vis) {
if (vi->format->sampleType != stFloat || vi->format->bitsPerSample != 32) {
return "expects clip with type fp32";
if (vi->format->sampleType != stFloat) {
return "expects clip with floating-point type";
}

if (vi->format->bitsPerSample != 32 && vi->format->bitsPerSample != 16) {
return "expects clip with type fp32 or fp16";
}

if (vi->width != vis[0]->width || vi->height != vis[0]->height) {
Expand Down Expand Up @@ -220,8 +237,8 @@ static std::optional<std::string> checkIOInfo(
ONNXTensorElementDataType element_type;
checkError(ortapi->GetTensorElementType(tensor_info, &element_type));

if (element_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
return set_error("expects network IO with type fp32");
if (element_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT && element_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
return set_error("expects network IO with type fp32 or fp16");
}

size_t num_dims;
Expand Down Expand Up @@ -337,16 +354,17 @@ static void setDimensions(
const std::array<int64_t, 4> & input_shape,
const std::array<int64_t, 4> & output_shape,
VSCore * core,
const VSAPI * vsapi
const VSAPI * vsapi,
int32_t onnx_output_type
) noexcept {

vi->height *= output_shape[2] / input_shape[2];
vi->width *= output_shape[3] / input_shape[3];

if (output_shape[1] == 1) {
vi->format = vsapi->registerFormat(cmGray, stFloat, 32, 0, 0, core);
vi->format = vsapi->registerFormat(cmGray, stFloat, 8 * getNumBytes(onnx_output_type), 0, 0, core);
} else if (output_shape[1] == 3) {
vi->format = vsapi->registerFormat(cmRGB, stFloat, 32, 0, 0, core);
vi->format = vsapi->registerFormat(cmRGB, stFloat, 8 * getNumBytes(onnx_output_type), 0, 0, core);
}
}

Expand Down Expand Up @@ -975,6 +993,15 @@ static void VS_CC vsOrtCreate(

rename(onnx_model);

auto onnx_input_type = onnx_model.graph().input()[0].type().tensor_type().elem_type();
auto onnx_output_type = onnx_model.graph().output()[0].type().tensor_type().elem_type();

if (onnx_input_type == ONNX_NAMESPACE::TensorProto::FLOAT && in_vis.front()->format->bitsPerSample != 32) {
return set_error("the onnx requires input to be of type fp32");
} else if (onnx_input_type == ONNX_NAMESPACE::TensorProto::FLOAT16 && in_vis.front()->format->bitsPerSample != 16) {
return set_error("the onnx requires input to be of type fp16");
}

std::string onnx_data = onnx_model.SerializeAsString();
if (std::size(onnx_data) == 0) {
return set_error("proto serialization failed");
Expand Down Expand Up @@ -1142,7 +1169,7 @@ static void VS_CC vsOrtCreate(
input_shape[1] *
input_shape[2] *
input_shape[3]
) * sizeof(float);
) * getNumBytes(onnx_input_type);

checkCUDAError(cudaMallocHost(
&resource.input.h_data, resource.input.size,
Expand All @@ -1154,15 +1181,16 @@ static void VS_CC vsOrtCreate(
memory_info,
resource.input.d_data, resource.input.size,
std::data(input_shape), std::size(input_shape),
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &resource.input_tensor
static_cast<ONNXTensorElementDataType>(onnx_input_type),
&resource.input_tensor
));
} else
#endif // ENALBE_CUDA
{
checkError(ortapi->CreateTensorAsOrtValue(
cpu_allocator,
std::data(input_shape), std::size(input_shape),
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
static_cast<ONNXTensorElementDataType>(onnx_input_type),
&resource.input_tensor
));
}
Expand All @@ -1178,7 +1206,7 @@ static void VS_CC vsOrtCreate(
output_shape[1] *
output_shape[2] *
output_shape[3]
) * sizeof(float);
) * getNumBytes(onnx_output_type);

checkCUDAError(cudaMallocHost(&resource.output.h_data, resource.output.size));
checkCUDAError(cudaMalloc(&resource.output.d_data, resource.output.size));
Expand All @@ -1187,15 +1215,16 @@ static void VS_CC vsOrtCreate(
memory_info,
resource.output.d_data, resource.output.size,
std::data(output_shape), std::size(output_shape),
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &resource.output_tensor
static_cast<ONNXTensorElementDataType>(onnx_input_type),
&resource.output_tensor
));
} else
#endif // ENABLE_CUDA
{
checkError(ortapi->CreateTensorAsOrtValue(
cpu_allocator,
std::data(output_shape), std::size(output_shape),
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
static_cast<ONNXTensorElementDataType>(onnx_input_type),
&resource.output_tensor
));
}
Expand All @@ -1217,7 +1246,7 @@ static void VS_CC vsOrtCreate(
}

if (i == 0) {
setDimensions(d->out_vi, input_shape, output_shape, core, vsapi);
setDimensions(d->out_vi, input_shape, output_shape, core, vsapi, onnx_output_type);
}

d->resources.push_back(resource);
Expand Down

0 comments on commit a65bc0e

Please sign in to comment.