diff --git a/src/Init/System/IO.lean b/src/Init/System/IO.lean index 4a3d0c5bce8b..f37167e8143b 100644 --- a/src/Init/System/IO.lean +++ b/src/Init/System/IO.lean @@ -228,6 +228,13 @@ local macro "nonempty_list" : tactic => /-- Helper method for implementing "deterministic" timeouts. It is the number of "small" memory allocations performed by the current execution thread. -/ @[extern "lean_io_get_num_heartbeats"] opaque getNumHeartbeats : BaseIO Nat +/-- +Adjusts the heartbeat counter of the current thread by the given amount. This can be useful to give +allocation-avoiding code additional "weight" and is also used to adjust the counter after resuming +from a snapshot. +-/ +@[extern "lean_io_add_heartbeats"] opaque addHeartbeats (count : UInt64) : BaseIO Unit + /-- The mode of a file handle (i.e., a set of `open` flags and an `fdopen` mode). diff --git a/src/Lean/CoreM.lean b/src/Lean/CoreM.lean index ba8b73258558..9123fbf84599 100644 --- a/src/Lean/CoreM.lean +++ b/src/Lean/CoreM.lean @@ -173,16 +173,45 @@ instance : MonadTrace CoreM where getTraceState := return (← get).traceState modifyTraceState f := modify fun s => { s with traceState := f s.traceState } -/-- Restore backtrackable parts of the state. -/ -def restore (b : State) : CoreM Unit := - modify fun s => { s with env := b.env, messages := b.messages, infoState := b.infoState } +structure SavedState extends State where + /-- Number of heartbeats passed inside `withRestoreOrSaveFull`, not used otherwise. -/ + passedHearbeats : Nat +deriving Nonempty + +def saveState : CoreM SavedState := do + let s ← get + return { toState := s, passedHearbeats := 0 } /-- -Restores full state including sources for unique identifiers. Only intended for incremental reuse -between elaboration runs, not for backtracking within a single run. +Incremental reuse primitive: if `old?` is `none`, runs `cont` with an action `save` that on +execution returns the saved monadic state at this point including the heartbeats used by `cont` so +far. If `old?` on the other hand is `some (a, state)`, restores full `state` including heartbeats +used and returns `a`. + +The intention is for steps that support incremental reuse to initially pass `none` as `old?` and +call `save` as late as possible in `cont`. In a further run, if reuse is possible, `old?` should be +set to the previous state and result, ensuring that the state after running `withRestoreOrSaveFull` +is identical in both runs. Note however that necessarily this is only an approximation in the case +of heartbeats as heartbeats used by `withRestoreOrSaveFull`, by the remainder of `cont` after +calling `save`, as well as by reuse-handling code such as the one supplying `old?` are not accounted +for. -/ -def restoreFull (b : State) : CoreM Unit := - set b +@[specialize] def withRestoreOrSaveFull (old? : Option (α × SavedState)) + (cont : (save : CoreM SavedState) → CoreM α) : CoreM α := do + if let some (oldVal, oldState) := old? then + set oldState.toState + IO.addHeartbeats oldState.passedHearbeats.toUInt64 + return oldVal + + let s ← get + let startHeartbeats ← IO.getNumHeartbeats + cont (do + let stopHeartbeats ← IO.getNumHeartbeats + return { toState := s, passedHearbeats := stopHeartbeats - startHeartbeats }) + +/-- Restore backtrackable parts of the state. -/ +def SavedState.restore (b : SavedState) : CoreM Unit := + modify fun s => { s with env := b.env, messages := b.messages, infoState := b.infoState } private def mkFreshNameImp (n : Name) : CoreM Name := do let fresh ← modifyGet fun s => (s.nextMacroScope, { s with nextMacroScope := s.nextMacroScope + 1 }) diff --git a/src/Lean/Elab/MutualDef.lean b/src/Lean/Elab/MutualDef.lean index 85e00c5c4009..1f3c974e8a72 100644 --- a/src/Lean/Elab/MutualDef.lean +++ b/src/Lean/Elab/MutualDef.lean @@ -140,11 +140,11 @@ private def elabHeaders (views : Array DefView) let mut reuseBody := views.all (·.headerSnap?.any (·.old?.isSome)) for view in views, ⟨shortDeclName, declName, levelNames⟩ in expandedDeclIds, tacPromise in tacPromises, bodyPromise in bodyPromises do + let mut reusableResult? := none if let some snap := view.headerSnap? then -- by the `DefView.headerSnap?` invariant, safe to reuse results at this point, so let's -- wait for them! if let some old := snap.old?.bind (·.val.get) then - old.state.restoreFull let (tacStx?, newTacTask?) ← mkTacPromiseAndSnap view.value tacPromise snap.new.resolve <| some { old with tacStx? @@ -161,7 +161,7 @@ private def elabHeaders (views : Array DefView) -- we can reuse the result reuseBody := reuseBody && view.value.structRangeEqWithTraceReuse (← getOptions) old.bodyStx - headers := headers.push { old.view, view with + let header := { old.view, view with tacSnap? := some { old? := do guard reuseTac @@ -174,11 +174,12 @@ private def elabHeaders (views : Array DefView) new := bodyPromise } } - continue + reusableResult? := some (header, old.state) else reuseBody := false - let newHeader ← withRef view.ref do + let header ← withRestoreOrSaveFull reusableResult? fun save => do + withRef view.ref do addDeclarationRanges declName view.ref applyAttributesAt declName view.modifiers.attrs .beforeElaboration withDeclName declName <| withAutoBoundImplicit <| withLevelNames levelNames <| @@ -220,7 +221,7 @@ private def elabHeaders (views : Array DefView) diagnostics := (← Language.Snapshot.Diagnostics.ofMessageLog (← Core.getAndEmptyMessageLog)) view := newHeader.toDefViewElabHeaderData - state := (← saveState) + state := (← save) tacStx? tacSnap? := newTacTask? bodyStx := view.value @@ -232,7 +233,7 @@ private def elabHeaders (views : Array DefView) } check headers newHeader return newHeader - headers := headers.push newHeader + headers := headers.push header return headers where getBodyTerm? (stx : Syntax) : Option Syntax := @@ -333,38 +334,39 @@ private def declValToTerminationHint (declVal : Syntax) : TermElabM WF.Terminati private def elabFunValues (headers : Array DefViewElabHeader) : TermElabM (Array Expr) := headers.mapM fun header => do + let mut reusableResult? := none if let some snap := header.bodySnap? then if let some old := snap.old? then -- guaranteed reusable as by the `bodySnap?` invariant, so let's wait on the previous -- elaboration if let some old := old.val.get then - old.state.restoreFull snap.new.resolve <| some old -- also make sure to reuse tactic snapshots if present so that body reuse does not lead to -- missed tactic reuse on further changes if let some tacSnap := header.tacSnap? then if let some oldTacSnap := tacSnap.old? then tacSnap.new.resolve oldTacSnap.val.get - return old.value - - withDeclName header.declName <| withLevelNames header.levelNames do - let valStx ← liftMacroM <| declValToTerm header.value - forallBoundedTelescope header.type header.numParams fun xs type => do - -- Add new info nodes for new fvars. The server will detect all fvars of a binder by the binder's source location. - for i in [0:header.binderIds.size] do - -- skip auto-bound prefix in `xs` - addLocalVarInfo header.binderIds[i]! xs[header.numParams - header.binderIds.size + i]! - let val ← withReader ({ · with tacSnap? := header.tacSnap? }) do - elabTermEnsuringType valStx type <* Term.synthesizeSyntheticMVarsNoPostponing - let val ← mkLambdaFVars xs val - if let some snap := header.bodySnap? then - snap.new.resolve <| some { - diagnostics := - (← Language.Snapshot.Diagnostics.ofMessageLog (← Core.getAndEmptyMessageLog)) - state := (← saveState) - value := val - } - return val + reusableResult? := some (old.value, old.state) + + withRestoreOrSaveFull reusableResult? fun save => do + withDeclName header.declName <| withLevelNames header.levelNames do + let valStx ← liftMacroM <| declValToTerm header.value + forallBoundedTelescope header.type header.numParams fun xs type => do + -- Add new info nodes for new fvars. The server will detect all fvars of a binder by the binder's source location. + for i in [0:header.binderIds.size] do + -- skip auto-bound prefix in `xs` + addLocalVarInfo header.binderIds[i]! xs[header.numParams - header.binderIds.size + i]! + let val ← withReader ({ · with tacSnap? := header.tacSnap? }) do + elabTermEnsuringType valStx type <* Term.synthesizeSyntheticMVarsNoPostponing + let val ← mkLambdaFVars xs val + if let some snap := header.bodySnap? then + snap.new.resolve <| some { + diagnostics := + (← Language.Snapshot.Diagnostics.ofMessageLog (← Core.getAndEmptyMessageLog)) + state := (← save) + value := val + } + return val private def collectUsed (headers : Array DefViewElabHeader) (values : Array Expr) (toLift : List LetRecToLift) : StateRefT CollectFVars.State MetaM Unit := do diff --git a/src/Lean/Elab/Tactic/Basic.lean b/src/Lean/Elab/Tactic/Basic.lean index 546b85fda505..a8305c7e2b76 100644 --- a/src/Lean/Elab/Tactic/Basic.lean +++ b/src/Lean/Elab/Tactic/Basic.lean @@ -96,13 +96,15 @@ def SavedState.restore (b : SavedState) (restoreInfo := false) : TacticM Unit := b.term.restore restoreInfo set b.tactic -/-- -Restores full state including sources for unique identifiers. Only intended for incremental reuse -between elaboration runs, not for backtracking within a single run. --/ -def SavedState.restoreFull (b : SavedState) : TacticM Unit := do - b.term.restoreFull - set b.tactic +@[specialize, inherit_doc Core.withRestoreOrSaveFull] +def withRestoreOrSaveFull (old? : Option (α × SavedState)) + (cont : TacticM SavedState → TacticM α) : TacticM α := do + if let some (_, oldState) := old? then + set oldState.tactic + let old? := old?.map (fun (oldVal, oldState) => (oldVal, oldState.term)) + controlAt TermElabM fun runInBase => + Term.withRestoreOrSaveFull old? fun restore => + runInBase <| cont (return { term := (← restore), tactic := (← get) }) protected def getCurrMacroScope : TacticM MacroScope := do pure (← readThe Core.Context).currMacroScope protected def getMainModule : TacticM Name := do pure (← getEnv).mainModule diff --git a/src/Lean/Elab/Tactic/BuiltinTactic.lean b/src/Lean/Elab/Tactic/BuiltinTactic.lean index 46a669600c2a..f6d8d8f1309f 100644 --- a/src/Lean/Elab/Tactic/BuiltinTactic.lean +++ b/src/Lean/Elab/Tactic/BuiltinTactic.lean @@ -64,14 +64,13 @@ where Term.withNarrowedTacticReuse (stx := stx) (fun stx => (stx[0], mkNullNode stx.getArgs[1:])) fun stxs => do let some snap := (← readThe Term.Context).tacSnap? | do evalTactic tac; goOdd stxs - let mut reused := false + let mut reusableResult? := none let mut oldNext? := none if let some old := snap.old? then -- `tac` must be unchanged given the narrow above; let's reuse `finished`'s state! let oldParsed := old.val.get if let some state := oldParsed.data.finished.get.state? then - state.restoreFull - reused := true + reusableResult? := some ((), state) -- only allow `next` reuse in this case oldNext? := oldParsed.next.get? 1 |>.map (⟨old.stx, ·⟩) @@ -89,7 +88,7 @@ where { range? := stxs |>.getRange? task := next.result }] - unless reused do + withRestoreOrSaveFull reusableResult? fun save => do withTheReader Term.Context ({ · with tacSnap? := if (← builtinIncrementalTactics.get).contains tac.getKind then some { @@ -97,7 +96,7 @@ where new := inner } else none }) do evalTactic tac - finished.resolve { state? := (← saveState) } + finished.resolve { state? := (← save) } withTheReader Term.Context ({ · with tacSnap? := some { new := next diff --git a/src/Lean/Elab/Term.lean b/src/Lean/Elab/Term.lean index 3d41a282e0d2..bdc19a971714 100644 --- a/src/Lean/Elab/Term.lean +++ b/src/Lean/Elab/Term.lean @@ -305,13 +305,15 @@ def SavedState.restore (s : SavedState) (restoreInfo : Bool := false) : TermElab unless restoreInfo do setInfoState infoState -/-- -Restores full state including sources for unique identifiers. Only intended for incremental reuse -between elaboration runs, not for backtracking within a single run. --/ -def SavedState.restoreFull (s : SavedState) : TermElabM Unit := do - s.meta.restoreFull - set s.elab +@[specialize, inherit_doc Core.withRestoreOrSaveFull] +def withRestoreOrSaveFull (old? : Option (α × SavedState)) + (cont : TermElabM SavedState → TermElabM α) : TermElabM α := do + if let some (_, oldState) := old? then + set oldState.elab + let old? := old?.map (fun (oldVal, oldState) => (oldVal, oldState.meta)) + controlAt MetaM fun runInBase => + Meta.withRestoreOrSaveFull old? fun restore => + runInBase <| cont (return { meta := (← restore), «elab» := (← get) }) instance : MonadBacktrack SavedState TermElabM where saveState := Term.saveState diff --git a/src/Lean/Meta/Basic.lean b/src/Lean/Meta/Basic.lean index 0882164276b5..1c69ba85a10a 100644 --- a/src/Lean/Meta/Basic.lean +++ b/src/Lean/Meta/Basic.lean @@ -274,7 +274,7 @@ structure State where Backtrackable state for the `MetaM` monad. -/ structure SavedState where - core : Core.State + core : Core.SavedState meta : State deriving Nonempty @@ -364,20 +364,22 @@ instance : AddMessageContext MetaM where addMessageContext := addMessageContextFull protected def saveState : MetaM SavedState := - return { core := (← getThe Core.State), meta := (← get) } + return { core := (← Core.saveState), meta := (← get) } /-- Restore backtrackable parts of the state. -/ def SavedState.restore (b : SavedState) : MetaM Unit := do - Core.restore b.core + b.core.restore modify fun s => { s with mctx := b.meta.mctx, zetaDeltaFVarIds := b.meta.zetaDeltaFVarIds, postponed := b.meta.postponed } -/-- -Restores full state including sources for unique identifiers. Only intended for incremental reuse -between elaboration runs, not for backtracking within a single run. --/ -def SavedState.restoreFull (b : SavedState) : MetaM Unit := do - Core.restoreFull b.core - set b.meta +@[specialize, inherit_doc Core.withRestoreOrSaveFull] +def withRestoreOrSaveFull (old? : Option (α × SavedState)) + (cont : MetaM SavedState → MetaM α) : MetaM α := do + if let some (_, oldState) := old? then + set oldState.meta + let old? := old?.map (fun (oldVal, oldState) => (oldVal, oldState.core)) + controlAt CoreM fun runInCoreM => + Core.withRestoreOrSaveFull old? fun restore => + runInCoreM <| cont (return { core := (← restore), meta := (← get) }) instance : MonadBacktrack SavedState MetaM where saveState := Meta.saveState diff --git a/src/runtime/alloc.cpp b/src/runtime/alloc.cpp index ea0a438311f4..6ac4804e5eaf 100644 --- a/src/runtime/alloc.cpp +++ b/src/runtime/alloc.cpp @@ -467,16 +467,20 @@ void finalize_alloc() { LEAN_THREAD_VALUE(uint64_t, g_heartbeat, 0); #endif -/* Helper function for increasing heartbeat even when LEAN_SMALL_ALLOCATOR is not defined */ -extern "C" LEAN_EXPORT void lean_inc_heartbeat() { +void add_heartbeats(uint64_t count) { #ifdef LEAN_SMALL_ALLOCATOR if (g_heap) - g_heap->m_heartbeat++; + g_heap->m_heartbeat += count; #else - g_heartbeat++; + g_heartbeat += count; #endif } +/* Helper function for increasing heartbeat even when LEAN_SMALL_ALLOCATOR is not defined */ +extern "C" LEAN_EXPORT void lean_inc_heartbeat() { + add_heartbeats(1); +} + uint64_t get_num_heartbeats() { #ifdef LEAN_SMALL_ALLOCATOR if (g_heap) diff --git a/src/runtime/alloc.h b/src/runtime/alloc.h index c626b5f7ad82..5aaa19cd71f3 100644 --- a/src/runtime/alloc.h +++ b/src/runtime/alloc.h @@ -12,6 +12,7 @@ namespace lean { void init_thread_heap(); void * alloc(size_t sz); void dealloc(void * o, size_t sz); +void add_heartbeats(uint64_t count); uint64_t get_num_heartbeats(); void initialize_alloc(); void finalize_alloc(); diff --git a/src/runtime/io.cpp b/src/runtime/io.cpp index 2b3fb8e7cfe3..eb81d5929273 100644 --- a/src/runtime/io.cpp +++ b/src/runtime/io.cpp @@ -615,6 +615,12 @@ extern "C" LEAN_EXPORT obj_res lean_io_get_num_heartbeats(obj_arg /* w */) { return io_result_mk_ok(lean_uint64_to_nat(get_num_heartbeats())); } +/* addHeartbeats (count : Int64) : BaseIO Unit */ +extern "C" LEAN_EXPORT obj_res lean_io_add_heartbeats(int64_t count, obj_arg /* w */) { + add_heartbeats(count); + return io_result_mk_ok(box(0)); +} + extern "C" LEAN_EXPORT obj_res lean_io_getenv(b_obj_arg env_var, obj_arg) { #if defined(LEAN_EMSCRIPTEN) // HACK(WN): getenv doesn't seem to work in Emscripten even though it should