Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented voluntary cancellation in worker threads #629

Merged
merged 12 commits into from
Nov 22, 2023
Merged
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Running asynchronous code from other threads

.. autofunction:: anyio.from_thread.run
.. autofunction:: anyio.from_thread.run_sync
.. autofunction:: anyio.from_thread.check_cancelled
.. autofunction:: anyio.from_thread.start_blocking_portal

.. autoclass:: anyio.from_thread.BlockingPortal
Expand Down
20 changes: 20 additions & 0 deletions docs/threads.rst
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,23 @@ maximum of 40 threads to be spawned. You can adjust this limit like this::

.. note:: AnyIO's default thread pool limiter does not affect the default thread pool
executor on :mod:`asyncio`.

Reacting to cancellation in worker threads
------------------------------------------

While there is no mechanism in Python to cancel code running in a thread, AnyIO provides a
mechanism that allows user code to voluntarily check if the host task's scope has been cancelled,
and if it has, raise a cancellation exception. This can be done by simply calling
:func:`from_thread.check_cancelled`::

from anyio import to_thread, from_thread

def sync_function():
while True:
from_thread.check_cancelled()
print("Not cancelled yet")
sleep(1)

async def foo():
with move_on_after(3):
await to_thread.run_sync(sync_function)
6 changes: 5 additions & 1 deletion docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
- Call ``trio.to_thread.run_sync()`` using the ``abandon_on_cancel`` keyword argument
instead of ``cancellable``
- Removed a checkpoint when exiting a task group
- Bumped minimum version of trio to v0.23
- Renamed the ``cancellable`` argument in ``anyio.to_thread.run_sync()`` to
``abandon_on_cancel`` (and deprecated the old parameter name)
- Bumped minimum version of Trio to v0.23
- Added support for voluntary thread cancellation via
``anyio.from_thread.check_cancelled()``

**4.0.0**

Expand Down
20 changes: 14 additions & 6 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
import sniffio

from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
from .._core._eventloop import claim_worker_thread
from .._core._eventloop import claim_worker_thread, threadlocals
from .._core._exceptions import (
BrokenResourceError,
BusyResourceError,
Expand Down Expand Up @@ -786,7 +786,7 @@ def __init__(
self.idle_workers = idle_workers
self.loop = root_task._loop
self.queue: Queue[
tuple[Context, Callable, tuple, asyncio.Future] | None
tuple[Context, Callable, tuple, asyncio.Future, CancelScope] | None
] = Queue(2)
self.idle_since = AsyncIOBackend.current_time()
self.stopping = False
Expand Down Expand Up @@ -817,14 +817,17 @@ def run(self) -> None:
# Shutdown command received
return

context, func, args, future = item
context, func, args, future, cancel_scope = item
if not future.cancelled():
result = None
exception: BaseException | None = None
threadlocals.current_cancel_scope = cancel_scope
try:
result = context.run(func, *args)
except BaseException as exc:
exception = exc
finally:
del threadlocals.current_cancel_scope

if not self.loop.is_closed():
self.loop.call_soon_threadsafe(
Expand Down Expand Up @@ -2048,7 +2051,7 @@ async def run_sync_in_worker_thread(
cls,
func: Callable[..., T_Retval],
args: tuple[Any, ...],
cancellable: bool = False,
abandon_on_cancel: bool = False,
limiter: abc.CapacityLimiter | None = None,
) -> T_Retval:
await cls.checkpoint()
Expand All @@ -2065,7 +2068,7 @@ async def run_sync_in_worker_thread(
_threadpool_workers.set(workers)

async with limiter or cls.current_default_thread_limiter():
with CancelScope(shield=not cancellable):
with CancelScope(shield=not abandon_on_cancel) as scope:
future: asyncio.Future = asyncio.Future()
root_task = find_root_task()
if not idle_workers:
Expand Down Expand Up @@ -2094,9 +2097,14 @@ async def run_sync_in_worker_thread(

context = copy_context()
context.run(sniffio.current_async_library_cvar.set, None)
worker.queue.put_nowait((context, func, args, future))
worker.queue.put_nowait((context, func, args, future, scope))
return await future

@classmethod
def check_cancelled(cls) -> None:
if threadlocals.current_cancel_scope._parent_cancelled():
raise CancelledError

@classmethod
def run_async_from_thread(
cls,
Expand Down
8 changes: 6 additions & 2 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ async def run_sync_in_worker_thread(
cls,
func: Callable[..., T_Retval],
args: tuple[Any, ...],
cancellable: bool = False,
abandon_on_cancel: bool = False,
limiter: abc.CapacityLimiter | None = None,
) -> T_Retval:
def wrapper() -> T_Retval:
Expand All @@ -879,10 +879,14 @@ def wrapper() -> T_Retval:
token = TrioBackend.current_token()
return await run_sync(
wrapper,
abandon_on_cancel=cancellable,
abandon_on_cancel=abandon_on_cancel,
limiter=cast(trio.CapacityLimiter, limiter),
)

@classmethod
def check_cancelled(cls) -> None:
trio.from_thread.check_cancelled()

@classmethod
def run_async_from_thread(
cls,
Expand Down
44 changes: 27 additions & 17 deletions src/anyio/_core/_fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@ class _PathIterator(AsyncIterator["Path"]):
iterator: Iterator[PathLike[str]]

async def __anext__(self) -> Path:
nextval = await to_thread.run_sync(next, self.iterator, None, cancellable=True)
nextval = await to_thread.run_sync(
next, self.iterator, None, abandon_on_cancel=True
)
if nextval is None:
raise StopAsyncIteration from None

Expand Down Expand Up @@ -386,17 +388,19 @@ async def cwd(cls) -> Path:
return cls(path)

async def exists(self) -> bool:
return await to_thread.run_sync(self._path.exists, cancellable=True)
return await to_thread.run_sync(self._path.exists, abandon_on_cancel=True)

async def expanduser(self) -> Path:
return Path(await to_thread.run_sync(self._path.expanduser, cancellable=True))
return Path(
await to_thread.run_sync(self._path.expanduser, abandon_on_cancel=True)
)

def glob(self, pattern: str) -> AsyncIterator[Path]:
gen = self._path.glob(pattern)
return _PathIterator(gen)

async def group(self) -> str:
return await to_thread.run_sync(self._path.group, cancellable=True)
return await to_thread.run_sync(self._path.group, abandon_on_cancel=True)

async def hardlink_to(self, target: str | pathlib.Path | Path) -> None:
if isinstance(target, Path):
Expand All @@ -413,31 +417,37 @@ def is_absolute(self) -> bool:
return self._path.is_absolute()

async def is_block_device(self) -> bool:
return await to_thread.run_sync(self._path.is_block_device, cancellable=True)
return await to_thread.run_sync(
self._path.is_block_device, abandon_on_cancel=True
)

async def is_char_device(self) -> bool:
return await to_thread.run_sync(self._path.is_char_device, cancellable=True)
return await to_thread.run_sync(
self._path.is_char_device, abandon_on_cancel=True
)

async def is_dir(self) -> bool:
return await to_thread.run_sync(self._path.is_dir, cancellable=True)
return await to_thread.run_sync(self._path.is_dir, abandon_on_cancel=True)

async def is_fifo(self) -> bool:
return await to_thread.run_sync(self._path.is_fifo, cancellable=True)
return await to_thread.run_sync(self._path.is_fifo, abandon_on_cancel=True)

async def is_file(self) -> bool:
return await to_thread.run_sync(self._path.is_file, cancellable=True)
return await to_thread.run_sync(self._path.is_file, abandon_on_cancel=True)

async def is_mount(self) -> bool:
return await to_thread.run_sync(os.path.ismount, self._path, cancellable=True)
return await to_thread.run_sync(
os.path.ismount, self._path, abandon_on_cancel=True
)

def is_reserved(self) -> bool:
return self._path.is_reserved()

async def is_socket(self) -> bool:
return await to_thread.run_sync(self._path.is_socket, cancellable=True)
return await to_thread.run_sync(self._path.is_socket, abandon_on_cancel=True)

async def is_symlink(self) -> bool:
return await to_thread.run_sync(self._path.is_symlink, cancellable=True)
return await to_thread.run_sync(self._path.is_symlink, abandon_on_cancel=True)

def iterdir(self) -> AsyncIterator[Path]:
gen = self._path.iterdir()
Expand All @@ -450,7 +460,7 @@ async def lchmod(self, mode: int) -> None:
await to_thread.run_sync(self._path.lchmod, mode)

async def lstat(self) -> os.stat_result:
return await to_thread.run_sync(self._path.lstat, cancellable=True)
return await to_thread.run_sync(self._path.lstat, abandon_on_cancel=True)

async def mkdir(
self, mode: int = 0o777, parents: bool = False, exist_ok: bool = False
Expand Down Expand Up @@ -493,7 +503,7 @@ async def open(
return AsyncFile(fp)

async def owner(self) -> str:
return await to_thread.run_sync(self._path.owner, cancellable=True)
return await to_thread.run_sync(self._path.owner, abandon_on_cancel=True)

async def read_bytes(self) -> bytes:
return await to_thread.run_sync(self._path.read_bytes)
Expand Down Expand Up @@ -526,7 +536,7 @@ async def replace(self, target: str | pathlib.PurePath | Path) -> Path:

async def resolve(self, strict: bool = False) -> Path:
func = partial(self._path.resolve, strict=strict)
return Path(await to_thread.run_sync(func, cancellable=True))
return Path(await to_thread.run_sync(func, abandon_on_cancel=True))

def rglob(self, pattern: str) -> AsyncIterator[Path]:
gen = self._path.rglob(pattern)
Expand All @@ -542,12 +552,12 @@ async def samefile(
other_path = other_path._path

return await to_thread.run_sync(
self._path.samefile, other_path, cancellable=True
self._path.samefile, other_path, abandon_on_cancel=True
)

async def stat(self, *, follow_symlinks: bool = True) -> os.stat_result:
func = partial(os.stat, follow_symlinks=follow_symlinks)
return await to_thread.run_sync(func, self._path, cancellable=True)
return await to_thread.run_sync(func, self._path, abandon_on_cancel=True)

async def symlink_to(
self,
Expand Down
4 changes: 2 additions & 2 deletions src/anyio/_core/_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,9 +693,9 @@ async def setup_unix_local_socket(

if path_str is not None:
try:
await to_thread.run_sync(raw_socket.bind, path_str, cancellable=True)
await to_thread.run_sync(raw_socket.bind, path_str, abandon_on_cancel=True)
if mode is not None:
await to_thread.run_sync(chmod, path_str, mode, cancellable=True)
await to_thread.run_sync(chmod, path_str, mode, abandon_on_cancel=True)
except BaseException:
raw_socket.close()
raise
Expand Down
7 changes: 6 additions & 1 deletion src/anyio/abc/_eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,16 @@ async def run_sync_in_worker_thread(
cls,
func: Callable[..., T_Retval],
args: tuple[Any, ...],
cancellable: bool = False,
abandon_on_cancel: bool = False,
limiter: CapacityLimiter | None = None,
) -> T_Retval:
pass

@classmethod
@abstractmethod
def check_cancelled(cls) -> None:
pass

@classmethod
@abstractmethod
def run_async_from_thread(
Expand Down
31 changes: 29 additions & 2 deletions src/anyio/from_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ._core._eventloop import get_async_backend, get_cancelled_exc_class, threadlocals
from ._core._synchronization import Event
from ._core._tasks import CancelScope, create_task_group
from .abc import AsyncBackend
from .abc._tasks import TaskStatus

T_Retval = TypeVar("T_Retval")
Expand All @@ -40,7 +41,9 @@ def run(func: Callable[..., Awaitable[T_Retval]], *args: object) -> T_Retval:
async_backend = threadlocals.current_async_backend
token = threadlocals.current_token
except AttributeError:
raise RuntimeError("This function can only be run from an AnyIO worker thread")
raise RuntimeError(
"This function can only be run from an AnyIO worker thread"
) from None

return async_backend.run_async_from_thread(func, args, token=token)

Expand All @@ -58,7 +61,9 @@ def run_sync(func: Callable[..., T_Retval], *args: object) -> T_Retval:
async_backend = threadlocals.current_async_backend
token = threadlocals.current_token
except AttributeError:
raise RuntimeError("This function can only be run from an AnyIO worker thread")
raise RuntimeError(
"This function can only be run from an AnyIO worker thread"
) from None

return async_backend.run_sync_from_thread(func, args, token=token)

Expand Down Expand Up @@ -422,3 +427,25 @@ async def run_portal() -> None:
pass

run_future.result()


def check_cancelled() -> None:
"""
Check if the cancel scope of the host task's running the current worker thread has
been cancelled.

If the host task's current cancel scope has indeed been cancelled, the
backend-specific cancellation exception will be raised.

:raises RuntimeError: if the current thread was not spawned by
:func:`.to_thread.run_sync`

"""
try:
async_backend: AsyncBackend = threadlocals.current_async_backend
except AttributeError:
raise RuntimeError(
"This function can only be run from an AnyIO worker thread"
) from None

async_backend.check_cancelled()
21 changes: 18 additions & 3 deletions src/anyio/to_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections.abc import Callable
from typing import TypeVar
from warnings import warn

from ._core._eventloop import get_async_backend
from .abc import CapacityLimiter
Expand All @@ -12,7 +13,8 @@
async def run_sync(
func: Callable[..., T_Retval],
*args: object,
cancellable: bool = False,
abandon_on_cancel: bool = False,
cancellable: bool | None = None,
limiter: CapacityLimiter | None = None,
) -> T_Retval:
"""
Expand All @@ -24,14 +26,27 @@ async def run_sync(

:param func: a callable
:param args: positional arguments for the callable
:param cancellable: ``True`` to allow cancellation of the operation
:param abandon_on_cancel: ``True`` to abandon the thread (leaving it to run
unchecked on own) if the host task is cancelled, ``False`` to ignore
cancellations in the host task until the operation has completed in the worker
thread
:param cancellable: deprecated alias of ``abandon_on_cancel``
:param limiter: capacity limiter to use to limit the total amount of threads running
(if omitted, the default limiter is used)
:return: an awaitable that yields the return value of the function.

"""
if cancellable is not None:
abandon_on_cancel = cancellable
warn(
"The `cancellable=` keyword argument to `anyio.to_thread.run_sync` is "
"deprecated since AnyIO 4.1.0; use `abandon_on_cancel=` instead",
DeprecationWarning,
stacklevel=2,
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allowing both aliases to be passed is a valid choice but departs from trio. Maybe also document/test which one overrides the other.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree if both are passed it'd be nice to throw an error, but then you need to change the types a little 🤷‍♀️ happy with either implementation in practice since a warning will still be raised.

return await get_async_backend().run_sync_in_worker_thread(
func, args, cancellable=cancellable, limiter=limiter
func, args, abandon_on_cancel=abandon_on_cancel, limiter=limiter
)


Expand Down
Loading