Skip to content

Commit

Permalink
Create c++ nofree attribute (#1474)
Browse files Browse the repository at this point in the history
* Create c++ nofree attribute

* fix build
  • Loading branch information
wsmoses authored Oct 10, 2023
1 parent a1f4041 commit b255008
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 0 deletions.
108 changes: 108 additions & 0 deletions enzyme/Enzyme/Clang/EnzymeClang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,114 @@ struct EnzymeInactiveAttrInfo : public ParsedAttrInfo {
static ParsedAttrInfoRegistry::Add<EnzymeInactiveAttrInfo> X4("enzyme_inactive",
"");

struct EnzymeNoFreeAttrInfo : public ParsedAttrInfo {
EnzymeNoFreeAttrInfo() {
OptArgs = 1;
// GNU-style __attribute__(("example")) and C++/C2x-style [[example]] and
// [[plugin::example]] supported.
static constexpr Spelling S[] = {
{ParsedAttr::AS_GNU, "enzyme_nofree"},
#if LLVM_VERSION_MAJOR > 17
{ParsedAttr::AS_C23, "enzyme_nofree"},
#else
{ParsedAttr::AS_C2x, "enzyme_nofree"},
#endif
{ParsedAttr::AS_CXX11, "enzyme_nofree"},
{ParsedAttr::AS_CXX11, "enzyme::nofree"}
};
Spellings = S;
}

bool diagAppertainsToDecl(Sema &S, const ParsedAttr &Attr,
const Decl *D) const override {
// This attribute appertains to functions only.
if (isa<FunctionDecl>(D))
return true;
if (auto VD = dyn_cast<VarDecl>(D)) {
if (VD->hasGlobalStorage())
return true;
}
S.Diag(Attr.getLoc(), diag::warn_attribute_wrong_decl_type_str)
<< Attr << "functions and globals";
return false;
}

AttrHandling handleDeclAttribute(Sema &S, Decl *D,
const ParsedAttr &Attr) const override {
if (Attr.getNumArgs() != 0) {
unsigned ID = S.getDiagnostics().getCustomDiagID(
DiagnosticsEngine::Error,
"'enzyme_nofree' attribute requires zero arguments");
S.Diag(Attr.getLoc(), ID);
return AttributeNotApplied;
}

auto &AST = S.getASTContext();
DeclContext *declCtx = D->getDeclContext();
auto loc = D->getLocation();
RecordDecl *RD;
if (S.getLangOpts().CPlusPlus)
RD = CXXRecordDecl::Create(AST, clang::TagTypeKind::TTK_Struct, declCtx,
loc, loc, nullptr); // rId);
else
RD = RecordDecl::Create(AST, clang::TagTypeKind::TTK_Struct, declCtx, loc,
loc, nullptr); // rId);
RD->setAnonymousStructOrUnion(true);
RD->setImplicit();
RD->startDefinition();
auto T = isa<FunctionDecl>(D) ? cast<FunctionDecl>(D)->getType()
: cast<VarDecl>(D)->getType();
auto Name = isa<FunctionDecl>(D) ? cast<FunctionDecl>(D)->getNameAsString()
: cast<VarDecl>(D)->getNameAsString();
auto FT = AST.getPointerType(T);
auto &Id = AST.Idents.get(
(StringRef("__enzyme_nofree") + "_autoreg_" + Name).str());
auto V = VarDecl::Create(AST, declCtx, loc, loc, &Id, FT, nullptr, SC_None);
V->setStorageClass(SC_PrivateExtern);
V->addAttr(clang::UsedAttr::CreateImplicit(AST));
TemplateArgumentListInfo *TemplateArgs = nullptr;
auto DR = DeclRefExpr::Create(
AST, NestedNameSpecifierLoc(), loc, cast<ValueDecl>(D), false, loc, T,
ExprValueKind::VK_LValue, cast<NamedDecl>(D), TemplateArgs);
#if LLVM_VERSION_MAJOR >= 13
auto rval = ExprValueKind::VK_PRValue;
#else
auto rval = ExprValueKind::VK_RValue;
#endif
Expr *expr = nullptr;
if (isa<FunctionDecl>(D)) {
#if LLVM_VERSION_MAJOR >= 12
expr =
ImplicitCastExpr::Create(AST, FT, CastKind::CK_FunctionToPointerDecay,
DR, nullptr, rval, FPOptionsOverride());
#else
expr = ImplicitCastExpr::Create(
AST, FT, CastKind::CK_FunctionToPointerDecay, DR, nullptr, rval);
#endif
} else {
expr =
UnaryOperator::Create(AST, DR, UnaryOperatorKind::UO_AddrOf, FT, rval,
clang::ExprObjectKind ::OK_Ordinary, loc,
/*canoverflow*/ false, FPOptionsOverride());
}

if (expr->isValueDependent()) {
unsigned ID = S.getDiagnostics().getCustomDiagID(
DiagnosticsEngine::Error, "use of attribute 'enzyme_nofree' "
"in a templated context not yet supported");
S.Diag(Attr.getLoc(), ID);
return AttributeNotApplied;
}
V->setInit(expr);
V->dump();
S.MarkVariableReferenced(loc, V);
S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V));
return AttributeApplied;
}
};

static ParsedAttrInfoRegistry::Add<EnzymeNoFreeAttrInfo> X5("enzyme_nofree",
"");
} // namespace

#endif
29 changes: 29 additions & 0 deletions enzyme/Enzyme/PreserveNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ bool preserveNVVM(bool Begin, Function &F) {
if (g.getName().contains(gradient_handler_name) ||
g.getName().contains(derivative_handler_name) ||
g.getName().contains(splitderivative_handler_name) ||
g.getName().contains("__enzyme_nofree") ||
g.getName().contains("__enzyme_inactivefn") ||
g.getName().contains("__enzyme_function_like") ||
g.getName().contains("__enzyme_allocation_like")) {
Expand Down Expand Up @@ -455,6 +456,34 @@ bool preserveNVVM(bool Begin, Function &F) {
}
}
}
if (g.getName().contains("__enzyme_nofree")) {
if (g.hasInitializer()) {
Value *V = g.getInitializer();
while (1) {
if (auto CE = dyn_cast<ConstantExpr>(V)) {
V = CE->getOperand(0);
continue;
}
if (auto CA = dyn_cast<ConstantAggregate>(V)) {
V = CA->getOperand(0);
continue;
}
break;
}
if (auto F = cast<Function>(V)) {
F->addAttribute(AttributeList::FunctionIndex,
Attribute::get(g.getContext(), Attribute::NoFree));
toErase.push_back(&g);
changed = true;
} else {
llvm::errs() << "Param of __enzyme_nofree must be a "
"constant function"
<< g << "\n"
<< *V << "\n";
llvm_unreachable("__enzyme_nofree");
}
}
}
if (g.getName().contains("__enzyme_function_like")) {
if (g.hasInitializer()) {
auto CA = dyn_cast<ConstantAggregate>(g.getInitializer());
Expand Down

0 comments on commit b255008

Please sign in to comment.