Skip to content

Commit

Permalink
feat: attribute delaborators
Browse files Browse the repository at this point in the history
  • Loading branch information
digama0 committed Nov 18, 2024
1 parent 5a99cb3 commit 9b8cd67
Show file tree
Hide file tree
Showing 42 changed files with 261 additions and 52 deletions.
72 changes: 54 additions & 18 deletions src/Lean/Attributes.lean
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ structure AttributeImpl extends AttributeImplCore where
/-- This is run when the attribute is applied to a declaration `decl`. `stx` is the syntax of the attribute including arguments. -/
add (decl : Name) (stx : Syntax) (kind : AttributeKind) : AttrM Unit
erase (decl : Name) : AttrM Unit := throwError "attribute cannot be erased"
/-- Implementations should push an `attr` syntax corresponding to (a best approximation of)
the attribute state for the given definition. -/
delab (decl : Name) : StateT (Array (TSyntax `attr)) AttrM Unit
deriving Inhabited

builtin_initialize attributeMapRef : IO.Ref (Std.HashMap Name AttributeImpl) ← IO.mkRef {}
Expand Down Expand Up @@ -134,6 +137,15 @@ structure TagAttribute where
ext : PersistentEnvExtension Name Name NameSet
deriving Inhabited

namespace TagAttribute

private def hasTagCore (ext : PersistentEnvExtension Name Name NameSet) (env : Environment) (decl : Name) : Bool :=
match env.getModuleIdxFor? decl with
| some modIdx => (ext.getModuleEntries env modIdx).binSearchContains decl Name.quickLt
| none => (ext.getState env).contains decl

end TagAttribute

def registerTagAttribute (name : Name) (descr : String)
(validate : Name → AttrM Unit := fun _ => pure ()) (ref : Name := by exact decl_name%) (applicationTime := AttributeApplicationTime.afterTypeChecking) : IO TagAttribute := do
let ext : PersistentEnvExtension Name Name NameSet ← registerPersistentEnvExtension {
Expand All @@ -156,16 +168,17 @@ def registerTagAttribute (name : Name) (descr : String)
throwError "invalid attribute '{name}', declaration is in an imported module"
validate decl
modifyEnv fun env => ext.addEntry env decl
delab := fun decl => do
if TagAttribute.hasTagCore ext (← getEnv) decl then
modify (·.push <| Unhygienic.run `(attr| $(mkIdent name):ident))
}
registerBuiltinAttribute attrImpl
return { attr := attrImpl, ext := ext }

namespace TagAttribute

def hasTag (attr : TagAttribute) (env : Environment) (decl : Name) : Bool :=
match env.getModuleIdxFor? decl with
| some modIdx => (attr.ext.getModuleEntries env modIdx).binSearchContains decl Name.quickLt
| none => (attr.ext.getState env).contains decl
hasTagCore attr.ext env decl

end TagAttribute

Expand All @@ -182,10 +195,21 @@ structure ParametricAttribute (α : Type) where

structure ParametricAttributeImpl (α : Type) extends AttributeImplCore where
getParam : Name → Syntax → AttrM α
delabParam : Name → α → StateT (Array (TSyntax `attr)) AttrM Unit
afterSet : Name → α → AttrM Unit := fun _ _ _ => pure ()
afterImport : Array (Array (Name × α)) → ImportM Unit := fun _ => pure ()

def registerParametricAttribute (impl : ParametricAttributeImpl α) : IO (ParametricAttribute α) := do
def ParametricAttribute.getParam?Core [Inhabited α]
(ext : PersistentEnvExtension (Name × α) (Name × α) (NameMap α))
(env : Environment) (decl : Name) : Option α :=
match env.getModuleIdxFor? decl with
| some modIdx =>
match (ext.getModuleEntries env modIdx).binSearch (decl, default) (fun a b => Name.quickLt a.1 b.1) with
| some (_, val) => some val
| none => none
| none => (ext.getState env).find? decl

def registerParametricAttribute [Inhabited α] (impl : ParametricAttributeImpl α) : IO (ParametricAttribute α) := do
let ext : PersistentEnvExtension (Name × α) (Name × α) (NameMap α) ← registerPersistentEnvExtension {
name := impl.ref
mkInitial := pure {}
Expand All @@ -198,27 +222,25 @@ def registerParametricAttribute (impl : ParametricAttributeImpl α) : IO (Parame
}
let attrImpl : AttributeImpl := {
impl.toAttributeImplCore with
add := fun decl stx kind => do
add := fun decl stx kind => do
unless kind == AttributeKind.global do throwError "invalid attribute '{impl.name}', must be global"
let env ← getEnv
unless (env.getModuleIdxFor? decl).isNone do
throwError "invalid attribute '{impl.name}', declaration is in an imported module"
let val ← impl.getParam decl stx
modifyEnv fun env => ext.addEntry env (decl, val)
try impl.afterSet decl val catch _ => setEnv env
delab := fun decl => do
if let some val := ParametricAttribute.getParam?Core ext (← getEnv) decl then
impl.delabParam decl val
}
registerBuiltinAttribute attrImpl
pure { attr := attrImpl, ext := ext }

namespace ParametricAttribute

def getParam? [Inhabited α] (attr : ParametricAttribute α) (env : Environment) (decl : Name) : Option α :=
match env.getModuleIdxFor? decl with
| some modIdx =>
match (attr.ext.getModuleEntries env modIdx).binSearch (decl, default) (fun a b => Name.quickLt a.1 b.1) with
| some (_, val) => some val
| none => none
| none => (attr.ext.getState env).find? decl
getParam?Core attr.ext env decl

def setParam (attr : ParametricAttribute α) (env : Environment) (decl : Name) (param : α) : Except String Environment :=
if (env.getModuleIdxFor? decl).isSome then
Expand All @@ -239,7 +261,17 @@ structure EnumAttributes (α : Type) where
ext : PersistentEnvExtension (Name × α) (Name × α) (NameMap α)
deriving Inhabited

def registerEnumAttributes (attrDescrs : List (Name × String × α))
private def EnumAttributes.getValueCore [Inhabited α]
(ext : PersistentEnvExtension (Name × α) (Name × α) (NameMap α))
(env : Environment) (decl : Name) : Option α :=
match env.getModuleIdxFor? decl with
| some modIdx =>
match (ext.getModuleEntries env modIdx).binSearch (decl, default) (fun a b => Name.quickLt a.1 b.1) with
| some (_, val) => some val
| none => none
| none => (ext.getState env).find? decl

def registerEnumAttributes [Inhabited α] [BEq α] (attrDescrs : List (Name × String × α))
(validate : Name → α → AttrM Unit := fun _ _ => pure ())
(applicationTime := AttributeApplicationTime.afterTypeChecking)
(ref : Name := by exact decl_name%) : IO (EnumAttributes α) := do
Expand All @@ -265,6 +297,10 @@ def registerEnumAttributes (attrDescrs : List (Name × String × α))
throwError "invalid attribute '{name}', declaration is in an imported module"
validate decl val
modifyEnv fun env => ext.addEntry env (decl, val)
delab := fun decl => do
if let some v := EnumAttributes.getValueCore ext (← getEnv) decl then
if v == val then
modify (·.push <| Unhygienic.run `(attr| $(mkIdent name):ident))
applicationTime := applicationTime
: AttributeImpl
}
Expand All @@ -274,12 +310,7 @@ def registerEnumAttributes (attrDescrs : List (Name × String × α))
namespace EnumAttributes

def getValue [Inhabited α] (attr : EnumAttributes α) (env : Environment) (decl : Name) : Option α :=
match env.getModuleIdxFor? decl with
| some modIdx =>
match (attr.ext.getModuleEntries env modIdx).binSearch (decl, default) (fun a b => Name.quickLt a.1 b.1) with
| some (_, val) => some val
| none => none
| none => (attr.ext.getState env).find? decl
getValueCore attr.ext env decl

def setValue (attrs : EnumAttributes α) (env : Environment) (decl : Name) (val : α) : Except String Environment :=
if (env.getModuleIdxFor? decl).isSome then
Expand Down Expand Up @@ -377,6 +408,11 @@ def getBuiltinAttributeImpl (attrName : Name) : IO AttributeImpl := do
| some attr => pure attr
| none => throw (IO.userError ("unknown attribute '" ++ toString attrName ++ "'"))

def delabAttributesOfDecl (declName : Name) : AttrM (Array (TSyntax `attr)) := do
let m ← attributeMapRef.get
let act := m.forM fun _ attr => attr.delab declName
(·.2) <$> act.run #[]

@[export lean_attribute_application_time]
def getBuiltinAttributeApplicationTime (n : Name) : IO AttributeApplicationTime := do
let attr ← getBuiltinAttributeImpl n
Expand Down
1 change: 1 addition & 0 deletions src/Lean/BuiltinDocAttr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ builtin_initialize
add := fun decl stx _ => do
Attribute.Builtin.ensureNoArgs stx
declareBuiltinDocStringAndRanges decl
delab := fun _ => pure ()
}

end Lean
3 changes: 3 additions & 0 deletions src/Lean/Class.lean
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ builtin_initialize
unless kind == AttributeKind.global do throwError "invalid attribute 'class', must be global"
let env ← ofExcept (addClass env decl)
setEnv env
delab := fun decl => do
if isClass (← getEnv) decl then
modify (·.push <| Unhygienic.run `(attr| class))
}

end Lean
21 changes: 12 additions & 9 deletions src/Lean/Compiler/CSimpAttr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,6 @@ def add (declName : Name) (kind : AttributeKind) : CoreM Unit := do
else
throwError "invalid 'csimp' theorem, only constant replacement theorems (e.g., `@f = @g`) are currently supported."

builtin_initialize
registerBuiltinAttribute {
name := `csimp
descr := "simplification theorem for the compiler"
add := fun declName stx attrKind => do
Attribute.Builtin.ensureNoArgs stx
discard <| add declName attrKind
}

@[export lean_csimp_replace_constants]
def replaceConstants (env : Environment) (e : Expr) : Expr :=
let s := ext.getState env
Expand All @@ -73,4 +64,16 @@ end CSimp
def hasCSimpAttribute (env : Environment) (declName : Name) : Bool :=
CSimp.ext.getState env |>.thmNames.contains declName

builtin_initialize
registerBuiltinAttribute {
name := `csimp
descr := "simplification theorem for the compiler"
add := fun declName stx attrKind => do
Attribute.Builtin.ensureNoArgs stx
discard <| CSimp.add declName attrKind
delab := fun declName => do
if hasCSimpAttribute (← getEnv) declName then
modify (·.push <| Unhygienic.run `(attr| csimp))
}

end Lean.Compiler
2 changes: 2 additions & 0 deletions src/Lean/Compiler/ExportAttr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ builtin_initialize exportAttr : ParametricAttribute Name ←
unless isValidCppName exportName do
throwError "invalid 'export' function name, is not a valid C++ identifier"
return exportName
delabParam := fun _ exportName => do
modify (·.push <| Unhygienic.run `(attr| export $(mkIdent exportName)))
}

@[export lean_get_export_name_for]
Expand Down
17 changes: 17 additions & 0 deletions src/Lean/Compiler/ExternAttr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,20 @@ private def syntaxToExternAttrData (stx : Syntax) : AttrM ExternAttrData := do
entries := entries.push <| ExternEntry.inline backend str
return { arity? := arity?, entries := entries.toList }

private def externAttrDataToSyntax (data : ExternAttrData) : AttrM (TSyntax `attr) := do
let arity? := data.arity?.map Syntax.mkNatLit
let mut entries := #[]
unless data matches {arity? := none, entries := [.adhoc `all]} do
for entry in data.entries do
let (inline, backend, str) ← match entry with
| .standard backend str => pure (none, backend, str)
| .inline backend str => pure (some (mkAtom "inline"), backend, str)
| _ => continue
let backend := if backend == `all then none else some (mkIdent backend)
entries := entries.push <| mkNode `Lean.Parser.Attr.externEntry #[
mkOptionalNode backend, mkOptionalNode inline, Syntax.mkStrLit str]
`(attr| extern $(arity?)? $entries*)

@[extern "lean_add_extern"]
opaque addExtern (env : Environment) (n : Name) : ExceptT String Id Environment

Expand All @@ -73,6 +87,9 @@ builtin_initialize externAttr : ParametricAttribute ExternAttrData ←
return ()
let env ← ofExcept <| addExtern env declName
setEnv env
delabParam := fun decl val => do
let stx ← externAttrDataToSyntax val
modify (·.push stx)
}

@[export lean_get_extern_attr_data]
Expand Down
3 changes: 3 additions & 0 deletions src/Lean/Compiler/ImplementedByAttr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ builtin_initialize implementedByAttr : ParametricAttribute Name ← registerPara
if decl.name == fnDecl.name then
throwError "invalid 'implemented_by' argument '{fnName}', function cannot be implemented by itself"
return fnName
delabParam := fun _ val => do
let val ← unresolveNameGlobal val
modify (·.push <| Unhygienic.run `(attr| implemented_by $(mkIdent val):ident))
}

@[export lean_get_implemented_by]
Expand Down
6 changes: 6 additions & 0 deletions src/Lean/Compiler/InitAttr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ unsafe def registerInitAttrUnsafe (attrName : Name) (runAfterImport : Bool) (ref
| none =>
if isIOUnit decl.type then pure Name.anonymous
else throwError "initialization function must have type `IO Unit`"
delabParam := fun _ val => do
if val.isAnonymous then
modify (·.push <| Unhygienic.run `(attr| $(mkIdent attrName):ident))
else
let val ← unresolveNameGlobal val
modify (·.push <| Unhygienic.run `(attr| $(mkIdent attrName):ident $(mkIdent val)))
afterImport := fun entries => do
let ctx ← read
if runAfterImport && (← isInitializerExecutionEnabled) then
Expand Down
1 change: 1 addition & 0 deletions src/Lean/Compiler/LCNF/Passes.lean
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ builtin_initialize
Attribute.Builtin.ensureNoArgs stx
unless kind == AttributeKind.global do throwError "invalid attribute 'cpass', must be global"
discard <| addPass declName
delab := fun _ => pure ()
applicationTime := .afterCompilation
}

Expand Down
3 changes: 3 additions & 0 deletions src/Lean/Compiler/Specialize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ builtin_initialize specializeAttr : ParametricAttribute (Array Nat) ←
getParam := fun declName stx => do
let args := stx[1].getArgs
elabSpecArgs declName args |>.run'
delabParam := fun declName val => do
let val := val.map fun n => Syntax.mkNatLit (n + 1)
modify (·.push <| Unhygienic.run `(attr| specialize $val*))
}

def getSpecializationArgs? (env : Environment) (declName : Name) : Option (Array Nat) :=
Expand Down
1 change: 1 addition & 0 deletions src/Lean/Elab/InheritDoc.lean
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ builtin_initialize
| logWarningAt id m!"{← mkConstWithLevelParams declName} does not have a doc string"
addDocString decl doc
| _ => throwError "invalid `[inherit_doc]` attribute"
delab := fun _ => pure ()
}
16 changes: 11 additions & 5 deletions src/Lean/Elab/Print.lean
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@ private def levelParamsToMessageData (levelParams : List Name) : MessageData :=
return m ++ "}"

private def mkHeader (kind : String) (id : Name) (levelParams : List Name) (type : Expr) (safety : DefinitionSafety) : CommandElabM MessageData := do
let attrs ← liftCoreM <| delabAttributesOfDecl id
let m : MessageData :=
match (← getReducibilityStatus id) with
| ReducibilityStatus.irreducible => "@[irreducible] "
| ReducibilityStatus.reducible => "@[reducible] "
| ReducibilityStatus.semireducible => ""
let m :=
if attrs.isEmpty then
""
else
-- This sorting is not perfect, but we need to do some sorting here because
-- otherwise the attributes come out in random order which causes problems for reproducibility
let key {k} (stx : TSyntax k) := (stx.raw.getKind, stx.raw[0].getKind)
have : Ord (Name × Name) := Ord.lex ⟨Name.cmp⟩ ⟨Name.cmp⟩
let attrs := attrs.qsort (fun a b => compare (key a) (key b) = .lt)
"@[" ++ MessageData.joinSep (attrs.toList.map fun s => .ofSyntax s.raw) ", " ++ "] "
let m :=
m ++
match safety with
| DefinitionSafety.unsafe => "unsafe "
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Elab/Tactic/BVDecide/Frontend/Attr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def addBVNormalizeProcBuiltinAttr (declName : Name) (post : Bool) (proc : Sum Si

builtin_initialize
registerBuiltinAttribute {
ref := by exact decl_name%
name := `bvNormalizeProcBuiltinAttr
descr := "Builtin bv_normalize simproc"
applicationTime := AttributeApplicationTime.afterCompilation
erase := fun _ => throwError "Not implemented yet, [-builtin_bv_normalize_proc]"
add := fun declName stx _ => addBuiltin declName stx ``addBVNormalizeProcBuiltinAttr
delab := fun _ => pure ()
}

end Frontend
Expand Down
6 changes: 6 additions & 0 deletions src/Lean/Elab/Tactic/Ext.lean
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,12 @@ builtin_initialize registerBuiltinAttribute {
-- Realize iff theorem
if iff then
discard <| liftCommandElabM <| withRef stx <| realizeExtIffTheorem declName
delab := fun decl => do
if let some {priority, .. : ExtTheorem} := (extExtension.getState (← getEnv)).tree.fold
(fun o _ v => o <|> guard (v.declName == decl) *> pure v) none
then
let prio := if priority = eval_prio default then none else some (Syntax.mkNatLit priority)
modify (·.push <| Unhygienic.run `(attr| ext $[$prio:num]?))
erase := fun declName => do
let s := extExtension.getState (← getEnv)
let s ← s.erase declName
Expand Down
3 changes: 3 additions & 0 deletions src/Lean/Elab/Term.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2158,6 +2158,9 @@ builtin_initialize
unless kind == AttributeKind.global do
throwError "invalid attribute 'builtin_incremental', must be global"
declareBuiltin decl <| mkApp (mkConst ``addBuiltinIncrementalElab) (toExpr decl)
delab := fun decl => do
if (← builtinIncrementalElabs.get).contains decl then
modify (·.push <| Unhygienic.run `(attr| builtin_incremental))
}

/-- Checks whether a declaration is annotated with `[builtin_incremental]` or `[incremental]`. -/
Expand Down
2 changes: 2 additions & 0 deletions src/Lean/KeyedDeclsAttribute.lean
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ protected unsafe def init {γ} (df : Def γ) (attrDeclName : Name := by exact de
ref := attrDeclName
name := df.builtinName
descr := "(builtin) " ++ df.descr
delab := fun _ => pure ()
add := fun declName stx kind => do
unless kind == AttributeKind.global do throwError "invalid attribute '{df.builtinName}', must be global"
let key ← df.evalKey true stx
Expand All @@ -145,6 +146,7 @@ protected unsafe def init {γ} (df : Def γ) (attrDeclName : Name := by exact de
let s := ext.getState (← getEnv)
let s ← s.erase df.name declName
modifyEnv fun env => ext.modifyState env fun _ => s
delab := fun _ => pure ()
add := fun declName stx attrKind => do
let key ← df.evalKey false stx
match IR.getSorryDep (← getEnv) declName with
Expand Down
3 changes: 3 additions & 0 deletions src/Lean/LabelAttribute.lean
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ registerBuiltinAttribute {
applicationTime := AttributeApplicationTime.afterCompilation
add := fun declName _ kind =>
ext.add declName kind
delab := fun declName => do
if (ext.getState (← getEnv)).contains declName then
modify (·.push <| Unhygienic.run `(attr| $(mkIdent attrName):ident))
erase := fun declName => do
let s := ext.getState (← getEnv)
modifyEnv fun env => ext.modifyState env fun _ => s.erase declName
Expand Down
5 changes: 5 additions & 0 deletions src/Lean/Linter/Deprecated.lean
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ builtin_initialize deprecatedAttr : ParametricAttribute DeprecationEntry ←
let text? := text?.map TSyntax.getString
let since? := since?.map TSyntax.getString
return { newName?, text?, since? }
delabParam := fun _ { newName?, text?, since? } => do
let id? ← newName?.mapM (mkIdent <$> unresolveNameGlobal ·)
let text? := text?.map Syntax.mkStrLit
let since? := since?.map Syntax.mkStrLit
modify (·.push <| Unhygienic.run `(attr| deprecated $(id?)? $(text?)? $[(since := $since?)]?))
}

def isDeprecated (env : Environment) (declName : Name) : Bool :=
Expand Down
Loading

0 comments on commit 9b8cd67

Please sign in to comment.