diff --git a/src/onnx/parse_resize.cpp b/src/onnx/parse_resize.cpp index cf015a53414..785efd93aff 100644 --- a/src/onnx/parse_resize.cpp +++ b/src/onnx/parse_resize.cpp @@ -33,49 +33,28 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace onnx { -static std::vector -calc_neighbor_points(const std::vector>>& vvv_ind, - int i_dim, - std::vector> vec_dims, - const shape& in_s) +static void calc_neighbor_points(const std::vector>>& vvv_ind, + const std::size_t& n_dims, + const std::size_t& out_elements, + const shape& in_s, + std::vector& vec_ind) { - if(i_dim == vvv_ind.size()) + for(std::size_t start = 0; start < (std::size_t{1} << n_dims); start++) { - std::vector vec_ind(vec_dims.size()); - std::transform(vec_dims.begin(), vec_dims.end(), vec_ind.begin(), [&](auto idx) { - return static_cast(in_s.index(idx)); - }); - return vec_ind; - } - - const auto& vv_lo = vvv_ind[i_dim][0]; - std::vector> vec_dims1; - for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size()) - { - std::transform(vv_lo.begin(), - vv_lo.end(), - vec_dims.begin() + start, - std::back_inserter(vec_dims1), - [](auto i, auto dim) { - dim.push_back(i); - return dim; - }); - } - - const auto& vv_hi = vvv_ind[i_dim][1]; - for(std::size_t start = 0; start < vec_dims.size(); start += vv_hi.size()) - { - std::transform(vv_hi.begin(), - vv_hi.end(), - vec_dims.begin() + start, - std::back_inserter(vec_dims1), - [](auto i, auto dim) { - dim.push_back(i); - return dim; - }); + std::vector idx(n_dims); + for(std::size_t idx_e = 0; idx_e < out_elements; idx_e++) + { + std::size_t bi = start; + idx.clear(); + for(std::size_t dim = 0; dim < n_dims; dim++) + { + idx.push_back((bi & std::size_t{1}) ? vvv_ind[dim][1][idx_e] + : vvv_ind[dim][0][idx_e]); + bi = bi >> std::size_t{1}; + } + vec_ind.push_back(in_s.index(idx)); + } } - vec_dims.clear(); - return calc_neighbor_points(vvv_ind, i_dim + 1, std::move(vec_dims1), in_s); } static std::string get_coord_trans_mode(const onnx_parser::attribute_map& attr) @@ -375,8 +354,9 @@ struct parse_resize : op_parser } }); - auto ind = calc_neighbor_points( - vvv_ind, 0, std::vector>(out_elements), in_s); + std::vector ind; + calc_neighbor_points(vvv_ind, n_dim, out_elements, in_s, ind); + auto ind_lens = out_lens; ind_lens[0] *= (std::size_t{1} << n_dim); shape ind_s{shape::int32_type, ind_lens};