diff --git a/cinn/hlir/op/reduction.cc b/cinn/hlir/op/reduction.cc index f4d0f82a9a..21ff118abd 100644 --- a/cinn/hlir/op/reduction.cc +++ b/cinn/hlir/op/reduction.cc @@ -465,17 +465,18 @@ std::vector InferShapeForReduction(const std::vector &inputs_s if (attrs.find("keep_dim") != attrs.end()) { keep_dim = absl::get(attrs.at("keep_dim")); } - CHECK(!dim.empty()) << "should have reduce dim, please check!"; - CHECK_LE(dim.size(), inputs_shape[0].size()) << "reduce dim should no more than the input size"; std::vector out_shapes; - auto ndim = inputs_shape[0].size(); - for (size_t i = 0; i < ndim; ++i) { - if (std::find(dim.begin(), dim.end(), i) != dim.end()) { - if (keep_dim) { - out_shapes.push_back(1); + if (!dim.empty()) { + CHECK_LE(dim.size(), inputs_shape[0].size()) << "reduce dim should no more than the input size"; + auto ndim = inputs_shape[0].size(); + for (size_t i = 0; i < ndim; ++i) { + if (std::find(dim.begin(), dim.end(), i) != dim.end()) { + if (keep_dim) { + out_shapes.push_back(1); + } + } else { + out_shapes.push_back(inputs_shape[0][i]); } - } else { - out_shapes.push_back(inputs_shape[0][i]); } } diff --git a/cinn/pybind/frontend.cc b/cinn/pybind/frontend.cc old mode 100755 new mode 100644 index 6e9c9da1cf..0d0abbecc1 --- a/cinn/pybind/frontend.cc +++ b/cinn/pybind/frontend.cc @@ -431,6 +431,24 @@ void BindFrontend(pybind11::module *m) { py::arg("padding_algorithm") = "EXPLICIT") .def("sum", &NetBuilder::sum, py::arg("inputs")); + py::enum_(*m, "ComparisonKind") + .value("kUnk", ComparisonKind::kUnk) + .value("kEq", ComparisonKind::kEq) + .value("kNe", ComparisonKind::kNe) + .value("kGe", ComparisonKind::kGe) + .value("kGt", ComparisonKind::kGt) + .value("kLe", ComparisonKind::kLe) + .value("kLt", ComparisonKind::kLt) + .export_values(); + + py::enum_(*m, "ReduceKind") + .value("kUnk", ReduceKind::kUnk) + .value("kSum", ReduceKind::kSum) + .value("kProd", ReduceKind::kProd) + .value("kMax", ReduceKind::kMax) + .value("kMin", ReduceKind::kMin) + .export_values(); + py::class_(*m, "CinnBuilder") .def(py::init(), py::arg("name") = "") .def("const_scalar", &CinnBuilder::ConstScalar) @@ -455,12 +473,12 @@ void BindFrontend(pybind11::module *m) { py::arg("data_format") = "NCHW", py::arg("padding_algorithm") = "EXPLICIT", py::arg("output_shape") = std::vector{}) - .def("compare", &CinnBuilder::Compare, py::arg("lhs"), py::arg("rhs"), py::arg("kind")) + .def("compare", &CinnBuilder::Compare, py::arg("lhs"), py::arg("rhs"), py::arg("kind") = ComparisonKind::kEq) .def("reduce", &CinnBuilder::Reduce, py::arg("operand"), - py::arg("kind"), - py::arg("dim"), + py::arg("kind") = ReduceKind::kSum, + py::arg("dim") = std::vector{}, py::arg("keep_dim") = false) .def("broadcast_to", &CinnBuilder::BroadcastTo, diff --git a/python/tests/test_cinnbuilder.py b/python/tests/test_cinnbuilder.py index a92881e553..f1944c6f92 100644 --- a/python/tests/test_cinnbuilder.py +++ b/python/tests/test_cinnbuilder.py @@ -32,6 +32,32 @@ enable_gpu = sys.argv.pop() +class TestCinnBuildBasic(unittest.TestCase): + def setUp(self): + pass + + def test_compare(self): + builder = CinnBuilder("test_compare") + a = builder.create_input(Float(32), (1, 24, 56, 56), "A") + b = builder.create_input(Float(32), (1, 24, 56, 56), "B") + # default compare kind is ComparisonKind.kEq + c = builder.compare(a, b) + d = builder.compare(a, c, ComparisonKind.kNe) + prog = builder.build() + for i in range(prog.size()): + print(prog[i]) + + def test_reduce(self): + builder = CinnBuilder("test_compare") + a = builder.create_input(Float(32), (1, 24, 56, 56), "A") + b = builder.reduce(a) + c = builder.reduce(a, ReduceKind.kMax) + d = builder.add(b, c) + prog = builder.build() + for i in range(prog.size()): + print(prog[i]) + + class TestCinnBuilder(unittest.TestCase): def setUp(self): if enable_gpu == "ON":