Skip to content

Commit

Permalink
feat: add proper erasure of type dependencies in LCNF (#6678)
Browse files Browse the repository at this point in the history
This PR modifies LCNF.toMonoType to use a more refined type erasure
scheme, which distinguishes between irrelevant/erased information
(represented by lcErased) and erased type dependencies (represented by
lcAny). This corresponds to the irrelevant/object distinction in the old
code generator.
  • Loading branch information
zwarich authored Jan 21, 2025
1 parent e3771e3 commit c54287f
Show file tree
Hide file tree
Showing 6 changed files with 293 additions and 27 deletions.
30 changes: 15 additions & 15 deletions src/Lean/Compiler/LCNF/MonoTypes.lean
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,23 @@ The type contains only `→` and constants.
-/
partial def toMonoType (type : Expr) : CoreM Expr := do
let type := type.headBeta
if type.isErased then
return erasedExpr
else if isTypeFormerType type then
return erasedExpr
else match type with
| .const .. => visitApp type #[]
| .app .. => type.withApp visitApp
| .forallE _ d b _ => mkArrow (← toMonoType d) (← toMonoType (b.instantiate1 erasedExpr))
| _ => return erasedExpr
match type with
| .const .. => visitApp type #[]
| .app .. => type.withApp visitApp
| .forallE _ d b _ =>
let monoB ← toMonoType (b.instantiate1 anyExpr)
match monoB with
| .const ``lcErased _ => return erasedExpr
| _ => mkArrow (← toMonoType d) monoB
| .sort _ => return erasedExpr
| _ => return anyExpr
where
visitApp (f : Expr) (args : Array Expr) : CoreM Expr := do
match f with
| .const ``lcErased _ => return erasedExpr
| .const ``lcAny _ => return anyExpr
| .const ``Decidable _ => return mkConst ``Bool
| .const declName us =>
if declName == ``Decidable then
return mkConst ``Bool
if let some info ← hasTrivialStructure? declName then
let ctorType ← getOtherDeclBaseType info.ctorName []
toMonoType (getParamTypes (← instantiateForall ctorType args[:info.numParams]))[info.fieldIdx]!
Expand All @@ -96,15 +98,13 @@ where
for arg in args do
let .forallE _ d b _ := type.headBeta | unreachable!
let arg := arg.headBeta
if arg.isErased then
result := mkApp result arg
else if d.isErased || d matches .sort _ then
if d matches .const ``lcErased _ | .sort _ then
result := mkApp result (← toMonoType arg)
else
result := mkApp result erasedExpr
type := b.instantiate1 arg
return result
| _ => return erasedExpr
| _ => return anyExpr

/--
State for the environment extension used to save the LCNF mono phase type for declarations
Expand Down
5 changes: 4 additions & 1 deletion src/Lean/Compiler/LCNF/PrettyPrinter.lean
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def ppLetDecl (letDecl : LetDecl) : M Format := do
return f!"let {letDecl.binderName} := {← ppLetValue letDecl.value}"

def getFunType (ps : Array Param) (type : Expr) : CoreM Expr :=
instantiateForall type (ps.map (mkFVar ·.fvarId))
if type.isErased then
pure type
else
instantiateForall type (ps.map (mkFVar ·.fvarId))

mutual
partial def ppFunDecl (funDecl : FunDecl) : M Format := do
Expand Down
1 change: 1 addition & 0 deletions src/Lean/Compiler/LCNF/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ scoped notation:max "◾" => lcErased
namespace LCNF

def erasedExpr := mkConst ``lcErased
def anyExpr := mkConst ``lcAny

def _root_.Lean.Expr.isErased (e : Expr) :=
e.isAppOf ``lcErased
Expand Down
18 changes: 8 additions & 10 deletions tests/lean/lcnfTypes.lean.expected.out
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,14 @@ weird1 : Bool → ◾
lamAny₁ : Bool → Monad ◾
lamAny₂ : Bool → Monad ◾
Term.constFold : List Ty → Ty → _root_.Term lcErased lcErased → _root_.Term lcErased lcErased
Term.denote : List Ty → Ty → _root_.Term lcErased lcErased → HList Ty lcErased lcErased → lcErased
HList.get : lcErased →
lcErased → List lcErased → lcErased → HList lcErased lcErased lcErased → Member lcErased lcErased lcErased → lcErased
Member.head : lcErased → lcErased → List lcErased → Member lcErased lcErased lcErased
Term.denote : lcErased
HList.get : lcErased → lcErased → List lcAny → lcAny → HList lcAny lcErased lcErased → Member lcAny lcErased lcErased → lcAny
Member.head : lcErased → lcAny → List lcAny → Member lcAny lcErased lcErased
Ty.denote : lcErased
MonadControl.liftWith : lcErased →
lcErased → MonadControl lcErased lcErased → lcErased → ((lcErased → lcErased → lcErased) → lcErased) → lcErased
MonadControl.restoreM : lcErased → lcErased → MonadControl lcErased lcErased → lcErased → lcErased → lcErased
Decidable.casesOn : lcErased → lcErased → Bool → (lcErased → lcErased) → (lcErased → lcErased) → lcErased
Lean.getConstInfo : lcErased → Monad lcErased → MonadEnv lcErased → MonadError lcErased → Name → lcErased
MonadControl.liftWith : lcErased → lcErased → MonadControl lcErased lcErased → lcErased → ((lcErased → lcAny → lcAny) → lcAny) → lcAny
MonadControl.restoreM : lcErased → lcErased → MonadControl lcErased lcErased → lcErased → lcAny → lcAny
Decidable.casesOn : lcErased → lcErased → Bool → (lcErased → lcAny) → (lcErased → lcAny) → lcAny
Lean.getConstInfo : lcErased → Monad lcErased → MonadEnv lcErased → MonadError lcErased → Name → lcAny
Lean.Meta.instMonadMetaM : Monad lcErased
Lean.Meta.inferType : Expr → Context → lcErased → Core.Context → lcErased → PUnit → EStateM.Result Exception PUnit Expr
Lean.Elab.Term.elabTerm : Syntax →
Expand All @@ -54,4 +52,4 @@ Lean.Elab.Term.elabTerm : Syntax →
lcErased → Context → lcErased → Core.Context → lcErased → PUnit → EStateM.Result Exception PUnit Expr
Nat.add : Nat → Nat → Nat
Fin.add : Nat → Nat → Nat → Nat
Lean.HashSetBucket.update : lcErased → Array (List lcErased) → USize → List lcErased → lcErased → Array (List lcErased)
Lean.HashSetBucket.update : lcErased → Array (List lcAny) → USize → List lcAny → lcErased → Array (List lcAny)
2 changes: 1 addition & 1 deletion tests/lean/run/erased.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ info: [Compiler.result] size: 1
let _x.1 : PSigma lcErased lcErased := PSigma.mk lcErased ◾ ◾ ◾;
return _x.1
[Compiler.result] size: 1
def Erased.mk (α : lcErased) (a : lcErased) : PSigma lcErased lcErased :=
def Erased.mk (α : lcErased) (a : lcAny) : PSigma lcErased lcErased :=
let _x.1 : PSigma lcErased lcErased := PSigma.mk lcErased ◾ ◾ ◾;
return _x.1
-/
Expand Down
264 changes: 264 additions & 0 deletions tests/lean/run/lcnfErasure.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
import Lean
import Lean.Compiler.LCNF.MonoTypes
import Lean.Compiler.LCNF.Types

open Lean Meta
open Compiler.LCNF (toLCNFType toMonoType)

def toMonoLCNFType (type : Expr) : MetaM Expr := do
toMonoType (← toLCNFType type)

def checkMonoType! (type₁ type₂ : Expr) : MetaM Unit := do
let monoType ← toMonoLCNFType type₁
if monoType != type₂ then
throwError f!"mono type for {type₁} is {monoType}, expected {type₂}"
let monoMonoType ← toMonoType monoType
if monoMonoType != monoType then
throwError f!"toMonoType is not idempotent: toMonoType of {monoType} is {monoMonoType}"

-- Nat

#eval checkMonoType!
(.const ``Nat [])
(.const ``Nat [])

-- Decidable

#eval checkMonoType!
(.const ``Decidable [])
(.const ``Bool [])

-- Prop

#eval checkMonoType!
(.sort .zero)
(.const ``lcErased [])

-- Type

#eval checkMonoType!
(.sort (.succ .zero))
(.const ``lcErased [])

-- Sort u

#eval checkMonoType!
(.sort (.param `u))
(.const ``lcErased [])

-- List Nat

#eval checkMonoType!
(.app (.const ``List [.succ .zero]) (.const ``Nat []))
(.app (.const ``List []) (.const ``Nat []))

-- List Type

#eval checkMonoType!
(.app (.const ``List [.succ (.succ .zero)]) (.sort (.succ .zero)))
(.app (.const ``List []) (.const ``lcErased []))

-- Inductive type with trivial structure

inductive TrivialInductive : Type where
| constructor (a : Nat) : TrivialInductive

#eval checkMonoType!
(.const ``TrivialInductive [])
(.const ``Nat [])

-- Inductive type with trivial structure and irrelevant fields

inductive TrivialInductivePropFields : Type where
| constructor (p₁ : Prop) (a : Nat) (p₂ : Prop) : TrivialInductivePropFields

#eval checkMonoType!
(.const ``TrivialInductivePropFields [])
(.const ``Nat [])

-- Structure type with trivial structure

structure TrivialStructure : Type where
a : Nat

#eval checkMonoType!
(.const ``TrivialStructure [])
(.const ``Nat [])

-- Structure type with trivial structure and irrelevant fields

structure TrivialStructurePropFields : Type where
p₁ : Prop
a : Nat
p₂ : Prop

#eval checkMonoType!
(.const ``TrivialStructurePropFields [])
(.const ``Nat [])

-- Nat → Nat

#eval checkMonoType!
(.forallE `a (.const ``Nat []) (.const ``Nat []) .default)
(.forallE `a (.const ``Nat []) (.const ``Nat []) .default)

-- Nat → List Nat

#eval checkMonoType!
(.forallE `a (.const ``Nat []) (.app (.const ``List [.succ .zero]) (.const ``Nat [])) .default)
(.forallE `a (.const ``Nat []) (.app (.const ``List []) (.const ``Nat [])) .default)

-- Nat → Prop

#eval checkMonoType!
(.forallE `a (.const ``Nat []) (.sort .zero) .default)
(.const ``lcErased [])

-- Nat → Type

#eval checkMonoType!
(.forallE `a (.const ``Nat []) (.sort (.succ .zero)) .default)
(.const ``lcErased [])

-- Nat → Bool → Type

#eval checkMonoType!
(.forallE `a
(.const ``Nat [])
(.forallE `a (.const ``Bool []) (.sort (.succ .zero)) .default)
.default)
(.const ``lcErased [])

-- (α : Type) → List α

#eval checkMonoType!
(.forallE `α (.sort (.succ .zero)) (.app (.const ``List [.succ .zero]) (.bvar 0)) .default)
(.forallE `α (.const ``lcErased []) (.app (.const ``List []) (.const ``lcAny [])) .default)

-- List Nat → List Bool

#eval checkMonoType!
(.forallE `a
(.app (.const ``List [.succ .zero]) (.const ``Nat []))
(.app (.const ``List [.succ .zero]) (.const ``Bool []))
.default)
(.forallE `a
(.app (.const ``List []) (.const ``Nat []))
(.app (.const ``List []) (.const ``Bool []))
.default)

-- List Nat → List Prop

#eval checkMonoType!
(.forallE `a
(.app (.const ``List [.succ .zero]) (.const ``Nat []))
(.app (.const ``List [.succ .zero]) (.sort .zero))
.default)
(.forallE `a
(.app (.const ``List []) (.const ``Nat []))
(.app (.const ``List []) (.const ``lcErased []))
.default)

-- (α : Type) → α → α

#eval checkMonoType!
(.forallE `α
(.sort (.succ .zero))
(.forallE `a (.bvar 0) (.bvar 1) .default)
.default)
(.forallE `α
(.const ``lcErased [])
(.forallE `a (.const ``lcAny []) (.const ``lcAny []) .default)
.default)

-- Nat → (α : Type) → α → Bool

#eval checkMonoType!
(.forallE `a
(.const ``Nat [])
(.forallE `α
(.sort (.succ .zero))
(.forallE `a (.bvar 0) (.const ``Bool []) .default)
.default)
.default)
(.forallE `a
(.const ``Nat [])
(.forallE `α
(.const ``lcErased [])
(.forallE `a (.const ``lcAny []) (.const ``Bool []) .default)
.default)
.default)

-- Nat → Bool → Type

#eval checkMonoType!
(.forallE `a
(.const ``Nat [])
(.forallE `b (.const ``Bool []) (.sort (.succ .zero)) .default)
.default)
(.const ``lcErased [])

-- Nat → Bool → (Nat → Type)

#eval checkMonoType!
(.forallE `a
(.const ``Nat [])
(.forallE `b (.const ``Bool []) (.sort (.succ .zero)) .default)
.default)
(.const ``lcErased [])

-- Nat → (Nat → Type) → Bool

#eval checkMonoType!
(.forallE `a
(.const ``Nat [])
(.forallE `b
(.forallE `c (.const ``lcErased []) (.sort (.succ .zero)) .default)
(.const ``Bool [])
.default)
.default)
(.forallE `a
(.const ``Nat [])
(.forallE `b
(.const ``lcErased [])
(.const ``Bool [])
.default)
.default)

-- (α : Sort u) → (β : α → Sort v) → (a : α) → ((x : α) → β x) → β a

#eval checkMonoType!
(.forallE
(.sort (.param `u))
(.forallE
(.forallE `f1 (.bvar 0) (.sort (.param `v)) .default)
(.forallE
`a
(.bvar 1)
(.forallE
`f2
(.forallE `x (.bvar 2) (.app (.bvar 2) (.bvar 0)) .default)
(.app (.bvar 2) (.bvar 1))
.default)
.default)
.default)
.default)
(.forallE
(.const ``lcErased [])
(.forallE
(.const ``lcErased [])
(.forallE
`a
(.const ``lcAny [])
(.forallE
`f2
(.forallE `x (.const ``lcAny []) (.const ``lcAny []) .default)
(.const ``lcAny [])
.default)
.default)
.default)
.default)

0 comments on commit c54287f

Please sign in to comment.