diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 63945b33..11120fa1 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -2097,12 +2097,18 @@ async def run_sync_in_worker_thread( context = copy_context() context.run(sniffio.current_async_library_cvar.set, None) - worker.queue.put_nowait((context, func, args, future, scope)) + if abandon_on_cancel or scope._parent_scope is None: + worker_scope = scope + else: + worker_scope = scope._parent_scope + + worker.queue.put_nowait((context, func, args, future, worker_scope)) return await future @classmethod def check_cancelled(cls) -> None: - if threadlocals.current_cancel_scope._parent_cancelled(): + scope = threadlocals.current_cancel_scope + if scope.cancel_called or scope._parent_cancelled(): raise CancelledError @classmethod @@ -2112,13 +2118,27 @@ def run_async_from_thread( args: tuple[Any, ...], token: object, ) -> T_Retval: + async def task_wrapper(scope: CancelScope) -> T_Retval: + __tracebackhide__ = True + task = cast(asyncio.Task, current_task()) + _task_states[task] = TaskState(None, scope) + scope._tasks.add(task) + try: + return await func(*args) + finally: + scope._tasks.discard(task) + loop = cast(AbstractEventLoop, token) context = copy_context() context.run(sniffio.current_async_library_cvar.set, "asyncio") + wrapper = task_wrapper(threadlocals.current_cancel_scope) f: concurrent.futures.Future[T_Retval] = context.run( - asyncio.run_coroutine_threadsafe, func(*args), loop + asyncio.run_coroutine_threadsafe, wrapper, loop ) - return f.result() + try: + return f.result() + except concurrent.futures.CancelledError: + raise CancelledError from None @classmethod def run_sync_from_thread( diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index d7556560..3127140c 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -36,7 +36,6 @@ import trio.lowlevel from outcome import Error, Outcome, Value from trio.lowlevel import ( - TrioToken, current_root_task, current_task, wait_readable, @@ -894,13 +893,13 @@ def run_async_from_thread( args: tuple[Any, ...], token: object, ) -> T_Retval: - return trio.from_thread.run(func, *args, trio_token=cast(TrioToken, token)) + return trio.from_thread.run(func, *args) @classmethod def run_sync_from_thread( cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object ) -> T_Retval: - return trio.from_thread.run_sync(func, *args, trio_token=cast(TrioToken, token)) + return trio.from_thread.run_sync(func, *args) @classmethod def create_blocking_portal(cls) -> abc.BlockingPortal: diff --git a/tests/test_from_thread.py b/tests/test_from_thread.py index c6b1f0aa..0e580462 100644 --- a/tests/test_from_thread.py +++ b/tests/test_from_thread.py @@ -16,8 +16,10 @@ from _pytest.logging import LogCaptureFixture from anyio import ( + CancelScope, Event, create_task_group, + fail_after, from_thread, get_all_backends, get_cancelled_exc_class, @@ -65,7 +67,8 @@ def thread_worker_sync(func: Callable[..., T_Retval], *args: Any) -> T_Retval: return from_thread.run_sync(func, *args) -async def test_thread_cancelled() -> None: +@pytest.mark.parametrize("cancel", [True, False]) +async def test_thread_cancelled(cancel: bool) -> None: event = threading.Event() thread_finished_future: Future[None] = Future() @@ -81,10 +84,15 @@ def sync_function() -> None: async with create_task_group() as tg: tg.start_soon(to_thread.run_sync, sync_function) await wait_all_tasks_blocked() - tg.cancel_scope.cancel() + if cancel: + tg.cancel_scope.cancel() + event.set() - with pytest.raises(get_cancelled_exc_class()): + if cancel: + with pytest.raises(get_cancelled_exc_class()): + thread_finished_future.result(3) + else: thread_finished_future.result(3) @@ -111,6 +119,40 @@ def sync_function() -> None: thread_finished_future.result(3) +async def test_cancelscope_propagation() -> None: + async def async_time_bomb() -> None: + cancel_scope.cancel() + with fail_after(1): + await sleep(3) + + with CancelScope() as cancel_scope: + await to_thread.run_sync(from_thread.run, async_time_bomb) + + assert cancel_scope.cancelled_caught + + +async def test_cancelscope_propagation_when_abandoned() -> None: + host_cancelled_event = Event() + completed_event = Event() + + async def async_time_bomb() -> None: + cancel_scope.cancel() + with fail_after(3): + await host_cancelled_event.wait() + + completed_event.set() + + with CancelScope() as cancel_scope: + await to_thread.run_sync( + from_thread.run, async_time_bomb, abandon_on_cancel=True + ) + + assert cancel_scope.cancelled_caught + host_cancelled_event.set() + with fail_after(3): + await completed_event.wait() + + class TestRunAsyncFromThread: async def test_run_corofunc_from_thread(self) -> None: result = await to_thread.run_sync(thread_worker_async, async_add, 1, 2)