Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented voluntary cancellation in worker threads #629

Merged
merged 12 commits into from
Nov 22, 2023
Merged
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Running asynchronous code from other threads

.. autofunction:: anyio.from_thread.run
.. autofunction:: anyio.from_thread.run_sync
.. autofunction:: anyio.from_thread.check_cancelled
.. autofunction:: anyio.from_thread.start_blocking_portal

.. autoclass:: anyio.from_thread.BlockingPortal
Expand Down
20 changes: 20 additions & 0 deletions docs/threads.rst
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,23 @@ maximum of 40 threads to be spawned. You can adjust this limit like this::

.. note:: AnyIO's default thread pool limiter does not affect the default thread pool
executor on :mod:`asyncio`.

Reacting to cancellation in worker threads
------------------------------------------

While there is no mechanism in Python to cancel code running in a thread, AnyIO provides a
mechanism that allows user code to voluntarily check if the host task's scope has been cancelled,
and if it has, raise a cancellation exception. This can be done by simply calling
:func:`from_thread.check_cancelled`::

from anyio import to_thread, from_thread

def sync_function():
while True:
from_thread.check_cancelled()
print("Not cancelled yet")
sleep(1)

async def foo():
with move_on_after(3):
await to_thread.run_sync(sync_function)
5 changes: 5 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
- Call ``trio.to_thread.run_sync()`` using the ``abandon_on_cancel`` keyword argument
instead of ``cancellable``
- Removed a checkpoint when exiting a task group
- Renamed the ``cancellable`` argument in ``anyio.to_thread.run_sync()`` to
``abandon_on_cancel`` (and deprecated the old parameter name)
- Bumped minimum version of Trio to v0.23
- Added support for voluntary thread cancellation via
``anyio.from_thread.check_cancelled()``
- Bumped minimum version of trio to v0.23
- Exposed the ``ResourceGuard`` class in the public API

Expand Down
47 changes: 40 additions & 7 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
import sniffio

from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
from .._core._eventloop import claim_worker_thread
from .._core._eventloop import claim_worker_thread, threadlocals
from .._core._exceptions import (
BrokenResourceError,
BusyResourceError,
Expand Down Expand Up @@ -786,7 +786,7 @@ def __init__(
self.idle_workers = idle_workers
self.loop = root_task._loop
self.queue: Queue[
tuple[Context, Callable, tuple, asyncio.Future] | None
tuple[Context, Callable, tuple, asyncio.Future, CancelScope] | None
] = Queue(2)
self.idle_since = AsyncIOBackend.current_time()
self.stopping = False
Expand Down Expand Up @@ -817,14 +817,17 @@ def run(self) -> None:
# Shutdown command received
return

context, func, args, future = item
context, func, args, future, cancel_scope = item
if not future.cancelled():
result = None
exception: BaseException | None = None
threadlocals.current_cancel_scope = cancel_scope
try:
result = context.run(func, *args)
except BaseException as exc:
exception = exc
finally:
del threadlocals.current_cancel_scope

if not self.loop.is_closed():
self.loop.call_soon_threadsafe(
Expand Down Expand Up @@ -2048,7 +2051,7 @@ async def run_sync_in_worker_thread(
cls,
func: Callable[..., T_Retval],
args: tuple[Any, ...],
cancellable: bool = False,
abandon_on_cancel: bool = False,
limiter: abc.CapacityLimiter | None = None,
) -> T_Retval:
await cls.checkpoint()
Expand All @@ -2065,7 +2068,7 @@ async def run_sync_in_worker_thread(
_threadpool_workers.set(workers)

async with limiter or cls.current_default_thread_limiter():
with CancelScope(shield=not cancellable):
with CancelScope(shield=not abandon_on_cancel) as scope:
future: asyncio.Future = asyncio.Future()
root_task = find_root_task()
if not idle_workers:
Expand Down Expand Up @@ -2094,21 +2097,51 @@ async def run_sync_in_worker_thread(

context = copy_context()
context.run(sniffio.current_async_library_cvar.set, None)
worker.queue.put_nowait((context, func, args, future))
if abandon_on_cancel or scope._parent_scope is None:
worker_scope = scope
else:
worker_scope = scope._parent_scope

worker.queue.put_nowait((context, func, args, future, worker_scope))
return await future

@classmethod
def check_cancelled(cls) -> None:
scope: CancelScope | None = threadlocals.current_cancel_scope
while scope is not None:
if scope.cancel_called:
raise CancelledError(f"Cancelled by cancel scope {id(scope):x}")

if scope.shield:
return

scope = scope._parent_scope

@classmethod
def run_async_from_thread(
cls,
func: Callable[..., Awaitable[T_Retval]],
args: tuple[Any, ...],
token: object,
) -> T_Retval:
async def task_wrapper(scope: CancelScope) -> T_Retval:
__tracebackhide__ = True
task = cast(asyncio.Task, current_task())
_task_states[task] = TaskState(None, scope)
scope._tasks.add(task)
try:
return await func(*args)
except CancelledError as exc:
raise concurrent.futures.CancelledError(str(exc)) from None
Comment on lines +2131 to +2132
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this break the chain of any exceptions attached to the CancelledError? I guess that's probably rare.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the chain is preserved anyway, but I can check if you like.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked and it doesn't seem to matter. There's a C level black box between that reraise and the calling synchronous code.

finally:
scope._tasks.discard(task)

loop = cast(AbstractEventLoop, token)
context = copy_context()
context.run(sniffio.current_async_library_cvar.set, "asyncio")
wrapper = task_wrapper(threadlocals.current_cancel_scope)
f: concurrent.futures.Future[T_Retval] = context.run(
asyncio.run_coroutine_threadsafe, func(*args), loop
asyncio.run_coroutine_threadsafe, wrapper, loop
)
return f.result()

Expand Down
13 changes: 8 additions & 5 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import trio.lowlevel
from outcome import Error, Outcome, Value
from trio.lowlevel import (
TrioToken,
current_root_task,
current_task,
wait_readable,
Expand Down Expand Up @@ -869,7 +868,7 @@ async def run_sync_in_worker_thread(
cls,
func: Callable[..., T_Retval],
args: tuple[Any, ...],
cancellable: bool = False,
abandon_on_cancel: bool = False,
limiter: abc.CapacityLimiter | None = None,
) -> T_Retval:
def wrapper() -> T_Retval:
Expand All @@ -879,24 +878,28 @@ def wrapper() -> T_Retval:
token = TrioBackend.current_token()
return await run_sync(
wrapper,
abandon_on_cancel=cancellable,
abandon_on_cancel=abandon_on_cancel,
limiter=cast(trio.CapacityLimiter, limiter),
)

@classmethod
def check_cancelled(cls) -> None:
trio.from_thread.check_cancelled()

@classmethod
def run_async_from_thread(
cls,
func: Callable[..., Awaitable[T_Retval]],
args: tuple[Any, ...],
token: object,
) -> T_Retval:
return trio.from_thread.run(func, *args, trio_token=cast(TrioToken, token))
return trio.from_thread.run(func, *args)

@classmethod
def run_sync_from_thread(
cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object
) -> T_Retval:
return trio.from_thread.run_sync(func, *args, trio_token=cast(TrioToken, token))
return trio.from_thread.run_sync(func, *args)

@classmethod
def create_blocking_portal(cls) -> abc.BlockingPortal:
Expand Down
44 changes: 27 additions & 17 deletions src/anyio/_core/_fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@ class _PathIterator(AsyncIterator["Path"]):
iterator: Iterator[PathLike[str]]

async def __anext__(self) -> Path:
nextval = await to_thread.run_sync(next, self.iterator, None, cancellable=True)
nextval = await to_thread.run_sync(
next, self.iterator, None, abandon_on_cancel=True
)
if nextval is None:
raise StopAsyncIteration from None

Expand Down Expand Up @@ -386,17 +388,19 @@ async def cwd(cls) -> Path:
return cls(path)

async def exists(self) -> bool:
return await to_thread.run_sync(self._path.exists, cancellable=True)
return await to_thread.run_sync(self._path.exists, abandon_on_cancel=True)

async def expanduser(self) -> Path:
return Path(await to_thread.run_sync(self._path.expanduser, cancellable=True))
return Path(
await to_thread.run_sync(self._path.expanduser, abandon_on_cancel=True)
)

def glob(self, pattern: str) -> AsyncIterator[Path]:
gen = self._path.glob(pattern)
return _PathIterator(gen)

async def group(self) -> str:
return await to_thread.run_sync(self._path.group, cancellable=True)
return await to_thread.run_sync(self._path.group, abandon_on_cancel=True)

async def hardlink_to(self, target: str | pathlib.Path | Path) -> None:
if isinstance(target, Path):
Expand All @@ -413,31 +417,37 @@ def is_absolute(self) -> bool:
return self._path.is_absolute()

async def is_block_device(self) -> bool:
return await to_thread.run_sync(self._path.is_block_device, cancellable=True)
return await to_thread.run_sync(
self._path.is_block_device, abandon_on_cancel=True
)

async def is_char_device(self) -> bool:
return await to_thread.run_sync(self._path.is_char_device, cancellable=True)
return await to_thread.run_sync(
self._path.is_char_device, abandon_on_cancel=True
)

async def is_dir(self) -> bool:
return await to_thread.run_sync(self._path.is_dir, cancellable=True)
return await to_thread.run_sync(self._path.is_dir, abandon_on_cancel=True)

async def is_fifo(self) -> bool:
return await to_thread.run_sync(self._path.is_fifo, cancellable=True)
return await to_thread.run_sync(self._path.is_fifo, abandon_on_cancel=True)

async def is_file(self) -> bool:
return await to_thread.run_sync(self._path.is_file, cancellable=True)
return await to_thread.run_sync(self._path.is_file, abandon_on_cancel=True)

async def is_mount(self) -> bool:
return await to_thread.run_sync(os.path.ismount, self._path, cancellable=True)
return await to_thread.run_sync(
os.path.ismount, self._path, abandon_on_cancel=True
)

def is_reserved(self) -> bool:
return self._path.is_reserved()

async def is_socket(self) -> bool:
return await to_thread.run_sync(self._path.is_socket, cancellable=True)
return await to_thread.run_sync(self._path.is_socket, abandon_on_cancel=True)

async def is_symlink(self) -> bool:
return await to_thread.run_sync(self._path.is_symlink, cancellable=True)
return await to_thread.run_sync(self._path.is_symlink, abandon_on_cancel=True)

def iterdir(self) -> AsyncIterator[Path]:
gen = self._path.iterdir()
Expand All @@ -450,7 +460,7 @@ async def lchmod(self, mode: int) -> None:
await to_thread.run_sync(self._path.lchmod, mode)

async def lstat(self) -> os.stat_result:
return await to_thread.run_sync(self._path.lstat, cancellable=True)
return await to_thread.run_sync(self._path.lstat, abandon_on_cancel=True)

async def mkdir(
self, mode: int = 0o777, parents: bool = False, exist_ok: bool = False
Expand Down Expand Up @@ -493,7 +503,7 @@ async def open(
return AsyncFile(fp)

async def owner(self) -> str:
return await to_thread.run_sync(self._path.owner, cancellable=True)
return await to_thread.run_sync(self._path.owner, abandon_on_cancel=True)

async def read_bytes(self) -> bytes:
return await to_thread.run_sync(self._path.read_bytes)
Expand Down Expand Up @@ -526,7 +536,7 @@ async def replace(self, target: str | pathlib.PurePath | Path) -> Path:

async def resolve(self, strict: bool = False) -> Path:
func = partial(self._path.resolve, strict=strict)
return Path(await to_thread.run_sync(func, cancellable=True))
return Path(await to_thread.run_sync(func, abandon_on_cancel=True))

def rglob(self, pattern: str) -> AsyncIterator[Path]:
gen = self._path.rglob(pattern)
Expand All @@ -542,12 +552,12 @@ async def samefile(
other_path = other_path._path

return await to_thread.run_sync(
self._path.samefile, other_path, cancellable=True
self._path.samefile, other_path, abandon_on_cancel=True
)

async def stat(self, *, follow_symlinks: bool = True) -> os.stat_result:
func = partial(os.stat, follow_symlinks=follow_symlinks)
return await to_thread.run_sync(func, self._path, cancellable=True)
return await to_thread.run_sync(func, self._path, abandon_on_cancel=True)

async def symlink_to(
self,
Expand Down
4 changes: 2 additions & 2 deletions src/anyio/_core/_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,9 +693,9 @@ async def setup_unix_local_socket(

if path_str is not None:
try:
await to_thread.run_sync(raw_socket.bind, path_str, cancellable=True)
await to_thread.run_sync(raw_socket.bind, path_str, abandon_on_cancel=True)
if mode is not None:
await to_thread.run_sync(chmod, path_str, mode, cancellable=True)
await to_thread.run_sync(chmod, path_str, mode, abandon_on_cancel=True)
except BaseException:
raw_socket.close()
raise
Expand Down
7 changes: 6 additions & 1 deletion src/anyio/abc/_eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,16 @@ async def run_sync_in_worker_thread(
cls,
func: Callable[..., T_Retval],
args: tuple[Any, ...],
cancellable: bool = False,
abandon_on_cancel: bool = False,
limiter: CapacityLimiter | None = None,
) -> T_Retval:
pass

@classmethod
@abstractmethod
def check_cancelled(cls) -> None:
pass

@classmethod
@abstractmethod
def run_async_from_thread(
Expand Down
Loading