diff --git a/vsort/vs_onnxruntime.cpp b/vsort/vs_onnxruntime.cpp index 2a62436..fc484d1 100644 --- a/vsort/vs_onnxruntime.cpp +++ b/vsort/vs_onnxruntime.cpp @@ -162,6 +162,19 @@ static std::variant> getShape( return std::get>(maybe_shape); } +static size_t getNumBytes(int32_t type) { + using namespace ONNX_NAMESPACE; + + switch (type) { + case TensorProto::FLOAT: + return 4; + case TensorProto::FLOAT16: + return 2; + default: + return 0; + } +} + static int numPlanes( const std::vector & vis @@ -183,8 +196,12 @@ static std::optional checkNodes( ) noexcept { for (const auto & vi : vis) { - if (vi->format->sampleType != stFloat || vi->format->bitsPerSample != 32) { - return "expects clip with type fp32"; + if (vi->format->sampleType != stFloat) { + return "expects clip with floating-point type"; + } + + if (vi->format->bitsPerSample != 32 && vi->format->bitsPerSample != 16) { + return "expects clip with type fp32 or fp16"; } if (vi->width != vis[0]->width || vi->height != vis[0]->height) { @@ -220,8 +237,8 @@ static std::optional checkIOInfo( ONNXTensorElementDataType element_type; checkError(ortapi->GetTensorElementType(tensor_info, &element_type)); - if (element_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { - return set_error("expects network IO with type fp32"); + if (element_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT && element_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { + return set_error("expects network IO with type fp32 or fp16"); } size_t num_dims; @@ -337,16 +354,17 @@ static void setDimensions( const std::array & input_shape, const std::array & output_shape, VSCore * core, - const VSAPI * vsapi + const VSAPI * vsapi, + int32_t onnx_output_type ) noexcept { vi->height *= output_shape[2] / input_shape[2]; vi->width *= output_shape[3] / input_shape[3]; if (output_shape[1] == 1) { - vi->format = vsapi->registerFormat(cmGray, stFloat, 32, 0, 0, core); + vi->format = vsapi->registerFormat(cmGray, stFloat, 8 * getNumBytes(onnx_output_type), 0, 0, core); } else if (output_shape[1] == 3) { - vi->format = vsapi->registerFormat(cmRGB, stFloat, 32, 0, 0, core); + vi->format = vsapi->registerFormat(cmRGB, stFloat, 8 * getNumBytes(onnx_output_type), 0, 0, core); } } @@ -975,6 +993,15 @@ static void VS_CC vsOrtCreate( rename(onnx_model); + auto onnx_input_type = onnx_model.graph().input()[0].type().tensor_type().elem_type(); + auto onnx_output_type = onnx_model.graph().output()[0].type().tensor_type().elem_type(); + + if (onnx_input_type == ONNX_NAMESPACE::TensorProto::FLOAT && in_vis.front()->format->bitsPerSample != 32) { + return set_error("the onnx requires input to be of type fp32"); + } else if (onnx_input_type == ONNX_NAMESPACE::TensorProto::FLOAT16 && in_vis.front()->format->bitsPerSample != 16) { + return set_error("the onnx requires input to be of type fp16"); + } + std::string onnx_data = onnx_model.SerializeAsString(); if (std::size(onnx_data) == 0) { return set_error("proto serialization failed"); @@ -1142,7 +1169,7 @@ static void VS_CC vsOrtCreate( input_shape[1] * input_shape[2] * input_shape[3] - ) * sizeof(float); + ) * getNumBytes(onnx_input_type); checkCUDAError(cudaMallocHost( &resource.input.h_data, resource.input.size, @@ -1154,7 +1181,8 @@ static void VS_CC vsOrtCreate( memory_info, resource.input.d_data, resource.input.size, std::data(input_shape), std::size(input_shape), - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &resource.input_tensor + static_cast(onnx_input_type), + &resource.input_tensor )); } else #endif // ENALBE_CUDA @@ -1162,7 +1190,7 @@ static void VS_CC vsOrtCreate( checkError(ortapi->CreateTensorAsOrtValue( cpu_allocator, std::data(input_shape), std::size(input_shape), - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + static_cast(onnx_input_type), &resource.input_tensor )); } @@ -1178,7 +1206,7 @@ static void VS_CC vsOrtCreate( output_shape[1] * output_shape[2] * output_shape[3] - ) * sizeof(float); + ) * getNumBytes(onnx_output_type); checkCUDAError(cudaMallocHost(&resource.output.h_data, resource.output.size)); checkCUDAError(cudaMalloc(&resource.output.d_data, resource.output.size)); @@ -1187,7 +1215,8 @@ static void VS_CC vsOrtCreate( memory_info, resource.output.d_data, resource.output.size, std::data(output_shape), std::size(output_shape), - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &resource.output_tensor + static_cast(onnx_input_type), + &resource.output_tensor )); } else #endif // ENABLE_CUDA @@ -1195,7 +1224,7 @@ static void VS_CC vsOrtCreate( checkError(ortapi->CreateTensorAsOrtValue( cpu_allocator, std::data(output_shape), std::size(output_shape), - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + static_cast(onnx_input_type), &resource.output_tensor )); } @@ -1217,7 +1246,7 @@ static void VS_CC vsOrtCreate( } if (i == 0) { - setDimensions(d->out_vi, input_shape, output_shape, core, vsapi); + setDimensions(d->out_vi, input_shape, output_shape, core, vsapi, onnx_output_type); } d->resources.push_back(resource);