diff --git a/common/convert_float_to_float16.cpp b/common/convert_float_to_float16.cpp index 0cd0294..23a216e 100644 --- a/common/convert_float_to_float16.cpp +++ b/common/convert_float_to_float16.cpp @@ -6,7 +6,6 @@ #include #include #include -#include #include #include #include @@ -19,11 +18,7 @@ #include - -void convert_float_to_float16( - ONNX_NAMESPACE::ModelProto & model, - bool force_fp16_initializers -) noexcept; +#include "convert_float_to_float16.h" namespace { diff --git a/common/convert_float_to_float16.h b/common/convert_float_to_float16.h new file mode 100644 index 0000000..670012a --- /dev/null +++ b/common/convert_float_to_float16.h @@ -0,0 +1,19 @@ +#ifndef CONVERT_FLOAT_TO_FLOAT16_H +#define CONVERT_FLOAT_TO_FLOAT16_H + +#include +#include + +#include + +void convert_float_to_float16( + ONNX_NAMESPACE::ModelProto & model, + bool force_fp16_initializers + // , bool keep_io_types = True + // , bool disable_shape_infer = True + // , const std::optional> op_block_list = DEFAULT_OP_BLOCK_LIST + // , const std::optional> op_block_list = {} + , const std::unordered_set & op_block_list +) noexcept; + +#endif diff --git a/common/onnx_utils.cpp b/common/onnx_utils.cpp index 9c22898..5de7b67 100644 --- a/common/onnx_utils.cpp +++ b/common/onnx_utils.cpp @@ -8,6 +8,8 @@ #include #include +#include "onnx_utils.h" + using namespace std::string_literals; diff --git a/common/onnx_utils.h b/common/onnx_utils.h new file mode 100644 index 0000000..7041ab7 --- /dev/null +++ b/common/onnx_utils.h @@ -0,0 +1,18 @@ +#ifndef ONNX_UTILS_H +#define ONNX_UTILS_H + +#include +#include +#include +#include + +#include + +std::variant loadONNX( + const std::string_view & path, + int64_t tile_w, + int64_t tile_h, + bool path_is_serialization +) noexcept; + +#endif diff --git a/vsncnn/vs_ncnn.cpp b/vsncnn/vs_ncnn.cpp index 8b06821..cabe182 100644 --- a/vsncnn/vs_ncnn.cpp +++ b/vsncnn/vs_ncnn.cpp @@ -22,17 +22,12 @@ #include #include -#include "config.h" // generated by cmake #include -#include "onnx2ncnn.hpp" +#include "../common/onnx_utils.h" +#include "onnx2ncnn.hpp" -extern std::variant loadONNX( - const std::string_view & path, - int64_t tile_w, - int64_t tile_h, - bool path_is_serialization -) noexcept; +#include "config.h" // generated by cmake static const VSPlugin * myself = nullptr; diff --git a/vsort/vs_onnxruntime.cpp b/vsort/vs_onnxruntime.cpp index f463d27..2a62436 100644 --- a/vsort/vs_onnxruntime.cpp +++ b/vsort/vs_onnxruntime.cpp @@ -34,21 +34,10 @@ using namespace std::chrono_literals; #include #endif // ENABLE_DML -#include "config.h" - +#include "../common/convert_float_to_float16.h" +#include "../common/onnx_utils.h" -extern std::variant loadONNX( - const std::string_view & path, - int64_t tile_w, - int64_t tile_h, - bool path_is_serialization -) noexcept; - -extern void convert_float_to_float16( - ONNX_NAMESPACE::ModelProto & model, - bool force_fp16_initializers, - const std::unordered_set & op_block_list -) noexcept; +#include "config.h" #ifdef ENABLE_COREML diff --git a/vsov/vs_openvino.cpp b/vsov/vs_openvino.cpp index 063c448..ce6dfe3 100644 --- a/vsov/vs_openvino.cpp +++ b/vsov/vs_openvino.cpp @@ -27,21 +27,10 @@ #include #endif // ENABLE_VISUALIZATION -#include "config.h" - +#include "../common/convert_float_to_float16.h" +#include "../common/onnx_utils.h" -extern std::variant loadONNX( - const std::string_view & path, - int64_t tile_w, - int64_t tile_h, - bool path_is_serialization -) noexcept; - -extern void convert_float_to_float16( - ONNX_NAMESPACE::ModelProto & model, - bool force_fp16_initializers, - const std::unordered_set & op_block_list -) noexcept; +#include "config.h" using namespace std::string_literals;