From 8d43b2e4d98caec3d5bd36b11e6d885b90e0bcf4 Mon Sep 17 00:00:00 2001 From: Kyle Miller Date: Sun, 17 Nov 2024 18:35:31 -0800 Subject: [PATCH] chore: document `Lean.Elab.StructInst`, refactor This PR does some mild refactoring of the `Lean.Elab.StructInst` module while adding documentation. Documentation is drawn from @thorimur's #1928. --- src/Lean/Elab/StructInst.lean | 484 ++++++++++++++++++++++------------ 1 file changed, 316 insertions(+), 168 deletions(-) diff --git a/src/Lean/Elab/StructInst.lean b/src/Lean/Elab/StructInst.lean index b97cf272a80c..9301428acc18 100644 --- a/src/Lean/Elab/StructInst.lean +++ b/src/Lean/Elab/StructInst.lean @@ -11,21 +11,40 @@ import Lean.Elab.App import Lean.Elab.Binders import Lean.PrettyPrinter +/-! +# Structure instance elaborator + +A *structure instance* is notation to construct a term of a `structure`. +Examples: `{ x := 2, y.z := true }`, `{ s with cache := c' }`, and `{ s with values[2] := v }`. +Structure instances are the preferred way to invoke a `structure`'s constructor, +since they hide Lean implementation details such as whether parents are represented as subobjects, +and also they do correct processing of default values, which are complicated due to the fact that `structure`s can override default values of their parents. + +This module elaborates structure instance notation. +Note that the `where` syntax to define structures (`Lean.Parser.Command.whereStructInst`) +macro expands into the structure instance notation elaborated by this module. +-/ + namespace Lean.Elab.Term.StructInst open Meta open TSyntax.Compat -/- - Structure instances are of the form: - - "{" >> optional (atomic (sepBy1 termParser ", " >> " with ")) - >> manyIndent (group ((structInstFieldAbbrev <|> structInstField) >> optional ", ")) - >> optEllipsis - >> optional (" : " >> termParser) - >> " }" +/-! +Recall that structure instances are of the form: +``` +"{" >> optional (atomic (sepBy1 termParser ", " >> " with ")) + >> manyIndent (group ((structInstFieldAbbrev <|> structInstField) >> optional ", ")) + >> optEllipsis + >> optional (" : " >> termParser) + >> " }" +``` -/ +/-- +Transforms structure instances such as `{ x := 0 : Foo }` into `({ x := 0 } : Foo)`. +Structure instance notation makes use of the expected type. +-/ @[builtin_macro Lean.Parser.Term.structInst] def expandStructInstExpectedType : Macro := fun stx => let expectedArg := stx[4] if expectedArg.isNone then @@ -35,7 +54,10 @@ open TSyntax.Compat let stxNew := stx.setArg 4 mkNullNode `(($stxNew : $expected)) -/-- Expand field abbreviations. Example: `{ x, y := 0 }` expands to `{ x := x, y := 0 }` -/ +/-- +Expands field abbreviation notation. +Example: `{ x, y := 0 }` expands to `{ x := x, y := 0 }`. +-/ @[builtin_macro Lean.Parser.Term.structInst] def expandStructInstFieldAbbrev : Macro | `({ $[$srcs,* with]? $fields,* $[..%$ell]? $[: $ty]? }) => if fields.getElems.raw.any (·.getKind == ``Lean.Parser.Term.structInstFieldAbbrev) then do @@ -49,9 +71,12 @@ open TSyntax.Compat | _ => Macro.throwUnsupported /-- - If `stx` is of the form `{ s₁, ..., sₙ with ... }` and `sᵢ` is not a local variable, expand into `let src := sᵢ; { ..., src, ... with ... }`. +If `stx` is of the form `{ s₁, ..., sₙ with ... }` and `sᵢ` is not a local variable, +expands into `let __src := sᵢ; { ..., __src, ... with ... }`. +The significance of `__src` is that the variable is treated as an implementation-detail local variable, +which can be unfolded by `simp` when `zetaDelta := false`. - Note that this one is not a `Macro` because we need to access the local context. +Note that this one is not a `Macro` because we need to access the local context. -/ private def expandNonAtomicExplicitSources (stx : Syntax) : TermElabM (Option Syntax) := do let sourcesOpt := stx[1] @@ -100,27 +125,44 @@ where let r ← go sources (sourcesNew.push sourceNew) `(let __src := $source; $r) -structure ExplicitSourceInfo where +/-- +An *explicit source* is one of the structures `sᵢ` that appear in `{ s₁, …, sₙ with … }`. +-/ +structure ExplicitSourceView where + /-- The syntax of the explicit source. -/ stx : Syntax + /-- The name of the structure for the type of the explicit source. -/ structName : Name deriving Inhabited -structure Source where - explicit : Array ExplicitSourceInfo -- `s₁ ... sₙ with` - implicit : Option Syntax -- `..` +/-- +A view of the sources of fields for the structure instance notation. +-/ +structure SourcesView where + /-- Explicit sources (i.e., one of the structures `sᵢ` that appear in `{ s₁, …, sₙ with … }`). -/ + explicit : Array ExplicitSourceView + /-- The syntax for a trailing `..`. This is "ellipsis mode" for missing fields, similar to ellipsis mode for applications. -/ + implicit : Option Syntax deriving Inhabited -def Source.isNone : Source → Bool +/-- Returns `true` if the structure instance has no sources (neither explicit sources nor a `..`). -/ +def SourcesView.isNone : SourcesView → Bool | { explicit := #[], implicit := none } => true | _ => false -/-- `optional (atomic (sepBy1 termParser ", " >> " with ")` -/ +/-- +Given an array of explicit sources, returns syntax of the form +`optional (atomic (sepBy1 termParser ", " >> " with ")` +-/ private def mkSourcesWithSyntax (sources : Array Syntax) : Syntax := let ref := sources[0]! let stx := Syntax.mkSep sources (mkAtomFrom ref ", ") mkNullNode #[stx, mkAtomFrom ref "with "] -private def getStructSource (structStx : Syntax) : TermElabM Source := +/-- +Creates a structure source view from structure instance notation. +-/ +private def getStructSources (structStx : Syntax) : TermElabM SourcesView := withRef structStx do let explicitSource := structStx[1] let implicitSource := structStx[3] @@ -138,10 +180,10 @@ private def getStructSource (structStx : Syntax) : TermElabM Source := return { explicit, implicit } /-- - We say a `{ ... }` notation is a `modifyOp` if it contains only one - ``` - def structInstArrayRef := leading_parser "[" >> termParser >>"]" - ``` +We say a structure instance notation is a "modifyOp" if it contains only a single array update. +```lean +def structInstArrayRef := leading_parser "[" >> termParser >>"]" +``` -/ private def isModifyOp? (stx : Syntax) : TermElabM (Option Syntax) := do let s? ← stx[2].getSepArgs.foldlM (init := none) fun s? arg => do @@ -177,7 +219,11 @@ private def isModifyOp? (stx : Syntax) : TermElabM (Option Syntax) := do | none => return none | some s => if s[0][0].getKind == ``Lean.Parser.Term.structInstArrayRef then return s? else return none -private def elabModifyOp (stx modifyOp : Syntax) (sources : Array ExplicitSourceInfo) (expectedType? : Option Expr) : TermElabM Expr := do +/-- +Given a `stx` that is a structure instance notation that's a modifyOp (according to `isModifyOp?`), elaborates it. +Only supports structure instances with a single source. +-/ +private def elabModifyOp (stx modifyOp : Syntax) (sources : Array ExplicitSourceView) (expectedType? : Option Expr) : TermElabM Expr := do if sources.size > 1 then throwError "invalid \{...} notation, multiple sources and array update is not supported." let cont (val : Syntax) : TermElabM Expr := do @@ -204,12 +250,13 @@ private def elabModifyOp (stx modifyOp : Syntax) (sources : Array ExplicitSource cont val /-- - Get structure name. - This method triest to postpone execution if the expected type is not available. +Gets the structure name for the structure instance from the expected type and the sources. +This method tries to postpone execution if the expected type is not available. - If the expected type is available and it is a structure, then we use it. - Otherwise, we use the type of the first source. -/ -private def getStructName (expectedType? : Option Expr) (sourceView : Source) : TermElabM Name := do +If the expected type is available and it is a structure, then we use it. +Otherwise, we use the type of the first source. +-/ +private def getStructName (expectedType? : Option Expr) (sourceView : SourcesView) : TermElabM Name := do tryPostponeIfNoneOrMVar expectedType? let useSource : Unit → TermElabM Name := fun _ => do unless sourceView.explicit.isEmpty do @@ -226,7 +273,7 @@ private def getStructName (expectedType? : Option Expr) (sourceView : Source) : unless isStructure (← getEnv) constName do throwError "invalid \{...} notation, structure type expected{indentExpr expectedType}" return constName - | _ => useSource () + | _ => useSource () where throwUnknownExpectedType := throwError "invalid \{...} notation, expected type is not known" @@ -237,72 +284,92 @@ where else throwError "invalid \{...} notation, {kind} type is not of the form (C ...){indentExpr type}" +/-- +A component of a left-hand side for a field appearing in structure instance syntax. +-/ inductive FieldLHS where + /-- A name component for a field left-hand side. For example, `x` and `y` in `{ x.y := v }`. -/ | fieldName (ref : Syntax) (name : Name) + /-- A numeric index component for a field left-hand side. For example `3` in `{ x.3 := v }`. -/ | fieldIndex (ref : Syntax) (idx : Nat) + /-- An array indexing component for a field left-hand side. For example `[3]` in `{ arr[3] := v }`. -/ | modifyOp (ref : Syntax) (index : Syntax) deriving Inhabited -instance : ToFormat FieldLHS := ⟨fun lhs => - match lhs with - | .fieldName _ n => format n - | .fieldIndex _ i => format i - | .modifyOp _ i => "[" ++ i.prettyPrint ++ "]"⟩ +instance : ToFormat FieldLHS where + format + | .fieldName _ n => format n + | .fieldIndex _ i => format i + | .modifyOp _ i => "[" ++ i.prettyPrint ++ "]" +/-- +`FieldVal StructInstView` is a representation of a field value in the structure instance. +-/ inductive FieldVal (σ : Type) where - | term (stx : Syntax) : FieldVal σ + /-- A `term` to use for the value of the field. -/ + | term (stx : Syntax) : FieldVal σ + /-- A `StructInstView` to use for the value of a subobject field. -/ | nested (s : σ) : FieldVal σ - | default : FieldVal σ -- mark that field must be synthesized using default value + /-- A field that was not provided and should be synthesized using default values. -/ + | default : FieldVal σ deriving Inhabited +/-- +`Field StructInstView` is a representation of a field in the structure instance. +-/ structure Field (σ : Type) where + /-- The whole field syntax. -/ ref : Syntax + /-- The LHS decomposed into components. -/ lhs : List FieldLHS + /-- The value of the field. -/ val : FieldVal σ + /-- The elaborated field value, filled in at `elabStruct`. + Missing fields use a metavariable for the elaborated value and are later solved for in `DefaultFields.propagate`. -/ expr? : Option Expr := none deriving Inhabited +/-- +Returns if the field has a single component in its LHS. +-/ def Field.isSimple {σ} : Field σ → Bool | { lhs := [_], .. } => true | _ => false -inductive Struct where - /-- Remark: the field `params` is use for default value propagation. It is initially empty, and then set at `elabStruct`. -/ - | mk (ref : Syntax) (structName : Name) (params : Array (Name × Expr)) (fields : List (Field Struct)) (source : Source) +/-- +The view for structure instance notation. +-/ +structure StructInstView where + /-- The syntax for the whole structure instance. -/ + ref : Syntax + /-- The name of the structure for the type of the structure instance. -/ + structName : Name + /-- Used for default values, to propagate structure type parameters. It is initially empty, and then set at `elabStruct`. -/ + params : Array (Name × Expr) + /-- The fields of the structure instance. -/ + fields : List (Field StructInstView) + /-- The additional sources for fields for the structure instance. -/ + sources : SourcesView deriving Inhabited -abbrev Fields := List (Field Struct) - -def Struct.ref : Struct → Syntax - | ⟨ref, _, _, _, _⟩ => ref - -def Struct.structName : Struct → Name - | ⟨_, structName, _, _, _⟩ => structName - -def Struct.params : Struct → Array (Name × Expr) - | ⟨_, _, params, _, _⟩ => params - -def Struct.fields : Struct → Fields - | ⟨_, _, _, fields, _⟩ => fields - -def Struct.source : Struct → Source - | ⟨_, _, _, _, s⟩ => s +/-- Abbreviation for the type of `StructInstView.fields`, namely `List (Field StructInstView)`. -/ +abbrev Fields := List (Field StructInstView) /-- `true` iff all fields of the given structure are marked as `default` -/ -partial def Struct.allDefault (s : Struct) : Bool := +partial def StructInstView.allDefault (s : StructInstView) : Bool := s.fields.all fun { val := val, .. } => match val with | .term _ => false | .default => true | .nested s => allDefault s -def formatField (formatStruct : Struct → Format) (field : Field Struct) : Format := +def formatField (formatStruct : StructInstView → Format) (field : Field StructInstView) : Format := Format.joinSep field.lhs " . " ++ " := " ++ match field.val with | .term v => v.prettyPrint | .nested s => formatStruct s | .default => "" -partial def formatStruct : Struct → Format +partial def formatStruct : StructInstView → Format | ⟨_, _, _, fields, source⟩ => let fieldsFmt := Format.joinSep (fields.map (formatField formatStruct)) ", " let implicitFmt := if source.implicit.isSome then " .. " else "" @@ -311,31 +378,39 @@ partial def formatStruct : Struct → Format else "{" ++ format (source.explicit.map (·.stx)) ++ " with " ++ fieldsFmt ++ implicitFmt ++ "}" -instance : ToFormat Struct := ⟨formatStruct⟩ -instance : ToString Struct := ⟨toString ∘ format⟩ +instance : ToFormat StructInstView := ⟨formatStruct⟩ +instance : ToString StructInstView := ⟨toString ∘ format⟩ -instance : ToFormat (Field Struct) := ⟨formatField formatStruct⟩ -instance : ToString (Field Struct) := ⟨toString ∘ format⟩ +instance : ToFormat (Field StructInstView) := ⟨formatField formatStruct⟩ +instance : ToString (Field StructInstView) := ⟨toString ∘ format⟩ + +/-- +Converts a `FieldLHS` back into syntax. This assumes the `ref` fields have the correct structure. -/- Recall that `structInstField` elements have the form -``` - def structInstField := leading_parser structInstLVal >> " := " >> termParser - def structInstLVal := leading_parser (ident <|> numLit <|> structInstArrayRef) >> many (("." >> (ident <|> numLit)) <|> structInstArrayRef) - def structInstArrayRef := leading_parser "[" >> termParser >>"]" +```lean +def structInstField := leading_parser structInstLVal >> " := " >> termParser +def structInstLVal := leading_parser (ident <|> numLit <|> structInstArrayRef) >> many (("." >> (ident <|> numLit)) <|> structInstArrayRef) +def structInstArrayRef := leading_parser "[" >> termParser >>"]" ``` -/ -- Remark: this code relies on the fact that `expandStruct` only transforms `fieldLHS.fieldName` -def FieldLHS.toSyntax (first : Bool) : FieldLHS → Syntax +private def FieldLHS.toSyntax (first : Bool) : FieldLHS → Syntax | .modifyOp stx _ => stx | .fieldName stx name => if first then mkIdentFrom stx name else mkGroupNode #[mkAtomFrom stx ".", mkIdentFrom stx name] | .fieldIndex stx _ => if first then stx else mkGroupNode #[mkAtomFrom stx ".", stx] -def FieldVal.toSyntax : FieldVal Struct → Syntax +/-- +Converts a `FieldVal StructInstView` back into syntax. Only supports `.term`, and it assumes the `stx` field has the correct structure. +-/ +private def FieldVal.toSyntax : FieldVal Struct → Syntax | .term stx => stx - | _ => unreachable! + | _ => unreachable! -def Field.toSyntax : Field Struct → Syntax +/-- +Converts a `Field StructInstView` back into syntax. Used to construct synthetic structure instance notation for subobjects in `StructInst.expandStruct` processing. +-/ +private def Field.toSyntax : Field Struct → Syntax | field => let stx := field.ref let stx := stx.setArg 2 field.val.toSyntax @@ -343,6 +418,7 @@ def Field.toSyntax : Field Struct → Syntax | first::rest => stx.setArg 0 <| mkNullNode #[first.toSyntax true, mkNullNode <| rest.toArray.map (FieldLHS.toSyntax false) ] | _ => unreachable! +/-- Creates a view of a field left-hand side. -/ private def toFieldLHS (stx : Syntax) : MacroM FieldLHS := if stx.getKind == ``Lean.Parser.Term.structInstArrayRef then return FieldLHS.modifyOp stx stx[1] @@ -355,7 +431,12 @@ private def toFieldLHS (stx : Syntax) : MacroM FieldLHS := | some idx => return FieldLHS.fieldIndex stx idx | none => Macro.throwError "unexpected structure syntax" -private def mkStructView (stx : Syntax) (structName : Name) (source : Source) : MacroM Struct := do +/-- +Creates a structure instance view from structure instance notation +and the computed structure name (from `Lean.Elab.Term.StructInst.getStructName`) +and structure source view (from `Lean.Elab.Term.StructInst.getStructSources`). +-/ +private def mkStructView (stx : Syntax) (structName : Name) (sources : SourcesView) : MacroM StructInstView := do /- Recall that `stx` is of the form ``` leading_parser "{" >> optional (atomic (sepBy1 termParser ", " >> " with ")) @@ -371,24 +452,18 @@ private def mkStructView (stx : Syntax) (structName : Name) (source : Source) : let val := fieldStx[2] let first ← toFieldLHS fieldStx[0][0] let rest ← fieldStx[0][1].getArgs.toList.mapM toFieldLHS - return { ref := fieldStx, lhs := first :: rest, val := FieldVal.term val : Field Struct } - return ⟨stx, structName, #[], fields, source⟩ + return { ref := fieldStx, lhs := first :: rest, val := FieldVal.term val : Field StructInstView } + return { ref := stx, structName, params := #[], fields, sources } -def Struct.modifyFieldsM {m : Type → Type} [Monad m] (s : Struct) (f : Fields → m Fields) : m Struct := +def StructInstView.modifyFieldsM {m : Type → Type} [Monad m] (s : StructInstView) (f : Fields → m Fields) : m StructInstView := match s with - | ⟨ref, structName, params, fields, source⟩ => return ⟨ref, structName, params, (← f fields), source⟩ + | { ref, structName, params, fields, sources } => return { ref, structName, params, fields := (← f fields), sources } -def Struct.modifyFields (s : Struct) (f : Fields → Fields) : Struct := +def StructInstView.modifyFields (s : StructInstView) (f : Fields → Fields) : StructInstView := Id.run <| s.modifyFieldsM f -def Struct.setFields (s : Struct) (fields : Fields) : Struct := - s.modifyFields fun _ => fields - -def Struct.setParams (s : Struct) (ps : Array (Name × Expr)) : Struct := - match s with - | ⟨ref, structName, _, fields, source⟩ => ⟨ref, structName, ps, fields, source⟩ - -private def expandCompositeFields (s : Struct) : Struct := +/-- Expands name field LHSs with multi-component names into multi-component LHSs. -/ +private def expandCompositeFields (s : StructInstView) : StructInstView := s.modifyFields fun fields => fields.map fun field => match field with | { lhs := .fieldName _ (.str Name.anonymous ..) :: _, .. } => field | { lhs := .fieldName ref n@(.str ..) :: rest, .. } => @@ -396,7 +471,8 @@ private def expandCompositeFields (s : Struct) : Struct := { field with lhs := newEntries ++ rest } | _ => field -private def expandNumLitFields (s : Struct) : TermElabM Struct := +/-- Replaces numeric index field LHSs with the corresponding named field, or throws an error if no such field exists. -/ +private def expandNumLitFields (s : StructInstView) : TermElabM StructInstView := s.modifyFieldsM fun fields => do let env ← getEnv let fieldNames := getStructureFields env s.structName @@ -407,28 +483,31 @@ private def expandNumLitFields (s : Struct) : TermElabM Struct := else return { field with lhs := .fieldName ref fieldNames[idx - 1]! :: rest } | _ => return field -/-- For example, consider the following structures: - ``` - structure A where - x : Nat - - structure B extends A where - y : Nat - - structure C extends B where - z : Bool - ``` - This method expands parent structure fields using the path to the parent structure. - For example, - ``` - { x := 0, y := 0, z := true : C } - ``` - is expanded into - ``` - { toB.toA.x := 0, toB.y := 0, z := true : C } - ``` --/ -private def expandParentFields (s : Struct) : TermElabM Struct := do +/-- +Expands fields that are actually represented as fields of subobject fields. + +For example, consider the following structures: +``` +structure A where + x : Nat + +structure B extends A where + y : Nat + +structure C extends B where + z : Bool +``` +This method expands parent structure fields using the path to the parent structure. +For example, +``` +{ x := 0, y := 0, z := true : C } +``` +is expanded into +``` +{ toB.toA.x := 0, toB.y := 0, z := true : C } +``` +-/ +private def expandParentFields (s : StructInstView) : TermElabM StructInstView := do let env ← getEnv s.modifyFieldsM fun fields => fields.mapM fun field => do match field with | { lhs := .fieldName ref fieldName :: _, .. } => @@ -448,6 +527,11 @@ private def expandParentFields (s : Struct) : TermElabM Struct := do private abbrev FieldMap := Std.HashMap Name Fields +/-- +Creates a hash map collecting all fields with the same first name component. +Throws an error if there are multiple simple fields with the same name. +Used by `StructInst.expandStruct` processing. +-/ private def mkFieldMap (fields : Fields) : TermElabM FieldMap := fields.foldlM (init := {}) fun fieldMap field => match field.lhs with @@ -461,15 +545,16 @@ private def mkFieldMap (fields : Fields) : TermElabM FieldMap := | _ => return fieldMap.insert fieldName [field] | _ => unreachable! -private def isSimpleField? : Fields → Option (Field Struct) +/-- +Given a value of the hash map created by `mkFieldMap`, returns true if the value corresponds to a simple field. +-/ +private def isSimpleField? : Fields → Option (Field StructInstView) | [field] => if field.isSimple then some field else none | _ => none -private def getFieldIdx (structName : Name) (fieldNames : Array Name) (fieldName : Name) : TermElabM Nat := do - match fieldNames.findIdx? fun n => n == fieldName with - | some idx => return idx - | none => throwError "field '{fieldName}' is not a valid field of '{structName}'" - +/-- +Creates projection notation for the given structure field. Used +-/ def mkProjStx? (s : Syntax) (structName : Name) (fieldName : Name) : TermElabM (Option Syntax) := do if (findField? (← getEnv) structName fieldName).isNone then return none @@ -478,7 +563,10 @@ def mkProjStx? (s : Syntax) (structName : Name) (fieldName : Name) : TermElabM ( #[mkAtomFrom s "@", mkNode ``Parser.Term.proj #[s, mkAtomFrom s ".", mkIdentFrom s fieldName]] -def findField? (fields : Fields) (fieldName : Name) : Option (Field Struct) := +/-- +Finds a simple field of the given name. +-/ +def findField? (fields : Fields) (fieldName : Name) : Option (Field StructInstView) := fields.find? fun field => match field.lhs with | [.fieldName _ n] => n == fieldName @@ -486,7 +574,10 @@ def findField? (fields : Fields) (fieldName : Name) : Option (Field Struct) := mutual - private partial def groupFields (s : Struct) : TermElabM Struct := do + /-- + Groups compound fields according to which subobject they are from. + -/ + private partial def groupFields (s : StructInstView) : TermElabM StructInstView := do let env ← getEnv withRef s.ref do s.modifyFieldsM fun fields => do @@ -499,14 +590,14 @@ mutual let field := fields.head! match Lean.isSubobjectField? env s.structName fieldName with | some substructName => - let substruct := Struct.mk s.ref substructName #[] substructFields s.source + let substruct := { ref := s.ref, structName := substructName, params := #[], fields := substructFields, sources := s.sources } let substruct ← expandStruct substruct pure { field with lhs := [field.lhs.head!], val := FieldVal.nested substruct } | none => let updateSource (structStx : Syntax) : TermElabM Syntax := do - let sourcesNew ← s.source.explicit.filterMapM fun source => mkProjStx? source.stx source.structName fieldName + let sourcesNew ← s.sources.explicit.filterMapM fun source => mkProjStx? source.stx source.structName fieldName let explicitSourceStx := if sourcesNew.isEmpty then mkNullNode else mkSourcesWithSyntax sourcesNew - let implicitSourceStx := s.source.implicit.getD mkNullNode + let implicitSourceStx := s.sources.implicit.getD mkNullNode return (structStx.setArg 1 explicitSourceStx).setArg 3 implicitSourceStx let valStx := s.ref -- construct substructure syntax using s.ref as template let valStx := valStx.setArg 4 mkNullNode -- erase optional expected type @@ -518,7 +609,7 @@ mutual Adds in the missing fields using the explicit sources. Invariant: a missing field always comes from the first source that can provide it. -/ - private partial def addMissingFields (s : Struct) : TermElabM Struct := do + private partial def addMissingFields (s : StructInstView) : TermElabM StructInstView := do let env ← getEnv let fieldNames := getStructureFields env s.structName let ref := s.ref.mkSynthetic @@ -527,7 +618,7 @@ mutual match findField? s.fields fieldName with | some field => return field::fields | none => - let addField (val : FieldVal Struct) : TermElabM Fields := do + let addField (val : FieldVal StructInstView) : TermElabM Fields := do return { ref, lhs := [FieldLHS.fieldName ref fieldName], val := val } :: fields match Lean.isSubobjectField? env s.structName fieldName with | some substructName => @@ -535,8 +626,8 @@ mutual let downFields := getStructureFieldsFlattened env substructName false -- Filter out all explicit sources that do not share a leaf field keeping -- structure with no fields - let filtered := s.source.explicit.filter fun source => - let sourceFields := getStructureFieldsFlattened env source.structName false + let filtered := s.sources.explicit.filter fun sources => + let sourceFields := getStructureFieldsFlattened env sources.structName false sourceFields.any (fun name => downFields.contains name) || sourceFields.isEmpty -- Take the first such one remaining match filtered[0]? with @@ -550,27 +641,30 @@ mutual -- No sources could provide this subobject in the proper order. -- Recurse to handle default values for fields. else - let substruct := Struct.mk ref substructName #[] [] s.source + let substruct := { ref, structName := substructName, params := #[], fields := [], sources := s.sources } let substruct ← expandStruct substruct addField (FieldVal.nested substruct) -- No sources could provide this subobject. -- Recurse to handle default values for fields. | none => - let substruct := Struct.mk ref substructName #[] [] s.source + let substruct := { ref, structName := substructName, params := #[], fields := [], sources := s.sources } let substruct ← expandStruct substruct addField (FieldVal.nested substruct) -- Since this is not a subobject field, we are free to use the first source that can -- provide it. | none => - if let some val ← s.source.explicit.findSomeM? fun source => mkProjStx? source.stx source.structName fieldName then + if let some val ← s.sources.explicit.findSomeM? fun source => mkProjStx? source.stx source.structName fieldName then addField (FieldVal.term val) - else if s.source.implicit.isSome then + else if s.sources.implicit.isSome then addField (FieldVal.term (mkHole ref)) else addField FieldVal.default - return s.setFields fields.reverse + return { s with fields := fields.reverse } - private partial def expandStruct (s : Struct) : TermElabM Struct := do + /-- + Expands all fields of the structure instance, consolidates compound fields into subobject fields, and adds missing fields. + -/ + private partial def expandStruct (s : StructInstView) : TermElabM StructInstView := do let s := expandCompositeFields s let s ← expandNumLitFields s let s ← expandParentFields s @@ -579,10 +673,17 @@ mutual end +/-- +The constructor to use for the structure instance notation. +-/ structure CtorHeaderResult where + /-- The constructor function with applied structure parameters. -/ ctorFn : Expr + /-- The type of `ctorFn` -/ ctorFnType : Expr + /-- Instance metavariables for structure parameters that are instance implicit. -/ instMVars : Array MVarId + /-- Type parameter names and metavariables for each parameter. Used to seed `StructInstView.params`. -/ params : Array (Name × Expr) private def mkCtorHeaderAux : Nat → Expr → Expr → Array MVarId → Array (Name × Expr) → TermElabM CtorHeaderResult @@ -604,6 +705,7 @@ private partial def getForallBody : Nat → Expr → Option Expr | _+1, _ => none | 0, type => type +/-- Attempts to use the expected type to solve for structure parameters. -/ private def propagateExpectedType (type : Expr) (numFields : Nat) (expectedType? : Option Expr) : TermElabM Unit := do match expectedType? with | none => return () @@ -614,6 +716,7 @@ private def propagateExpectedType (type : Expr) (numFields : Nat) (expectedType? unless typeBody.hasLooseBVars do discard <| isDefEq expectedType typeBody +/-- Elaborates the structure constructor using the expected type, filling in all structure parameters. -/ private def mkCtorHeader (ctorVal : ConstructorVal) (expectedType? : Option Expr) : TermElabM CtorHeaderResult := do let us ← mkFreshLevelMVars ctorVal.levelParams.length let val := Lean.mkConst ctorVal.name us @@ -623,32 +726,43 @@ private def mkCtorHeader (ctorVal : ConstructorVal) (expectedType? : Option Expr synthesizeAppInstMVars r.instMVars r.ctorFn return r +/-- Annotates an expression that it is a value for a missing field. -/ def markDefaultMissing (e : Expr) : Expr := mkAnnotation `structInstDefault e +/-- If the expression has been annotated by `markDefaultMissing`, returns the unannotated expression. -/ def defaultMissing? (e : Expr) : Option Expr := annotation? `structInstDefault e +/-- Throws "failed to elaborate field" error. -/ def throwFailedToElabField {α} (fieldName : Name) (structName : Name) (msgData : MessageData) : TermElabM α := throwError "failed to elaborate field '{fieldName}' of '{structName}, {msgData}" -def trySynthStructInstance? (s : Struct) (expectedType : Expr) : TermElabM (Option Expr) := do +/-- If the struct has all-missing fields, tries to synthesize the structure using typeclass inference. -/ +def trySynthStructInstance? (s : StructInstView) (expectedType : Expr) : TermElabM (Option Expr) := do if !s.allDefault then return none else try synthInstance? expectedType catch _ => return none +/-- The result of elaborating a `StructInstView` structure instance view. -/ structure ElabStructResult where + /-- The elaborated value. -/ val : Expr - struct : Struct + /-- The modified `StructInstView` view after elaboration. -/ + struct : StructInstView + /-- Metavariables for instance implicit fields. These will be registered after default value propagation. -/ instMVars : Array MVarId -private partial def elabStruct (s : Struct) (expectedType? : Option Expr) : TermElabM ElabStructResult := withRef s.ref do +/-- +Main elaborator for structure instances. +-/ +private partial def elabStructInstView (s : StructInstView) (expectedType? : Option Expr) : TermElabM ElabStructResult := withRef s.ref do let env ← getEnv let ctorVal := getStructureCtor env s.structName if isPrivateNameFromImportedModule env ctorVal.name then throwError "invalid \{...} notation, constructor for `{s.structName}` is marked as private" - -- We store the parameters at the resulting `Struct`. We use this information during default value propagation. + -- We store the parameters at the resulting `StructInstView`. We use this information during default value propagation. let { ctorFn, ctorFnType, params, .. } ← mkCtorHeader ctorVal expectedType? let (e, _, fields, instMVars) ← s.fields.foldlM (init := (ctorFn, ctorFnType, [], #[])) fun (e, type, fields, instMVars) field => do match field.lhs with @@ -657,7 +771,7 @@ private partial def elabStruct (s : Struct) (expectedType? : Option Expr) : Term trace[Elab.struct] "elabStruct {field}, {type}" match type with | .forallE _ d b bi => - let cont (val : Expr) (field : Field Struct) (instMVars := instMVars) : TermElabM (Expr × Expr × Fields × Array MVarId) := do + let cont (val : Expr) (field : Field StructInstView) (instMVars := instMVars) : TermElabM (Expr × Expr × Fields × Array MVarId) := do pushInfoTree <| InfoTree.node (children := {}) <| Info.ofFieldInfo { projName := s.structName.append fieldName, fieldName, lctx := (← getLCtx), val, stx := ref } let e := mkApp e val @@ -671,7 +785,7 @@ private partial def elabStruct (s : Struct) (expectedType? : Option Expr) : Term match (← trySynthStructInstance? s d) with | some val => cont val { field with val := FieldVal.term (mkHole field.ref) } | none => - let { val, struct := sNew, instMVars := instMVarsNew } ← elabStruct s (some d) + let { val, struct := sNew, instMVars := instMVarsNew } ← elabStructInstView s (some d) let val ← ensureHasType d val cont val { field with val := FieldVal.nested sNew } (instMVars ++ instMVarsNew) | .default => @@ -700,17 +814,21 @@ private partial def elabStruct (s : Struct) (expectedType? : Option Expr) : Term cont (markDefaultMissing val) field | _ => withRef field.ref <| throwFailedToElabField fieldName s.structName m!"unexpected constructor type{indentExpr type}" | _ => throwErrorAt field.ref "unexpected unexpanded structure field" - return { val := e, struct := s.setFields fields.reverse |>.setParams params, instMVars } + return { val := e, struct := { s with fields := fields.reverse, params }, instMVars } namespace DefaultFields +/-- +Context for default value propagation. +-/ structure Context where - -- We must search for default values overridden in derived structures - structs : Array Struct := #[] + /-- The current path through `.nested` subobject structures. We must search for default values overridden in derived structures. -/ + structs : Array StructInstView := #[] + /-- The collection of structures that could provide a default value. -/ allStructNames : Array Name := #[] /-- Consider the following example: - ``` + ```lean structure A where x : Nat := 1 @@ -736,22 +854,29 @@ structure Context where -/ maxDistance : Nat := 0 +/-- +State for default value propagation +-/ structure State where + /-- Whether progress has been made so far on this round of the propagation loop. -/ progress : Bool := false -partial def collectStructNames (struct : Struct) (names : Array Name) : Array Name := +/-- Collects all structures that may provide default values for fields. -/ +partial def collectStructNames (struct : StructInstView) (names : Array Name) : Array Name := let names := names.push struct.structName struct.fields.foldl (init := names) fun names field => match field.val with | .nested struct => collectStructNames struct names | _ => names -partial def getHierarchyDepth (struct : Struct) : Nat := +/-- Gets the maximum nesting depth of subobjects. -/ +partial def getHierarchyDepth (struct : StructInstView) : Nat := struct.fields.foldl (init := 0) fun max field => match field.val with | .nested struct => Nat.max max (getHierarchyDepth struct + 1) | _ => max +/-- Returns whether the field is still missing. -/ def isDefaultMissing? [Monad m] [MonadMCtx m] (field : Field Struct) : m Bool := do if let some expr := field.expr? then if let some (.mvar mvarId) := defaultMissing? expr then @@ -759,40 +884,51 @@ def isDefaultMissing? [Monad m] [MonadMCtx m] (field : Field Struct) : m Bool := return true return false -partial def findDefaultMissing? [Monad m] [MonadMCtx m] (struct : Struct) : m (Option (Field Struct)) := +/-- Returns a field that is still missing. -/ +partial def findDefaultMissing? [Monad m] [MonadMCtx m] (struct : StructInstView) : m (Option (Field StructInstView)) := struct.fields.findSomeM? fun field => do match field.val with | .nested struct => findDefaultMissing? struct | _ => return if (← isDefaultMissing? field) then field else none -partial def allDefaultMissing [Monad m] [MonadMCtx m] (struct : Struct) : m (Array (Field Struct)) := +/-- Returns all fields that are still missing. -/ +partial def allDefaultMissing [Monad m] [MonadMCtx m] (struct : StructInstView) : m (Array (Field StructInstView)) := go struct *> get |>.run' #[] where - go (struct : Struct) : StateT (Array (Field Struct)) m Unit := + go (struct : StructInstView) : StateT (Array (Field StructInstView)) m Unit := for field in struct.fields do if let .nested struct := field.val then go struct else if (← isDefaultMissing? field) then modify (·.push field) -def getFieldName (field : Field Struct) : Name := +/-- Returns the name of the field. Assumes all fields under consideration are simple and named. -/ +def getFieldName (field : Field StructInstView) : Name := match field.lhs with | [.fieldName _ fieldName] => fieldName | _ => unreachable! abbrev M := ReaderT Context (StateRefT State TermElabM) +/-- Returns whether we should interrupt the round because we have made progress allowing nonzero depth. -/ def isRoundDone : M Bool := do return (← get).progress && (← read).maxDistance > 0 -def getFieldValue? (struct : Struct) (fieldName : Name) : Option Expr := +/-- Returns the `expr?` for the given field. -/ +def getFieldValue? (struct : StructInstView) (fieldName : Name) : Option Expr := struct.fields.findSome? fun field => if getFieldName field == fieldName then field.expr? else none -partial def mkDefaultValueAux? (struct : Struct) : Expr → TermElabM (Option Expr) +/-- Instantiates a default value from the given default value declaration, if applicable. -/ +partial def mkDefaultValue? (struct : StructInstView) (cinfo : ConstantInfo) : TermElabM (Option Expr) := + withRef struct.ref do + let us ← mkFreshLevelMVarsFor cinfo + process (← instantiateValueLevelParams cinfo us) +where + process : Expr → TermElabM (Option Expr) | .lam n d b c => withRef struct.ref do if c.isExplicit then let fieldName := n @@ -801,29 +937,26 @@ partial def mkDefaultValueAux? (struct : Struct) : Expr → TermElabM (Option Ex | some val => let valType ← inferType val if (← isDefEq valType d) then - mkDefaultValueAux? struct (b.instantiate1 val) + process (b.instantiate1 val) else return none else if let some (_, param) := struct.params.find? fun (paramName, _) => paramName == n then -- Recall that we did not use to have support for parameter propagation here. if (← isDefEq (← inferType param) d) then - mkDefaultValueAux? struct (b.instantiate1 param) + process (b.instantiate1 param) else return none else let arg ← mkFreshExprMVar d - mkDefaultValueAux? struct (b.instantiate1 arg) + process (b.instantiate1 arg) | e => let_expr id _ a := e | return some e return some a -def mkDefaultValue? (struct : Struct) (cinfo : ConstantInfo) : TermElabM (Option Expr) := - withRef struct.ref do - let us ← mkFreshLevelMVarsFor cinfo - mkDefaultValueAux? struct (← instantiateValueLevelParams cinfo us) - -/-- Reduce default value. It performs beta reduction and projections of the given structures. -/ +/-- +Reduces a default value. It performs beta reduction and projections of the given structures to reduce them to the provided values for fields. +-/ partial def reduce (structNames : Array Name) (e : Expr) : MetaM Expr := do match e with | .forallE .. => @@ -880,7 +1013,10 @@ where else k -partial def tryToSynthesizeDefault (structs : Array Struct) (allStructNames : Array Name) (maxDistance : Nat) (fieldName : Name) (mvarId : MVarId) : TermElabM Bool := +/-- +Attempts to synthesize a default value for a missing field `fieldName` using default values from each structure in `structs`. +-/ +def tryToSynthesizeDefault (structs : Array StructInstView) (allStructNames : Array Name) (maxDistance : Nat) (fieldName : Name) (mvarId : MVarId) : TermElabM Bool := let rec loop (i : Nat) (dist : Nat) := do if dist > maxDistance then return false @@ -915,7 +1051,10 @@ partial def tryToSynthesizeDefault (structs : Array Struct) (allStructNames : Ar return false loop 0 0 -partial def step (struct : Struct) : M Unit := +/-- +Performs one step of default value synthesis. +-/ +partial def step (struct : StructInstView) : M Unit := unless (← isRoundDone) do withReader (fun ctx => { ctx with structs := ctx.structs.push struct }) do for field in struct.fields do @@ -932,7 +1071,10 @@ partial def step (struct : Struct) : M Unit := modify fun _ => { progress := true } | _ => pure () -partial def propagateLoop (hierarchyDepth : Nat) (d : Nat) (struct : Struct) : M Unit := do +/-- +Main entry point to default value synthesis in the `M` monad. +-/ +partial def propagateLoop (hierarchyDepth : Nat) (d : Nat) (struct : StructInstView) : M Unit := do match (← findDefaultMissing? struct) with | none => return () -- Done | some field => @@ -955,16 +1097,22 @@ partial def propagateLoop (hierarchyDepth : Nat) (d : Nat) (struct : Struct) : M else propagateLoop hierarchyDepth (d+1) struct -def propagate (struct : Struct) : TermElabM Unit := +/-- +Synthesizes default values for all missing fields, if possible. +-/ +def propagate (struct : StructInstView) : TermElabM Unit := let hierarchyDepth := getHierarchyDepth struct let structNames := collectStructNames struct #[] propagateLoop hierarchyDepth 0 struct { allStructNames := structNames } |>.run' {} end DefaultFields -private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (source : Source) : TermElabM Expr := do - let structName ← getStructName expectedType? source - let struct ← liftMacroM <| mkStructView stx structName source +/-- +Main entry point to elaborator for structure instance notation, unless the structure instance is a modifyOp. +-/ +private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (sources : SourcesView) : TermElabM Expr := do + let structName ← getStructName expectedType? sources + let struct ← liftMacroM <| mkStructView stx structName sources let struct ← expandStruct struct trace[Elab.struct] "{struct}" /- We try to synthesize pending problems with `withSynthesize` combinator before trying to use default values. @@ -982,7 +1130,7 @@ private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (sour TODO: investigate whether this design decision may have unintended side effects or produce confusing behavior. -/ - let { val := r, struct, instMVars } ← withSynthesize (postpone := .yes) <| elabStruct struct expectedType? + let { val := r, struct, instMVars } ← withSynthesize (postpone := .yes) <| elabStructInstView struct expectedType? trace[Elab.struct] "before propagate {r}" DefaultFields.propagate struct synthesizeAppInstMVars instMVars r @@ -992,13 +1140,13 @@ private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (sour match (← expandNonAtomicExplicitSources stx) with | some stxNew => withMacroExpansion stx stxNew <| elabTerm stxNew expectedType? | none => - let sourceView ← getStructSource stx + let sourcesView ← getStructSources stx if let some modifyOp ← isModifyOp? stx then - if sourceView.explicit.isEmpty then + if sourcesView.explicit.isEmpty then throwError "invalid \{...} notation, explicit source is required when using '[] := '" - elabModifyOp stx modifyOp sourceView.explicit expectedType? + elabModifyOp stx modifyOp sourcesView.explicit expectedType? else - elabStructInstAux stx expectedType? sourceView + elabStructInstAux stx expectedType? sourcesView builtin_initialize registerTraceClass `Elab.struct