diff --git a/src/Lean/Compiler/LCNF/MonoTypes.lean b/src/Lean/Compiler/LCNF/MonoTypes.lean index ce366b187d23..88b65701d852 100644 --- a/src/Lean/Compiler/LCNF/MonoTypes.lean +++ b/src/Lean/Compiler/LCNF/MonoTypes.lean @@ -72,21 +72,23 @@ The type contains only `→` and constants. -/ partial def toMonoType (type : Expr) : CoreM Expr := do let type := type.headBeta - if type.isErased then - return erasedExpr - else if isTypeFormerType type then - return erasedExpr - else match type with - | .const .. => visitApp type #[] - | .app .. => type.withApp visitApp - | .forallE _ d b _ => mkArrow (← toMonoType d) (← toMonoType (b.instantiate1 erasedExpr)) - | _ => return erasedExpr + match type with + | .const .. => visitApp type #[] + | .app .. => type.withApp visitApp + | .forallE _ d b _ => + let monoB ← toMonoType (b.instantiate1 anyExpr) + match monoB with + | .const ``lcErased _ => return erasedExpr + | _ => mkArrow (← toMonoType d) monoB + | .sort _ => return erasedExpr + | _ => return anyExpr where visitApp (f : Expr) (args : Array Expr) : CoreM Expr := do match f with + | .const ``lcErased _ => return erasedExpr + | .const ``lcAny _ => return anyExpr + | .const ``Decidable _ => return mkConst ``Bool | .const declName us => - if declName == ``Decidable then - return mkConst ``Bool if let some info ← hasTrivialStructure? declName then let ctorType ← getOtherDeclBaseType info.ctorName [] toMonoType (getParamTypes (← instantiateForall ctorType args[:info.numParams]))[info.fieldIdx]! @@ -96,15 +98,13 @@ where for arg in args do let .forallE _ d b _ := type.headBeta | unreachable! let arg := arg.headBeta - if arg.isErased then - result := mkApp result arg - else if d.isErased || d matches .sort _ then + if d matches .const ``lcErased _ | .sort _ then result := mkApp result (← toMonoType arg) else result := mkApp result erasedExpr type := b.instantiate1 arg return result - | _ => return erasedExpr + | _ => return anyExpr /-- State for the environment extension used to save the LCNF mono phase type for declarations diff --git a/src/Lean/Compiler/LCNF/PrettyPrinter.lean b/src/Lean/Compiler/LCNF/PrettyPrinter.lean index 31d18785edd7..37ee6ce9b265 100644 --- a/src/Lean/Compiler/LCNF/PrettyPrinter.lean +++ b/src/Lean/Compiler/LCNF/PrettyPrinter.lean @@ -81,7 +81,10 @@ def ppLetDecl (letDecl : LetDecl) : M Format := do return f!"let {letDecl.binderName} := {← ppLetValue letDecl.value}" def getFunType (ps : Array Param) (type : Expr) : CoreM Expr := - instantiateForall type (ps.map (mkFVar ·.fvarId)) + if type.isErased then + pure type + else + instantiateForall type (ps.map (mkFVar ·.fvarId)) mutual partial def ppFunDecl (funDecl : FunDecl) : M Format := do diff --git a/src/Lean/Compiler/LCNF/Types.lean b/src/Lean/Compiler/LCNF/Types.lean index 8a8d2278e7cd..fe16d3414f78 100644 --- a/src/Lean/Compiler/LCNF/Types.lean +++ b/src/Lean/Compiler/LCNF/Types.lean @@ -13,6 +13,7 @@ scoped notation:max "◾" => lcErased namespace LCNF def erasedExpr := mkConst ``lcErased +def anyExpr := mkConst ``lcAny def _root_.Lean.Expr.isErased (e : Expr) := e.isAppOf ``lcErased diff --git a/tests/lean/lcnfTypes.lean.expected.out b/tests/lean/lcnfTypes.lean.expected.out index a0a17428e967..8ed0fe8d9b51 100644 --- a/tests/lean/lcnfTypes.lean.expected.out +++ b/tests/lean/lcnfTypes.lean.expected.out @@ -34,16 +34,14 @@ weird1 : Bool → ◾ lamAny₁ : Bool → Monad ◾ lamAny₂ : Bool → Monad ◾ Term.constFold : List Ty → Ty → _root_.Term lcErased lcErased → _root_.Term lcErased lcErased -Term.denote : List Ty → Ty → _root_.Term lcErased lcErased → HList Ty lcErased lcErased → lcErased -HList.get : lcErased → - lcErased → List lcErased → lcErased → HList lcErased lcErased lcErased → Member lcErased lcErased lcErased → lcErased -Member.head : lcErased → lcErased → List lcErased → Member lcErased lcErased lcErased +Term.denote : lcErased +HList.get : lcErased → lcErased → List lcAny → lcAny → HList lcAny lcErased lcErased → Member lcAny lcErased lcErased → lcAny +Member.head : lcErased → lcAny → List lcAny → Member lcAny lcErased lcErased Ty.denote : lcErased -MonadControl.liftWith : lcErased → - lcErased → MonadControl lcErased lcErased → lcErased → ((lcErased → lcErased → lcErased) → lcErased) → lcErased -MonadControl.restoreM : lcErased → lcErased → MonadControl lcErased lcErased → lcErased → lcErased → lcErased -Decidable.casesOn : lcErased → lcErased → Bool → (lcErased → lcErased) → (lcErased → lcErased) → lcErased -Lean.getConstInfo : lcErased → Monad lcErased → MonadEnv lcErased → MonadError lcErased → Name → lcErased +MonadControl.liftWith : lcErased → lcErased → MonadControl lcErased lcErased → lcErased → ((lcErased → lcAny → lcAny) → lcAny) → lcAny +MonadControl.restoreM : lcErased → lcErased → MonadControl lcErased lcErased → lcErased → lcAny → lcAny +Decidable.casesOn : lcErased → lcErased → Bool → (lcErased → lcAny) → (lcErased → lcAny) → lcAny +Lean.getConstInfo : lcErased → Monad lcErased → MonadEnv lcErased → MonadError lcErased → Name → lcAny Lean.Meta.instMonadMetaM : Monad lcErased Lean.Meta.inferType : Expr → Context → lcErased → Core.Context → lcErased → PUnit → EStateM.Result Exception PUnit Expr Lean.Elab.Term.elabTerm : Syntax → @@ -54,4 +52,4 @@ Lean.Elab.Term.elabTerm : Syntax → lcErased → Context → lcErased → Core.Context → lcErased → PUnit → EStateM.Result Exception PUnit Expr Nat.add : Nat → Nat → Nat Fin.add : Nat → Nat → Nat → Nat -Lean.HashSetBucket.update : lcErased → Array (List lcErased) → USize → List lcErased → lcErased → Array (List lcErased) +Lean.HashSetBucket.update : lcErased → Array (List lcAny) → USize → List lcAny → lcErased → Array (List lcAny) diff --git a/tests/lean/run/erased.lean b/tests/lean/run/erased.lean index d3176568c5a6..d34195214e7b 100644 --- a/tests/lean/run/erased.lean +++ b/tests/lean/run/erased.lean @@ -25,7 +25,7 @@ info: [Compiler.result] size: 1 let _x.1 : PSigma lcErased lcErased := PSigma.mk lcErased ◾ ◾ ◾; return _x.1 [Compiler.result] size: 1 - def Erased.mk (α : lcErased) (a : lcErased) : PSigma lcErased lcErased := + def Erased.mk (α : lcErased) (a : lcAny) : PSigma lcErased lcErased := let _x.1 : PSigma lcErased lcErased := PSigma.mk lcErased ◾ ◾ ◾; return _x.1 -/ diff --git a/tests/lean/run/lcnfErasure.lean b/tests/lean/run/lcnfErasure.lean new file mode 100644 index 000000000000..369facb47f9a --- /dev/null +++ b/tests/lean/run/lcnfErasure.lean @@ -0,0 +1,264 @@ +import Lean +import Lean.Compiler.LCNF.MonoTypes +import Lean.Compiler.LCNF.Types + +open Lean Meta +open Compiler.LCNF (toLCNFType toMonoType) + +def toMonoLCNFType (type : Expr) : MetaM Expr := do + toMonoType (← toLCNFType type) + +def checkMonoType! (type₁ type₂ : Expr) : MetaM Unit := do + let monoType ← toMonoLCNFType type₁ + if monoType != type₂ then + throwError f!"mono type for {type₁} is {monoType}, expected {type₂}" + let monoMonoType ← toMonoType monoType + if monoMonoType != monoType then + throwError f!"toMonoType is not idempotent: toMonoType of {monoType} is {monoMonoType}" + +-- Nat + +#eval checkMonoType! + (.const ``Nat []) + (.const ``Nat []) + +-- Decidable + +#eval checkMonoType! + (.const ``Decidable []) + (.const ``Bool []) + +-- Prop + +#eval checkMonoType! + (.sort .zero) + (.const ``lcErased []) + +-- Type + +#eval checkMonoType! + (.sort (.succ .zero)) + (.const ``lcErased []) + +-- Sort u + +#eval checkMonoType! + (.sort (.param `u)) + (.const ``lcErased []) + +-- List Nat + +#eval checkMonoType! + (.app (.const ``List [.succ .zero]) (.const ``Nat [])) + (.app (.const ``List []) (.const ``Nat [])) + +-- List Type + +#eval checkMonoType! + (.app (.const ``List [.succ (.succ .zero)]) (.sort (.succ .zero))) + (.app (.const ``List []) (.const ``lcErased [])) + +-- Inductive type with trivial structure + +inductive TrivialInductive : Type where + | constructor (a : Nat) : TrivialInductive + +#eval checkMonoType! + (.const ``TrivialInductive []) + (.const ``Nat []) + +-- Inductive type with trivial structure and irrelevant fields + +inductive TrivialInductivePropFields : Type where + | constructor (p₁ : Prop) (a : Nat) (p₂ : Prop) : TrivialInductivePropFields + +#eval checkMonoType! + (.const ``TrivialInductivePropFields []) + (.const ``Nat []) + +-- Structure type with trivial structure + +structure TrivialStructure : Type where + a : Nat + +#eval checkMonoType! + (.const ``TrivialStructure []) + (.const ``Nat []) + +-- Structure type with trivial structure and irrelevant fields + +structure TrivialStructurePropFields : Type where + p₁ : Prop + a : Nat + p₂ : Prop + +#eval checkMonoType! + (.const ``TrivialStructurePropFields []) + (.const ``Nat []) + +-- Nat → Nat + +#eval checkMonoType! + (.forallE `a (.const ``Nat []) (.const ``Nat []) .default) + (.forallE `a (.const ``Nat []) (.const ``Nat []) .default) + +-- Nat → List Nat + +#eval checkMonoType! + (.forallE `a (.const ``Nat []) (.app (.const ``List [.succ .zero]) (.const ``Nat [])) .default) + (.forallE `a (.const ``Nat []) (.app (.const ``List []) (.const ``Nat [])) .default) + +-- Nat → Prop + +#eval checkMonoType! + (.forallE `a (.const ``Nat []) (.sort .zero) .default) + (.const ``lcErased []) + +-- Nat → Type + +#eval checkMonoType! + (.forallE `a (.const ``Nat []) (.sort (.succ .zero)) .default) + (.const ``lcErased []) + +-- Nat → Bool → Type + +#eval checkMonoType! + (.forallE `a + (.const ``Nat []) + (.forallE `a (.const ``Bool []) (.sort (.succ .zero)) .default) + .default) + (.const ``lcErased []) + +-- (α : Type) → List α + +#eval checkMonoType! + (.forallE `α (.sort (.succ .zero)) (.app (.const ``List [.succ .zero]) (.bvar 0)) .default) + (.forallE `α (.const ``lcErased []) (.app (.const ``List []) (.const ``lcAny [])) .default) + +-- List Nat → List Bool + +#eval checkMonoType! + (.forallE `a + (.app (.const ``List [.succ .zero]) (.const ``Nat [])) + (.app (.const ``List [.succ .zero]) (.const ``Bool [])) + .default) + (.forallE `a + (.app (.const ``List []) (.const ``Nat [])) + (.app (.const ``List []) (.const ``Bool [])) + .default) + +-- List Nat → List Prop + +#eval checkMonoType! + (.forallE `a + (.app (.const ``List [.succ .zero]) (.const ``Nat [])) + (.app (.const ``List [.succ .zero]) (.sort .zero)) + .default) + (.forallE `a + (.app (.const ``List []) (.const ``Nat [])) + (.app (.const ``List []) (.const ``lcErased [])) + .default) + +-- (α : Type) → α → α + +#eval checkMonoType! + (.forallE `α + (.sort (.succ .zero)) + (.forallE `a (.bvar 0) (.bvar 1) .default) + .default) + (.forallE `α + (.const ``lcErased []) + (.forallE `a (.const ``lcAny []) (.const ``lcAny []) .default) + .default) + +-- Nat → (α : Type) → α → Bool + +#eval checkMonoType! + (.forallE `a + (.const ``Nat []) + (.forallE `α + (.sort (.succ .zero)) + (.forallE `a (.bvar 0) (.const ``Bool []) .default) + .default) + .default) + (.forallE `a + (.const ``Nat []) + (.forallE `α + (.const ``lcErased []) + (.forallE `a (.const ``lcAny []) (.const ``Bool []) .default) + .default) + .default) + +-- Nat → Bool → Type + +#eval checkMonoType! + (.forallE `a + (.const ``Nat []) + (.forallE `b (.const ``Bool []) (.sort (.succ .zero)) .default) + .default) + (.const ``lcErased []) + +-- Nat → Bool → (Nat → Type) + +#eval checkMonoType! + (.forallE `a + (.const ``Nat []) + (.forallE `b (.const ``Bool []) (.sort (.succ .zero)) .default) + .default) + (.const ``lcErased []) + +-- Nat → (Nat → Type) → Bool + +#eval checkMonoType! + (.forallE `a + (.const ``Nat []) + (.forallE `b + (.forallE `c (.const ``lcErased []) (.sort (.succ .zero)) .default) + (.const ``Bool []) + .default) + .default) + (.forallE `a + (.const ``Nat []) + (.forallE `b + (.const ``lcErased []) + (.const ``Bool []) + .default) + .default) + +-- (α : Sort u) → (β : α → Sort v) → (a : α) → ((x : α) → β x) → β a + +#eval checkMonoType! + (.forallE + `α + (.sort (.param `u)) + (.forallE + `β + (.forallE `f1 (.bvar 0) (.sort (.param `v)) .default) + (.forallE + `a + (.bvar 1) + (.forallE + `f2 + (.forallE `x (.bvar 2) (.app (.bvar 2) (.bvar 0)) .default) + (.app (.bvar 2) (.bvar 1)) + .default) + .default) + .default) + .default) + (.forallE + `α + (.const ``lcErased []) + (.forallE + `β + (.const ``lcErased []) + (.forallE + `a + (.const ``lcAny []) + (.forallE + `f2 + (.forallE `x (.const ``lcAny []) (.const ``lcAny []) .default) + (.const ``lcAny []) + .default) + .default) + .default) + .default)