From b429c1a3202dd3eff2eaad0124b327b3cad15b01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Fri, 8 Dec 2023 11:07:26 +0200 Subject: [PATCH] Fixed CapacityLimiter not waking up waiters on asyncio when setting total tokens The problem was that it didn't remove events from the wait queue, thus blocking acquire_on_behalf_of() from seeing that there are in fact free slots. Closes #646. --- docs/versionhistory.rst | 5 ++++- src/anyio/_backends/_asyncio.py | 15 +++++---------- tests/test_synchronization.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index c8b089c9..af625fd8 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -7,7 +7,10 @@ This library adheres to `Semantic Versioning 2.0 `_. - Add support for ``byte``-based paths in ``connect_unix``, ``create_unix_listeners``, ``create_unix_datagram_socket``, and ``create_connected_unix_datagram_socket``. (PR by - Lura Skye.) + Lura Skye) +- Fixed adjusting the total number of tokens in a ``CapacityLimiter`` on asyncio failing + to wake up tasks waiting to acquire the limiter in certain edge cases (fixed with help + from Egor Blagov) **4.1.0** diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 71d2ca2b..7a10e17b 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -1649,19 +1649,14 @@ def total_tokens(self, value: float) -> None: if value < 1: raise ValueError("total_tokens must be >= 1") - old_value = self._total_tokens + waiters_to_notify = max(value - self._total_tokens, 0) self._total_tokens = value - events = [] - for event in self._wait_queue.values(): - if value <= old_value: - break - - if not event.is_set(): - events.append(event) - old_value += 1 - for event in events: + # Notify waiting tasks that they have acquired the limiter + while self._wait_queue and waiters_to_notify: + event = self._wait_queue.popitem(last=False)[1] event.set() + waiters_to_notify -= 1 @property def borrowed_tokens(self) -> int: diff --git a/tests/test_synchronization.py b/tests/test_synchronization.py index a4c1207d..6011710a 100644 --- a/tests/test_synchronization.py +++ b/tests/test_synchronization.py @@ -12,6 +12,7 @@ Semaphore, WouldBlock, create_task_group, + fail_after, to_thread, wait_all_tasks_blocked, ) @@ -564,3 +565,33 @@ async def append(x: int, task_status: TaskStatus) -> None: event.set() assert results == [0, 1, 2] + + async def test_increase_tokens_lets_others_acquire(self) -> None: + limiter = CapacityLimiter(1) + entered_events = [Event() for _ in range(3)] + continue_event = Event() + + async def worker(entered_event: Event) -> None: + async with limiter: + entered_event.set() + await continue_event.wait() + + async with create_task_group() as tg: + for event in entered_events[:2]: + tg.start_soon(worker, event) + + # One task should be able to acquire the limiter while the other is left + # waiting + await wait_all_tasks_blocked() + assert sum(ev.is_set() for ev in entered_events) == 1 + + # Increase the total tokens and start another worker. + # All tasks should be able to acquire the limiter now. + limiter.total_tokens = 3 + tg.start_soon(worker, entered_events[2]) + with fail_after(1): + for ev in entered_events[1:]: + await ev.wait() + + # Allow all tasks to exit + continue_event.set()