Skip to content

Commit

Permalink
good-enough approximation of "true" heartbeats in case of reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
Kha committed Apr 24, 2024
1 parent a6b7d29 commit 3eb9190
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 67 deletions.
7 changes: 7 additions & 0 deletions src/Init/System/IO.lean
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,13 @@ local macro "nonempty_list" : tactic =>
/-- Helper method for implementing "deterministic" timeouts. It is the number of "small" memory allocations performed by the current execution thread. -/
@[extern "lean_io_get_num_heartbeats"] opaque getNumHeartbeats : BaseIO Nat

/--
Adjusts the heartbeat counter of the current thread by the given amount. This can be useful to give
allocation-avoiding code additional "weight" and is also used to adjust the counter after resuming
from a snapshot.
-/
@[extern "lean_io_add_heartbeats"] opaque addHeartbeats (count : UInt64) : BaseIO Unit

/--
The mode of a file handle (i.e., a set of `open` flags and an `fdopen` mode).
Expand Down
43 changes: 36 additions & 7 deletions src/Lean/CoreM.lean
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,45 @@ instance : MonadTrace CoreM where
getTraceState := return (← get).traceState
modifyTraceState f := modify fun s => { s with traceState := f s.traceState }

/-- Restore backtrackable parts of the state. -/
def restore (b : State) : CoreM Unit :=
modify fun s => { s with env := b.env, messages := b.messages, infoState := b.infoState }
structure SavedState extends State where
/-- Number of heartbeats passed inside `withRestoreOrSaveFull`, not used otherwise. -/
passedHearbeats : Nat
deriving Nonempty

def saveState : CoreM SavedState := do
let s ← get
return { toState := s, passedHearbeats := 0 }

/--
Restores full state including sources for unique identifiers. Only intended for incremental reuse
between elaboration runs, not for backtracking within a single run.
Incremental reuse primitive: if `old?` is `none`, runs `cont` with an action `save` that on
execution returns the saved monadic state at this point including the heartbeats used by `cont` so
far. If `old?` on the other hand is `some (a, state)`, restores full `state` including heartbeats
used and returns `a`.
The intention is for steps that support incremental reuse to initially pass `none` as `old?` and
call `save` as late as possible in `cont`. In a further run, if reuse is possible, `old?` should be
set to the previous state and result, ensuring that the state after running `withRestoreOrSaveFull`
is identical in both runs. Note however that necessarily this is only an approximation in the case
of heartbeats as heartbeats used by `withRestoreOrSaveFull`, by the remainder of `cont` after
calling `save`, as well as by reuse-handling code such as the one supplying `old?` are not accounted
for.
-/
def restoreFull (b : State) : CoreM Unit :=
set b
@[specialize] def withRestoreOrSaveFull (old? : Option (α × SavedState))
(cont : (save : CoreM SavedState) → CoreM α) : CoreM α := do
if let some (oldVal, oldState) := old? then
set oldState.toState
IO.addHeartbeats oldState.passedHearbeats.toUInt64
return oldVal

let s ← get
let startHeartbeats ← IO.getNumHeartbeats
cont (do
let stopHeartbeats ← IO.getNumHeartbeats
return { toState := s, passedHearbeats := stopHeartbeats - startHeartbeats })

/-- Restore backtrackable parts of the state. -/
def SavedState.restore (b : SavedState) : CoreM Unit :=
modify fun s => { s with env := b.env, messages := b.messages, infoState := b.infoState }

private def mkFreshNameImp (n : Name) : CoreM Name := do
let fresh ← modifyGet fun s => (s.nextMacroScope, { s with nextMacroScope := s.nextMacroScope + 1 })
Expand Down
56 changes: 29 additions & 27 deletions src/Lean/Elab/MutualDef.lean
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ private def elabHeaders (views : Array DefView)
let mut reuseBody := views.all (·.headerSnap?.any (·.old?.isSome))
for view in views, ⟨shortDeclName, declName, levelNames⟩ in expandedDeclIds,
tacPromise in tacPromises, bodyPromise in bodyPromises do
let mut reusableResult? := none
if let some snap := view.headerSnap? then
-- by the `DefView.headerSnap?` invariant, safe to reuse results at this point, so let's
-- wait for them!
if let some old := snap.old?.bind (·.val.get) then
old.state.restoreFull
let (tacStx?, newTacTask?) ← mkTacPromiseAndSnap view.value tacPromise
snap.new.resolve <| some { old with
tacStx?
Expand All @@ -161,7 +161,7 @@ private def elabHeaders (views : Array DefView)
-- we can reuse the result
reuseBody := reuseBody &&
view.value.structRangeEqWithTraceReuse (← getOptions) old.bodyStx
headers := headers.push { old.view, view with
let header := { old.view, view with
tacSnap? := some {
old? := do
guard reuseTac
Expand All @@ -174,11 +174,12 @@ private def elabHeaders (views : Array DefView)
new := bodyPromise
}
}
continue
reusableResult? := some (header, old.state)
else
reuseBody := false

let newHeader ← withRef view.ref do
let header ← withRestoreOrSaveFull reusableResult? fun save => do
withRef view.ref do
addDeclarationRanges declName view.ref
applyAttributesAt declName view.modifiers.attrs .beforeElaboration
withDeclName declName <| withAutoBoundImplicit <| withLevelNames levelNames <|
Expand Down Expand Up @@ -220,7 +221,7 @@ private def elabHeaders (views : Array DefView)
diagnostics :=
(← Language.Snapshot.Diagnostics.ofMessageLog (← Core.getAndEmptyMessageLog))
view := newHeader.toDefViewElabHeaderData
state := (← saveState)
state := (← save)
tacStx?
tacSnap? := newTacTask?
bodyStx := view.value
Expand All @@ -232,7 +233,7 @@ private def elabHeaders (views : Array DefView)
}
check headers newHeader
return newHeader
headers := headers.push newHeader
headers := headers.push header
return headers
where
getBodyTerm? (stx : Syntax) : Option Syntax :=
Expand Down Expand Up @@ -333,38 +334,39 @@ private def declValToTerminationHint (declVal : Syntax) : TermElabM WF.Terminati

private def elabFunValues (headers : Array DefViewElabHeader) : TermElabM (Array Expr) :=
headers.mapM fun header => do
let mut reusableResult? := none
if let some snap := header.bodySnap? then
if let some old := snap.old? then
-- guaranteed reusable as by the `bodySnap?` invariant, so let's wait on the previous
-- elaboration
if let some old := old.val.get then
old.state.restoreFull
snap.new.resolve <| some old
-- also make sure to reuse tactic snapshots if present so that body reuse does not lead to
-- missed tactic reuse on further changes
if let some tacSnap := header.tacSnap? then
if let some oldTacSnap := tacSnap.old? then
tacSnap.new.resolve oldTacSnap.val.get
return old.value

withDeclName header.declName <| withLevelNames header.levelNames do
let valStx ← liftMacroM <| declValToTerm header.value
forallBoundedTelescope header.type header.numParams fun xs type => do
-- Add new info nodes for new fvars. The server will detect all fvars of a binder by the binder's source location.
for i in [0:header.binderIds.size] do
-- skip auto-bound prefix in `xs`
addLocalVarInfo header.binderIds[i]! xs[header.numParams - header.binderIds.size + i]!
let val ← withReader ({ · with tacSnap? := header.tacSnap? }) do
elabTermEnsuringType valStx type <* Term.synthesizeSyntheticMVarsNoPostponing
let val ← mkLambdaFVars xs val
if let some snap := header.bodySnap? then
snap.new.resolve <| some {
diagnostics :=
(← Language.Snapshot.Diagnostics.ofMessageLog (← Core.getAndEmptyMessageLog))
state := (← saveState)
value := val
}
return val
reusableResult? := some (old.value, old.state)

withRestoreOrSaveFull reusableResult? fun save => do
withDeclName header.declName <| withLevelNames header.levelNames do
let valStx ← liftMacroM <| declValToTerm header.value
forallBoundedTelescope header.type header.numParams fun xs type => do
-- Add new info nodes for new fvars. The server will detect all fvars of a binder by the binder's source location.
for i in [0:header.binderIds.size] do
-- skip auto-bound prefix in `xs`
addLocalVarInfo header.binderIds[i]! xs[header.numParams - header.binderIds.size + i]!
let val ← withReader ({ · with tacSnap? := header.tacSnap? }) do
elabTermEnsuringType valStx type <* Term.synthesizeSyntheticMVarsNoPostponing
let val ← mkLambdaFVars xs val
if let some snap := header.bodySnap? then
snap.new.resolve <| some {
diagnostics :=
(← Language.Snapshot.Diagnostics.ofMessageLog (← Core.getAndEmptyMessageLog))
state := (← save)
value := val
}
return val

private def collectUsed (headers : Array DefViewElabHeader) (values : Array Expr) (toLift : List LetRecToLift)
: StateRefT CollectFVars.State MetaM Unit := do
Expand Down
16 changes: 9 additions & 7 deletions src/Lean/Elab/Tactic/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,15 @@ def SavedState.restore (b : SavedState) (restoreInfo := false) : TacticM Unit :=
b.term.restore restoreInfo
set b.tactic

/--
Restores full state including sources for unique identifiers. Only intended for incremental reuse
between elaboration runs, not for backtracking within a single run.
-/
def SavedState.restoreFull (b : SavedState) : TacticM Unit := do
b.term.restoreFull
set b.tactic
@[specialize, inherit_doc Core.withRestoreOrSaveFull]
def withRestoreOrSaveFull (old? : Option (α × SavedState))
(cont : TacticM SavedState → TacticM α) : TacticM α := do
if let some (_, oldState) := old? then
set oldState.tactic
let old? := old?.map (fun (oldVal, oldState) => (oldVal, oldState.term))
controlAt TermElabM fun runInBase =>
Term.withRestoreOrSaveFull old? fun restore =>
runInBase <| cont (return { term := (← restore), tactic := (← get) })

protected def getCurrMacroScope : TacticM MacroScope := do pure (← readThe Core.Context).currMacroScope
protected def getMainModule : TacticM Name := do pure (← getEnv).mainModule
Expand Down
9 changes: 4 additions & 5 deletions src/Lean/Elab/Tactic/BuiltinTactic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,13 @@ where
Term.withNarrowedTacticReuse (stx := stx) (fun stx => (stx[0], mkNullNode stx.getArgs[1:])) fun stxs => do
let some snap := (← readThe Term.Context).tacSnap?
| do evalTactic tac; goOdd stxs
let mut reused := false
let mut reusableResult? := none
let mut oldNext? := none
if let some old := snap.old? then
-- `tac` must be unchanged given the narrow above; let's reuse `finished`'s state!
let oldParsed := old.val.get
if let some state := oldParsed.data.finished.get.state? then
state.restoreFull
reused := true
reusableResult? := some ((), state)
-- only allow `next` reuse in this case
oldNext? := oldParsed.next.get? 1 |>.map (⟨old.stx, ·⟩)

Expand All @@ -89,15 +88,15 @@ where
{
range? := stxs |>.getRange?
task := next.result }]
unless reused do
withRestoreOrSaveFull reusableResult? fun save => do
withTheReader Term.Context ({ · with
tacSnap? := if (← builtinIncrementalTactics.get).contains tac.getKind then
some {
old? := oldInner?
new := inner
} else none }) do
evalTactic tac
finished.resolve { state? := (← saveState) }
finished.resolve { state? := (← save) }

withTheReader Term.Context ({ · with tacSnap? := some {
new := next
Expand Down
16 changes: 9 additions & 7 deletions src/Lean/Elab/Term.lean
Original file line number Diff line number Diff line change
Expand Up @@ -305,13 +305,15 @@ def SavedState.restore (s : SavedState) (restoreInfo : Bool := false) : TermElab
unless restoreInfo do
setInfoState infoState

/--
Restores full state including sources for unique identifiers. Only intended for incremental reuse
between elaboration runs, not for backtracking within a single run.
-/
def SavedState.restoreFull (s : SavedState) : TermElabM Unit := do
s.meta.restoreFull
set s.elab
@[specialize, inherit_doc Core.withRestoreOrSaveFull]
def withRestoreOrSaveFull (old? : Option (α × SavedState))
(cont : TermElabM SavedState → TermElabM α) : TermElabM α := do
if let some (_, oldState) := old? then
set oldState.elab
let old? := old?.map (fun (oldVal, oldState) => (oldVal, oldState.meta))
controlAt MetaM fun runInBase =>
Meta.withRestoreOrSaveFull old? fun restore =>
runInBase <| cont (return { meta := (← restore), «elab» := (← get) })

instance : MonadBacktrack SavedState TermElabM where
saveState := Term.saveState
Expand Down
22 changes: 12 additions & 10 deletions src/Lean/Meta/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ structure State where
Backtrackable state for the `MetaM` monad.
-/
structure SavedState where
core : Core.State
core : Core.SavedState
meta : State
deriving Nonempty

Expand Down Expand Up @@ -364,20 +364,22 @@ instance : AddMessageContext MetaM where
addMessageContext := addMessageContextFull

protected def saveState : MetaM SavedState :=
return { core := (← getThe Core.State), meta := (← get) }
return { core := (← Core.saveState), meta := (← get) }

/-- Restore backtrackable parts of the state. -/
def SavedState.restore (b : SavedState) : MetaM Unit := do
Core.restore b.core
b.core.restore
modify fun s => { s with mctx := b.meta.mctx, zetaDeltaFVarIds := b.meta.zetaDeltaFVarIds, postponed := b.meta.postponed }

/--
Restores full state including sources for unique identifiers. Only intended for incremental reuse
between elaboration runs, not for backtracking within a single run.
-/
def SavedState.restoreFull (b : SavedState) : MetaM Unit := do
Core.restoreFull b.core
set b.meta
@[specialize, inherit_doc Core.withRestoreOrSaveFull]
def withRestoreOrSaveFull (old? : Option (α × SavedState))
(cont : MetaM SavedState → MetaM α) : MetaM α := do
if let some (_, oldState) := old? then
set oldState.meta
let old? := old?.map (fun (oldVal, oldState) => (oldVal, oldState.core))
controlAt CoreM fun runInCoreM =>
Core.withRestoreOrSaveFull old? fun restore =>
runInCoreM <| cont (return { core := (← restore), meta := (← get) })

instance : MonadBacktrack SavedState MetaM where
saveState := Meta.saveState
Expand Down
12 changes: 8 additions & 4 deletions src/runtime/alloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,16 +467,20 @@ void finalize_alloc() {
LEAN_THREAD_VALUE(uint64_t, g_heartbeat, 0);
#endif

/* Helper function for increasing heartbeat even when LEAN_SMALL_ALLOCATOR is not defined */
extern "C" LEAN_EXPORT void lean_inc_heartbeat() {
void add_heartbeats(uint64_t count) {
#ifdef LEAN_SMALL_ALLOCATOR
if (g_heap)
g_heap->m_heartbeat++;
g_heap->m_heartbeat += count;
#else
g_heartbeat++;
g_heartbeat += count;
#endif
}

/* Helper function for increasing heartbeat even when LEAN_SMALL_ALLOCATOR is not defined */
extern "C" LEAN_EXPORT void lean_inc_heartbeat() {
add_heartbeats(1);
}

uint64_t get_num_heartbeats() {
#ifdef LEAN_SMALL_ALLOCATOR
if (g_heap)
Expand Down
1 change: 1 addition & 0 deletions src/runtime/alloc.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace lean {
void init_thread_heap();
void * alloc(size_t sz);
void dealloc(void * o, size_t sz);
void add_heartbeats(uint64_t count);
uint64_t get_num_heartbeats();
void initialize_alloc();
void finalize_alloc();
Expand Down
6 changes: 6 additions & 0 deletions src/runtime/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,12 @@ extern "C" LEAN_EXPORT obj_res lean_io_get_num_heartbeats(obj_arg /* w */) {
return io_result_mk_ok(lean_uint64_to_nat(get_num_heartbeats()));
}

/* addHeartbeats (count : Int64) : BaseIO Unit */
extern "C" LEAN_EXPORT obj_res lean_io_add_heartbeats(int64_t count, obj_arg /* w */) {
add_heartbeats(count);
return io_result_mk_ok(box(0));
}

extern "C" LEAN_EXPORT obj_res lean_io_getenv(b_obj_arg env_var, obj_arg) {
#if defined(LEAN_EMSCRIPTEN)
// HACK(WN): getenv doesn't seem to work in Emscripten even though it should
Expand Down

0 comments on commit 3eb9190

Please sign in to comment.