Skip to content

Commit

Permalink
refactor(langserver): implement cancelable threads and remove async c…
Browse files Browse the repository at this point in the history
…ode from definition, implementations, hover, folding
  • Loading branch information
d-biehl committed Dec 30, 2023
1 parent 4b3e65c commit 2fae8e3
Show file tree
Hide file tree
Showing 18 changed files with 240 additions and 195 deletions.
68 changes: 65 additions & 3 deletions packages/core/src/robotcode/core/utils/threading.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,27 @@
import inspect
from typing import Any, Callable, TypeVar
from concurrent.futures import CancelledError, Future
from threading import Thread, current_thread, local
from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar, cast

_F = TypeVar("_F", bound=Callable[..., Any])
_TResult = TypeVar("_TResult")

__THREADED_MARKER = "__threaded__"


class FutureEx(Future, Generic[_TResult]): # type: ignore[type-arg]
def __init__(self) -> None:
super().__init__()
self.cancelation_requested = False

def cancel(self) -> bool:
self.cancelation_requested = True
return super().cancel()

def result(self, timeout: Optional[float] = None) -> _TResult:
return cast(_TResult, super().result(timeout))


def threaded(enabled: bool = True) -> Callable[[_F], _F]:
def decorator(func: _F) -> _F:
setattr(func, __THREADED_MARKER, enabled)
Expand All @@ -14,5 +30,51 @@ def decorator(func: _F) -> _F:
return decorator


def is_threaded_callable(func: Callable[..., Any]) -> bool:
return getattr(func, __THREADED_MARKER, False) or inspect.ismethod(func) and getattr(func, __THREADED_MARKER, False)
def is_threaded_callable(callable: Callable[..., Any]) -> bool:
return (
getattr(callable, __THREADED_MARKER, False)
or inspect.ismethod(callable)
and getattr(callable, __THREADED_MARKER, False)
)


class _Local(local):
def __init__(self) -> None:
super().__init__()
self._local_future: Optional[FutureEx[Any]] = None


_local_storage = _Local()


def _run_callable_in_thread_handler(
future: FutureEx[_TResult], callable: Callable[..., _TResult], args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> None:
_local_storage._local_future = future
future.set_running_or_notify_cancel()
try:
future.set_result(callable(*args, **kwargs))
except Exception as e:
# TODO: add traceback to exception e.traceback = format_exc()

future.set_exception(e)
finally:
_local_storage._local_future = None


def is_thread_cancelled() -> bool:
return _local_storage._local_future is not None and _local_storage._local_future.cancelation_requested


def check_thread_canceled() -> None:
if _local_storage._local_future is not None and _local_storage._local_future.cancelation_requested:
name = current_thread().name
raise CancelledError(f"Thread {name+' ' if name else ' '}Cancelled")


def run_callable_in_thread(callable: Callable[..., _TResult], *args: Any, **kwargs: Any) -> FutureEx[_TResult]:
future: FutureEx[_TResult] = FutureEx()

Thread(target=_run_callable_in_thread_handler, args=(future, callable, args, kwargs), name=str(callable)).start()

return future
114 changes: 66 additions & 48 deletions packages/jsonrpc2/src/robotcode/jsonrpc2/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import asyncio
import concurrent.futures
import functools
import inspect
import json
import re
import threading
import time
import weakref
from abc import ABC, abstractmethod
from collections import OrderedDict
Expand All @@ -21,7 +21,6 @@
Iterator,
List,
Mapping,
NamedTuple,
Optional,
Protocol,
Set,
Expand All @@ -42,7 +41,7 @@
from robotcode.core.utils.dataclasses import as_json, from_dict
from robotcode.core.utils.inspect import ensure_coroutine, iter_methods
from robotcode.core.utils.logging import LoggingDescriptor
from robotcode.core.utils.threading import is_threaded_callable
from robotcode.core.utils.threading import is_threaded_callable, run_callable_in_thread

__all__ = [
"JsonRPCErrors",
Expand Down Expand Up @@ -344,15 +343,23 @@ def get_param_type(self, name: str) -> Optional[Type[Any]]:
return result.param_type


class SendedRequestEntry(NamedTuple):
future: concurrent.futures.Future[Any]
result_type: Optional[Type[Any]]
class SendedRequestEntry:
def __init__(self, future: concurrent.futures.Future[Any], result_type: Optional[Type[Any]]) -> None:
self.future = future
self.result_type = result_type


class ReceivedRequestEntry(NamedTuple):
future: asyncio.Future[Any]
request: Optional[Any]
cancelable: bool
class ReceivedRequestEntry:
def __init__(self, future: asyncio.Future[Any], request: JsonRPCRequest, cancelable: bool) -> None:
self.future = future
self.request = request
self.cancelable = cancelable
self.cancel_requested = False

def cancel(self) -> None:
self.cancel_requested = True
if self.future is not None and not self.future.cancelled():
self.future.cancel()


class JsonRPCProtocolBase(asyncio.Protocol, ABC):
Expand Down Expand Up @@ -711,7 +718,6 @@ def _convert_params(
return args, kw_args

async def handle_request(self, message: JsonRPCRequest) -> None:
start = time.monotonic_ns()
try:
e = self.registry.get_entry(message.method)

Expand All @@ -725,13 +731,18 @@ async def handle_request(self, message: JsonRPCRequest) -> None:

params = self._convert_params(e.method, e.param_type, message.params)

if not e.is_coroutine:
is_threaded_method = is_threaded_callable(e.method)

if not is_threaded_method and not e.is_coroutine:
self.send_response(message.id, e.method(*params[0], **params[1]))
else:
if is_threaded_callable(e.method):
task = run_coroutine_in_thread(
ensure_coroutine(cast(Callable[..., Any], e.method)), *params[0], **params[1]
)
if is_threaded_method:
if e.is_coroutine:
task = run_coroutine_in_thread(
ensure_coroutine(cast(Callable[..., Any], e.method)), *params[0], **params[1]
)
else:
task = asyncio.wrap_future(run_callable_in_thread(e.method, *params[0], **params[1]))
else:
task = create_sub_task(
ensure_coroutine(e.method)(*params[0], **params[1]),
Expand All @@ -741,49 +752,56 @@ async def handle_request(self, message: JsonRPCRequest) -> None:
with self._received_request_lock:
self._received_request[message.id] = ReceivedRequestEntry(task, message, e.cancelable)

def done(t: asyncio.Future[Any]) -> None:
try:
if not t.cancelled():
ex = t.exception()
if ex is not None:
self.__logger.exception(ex, exc_info=ex)
raise JsonRPCErrorException(
JsonRPCErrors.INTERNAL_ERROR, f"{type(ex).__name__}: {ex}"
) from ex

self.send_response(message.id, t.result())
except asyncio.CancelledError:
self.__logger.debug(lambda: f"request message {message!r} canceled")
self.send_error(JsonRPCErrors.REQUEST_CANCELLED, "Request canceled.", id=message.id)
except (SystemExit, KeyboardInterrupt):
raise
except JsonRPCErrorException as e:
self.send_error(e.code, e.message or f"{type(e).__name__}: {e}", id=message.id, data=e.data)
except BaseException as e:
self.__logger.exception(e)
self.send_error(JsonRPCErrors.INTERNAL_ERROR, f"{type(e).__name__}: {e}", id=message.id)
finally:
with self._received_request_lock:
self._received_request.pop(message.id, None)

task.add_done_callback(done)
task.add_done_callback(functools.partial(self._received_request_done, message))

await task
finally:
self.__logger.debug(lambda: f"request message {message!r} done in {time.monotonic_ns() - start}ns")
except (SystemExit, KeyboardInterrupt, asyncio.CancelledError):
raise
except BaseException as e:
self.__logger.exception(e)

def _received_request_done(self, message: JsonRPCRequest, t: asyncio.Future[Any]) -> None:
try:
with self._received_request_lock:
entry = self._received_request.pop(message.id, None)

if entry is None:
self.__logger.critical(lambda: f"unknown request {message!r}")
return

if entry.cancel_requested:
self.__logger.debug(lambda: f"request {message!r} canceled")
self.send_error(JsonRPCErrors.REQUEST_CANCELLED, "Request canceled.", id=message.id)
else:
if not t.cancelled():
ex = t.exception()
if ex is not None:
self.__logger.exception(ex, exc_info=ex)
raise JsonRPCErrorException(JsonRPCErrors.INTERNAL_ERROR, f"{type(ex).__name__}: {ex}") from ex

self.send_response(message.id, t.result())
except asyncio.CancelledError:
self.__logger.debug(lambda: f"request message {message!r} canceled")
self.send_error(JsonRPCErrors.REQUEST_CANCELLED, "Request canceled.", id=message.id)
except (SystemExit, KeyboardInterrupt):
raise
except JsonRPCErrorException as e:
self.send_error(e.code, e.message or f"{type(e).__name__}: {e}", id=message.id, data=e.data)
except BaseException as e:
self.__logger.exception(e)
self.send_error(JsonRPCErrors.INTERNAL_ERROR, f"{type(e).__name__}: {e}", id=message.id)

def cancel_request(self, id: Union[int, str, None]) -> None:
with self._received_request_lock:
entry = self._received_request.get(id, None)

if entry is not None and entry.future is not None and not entry.future.cancelled():
if entry is not None:
self.__logger.debug(lambda: f"try to cancel request {entry.request if entry is not None else ''}")
entry.future.cancel()
entry.cancel()

def cancel_all_received_request(self) -> None:
for entry in self._received_request.values():
if entry is not None and entry.cancelable and entry.future is not None and not entry.future.cancelled():
entry.future.cancel()
entry.cancel()

@__logger.call
async def handle_notification(self, message: JsonRPCNotification) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

from asyncio import CancelledError
from concurrent.futures import CancelledError
from typing import TYPE_CHECKING, Any, Final, List, Optional, Union

from robotcode.core.async_tools import async_tasking_event
from robotcode.core.event import event
from robotcode.core.lsp.types import (
DefinitionParams,
Location,
Expand All @@ -13,7 +13,7 @@
TextDocumentIdentifier,
)
from robotcode.core.utils.logging import LoggingDescriptor
from robotcode.core.utils.threading import threaded
from robotcode.core.utils.threading import check_thread_canceled, threaded
from robotcode.jsonrpc2.protocol import rpc_method
from robotcode.language_server.common.decorators import language_id_filter
from robotcode.language_server.common.has_extend_capabilities import HasExtendCapabilities
Expand All @@ -31,8 +31,8 @@ def __init__(self, parent: LanguageServerProtocol) -> None:
super().__init__(parent)
self.link_support = False

@async_tasking_event
async def collect(
@event
def collect(
sender, document: TextDocument, position: Position # NOSONAR
) -> Union[Location, List[Location], List[LocationLink], None]:
...
Expand All @@ -50,7 +50,7 @@ def extend_capabilities(self, capabilities: ServerCapabilities) -> None:

@rpc_method(name="textDocument/definition", param_type=DefinitionParams)
@threaded()
async def _text_document_definition(
def _text_document_definition(
self, text_document: TextDocumentIdentifier, position: Position, *args: Any, **kwargs: Any
) -> Optional[Union[Location, List[Location], List[LocationLink]]]:
locations: List[Location] = []
Expand All @@ -60,9 +60,11 @@ async def _text_document_definition(
if document is None:
return None

for result in await self.collect(
for result in self.collect(
self, document, document.position_from_utf16(position), callback_filter=language_id_filter(document)
):
check_thread_canceled()

if isinstance(result, BaseException):
if not isinstance(result, CancelledError):
self._logger.exception(result, exc_info=result)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from __future__ import annotations

from asyncio import CancelledError
from concurrent.futures import CancelledError
from typing import TYPE_CHECKING, Any, Final, List, Optional

from robotcode.core.async_tools import async_tasking_event
from robotcode.core.event import event
from robotcode.core.lsp.types import (
FoldingRange,
FoldingRangeParams,
Expand All @@ -26,11 +24,11 @@
class FoldingRangeProtocolPart(LanguageServerProtocolPart, HasExtendCapabilities):
_logger: Final = LoggingDescriptor()

def __init__(self, parent: LanguageServerProtocol) -> None:
def __init__(self, parent: "LanguageServerProtocol") -> None:
super().__init__(parent)

@async_tasking_event
async def collect(sender, document: TextDocument) -> Optional[List[FoldingRange]]: # pragma: no cover, NOSONAR
@event
def collect(sender, document: TextDocument) -> Optional[List[FoldingRange]]: # pragma: no cover, NOSONAR
...

def extend_capabilities(self, capabilities: ServerCapabilities) -> None:
Expand All @@ -39,15 +37,15 @@ def extend_capabilities(self, capabilities: ServerCapabilities) -> None:

@rpc_method(name="textDocument/foldingRange", param_type=FoldingRangeParams)
@threaded()
async def _text_document_folding_range(
def _text_document_folding_range(
self, text_document: TextDocumentIdentifier, *args: Any, **kwargs: Any
) -> Optional[List[FoldingRange]]:
results: List[FoldingRange] = []
document = self.parent.documents.get(text_document.uri)
if document is None:
return None

for result in await self.collect(self, document, callback_filter=language_id_filter(document)):
for result in self.collect(self, document, callback_filter=language_id_filter(document)):
if isinstance(result, BaseException):
if not isinstance(result, CancelledError):
self._logger.exception(result, exc_info=result)
Expand All @@ -58,14 +56,10 @@ async def _text_document_folding_range(
if not results:
return None

for result in results:
if result.start_character is not None:
result.start_character = document.position_to_utf16(
Position(result.start_line, result.start_character)
).character
if result.end_character is not None:
result.end_character = document.position_to_utf16(
Position(result.end_line, result.end_character)
).character
for r in results:
if r.start_character is not None:
r.start_character = document.position_to_utf16(Position(r.start_line, r.start_character)).character
if r.end_character is not None:
r.end_character = document.position_to_utf16(Position(r.end_line, r.end_character)).character

return results
Loading

0 comments on commit 2fae8e3

Please sign in to comment.