Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blas opt #1366

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions enzyme/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,19 @@ gentbl(
":enzyme-tblgen",
],
)
gentbl(
name = "blas-optimizations",
tbl_outs = [(
"-gen-blas-optimizations",
"BlasOpts.inc",
)],
tblgen = ":enzyme-tblgen",
td_file = "Enzyme/InstructionDerivatives.td",
td_srcs = ["Enzyme/BlasDerivatives.td"],
deps = [
":enzyme-tblgen",
],
)
gentbl(
name = "blas-diffuseanalysis",
tbl_outs = [(
Expand Down
27 changes: 27 additions & 0 deletions enzyme/Enzyme/BlasOpts.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

class BlasOptPattern< list<dag> _inputs, list<string> _tmps, list<dag> _outputs> {
list<dag> inputs = _inputs;
// tmp variables will dissapear during the transformation
// and therefore are not allowed to be read elsewhere
list<string> tmps = _tmps;
list<dag> outputs = _outputs;
}

class Value<string _i> {
string name = _i;
}

def first : BlasOptPattern<
[
(b<"ger"> $m, $n, $alpha, $x, $incx, $y, $incy, $A, $lda),
(b<"ger"> $n, $k, $beta, $v, $incv, $w, $incw, $B, $ldb),
(b<"gemm"> $transa, $transb, $m, $n, $k, $alpha, $A, $lda, $B, $ldb, $beta, $C, $ldc),
],
["A", "B"],
[
(Value<"1"> (b<"dot"> $n, $v, $incv, $y, $incy)),
(Value<"2"> (FMul $alpha, Value<"1">)),
(Value<"3"> (FMul $beta, Value<"2">)),
(b<"ger"> $m, $k, Value<"3">, $x, $incx, $w, $incw, $C, $ldc),
]
>;
3 changes: 3 additions & 0 deletions enzyme/Enzyme/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ enzyme_tablegen(BlasDerivatives.inc -gen-blas-derivatives)
enzyme_tablegen(BlasAttributor.inc -update-blas-declarations)
enzyme_tablegen(BlasTA.inc -gen-blas-typeanalysis)
enzyme_tablegen(BlasDiffUse.inc -gen-blas-diffuseanalysis)
enzyme_tablegen(BlasOpts.inc -gen-blas-optimizations)
add_public_tablegen_target(BinopDerivativesIncGen)
add_public_tablegen_target(IntrinsicDerivativesIncGen)
add_public_tablegen_target(CallDerivativesIncGen)
Expand All @@ -36,6 +37,7 @@ add_public_tablegen_target(BlasDerivativesIncGen)
add_public_tablegen_target(BlasDeclarationsIncGen)
add_public_tablegen_target(BlasTAIncGen)
add_public_tablegen_target(BlasDiffUseIncGen)
add_public_tablegen_target(BlasOptsIncGen)

include_directories(${CMAKE_CURRENT_BINARY_DIR})

Expand Down Expand Up @@ -127,6 +129,7 @@ if (${ENZYME_EXTERNAL_SHARED_LIB})
add_dependencies(Enzyme-${LLVM_VERSION_MAJOR} BlasDeclarationsIncGen)
add_dependencies(Enzyme-${LLVM_VERSION_MAJOR} BlasTAIncGen)
add_dependencies(Enzyme-${LLVM_VERSION_MAJOR} BlasDiffUseIncGen)
add_dependencies(Enzyme-${LLVM_VERSION_MAJOR} BlasOptsIncGen)
target_link_libraries(Enzyme-${LLVM_VERSION_MAJOR} LLVM)
install(TARGETS Enzyme-${LLVM_VERSION_MAJOR}
EXPORT EnzymeTargets
Expand Down
5 changes: 5 additions & 0 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2954,6 +2954,7 @@ class EnzymeNewPM final : public EnzymeBase,
AnalysisKey EnzymeNewPM::Key;

#include "ActivityAnalysisPrinter.h"
#include "OptBlas.h"
#include "PreserveNVVM.h"
#include "TypeAnalysis/TypeAnalysisPrinter.h"
#include "llvm/Passes/PassBuilder.h"
Expand Down Expand Up @@ -3490,6 +3491,10 @@ llvmGetPassPluginInfo() {
PB.registerPipelineParsingCallback(
[](llvm::StringRef Name, llvm::ModulePassManager &MPM,
llvm::ArrayRef<llvm::PassBuilder::PipelineElement>) {
// if (Name == "blas-opt") {
ZuseZ4 marked this conversation as resolved.
Show resolved Hide resolved
// MPM.addPass(createOptimizeBlasPass(/*Begin*/ true));
// return true;
// }
if (Name == "enzyme") {
MPM.addPass(EnzymeNewPM());
return true;
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/InstructionDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -903,3 +903,4 @@ def : InstPattern<(Op $x), "Freeze", 10, 0,
>;

include "BlasDerivatives.td"
include "BlasOpts.td"
134 changes: 134 additions & 0 deletions enzyme/Enzyme/OptBlas.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
//===- PreserveNVVM.cpp - Mark NVVM attributes for preservation. -------===//
ZuseZ4 marked this conversation as resolved.
Show resolved Hide resolved
//
// Enzyme Project
//
// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// If using this code in an academic setting, please cite the following:
// @incollection{enzymeNeurips,
// title = {Instead of Rewriting Foreign Code for Machine Learning,
// Automatically Synthesize Fast Gradients},
// author = {Moses, William S. and Churavy, Valentin},
// booktitle = {Advances in Neural Information Processing Systems 33},
// year = {2020},
// note = {To appear in},
// }
//
//===----------------------------------------------------------------------===//
//
// This file contains createPreserveNVVM, a transformation pass that marks
// calls to __nv_* functions, marking them as noinline as implementing the llvm
// intrinsic.
//
//===----------------------------------------------------------------------===//
#include <llvm/Config/llvm-config.h>

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"

#include "llvm/ADT/SmallSet.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/raw_ostream.h"

#include "llvm/Pass.h"

#include "llvm/Transforms/Utils.h"

#include <map>

#include "OptBlas.h"
#include "Utils.h"

using namespace llvm;
#ifdef DEBUG_TYPE
#undef DEBUG_TYPE
#endif
#define DEBUG_TYPE "blas-opt"

#if LLVM_VERSION_MAJOR >= 14
#define addAttribute addAttributeAtIndex
#endif

bool optimizeBlas(bool Begin, Function &F) {
bool changed = false;
return changed;
}

namespace {

class OptimizeBlas final : public ModulePass {
public:
static char ID;
// bool Begin;
// OptimizeBlas(bool Begin = true) : ModulePass(ID), Begin(Begin) {}

OptimizeBlas() : ModulePass(ID) {}
OptimizeBlas(char &pid) : ModulePass(pid) {}
void getAnalysisUsage(AnalysisUsage &AU) const override {}
bool runOnModule(Module &M) override { return optimizeFncsWithBlas(M); }
};

} // namespace

char OptimizeBlas::ID = 0;

static RegisterPass<OptimizeBlas> X("blas-opt", "Optimize Blas Pass");

ModulePass *createOptimizeBlasPass(bool Begin) {
char pid = 0;
return new OptimizeBlas(pid);
}

#include <llvm-c/Core.h>
#include <llvm-c/Types.h>

#include "llvm/IR/LegacyPassManager.h"

extern "C" void AddOptimizeBlasPass(LLVMPassManagerRef PM, uint8_t Begin) {
unwrap(PM)->add(createOptimizeBlasPass((bool)Begin));
}

llvm::AnalysisKey OptimizeBlasNewPM::Key;

bool cmp_or_set(llvm::CallInst *CI, std::vector<llvm::Value *> values) {
// first run trough to see if the already set args match.
// second run if they do and then we set the nullptr.
for (size_t i = 0; i < values.size(); ++i) {
if (values[i] == nullptr) {
continue;
}
if (CI->getArgOperand(i) != values[i])
return false;
}
for (size_t i = 0; i < values.size(); ++i) {
if (values[i] == nullptr) {
values[i] = CI->getArgOperand(i);
}
}
return true;
}

#include "BlasOpts.inc"

bool optimizeFncsWithBlas(llvm::Module &M) {

using namespace llvm;

Function *F = M.getFunction("f");
if (F) {
errs() << "Found function: " << F->getName() << "\n";
} else {
return false;
}

optfirst(F, M);

return true;
}
50 changes: 50 additions & 0 deletions enzyme/Enzyme/OptBlas.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
//===- OptBlas.h - Mark NVVM attributes for preservation. -------===//
//
// Enzyme Project
//
// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// If using this code in an academic setting, please cite the following:
// @incollection{enzymeNeurips,
// title = {Instead of Rewriting Foreign Code for Machine Learning,
// Automatically Synthesize Fast Gradients},
// author = {Moses, William S. and Churavy, Valentin},
// booktitle = {Advances in Neural Information Processing Systems 33},
// year = {2020},
// note = {To appear in},
// }
//
//===----------------------------------------------------------------------===//
//
// This file contains createOptimizeBlas,
//
//===----------------------------------------------------------------------===//

#include "llvm/IR/PassManager.h"
#include "llvm/Passes/PassPlugin.h"

namespace llvm {
class ModulePass;
}

bool optimizeFncsWithBlas(llvm::Module &M);

llvm::ModulePass *createOptimizeBlasPass(bool Begin);

class OptimizeBlasNewPM final
: public llvm::AnalysisInfoMixin<OptimizeBlasNewPM> {
// friend struct llvm::AnalysisInfoMixin<OptimizeBlasNewPM>;

private:
bool Begin;
static llvm::AnalysisKey Key;

public:
OptimizeBlasNewPM(bool Begin) : Begin(Begin) {}

bool runOnModule(llvm::Module &M);

static bool isRequired() { return true; }
};
34 changes: 34 additions & 0 deletions enzyme/test/Enzyme/BlasOpt/blasCOpt.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -blas-opt -S | FileCheck %s; fi
;RUN: %opt < %s %newLoadEnzyme -passes="blas-opt" -S | FileCheck %s

declare void @dgemm_64_(i8, i8, i64, i64, i64, double, double* nocapture readonly, i64, double* nocapture readonly, i64, double, double* nocapture readonly, i64)

declare void @dger_64_(i64 , i64, double, double* nocapture readonly, i64, double* nocapture readonly, i64, double*, i64)

define void @f(double* noalias %C, double %alpha, double %beta, double* %x, double *%y, double* %v, double *%w, i64 %m, i64 %n, i64 %p, i64 %lda, i64 %ldb, i64 %ldc, i64 %incx, i64 %incy, i64 %incv, i64 %incw, i8 %transa, i8 %transb) {
entry:
%A = alloca double, align 16
store double 0.000000e+00, double* %A
%B = alloca double, align 16
store double 0.000000e+00, double* %B
call void @dger_64_(i64 %m, i64 %n, double %alpha, double* %x, i64 %incx, double* %y, i64 %incy, double* %A, i64 %lda)
call void @dger_64_(i64 %n, i64 %p, double %beta, double* %v, i64 %incv, double* %w, i64 %incw, double* %B, i64 %ldb)
call void @dgemm_64_(i8 %transa, i8 %transb, i64 %m, i64 %n, i64 %p, double %alpha, double* %A, i64 %lda, double* %B, i64 %ldb, double %beta, double* %C, i64 %ldc)
;%ptr = bitcast i8* %A to double*
;store double 0.0000000e+00, double* %ptr, align 8
ret void
}


; CHECK: define void @f(double* noalias %C, double %alpha, double %beta, double* %x, double* %y, double* %v, double* %w, i64 %m, i64 %n, i64 %p, i64 %lda, i64 %ldb, i64 %ldc, i64 %incx, i64 %incy, i64 %incv, i64 %incw, i8 %transa, i8 %transb)
; CHECK-NEXT: entry:
; CHECK-NEXT: %A = alloca double, align 16
; CHECK-NEXT: store double 0.000000e+00, double* %A
; CHECK-NEXT: %B = alloca double, align 16
; CHECK-NEXT: store double 0.000000e+00, double* %B
; CHECK-NEXT: %0 = call double @ddot_64_(i64 %m, double* %y, i64 %incy, double* %v, i64 %incv)
; CHECK-NEXT: %1 = fmul double %alpha, %0
; CHECK-NEXT: %2 = fmul double %1, %beta
; CHECK-NEXT: call void @dger_64_(i64 %m, i64 %n, double %2, double* %x, i64 %incx, double* %w, i64 %incw, double* %C, i64 %ldc)
; CHECK-NEXT: ret void
; CHECK-NEXT: }
42 changes: 42 additions & 0 deletions enzyme/test/Integration/BlasOpt/first.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
ZuseZ4 marked this conversation as resolved.
Show resolved Hide resolved
enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};

void cblas_dger(const enum CBLAS_ORDER order, const int M, const int N,
const double alpha, const double *X, const int incX,
const double *Y, const int incY, double *A, const int lda);

void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const double alpha, const double *A,
const int lda, const double *B, const int ldb,
const double beta, double *C, const int ldc);

void g(double *restrict x, double *restrict y, double *restrict v, double *restrict w, double *restrict C) {
double A[] = {0.00, 0.00, 0.00, 0.00,
0.00, 0.00, 0.00, 0.00,
0.00};
double B[] = {0.00, 0.00, 0.00, 0.00,
0.00, 0.00, 0.00, 0.00,
0.00};
int lda = 3, ldb = 3, ldc = 3;
int m = 3, n = 3, p = 3;
double alpha = 3.14, beta = 4.31;
int incx = 1, incy = 1, incv = 1, incw = 1;

cblas_dger(CblasRowMajor, m, n, alpha, x, incx, y, incy, A, lda);
cblas_dger(CblasRowMajor, n, p, beta, v, incv, w, incw, B, ldb);
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, p, alpha, A, lda, B, ldb, beta, C, ldc);
}

int main() {
double x[] = {1.0, 2.0, 3.0};
double y[] = {2.0, 3.0, 4.0};
double v[] = {0.0, 2.5, 3.5};
double w[] = {1.0, 1.0, 1.0};
double C[] = {0.00, 0.00, 0.00, 0.00,
0.00, 0.00, 0.00, 0.00,
0.00};
g(x,y,v,w,C);
for (int i = 0; i < 9; i++)
printf("%f\n", C[i]);
}
2 changes: 1 addition & 1 deletion enzyme/tools/enzyme-tblgen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ set(LLVM_LINK_COMPONENTS
)

add_tablegen(enzyme-tblgen Enzyme
blas-tblgen.cpp datastructures.cpp caching.cpp enzyme-tblgen.cpp
opts.cpp blas-tblgen.cpp datastructures.cpp caching.cpp enzyme-tblgen.cpp
)

get_target_property(TBL_LINKED_LIBS LLVMSupport INTERFACE_LINK_LIBRARIES)
Expand Down
Loading