Skip to content

Commit

Permalink
fix var updates, fix test ci
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Oct 20, 2023
1 parent 036f3f6 commit 30073cc
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 8 deletions.
3 changes: 2 additions & 1 deletion enzyme/Enzyme/Clang/EnzymePassLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ static RegisterStandardPasses
clangtoolLoader_OEarly(PassManagerBuilder::EP_EarlyAsPossible,
loadNVVMPass);
static RegisterStandardPasses
clangtoolLoader_Ox(PassManagerBuilder::EP_VectorizerStart, loadBLASPass);
clangtoolBlasLoader_Ox(PassManagerBuilder::EP_VectorizerStart,
loadBLASPass);

static void loadLTOPass(const PassManagerBuilder &Builder,
legacy::PassManagerBase &PM) {
Expand Down
5 changes: 4 additions & 1 deletion enzyme/Enzyme/OptBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,11 @@ bool cmp_or_set(llvm::CallInst *CI, std::vector<llvm::Value *> values) {
if (values[i] == nullptr) {
continue;
}
if (CI->getArgOperand(i) != values[i])
if (CI->getArgOperand(i) != values[i]) {
llvm::errs() << "cmp_or_set: " << i << " " << *CI->getArgOperand(i)
<< " != " << *values[i] << "\n";
return false;
}
}
for (size_t i = 0; i < values.size(); ++i) {
if (values[i] == nullptr) {
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2354,7 +2354,7 @@ std::optional<BlasInfo> extractBLAS(llvm::StringRef in)
llvm::Optional<BlasInfo> extractBLAS(llvm::StringRef in)
#endif
{
const char *extractable[] = {"dot", "scal", "axpy", "gemv", "gemm", "spmv"};
const char *extractable[] = {"dot", "scal", "axpy", "gemv", "gemm", "spmv", "ger"};
const char *floatType[] = {"s", "d"}; // c, z
const char *prefixes[] = {"" /*Fortran*/, "cblas_"};
const char *suffixes[] = {"", "_", "64_", "_64_"};
Expand Down
12 changes: 12 additions & 0 deletions enzyme/test/Enzyme/BlasOpt/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Run regression and unit tests
add_lit_testsuite(check-enzyme-blasopt "Running enzyme blas optimization regression tests"
${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${ENZYME_TEST_DEPS}
ARGS -v
)

set_target_properties(check-enzyme-blasopt PROPERTIES FOLDER "Tests")

# add_lit_testsuites(ENZYME ${CMAKE_CURRENT_SOURCE_DIR}
# DEPENDS ${ENZYME_TEST_DEPS}
# )
1 change: 1 addition & 0 deletions enzyme/test/Enzyme/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_subdirectory(ForwardModeSplit)
add_subdirectory(ForwardModeVector)
add_subdirectory(BatchMode)
add_subdirectory(ProbProg)
add_subdirectory(BlasOpt)

# Run regression and unit tests
add_lit_testsuite(check-enzyme "Running enzyme regression tests"
Expand Down
38 changes: 33 additions & 5 deletions enzyme/tools/enzyme-tblgen/opts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,23 @@ void emitBlasOpt(StringRef name, std::vector<DagInit *> inputs,

os << " // create a vector of calls to delete\n";
os << " std::vector<CallInst *> todelete;\n";
os << " int num_calls = 0;\n";

os << " for (auto &BB : *F) {\n"
<< " for (auto &I : BB) {\n"
<< " if (auto *CI = dyn_cast<CallInst>(&I)) {\n"
<< " auto name = CI->getCalledFunction()->getName();\n";
<< " num_calls++;\n"
<< " auto CIname = CI->getCalledFunction()->getName();\n"
<< " auto blasOption = extractBLAS(CIname);\n"
<< "#if LLVM_VERSION_MAJOR >= 16\n"
<< " if (!blasOption.has_value()) continue;\n"
<< " auto blas = blasOption.value();\n"
<< "#else\n"
<< " if (!blasOption.hasValue()) continue;\n"
<< " auto blas = blasOption.getValue();\n"
<< "#endif\n";
for (auto fnc : unique_functions.keys()) {
os << " if (name == \"" << fnc << "\") {\n";
os << " if (blas.function == \"" << fnc << "\") {\n";
std::string tab = " ";
auto fnc_vec = unique_functions[fnc];
bool multiple = fnc_vec.size() > 1;
Expand All @@ -127,12 +137,28 @@ void emitBlasOpt(StringRef name, std::vector<DagInit *> inputs,
os << "};\n";
}
os << tab << "bool set = cmp_or_set(CI, values);\n";
os << tab << "if (!set) continue;\n";
os << tab << "if (!set) {\n";
os << tab << " llvm::errs() << \"args missmatch: " << fnc << "\";\n";
os << tab << " continue;\n";
os << tab << "}\n";
for (size_t i = 0; i < fnc_vec[0]->getNumArgs(); i++) {
os << tab << "values[" << i << "] = CI->getArgOperand(" << i << ");\n";
}
for (size_t i = 0; i < fnc_vec.size(); ++i) {
ArrayRef<StringInit *> args = fnc_vec[i]->getArgNames();
size_t pos = 0;
for (auto arg : args) {
os << tab << arg->getValue() << " = CI->getArgOperand(" << i << ");\n";
pos++;
}
}
os << tab << "llvm::errs() << \"found " << fnc << "\\n\";\n";
os << tab << "idx_" << fnc << "++;\n"
<< tab << "todelete.push_back(CI);\n"
<< tab << "continue;\n"
<< " }\n";
}
os << " llvm::errs() << \"unhandled: \" << blas.function << \"\\n\";\n";
os << " }\n";
os << " }\n";
os << " }\n";
Expand All @@ -143,8 +169,10 @@ void emitBlasOpt(StringRef name, std::vector<DagInit *> inputs,
os << " if (idx_" << fnc << " != " << unique_functions[fnc].size() << ")\n"
<< " found = false;\n";
}
os << " if (!found)\n"
<< " return false;\n";
os << " if (!found) {\n"
<< " llvm::errs() << \"num calls: \" << num_calls << \"\\n\";\n"
<< " return false;\n"
<< " }\n";

os << " llvm::errs() << \"found optimization " << name << "\\n\";\n";

Expand Down

0 comments on commit 30073cc

Please sign in to comment.