Skip to content

Commit

Permalink
refactor: track h_run as an Expr
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkeizer committed Sep 27, 2024
1 parent 3f6b34e commit 51aad7b
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 58 deletions.
85 changes: 44 additions & 41 deletions Tactics/Sym.lean
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ for some metavariable `?runSteps`, then create the proof obligation
`?runSteps = _ + 1`, and attempt to close it using `whileTac`.
Finally, we use this proof to change the type of `h_run` accordingly.
-/
def unfoldRun (whileTac : Unit → TacticM Unit) : SymReaderM Unit := do
def unfoldRun (whileTac : Unit → TacticM Unit) : SymM Unit := do
let c ← readThe SymContext
let msg := m!"unfoldRun (runSteps? := {c.runSteps?})"
withTraceNode `Tactic.sym (fun _ => pure msg) <|
Expand All @@ -102,13 +102,14 @@ def unfoldRun (whileTac : Unit → TacticM Unit) : SymReaderM Unit := do
-- `sym_n` that, if the number of runSteps is statically known,
-- that we never simulate more than that many steps
| none => withMainContext' do
let mut goal :: originalGoals ← getGoals
| throwNoGoalsToBeSolved
let hRunDecl ← c.hRunDecl
let hRun := c.hRun

-- Assert that `h_run : <finalState> = run ?runSteps <state>`
let runSteps ← mkFreshExprMVar (mkConst ``Nat)
guard <|← isDefEq hRunDecl.type (
-- TODO(@alexkeizer): if the following guard should never fail, and
-- doesn't assign mvars, then why do we do it?
-- We should hide it behind the `Tactic.sym.debug` option
guard <|← isDefEq hRun (
mkApp3 (.const ``Eq [1]) (mkConst ``ArmState)
c.finalState
(mkApp2 (mkConst ``_root_.run) runSteps (← getCurrentState)))
Expand All @@ -119,46 +120,44 @@ def unfoldRun (whileTac : Unit → TacticM Unit) : SymReaderM Unit := do
let runStepsPredId ← mkFreshMVarId
let runStepsPred ← mkFreshExprMVarWithId runStepsPredId (mkConst ``Nat)
let subGoalTyRhs := mkApp (mkConst ``Nat.succ) runStepsPred
let subGoalTy := -- `?runSteps = ?runStepsPred + 1`
let runStepsEqTy := -- `?runSteps = ?runStepsPred + 1`
mkApp3 (.const ``Eq [1]) (mkConst ``Nat) runSteps subGoalTyRhs
let subGoal ← mkFreshMVarId
let _ ← mkFreshExprMVarWithId subGoal subGoalTy
let runStepsEq ← mkFreshMVarId
let _ ← mkFreshExprMVarWithId runStepsEq runStepsEqTy

let msg := m!"runSteps is not statically known, so attempt to prove:\
{subGoal}"
withTraceNode `Tactic.sym (fun _ => pure msg) <| subGoal.withContext <| do
setGoals [subGoal]
{runStepsEq}"
withTraceNode `Tactic.sym (fun _ => pure msg) <| runStepsEq.withContext <| do
setGoals [runStepsEq]
whileTac () -- run `whileTac` to attempt to close `subGoal`

-- Ensure `runStepsPred` is assigned, by giving it a default value
-- This is important because of the use of `replaceLocalDecl` below
-- TODO(@alexkeizer): we got rid of replaceLocalDecl, so we probably
-- can get rid of this, too, leaving the mvar unassigned
if !(← runStepsPredId.isAssigned) then
let default := mkApp (mkConst ``Nat.pred) runSteps
trace[Tactic.sym] "{runStepsPred} is unassigned, \
so we assign to the default value ({default})"
runStepsPredId.assign default

-- Change the type of `h_run`
let state ← getCurrentState
let typeNew ← do
let rhs := mkApp2 (mkConst ``_root_.run) subGoalTyRhs state
mkEq c.finalState rhs
let eqProof ← do
let f := -- `fun s => <finalState> = s`
let eq := mkApp3 (.const ``Eq [1]) (mkConst ``ArmState)
c.finalState (.bvar 0)
.lam `s (mkConst ``ArmState) eq .default
let g := mkConst ``_root_.run
let h ← instantiateMVars (.mvar subGoal)
mkCongrArg f (←mkCongrFun (←mkCongrArg g h) state)
let res ← goal.replaceLocalDecl hRunDecl.fvarId typeNew eqProof
-- Change the type of `hRun`
let goal :: originalGoals ← getGoals
| throwNoGoalsToBeSolved
let rwRes ← goal.rewrite hRun (.mvar runStepsEq)
modifyThe SymContext ({ · with
hRun := rwRes.eNew
})

-- Restore goal state
if !(←subGoal.isAssigned) then
trace[Tactic.sym] "Subgoal {subGoal} was not closed yet, \
so add it as a goal for the user to solve"
originalGoals := originalGoals.concat subGoal
setGoals (res.mvarId :: originalGoals)
let newGoal ← do
if (←runStepsEq.isAssigned) then
pure []
else
trace[Tactic.sym] "Subgoal {runStepsEq} was not closed yet, \
so add it as a goal for the user to solve"
pure [runStepsEq]
setGoals (rwRes.mvarIds ++ originalGoals ++ newGoal)

/-- Break an equality `h_step : s{i+1} = w ... (... (w ... s{i})...)` into an
`AxEffects` that characterizes the effects in terms of reads from `s{i+1}`,
Expand Down Expand Up @@ -265,12 +264,19 @@ def sym1 (whileTac : TSyntax `tactic) : SymM Unit := do
-- Add new state to local context
let hRunId := mkIdent <|← getHRunName
let nextStateId := mkIdent <|← getNextStateName
evalTacticAndTrace <|← `(tactic|
withMainContext' <| evalTacticAndTrace <|← `(tactic|
init_next_step $hRunId:ident $stepi_eq:ident $nextStateId:ident
)

-- Apply relevant pre-generated `stepi` lemma
withMainContext' <| do
-- Update `hRun`
let hRun := hRunId.getId
let some hRun := (← getLCtx).findFromUserName? hRun
| throwError "internal error: couldn't find {hRun}"
modifyThe SymContext ({ · with
hRun := hRun.toExpr
})

-- Apply relevant pre-generated `stepi` lemma
let stepiEq ← SymContext.findFromUserName stepi_eq.getId
stepiTac stepiEq.toExpr h_step.getId

Expand Down Expand Up @@ -326,9 +332,9 @@ def ensureAtMostRunSteps (n : Nat) : SymM Nat := do
pure n
else
withMainContext <| do
let hRun ctx.hRunDecl
let hRun := ctx.hRun
logWarning m!"Symbolic simulation is limited to at most {runSteps} \
steps, because {hRun.toExpr} is of type:\n {hRun.type}"
steps, because {hRun} is of type:\n {← inferType hRun}"
pure runSteps
return n

Expand Down Expand Up @@ -421,17 +427,14 @@ elab "sym_n" whileTac?:(sym_while)? n:num s:(sym_at)? : tactic => do
let c ← getThe SymContext
-- Check if we can substitute the final state
if c.runSteps? = some 0 then
let msg := do
let hRun ← userNameToMessageData c.h_run
pure m!"runSteps := 0, substituting along {hRun}"
withTraceNode `Tactic.sym (fun _ => msg) <| withMainContext' do
let msg := pure m!"runSteps := 0, substituting along {c.hRun}"
withMainContext' <| withTraceNode `Tactic.sym (fun _ => msg) <| do
let sfEq ← mkEq (← getCurrentState) c.finalState

let goal ← getMainGoal
trace[Tactic.sym] "original goal:\n{goal}"
let ⟨hEqId, goal⟩ ← do
let hRun ← SymContext.findFromUserName c.h_run
goal.note `this (← mkEqSymm hRun.toExpr) sfEq
goal.note `this (← mkEqSymm c.hRun) sfEq
goal.withContext <| do
trace[Tactic.sym] "added {← userNameToMessageData `this} of type \
{sfEq} in:\n{goal}"
Expand Down
49 changes: 32 additions & 17 deletions Tactics/Sym/Context.lean
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ structure SymContext where
If `runSteps?` is `some n`, where `n` is a meta-level `Nat`,
then we expect that `<runSteps>` in type of `h_run` is the literal `n`.
Otherwise, if `runSteps?` is `none`,
then `<runSteps>` is allowed to be anything, even a symbolic value.
then `<runSteps>` is allowed to be anything, including a symbolic value.
See also `SymContext.h_run` -/
runSteps? : Option Nat
/-- `h_run` is a local hypothesis of the form
`finalState = run <runSteps> state`
/-- `hRun` is an expression of type
`<finalState> = run <runSteps> state`
See also `SymContext.runSteps?` -/
h_run : Name
hRun : Expr
/-- `programInfo` is the relevant cached `ProgramInfo` -/
programInfo : ProgramInfo

Expand Down Expand Up @@ -162,19 +162,33 @@ def findFromUserName (name : Name) : MetaM LocalDecl := do
| throwError "Unknown local variable `{name}`"
return decl

/-- Find the local declaration that corresponds to `c.h_run`,
or throw an error if no local variable of that name exists -/
def hRunDecl : MetaM LocalDecl := do
findFromUserName c.h_run

section Monad
variable {m} [Monad m] [MonadReaderOf SymContext m]

def getCurrentStateNumber : m Nat := do return (← read).currentStateNumber

/-- Return the name of the hypothesis
`h_run : <finalState> = run <runSteps> <initialState>` -/
def getHRunName : m Name := do return (← read).h_run
/-- Return an expression of type
`<finalState> = run <runSteps> <initialState>` -/
def getHRun : m Expr := do return (← read).hRun

/-- Return the `Name` of a variable of type
`<finalState> = run <runSteps> <initialState>`
This will return the name of `hRun`, if its an fvar.
Otherwise, add a new variable to the local context, and return the new name.
Note that `hRun` is not modified in either case. -/
def getHRunName [MonadLiftT TacticM m] [MonadLiftT MetaM m] [MonadError m]
[MonadLCtx m] :
m Name := do
let hRun ← getHRun
if let Expr.fvar id := hRun then
let some decl := (← getLCtx).find? id
| throwError "Unknown fvar {Expr.fvar id}"
return decl.userName
else
let ⟨_id, goal⟩ ← (← getMainGoal).note `h_run hRun none
replaceMainGoal [goal]
return `h_run

/-- Retrieve the name for the next state
Expand All @@ -194,12 +208,11 @@ end
/-- Convert a `SymContext` to `MessageData` for tracing.
This is not a `ToMessageData` instance because we need access to `MetaM` -/
def toMessageData (c : SymContext) : MetaM MessageData := do
let h_run ← userNameToMessageData c.h_run
let h_sp? ← c.h_sp?.mapM userNameToMessageData

return m!"\{ finalState := {c.finalState},
runSteps? := {c.runSteps?},
h_run := {h_run},
hRun := {c.hRun},
program := {c.program},
pc := {c.pc},
h_sp? := {h_sp?},
Expand Down Expand Up @@ -353,7 +366,7 @@ def fromLocalContext (state? : Option Name) : TacticM SymContext := do
}
let c : SymContext := {
finalState, runSteps?, pc,
h_run := h_run.userName,
hRun := h_run.toExpr,
h_sp? := (·.userName) <$> h_sp?,
programInfo,
effects,
Expand Down Expand Up @@ -390,8 +403,6 @@ def canonicalizeHypothesisTypes : SymReaderM Unit := withMainContext' do
let state := c.effects.currentState

let mut hyps := #[]
if let some runSteps := c.runSteps? then
hyps := hyps.push (c.h_run, h_run_type c.finalState (toExpr runSteps) state)
if let some h_sp := c.h_sp? then
hyps := hyps.push (h_sp, h_sp_type state)

Expand All @@ -400,6 +411,10 @@ def canonicalizeHypothesisTypes : SymReaderM Unit := withMainContext' do
| throwError "Unknown local hypothesis `{name}`"
pure (decl.fvarId, type)


if let some runSteps := c.runSteps? then
hypIds := hypIds.push
(c.hRun.fvarId!, h_run_type c.finalState (toExpr runSteps) state)
let errHyp ← AxEffects.getFieldM .ERR
if let Expr.fvar id := errHyp.proof then
hypIds := hypIds.push (id, h_err_type state)
Expand Down

0 comments on commit 51aad7b

Please sign in to comment.