From d790fb530579682cd4f7a8a772cefba09282bdbf Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 3 Aug 2023 16:21:42 -0400 Subject: [PATCH 01/17] wip --- enzyme/Enzyme/Enzyme.cpp | 5 + enzyme/Enzyme/OptBlas.cpp | 125 +++++++++++++++++++++++++ enzyme/Enzyme/OptBlas.h | 48 ++++++++++ enzyme/test/Enzyme/BlasOpt/blasCOpt.ll | 34 +++++++ 4 files changed, 212 insertions(+) create mode 100644 enzyme/Enzyme/OptBlas.cpp create mode 100644 enzyme/Enzyme/OptBlas.h create mode 100644 enzyme/test/Enzyme/BlasOpt/blasCOpt.ll diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index a8ef011150db..1cbf84e499c6 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -2881,6 +2881,7 @@ class EnzymeNewPM final : public EnzymeBase, AnalysisKey EnzymeNewPM::Key; #include "ActivityAnalysisPrinter.h" +#include "OptBlas.h" #include "PreserveNVVM.h" #include "TypeAnalysis/TypeAnalysisPrinter.h" #include "llvm/Passes/PassBuilder.h" @@ -3417,6 +3418,10 @@ llvmGetPassPluginInfo() { PB.registerPipelineParsingCallback( [](llvm::StringRef Name, llvm::ModulePassManager &MPM, llvm::ArrayRef) { + if (Name == "blas-opt") { + MPM.addPass(createOptimizeBlasPass(/*Begin*/ true)); + return true; + } if (Name == "enzyme") { MPM.addPass(EnzymeNewPM()); return true; diff --git a/enzyme/Enzyme/OptBlas.cpp b/enzyme/Enzyme/OptBlas.cpp new file mode 100644 index 000000000000..c4a4dd4ae503 --- /dev/null +++ b/enzyme/Enzyme/OptBlas.cpp @@ -0,0 +1,125 @@ +//===- PreserveNVVM.cpp - Mark NVVM attributes for preservation. -------===// +// +// Enzyme Project +// +// Part of the Enzyme Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// If using this code in an academic setting, please cite the following: +// @incollection{enzymeNeurips, +// title = {Instead of Rewriting Foreign Code for Machine Learning, +// Automatically Synthesize Fast Gradients}, +// author = {Moses, William S. and Churavy, Valentin}, +// booktitle = {Advances in Neural Information Processing Systems 33}, +// year = {2020}, +// note = {To appear in}, +// } +// +//===----------------------------------------------------------------------===// +// +// This file contains createPreserveNVVM, a transformation pass that marks +// calls to __nv_* functions, marking them as noinline as implementing the llvm +// intrinsic. +// +//===----------------------------------------------------------------------===// +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" + +#include "llvm/ADT/SmallSet.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/raw_ostream.h" + +#include "llvm/Pass.h" + +#include "llvm/Transforms/Utils.h" + +#include + +#include "OptBlas.h" +#include "Utils.h" + +using namespace llvm; +#ifdef DEBUG_TYPE +#undef DEBUG_TYPE +#endif +#define DEBUG_TYPE "blas-opt" + +#if LLVM_VERSION_MAJOR >= 14 +#define addAttribute addAttributeAtIndex +#endif + +bool optimizeBlas(bool Begin, Function &F) { + bool changed = false; + return changed; +} + +namespace { + +bool optimizeFncsWithBlas(llvm::Module &M) { + + using namespace llvm; + + // assert(called); + + // IRBuilder<> BuilderZ(&called->getEntryBlock()); + // BuilderZ.setFastMathFlags(getFast()); + + // BasicBlock *bb = called->getEntryBlock(); + // assert(bb.getTerminator() != nullptr); + + Value *A = nullptr; + Value *B = nullptr; + + Value *alpha = nullptr; + Value *beta = nullptr; + + return false; +} + +class OptimizeBlas final : public ModulePass { +public: + static char ID; + // bool Begin; + // OptimizeBlas(bool Begin = true) : ModulePass(ID), Begin(Begin) {} + + OptimizeBlas() : ModulePass(ID) {} + OptimizeBlas(char &pid) : ModulePass(pid) {} + void getAnalysisUsage(AnalysisUsage &AU) const override {} + bool runOnModule(Module &M) override { return optimizeFncsWithBlas(M); } +}; + +} // namespace + +char OptimizeBlas::ID = 0; + +static RegisterPass X("blas-opt", "Optimize Blas Pass"); + +ModulePass *createOptimizeBlasPass(bool Begin) { + char pid = 0; + return new OptimizeBlas(pid); +} + +#include +#include + +#include "llvm/IR/LegacyPassManager.h" + +extern "C" void AddOptimizeBlasPass(LLVMPassManagerRef PM, uint8_t Begin) { + unwrap(PM)->add(createOptimizeBlasPass((bool)Begin)); +} + +//bool OptimizeBlasNewPM::run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) { +// bool changed = false; +// for (auto &F : M) +// changed |= optimizeBlas(Begin, F); +// return changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); +//} +llvm::AnalysisKey OptimizeBlasNewPM::Key; diff --git a/enzyme/Enzyme/OptBlas.h b/enzyme/Enzyme/OptBlas.h new file mode 100644 index 000000000000..549c2873c949 --- /dev/null +++ b/enzyme/Enzyme/OptBlas.h @@ -0,0 +1,48 @@ +//===- OptBlas.h - Mark NVVM attributes for preservation. -------===// +// +// Enzyme Project +// +// Part of the Enzyme Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// If using this code in an academic setting, please cite the following: +// @incollection{enzymeNeurips, +// title = {Instead of Rewriting Foreign Code for Machine Learning, +// Automatically Synthesize Fast Gradients}, +// author = {Moses, William S. and Churavy, Valentin}, +// booktitle = {Advances in Neural Information Processing Systems 33}, +// year = {2020}, +// note = {To appear in}, +// } +// +//===----------------------------------------------------------------------===// +// +// This file contains createOptimizeBlas, +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/PassManager.h" +#include "llvm/Passes/PassPlugin.h" + +namespace llvm { +class ModulePass; +} + +llvm::ModulePass *createOptimizeBlasPass(bool Begin); + +class OptimizeBlasNewPM final + : public llvm::AnalysisInfoMixin { + // friend struct llvm::AnalysisInfoMixin; + +private: + bool Begin; + static llvm::AnalysisKey Key; + +public: + OptimizeBlasNewPM(bool Begin) : Begin(Begin) {} + + bool runOnModule(llvm::Module &M); + + static bool isRequired() { return true; } +}; diff --git a/enzyme/test/Enzyme/BlasOpt/blasCOpt.ll b/enzyme/test/Enzyme/BlasOpt/blasCOpt.ll new file mode 100644 index 000000000000..746971630e26 --- /dev/null +++ b/enzyme/test/Enzyme/BlasOpt/blasCOpt.ll @@ -0,0 +1,34 @@ +;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -blasOpt -S | FileCheck %s; fi +;RUN: %opt < %s %newLoadEnzyme -passes="blasOpt" -S | FileCheck %s + +declare void @dgemm_64_(i8, i8, i64, i64, i64, double, double* nocapture readonly, i64, double* nocapture readonly, i64, double, double* nocapture readonly, i64) + +declare void @dger_64_(i64 , i64, double, double* nocapture readonly, i64, double* nocapture readonly, i64, double*, i64) + +define void @f(double* noalias %C, double %alpha, double %beta, double* %x, double *%y, double* %v, double *%w, i64 %m, i64 %n, i64 %p, i64 %lda, i64 %ldb, i64 %ldc, i64 %incx, i64 %incy, i64 %incv, i64 %incw, i8 %transa, i8 %transb) { +entry: + %A = alloca double, align 16 + store double 0.000000e+00, double* %A + %B = alloca double, align 16 + store double 0.000000e+00, double* %B + call void @dger_64_(i64 %m, i64 %n, double %alpha, double* %x, i64 %incx, double* %y, i64 %incy, double* %A, i64 %lda) + call void @dger_64_(i64 %n, i64 %p, double %beta, double* %v, i64 %incv, double* %w, i64 %incw, double* %B, i64 %ldb) + call void @dgemm_64_(i8 %transa, i8 %transb, i64 %m, i64 %n, i64 %p, double %alpha, double* %A, i64 %lda, double* %B, i64 %ldb, double %beta, double* %C, i64 %ldc) + ;%ptr = bitcast i8* %A to double* + ;store double 0.0000000e+00, double* %ptr, align 8 + ret void +} + +; CHECK-COM: Target function: +; CHECK: define void @f(i8* noalias %C, i8* noalias %alpha, i8* noalias %beta, i64 %m, i64 %n, i64 %p, i64 %lda, i64 %ldb, i64 %ldc, i64 %incx, i64 %incy, i64 %incv, i64 %incw) +; CHECK-NEXT: entry: +; CHECK-NEXT: %A = alloca double, align 16 +; CHECK-NEXT: store double 0.000000e+00, double* %A +; CHECK-NEXT: %B = alloca double, align 16 +; CHECK-NEXT: store double 0.000000e+00, double* %B +; CHECK-NEXT: %0 = call double @ddot_64_(i64 %m, double* %y, i64 %incy, double* %v, i64 %incv) +; CHECK-NEXT: %1 = fmul double %0, %alpha +; CHECK-NEXT: %2 = fmul double %1, %beta +; CHECK-NEXT: call void @dger_64_(i64 %m, i64 %n, double %2, double* %x, i64 %incx, double* %w, i64 %incw, double* %C, i64 %ldc) +; CHECK-NEXT: ret void +; CHECK-NEXT: } From b0d01158144400c4ff8731e378fa8e8a5d17ca50 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 3 Aug 2023 19:21:57 -0400 Subject: [PATCH 02/17] starting somewhere --- enzyme/Enzyme/Enzyme.cpp | 8 +- enzyme/Enzyme/OptBlas.cpp | 151 ++++++++++++++++++++----- enzyme/Enzyme/OptBlas.h | 2 + enzyme/test/Enzyme/BlasOpt/blasCOpt.ll | 10 +- 4 files changed, 135 insertions(+), 36 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 1cbf84e499c6..47aa0c3f00f3 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -3418,10 +3418,10 @@ llvmGetPassPluginInfo() { PB.registerPipelineParsingCallback( [](llvm::StringRef Name, llvm::ModulePassManager &MPM, llvm::ArrayRef) { - if (Name == "blas-opt") { - MPM.addPass(createOptimizeBlasPass(/*Begin*/ true)); - return true; - } + // if (Name == "blas-opt") { + // MPM.addPass(createOptimizeBlasPass(/*Begin*/ true)); + // return true; + // } if (Name == "enzyme") { MPM.addPass(EnzymeNewPM()); return true; diff --git a/enzyme/Enzyme/OptBlas.cpp b/enzyme/Enzyme/OptBlas.cpp index c4a4dd4ae503..a3b6185232f0 100644 --- a/enzyme/Enzyme/OptBlas.cpp +++ b/enzyme/Enzyme/OptBlas.cpp @@ -63,27 +63,6 @@ bool optimizeBlas(bool Begin, Function &F) { namespace { -bool optimizeFncsWithBlas(llvm::Module &M) { - - using namespace llvm; - - // assert(called); - - // IRBuilder<> BuilderZ(&called->getEntryBlock()); - // BuilderZ.setFastMathFlags(getFast()); - - // BasicBlock *bb = called->getEntryBlock(); - // assert(bb.getTerminator() != nullptr); - - Value *A = nullptr; - Value *B = nullptr; - - Value *alpha = nullptr; - Value *beta = nullptr; - - return false; -} - class OptimizeBlas final : public ModulePass { public: static char ID; @@ -116,10 +95,128 @@ extern "C" void AddOptimizeBlasPass(LLVMPassManagerRef PM, uint8_t Begin) { unwrap(PM)->add(createOptimizeBlasPass((bool)Begin)); } -//bool OptimizeBlasNewPM::run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) { -// bool changed = false; -// for (auto &F : M) -// changed |= optimizeBlas(Begin, F); -// return changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); -//} llvm::AnalysisKey OptimizeBlasNewPM::Key; + +bool optimizeFncsWithBlas(llvm::Module &M) { + + using namespace llvm; + + Function *F = M.getFunction("f"); + if (F) { + errs() << "Found function: " << F->getName() << "\n"; + } else { + return false; + } + + BasicBlock *bb = &F->getEntryBlock(); + + bool firstGer = true; + + Value *m = nullptr; + Value *n = nullptr; + Value *p = nullptr; + + Value *x = nullptr; + Value *y = nullptr; + Value *v = nullptr; + Value *w = nullptr; + + Value *incx = nullptr; + Value *incy = nullptr; + Value *incv = nullptr; + Value *incw = nullptr; + + Value *A = nullptr; + Value *B = nullptr; + Value *C = nullptr; + + Value *ldc = nullptr; + + Value *alpha = nullptr; + Value *beta = nullptr; + + // create a vector of calls I will delete after the loop + std::vector callsToDelete; + + // Store insertion point for new instructions + Instruction *insertionPoint = nullptr; + + for (auto &I : *bb) { + if (auto *call = dyn_cast(&I)) { + auto name = call->getCalledFunction()->getName(); + if (name.contains("dger_")) { + errs() << "Found dger_ call\n"; + if (firstGer) { + m = call->getArgOperand(0); + n = call->getArgOperand(1); + alpha = call->getArgOperand(2); + x = call->getArgOperand(3); + incx = call->getArgOperand(4); + y = call->getArgOperand(5); + incy = call->getArgOperand(6); + A = call->getArgOperand(7); + firstGer = false; + callsToDelete.push_back(call); + insertionPoint = call; + } else { + assert(n == call->getArgOperand(0)); + + p = call->getArgOperand(1); + beta = call->getArgOperand(2); + v = call->getArgOperand(3); + incv = call->getArgOperand(4); + w = call->getArgOperand(5); + incw = call->getArgOperand(6); + B = call->getArgOperand(7); + callsToDelete.push_back(call); + } + } + if (name.contains("dgemm_")) { + errs() << "Found dgemm_ call\n"; + assert(call->getArgOperand(2) == m); + assert(call->getArgOperand(3) == n); + assert(call->getArgOperand(4) == p); + assert(call->getArgOperand(5) == alpha); + assert(call->getArgOperand(6) == A); + assert(call->getArgOperand(8) == B); + assert(call->getArgOperand(10) == beta); + C = call->getArgOperand(11); + ldc = call->getArgOperand(12); + callsToDelete.push_back(call); + } + } + } + if (callsToDelete.size() == 0) { + return false; + } + insertionPoint = callsToDelete[0]->getPrevNode(); + for (auto call : callsToDelete) { + call->eraseFromParent(); + } + + bb->getTerminator()->eraseFromParent(); + FunctionType *FTDot = + FunctionType::get(Type::getDoubleTy(M.getContext()), + {m->getType(), y->getType(), incy->getType(), + v->getType(), incv->getType()}, + false); + std::string dot_name = "ddot_64_"; + Function *FDot = + cast(M.getOrInsertFunction(dot_name, FTDot).getCallee()); + + Function *FGer = M.getFunction("dger_64_"); + assert(FGer); + + // bb->setInsertPoint(insertionPoint); + + + IRBuilder<> B1(bb); + Value *dotRet = B1.CreateCall(FDot, {m, y, incy, v, incv}); + Value *alphaDotRet = B1.CreateFMul(alpha, dotRet); + Value *alphabeta = B1.CreateFMul(alphaDotRet, beta); + B1.CreateCall(FGer, {m, n, alphabeta, x, incx, w, incw, C, ldc}); + + B1.CreateRetVoid(); + + return true; +} diff --git a/enzyme/Enzyme/OptBlas.h b/enzyme/Enzyme/OptBlas.h index 549c2873c949..78d001b021e4 100644 --- a/enzyme/Enzyme/OptBlas.h +++ b/enzyme/Enzyme/OptBlas.h @@ -29,6 +29,8 @@ namespace llvm { class ModulePass; } +bool optimizeFncsWithBlas(llvm::Module &M); + llvm::ModulePass *createOptimizeBlasPass(bool Begin); class OptimizeBlasNewPM final diff --git a/enzyme/test/Enzyme/BlasOpt/blasCOpt.ll b/enzyme/test/Enzyme/BlasOpt/blasCOpt.ll index 746971630e26..2b0aca7fcdb4 100644 --- a/enzyme/test/Enzyme/BlasOpt/blasCOpt.ll +++ b/enzyme/test/Enzyme/BlasOpt/blasCOpt.ll @@ -1,5 +1,5 @@ -;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -blasOpt -S | FileCheck %s; fi -;RUN: %opt < %s %newLoadEnzyme -passes="blasOpt" -S | FileCheck %s +;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -blas-opt -S | FileCheck %s; fi +;RUN: %opt < %s %newLoadEnzyme -passes="blas-opt" -S | FileCheck %s declare void @dgemm_64_(i8, i8, i64, i64, i64, double, double* nocapture readonly, i64, double* nocapture readonly, i64, double, double* nocapture readonly, i64) @@ -19,15 +19,15 @@ entry: ret void } -; CHECK-COM: Target function: -; CHECK: define void @f(i8* noalias %C, i8* noalias %alpha, i8* noalias %beta, i64 %m, i64 %n, i64 %p, i64 %lda, i64 %ldb, i64 %ldc, i64 %incx, i64 %incy, i64 %incv, i64 %incw) + +; CHECK: define void @f(double* noalias %C, double %alpha, double %beta, double* %x, double* %y, double* %v, double* %w, i64 %m, i64 %n, i64 %p, i64 %lda, i64 %ldb, i64 %ldc, i64 %incx, i64 %incy, i64 %incv, i64 %incw, i8 %transa, i8 %transb) ; CHECK-NEXT: entry: ; CHECK-NEXT: %A = alloca double, align 16 ; CHECK-NEXT: store double 0.000000e+00, double* %A ; CHECK-NEXT: %B = alloca double, align 16 ; CHECK-NEXT: store double 0.000000e+00, double* %B ; CHECK-NEXT: %0 = call double @ddot_64_(i64 %m, double* %y, i64 %incy, double* %v, i64 %incv) -; CHECK-NEXT: %1 = fmul double %0, %alpha +; CHECK-NEXT: %1 = fmul double %alpha, %0 ; CHECK-NEXT: %2 = fmul double %1, %beta ; CHECK-NEXT: call void @dger_64_(i64 %m, i64 %n, double %2, double* %x, i64 %incx, double* %w, i64 %incw, double* %C, i64 %ldc) ; CHECK-NEXT: ret void From 94c3aa0bc83114355b869bca72cb1cbb32e4f8b1 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 3 Aug 2023 23:38:57 -0400 Subject: [PATCH 03/17] wip tablegen automation --- enzyme/BUILD | 13 +++ enzyme/Enzyme/BlasOpts.td | 27 +++++ enzyme/Enzyme/CMakeLists.txt | 3 + enzyme/Enzyme/InstructionDerivatives.td | 1 + enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 117 +++++++++++++++++++ 5 files changed, 161 insertions(+) create mode 100644 enzyme/Enzyme/BlasOpts.td diff --git a/enzyme/BUILD b/enzyme/BUILD index c54c2f6b1097..8f2240376255 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -113,6 +113,19 @@ gentbl( ":enzyme-tblgen", ], ) +gentbl( + name = "blas-optimizations", + tbl_outs = [( + "-gen-blas-optimizations", + "BlasOpts.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/InstructionDerivatives.td", + td_srcs = ["Enzyme/BlasDerivatives.td"], + deps = [ + ":enzyme-tblgen", + ], +) gentbl( name = "blas-diffuseanalysis", tbl_outs = [( diff --git a/enzyme/Enzyme/BlasOpts.td b/enzyme/Enzyme/BlasOpts.td new file mode 100644 index 000000000000..ee2ac16b277c --- /dev/null +++ b/enzyme/Enzyme/BlasOpts.td @@ -0,0 +1,27 @@ + +class BlasOptPattern< list _inputs, list _tmps, list _outputs> { + list inputs = _inputs; + // tmp variables will dissapear during the transformation + // and therefore are not allowed to be read elsewhere + list tmps = _tmps; + list outputs = _outputs; +} + +class Value { + int i = _i; +} + +def first : BlasOptPattern< + [ + (b<"ger"> $layout, $m, $n, $alpha, $x, $incx, $y, $incy, $A, $lda), + (b<"ger"> $layout, $m, $k, $beta, $v, $incv, $w, $incw, $B, $ldb), + (b<"gemm"> $layout, $transa, $transb, $m, $n, $k, $alpha, $A, $lda, $B, $ldb, $beta, $C, $ldc), + ], + ["A", "B"], + [ + (Value<1> (b<"dot"> $layout, $n, $v, $incv, $y, $incy)), + (Value<2> (FMul $alpha, Value<1>)), + (Value<3> (FMul $beta, Value<2>)), + (b<"ger"> $layout, $m, $k, Value<3>, $x, $incx, $w, $incw, $C, $ldc), + ] + >; diff --git a/enzyme/Enzyme/CMakeLists.txt b/enzyme/Enzyme/CMakeLists.txt index c895912bc0f9..9910f2d888a1 100644 --- a/enzyme/Enzyme/CMakeLists.txt +++ b/enzyme/Enzyme/CMakeLists.txt @@ -28,6 +28,7 @@ enzyme_tablegen(BlasDerivatives.inc -gen-blas-derivatives) enzyme_tablegen(BlasAttributor.inc -update-blas-declarations) enzyme_tablegen(BlasTA.inc -gen-blas-typeanalysis) enzyme_tablegen(BlasDiffUse.inc -gen-blas-diffuseanalysis) +enzyme_tablegen(BlasOpts.inc -gen-blas-optimizations) add_public_tablegen_target(BinopDerivativesIncGen) add_public_tablegen_target(IntrinsicDerivativesIncGen) add_public_tablegen_target(CallDerivativesIncGen) @@ -36,6 +37,7 @@ add_public_tablegen_target(BlasDerivativesIncGen) add_public_tablegen_target(BlasDeclarationsIncGen) add_public_tablegen_target(BlasTAIncGen) add_public_tablegen_target(BlasDiffUseIncGen) +add_public_tablegen_target(BlasOptsIncGen) include_directories(${CMAKE_CURRENT_BINARY_DIR}) @@ -127,6 +129,7 @@ if (${ENZYME_EXTERNAL_SHARED_LIB}) add_dependencies(Enzyme-${LLVM_VERSION_MAJOR} BlasDeclarationsIncGen) add_dependencies(Enzyme-${LLVM_VERSION_MAJOR} BlasTAIncGen) add_dependencies(Enzyme-${LLVM_VERSION_MAJOR} BlasDiffUseIncGen) + add_dependencies(Enzyme-${LLVM_VERSION_MAJOR} BlasOptsIncGen) target_link_libraries(Enzyme-${LLVM_VERSION_MAJOR} LLVM) install(TARGETS Enzyme-${LLVM_VERSION_MAJOR} EXPORT EnzymeTargets diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index 1796645073fc..687db3c34da6 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -903,3 +903,4 @@ def : InstPattern<(Op $x), "Freeze", 10, 0, >; include "BlasDerivatives.td" +include "BlasOpts.td" diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index a733601c8203..738b29be9445 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -38,6 +38,7 @@ enum ActionType { UpdateBlasDecl, UpdateBlasTA, GenBlasDiffUse, + GenBlasOpts, }; static cl::opt @@ -56,6 +57,8 @@ static cl::opt "Generate binaryoperator derivative")), cl::values(clEnumValN(InstDerivatives, "gen-inst-derivatives", "Generate instruction derivative")), + cl::values(clEnumValN(GenBlasOpts, "gen-blas-optimizations", + "Generate BLAS optimizations")), cl::values(clEnumValN(CallDerivatives, "gen-call-derivatives", "Generate call derivative"))); @@ -1397,6 +1400,117 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, #include "blasDiffUseUpdater.h" #include "blasTAUpdater.h" +// class BlasOptPattern< list _inputs, list _tmps, list +// _outputs> { +// list inputs = _inputs; +// // tmp variables will dissapear during the transformation +// // and therefore are not allowed to be read elsewhere +// list tmps = _tmps; +// list outputs = _outputs; +// } +// def first : BlasOptPattern<, +// [ +// (b<"ger"> $layout, $m, $n, $alpha, $x, $incx, $y, $incy, $A, $lda), +// (b<"ger"> $layout, $m, $k, $beta, $v, $incv, $w, $incw, $B, $ldb), +// (b<"gemm"> $layout, $transa, $transb, $m, $n, $k, $alpha, $A, $lda, $B, $ldb, +// $beta, $C, $ldc), +// ], +// ["A", "B"], +// [ +// (Value<1> (b<"dot"> $layout, $n, $v, $incv, $y, $incy)), +// (Value<2> (FMul $alpha Value<1>)), +// (Value<3> (FMul $beta Value<2>)), +// (b<"ger"> $layout, $m, $k, Value<3>, $x, $incx, $w, $incw, $C, $ldc), +// ] +// >; +void emitBlasOpt(StringRef name, std::vector inputs, + std::vector, std::vector outputs, + raw_ostream &os) { + os << "bool opt" << name << "(llvm::Function *F) {\n"; + os << " using namespace llvm;\n"; + + StringSet usedArgs{}; + std::vector functions{}; + StringMap> unique_functions{}; + for (auto input : inputs) { + ArrayRef args = input->getArgNames(); + auto Def = cast(input->getOperator())->getDef(); + assert(Def->isSubClassOf("b")); + auto fnc_name = Def->getValueAsString("s"); + // auto fnc_name = input->getNameStr(); + functions.push_back(fnc_name); + unique_functions[fnc_name].push_back(input); + for (auto &arg : args) { + if (usedArgs.count(arg->getValue())) + continue; + os << " Value *" << arg->getValue() << " = nullptr;\n"; + usedArgs.insert(arg->getValue()); + } + } + + size_t idx = 0; + + for (auto fnc : unique_functions.keys()) { + if (unique_functions.count(fnc) > 1) + os << " size_t idx_" << fnc << " = 0;\n"; + } + + os << " for (auto &BB : *F) {\n" + << " for (auto &I : BB) {\n" + << " if (auto *CI = dyn_cast(&I)) {\n" + << " name = CI->getCalledFunction()->getName();\n"; + + for (auto fnc : unique_functions.keys()) { + os << " if (name == \"" << fnc << "\") {\n"; + auto fnc_vec = unique_functions[fnc]; + for (size_t i = 0; i < fnc_vec.size(); ++i) { + if (fnc_vec.size() > 1) { + os << " if (idx_" << fnc << " == " << i << ") {\n"; + } + auto input = fnc_vec[i]; + ArrayRef args = input->getArgNames(); + for (size_t j = 0; j < args.size(); ++j) { + os << " " << args[j]->getValue() << " = CI->getArgOperand(" + << j << ");\n"; + } + if (fnc_vec.size() > 1) { + os << " }\n"; + } + } + os << " }\n"; + } + os << " }\n"; + os << " }\n"; + os << " }\n"; + os << "}\n"; +} + +static void emitBlasOpts(const RecordKeeper &recordKeeper, raw_ostream &os) { + emitSourceFileHeader("Rewriters", os); + const char *patternNames = "BlasOptPattern"; + const auto &patterns = recordKeeper.getAllDerivedDefinitions(patternNames); + + for (Record *pattern : patterns) { + ListInit *inputs = pattern->getValueAsListInit("inputs"); + std::vector tmps = pattern->getValueAsListOfStrings("tmps"); + ListInit *outputs = pattern->getValueAsListInit("outputs"); + + std::vector inputDags; + for (auto input : *inputs) { + DagInit *dag = dyn_cast(input); + assert(dag); + inputDags.push_back(dag); + } + std::vector outputDags; + for (auto output : *outputs) { + DagInit *dag = dyn_cast(output); + assert(dag); + outputDags.push_back(dag); + } + emitBlasOpt(pattern->getName(), inputDags, tmps, outputDags, os); + } +} + static bool EnzymeTableGenMain(raw_ostream &os, RecordKeeper &records) { switch (action) { case CallDerivatives: @@ -1417,6 +1531,9 @@ static bool EnzymeTableGenMain(raw_ostream &os, RecordKeeper &records) { case UpdateBlasTA: emitBlasTAUpdater(records, os); return false; + case GenBlasOpts: + emitBlasOpts(records, os); + return false; default: errs() << "unknown tablegen action!\n"; From b7795b91a56f3facafbc06048f3ac98bde60431e Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 4 Aug 2023 00:02:58 -0400 Subject: [PATCH 04/17] tablegen more of the logic --- enzyme/Enzyme/BlasOpts.td | 2 +- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 46 ++++++++++++++++++-- 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/enzyme/Enzyme/BlasOpts.td b/enzyme/Enzyme/BlasOpts.td index ee2ac16b277c..7d2c1bb590f0 100644 --- a/enzyme/Enzyme/BlasOpts.td +++ b/enzyme/Enzyme/BlasOpts.td @@ -14,7 +14,7 @@ class Value { def first : BlasOptPattern< [ (b<"ger"> $layout, $m, $n, $alpha, $x, $incx, $y, $incy, $A, $lda), - (b<"ger"> $layout, $m, $k, $beta, $v, $incv, $w, $incw, $B, $ldb), + (b<"ger"> $layout, $n, $k, $beta, $v, $incv, $w, $incw, $B, $ldb), (b<"gemm"> $layout, $transa, $transb, $m, $n, $k, $alpha, $A, $lda, $B, $ldb, $beta, $C, $ldc), ], ["A", "B"], diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 738b29be9445..475a8adbe9c9 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1448,13 +1448,15 @@ void emitBlasOpt(StringRef name, std::vector inputs, } } - size_t idx = 0; + os << "\n"; for (auto fnc : unique_functions.keys()) { - if (unique_functions.count(fnc) > 1) - os << " size_t idx_" << fnc << " = 0;\n"; + os << " size_t idx_" << fnc << " = 0;\n"; } + os << " // create a vector of calls to delete\n"; + os << " std::vector todelete;\n"; + os << " for (auto &BB : *F) {\n" << " for (auto &I : BB) {\n" << " if (auto *CI = dyn_cast(&I)) {\n" @@ -1470,6 +1472,8 @@ void emitBlasOpt(StringRef name, std::vector inputs, auto input = fnc_vec[i]; ArrayRef args = input->getArgNames(); for (size_t j = 0; j < args.size(); ++j) { + if (fnc_vec.size() > 1) + os << " "; os << " " << args[j]->getValue() << " = CI->getArgOperand(" << j << ");\n"; } @@ -1477,11 +1481,45 @@ void emitBlasOpt(StringRef name, std::vector inputs, os << " }\n"; } } - os << " }\n"; + os << " idx_" << fnc << "++;\n" + << " todelete.push_back(CI);\n" + << " }\n"; } os << " }\n"; os << " }\n"; os << " }\n"; + + // check that all functions have been found + os << " bool found = true;\n"; + for (auto fnc : unique_functions.keys()) { + os << " if (idx_" << fnc << " != " << unique_functions[fnc].size() << ")\n" + << " found = false;\n"; + } + os << " if (!found)\n" + << " return false;\n"; + + // now that we found an optimization to apply, + // we can delete the old calls + os << " for (auto *CI : todelete) {\n" + << " CI->eraseFromParent();\n" + << " }\n"; + + // emit the actual optimization + // for (auto output : outputs) { + // auto Def = cast(output->getOperator())->getDef(); + // assert(Def->isSubClassOf("b")); + // auto fnc_name = Def->getValueAsString("s"); + // // auto fnc_name = output->getNameStr(); + // os << " " << fnc_name << "("; + // ArrayRef args = output->getArgNames(); + // for (size_t i = 0; i < args.size(); ++i) { + // if (i > 0) + // os << ", "; + // os << args[i]->getValue(); + // } + // os << ");\n"; + //} + os << "}\n"; } From 4482ceea2e35d9a934cb64590f89ceee2df68c81 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sat, 5 Aug 2023 18:46:49 -0400 Subject: [PATCH 05/17] also emit new rules --- enzyme/Enzyme/BlasOpts.td | 18 ++-- enzyme/Enzyme/OptBlas.cpp | 22 ++++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 102 +++++++++++++------ 3 files changed, 102 insertions(+), 40 deletions(-) diff --git a/enzyme/Enzyme/BlasOpts.td b/enzyme/Enzyme/BlasOpts.td index 7d2c1bb590f0..9977d81a52d2 100644 --- a/enzyme/Enzyme/BlasOpts.td +++ b/enzyme/Enzyme/BlasOpts.td @@ -7,21 +7,21 @@ class BlasOptPattern< list _inputs, list _tmps, list _outputs> list outputs = _outputs; } -class Value { - int i = _i; +class Value { + string name = _i; } def first : BlasOptPattern< [ - (b<"ger"> $layout, $m, $n, $alpha, $x, $incx, $y, $incy, $A, $lda), - (b<"ger"> $layout, $n, $k, $beta, $v, $incv, $w, $incw, $B, $ldb), - (b<"gemm"> $layout, $transa, $transb, $m, $n, $k, $alpha, $A, $lda, $B, $ldb, $beta, $C, $ldc), + (b<"dger_64_"> $m, $n, $alpha, $x, $incx, $y, $incy, $A, $lda), + (b<"dger_64_"> $n, $k, $beta, $v, $incv, $w, $incw, $B, $ldb), + (b<"dgemm_64_"> $transa, $transb, $m, $n, $k, $alpha, $A, $lda, $B, $ldb, $beta, $C, $ldc), ], ["A", "B"], [ - (Value<1> (b<"dot"> $layout, $n, $v, $incv, $y, $incy)), - (Value<2> (FMul $alpha, Value<1>)), - (Value<3> (FMul $beta, Value<2>)), - (b<"ger"> $layout, $m, $k, Value<3>, $x, $incx, $w, $incw, $C, $ldc), + (Value<"1"> (b<"dot"> $n, $v, $incv, $y, $incy)), + (Value<"2"> (FMul $alpha, Value<"1">)), + (Value<"3"> (FMul $beta, Value<"2">)), + (b<"ger"> $m, $k, Value<"3">, $x, $incx, $w, $incw, $C, $ldc), ] >; diff --git a/enzyme/Enzyme/OptBlas.cpp b/enzyme/Enzyme/OptBlas.cpp index a3b6185232f0..502211d02f72 100644 --- a/enzyme/Enzyme/OptBlas.cpp +++ b/enzyme/Enzyme/OptBlas.cpp @@ -97,6 +97,26 @@ extern "C" void AddOptimizeBlasPass(LLVMPassManagerRef PM, uint8_t Begin) { llvm::AnalysisKey OptimizeBlasNewPM::Key; +bool cmp_or_set(llvm::CallInst *CI, std::vector values) { + // first run trough to see if the already set args match. + // second run if they do and then we set the nullptr. + for (size_t i = 0; i < values.size(); ++i) { + if (values[i] == nullptr) { + continue; + } + if (CI->getArgOperand(i) != values[i]) + return false; + } + for (size_t i = 0; i < values.size(); ++i) { + if (values[i] == nullptr) { + values[i] = CI->getArgOperand(i); + } + } + return true; +} + +#include "BlasOpts.inc" + bool optimizeFncsWithBlas(llvm::Module &M) { using namespace llvm; @@ -108,6 +128,8 @@ bool optimizeFncsWithBlas(llvm::Module &M) { return false; } + optfirst(F); + BasicBlock *bb = &F->getEntryBlock(); bool firstGer = true; diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 475a8adbe9c9..9d4b3c1925ff 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1400,6 +1400,21 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, #include "blasDiffUseUpdater.h" #include "blasTAUpdater.h" +void emit_args(DagInit *dag, raw_ostream &os) { + for (size_t i = 0; i < dag->getNumArgs(); ++i) { + if (i > 0) + os << ", "; + if (auto def = dyn_cast(dag->getArg(i))) { + auto Def = def->getDef(); + if (Def->isSubClassOf("Value")) { + os << "tmp" << Def->getValueAsString("name"); + } + } else { + os << dag->getArgNameStr(i); + } + } +} + // class BlasOptPattern< list _inputs, list _tmps, list // _outputs> { // list inputs = _inputs; @@ -1460,29 +1475,33 @@ void emitBlasOpt(StringRef name, std::vector inputs, os << " for (auto &BB : *F) {\n" << " for (auto &I : BB) {\n" << " if (auto *CI = dyn_cast(&I)) {\n" - << " name = CI->getCalledFunction()->getName();\n"; - + << " auto name = CI->getCalledFunction()->getName();\n"; for (auto fnc : unique_functions.keys()) { os << " if (name == \"" << fnc << "\") {\n"; + std::string tab = " "; auto fnc_vec = unique_functions[fnc]; + bool multiple = fnc_vec.size() > 1; + os << tab << "assert(idx_" << fnc << " < " << fnc_vec.size() + << " && \"idx out of bounds\");\n"; + os << tab << "std::vector values;\n"; for (size_t i = 0; i < fnc_vec.size(); ++i) { - if (fnc_vec.size() > 1) { - os << " if (idx_" << fnc << " == " << i << ") {\n"; - } - auto input = fnc_vec[i]; - ArrayRef args = input->getArgNames(); - for (size_t j = 0; j < args.size(); ++j) { - if (fnc_vec.size() > 1) - os << " "; - os << " " << args[j]->getValue() << " = CI->getArgOperand(" - << j << ");\n"; + if (multiple) { + os << tab << "if (idx_" << fnc << " == " << i << ")\n "; } - if (fnc_vec.size() > 1) { - os << " }\n"; + os << tab << "values = {"; + ArrayRef args = fnc_vec[i]->getArgNames(); + bool first = true; + for (auto arg : args) { + os << (first ? "" : ", ") << arg->getValue(); + first = false; } + os << "};\n"; } - os << " idx_" << fnc << "++;\n" - << " todelete.push_back(CI);\n" + os << tab << "bool set = cmp_or_set(CI, values);\n"; + os << tab << "if (!set) continue;\n"; + os << tab << "llvm::errs() << \"found " << fnc << "\\n\";\n"; + os << tab << "idx_" << fnc << "++;\n" + << tab << "todelete.push_back(CI);\n" << " }\n"; } os << " }\n"; @@ -1498,28 +1517,49 @@ void emitBlasOpt(StringRef name, std::vector inputs, os << " if (!found)\n" << " return false;\n"; + os << " llvm::errs() << \"found optimization " << name << "\\n\";\n"; + // now that we found an optimization to apply, // we can delete the old calls os << " for (auto *CI : todelete) {\n" << " CI->eraseFromParent();\n" << " }\n"; - // emit the actual optimization - // for (auto output : outputs) { - // auto Def = cast(output->getOperator())->getDef(); - // assert(Def->isSubClassOf("b")); - // auto fnc_name = Def->getValueAsString("s"); - // // auto fnc_name = output->getNameStr(); - // os << " " << fnc_name << "("; - // ArrayRef args = output->getArgNames(); - // for (size_t i = 0; i < args.size(); ++i) { - // if (i > 0) - // os << ", "; - // os << args[i]->getValue(); - // } - // os << ");\n"; - //} + + os << " BasicBlock *bb = &F->getEntryBlock();\n" + << " IRBuilder<> B1(bb);\n"; + + for (auto outerOutput : outputs) { + DagInit *output = outerOutput; + auto Def = cast(output->getOperator())->getDef(); + if (Def->isSubClassOf("Value")) { + assert(output->getNumArgs() == 1); + auto name = Def->getValueAsString("name"); + os << " Value *tmp" << name << " = "; + // This is just wrapping the actual DagInit in a Value<>. + // So now strip the Value wrapper to handle it in the next if/else + output = cast(output->getArg(0)); + Def = cast(output->getOperator())->getDef(); + } + if (Def->isSubClassOf("Inst")) { + auto name = Def->getValueAsString("name"); + os << " B1.CreateCall(" << name << ""; + emit_args(output, os); + os << ");\n"; + } else if (Def->isSubClassOf("b")) { + auto fnc_name = Def->getValueAsString("s"); + os << " " << fnc_name << "("; + emit_args(output, os); + os << ");\n"; + } else { + llvm::errs() << "failed with: " << Def->getName() << "\n"; + PrintFatalError(Def->getLoc(), "unknown output type"); + assert(false); + llvm_unreachable("unknown output type"); + } + } + os << " return true;\n"; os << "}\n"; } From 6adc1a0b4e91d5e58fa751be1fd67d5c45cdd425 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 7 Aug 2023 14:49:12 -0400 Subject: [PATCH 06/17] add c level test --- enzyme/test/Integration/BlasOpt/first.c | 42 +++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 enzyme/test/Integration/BlasOpt/first.c diff --git a/enzyme/test/Integration/BlasOpt/first.c b/enzyme/test/Integration/BlasOpt/first.c new file mode 100644 index 000000000000..a51ac37c6dc2 --- /dev/null +++ b/enzyme/test/Integration/BlasOpt/first.c @@ -0,0 +1,42 @@ +enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102}; +enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113}; + +void cblas_dger(const enum CBLAS_ORDER order, const int M, const int N, + const double alpha, const double *X, const int incX, + const double *Y, const int incY, double *A, const int lda); + +void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const double alpha, const double *A, + const int lda, const double *B, const int ldb, + const double beta, double *C, const int ldc); + +void g(double *restrict x, double *restrict y, double *restrict v, double *restrict w, double *restrict C) { + double A[] = {0.00, 0.00, 0.00, 0.00, + 0.00, 0.00, 0.00, 0.00, + 0.00}; + double B[] = {0.00, 0.00, 0.00, 0.00, + 0.00, 0.00, 0.00, 0.00, + 0.00}; + int lda = 3, ldb = 3, ldc = 3; + int m = 3, n = 3, p = 3; + double alpha = 3.14, beta = 4.31; + int incx = 1, incy = 1, incv = 1, incw = 1; + + cblas_dger(CblasRowMajor, m, n, alpha, x, incx, y, incy, A, lda); + cblas_dger(CblasRowMajor, n, p, beta, v, incv, w, incw, B, ldb); + cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, p, alpha, A, lda, B, ldb, beta, C, ldc); +} + +int main() { + double x[] = {1.0, 2.0, 3.0}; + double y[] = {2.0, 3.0, 4.0}; + double v[] = {0.0, 2.5, 3.5}; + double w[] = {1.0, 1.0, 1.0}; + double C[] = {0.00, 0.00, 0.00, 0.00, + 0.00, 0.00, 0.00, 0.00, + 0.00}; + g(x,y,v,w,C); + for (int i = 0; i < 9; i++) + printf("%f\n", C[i]); +} From 9416f51eea44faf2bc8f2f64b8a056c4d68db815 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 7 Aug 2023 16:28:18 -0400 Subject: [PATCH 07/17] tablegen first opt --- enzyme/Enzyme/BlasOpts.td | 26 +- enzyme/Enzyme/OptBlas.cpp | 112 +-------- enzyme/tools/enzyme-tblgen/CMakeLists.txt | 2 +- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 190 +------------- enzyme/tools/enzyme-tblgen/opts.cpp | 247 +++++++++++++++++++ enzyme/tools/enzyme-tblgen/opts.h | 2 + 6 files changed, 265 insertions(+), 314 deletions(-) create mode 100644 enzyme/tools/enzyme-tblgen/opts.cpp create mode 100644 enzyme/tools/enzyme-tblgen/opts.h diff --git a/enzyme/Enzyme/BlasOpts.td b/enzyme/Enzyme/BlasOpts.td index 9977d81a52d2..f2ae4846af3f 100644 --- a/enzyme/Enzyme/BlasOpts.td +++ b/enzyme/Enzyme/BlasOpts.td @@ -12,16 +12,16 @@ class Value { } def first : BlasOptPattern< - [ - (b<"dger_64_"> $m, $n, $alpha, $x, $incx, $y, $incy, $A, $lda), - (b<"dger_64_"> $n, $k, $beta, $v, $incv, $w, $incw, $B, $ldb), - (b<"dgemm_64_"> $transa, $transb, $m, $n, $k, $alpha, $A, $lda, $B, $ldb, $beta, $C, $ldc), - ], - ["A", "B"], - [ - (Value<"1"> (b<"dot"> $n, $v, $incv, $y, $incy)), - (Value<"2"> (FMul $alpha, Value<"1">)), - (Value<"3"> (FMul $beta, Value<"2">)), - (b<"ger"> $m, $k, Value<"3">, $x, $incx, $w, $incw, $C, $ldc), - ] - >; + [ + (b<"ger"> $m, $n, $alpha, $x, $incx, $y, $incy, $A, $lda), + (b<"ger"> $n, $k, $beta, $v, $incv, $w, $incw, $B, $ldb), + (b<"gemm"> $transa, $transb, $m, $n, $k, $alpha, $A, $lda, $B, $ldb, $beta, $C, $ldc), + ], + ["A", "B"], + [ + (Value<"1"> (b<"dot"> $n, $v, $incv, $y, $incy)), + (Value<"2"> (FMul $alpha, Value<"1">)), + (Value<"3"> (FMul $beta, Value<"2">)), + (b<"ger"> $m, $k, Value<"3">, $x, $incx, $w, $incw, $C, $ldc), + ] + >; diff --git a/enzyme/Enzyme/OptBlas.cpp b/enzyme/Enzyme/OptBlas.cpp index 502211d02f72..aeb8d35333f2 100644 --- a/enzyme/Enzyme/OptBlas.cpp +++ b/enzyme/Enzyme/OptBlas.cpp @@ -128,117 +128,7 @@ bool optimizeFncsWithBlas(llvm::Module &M) { return false; } - optfirst(F); - - BasicBlock *bb = &F->getEntryBlock(); - - bool firstGer = true; - - Value *m = nullptr; - Value *n = nullptr; - Value *p = nullptr; - - Value *x = nullptr; - Value *y = nullptr; - Value *v = nullptr; - Value *w = nullptr; - - Value *incx = nullptr; - Value *incy = nullptr; - Value *incv = nullptr; - Value *incw = nullptr; - - Value *A = nullptr; - Value *B = nullptr; - Value *C = nullptr; - - Value *ldc = nullptr; - - Value *alpha = nullptr; - Value *beta = nullptr; - - // create a vector of calls I will delete after the loop - std::vector callsToDelete; - - // Store insertion point for new instructions - Instruction *insertionPoint = nullptr; - - for (auto &I : *bb) { - if (auto *call = dyn_cast(&I)) { - auto name = call->getCalledFunction()->getName(); - if (name.contains("dger_")) { - errs() << "Found dger_ call\n"; - if (firstGer) { - m = call->getArgOperand(0); - n = call->getArgOperand(1); - alpha = call->getArgOperand(2); - x = call->getArgOperand(3); - incx = call->getArgOperand(4); - y = call->getArgOperand(5); - incy = call->getArgOperand(6); - A = call->getArgOperand(7); - firstGer = false; - callsToDelete.push_back(call); - insertionPoint = call; - } else { - assert(n == call->getArgOperand(0)); - - p = call->getArgOperand(1); - beta = call->getArgOperand(2); - v = call->getArgOperand(3); - incv = call->getArgOperand(4); - w = call->getArgOperand(5); - incw = call->getArgOperand(6); - B = call->getArgOperand(7); - callsToDelete.push_back(call); - } - } - if (name.contains("dgemm_")) { - errs() << "Found dgemm_ call\n"; - assert(call->getArgOperand(2) == m); - assert(call->getArgOperand(3) == n); - assert(call->getArgOperand(4) == p); - assert(call->getArgOperand(5) == alpha); - assert(call->getArgOperand(6) == A); - assert(call->getArgOperand(8) == B); - assert(call->getArgOperand(10) == beta); - C = call->getArgOperand(11); - ldc = call->getArgOperand(12); - callsToDelete.push_back(call); - } - } - } - if (callsToDelete.size() == 0) { - return false; - } - insertionPoint = callsToDelete[0]->getPrevNode(); - for (auto call : callsToDelete) { - call->eraseFromParent(); - } - - bb->getTerminator()->eraseFromParent(); - FunctionType *FTDot = - FunctionType::get(Type::getDoubleTy(M.getContext()), - {m->getType(), y->getType(), incy->getType(), - v->getType(), incv->getType()}, - false); - std::string dot_name = "ddot_64_"; - Function *FDot = - cast(M.getOrInsertFunction(dot_name, FTDot).getCallee()); - - Function *FGer = M.getFunction("dger_64_"); - assert(FGer); - - // bb->setInsertPoint(insertionPoint); - - - IRBuilder<> B1(bb); - Value *dotRet = B1.CreateCall(FDot, {m, y, incy, v, incv}); - Value *alphaDotRet = B1.CreateFMul(alpha, dotRet); - Value *alphabeta = B1.CreateFMul(alphaDotRet, beta); - B1.CreateCall(FGer, {m, n, alphabeta, x, incx, w, incw, C, ldc}); - - B1.CreateRetVoid(); + optfirst(F, M); return true; } diff --git a/enzyme/tools/enzyme-tblgen/CMakeLists.txt b/enzyme/tools/enzyme-tblgen/CMakeLists.txt index b974488f80cb..3bb4df0b6f10 100644 --- a/enzyme/tools/enzyme-tblgen/CMakeLists.txt +++ b/enzyme/tools/enzyme-tblgen/CMakeLists.txt @@ -7,7 +7,7 @@ set(LLVM_LINK_COMPONENTS ) add_tablegen(enzyme-tblgen Enzyme - blas-tblgen.cpp datastructures.cpp caching.cpp enzyme-tblgen.cpp + opts.cpp blas-tblgen.cpp datastructures.cpp caching.cpp enzyme-tblgen.cpp ) get_target_property(TBL_LINKED_LIBS LLVMSupport INTERFACE_LINK_LIBRARIES) diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 9d4b3c1925ff..e5f64c36bce1 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1399,195 +1399,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, #include "blasDeclUpdater.h" #include "blasDiffUseUpdater.h" #include "blasTAUpdater.h" - -void emit_args(DagInit *dag, raw_ostream &os) { - for (size_t i = 0; i < dag->getNumArgs(); ++i) { - if (i > 0) - os << ", "; - if (auto def = dyn_cast(dag->getArg(i))) { - auto Def = def->getDef(); - if (Def->isSubClassOf("Value")) { - os << "tmp" << Def->getValueAsString("name"); - } - } else { - os << dag->getArgNameStr(i); - } - } -} - -// class BlasOptPattern< list _inputs, list _tmps, list -// _outputs> { -// list inputs = _inputs; -// // tmp variables will dissapear during the transformation -// // and therefore are not allowed to be read elsewhere -// list tmps = _tmps; -// list outputs = _outputs; -// } -// def first : BlasOptPattern<, -// [ -// (b<"ger"> $layout, $m, $n, $alpha, $x, $incx, $y, $incy, $A, $lda), -// (b<"ger"> $layout, $m, $k, $beta, $v, $incv, $w, $incw, $B, $ldb), -// (b<"gemm"> $layout, $transa, $transb, $m, $n, $k, $alpha, $A, $lda, $B, $ldb, -// $beta, $C, $ldc), -// ], -// ["A", "B"], -// [ -// (Value<1> (b<"dot"> $layout, $n, $v, $incv, $y, $incy)), -// (Value<2> (FMul $alpha Value<1>)), -// (Value<3> (FMul $beta Value<2>)), -// (b<"ger"> $layout, $m, $k, Value<3>, $x, $incx, $w, $incw, $C, $ldc), -// ] -// >; -void emitBlasOpt(StringRef name, std::vector inputs, - std::vector, std::vector outputs, - raw_ostream &os) { - os << "bool opt" << name << "(llvm::Function *F) {\n"; - os << " using namespace llvm;\n"; - - StringSet usedArgs{}; - std::vector functions{}; - StringMap> unique_functions{}; - for (auto input : inputs) { - ArrayRef args = input->getArgNames(); - auto Def = cast(input->getOperator())->getDef(); - assert(Def->isSubClassOf("b")); - auto fnc_name = Def->getValueAsString("s"); - // auto fnc_name = input->getNameStr(); - functions.push_back(fnc_name); - unique_functions[fnc_name].push_back(input); - for (auto &arg : args) { - if (usedArgs.count(arg->getValue())) - continue; - os << " Value *" << arg->getValue() << " = nullptr;\n"; - usedArgs.insert(arg->getValue()); - } - } - - os << "\n"; - - for (auto fnc : unique_functions.keys()) { - os << " size_t idx_" << fnc << " = 0;\n"; - } - - os << " // create a vector of calls to delete\n"; - os << " std::vector todelete;\n"; - - os << " for (auto &BB : *F) {\n" - << " for (auto &I : BB) {\n" - << " if (auto *CI = dyn_cast(&I)) {\n" - << " auto name = CI->getCalledFunction()->getName();\n"; - for (auto fnc : unique_functions.keys()) { - os << " if (name == \"" << fnc << "\") {\n"; - std::string tab = " "; - auto fnc_vec = unique_functions[fnc]; - bool multiple = fnc_vec.size() > 1; - os << tab << "assert(idx_" << fnc << " < " << fnc_vec.size() - << " && \"idx out of bounds\");\n"; - os << tab << "std::vector values;\n"; - for (size_t i = 0; i < fnc_vec.size(); ++i) { - if (multiple) { - os << tab << "if (idx_" << fnc << " == " << i << ")\n "; - } - os << tab << "values = {"; - ArrayRef args = fnc_vec[i]->getArgNames(); - bool first = true; - for (auto arg : args) { - os << (first ? "" : ", ") << arg->getValue(); - first = false; - } - os << "};\n"; - } - os << tab << "bool set = cmp_or_set(CI, values);\n"; - os << tab << "if (!set) continue;\n"; - os << tab << "llvm::errs() << \"found " << fnc << "\\n\";\n"; - os << tab << "idx_" << fnc << "++;\n" - << tab << "todelete.push_back(CI);\n" - << " }\n"; - } - os << " }\n"; - os << " }\n"; - os << " }\n"; - - // check that all functions have been found - os << " bool found = true;\n"; - for (auto fnc : unique_functions.keys()) { - os << " if (idx_" << fnc << " != " << unique_functions[fnc].size() << ")\n" - << " found = false;\n"; - } - os << " if (!found)\n" - << " return false;\n"; - - os << " llvm::errs() << \"found optimization " << name << "\\n\";\n"; - - // now that we found an optimization to apply, - // we can delete the old calls - os << " for (auto *CI : todelete) {\n" - << " CI->eraseFromParent();\n" - << " }\n"; - - - os << " BasicBlock *bb = &F->getEntryBlock();\n" - << " IRBuilder<> B1(bb);\n"; - - for (auto outerOutput : outputs) { - DagInit *output = outerOutput; - auto Def = cast(output->getOperator())->getDef(); - if (Def->isSubClassOf("Value")) { - assert(output->getNumArgs() == 1); - auto name = Def->getValueAsString("name"); - os << " Value *tmp" << name << " = "; - // This is just wrapping the actual DagInit in a Value<>. - // So now strip the Value wrapper to handle it in the next if/else - output = cast(output->getArg(0)); - Def = cast(output->getOperator())->getDef(); - } - if (Def->isSubClassOf("Inst")) { - auto name = Def->getValueAsString("name"); - os << " B1.CreateCall(" << name << ""; - emit_args(output, os); - os << ");\n"; - } else if (Def->isSubClassOf("b")) { - auto fnc_name = Def->getValueAsString("s"); - os << " " << fnc_name << "("; - emit_args(output, os); - os << ");\n"; - } else { - llvm::errs() << "failed with: " << Def->getName() << "\n"; - PrintFatalError(Def->getLoc(), "unknown output type"); - assert(false); - llvm_unreachable("unknown output type"); - } - } - - os << " return true;\n"; - os << "}\n"; -} - -static void emitBlasOpts(const RecordKeeper &recordKeeper, raw_ostream &os) { - emitSourceFileHeader("Rewriters", os); - const char *patternNames = "BlasOptPattern"; - const auto &patterns = recordKeeper.getAllDerivedDefinitions(patternNames); - - for (Record *pattern : patterns) { - ListInit *inputs = pattern->getValueAsListInit("inputs"); - std::vector tmps = pattern->getValueAsListOfStrings("tmps"); - ListInit *outputs = pattern->getValueAsListInit("outputs"); - - std::vector inputDags; - for (auto input : *inputs) { - DagInit *dag = dyn_cast(input); - assert(dag); - inputDags.push_back(dag); - } - std::vector outputDags; - for (auto output : *outputs) { - DagInit *dag = dyn_cast(output); - assert(dag); - outputDags.push_back(dag); - } - emitBlasOpt(pattern->getName(), inputDags, tmps, outputDags, os); - } -} +#include "opts.h" static bool EnzymeTableGenMain(raw_ostream &os, RecordKeeper &records) { switch (action) { diff --git a/enzyme/tools/enzyme-tblgen/opts.cpp b/enzyme/tools/enzyme-tblgen/opts.cpp new file mode 100644 index 000000000000..7272289807db --- /dev/null +++ b/enzyme/tools/enzyme-tblgen/opts.cpp @@ -0,0 +1,247 @@ + +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/PrettyStackTrace.h" +#include "llvm/Support/Signals.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Main.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +#include "opts.h" + +using namespace llvm; + +void emit_args(DagInit *dag, raw_ostream &os) { + for (size_t i = 0; i < dag->getNumArgs(); ++i) { + if (i > 0) + os << ", "; + if (auto def = dyn_cast(dag->getArg(i))) { + auto Def = def->getDef(); + if (Def->isSubClassOf("Value")) { + os << "tmp" << Def->getValueAsString("name"); + } + } else { + os << dag->getArgNameStr(i); + } + } +} +void emit_types(DagInit *dag, raw_ostream &os) { + for (size_t i = 0; i < dag->getNumArgs(); ++i) { + if (i > 0) + os << ", "; + if (auto def = dyn_cast(dag->getArg(i))) { + auto Def = def->getDef(); + if (Def->isSubClassOf("Value")) { + os << "tmp" << Def->getValueAsString("name") << "->getType()"; + } + } else { + os << dag->getArgNameStr(i) << "->getType()"; + } + } +} + +// class BlasOptPattern< list _inputs, list _tmps, list +// _outputs> { +// list inputs = _inputs; +// // tmp variables will dissapear during the transformation +// // and therefore are not allowed to be read elsewhere +// list tmps = _tmps; +// list outputs = _outputs; +// } +// def first : BlasOptPattern<, +// [ +// (b<"ger"> $layout, $m, $n, $alpha, $x, $incx, $y, $incy, $A, $lda), +// (b<"ger"> $layout, $m, $k, $beta, $v, $incv, $w, $incw, $B, $ldb), +// (b<"gemm"> $layout, $transa, $transb, $m, $n, $k, $alpha, $A, $lda, $B, $ldb, +// $beta, $C, $ldc), +// ], +// ["A", "B"], +// [ +// (Value<1> (b<"dot"> $layout, $n, $v, $incv, $y, $incy)), +// (Value<2> (FMul $alpha Value<1>)), +// (Value<3> (FMul $beta Value<2>)), +// (b<"ger"> $layout, $m, $k, Value<3>, $x, $incx, $w, $incw, $C, $ldc), +// ] +// >; +void emitBlasOpt(StringRef name, std::vector inputs, + std::vector, std::vector outputs, + raw_ostream &os) { + os << "bool opt" << name << "(llvm::Function *F, llvm::Module &M) {\n"; + os << " using namespace llvm;\n"; + + StringSet usedArgs{}; + std::vector functions{}; + StringMap> unique_functions{}; + for (auto input : inputs) { + ArrayRef args = input->getArgNames(); + auto Def = cast(input->getOperator())->getDef(); + assert(Def->isSubClassOf("b")); + auto fnc_name = Def->getValueAsString("s"); + // auto fnc_name = input->getNameStr(); + functions.push_back(fnc_name); + unique_functions[fnc_name].push_back(input); + for (auto &arg : args) { + if (usedArgs.count(arg->getValue())) + continue; + os << " Value *" << arg->getValue() << " = nullptr;\n"; + usedArgs.insert(arg->getValue()); + } + } + + os << "\n"; + + for (auto fnc : unique_functions.keys()) { + os << " size_t idx_" << fnc << " = 0;\n"; + } + + os << " // create a vector of calls to delete\n"; + os << " std::vector todelete;\n"; + + os << " for (auto &BB : *F) {\n" + << " for (auto &I : BB) {\n" + << " if (auto *CI = dyn_cast(&I)) {\n" + << " auto name = CI->getCalledFunction()->getName();\n"; + for (auto fnc : unique_functions.keys()) { + os << " if (name == \"" << fnc << "\") {\n"; + std::string tab = " "; + auto fnc_vec = unique_functions[fnc]; + bool multiple = fnc_vec.size() > 1; + os << tab << "assert(idx_" << fnc << " < " << fnc_vec.size() + << " && \"idx out of bounds\");\n"; + os << tab << "std::vector values;\n"; + for (size_t i = 0; i < fnc_vec.size(); ++i) { + if (multiple) { + os << tab << "if (idx_" << fnc << " == " << i << ")\n "; + } + os << tab << "values = {"; + ArrayRef args = fnc_vec[i]->getArgNames(); + bool first = true; + for (auto arg : args) { + os << (first ? "" : ", ") << arg->getValue(); + first = false; + } + os << "};\n"; + } + os << tab << "bool set = cmp_or_set(CI, values);\n"; + os << tab << "if (!set) continue;\n"; + os << tab << "llvm::errs() << \"found " << fnc << "\\n\";\n"; + os << tab << "idx_" << fnc << "++;\n" + << tab << "todelete.push_back(CI);\n" + << " }\n"; + } + os << " }\n"; + os << " }\n"; + os << " }\n"; + + // check that all functions have been found + os << " bool found = true;\n"; + for (auto fnc : unique_functions.keys()) { + os << " if (idx_" << fnc << " != " << unique_functions[fnc].size() << ")\n" + << " found = false;\n"; + } + os << " if (!found)\n" + << " return false;\n"; + + os << " llvm::errs() << \"found optimization " << name << "\\n\";\n"; + + // now that we found an optimization to apply, + // we can delete the old calls + os << " for (auto *CI : todelete) {\n" + << " CI->eraseFromParent();\n" + << " }\n"; + + os << " BasicBlock *bb = &F->getEntryBlock();\n" + << " IRBuilder<> B1(bb);\n"; + + for (auto outerOutput : outputs) { + DagInit *output = outerOutput; + auto buffer = std::string(""); + auto Def = cast(output->getOperator())->getDef(); + if (Def->isSubClassOf("Value")) { + assert(output->getNumArgs() == 1); + auto name = Def->getValueAsString("name"); + buffer = (Twine(" Value *tmp") + name + " = ").str(); + // This is just wrapping the actual DagInit in a Value<>. + // So now strip the Value wrapper to handle it in the next if/else + output = cast(output->getArg(0)); + Def = cast(output->getOperator())->getDef(); + } + if (Def->isSubClassOf("Inst")) { + auto name = Def->getValueAsString("name"); + os << buffer << " B1.Create" << name << "("; + emit_args(output, os); + os << ");\n"; + } else if (Def->isSubClassOf("b")) { + auto fnc_name = Def->getValueAsString("s"); + if (unique_functions[fnc_name].size() >= 1) { + // function decl already existed in the module + os << " Function *Fnc_" << fnc_name << " = M.getFunction(\"" + << fnc_name << "\");\n" + << " assert(Fnc_" << fnc_name << ");\n"; + } else { + // if the function decl did not exist, we need to create it + // iff the buffer is empty, we return void, othewise + // the buffer would be equal to 'Value *tmp = ...' + std::string retTy = "Type::getVoidTy(M.getContext())"; + if (!buffer.empty()) { + if (fnc_name.contains("64")) { + retTy = "Type::getDoubleTy(M.getContext())"; + } else { + retTy = "Type::getFloatTy(M.getContext())"; + } + } + os << " FunctionType *FT" << fnc_name << " = FunctionType::get(" + << retTy << ", {"; + emit_types(output, os); + os << "}, false);\n"; + os << " Function *Fnc_" << fnc_name + << " = cast(M.getOrInsertFunction(\"" << fnc_name + << "\", FT" << fnc_name << ").getCallee());\n"; + } + os << buffer << " B1.CreateCall(Fnc_" << fnc_name << ", {"; + emit_args(output, os); + os << "});\n"; + } else { + llvm::errs() << "failed with: " << Def->getName() << "\n"; + PrintFatalError(Def->getLoc(), "unknown output type"); + assert(false); + llvm_unreachable("unknown output type"); + } + } + //// bb->setInsertPoint(insertionPoint); + + //B1.CreateRetVoid(); + + os << " return true;\n"; + os << "}\n"; +} + +void emitBlasOpts(const RecordKeeper &recordKeeper, raw_ostream &os) { + emitSourceFileHeader("Rewriters", os); + const char *patternNames = "BlasOptPattern"; + const auto &patterns = recordKeeper.getAllDerivedDefinitions(patternNames); + + for (Record *pattern : patterns) { + ListInit *inputs = pattern->getValueAsListInit("inputs"); + std::vector tmps = pattern->getValueAsListOfStrings("tmps"); + ListInit *outputs = pattern->getValueAsListInit("outputs"); + + std::vector inputDags; + for (auto input : *inputs) { + DagInit *dag = dyn_cast(input); + assert(dag); + inputDags.push_back(dag); + } + std::vector outputDags; + for (auto output : *outputs) { + DagInit *dag = dyn_cast(output); + assert(dag); + outputDags.push_back(dag); + } + emitBlasOpt(pattern->getName(), inputDags, tmps, outputDags, os); + } +} diff --git a/enzyme/tools/enzyme-tblgen/opts.h b/enzyme/tools/enzyme-tblgen/opts.h new file mode 100644 index 000000000000..b6764ae131d3 --- /dev/null +++ b/enzyme/tools/enzyme-tblgen/opts.h @@ -0,0 +1,2 @@ +void emitBlasOpts(const llvm::RecordKeeper &recordKeeper, + llvm::raw_ostream &os); From 304b9ceb695ab74fab3c03a027a48c3864c79357 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 7 Aug 2023 16:30:22 -0400 Subject: [PATCH 08/17] fmt --- enzyme/Enzyme/OptBlas.h | 2 +- enzyme/tools/enzyme-tblgen/opts.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/OptBlas.h b/enzyme/Enzyme/OptBlas.h index 78d001b021e4..928c0c6939bf 100644 --- a/enzyme/Enzyme/OptBlas.h +++ b/enzyme/Enzyme/OptBlas.h @@ -18,7 +18,7 @@ // //===----------------------------------------------------------------------===// // -// This file contains createOptimizeBlas, +// This file contains createOptimizeBlas, // //===----------------------------------------------------------------------===// diff --git a/enzyme/tools/enzyme-tblgen/opts.cpp b/enzyme/tools/enzyme-tblgen/opts.cpp index 7272289807db..85df98981858 100644 --- a/enzyme/tools/enzyme-tblgen/opts.cpp +++ b/enzyme/tools/enzyme-tblgen/opts.cpp @@ -214,7 +214,7 @@ void emitBlasOpt(StringRef name, std::vector inputs, } //// bb->setInsertPoint(insertionPoint); - //B1.CreateRetVoid(); + // B1.CreateRetVoid(); os << " return true;\n"; os << "}\n"; From 4de22bc4929de668a5844fdee70e59f2816ce5e8 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 9 Aug 2023 15:49:51 -0400 Subject: [PATCH 09/17] adjust file descriptions --- enzyme/Enzyme/OptBlas.cpp | 6 ++---- enzyme/Enzyme/OptBlas.h | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/enzyme/Enzyme/OptBlas.cpp b/enzyme/Enzyme/OptBlas.cpp index aeb8d35333f2..c38bc3550cc2 100644 --- a/enzyme/Enzyme/OptBlas.cpp +++ b/enzyme/Enzyme/OptBlas.cpp @@ -1,4 +1,4 @@ -//===- PreserveNVVM.cpp - Mark NVVM attributes for preservation. -------===// +//===- OptBlas.cpp - Rewrite BLAS calls for better performance. -------===// // // Enzyme Project // @@ -18,9 +18,7 @@ // //===----------------------------------------------------------------------===// // -// This file contains createPreserveNVVM, a transformation pass that marks -// calls to __nv_* functions, marking them as noinline as implementing the llvm -// intrinsic. +// This file contains code to handle this new blas optimization pass. // //===----------------------------------------------------------------------===// #include diff --git a/enzyme/Enzyme/OptBlas.h b/enzyme/Enzyme/OptBlas.h index 928c0c6939bf..01adfe12f9eb 100644 --- a/enzyme/Enzyme/OptBlas.h +++ b/enzyme/Enzyme/OptBlas.h @@ -1,4 +1,4 @@ -//===- OptBlas.h - Mark NVVM attributes for preservation. -------===// +//===- OptBlas.h - Rewrite BLAS calls for better performance. -------===// // // Enzyme Project // @@ -18,7 +18,7 @@ // //===----------------------------------------------------------------------===// // -// This file contains createOptimizeBlas, +// This file contains code to handle this new blas optimization pass. // //===----------------------------------------------------------------------===// From 0ed2504c9030c17953dff3a77f96ba391e351581 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 9 Aug 2023 17:07:10 -0400 Subject: [PATCH 10/17] support new pm --- enzyme/Enzyme/Enzyme.cpp | 10 ++++++---- enzyme/Enzyme/OptBlas.cpp | 14 ++++++++++---- enzyme/Enzyme/OptBlas.h | 5 +++-- enzyme/test/Integration/BlasOpt/first.c | 2 ++ 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 59f72cd5a1bb..01325fe70a48 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -3199,6 +3199,8 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { #endif MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizerPM))); MPM.addPass(EnzymeNewPM(/*PostOpt=*/true)); + // Manuel, new + MPM.addPass(OptimizeBlasNewPM(/*Begin*/ false)); MPM.addPass(PreserveNVVMNewPM(/*Begin*/ false)); #if LLVM_VERSION_MAJOR >= 16 OptimizerPM2.addPass(llvm::GVNPass()); @@ -3491,10 +3493,10 @@ llvmGetPassPluginInfo() { PB.registerPipelineParsingCallback( [](llvm::StringRef Name, llvm::ModulePassManager &MPM, llvm::ArrayRef) { - // if (Name == "blas-opt") { - // MPM.addPass(createOptimizeBlasPass(/*Begin*/ true)); - // return true; - // } + if (Name == "blas-opt") { + MPM.addPass(OptimizeBlasNewPM(/*Begin*/ true)); + return true; + } if (Name == "enzyme") { MPM.addPass(EnzymeNewPM()); return true; diff --git a/enzyme/Enzyme/OptBlas.cpp b/enzyme/Enzyme/OptBlas.cpp index c38bc3550cc2..c1dcea74c424 100644 --- a/enzyme/Enzyme/OptBlas.cpp +++ b/enzyme/Enzyme/OptBlas.cpp @@ -62,11 +62,10 @@ bool optimizeBlas(bool Begin, Function &F) { namespace { class OptimizeBlas final : public ModulePass { + // class OptimizeBlas final : public PassInfoMixin { public: static char ID; // bool Begin; - // OptimizeBlas(bool Begin = true) : ModulePass(ID), Begin(Begin) {} - OptimizeBlas() : ModulePass(ID) {} OptimizeBlas(char &pid) : ModulePass(pid) {} void getAnalysisUsage(AnalysisUsage &AU) const override {} @@ -93,8 +92,6 @@ extern "C" void AddOptimizeBlasPass(LLVMPassManagerRef PM, uint8_t Begin) { unwrap(PM)->add(createOptimizeBlasPass((bool)Begin)); } -llvm::AnalysisKey OptimizeBlasNewPM::Key; - bool cmp_or_set(llvm::CallInst *CI, std::vector values) { // first run trough to see if the already set args match. // second run if they do and then we set the nullptr. @@ -119,6 +116,7 @@ bool optimizeFncsWithBlas(llvm::Module &M) { using namespace llvm; + errs() << "asdf\n"; Function *F = M.getFunction("f"); if (F) { errs() << "Found function: " << F->getName() << "\n"; @@ -130,3 +128,11 @@ bool optimizeFncsWithBlas(llvm::Module &M) { return true; } + +OptimizeBlasNewPM::Result +OptimizeBlasNewPM::run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) { + llvm::errs() << "fooBar\n"; + bool changed = optimizeFncsWithBlas(M); + return changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} +llvm::AnalysisKey OptimizeBlasNewPM::Key; diff --git a/enzyme/Enzyme/OptBlas.h b/enzyme/Enzyme/OptBlas.h index 01adfe12f9eb..f354c22b0d5f 100644 --- a/enzyme/Enzyme/OptBlas.h +++ b/enzyme/Enzyme/OptBlas.h @@ -35,16 +35,17 @@ llvm::ModulePass *createOptimizeBlasPass(bool Begin); class OptimizeBlasNewPM final : public llvm::AnalysisInfoMixin { - // friend struct llvm::AnalysisInfoMixin; + friend struct llvm::AnalysisInfoMixin; private: bool Begin; static llvm::AnalysisKey Key; public: + using Result = llvm::PreservedAnalyses; OptimizeBlasNewPM(bool Begin) : Begin(Begin) {} - bool runOnModule(llvm::Module &M); + Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM); static bool isRequired() { return true; } }; diff --git a/enzyme/test/Integration/BlasOpt/first.c b/enzyme/test/Integration/BlasOpt/first.c index a51ac37c6dc2..37f98b43bb35 100644 --- a/enzyme/test/Integration/BlasOpt/first.c +++ b/enzyme/test/Integration/BlasOpt/first.c @@ -1,3 +1,5 @@ +#include + enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102}; enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113}; From f1d2291370f7de9bc0f08d3ec20d14bf8484daa9 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 9 Aug 2023 17:37:02 -0400 Subject: [PATCH 11/17] add blasOpt in more cases --- enzyme/Enzyme/Enzyme.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index d78eb3b6ff9f..7e89eb77df3c 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -3140,6 +3140,7 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { auto loadPass = [prePass](ModulePassManager &MPM) #endif { + MPM.addPass(OptimizeBlasNewPM(/*Begin*/ false)); MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true)); #if LLVM_VERSION_MAJOR >= 12 @@ -3192,6 +3193,18 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { PB.registerPipelineStartEPCallback(loadPass); #endif +#if LLVM_VERSION_MAJOR >= 12 + auto optBLAS = [](ModulePassManager &MPM, OptimizationLevel) +#else + auto optBLAS = [](ModulePassManager &MPM) +#endif + { MPM.addPass(OptimizeBlasNewPM(/*Begin*/ true)); }; + + // We should register at vectorizer start for consistency, however, + // that requires a functionpass, and we have a modulepass. + // PB.registerVectorizerStartEPCallback(loadPass); + PB.registerPipelineStartEPCallback(optBLAS); + #if LLVM_VERSION_MAJOR >= 12 auto loadNVVM = [](ModulePassManager &MPM, OptimizationLevel) #else @@ -3204,6 +3217,7 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { // PB.registerVectorizerStartEPCallback(loadPass); PB.registerPipelineStartEPCallback(loadNVVM); #if LLVM_VERSION_MAJOR >= 15 + PB.registerFullLinkTimeOptimizationEarlyEPCallback(optBLAS); PB.registerFullLinkTimeOptimizationEarlyEPCallback(loadNVVM); auto preLTOPass = [](ModulePassManager &MPM, OptimizationLevel Level) { From 115df3f53a1f41fbdebaef72cb17e6b7152a8cde Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 11 Aug 2023 10:52:28 -0400 Subject: [PATCH 12/17] wip --- enzyme/Enzyme/Clang/EnzymePassLoader.cpp | 7 +++++++ enzyme/Enzyme/Enzyme.cpp | 26 ++++++++---------------- enzyme/Enzyme/OptBlas.cpp | 4 ++-- enzyme/Enzyme/OptBlas.h | 7 +++---- enzyme/test/Integration/BlasOpt/first.c | 4 ++-- 5 files changed, 22 insertions(+), 26 deletions(-) diff --git a/enzyme/Enzyme/Clang/EnzymePassLoader.cpp b/enzyme/Enzyme/Clang/EnzymePassLoader.cpp index b77608f06c81..c689f4af2ed1 100644 --- a/enzyme/Enzyme/Clang/EnzymePassLoader.cpp +++ b/enzyme/Enzyme/Clang/EnzymePassLoader.cpp @@ -34,6 +34,7 @@ #include "llvm/Transforms/Scalar/GVN.h" #include "../Enzyme.h" +#include "../OptBlas.h" #include "../PreserveNVVM.h" using namespace llvm; @@ -57,6 +58,10 @@ static void loadNVVMPass(const PassManagerBuilder &Builder, legacy::PassManagerBase &PM) { PM.add(createPreserveNVVMPass(/*Begin=*/true)); } +static void loadBLASPass(const PassManagerBuilder &Builder, + legacy::PassManagerBase &PM) { + PM.add(createOptimizeBlasPass(/*Begin=*/true)); +} // These constructors add our pass to a list of global extensions. static RegisterStandardPasses @@ -66,6 +71,8 @@ static RegisterStandardPasses static RegisterStandardPasses clangtoolLoader_OEarly(PassManagerBuilder::EP_EarlyAsPossible, loadNVVMPass); +static RegisterStandardPasses + clangtoolLoader_Ox(PassManagerBuilder::EP_VectorizerStart, loadBLASPass); static void loadLTOPass(const PassManagerBuilder &Builder, legacy::PassManagerBase &PM) { diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 7e89eb77df3c..63e28b12963a 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -3140,7 +3140,7 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { auto loadPass = [prePass](ModulePassManager &MPM) #endif { - MPM.addPass(OptimizeBlasNewPM(/*Begin*/ false)); + // MPM.addPass(OptimizeBlasNewPM(/*Begin*/ false)); MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true)); #if LLVM_VERSION_MAJOR >= 12 @@ -3163,8 +3163,6 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { #endif MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizerPM))); MPM.addPass(EnzymeNewPM(/*PostOpt=*/true)); - // Manuel, new - MPM.addPass(OptimizeBlasNewPM(/*Begin*/ false)); MPM.addPass(PreserveNVVMNewPM(/*Begin*/ false)); #if LLVM_VERSION_MAJOR >= 16 OptimizerPM2.addPass(llvm::GVNPass()); @@ -3193,18 +3191,6 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { PB.registerPipelineStartEPCallback(loadPass); #endif -#if LLVM_VERSION_MAJOR >= 12 - auto optBLAS = [](ModulePassManager &MPM, OptimizationLevel) -#else - auto optBLAS = [](ModulePassManager &MPM) -#endif - { MPM.addPass(OptimizeBlasNewPM(/*Begin*/ true)); }; - - // We should register at vectorizer start for consistency, however, - // that requires a functionpass, and we have a modulepass. - // PB.registerVectorizerStartEPCallback(loadPass); - PB.registerPipelineStartEPCallback(optBLAS); - #if LLVM_VERSION_MAJOR >= 12 auto loadNVVM = [](ModulePassManager &MPM, OptimizationLevel) #else @@ -3217,7 +3203,6 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { // PB.registerVectorizerStartEPCallback(loadPass); PB.registerPipelineStartEPCallback(loadNVVM); #if LLVM_VERSION_MAJOR >= 15 - PB.registerFullLinkTimeOptimizationEarlyEPCallback(optBLAS); PB.registerFullLinkTimeOptimizationEarlyEPCallback(loadNVVM); auto preLTOPass = [](ModulePassManager &MPM, OptimizationLevel Level) { @@ -3467,10 +3452,12 @@ llvmGetPassPluginInfo() { #ifdef ENZYME_RUNPASS augmentPassBuilder(PB); #endif + llvm::errs() << "CCCCCC\n"; PB.registerPipelineParsingCallback( [](llvm::StringRef Name, llvm::ModulePassManager &MPM, llvm::ArrayRef) { if (Name == "blas-opt") { + llvm::errs() << "AAAA\n"; MPM.addPass(OptimizeBlasNewPM(/*Begin*/ true)); return true; } @@ -3486,7 +3473,9 @@ llvmGetPassPluginInfo() { MPM.addPass(TypeAnalysisPrinterNewPM()); return true; } - return false; + llvm::errs() << "BBBB\n"; + return true; + // return false; }); PB.registerPipelineParsingCallback( [](llvm::StringRef Name, llvm::FunctionPassManager &FPM, @@ -3495,7 +3484,8 @@ llvmGetPassPluginInfo() { FPM.addPass(ActivityAnalysisPrinterNewPM()); return true; } - return false; + return true; + // return false; }); }}; } diff --git a/enzyme/Enzyme/OptBlas.cpp b/enzyme/Enzyme/OptBlas.cpp index c1dcea74c424..d264fc8aaf61 100644 --- a/enzyme/Enzyme/OptBlas.cpp +++ b/enzyme/Enzyme/OptBlas.cpp @@ -131,8 +131,8 @@ bool optimizeFncsWithBlas(llvm::Module &M) { OptimizeBlasNewPM::Result OptimizeBlasNewPM::run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) { - llvm::errs() << "fooBar\n"; + llvm::errs() << "newPM opt-blas\n"; bool changed = optimizeFncsWithBlas(M); return changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); } -llvm::AnalysisKey OptimizeBlasNewPM::Key; +// llvm::AnalysisKey OptimizeBlasNewPM::Key; diff --git a/enzyme/Enzyme/OptBlas.h b/enzyme/Enzyme/OptBlas.h index f354c22b0d5f..31d16b7a4f12 100644 --- a/enzyme/Enzyme/OptBlas.h +++ b/enzyme/Enzyme/OptBlas.h @@ -33,13 +33,12 @@ bool optimizeFncsWithBlas(llvm::Module &M); llvm::ModulePass *createOptimizeBlasPass(bool Begin); -class OptimizeBlasNewPM final - : public llvm::AnalysisInfoMixin { - friend struct llvm::AnalysisInfoMixin; +class OptimizeBlasNewPM final : public llvm::PassInfoMixin { + friend struct llvm::PassInfoMixin; private: bool Begin; - static llvm::AnalysisKey Key; + // static llvm::AnalysisKey Key; public: using Result = llvm::PreservedAnalyses; diff --git a/enzyme/test/Integration/BlasOpt/first.c b/enzyme/test/Integration/BlasOpt/first.c index 37f98b43bb35..259a43bfaa98 100644 --- a/enzyme/test/Integration/BlasOpt/first.c +++ b/enzyme/test/Integration/BlasOpt/first.c @@ -13,7 +13,7 @@ void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA const int lda, const double *B, const int ldb, const double beta, double *C, const int ldc); -void g(double *restrict x, double *restrict y, double *restrict v, double *restrict w, double *restrict C) { +void f(double *restrict x, double *restrict y, double *restrict v, double *restrict w, double *restrict C) { double A[] = {0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00}; @@ -38,7 +38,7 @@ int main() { double C[] = {0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00}; - g(x,y,v,w,C); + f(x,y,v,w,C); for (int i = 0; i < 9; i++) printf("%f\n", C[i]); } From 30073cc0929ce66df623c9843238fffb1d980884 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 19 Oct 2023 21:58:58 -0400 Subject: [PATCH 13/17] fix var updates, fix test ci --- enzyme/Enzyme/Clang/EnzymePassLoader.cpp | 3 +- enzyme/Enzyme/OptBlas.cpp | 5 ++- enzyme/Enzyme/Utils.cpp | 2 +- enzyme/test/Enzyme/BlasOpt/CMakeLists.txt | 12 +++++++ enzyme/test/Enzyme/CMakeLists.txt | 1 + enzyme/tools/enzyme-tblgen/opts.cpp | 38 ++++++++++++++++++++--- 6 files changed, 53 insertions(+), 8 deletions(-) create mode 100644 enzyme/test/Enzyme/BlasOpt/CMakeLists.txt diff --git a/enzyme/Enzyme/Clang/EnzymePassLoader.cpp b/enzyme/Enzyme/Clang/EnzymePassLoader.cpp index 03966778bfb3..b54dc6fce88f 100644 --- a/enzyme/Enzyme/Clang/EnzymePassLoader.cpp +++ b/enzyme/Enzyme/Clang/EnzymePassLoader.cpp @@ -77,7 +77,8 @@ static RegisterStandardPasses clangtoolLoader_OEarly(PassManagerBuilder::EP_EarlyAsPossible, loadNVVMPass); static RegisterStandardPasses - clangtoolLoader_Ox(PassManagerBuilder::EP_VectorizerStart, loadBLASPass); + clangtoolBlasLoader_Ox(PassManagerBuilder::EP_VectorizerStart, + loadBLASPass); static void loadLTOPass(const PassManagerBuilder &Builder, legacy::PassManagerBase &PM) { diff --git a/enzyme/Enzyme/OptBlas.cpp b/enzyme/Enzyme/OptBlas.cpp index d264fc8aaf61..d713513f9933 100644 --- a/enzyme/Enzyme/OptBlas.cpp +++ b/enzyme/Enzyme/OptBlas.cpp @@ -99,8 +99,11 @@ bool cmp_or_set(llvm::CallInst *CI, std::vector values) { if (values[i] == nullptr) { continue; } - if (CI->getArgOperand(i) != values[i]) + if (CI->getArgOperand(i) != values[i]) { + llvm::errs() << "cmp_or_set: " << i << " " << *CI->getArgOperand(i) + << " != " << *values[i] << "\n"; return false; + } } for (size_t i = 0; i < values.size(); ++i) { if (values[i] == nullptr) { diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 5e68f0b2bcdb..26616a32fd13 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -2354,7 +2354,7 @@ std::optional extractBLAS(llvm::StringRef in) llvm::Optional extractBLAS(llvm::StringRef in) #endif { - const char *extractable[] = {"dot", "scal", "axpy", "gemv", "gemm", "spmv"}; + const char *extractable[] = {"dot", "scal", "axpy", "gemv", "gemm", "spmv", "ger"}; const char *floatType[] = {"s", "d"}; // c, z const char *prefixes[] = {"" /*Fortran*/, "cblas_"}; const char *suffixes[] = {"", "_", "64_", "_64_"}; diff --git a/enzyme/test/Enzyme/BlasOpt/CMakeLists.txt b/enzyme/test/Enzyme/BlasOpt/CMakeLists.txt new file mode 100644 index 000000000000..1ced95e82b4a --- /dev/null +++ b/enzyme/test/Enzyme/BlasOpt/CMakeLists.txt @@ -0,0 +1,12 @@ +# Run regression and unit tests +add_lit_testsuite(check-enzyme-blasopt "Running enzyme blas optimization regression tests" + ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${ENZYME_TEST_DEPS} + ARGS -v +) + +set_target_properties(check-enzyme-blasopt PROPERTIES FOLDER "Tests") + +# add_lit_testsuites(ENZYME ${CMAKE_CURRENT_SOURCE_DIR} +# DEPENDS ${ENZYME_TEST_DEPS} +# ) diff --git a/enzyme/test/Enzyme/CMakeLists.txt b/enzyme/test/Enzyme/CMakeLists.txt index 0187644409f2..f54a64049f03 100644 --- a/enzyme/test/Enzyme/CMakeLists.txt +++ b/enzyme/test/Enzyme/CMakeLists.txt @@ -6,6 +6,7 @@ add_subdirectory(ForwardModeSplit) add_subdirectory(ForwardModeVector) add_subdirectory(BatchMode) add_subdirectory(ProbProg) +add_subdirectory(BlasOpt) # Run regression and unit tests add_lit_testsuite(check-enzyme "Running enzyme regression tests" diff --git a/enzyme/tools/enzyme-tblgen/opts.cpp b/enzyme/tools/enzyme-tblgen/opts.cpp index 85df98981858..6eb29d808123 100644 --- a/enzyme/tools/enzyme-tblgen/opts.cpp +++ b/enzyme/tools/enzyme-tblgen/opts.cpp @@ -100,13 +100,23 @@ void emitBlasOpt(StringRef name, std::vector inputs, os << " // create a vector of calls to delete\n"; os << " std::vector todelete;\n"; + os << " int num_calls = 0;\n"; os << " for (auto &BB : *F) {\n" << " for (auto &I : BB) {\n" << " if (auto *CI = dyn_cast(&I)) {\n" - << " auto name = CI->getCalledFunction()->getName();\n"; + << " num_calls++;\n" + << " auto CIname = CI->getCalledFunction()->getName();\n" + << " auto blasOption = extractBLAS(CIname);\n" + << "#if LLVM_VERSION_MAJOR >= 16\n" + << " if (!blasOption.has_value()) continue;\n" + << " auto blas = blasOption.value();\n" + << "#else\n" + << " if (!blasOption.hasValue()) continue;\n" + << " auto blas = blasOption.getValue();\n" + << "#endif\n"; for (auto fnc : unique_functions.keys()) { - os << " if (name == \"" << fnc << "\") {\n"; + os << " if (blas.function == \"" << fnc << "\") {\n"; std::string tab = " "; auto fnc_vec = unique_functions[fnc]; bool multiple = fnc_vec.size() > 1; @@ -127,12 +137,28 @@ void emitBlasOpt(StringRef name, std::vector inputs, os << "};\n"; } os << tab << "bool set = cmp_or_set(CI, values);\n"; - os << tab << "if (!set) continue;\n"; + os << tab << "if (!set) {\n"; + os << tab << " llvm::errs() << \"args missmatch: " << fnc << "\";\n"; + os << tab << " continue;\n"; + os << tab << "}\n"; + for (size_t i = 0; i < fnc_vec[0]->getNumArgs(); i++) { + os << tab << "values[" << i << "] = CI->getArgOperand(" << i << ");\n"; + } + for (size_t i = 0; i < fnc_vec.size(); ++i) { + ArrayRef args = fnc_vec[i]->getArgNames(); + size_t pos = 0; + for (auto arg : args) { + os << tab << arg->getValue() << " = CI->getArgOperand(" << i << ");\n"; + pos++; + } + } os << tab << "llvm::errs() << \"found " << fnc << "\\n\";\n"; os << tab << "idx_" << fnc << "++;\n" << tab << "todelete.push_back(CI);\n" + << tab << "continue;\n" << " }\n"; } + os << " llvm::errs() << \"unhandled: \" << blas.function << \"\\n\";\n"; os << " }\n"; os << " }\n"; os << " }\n"; @@ -143,8 +169,10 @@ void emitBlasOpt(StringRef name, std::vector inputs, os << " if (idx_" << fnc << " != " << unique_functions[fnc].size() << ")\n" << " found = false;\n"; } - os << " if (!found)\n" - << " return false;\n"; + os << " if (!found) {\n" + << " llvm::errs() << \"num calls: \" << num_calls << \"\\n\";\n" + << " return false;\n" + << " }\n"; os << " llvm::errs() << \"found optimization " << name << "\\n\";\n"; From ec50cd46d2fb2b99221dc669bc3638b1e45d65e7 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 20 Oct 2023 13:39:37 -0400 Subject: [PATCH 14/17] more tests, cleanup codegen --- .../test/Enzyme/BlasOpt/PlusEqualIntoZero.ll | 149 ++++++++++++++++++ enzyme/test/Enzyme/BlasOpt/blasCOpt.ll | 2 +- enzyme/tools/enzyme-tblgen/opts.cpp | 3 - 3 files changed, 150 insertions(+), 4 deletions(-) create mode 100644 enzyme/test/Enzyme/BlasOpt/PlusEqualIntoZero.ll diff --git a/enzyme/test/Enzyme/BlasOpt/PlusEqualIntoZero.ll b/enzyme/test/Enzyme/BlasOpt/PlusEqualIntoZero.ll new file mode 100644 index 000000000000..5277141bf332 --- /dev/null +++ b/enzyme/test/Enzyme/BlasOpt/PlusEqualIntoZero.ll @@ -0,0 +1,149 @@ +;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -blas-opt -S | FileCheck %s; fi +;RUN: %opt < %s %newLoadEnzyme -passes="blas-opt" -S | FileCheck %s + +declare void @dgemv_64_(i8, i64, i64, double, double* nocapture readonly, i64, double* nocapture readonly, i64, double, double* nocapture, i64) +declare void @dgemm_64_(i8, i8, i64, i64, i64, double, double* nocapture readonly, i64, double* nocapture readonly, i64, double, double* nocapture, i64) +declare void @daxpy(i32, double, double* nocapture readonly, i32, double* nocapture, i32) + +declare void @llvm.memset.p0.i64(double* nocapture writeonly, i8, i64, i1) +declare void @llvm.memset.inline.p0.i64(double* nocapture writeonly, i8, i64, i1) + + +; y is zeroed, so we can replace the beta arg with const 0.0 +define void @h(double* noalias %C, double %alpha, double %beta, double* %x, double *%y, double* %v, double *%w, i64 %m, i64 %n, i64 %p, i64 %lda, i64 %ldb, i64 %ldc, i64 %incx, i64 %incy, i64 %incv, i64 %incw, i8 %transa, i8 %transb) { +entry: + %A = alloca double, align 16 + store double 0.000000e+00, double* %A + %B = alloca double, align 16 + store double 0.000000e+00, double* %B + %size = mul i64 %incy, %n + call void @llvm.memset.p0.i64(double* %y, i8 0, i64 %size, i1 false) + call void @dgemv_64_(i8 %transa, i64 %m, i64 %n, double %alpha, double* %A, i64 %lda, double* %x, i64 %incx, double %beta, double* %y, i64 %incy) + ret void +} + +define void @h2(double* noalias %C, double %alpha, double %beta, double* %x, double *%y, double* %v, double *%w, i64 %m, i64 %n, i64 %p, i64 %lda, i64 %ldb, i64 %ldc, i64 %incx, i64 %incy, i64 %incv, i64 %incw, i8 %transa, i8 %transb) { +entry: + %A = alloca double, align 16 + store double 0.000000e+00, double* %A + %B = alloca double, align 16 + store double 0.000000e+00, double* %B + %size = mul i64 %incy, %n + call void @llvm.memset.p0.i64(double* %y, i8 0, i64 %size, i1 false) + store double 1.000000e+00, double* %y + call void @dgemv_64_(i8 %transa, i64 %m, i64 %n, double %alpha, double* %A, i64 %lda, double* %x, i64 %incx, double %beta, double* %y, i64 %incy) + ret void +} + +define void @h3(double* noalias %C, double %alpha, double %beta, double* %x, double *%y, double* %v, double *%w, i64 %m, i64 %n, i64 %p, i64 %lda, i64 %ldb, i64 %ldc, i64 %incx, i64 %incy, i64 %incv, i64 %incw, i8 %transa, i8 %transb) { +entry: + %A = alloca double, align 16 + store double 0.000000e+00, double* %A + %B = alloca double, align 16 + store double 0.000000e+00, double* %B + call void @llvm.memset.p0.i64(double* %y, i8 0, i64 2, i1 false) + call void @dgemv_64_(i8 %transa, i64 %m, i64 %n, double %alpha, double* %A, i64 %lda, double* %x, i64 %incx, double %beta, double* %y, i64 %incy) + ret void +} + +; x is zeroed, so y = alpha * x + y is a no-op, so we remove the call to daxpy +define void @g(double* noalias %C, double %alpha, double %beta, double* %x, double *%y, double* %v, double *%w, i64 %m, i64 %n, i64 %p, i64 %lda, i64 %ldb, i64 %ldc, i64 %incx, i64 %incy, i64 %incv, i64 %incw, i8 %transa, i8 %transb) { +entry: + %A = alloca double, align 16 + store double 0.000000e+00, double* %A + %B = alloca double, align 16 + store double 0.000000e+00, double* %B + %size = mul i64 %incx, %n + call void @llvm.memset.p0.i64(double* %x, i8 0, i64 %size, i1 false) + call void @daxpy(i64 %n, double %alpha, double* %x, i64 %incx, double* %y, i64 %incy) + ret void +} + + +; A is zeroed, so we can replace the beta arg with const 0.0 +define void @f(double* noalias %C, double %alpha, double %beta, double* %x, double *%y, double* %v, double *%w, i64 %m, i64 %n, i64 %p, i64 %lda, i64 %ldb, i64 %ldc, i64 %incx, i64 %incy, i64 %incv, i64 %incw, i8 %transa, i8 %transb) { +entry: + %A = alloca double, align 16 + store double 0.000000e+00, double* %A + %B = alloca double, align 16 + store double 0.000000e+00, double* %B + %size = mul i64 %ldc, %n + call void @llvm.memset.p0.i64(double* %C, i8 0, i64 %size, i1 false) + call void @dgemm_64_(i8 %transa, i8 %transb, i64 %m, i64 %n, i64 %p, double %alpha, double* %A, i64 %lda, double* %B, i64 %ldb, double %beta, double* %C, i64 %ldc) + ret void +} + +; A is zeroed, so we can replace the beta arg with const 0.0 +; use inline memset here to check we still recognize it +define void @f2(double* noalias %C, double %alpha, double %beta, double* %x, double *%y, double* %v, double *%w, i64 %m, i64 %n, i64 %p, i64 %lda, i64 %ldb, i64 %ldc, i64 %incx, i64 %incy, i64 %incv, i64 %incw, i8 %transa, i8 %transb) { +entry: + %A = alloca double, align 16 + store double 0.000000e+00, double* %A + %B = alloca double, align 16 + store double 0.000000e+00, double* %B + %size = mul i64 %ldc, %n + call void @llvm.memset.inline.p0.i64(double* %C, i8 0, i64 %size, i1 false) + call void @dgemm_64_(i8 %transa, i8 %transb, i64 %m, i64 %n, i64 %p, double %alpha, double* %A, i64 %lda, double* %B, i64 %ldb, double %beta, double* %C, i64 %ldc) + ret void +} + +; CHECK: define void @h(double* noalias %C, double %alpha, double %beta, double* %x, double* %y, double* %v, double* %w, i64 %m, i64 %n, i64 %p, i64 %lda, i64 %ldb, i64 %ldc, i64 %incx, i64 %incy, i64 %incv, i64 %incw, i8 %transa, i8 %transb) +; CHECK-NEXT: entry: +; CHECK-NEXT: %A = alloca double, align 16 +; CHECK-NEXT: store double 0.000000e+00, double* %A +; CHECK-NEXT: %B = alloca double, align 16 +; CHECK-NEXT: store double 0.000000e+00, double* %B +; CHECK-NEXT: call void @dgemv_64_(i8 %transa, i64 %m, i64 %n, double %alpha, double* %A, i64 %lda, double* %x, i64 %incx, double 0.0, double* %y, i64 %incy) +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define void @h2(double* noalias %C, double %alpha, double %beta, double* %x, double* %y, double* %v, double* %w, i64 %m, i64 %n, i64 %p, i64 %lda, i64 %ldb, i64 %ldc, i64 %incx, i64 %incy, i64 %incv, i64 %incw, i8 %transa, i8 %transb) +; CHECK-NEXT: entry: +; CHECK-NEXT: %A = alloca double, align 16 +; CHECK-NEXT: store double 0.000000e+00, double* %A +; CHECK-NEXT: %B = alloca double, align 16 +; CHECK-NEXT: store double 0.000000e+00, double* %B +; CHECK-NEXT: call void @dgemv_64_(i8 %transa, i64 %m, i64 %n, double %alpha, double* %A, i64 %lda, double* %x, i64 %incx, double 0.0, double* %y, i64 %incy) +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define void @h3(double* noalias %C, double %alpha, double %beta, double* %x, double* %y, double* %v, double* %w, i64 %m, i64 %n, i64 %p, i64 %lda, i64 %ldb, i64 %ldc, i64 %incx, i64 %incy, i64 %incv, i64 %incw, i8 %transa, i8 %transb) +; CHECK-NEXT: entry: +; CHECK-NEXT: %A = alloca double, align 16 +; CHECK-NEXT: store double 0.000000e+00, double* %A +; CHECK-NEXT: %B = alloca double, align 16 +; CHECK-NEXT: store double 0.000000e+00, double* %B +; CHECK-NEXT: call void @dgemv_64_(i8 %transa, i64 %m, i64 %n, double %alpha, double* %A, i64 %lda, double* %x, i64 %incx, double 0.0, double* %y, i64 %incy) +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define void @g(double* noalias %C, double %alpha, double %beta, double* %x, double* %y, double* %v, double* %w, i64 %m, i64 %n, i64 %p, i64 %lda, i64 %ldb, i64 %ldc, i64 %incx, i64 %incy, i64 %incv, i64 %incw, i8 %transa, i8 %transb) +; CHECK-NEXT: entry: +; CHECK-NEXT: %A = alloca double, align 16 +; CHECK-NEXT: store double 0.000000e+00, double* %A +; CHECK-NEXT: %B = alloca double, align 16 +; CHECK-NEXT: store double 0.000000e+00, double* %B +; CHECK-NEXT: %size = mul i64 %incx, %n +; CHECK-NEXT: call void @llvm.memset.p0.i64(double* %x, i8 0, i64 %size, i1 false) +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define void @f(double* noalias %C, double %alpha, double %beta, double* %x, double* %y, double* %v, double* %w, i64 %m, i64 %n, i64 %p, i64 %lda, i64 %ldb, i64 %ldc, i64 %incx, i64 %incy, i64 %incv, i64 %incw, i8 %transa, i8 %transb) +; CHECK-NEXT: entry: +; CHECK-NEXT: %A = alloca double, align 16 +; CHECK-NEXT: store double 0.000000e+00, double* %A +; CHECK-NEXT: %B = alloca double, align 16 +; CHECK-NEXT: store double 0.000000e+00, double* %B +; CHECK-NEXT: call void @dgemm_64_(i8 %transa, i8 %transb, i64 %m, i64 %n, i64 %p, double %alpha, double* %A, i64 %lda, double* %B, i64 %ldb, double 0.0, double* %C, i64 %ldc) +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define void @f2(double* noalias %C, double %alpha, double %beta, double* %x, double* %y, double* %v, double* %w, i64 %m, i64 %n, i64 %p, i64 %lda, i64 %ldb, i64 %ldc, i64 %incx, i64 %incy, i64 %incv, i64 %incw, i8 %transa, i8 %transb) +; CHECK-NEXT: entry: +; CHECK-NEXT: %A = alloca double, align 16 +; CHECK-NEXT: store double 0.000000e+00, double* %A +; CHECK-NEXT: %B = alloca double, align 16 +; CHECK-NEXT: store double 0.000000e+00, double* %B +; CHECK-NEXT: call void @dgemm_64_(i8 %transa, i8 %transb, i64 %m, i64 %n, i64 %p, double %alpha, double* %A, i64 %lda, double* %B, i64 %ldb, double 0.0, double* %C, i64 %ldc) +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/BlasOpt/blasCOpt.ll b/enzyme/test/Enzyme/BlasOpt/blasCOpt.ll index 2b0aca7fcdb4..7ae2afb434eb 100644 --- a/enzyme/test/Enzyme/BlasOpt/blasCOpt.ll +++ b/enzyme/test/Enzyme/BlasOpt/blasCOpt.ll @@ -11,8 +11,8 @@ entry: store double 0.000000e+00, double* %A %B = alloca double, align 16 store double 0.000000e+00, double* %B - call void @dger_64_(i64 %m, i64 %n, double %alpha, double* %x, i64 %incx, double* %y, i64 %incy, double* %A, i64 %lda) call void @dger_64_(i64 %n, i64 %p, double %beta, double* %v, i64 %incv, double* %w, i64 %incw, double* %B, i64 %ldb) + call void @dger_64_(i64 %m, i64 %n, double %alpha, double* %x, i64 %incx, double* %y, i64 %incy, double* %A, i64 %lda) call void @dgemm_64_(i8 %transa, i8 %transb, i64 %m, i64 %n, i64 %p, double %alpha, double* %A, i64 %lda, double* %B, i64 %ldb, double %beta, double* %C, i64 %ldc) ;%ptr = bitcast i8* %A to double* ;store double 0.0000000e+00, double* %ptr, align 8 diff --git a/enzyme/tools/enzyme-tblgen/opts.cpp b/enzyme/tools/enzyme-tblgen/opts.cpp index 6eb29d808123..68e9840d94c0 100644 --- a/enzyme/tools/enzyme-tblgen/opts.cpp +++ b/enzyme/tools/enzyme-tblgen/opts.cpp @@ -141,9 +141,6 @@ void emitBlasOpt(StringRef name, std::vector inputs, os << tab << " llvm::errs() << \"args missmatch: " << fnc << "\";\n"; os << tab << " continue;\n"; os << tab << "}\n"; - for (size_t i = 0; i < fnc_vec[0]->getNumArgs(); i++) { - os << tab << "values[" << i << "] = CI->getArgOperand(" << i << ");\n"; - } for (size_t i = 0; i < fnc_vec.size(); ++i) { ArrayRef args = fnc_vec[i]->getArgNames(); size_t pos = 0; From c2b4c1e43fa009074b9e1d85eece4e8e1409af46 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 20 Oct 2023 16:22:03 -0400 Subject: [PATCH 15/17] first zeroedMem optimizations --- enzyme/Enzyme/BlasOpts.td | 36 +++++++++++++++++++++++++ enzyme/Enzyme/Utils.cpp | 42 +++++++++++++++++++++++++++++ enzyme/Enzyme/Utils.h | 3 +++ enzyme/tools/enzyme-tblgen/opts.cpp | 27 +++---------------- 4 files changed, 85 insertions(+), 23 deletions(-) diff --git a/enzyme/Enzyme/BlasOpts.td b/enzyme/Enzyme/BlasOpts.td index f2ae4846af3f..308f5128202e 100644 --- a/enzyme/Enzyme/BlasOpts.td +++ b/enzyme/Enzyme/BlasOpts.td @@ -11,6 +11,42 @@ class Value { string name = _i; } +class ZeroInst : Inst<"blas">; +def ZeroVec : ZeroInst; +def ZeroMat : ZeroInst; + +def zeroOne : BlasOptPattern< + [ + (ZeroVec $n, $y, $incy), + (b<"gemv"> $transa, $m, $n, $alpha, $A, $lda, $x, $incx, $beta, $y, $incy), + ], + [], + [ + (b<"gemv"> $transa, $m, $n, $alpha, $A, $lda, $x, $incx, Constant<"0.0">, $y, $incy), + ] + >; +def zeroTwo : BlasOptPattern< + [ + (ZeroVec $n, $x, $incx), + (b<"axpy"> $n, $alpha, $x, $incx, $y, $incy), + ], + [], + [ + // y = alpha * x + y becomes no-op if x is zero. + (ZeroVec $n, $x, $incx), + ] + >; +def zeroThree : BlasOptPattern< + [ + (ZeroMat $n, $C, $ldc), + (b<"gemm"> $transa, $transb, $m, $n, $k, $alpha, $A, $lda, $B, $ldb, $beta, $C, $ldc), + ], + [], + [ + (b<"gemm"> $transa, $transb, $m, $n, $k, $alpha, $A, $lda, $B, $ldb, Constant<"0.0">, $C, $ldc), + ] + >; + def first : BlasOptPattern< [ (b<"ger"> $m, $n, $alpha, $x, $incx, $y, $incy, $A, $lda), diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 26616a32fd13..20048d8b78d0 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -2677,6 +2677,48 @@ SmallVector get_blas_row(llvm::IRBuilder<> &B, return toreturn; } +bool isConstVecZeroing(llvm::CallInst *CI, llvm::Value *y, llvm::Value *n, + llvm::Value *incy) { + // assert CI is memset. + if (isa(CI)) { + auto II = cast(CI); + if (II->getIntrinsicID() != Intrinsic::memset) + return false; + } else { + auto name = CI->getCalledFunction()->getName(); + if (!name.startswith("llvm.memset")) + return false; + } + + if (!isa(n) || !isa(incy)) + return false; + + //declare void @llvm.memset.p0.i32(ptr , i8 , i32 , i1 ) + auto memset_len = CI->getArgOperand(2); + if (!isa(memset_len)) + return false; + auto memset_lenC = cast(memset_len); + auto nC = cast(n); + auto incyC = cast(incy); + if (memset_lenC->getZExtValue() < nC->getZExtValue() * incyC->getZExtValue()) + return false; + return true; + +} + +// Used for BlasOpts +Value *isVecZeroing(IRBuilder<> &B, llvm::CallInst *CI, llvm::Value *y, llvm::Value *n, llvm::Value *incy) { + Value *size = B.CreateMul(n, incy); + // assert CI is memset. + + // conservative + return ConstantInt::getFalse(CI->getContext()); +} +llvm::Value *isMatZeroing() { + // conservative + return ConstantInt::getFalse(CI->getContext()); +} + // return how many Special pointers are in T (count > 0), // and if there is anything else in T (all == false) CountTrackedPointers::CountTrackedPointers(Type *T) { diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 1fef2c7530b5..7c3193d5abbc 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -1677,6 +1677,9 @@ get_blas_row(llvm::IRBuilder<> &B, llvm::ArrayRef trans, llvm::ArrayRef row, llvm::ArrayRef col, bool byRef, bool cublas); +llvm::Value *isVecZeroing(llvm::IRBuilder<> &B, llvm::CallInst *CI, llvm::Value *y, llvm::Value *n, llvm::Value *incy); +llvm::Value *isMatZeroing(llvm::IRBuilder<> &B, llvm::CallInst *CI, llvm::Value *C, llvm::Value *n, llvm::Value *ldc); + // Parameter attributes from the original function/call that // we should preserve on the primal of the derivative code. static inline llvm::Attribute::AttrKind PrimalParamAttrsToPreserve[] = { diff --git a/enzyme/tools/enzyme-tblgen/opts.cpp b/enzyme/tools/enzyme-tblgen/opts.cpp index 68e9840d94c0..3fada470900c 100644 --- a/enzyme/tools/enzyme-tblgen/opts.cpp +++ b/enzyme/tools/enzyme-tblgen/opts.cpp @@ -44,29 +44,8 @@ void emit_types(DagInit *dag, raw_ostream &os) { } } -// class BlasOptPattern< list _inputs, list _tmps, list -// _outputs> { -// list inputs = _inputs; -// // tmp variables will dissapear during the transformation -// // and therefore are not allowed to be read elsewhere -// list tmps = _tmps; -// list outputs = _outputs; -// } -// def first : BlasOptPattern<, -// [ -// (b<"ger"> $layout, $m, $n, $alpha, $x, $incx, $y, $incy, $A, $lda), -// (b<"ger"> $layout, $m, $k, $beta, $v, $incv, $w, $incw, $B, $ldb), -// (b<"gemm"> $layout, $transa, $transb, $m, $n, $k, $alpha, $A, $lda, $B, $ldb, -// $beta, $C, $ldc), -// ], -// ["A", "B"], -// [ -// (Value<1> (b<"dot"> $layout, $n, $v, $incv, $y, $incy)), -// (Value<2> (FMul $alpha Value<1>)), -// (Value<3> (FMul $beta Value<2>)), -// (b<"ger"> $layout, $m, $k, Value<3>, $x, $incx, $w, $incw, $C, $ldc), -// ] -// >; + +// (ZeroMat $n, $C, $ldc), void emitBlasOpt(StringRef name, std::vector inputs, std::vector, std::vector outputs, raw_ostream &os) { @@ -212,6 +191,8 @@ void emitBlasOpt(StringRef name, std::vector inputs, // iff the buffer is empty, we return void, othewise // the buffer would be equal to 'Value *tmp = ...' std::string retTy = "Type::getVoidTy(M.getContext())"; + // this is wrong, needs to be decided at compile time, not codegen time, + // based on the blas fnc names we see from the user if (!buffer.empty()) { if (fnc_name.contains("64")) { retTy = "Type::getDoubleTy(M.getContext())"; From 55c4bfe14e39d5be2e4bcccad22fe6d18639ccdd Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 23 Oct 2023 21:24:30 -0400 Subject: [PATCH 16/17] add two scal opts --- enzyme/Enzyme/BlasOpts.td | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/enzyme/Enzyme/BlasOpts.td b/enzyme/Enzyme/BlasOpts.td index 308f5128202e..dda9105c5a60 100644 --- a/enzyme/Enzyme/BlasOpts.td +++ b/enzyme/Enzyme/BlasOpts.td @@ -61,3 +61,21 @@ def first : BlasOptPattern< (b<"ger"> $m, $k, Value<"3">, $x, $incx, $w, $incw, $C, $ldc), ] >; +def noopScal : BlasOptPattern< + [ + (b<"scal"> $n, Constant<"1.0">, adj<"x">) + ], + [], + [ + ] + >; +def zeroScal : BlasOptPattern< + [ + (b<"scal"> $n, Constant<"0.0">, adj<"x">) + ], + [], + [ + (memset $x, $n, Constant<"0.0">) + ] + ] + >; From d8b9ebc88e6aced0a80a4a298a1c859477ae5057 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 23 Oct 2023 23:01:37 -0400 Subject: [PATCH 17/17] blas_copy between two continuous vecs is memcpy --- enzyme/Enzyme/BlasOpts.td | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/enzyme/Enzyme/BlasOpts.td b/enzyme/Enzyme/BlasOpts.td index dda9105c5a60..fd1d0a500770 100644 --- a/enzyme/Enzyme/BlasOpts.td +++ b/enzyme/Enzyme/BlasOpts.td @@ -79,3 +79,13 @@ def zeroScal : BlasOptPattern< ] ] >; +def continuousCopy : BlasOptPattern< + [ + (b<"copy"> $n, $x, ConstantInt<1>, $y, ConstantInt<1>)) + ], + [], + [ + (memcopy $n, $x, $y), + ] + ] + >;