Skip to content

Commit

Permalink
Fixed CapacityLimiter not waking up waiters on asyncio when setting t…
Browse files Browse the repository at this point in the history
…otal tokens

The problem was that it didn't remove events from the wait queue, thus blocking acquire_on_behalf_of() from seeing that there are in fact free slots.

Closes #646.
  • Loading branch information
agronholm committed Dec 8, 2023
1 parent ed3d307 commit b429c1a
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 11 deletions.
5 changes: 4 additions & 1 deletion docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.

- Add support for ``byte``-based paths in ``connect_unix``, ``create_unix_listeners``,
``create_unix_datagram_socket``, and ``create_connected_unix_datagram_socket``. (PR by
Lura Skye.)
Lura Skye)
- Fixed adjusting the total number of tokens in a ``CapacityLimiter`` on asyncio failing
to wake up tasks waiting to acquire the limiter in certain edge cases (fixed with help
from Egor Blagov)

**4.1.0**

Expand Down
15 changes: 5 additions & 10 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,19 +1649,14 @@ def total_tokens(self, value: float) -> None:
if value < 1:
raise ValueError("total_tokens must be >= 1")

old_value = self._total_tokens
waiters_to_notify = max(value - self._total_tokens, 0)
self._total_tokens = value
events = []
for event in self._wait_queue.values():
if value <= old_value:
break

if not event.is_set():
events.append(event)
old_value += 1

for event in events:
# Notify waiting tasks that they have acquired the limiter
while self._wait_queue and waiters_to_notify:
event = self._wait_queue.popitem(last=False)[1]
event.set()
waiters_to_notify -= 1

@property
def borrowed_tokens(self) -> int:
Expand Down
31 changes: 31 additions & 0 deletions tests/test_synchronization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Semaphore,
WouldBlock,
create_task_group,
fail_after,
to_thread,
wait_all_tasks_blocked,
)
Expand Down Expand Up @@ -564,3 +565,33 @@ async def append(x: int, task_status: TaskStatus) -> None:
event.set()

assert results == [0, 1, 2]

async def test_increase_tokens_lets_others_acquire(self) -> None:
limiter = CapacityLimiter(1)
entered_events = [Event() for _ in range(3)]
continue_event = Event()

async def worker(entered_event: Event) -> None:
async with limiter:
entered_event.set()
await continue_event.wait()

async with create_task_group() as tg:
for event in entered_events[:2]:
tg.start_soon(worker, event)

# One task should be able to acquire the limiter while the other is left
# waiting
await wait_all_tasks_blocked()
assert sum(ev.is_set() for ev in entered_events) == 1

# Increase the total tokens and start another worker.
# All tasks should be able to acquire the limiter now.
limiter.total_tokens = 3
tg.start_soon(worker, entered_events[2])
with fail_after(1):
for ev in entered_events[1:]:
await ev.wait()

# Allow all tasks to exit
continue_event.set()

0 comments on commit b429c1a

Please sign in to comment.