Skip to content

Commit

Permalink
Fix blas shadow differential use analysis (#1493)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Oct 20, 2023
1 parent 18d3da5 commit bb00501
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 5 deletions.
23 changes: 23 additions & 0 deletions enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,29 @@ bool hasDiffeRet(Init *resultTree) {
return false;
}

bool hasAdjoint(Init *resultTree, StringRef argName) {
if (DagInit *resultRoot = dyn_cast<DagInit>(resultTree)) {
auto opName = resultRoot->getOperator()->getAsString();
auto Def = cast<DefInit>(resultRoot->getOperator())->getDef();
if (Def->isSubClassOf("adj")) {
auto name = Def->getValueAsString("name");
return name == argName;
}
for (auto arg : resultRoot->getArgs()) {
if (hasAdjoint(arg, argName))
return true;
}
}
if (DefInit *DefArg = dyn_cast<DefInit>(resultTree)) {
auto Def = DefArg->getDef();
if (Def->isSubClassOf("adj")) {
auto name = Def->getValueAsString("name");
return name == argName;
}
}
return false;
}

static void checkBlasCallsInDag(const RecordKeeper &RK,
ArrayRef<Record *> blasPatterns,
StringRef blasName, const DagInit *toSearch) {
Expand Down
1 change: 1 addition & 0 deletions enzyme/tools/enzyme-tblgen/blas-tblgen.h
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
void emitBlasDerivatives(const llvm::RecordKeeper &RK, llvm::raw_ostream &os);
bool hasDiffeRet(llvm::Init *resultTree);
bool hasAdjoint(llvm::Init *resultTree, llvm::StringRef argName);
53 changes: 48 additions & 5 deletions enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,61 @@ void emit_BLASDiffUse(TGPattern &pattern, llvm::raw_ostream &os) {
auto name = nameVec[arg];
os << " bool active_" << name << " = !gutils->isConstantValue(arg_" << name
<< ");\n";
os << " if (EnzymeRuntimeActivityCheck && active_" << name
os << " if (!shadow && EnzymeRuntimeActivityCheck && active_" << name
<< ") return true;\n";
}

emit_need_cache_info(pattern, os);

for (size_t argPos = (lv23 ? 1 : 0); argPos < typeMap.size(); argPos++) {
auto users = argUsers.lookup(argPos);
auto name = nameVec[argPos];
os << " if (val == arg_" << name << " && need_" << name << " && !cache_"
<< name << ")\n"
<< " return true;\n";
auto argname = nameVec[argPos];

os << " if (val == arg_" << argname << ") {\n";

// We need the shadow of the value we're updating
if (typeMap[argPos] == ArgType::fp) {
os << " if (shadow && byRef && active_" << argname
<< ") return true;\n";
} else if (typeMap[argPos] == ArgType::vincData ||
typeMap[argPos] == ArgType::mldData) {
for (auto derivOp : pattern.getRules()) {
if (hasAdjoint(derivOp.getRuleDag(), argname)) {
os << " if (shadow && active_"
<< nameVec[derivOp.getHandledArgIdx()] << ") return true;\n";
} else {
bool isNoop = false;
if (DagInit *resultRoot = dyn_cast<DagInit>(derivOp.getRuleDag())) {
auto opName = resultRoot->getOperator()->getAsString();
auto Def = cast<DefInit>(resultRoot->getOperator())->getDef();
if (Def->getName() == "noop" || Def->getName() == "inactive") {
isNoop = true;
}
}
if (DefInit *DefArg = dyn_cast<DefInit>(derivOp.getRuleDag())) {
auto Def = DefArg->getDef();
if (Def->getName() == "noop" || Def->getName() == "inactive") {
isNoop = true;
}
}
// updates to a vector/matrix must definitionally use the shadow of
// the input, unless a noop-update
if (!isNoop) {
if (derivOp.getHandledArgIdx() == argPos) {
llvm::errs() << " fnname: " << name << " argPos: " << argPos
<< " argname: " << argname
<< " rule: " << *derivOp.getRuleDag() << "\n";
}
assert(derivOp.getHandledArgIdx() != argPos);
}
}
}
}

os << " if (!shadow && need_" << argname << " && !cache_" << argname
<< ")\n"
<< " return true;\n";
os << " }\n";
}

// If any of the rule uses DiffeRet, the primary function has a ret val
Expand Down

0 comments on commit bb00501

Please sign in to comment.