Skip to content

Commit

Permalink
common/convert_float_to_float16: optional io type casting
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Apr 19, 2024
1 parent a65bc0e commit 6b33f0e
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 88 deletions.
170 changes: 88 additions & 82 deletions common/convert_float_to_float16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,14 @@ static ONNX_NAMESPACE::ValueInfoProto make_value_info_from_tensor(

void convert_float_to_float16(
ONNX_NAMESPACE::ModelProto & model,
bool force_fp16_initializers
// , bool keep_io_types = True
// , bool disable_shape_infer = True
// , const std::optional<std::unordered_set<std::string>> op_block_list = DEFAULT_OP_BLOCK_LIST
// , const std::optional<std::unordered_set<std::string>> op_block_list = {}
, const std::unordered_set<std::string> & op_block_list
bool force_fp16_initializers,
// bool keep_io_types = True,
// bool disable_shape_infer = True,
// const std::optional<std::unordered_set<std::string>> op_block_list = DEFAULT_OP_BLOCK_LIST,
// const std::optional<std::unordered_set<std::string>> op_block_list = {},
const std::unordered_set<std::string> & op_block_list,
bool cast_input,
bool cast_output
) noexcept {

std::vector<ONNX_NAMESPACE::ValueInfoProto> value_info_list {};
Expand All @@ -307,97 +309,101 @@ void convert_float_to_float16(
std::unordered_map<std::string, std::string> name_mapping {};
std::unordered_set<std::string> graph_io_to_skip {};

const std::vector<std::string> fp32_inputs = [&]() {
std::vector<std::string> ret {};
if (cast_input) {
const std::vector<std::string> fp32_inputs = [&]() {
std::vector<std::string> ret {};

for (const auto & n : model.graph().input()) {
if (n.type().tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto::FLOAT) {
ret.emplace_back(n.name());
for (const auto & n : model.graph().input()) {
if (n.type().tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto::FLOAT) {
ret.emplace_back(n.name());
}
}
}

return ret;
}();

for (const auto & n : model.graph().input()) {
if (auto idx = std::find(std::cbegin(fp32_inputs), std::cend(fp32_inputs), n.name());
idx != std::cend(fp32_inputs)
) {
const auto i = idx - std::cbegin(fp32_inputs);
std::string node_name = "graph_input_cast_" + std::to_string(i);
name_mapping.emplace(n.name(), node_name);
graph_io_to_skip.emplace(n.name());

auto * new_value_info = model.mutable_graph()->mutable_value_info()->Add();
new_value_info->CopyFrom(n);
new_value_info->set_name(node_name);
new_value_info->mutable_type()->mutable_tensor_type()->set_elem_type(
ONNX_NAMESPACE::TensorProto::FLOAT16
);
// add Cast node (from tensor(float) to tensor(float16) after graph input
for (auto & node : *model.mutable_graph()->mutable_node()) {
for (auto & input : *node.mutable_input()) {
if (input == n.name()) {
input = node_name;
return ret;
}();

for (const auto & n : model.graph().input()) {
if (auto idx = std::find(std::cbegin(fp32_inputs), std::cend(fp32_inputs), n.name());
idx != std::cend(fp32_inputs)
) {
const auto i = idx - std::cbegin(fp32_inputs);
std::string node_name = "graph_input_cast_" + std::to_string(i);
name_mapping.emplace(n.name(), node_name);
graph_io_to_skip.emplace(n.name());

auto * new_value_info = model.mutable_graph()->mutable_value_info()->Add();
new_value_info->CopyFrom(n);
new_value_info->set_name(node_name);
new_value_info->mutable_type()->mutable_tensor_type()->set_elem_type(
ONNX_NAMESPACE::TensorProto::FLOAT16
);
// add Cast node (from tensor(float) to tensor(float16) after graph input
for (auto & node : *model.mutable_graph()->mutable_node()) {
for (auto & input : *node.mutable_input()) {
if (input == n.name()) {
input = node_name;
}
}
}
auto new_node = make_node(
"Cast", {n.name()}, {node_name}, node_name,
"to", ONNX_NAMESPACE::TensorProto::FLOAT16
);
model.mutable_graph()->mutable_node()->Add();
for (int i = model.graph().node_size() - 2; i >= 0; --i) {
model.mutable_graph()->mutable_node()->SwapElements(i, i + 1);
}
*model.mutable_graph()->mutable_node(0) = std::move(new_node);
value_info_list.emplace_back(*new_value_info);
io_casts.emplace(std::move(node_name));
}
auto new_node = make_node(
"Cast", {n.name()}, {node_name}, node_name,
"to", ONNX_NAMESPACE::TensorProto::FLOAT16
);
model.mutable_graph()->mutable_node()->Add();
for (int i = model.graph().node_size() - 2; i >= 0; --i) {
model.mutable_graph()->mutable_node()->SwapElements(i, i + 1);
}
*model.mutable_graph()->mutable_node(0) = std::move(new_node);
value_info_list.emplace_back(*new_value_info);
io_casts.emplace(std::move(node_name));
}
}

const std::vector<std::string> fp32_outputs = [&]() {
std::vector<std::string> ret {};
if (cast_output) {
const std::vector<std::string> fp32_outputs = [&]() {
std::vector<std::string> ret {};

for (const auto & n : model.graph().output()) {
if (n.type().tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto::FLOAT) {
ret.emplace_back(n.name());
for (const auto & n : model.graph().output()) {
if (n.type().tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto::FLOAT) {
ret.emplace_back(n.name());
}
}
}

return ret;
}();

for (const auto & n : model.graph().output()) {
if (auto idx = std::find(std::cbegin(fp32_outputs), std::cend(fp32_outputs), n.name());
idx != std::cend(fp32_outputs)
) {
const auto i = idx - std::cbegin(fp32_outputs);
std::string node_name = "graph_output_cast_" + std::to_string(i);
name_mapping.emplace(n.name(), node_name);
graph_io_to_skip.emplace(n.name());

auto * new_value_info = model.mutable_graph()->mutable_value_info()->Add();
new_value_info->CopyFrom(n);
new_value_info->set_name(node_name);
new_value_info->mutable_type()->mutable_tensor_type()->set_elem_type(
ONNX_NAMESPACE::TensorProto::FLOAT16
);
// add Cast node (from tensor(float16) to tensor(float) before graph output
for (auto & node : *model.mutable_graph()->mutable_node()) {
for (auto & output : *node.mutable_output()) {
if (output == n.name()) {
output = node_name;
return ret;
}();

for (const auto & n : model.graph().output()) {
if (auto idx = std::find(std::cbegin(fp32_outputs), std::cend(fp32_outputs), n.name());
idx != std::cend(fp32_outputs)
) {
const auto i = idx - std::cbegin(fp32_outputs);
std::string node_name = "graph_output_cast_" + std::to_string(i);
name_mapping.emplace(n.name(), node_name);
graph_io_to_skip.emplace(n.name());

auto * new_value_info = model.mutable_graph()->mutable_value_info()->Add();
new_value_info->CopyFrom(n);
new_value_info->set_name(node_name);
new_value_info->mutable_type()->mutable_tensor_type()->set_elem_type(
ONNX_NAMESPACE::TensorProto::FLOAT16
);
// add Cast node (from tensor(float16) to tensor(float) before graph output
for (auto & node : *model.mutable_graph()->mutable_node()) {
for (auto & output : *node.mutable_output()) {
if (output == n.name()) {
output = node_name;
}
}
}
auto new_node = make_node(
"Cast", {node_name}, {n.name()}, node_name,
"to", ONNX_NAMESPACE::TensorProto::FLOAT
);
model.mutable_graph()->mutable_node()->Add(std::move(new_node));
value_info_list.emplace_back(*new_value_info);
io_casts.emplace(std::move(node_name));
}
auto new_node = make_node(
"Cast", {node_name}, {n.name()}, node_name,
"to", ONNX_NAMESPACE::TensorProto::FLOAT
);
model.mutable_graph()->mutable_node()->Add(std::move(new_node));
value_info_list.emplace_back(*new_value_info);
io_casts.emplace(std::move(node_name));
}
}

Expand Down
14 changes: 8 additions & 6 deletions common/convert_float_to_float16.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

void convert_float_to_float16(
ONNX_NAMESPACE::ModelProto & model,
bool force_fp16_initializers
// , bool keep_io_types = True
// , bool disable_shape_infer = True
// , const std::optional<std::unordered_set<std::string>> op_block_list = DEFAULT_OP_BLOCK_LIST
// , const std::optional<std::unordered_set<std::string>> op_block_list = {}
, const std::unordered_set<std::string> & op_block_list
bool force_fp16_initializers,
// bool keep_io_types = True,
// bool disable_shape_infer = True,
// const std::optional<std::unordered_set<std::string>> op_block_list = DEFAULT_OP_BLOCK_LIST,
// const std::optional<std::unordered_set<std::string>> op_block_list = {},
const std::unordered_set<std::string> & op_block_list,
bool cast_input = true,
bool cast_output = true
) noexcept;

#endif

0 comments on commit 6b33f0e

Please sign in to comment.