Skip to content

Commit

Permalink
Implemented cancel scope propagation to tasks spawned from worker thr…
Browse files Browse the repository at this point in the history
…eads
  • Loading branch information
agronholm committed Nov 19, 2023
1 parent 527cf01 commit 793095e
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 10 deletions.
28 changes: 24 additions & 4 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2097,12 +2097,18 @@ async def run_sync_in_worker_thread(

context = copy_context()
context.run(sniffio.current_async_library_cvar.set, None)
worker.queue.put_nowait((context, func, args, future, scope))
if abandon_on_cancel or scope._parent_scope is None:
worker_scope = scope
else:
worker_scope = scope._parent_scope

worker.queue.put_nowait((context, func, args, future, worker_scope))
return await future

@classmethod
def check_cancelled(cls) -> None:
if threadlocals.current_cancel_scope._parent_cancelled():
scope = threadlocals.current_cancel_scope
if scope.cancel_called or scope._parent_cancelled():
raise CancelledError

@classmethod
Expand All @@ -2112,13 +2118,27 @@ def run_async_from_thread(
args: tuple[Any, ...],
token: object,
) -> T_Retval:
async def task_wrapper(scope: CancelScope) -> T_Retval:
__tracebackhide__ = True
task = cast(asyncio.Task, current_task())
_task_states[task] = TaskState(None, scope)
scope._tasks.add(task)
try:
return await func(*args)
finally:
scope._tasks.discard(task)

loop = cast(AbstractEventLoop, token)
context = copy_context()
context.run(sniffio.current_async_library_cvar.set, "asyncio")
wrapper = task_wrapper(threadlocals.current_cancel_scope)
f: concurrent.futures.Future[T_Retval] = context.run(
asyncio.run_coroutine_threadsafe, func(*args), loop
asyncio.run_coroutine_threadsafe, wrapper, loop
)
return f.result()
try:
return f.result()
except concurrent.futures.CancelledError:
raise CancelledError from None

@classmethod
def run_sync_from_thread(
Expand Down
5 changes: 2 additions & 3 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import trio.lowlevel
from outcome import Error, Outcome, Value
from trio.lowlevel import (
TrioToken,
current_root_task,
current_task,
wait_readable,
Expand Down Expand Up @@ -894,13 +893,13 @@ def run_async_from_thread(
args: tuple[Any, ...],
token: object,
) -> T_Retval:
return trio.from_thread.run(func, *args, trio_token=cast(TrioToken, token))
return trio.from_thread.run(func, *args)

@classmethod
def run_sync_from_thread(
cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object
) -> T_Retval:
return trio.from_thread.run_sync(func, *args, trio_token=cast(TrioToken, token))
return trio.from_thread.run_sync(func, *args)

@classmethod
def create_blocking_portal(cls) -> abc.BlockingPortal:
Expand Down
48 changes: 45 additions & 3 deletions tests/test_from_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from _pytest.logging import LogCaptureFixture

from anyio import (
CancelScope,
Event,
create_task_group,
fail_after,
from_thread,
get_all_backends,
get_cancelled_exc_class,
Expand Down Expand Up @@ -65,7 +67,8 @@ def thread_worker_sync(func: Callable[..., T_Retval], *args: Any) -> T_Retval:
return from_thread.run_sync(func, *args)


async def test_thread_cancelled() -> None:
@pytest.mark.parametrize("cancel", [True, False])
async def test_thread_cancelled(cancel: bool) -> None:
event = threading.Event()
thread_finished_future: Future[None] = Future()

Expand All @@ -81,10 +84,15 @@ def sync_function() -> None:
async with create_task_group() as tg:
tg.start_soon(to_thread.run_sync, sync_function)
await wait_all_tasks_blocked()
tg.cancel_scope.cancel()
if cancel:
tg.cancel_scope.cancel()

event.set()

with pytest.raises(get_cancelled_exc_class()):
if cancel:
with pytest.raises(get_cancelled_exc_class()):
thread_finished_future.result(3)
else:
thread_finished_future.result(3)


Expand All @@ -111,6 +119,40 @@ def sync_function() -> None:
thread_finished_future.result(3)


async def test_cancelscope_propagation() -> None:
async def async_time_bomb() -> None:
cancel_scope.cancel()
with fail_after(1):
await sleep(3)

with CancelScope() as cancel_scope:
await to_thread.run_sync(from_thread.run, async_time_bomb)

assert cancel_scope.cancelled_caught


async def test_cancelscope_propagation_when_abandoned() -> None:
host_cancelled_event = Event()
completed_event = Event()

async def async_time_bomb() -> None:
cancel_scope.cancel()
with fail_after(3):
await host_cancelled_event.wait()

completed_event.set()

with CancelScope() as cancel_scope:
await to_thread.run_sync(
from_thread.run, async_time_bomb, abandon_on_cancel=True
)

assert cancel_scope.cancelled_caught
host_cancelled_event.set()
with fail_after(3):
await completed_event.wait()


class TestRunAsyncFromThread:
async def test_run_corofunc_from_thread(self) -> None:
result = await to_thread.run_sync(thread_worker_async, async_add, 1, 2)
Expand Down

0 comments on commit 793095e

Please sign in to comment.