Skip to content

Commit

Permalink
[compiler] register merge_two_modules in python binding (#158)
Browse files Browse the repository at this point in the history
* register `merge_two_modules`.
* throw exception in pybind11 when meet failures.
  • Loading branch information
qingyunqu authored Mar 22, 2024
1 parent 68d06b5 commit c4eeef0
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 14 deletions.
4 changes: 4 additions & 0 deletions compiler/include/byteir-c/Translation.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ MLIR_CAPI_EXPORTED bool byteirSerializeByre(MlirModule module,

MLIR_CAPI_EXPORTED MlirModule byteirDeserializeByre(MlirStringRef artifactStr,
MlirContext context);

MLIR_CAPI_EXPORTED MlirModule byteirMergeTwoModules(MlirModule module0,
MlirModule module1);

#ifdef __cplusplus
}
#endif
Expand Down
3 changes: 2 additions & 1 deletion compiler/include/byteir/Utils/ModuleUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ constexpr llvm::StringRef getByteIREntryPointName() {
// order
// 3. if there are `byteir.entry_point` in only one module, return std::nullopt
// 4. check arguments' shape and dtype on the border
ModuleOp mergeTwoModulesByNameOrOrder(ModuleOp module0, ModuleOp module1);
OwningOpRef<ModuleOp> mergeTwoModulesByNameOrOrder(ModuleOp module0,
ModuleOp module1);

} // namespace mlir

Expand Down
6 changes: 6 additions & 0 deletions compiler/lib/CAPI/Translation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "byteir/Dialect/Byre/Serialization.h"
#include "byteir/Dialect/Byre/Serialization/Versioning.h"
#include "byteir/Target/PTX/ToPTX.h"
#include "byteir/Utils/ModuleUtils.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Utils.h"
Expand Down Expand Up @@ -148,3 +149,8 @@ MlirModule byteirDeserializeByre(MlirStringRef artifactStr,
}
return {m.release()};
}

MlirModule byteirMergeTwoModules(MlirModule module0, MlirModule module1) {
auto m = mergeTwoModulesByNameOrOrder(unwrap(module0), unwrap(module1));
return {m.release()};
}
4 changes: 2 additions & 2 deletions compiler/lib/Utils/ModuleUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ ModuleOp mergeTwoModulesByOrder(ModuleOp module0, ModuleOp module1,

} // namespace

ModuleOp mlir::mergeTwoModulesByNameOrOrder(ModuleOp module0,
ModuleOp module1) {
OwningOpRef<ModuleOp> mlir::mergeTwoModulesByNameOrOrder(ModuleOp module0,
ModuleOp module1) {
assert(module0.getContext() == module1.getContext() &&
"module0 and module1 should have same context");
MLIRContext *context = module0.getContext();
Expand Down
48 changes: 39 additions & 9 deletions compiler/python/ByteIRModules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,29 +99,59 @@ PYBIND11_MODULE(_byteir, m) {
m.def(
"translate_to_ptx",
[](MlirModule module, const std::string &ptxPrefixFileName,
const std::string &gpuArch) -> bool {
return byteirTranslateToPTX(module, toMlirStringRef(ptxPrefixFileName),
toMlirStringRef(gpuArch));
const std::string &gpuArch) {
if (!byteirTranslateToPTX(module, toMlirStringRef(ptxPrefixFileName),
toMlirStringRef(gpuArch))) {
PyErr_SetString(PyExc_ValueError, "failed to translate to ptx");
return;
}
return;
},
py::arg("module"), py::arg("ptx_prefix_file_name"),
py::arg("gpu_arch") = "sm_70");
m.def(
"translate_to_llvmbc",
[](MlirModule module, const std::string &outputFile) -> bool {
return byteirTranslateToLLVMBC(module, toMlirStringRef(outputFile));
[](MlirModule module, const std::string &outputFile) {
if (!byteirTranslateToLLVMBC(module, toMlirStringRef(outputFile))) {
PyErr_SetString(PyExc_ValueError,
"failed to translate to llvm bytecode");
return;
}
return;
},
py::arg("module"), py::arg("output_file"));

m.def(
"serialize_byre",
[](MlirModule module, const std::string &targetVersion,
const std::string &outputFile) -> bool {
return byteirSerializeByre(module, toMlirStringRef(targetVersion),
toMlirStringRef(outputFile));
const std::string &outputFile) {
if (!byteirSerializeByre(module, toMlirStringRef(targetVersion),
toMlirStringRef(outputFile))) {
PyErr_SetString(PyExc_ValueError, "failed to serialize byre");
return;
}
return;
},
py::arg("module"), py::arg("target_version"), py::arg("output_file"));
m.def("deserialize_byre",
[](const std::string &artifactStr, MlirContext context) -> MlirModule {
return byteirDeserializeByre(toMlirStringRef(artifactStr), context);
auto module =
byteirDeserializeByre(toMlirStringRef(artifactStr), context);
if (mlirModuleIsNull(module)) {
PyErr_SetString(PyExc_ValueError, "failed to deserialize byre");
}
return module;
});

m.def(
"merge_two_modules",
[](MlirModule module0, MlirModule module1) -> MlirModule {
auto module = byteirMergeTwoModules(module0, module1);
if (mlirModuleIsNull(module)) {
PyErr_SetString(PyExc_ValueError, "failed to merge two modules");
return {};
}
return module;
},
py::arg("module0"), py::arg("module1"));
}
4 changes: 2 additions & 2 deletions compiler/python/test/api/test_py_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ def test_translate_to_llvmbc():
context = ir.Context()
with open(path, "r") as f:
module = ir.Module.parse(f.read(), context)
assert byteir.translate_to_llvmbc(module, temp_dir.name + "/test.ll.bc")
byteir.translate_to_llvmbc(module, temp_dir.name + "/test.ll.bc")

def test_serialize_byre():
path = TEST_ROOT_DIR + "Dialect/Byre/Serialization/Compatibility/version_1_0_0.mlir"
context = ir.Context()
with open(path, "r") as f:
module = ir.Module.parse(f.read(), context)
assert byteir.serialize_byre(module, "1.0.0", temp_dir.name + "/test.mlir.bc")
byteir.serialize_byre(module, "1.0.0", temp_dir.name + "/test.mlir.bc")

def test_deserialize_byre():
paths = [TEST_ROOT_DIR + "Dialect/Byre/Serialization/Compatibility/version_1_0_0.mlir.bc",
Expand Down

0 comments on commit c4eeef0

Please sign in to comment.