From 238a340d1237e1d5e7353eec36f7adc03746fc5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Fri, 8 Dec 2023 16:32:00 +0200 Subject: [PATCH] Fixed cancellation propagation when task group host is in a shielded scope Fixes #642. --- docs/versionhistory.rst | 3 + src/anyio/_backends/_asyncio.py | 104 ++++++++++++++++++-------------- tests/test_taskgroups.py | 23 +++++++ 3 files changed, 85 insertions(+), 45 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 78ef5105..94eaf15f 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -13,6 +13,9 @@ This library adheres to `Semantic Versioning 2.0 `_. from Egor Blagov) - Fixed ``loop_factory`` and ``use_uvloop`` options not being used on the asyncio backend (`#643 `_) +- Fixed cancellation propagating on asyncio from a task group to child tasks if the task + hosting the task group is in a shielded cancel scope + (`#642 `_) **4.1.0** diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index bb6a9bff..84502851 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -343,6 +343,7 @@ def __init__(self, deadline: float = math.inf, shield: bool = False): self._deadline = deadline self._shield = shield self._parent_scope: CancelScope | None = None + self._child_scopes: set[CancelScope] = set() self._cancel_called = False self._cancelled_caught = False self._active = False @@ -369,6 +370,9 @@ def __enter__(self) -> CancelScope: else: self._parent_scope = task_state.cancel_scope task_state.cancel_scope = self + if self._parent_scope is not None: + self._parent_scope._child_scopes.add(self) + self._parent_scope._tasks.remove(host_task) self._timeout() self._active = True @@ -377,7 +381,7 @@ def __enter__(self) -> CancelScope: # Start cancelling the host task if the scope was cancelled before entering if self._cancel_called: - self._deliver_cancellation() + self._deliver_cancellation(self) return self @@ -409,13 +413,15 @@ def __exit__( self._timeout_handle = None self._tasks.remove(self._host_task) + if self._parent_scope is not None: + self._parent_scope._child_scopes.remove(self) + self._parent_scope._tasks.add(self._host_task) host_task_state.cancel_scope = self._parent_scope # Restart the cancellation effort in the farthest directly cancelled parent # scope if this one was shielded - if self._shield: - self._deliver_cancellation_to_parent() + self._restart_cancellation_in_parent() if self._cancel_called and exc_val is not None: for exc in iterate_exceptions(exc_val): @@ -451,12 +457,16 @@ def _timeout(self) -> None: else: self._timeout_handle = loop.call_at(self._deadline, self._timeout) - def _deliver_cancellation(self) -> None: + def _deliver_cancellation(self, origin: CancelScope) -> bool: """ Deliver cancellation to directly contained tasks and nested cancel scopes. Schedule another run at the end if we still have tasks eligible for cancellation. + + :param origin: the cancel scope that originated the cancellation + :return: ``True`` if the delivery needs to be retried on the next cycle + """ should_retry = False current = current_task() @@ -464,42 +474,43 @@ def _deliver_cancellation(self) -> None: if task._must_cancel: # type: ignore[attr-defined] continue - # The task is eligible for cancellation if it has started and is not in a - # cancel scope shielded from this one - cancel_scope = _task_states[task].cancel_scope - while cancel_scope is not self: - if cancel_scope is None or cancel_scope._shield: - break - else: - cancel_scope = cancel_scope._parent_scope - else: - should_retry = True - if task is not current and ( - task is self._host_task or _task_started(task) - ): - waiter = task._fut_waiter # type: ignore[attr-defined] - if not isinstance(waiter, asyncio.Future) or not waiter.done(): - self._cancel_calls += 1 - if sys.version_info >= (3, 9): - task.cancel(f"Cancelled by cancel scope {id(self):x}") - else: - task.cancel() + # The task is eligible for cancellation if it has started + should_retry = True + if task is not current and (task is self._host_task or _task_started(task)): + waiter = task._fut_waiter # type: ignore[attr-defined] + if not isinstance(waiter, asyncio.Future) or not waiter.done(): + self._cancel_calls += 1 + if sys.version_info >= (3, 9): + task.cancel(f"Cancelled by cancel scope {id(origin):x}") + else: + task.cancel() + + # Deliver cancellation to child scopes that aren't shielded or running their own + # cancellation callbacks + for scope in self._child_scopes: + if not scope._shield and not scope.cancel_called: + should_retry = scope._deliver_cancellation(origin) or should_retry # Schedule another callback if there are still tasks left - if should_retry: - self._cancel_handle = get_running_loop().call_soon( - self._deliver_cancellation - ) - else: - self._cancel_handle = None + if origin is self: + if should_retry: + self._cancel_handle = get_running_loop().call_soon( + self._deliver_cancellation, origin + ) + else: + self._cancel_handle = None + + return should_retry - def _deliver_cancellation_to_parent(self) -> None: - """Start cancellation effort in the farthest directly cancelled parent scope""" + def _restart_cancellation_in_parent(self) -> None: + """Start cancellation effort in the closest directly cancelled parent scope""" scope = self._parent_scope - scope_to_cancel: CancelScope | None = None while scope is not None: - if scope._cancel_called and scope._cancel_handle is None: - scope_to_cancel = scope + if scope._cancel_called: + if scope._cancel_handle is None: + scope._deliver_cancellation(scope) + + break # No point in looking beyond any shielded scope if scope._shield: @@ -507,9 +518,6 @@ def _deliver_cancellation_to_parent(self) -> None: scope = scope._parent_scope - if scope_to_cancel is not None: - scope_to_cancel._deliver_cancellation() - def _parent_cancelled(self) -> bool: # Check whether any parent has been cancelled cancel_scope = self._parent_scope @@ -529,7 +537,7 @@ def cancel(self) -> None: self._cancel_called = True if self._host_task is not None: - self._deliver_cancellation() + self._deliver_cancellation(self) @property def deadline(self) -> float: @@ -562,7 +570,7 @@ def shield(self, value: bool) -> None: if self._shield != value: self._shield = value if not value: - self._deliver_cancellation_to_parent() + self._restart_cancellation_in_parent() # @@ -623,6 +631,7 @@ def __init__(self) -> None: self.cancel_scope: CancelScope = CancelScope() self._active = False self._exceptions: list[BaseException] = [] + self._tasks: set[asyncio.Task] = set() async def __aenter__(self) -> TaskGroup: self.cancel_scope.__enter__() @@ -642,9 +651,9 @@ async def __aexit__( self._exceptions.append(exc_val) cancelled_exc_while_waiting_tasks: CancelledError | None = None - while self.cancel_scope._tasks: + while self._tasks: try: - await asyncio.wait(self.cancel_scope._tasks) + await asyncio.wait(self._tasks) except CancelledError as exc: # This task was cancelled natively; reraise the CancelledError later # unless this task was already interrupted by another exception @@ -676,8 +685,11 @@ def _spawn( task_status_future: asyncio.Future | None = None, ) -> asyncio.Task: def task_done(_task: asyncio.Task) -> None: - assert _task in self.cancel_scope._tasks - self.cancel_scope._tasks.remove(_task) + task_state = _task_states[_task] + assert task_state.cancel_scope is not None + assert _task in task_state.cancel_scope._tasks + task_state.cancel_scope._tasks.remove(_task) + self._tasks.remove(task) del _task_states[_task] try: @@ -693,7 +705,8 @@ def task_done(_task: asyncio.Task) -> None: if not isinstance(exc, CancelledError): self._exceptions.append(exc) - self.cancel_scope.cancel() + if not self.cancel_scope._parent_cancelled(): + self.cancel_scope.cancel() else: task_status_future.set_exception(exc) elif task_status_future is not None and not task_status_future.done(): @@ -732,6 +745,7 @@ def task_done(_task: asyncio.Task) -> None: parent_id=parent_id, cancel_scope=self.cancel_scope ) self.cancel_scope._tasks.add(task) + self._tasks.add(task) return task def start_soon( diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index 221f00f4..abe93e23 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -1293,6 +1293,29 @@ def handler(excgrp: BaseExceptionGroup) -> None: await anyio.sleep_forever() +async def test_cancel_child_task_when_host_is_shielded() -> None: + # Regression test for #642 + # Tests that cancellation propagates to a child task even if the host task is within + # a shielded cancel scope. + cancelled = anyio.Event() + + async def wait_cancel() -> None: + try: + await anyio.sleep_forever() + except anyio.get_cancelled_exc_class(): + cancelled.set() + raise + + with CancelScope() as parent_scope: + async with anyio.create_task_group() as task_group: + task_group.start_soon(wait_cancel) + await wait_all_tasks_blocked() + + with CancelScope(shield=True), fail_after(1): + parent_scope.cancel() + await cancelled.wait() + + class TestTaskStatusTyping: """ These tests do not do anything at run time, but since the test suite is also checked