From 51aad7b80ca7cc9ffeec646415df259a626cc18f Mon Sep 17 00:00:00 2001 From: Alex Keizer Date: Thu, 26 Sep 2024 20:46:44 -0500 Subject: [PATCH] refactor: track `h_run` as an `Expr` --- Tactics/Sym.lean | 85 +++++++++++++++++++++------------------- Tactics/Sym/Context.lean | 49 +++++++++++++++-------- 2 files changed, 76 insertions(+), 58 deletions(-) diff --git a/Tactics/Sym.lean b/Tactics/Sym.lean index 5a62f834..3820f6fe 100644 --- a/Tactics/Sym.lean +++ b/Tactics/Sym.lean @@ -87,7 +87,7 @@ for some metavariable `?runSteps`, then create the proof obligation `?runSteps = _ + 1`, and attempt to close it using `whileTac`. Finally, we use this proof to change the type of `h_run` accordingly. -/ -def unfoldRun (whileTac : Unit → TacticM Unit) : SymReaderM Unit := do +def unfoldRun (whileTac : Unit → TacticM Unit) : SymM Unit := do let c ← readThe SymContext let msg := m!"unfoldRun (runSteps? := {c.runSteps?})" withTraceNode `Tactic.sym (fun _ => pure msg) <| @@ -102,13 +102,14 @@ def unfoldRun (whileTac : Unit → TacticM Unit) : SymReaderM Unit := do -- `sym_n` that, if the number of runSteps is statically known, -- that we never simulate more than that many steps | none => withMainContext' do - let mut goal :: originalGoals ← getGoals - | throwNoGoalsToBeSolved - let hRunDecl ← c.hRunDecl + let hRun := c.hRun -- Assert that `h_run : = run ?runSteps ` let runSteps ← mkFreshExprMVar (mkConst ``Nat) - guard <|← isDefEq hRunDecl.type ( + -- TODO(@alexkeizer): if the following guard should never fail, and + -- doesn't assign mvars, then why do we do it? + -- We should hide it behind the `Tactic.sym.debug` option + guard <|← isDefEq hRun ( mkApp3 (.const ``Eq [1]) (mkConst ``ArmState) c.finalState (mkApp2 (mkConst ``_root_.run) runSteps (← getCurrentState))) @@ -119,46 +120,44 @@ def unfoldRun (whileTac : Unit → TacticM Unit) : SymReaderM Unit := do let runStepsPredId ← mkFreshMVarId let runStepsPred ← mkFreshExprMVarWithId runStepsPredId (mkConst ``Nat) let subGoalTyRhs := mkApp (mkConst ``Nat.succ) runStepsPred - let subGoalTy := -- `?runSteps = ?runStepsPred + 1` + let runStepsEqTy := -- `?runSteps = ?runStepsPred + 1` mkApp3 (.const ``Eq [1]) (mkConst ``Nat) runSteps subGoalTyRhs - let subGoal ← mkFreshMVarId - let _ ← mkFreshExprMVarWithId subGoal subGoalTy + let runStepsEq ← mkFreshMVarId + let _ ← mkFreshExprMVarWithId runStepsEq runStepsEqTy let msg := m!"runSteps is not statically known, so attempt to prove:\ - {subGoal}" - withTraceNode `Tactic.sym (fun _ => pure msg) <| subGoal.withContext <| do - setGoals [subGoal] + {runStepsEq}" + withTraceNode `Tactic.sym (fun _ => pure msg) <| runStepsEq.withContext <| do + setGoals [runStepsEq] whileTac () -- run `whileTac` to attempt to close `subGoal` -- Ensure `runStepsPred` is assigned, by giving it a default value -- This is important because of the use of `replaceLocalDecl` below + -- TODO(@alexkeizer): we got rid of replaceLocalDecl, so we probably + -- can get rid of this, too, leaving the mvar unassigned if !(← runStepsPredId.isAssigned) then let default := mkApp (mkConst ``Nat.pred) runSteps trace[Tactic.sym] "{runStepsPred} is unassigned, \ so we assign to the default value ({default})" runStepsPredId.assign default - -- Change the type of `h_run` - let state ← getCurrentState - let typeNew ← do - let rhs := mkApp2 (mkConst ``_root_.run) subGoalTyRhs state - mkEq c.finalState rhs - let eqProof ← do - let f := -- `fun s => = s` - let eq := mkApp3 (.const ``Eq [1]) (mkConst ``ArmState) - c.finalState (.bvar 0) - .lam `s (mkConst ``ArmState) eq .default - let g := mkConst ``_root_.run - let h ← instantiateMVars (.mvar subGoal) - mkCongrArg f (←mkCongrFun (←mkCongrArg g h) state) - let res ← goal.replaceLocalDecl hRunDecl.fvarId typeNew eqProof + -- Change the type of `hRun` + let goal :: originalGoals ← getGoals + | throwNoGoalsToBeSolved + let rwRes ← goal.rewrite hRun (.mvar runStepsEq) + modifyThe SymContext ({ · with + hRun := rwRes.eNew + }) -- Restore goal state - if !(←subGoal.isAssigned) then - trace[Tactic.sym] "Subgoal {subGoal} was not closed yet, \ - so add it as a goal for the user to solve" - originalGoals := originalGoals.concat subGoal - setGoals (res.mvarId :: originalGoals) + let newGoal ← do + if (←runStepsEq.isAssigned) then + pure [] + else + trace[Tactic.sym] "Subgoal {runStepsEq} was not closed yet, \ + so add it as a goal for the user to solve" + pure [runStepsEq] + setGoals (rwRes.mvarIds ++ originalGoals ++ newGoal) /-- Break an equality `h_step : s{i+1} = w ... (... (w ... s{i})...)` into an `AxEffects` that characterizes the effects in terms of reads from `s{i+1}`, @@ -265,12 +264,19 @@ def sym1 (whileTac : TSyntax `tactic) : SymM Unit := do -- Add new state to local context let hRunId := mkIdent <|← getHRunName let nextStateId := mkIdent <|← getNextStateName - evalTacticAndTrace <|← `(tactic| + withMainContext' <| evalTacticAndTrace <|← `(tactic| init_next_step $hRunId:ident $stepi_eq:ident $nextStateId:ident ) - - -- Apply relevant pre-generated `stepi` lemma withMainContext' <| do + -- Update `hRun` + let hRun := hRunId.getId + let some hRun := (← getLCtx).findFromUserName? hRun + | throwError "internal error: couldn't find {hRun}" + modifyThe SymContext ({ · with + hRun := hRun.toExpr + }) + + -- Apply relevant pre-generated `stepi` lemma let stepiEq ← SymContext.findFromUserName stepi_eq.getId stepiTac stepiEq.toExpr h_step.getId @@ -326,9 +332,9 @@ def ensureAtMostRunSteps (n : Nat) : SymM Nat := do pure n else withMainContext <| do - let hRun ← ctx.hRunDecl + let hRun := ctx.hRun logWarning m!"Symbolic simulation is limited to at most {runSteps} \ - steps, because {hRun.toExpr} is of type:\n {hRun.type}" + steps, because {hRun} is of type:\n {← inferType hRun}" pure runSteps return n @@ -421,17 +427,14 @@ elab "sym_n" whileTac?:(sym_while)? n:num s:(sym_at)? : tactic => do let c ← getThe SymContext -- Check if we can substitute the final state if c.runSteps? = some 0 then - let msg := do - let hRun ← userNameToMessageData c.h_run - pure m!"runSteps := 0, substituting along {hRun}" - withTraceNode `Tactic.sym (fun _ => msg) <| withMainContext' do + let msg := pure m!"runSteps := 0, substituting along {c.hRun}" + withMainContext' <| withTraceNode `Tactic.sym (fun _ => msg) <| do let sfEq ← mkEq (← getCurrentState) c.finalState let goal ← getMainGoal trace[Tactic.sym] "original goal:\n{goal}" let ⟨hEqId, goal⟩ ← do - let hRun ← SymContext.findFromUserName c.h_run - goal.note `this (← mkEqSymm hRun.toExpr) sfEq + goal.note `this (← mkEqSymm c.hRun) sfEq goal.withContext <| do trace[Tactic.sym] "added {← userNameToMessageData `this} of type \ {sfEq} in:\n{goal}" diff --git a/Tactics/Sym/Context.lean b/Tactics/Sym/Context.lean index 39f6f4ed..709ac45d 100644 --- a/Tactics/Sym/Context.lean +++ b/Tactics/Sym/Context.lean @@ -43,15 +43,15 @@ structure SymContext where If `runSteps?` is `some n`, where `n` is a meta-level `Nat`, then we expect that `` in type of `h_run` is the literal `n`. Otherwise, if `runSteps?` is `none`, - then `` is allowed to be anything, even a symbolic value. + then `` is allowed to be anything, including a symbolic value. See also `SymContext.h_run` -/ runSteps? : Option Nat - /-- `h_run` is a local hypothesis of the form - `finalState = run state` + /-- `hRun` is an expression of type + ` = run state` See also `SymContext.runSteps?` -/ - h_run : Name + hRun : Expr /-- `programInfo` is the relevant cached `ProgramInfo` -/ programInfo : ProgramInfo @@ -162,19 +162,33 @@ def findFromUserName (name : Name) : MetaM LocalDecl := do | throwError "Unknown local variable `{name}`" return decl -/-- Find the local declaration that corresponds to `c.h_run`, -or throw an error if no local variable of that name exists -/ -def hRunDecl : MetaM LocalDecl := do - findFromUserName c.h_run - section Monad variable {m} [Monad m] [MonadReaderOf SymContext m] def getCurrentStateNumber : m Nat := do return (← read).currentStateNumber -/-- Return the name of the hypothesis - `h_run : = run ` -/ -def getHRunName : m Name := do return (← read).h_run +/-- Return an expression of type + ` = run ` -/ +def getHRun : m Expr := do return (← read).hRun + +/-- Return the `Name` of a variable of type + ` = run ` + +This will return the name of `hRun`, if its an fvar. +Otherwise, add a new variable to the local context, and return the new name. +Note that `hRun` is not modified in either case. -/ +def getHRunName [MonadLiftT TacticM m] [MonadLiftT MetaM m] [MonadError m] + [MonadLCtx m] : + m Name := do + let hRun ← getHRun + if let Expr.fvar id := hRun then + let some decl := (← getLCtx).find? id + | throwError "Unknown fvar {Expr.fvar id}" + return decl.userName + else + let ⟨_id, goal⟩ ← (← getMainGoal).note `h_run hRun none + replaceMainGoal [goal] + return `h_run /-- Retrieve the name for the next state @@ -194,12 +208,11 @@ end /-- Convert a `SymContext` to `MessageData` for tracing. This is not a `ToMessageData` instance because we need access to `MetaM` -/ def toMessageData (c : SymContext) : MetaM MessageData := do - let h_run ← userNameToMessageData c.h_run let h_sp? ← c.h_sp?.mapM userNameToMessageData return m!"\{ finalState := {c.finalState}, runSteps? := {c.runSteps?}, - h_run := {h_run}, + hRun := {c.hRun}, program := {c.program}, pc := {c.pc}, h_sp? := {h_sp?}, @@ -353,7 +366,7 @@ def fromLocalContext (state? : Option Name) : TacticM SymContext := do } let c : SymContext := { finalState, runSteps?, pc, - h_run := h_run.userName, + hRun := h_run.toExpr, h_sp? := (·.userName) <$> h_sp?, programInfo, effects, @@ -390,8 +403,6 @@ def canonicalizeHypothesisTypes : SymReaderM Unit := withMainContext' do let state := c.effects.currentState let mut hyps := #[] - if let some runSteps := c.runSteps? then - hyps := hyps.push (c.h_run, h_run_type c.finalState (toExpr runSteps) state) if let some h_sp := c.h_sp? then hyps := hyps.push (h_sp, h_sp_type state) @@ -400,6 +411,10 @@ def canonicalizeHypothesisTypes : SymReaderM Unit := withMainContext' do | throwError "Unknown local hypothesis `{name}`" pure (decl.fvarId, type) + + if let some runSteps := c.runSteps? then + hypIds := hypIds.push + (c.hRun.fvarId!, h_run_type c.finalState (toExpr runSteps) state) let errHyp ← AxEffects.getFieldM .ERR if let Expr.fvar id := errHyp.proof then hypIds := hypIds.push (id, h_err_type state)