diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index 69dfdfa6cc39..45f891eecd18 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -57,6 +57,29 @@ bool hasDiffeRet(Init *resultTree) { return false; } +bool hasAdjoint(Init *resultTree, StringRef argName) { + if (DagInit *resultRoot = dyn_cast(resultTree)) { + auto opName = resultRoot->getOperator()->getAsString(); + auto Def = cast(resultRoot->getOperator())->getDef(); + if (Def->isSubClassOf("adj")) { + auto name = Def->getValueAsString("name"); + return name == argName; + } + for (auto arg : resultRoot->getArgs()) { + if (hasAdjoint(arg, argName)) + return true; + } + } + if (DefInit *DefArg = dyn_cast(resultTree)) { + auto Def = DefArg->getDef(); + if (Def->isSubClassOf("adj")) { + auto name = Def->getValueAsString("name"); + return name == argName; + } + } + return false; +} + static void checkBlasCallsInDag(const RecordKeeper &RK, ArrayRef blasPatterns, StringRef blasName, const DagInit *toSearch) { diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.h b/enzyme/tools/enzyme-tblgen/blas-tblgen.h index e0c2c966e008..f9726b396561 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.h +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.h @@ -1,2 +1,3 @@ void emitBlasDerivatives(const llvm::RecordKeeper &RK, llvm::raw_ostream &os); bool hasDiffeRet(llvm::Init *resultTree); +bool hasAdjoint(llvm::Init *resultTree, llvm::StringRef argName); diff --git a/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h b/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h index 439ddc583bfe..e071222ea19a 100644 --- a/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h +++ b/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h @@ -60,7 +60,7 @@ void emit_BLASDiffUse(TGPattern &pattern, llvm::raw_ostream &os) { auto name = nameVec[arg]; os << " bool active_" << name << " = !gutils->isConstantValue(arg_" << name << ");\n"; - os << " if (EnzymeRuntimeActivityCheck && active_" << name + os << " if (!shadow && EnzymeRuntimeActivityCheck && active_" << name << ") return true;\n"; } @@ -68,10 +68,53 @@ void emit_BLASDiffUse(TGPattern &pattern, llvm::raw_ostream &os) { for (size_t argPos = (lv23 ? 1 : 0); argPos < typeMap.size(); argPos++) { auto users = argUsers.lookup(argPos); - auto name = nameVec[argPos]; - os << " if (val == arg_" << name << " && need_" << name << " && !cache_" - << name << ")\n" - << " return true;\n"; + auto argname = nameVec[argPos]; + + os << " if (val == arg_" << argname << ") {\n"; + + // We need the shadow of the value we're updating + if (typeMap[argPos] == ArgType::fp) { + os << " if (shadow && byRef && active_" << argname + << ") return true;\n"; + } else if (typeMap[argPos] == ArgType::vincData || + typeMap[argPos] == ArgType::mldData) { + for (auto derivOp : pattern.getRules()) { + if (hasAdjoint(derivOp.getRuleDag(), argname)) { + os << " if (shadow && active_" + << nameVec[derivOp.getHandledArgIdx()] << ") return true;\n"; + } else { + bool isNoop = false; + if (DagInit *resultRoot = dyn_cast(derivOp.getRuleDag())) { + auto opName = resultRoot->getOperator()->getAsString(); + auto Def = cast(resultRoot->getOperator())->getDef(); + if (Def->getName() == "noop" || Def->getName() == "inactive") { + isNoop = true; + } + } + if (DefInit *DefArg = dyn_cast(derivOp.getRuleDag())) { + auto Def = DefArg->getDef(); + if (Def->getName() == "noop" || Def->getName() == "inactive") { + isNoop = true; + } + } + // updates to a vector/matrix must definitionally use the shadow of + // the input, unless a noop-update + if (!isNoop) { + if (derivOp.getHandledArgIdx() == argPos) { + llvm::errs() << " fnname: " << name << " argPos: " << argPos + << " argname: " << argname + << " rule: " << *derivOp.getRuleDag() << "\n"; + } + assert(derivOp.getHandledArgIdx() != argPos); + } + } + } + } + + os << " if (!shadow && need_" << argname << " && !cache_" << argname + << ")\n" + << " return true;\n"; + os << " }\n"; } // If any of the rule uses DiffeRet, the primary function has a ret val