Skip to content


fix: semantic tokens performance
Browse files Browse the repository at this point in the history
  • Loading branch information
mhuisi committed Apr 17, 2024
1 parent d3e0049 commit 869325a
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 84 deletions.
2 changes: 1 addition & 1 deletion src/Lean/Data/Lsp/LanguageFeatures.lean
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ inductive SemanticTokenType where
| decorator
-- Extensions
| leanSorryLike
deriving ToJson, FromJson
deriving ToJson, FromJson, BEq, Hashable

-- must be in the same order as the constructors
def SemanticTokenType.names : Array String :=
Expand Down
17 changes: 14 additions & 3 deletions src/Lean/Elab/App.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,17 @@ private def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Ar
argIdx := argIdx + 1
throwError "invalid field notation, function '{fullName}' does not have argument with type ({baseName} ...) that can be used, it must be explicit or implicit with a unique name"

/-- Adds the `TermInfo` for the field of a projection. See `Lean.Parser.Term.identProjKind`. -/
private def addProjTermInfo
(stx : Syntax)
(e : Expr)
(expectedType? : Option Expr := none)
(lctx? : Option LocalContext := none)
(elaborator : Name := Name.anonymous)
(isBinder force : Bool := false)
: TermElabM Expr :=
addTermInfo (Syntax.node .none Parser.Term.identProjKind #[stx]) e expectedType? lctx? elaborator isBinder force

private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (expectedType? : Option Expr) (explicit ellipsis : Bool)
(f : Expr) (lvals : List LVal) : TermElabM Expr :=
let rec loop : Expr → List LVal → TermElabM Expr
Expand All @@ -1214,7 +1225,7 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp
if isPrivateNameFromImportedModule (← getEnv) info.projFn then
throwError "field '{fieldName}' from structure '{structName}' is private"
let projFn ← mkConst info.projFn
let projFn ← addTermInfo lval.getRef projFn
let projFn ← addProjTermInfo lval.getRef projFn
if lvals.isEmpty then
let namedArgs ← addNamedArg namedArgs { name := `self, val := Arg.expr f }
elabAppArgs projFn namedArgs args expectedType? explicit ellipsis
Expand All @@ -1226,7 +1237,7 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp
| LValResolution.const baseStructName structName constName =>
let f ← if baseStructName != structName then mkBaseProjections baseStructName structName f else pure f
let projFn ← mkConst constName
let projFn ← addTermInfo lval.getRef projFn
let projFn ← addProjTermInfo lval.getRef projFn
if lvals.isEmpty then
let projFnType ← inferType projFn
let (args, namedArgs) ← addLValArg baseStructName constName f args namedArgs projFnType
Expand All @@ -1235,7 +1246,7 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp
let f ← elabAppArgs projFn #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false)
loop f lvals
| LValResolution.localRec baseName fullName fvar =>
let fvar ← addTermInfo lval.getRef fvar
let fvar ← addProjTermInfo lval.getRef fvar
if lvals.isEmpty then
let fvarType ← inferType fvar
let (args, namedArgs) ← addLValArg baseName fullName f args namedArgs fvarType
Expand Down
11 changes: 11 additions & 0 deletions src/Lean/Parser/Term.lean
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,17 @@ is short for accessing the `i`-th field (1-indexed) of `e` if it is of a structu
@[builtin_term_parser] def arrow := trailing_parser
checkPrec 25 >> unicodeSymbol " → " " -> " >> termParser 25

Syntax kind for syntax nodes representing the field of a projection in the `InfoTree`.
Specifically, the `InfoTree` node for a projection `s.f` contains a child `InfoTree` node
with syntax ``(Syntax.node .none identProjKind #[`f])``.
This is necessary because projection syntax cannot always be detected purely syntactically
(`s.f` may refer to either the identifier `s.f` or a projection `s.f` depending on
the evailable context).
def identProjKind := `Lean.Parser.Term.identProj

def isIdent (stx : Syntax) : Bool :=
-- antiquotations should also be allowed where an identifier is expected
stx.isAntiquot || stx.isIdent
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Server/FileWorker.lean
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def runRefreshTask : WorkerM (Task (Except IO.Error Unit)) := do
IO.sleep 1000
sendServerRequest ctx "workspace/semanticTokens/refresh" (none : Option Nat)
IO.sleep 5000
IO.sleep 2000

def initAndRunWorker (i o e : FS.Stream) (opts : Options) : IO UInt32 := do
let i ← maybeTee "fwIn.txt" false i
Expand Down
204 changes: 125 additions & 79 deletions src/Lean/Server/FileWorker/RequestHandling.lean
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,9 @@ where
return toDocumentSymbols text stxs (syms.push sym) stack
toDocumentSymbols text stxs syms stack

`SyntaxNodeKind`s for which the syntax node and its children receive no semantic highlighting.
def noHighlightKinds : Array SyntaxNodeKind := #[
-- usually have special highlighting by the client
Expand All @@ -429,25 +432,121 @@ def noHighlightKinds : Array SyntaxNodeKind := #[

structure SemanticTokensContext where
beginPos : String.Pos
endPos? : Option String.Pos
text : FileMap
snap : Snapshot

structure SemanticTokensState where
data : Array Nat
lastLspPos : Lsp.Position

-- TODO: make extensible, or don't
/-- Keywords for which a specific semantic token is provided. -/
def keywordSemanticTokenMap : RBMap String SemanticTokenType compare :=
|>.insert "sorry" .leanSorryLike
|>.insert "admit" .leanSorryLike
|>.insert "stop" .leanSorryLike
|>.insert "#exit" .leanSorryLike

partial def handleSemanticTokens (beginPos : String.Pos) (endPos? : Option String.Pos)
/-- Semantic token information for a given `Syntax`. -/
structure LeanSemanticToken where
/-- Syntax of the semantic token. -/
stx : Syntax
/-- Type of the semantic token. -/
type : SemanticTokenType

/-- Semantic token information with absolute LSP positions. -/
structure AbsoluteLspSemanticToken where
/-- Start position of the semantic token. -/
pos : Lsp.Position
/-- End position of the semantic token. -/
tailPos : Lsp.Position
/-- Start position of the semantic token. -/
type : SemanticTokenType
deriving BEq, Hashable, FromJson, ToJson

Given a set of `LeanSemanticToken`, computes the `AbsoluteLspSemanticToken` with absolute
LSP position information for each token.
def computeAbsoluteLspSemanticTokens
(text : FileMap)
(beginPos : String.Pos)
(endPos? : Option String.Pos)
(tokens : Array LeanSemanticToken)
: Array AbsoluteLspSemanticToken :=
tokens.filterMap fun ⟨stx, type⟩ => do
let (pos, tailPos) := (← stx.getPos?, ← stx.getTailPos?)
guard <| beginPos <= pos && endPos?.all (pos < ·)
let (lspPos, lspTailPos) := (text.utf8PosToLspPos pos, text.utf8PosToLspPos tailPos)
return ⟨lspPos, lspTailPos, type⟩

/-- Filters all duplicate semantic tokens with the same `pos`, `tailPos` and `type`. -/
def filterDuplicateSemanticTokens (tokens : Array AbsoluteLspSemanticToken) : Array AbsoluteLspSemanticToken :=
tokens.groupByKey id |> (·.1)

Given a set of `AbsoluteLspSemanticToken`, computes the LSP `SemanticTokens` data with
token-relative positioning.
def computeDeltaLspSemanticTokens (tokens : Array AbsoluteLspSemanticToken) : SemanticTokens := do
let tokens := tokens.qsort fun ⟨pos1, tailPos1, _⟩ ⟨pos2, tailPos2, _⟩ =>
pos1 < pos2 || pos1 == pos2 && tailPos1 <= tailPos2
let mut data : Array Nat := Array.mkEmpty (5*tokens.size)
let mut lastPos : Lsp.Position := ⟨0, 0
for ⟨pos, tailPos, type⟩ in tokens do
let deltaLine := pos.line - lastPos.line
let deltaStart := pos.character - (if pos.line == lastPos.line then lastPos.character else 0)
let length := tailPos.character - pos.character
let tokenType := type.toNat
let tokenModifiers := 0
data := data ++ #[deltaLine, deltaStart, length, tokenType, tokenModifiers]
lastPos := pos
return { data }

Collects all semantic tokens that can be deduced purely from `Syntax`
without elaboration information.
partial def collectSyntaxBasedSemanticTokens : (stx : Syntax) → Array LeanSemanticToken
| `($e.$id:ident) =>
let tokens := collectSyntaxBasedSemanticTokens e
tokens.push ⟨id,⟩
| `($e |>.$field:ident) =>
let tokens := collectSyntaxBasedSemanticTokens e
tokens.push ⟨field,⟩
| stx => do
if noHighlightKinds.contains stx.getKind then
return #[]
let mut tokens :=
if stx.isOfKind choiceKind then
collectSyntaxBasedSemanticTokens stx[0]
else collectSyntaxBasedSemanticTokens |>.flatten
let Syntax.atom _ val := stx
| return tokens
let isRegularKeyword := val.length > 0 && val.front.isAlpha
let isHashKeyword := val.length > 1 && val.front == '#' && (val.get ⟨1⟩).isAlpha
if ! isRegularKeyword && ! isHashKeyword then
return tokens
return tokens.push ⟨stx, keywordSemanticTokenMap.findD val .keyword⟩

/-- Collects all semantic tokens from the given `Elab.InfoTree`. -/
def collectInfoBasedSemanticTokens (i : Elab.InfoTree) : Array LeanSemanticToken :=
List.toArray <| i.deepestNodes fun _ i _ => do
let .ofTermInfo ti := i
| none
let .original .. := ti.stx.getHeadInfo
| none
if let `($_:ident) := ti.stx then
if let Expr.fvar fvarId .. := ti.expr then
if let some localDecl := ti.lctx.find? fvarId then
-- Recall that `isAuxDecl` is an auxiliary declaration used to elaborate a recursive definition.
if localDecl.isAuxDecl then
if ti.isBinder then
return ⟨ti.stx, SemanticTokenType.function⟩
else if ! localDecl.isImplementationDetail then
return ⟨ti.stx, SemanticTokenType.variable⟩
if ti.stx.getKind == Parser.Term.identProjKind then
return ⟨ti.stx,⟩

/-- Computes the semantic tokens in the range [beginPos, endPos?). -/
def handleSemanticTokens (beginPos : String.Pos) (endPos? : Option String.Pos)
: RequestM (RequestTask SemanticTokens) := do
let doc ← readDoc
match endPos? with
Expand All @@ -462,78 +561,25 @@ partial def handleSemanticTokens (beginPos : String.Pos) (endPos? : Option Strin
let t := doc.cmdSnaps.waitUntil (·.endPos >= endPos)
mapTask t fun (snaps, _) => run doc snaps
run doc snaps : RequestM SemanticTokens :=' (s := { data := #[], lastLspPos := ⟨0, 0⟩ : SemanticTokensState }) do
for s in snaps do
if s.endPos <= beginPos then
continue (r := beginPos endPos? doc.meta.text s) <|
go s.stx
return { data := (← get).data }
go (stx : Syntax) := do
match stx with
| `($e.$id:ident) => go e; addToken id
-- indistinguishable from next pattern
--| `(level|$id:ident) => addToken id SemanticTokenType.variable
| `($id:ident) => highlightId id
| _ =>
if !noHighlightKinds.contains stx.getKind then
highlightKeyword stx
if stx.isOfKind choiceKind then
go stx[0]
stx.getArgs.forM go
highlightId (stx : Syntax) : ReaderT SemanticTokensContext (StateT SemanticTokensState RequestM) _ := do
if let some range := stx.getRange? then
let mut lastPos := range.start
for ti in (← read).snap.infoTree.deepestNodes (fun
| _, i@(Elab.Info.ofTermInfo ti), _ => match i.pos? with
| some ipos => if range.contains ipos then some ti else none
| _ => none
| _, _, _ => none) do
let pos := ti.stx.getPos?.get!
-- avoid reporting same position twice; the info node can occur multiple times if
-- e.g. the term is elaborated multiple times
if pos < lastPos then
if let Expr.fvar fvarId .. := ti.expr then
if let some localDecl := ti.lctx.find? fvarId then
-- Recall that `isAuxDecl` is an auxiliary declaration used to elaborate a recursive definition.
if localDecl.isAuxDecl then
if ti.isBinder then
addToken ti.stx SemanticTokenType.function
addToken ti.stx SemanticTokenType.variable
else if ti.stx.getPos?.get! > lastPos then
-- any info after the start position: must be projection notation
addToken ti.stx
lastPos := ti.stx.getPos?.get!
highlightKeyword stx := do
if let Syntax.atom _ val := stx then
if (val.length > 0 && val.front.isAlpha) ||
-- Support for keywords of the form `#<alpha>...`
(val.length > 1 && val.front == '#' && (val.get ⟨1⟩).isAlpha) then
addToken stx (keywordSemanticTokenMap.findD val .keyword)
addToken stx type := do
let ⟨beginPos, endPos?, text, _⟩ ← read
if let (some pos, some tailPos) := (stx.getPos?, stx.getTailPos?) then
if beginPos <= pos && endPos?.all (pos < ·) then
let lspPos := (← get).lastLspPos
let lspPos' := text.utf8PosToLspPos pos
let deltaLine := lspPos'.line - lspPos.line
let deltaStart := lspPos'.character - (if lspPos'.line == lspPos.line then lspPos.character else 0)
let length := (text.utf8PosToLspPos tailPos).character - lspPos'.character
let tokenType := type.toNat
let tokenModifiers := 0
modify fun st => {
data := ++ #[deltaLine, deltaStart, length, tokenType, tokenModifiers]
lastLspPos := lspPos'

run doc snaps : RequestM SemanticTokens := do
let mut leanSemanticTokens := #[]
for s in snaps do
if s.endPos <= beginPos then
let syntaxBasedSemanticTokens := collectSyntaxBasedSemanticTokens s.stx
let infoBasedSemanticTokens := collectInfoBasedSemanticTokens s.infoTree
leanSemanticTokens := leanSemanticTokens ++ syntaxBasedSemanticTokens ++ infoBasedSemanticTokens
let absoluteLspSemanticTokens := computeAbsoluteLspSemanticTokens doc.meta.text beginPos endPos? leanSemanticTokens
let absoluteLspSemanticTokens := filterDuplicateSemanticTokens absoluteLspSemanticTokens
let semanticTokens := computeDeltaLspSemanticTokens absoluteLspSemanticTokens
return semanticTokens

/-- Computes all semantic tokens for the document. -/
def handleSemanticTokensFull (_ : SemanticTokensParams)
: RequestM (RequestTask SemanticTokens) := do
handleSemanticTokens 0 none

/-- Computes the semantic tokens in the range provided by `p`. -/
def handleSemanticTokensRange (p : SemanticTokensRangeParams)
: RequestM (RequestTask SemanticTokens) := do
let doc ← readDoc
Expand Down

0 comments on commit 869325a

Please sign in to comment.