Skip to content

Commit

Permalink
fix: do not introduced additional lets that block match generalization
Browse files Browse the repository at this point in the history
Match generalization is still blocked by additional `let`s inserted by Lean itself, but this will hopefully change in a future release.
  • Loading branch information
eric-wieser committed Dec 31, 2023
1 parent ccba5d3 commit 0ad8d21
Showing 1 changed file with 12 additions and 23 deletions.
35 changes: 12 additions & 23 deletions Qq/Match.lean
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,9 @@ scoped elab "_qq_match" pat:term " ← " e:term " | " alt:term " in " body:term
makeMatchCode q($inst2) inst oldPatVarDecls argLvlExpr argTyExpr synthed q($e') alt expectedType fun expectedType =>
return Quoted.unsafeMk (← elabTerm body expectedType)

scoped syntax "_qq_match" term " " term " | " doSeq : term
scoped syntax "_qq_match" term " := " term " | " doSeq : term
macro_rules
| `(assert! (_qq_match $pat $e | $alt); $x) => `(_qq_match $pat ← $e | (do $alt) in $x)
| `(assert! (_qq_match $pat := $e | $alt); $x) => `(_qq_match $pat ← $e | (do $alt) in $x)

partial def isIrrefutablePattern : Term → Bool
| `(($stx)) => isIrrefutablePattern stx
Expand All @@ -257,14 +257,14 @@ macro_rules | `(assert! (_comefrom $n do $b); $body) => `(_comefrom $n do $b in
scoped macro "comefrom" n:ident "do" b:doSeq : doElem =>
`(doElem| assert! (_comefrom $n do $b))

def mkLetDoSeqItem [Monad m] [MonadQuotation m] (pat : Term) (rhs : TSyntax `doElem) (alt : TSyntax ``doSeq) : m (List (TSyntax ``doSeqItem)) := do
def mkLetDoSeqItem [Monad m] [MonadQuotation m] (pat : Term) (rhs : TSyntax `term) (alt : TSyntax ``doSeq) : m (List (TSyntax ``doSeqItem)) := do
match pat with
| `(_) => return []
| _ =>
if isIrrefutablePattern pat then
return [← `(doSeqItem| let $pat:term $rhs)]
return [← `(doSeqItem| let $pat:term := $rhs)]
else
return [← `(doSeqItem| let $pat:term $rhs | $alt)]
return [← `(doSeqItem| let $pat:term := $rhs | $alt)]

end Impl

Expand Down Expand Up @@ -299,7 +299,7 @@ private partial def floatLevelAntiquot (stx : Syntax.Level) : StateT (Array (TSy
if stx.1.isAntiquot && !stx.1.isEscapedAntiquot then
if !stx.1.getAntiquotTerm.isIdent then
withFreshMacroScope do
push <|<- `(doSeqItem| let u : Level := $(⟨stx.1.getAntiquotTerm⟩))
push <| `(doSeqItem| let u : Level := $(⟨stx.1.getAntiquotTerm⟩))
`(level| u)
else
pure stx
Expand Down Expand Up @@ -327,32 +327,24 @@ private partial def floatExprAntiquot (depth : Nat) : Term → StateT (Array (TS
return ⟨addSyntaxDollar id⟩
| none => pure ()
withFreshMacroScope do
push <|<- `(doSeqItem| let a : Quoted _ := $term)
push <| `(doSeqItem| let a : Quoted _ := $term)
return ⟨addSyntaxDollar (← `(a))⟩
else
match stx with
| ⟨.node i k args⟩ => return ⟨.node i k (← args.mapM (floatExprAntiquot depth ⟨·⟩))⟩
| stx => return stx

macro_rules
| `(doElem| let $pat:term $_) => do
| `(doElem| let $pat:term := $_) => do
if !hasQMatch pat then Macro.throwUnsupported
Macro.throwError "let-bindings with ~q(.) require an explicit alternative"

| `(doElem| let $pat:term $rhs:doElem | $alt:doSeq) => do
| `(doElem| let $pat:term := $rhs:term | $alt:doSeq) => do
if !hasQMatch pat then Macro.throwUnsupported
match pat with
| `(~q($pat)) =>
let (pat, lifts) ← floatExprAntiquot 0 pat #[]

let mut t ← (do
match rhs with
| `(doElem| $id:ident $rhs:term) =>
if id.getId.eraseMacroScopes == `pure then -- TODO: super hacky
return ← `(doSeqItem| assert! (_qq_match $pat ← $rhs | $alt))
| _ => pure ()
`(doSeqItem| do let rhs ← $rhs; assert! (_qq_match $pat ← rhs | $alt)))

let t ← `(doSeqItem| do assert! (_qq_match $pat := $rhs | $alt))
`(doElem| do $(lifts.push t):doSeqItem*)

| _ =>
Expand All @@ -369,15 +361,12 @@ macro_rules

| `(doElem| match $[$discrs:term],* with $[| $[$patss],* => $rhss]*) => do
if !patss.any (·.any (hasQMatch ·)) then Macro.throwUnsupported
let discrs ← discrs.mapM fun d => withFreshMacroScope do
pure (← `(x), ← `(doSeqItem| let x := $d:term))
let mut items := discrs.map (·.2)
let discrs := discrs.map (·.1)
let mut items := #[]
items := items.push (← `(doSeqItem| comefrom alt do throwError "nonexhaustive match"))
for pats in patss.reverse, rhs in rhss.reverse do
let mut subItems : Array (TSyntax ``doSeqItem) := #[]
for discr in discrs, pat in pats do
subItems := subItems ++ (← mkLetDoSeqItem pat (← `(doElem| pure $discr:term)) (← `(doSeq| alt)))
subItems := subItems ++ (← mkLetDoSeqItem pat discr (← `(doSeq| alt)))
subItems := subItems.push (← `(doSeqItem| do $rhs))
items := items.push (← `(doSeqItem| comefrom alt do $subItems:doSeqItem*))
items := items.push (← `(doSeqItem| alt))
Expand Down

0 comments on commit 0ad8d21

Please sign in to comment.