Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Lib/asyncio/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
__all__ = ('BrokenBarrierError',
'CancelledError', 'InvalidStateError', 'TimeoutError',
'IncompleteReadError', 'LimitOverrunError',
'SendfileNotAvailableError')
'SendfileNotAvailableError',
'WouldBlock')


class CancelledError(BaseException):
Expand Down Expand Up @@ -60,3 +61,7 @@ def __reduce__(self):

class BrokenBarrierError(RuntimeError):
"""Barrier is broken by barrier.abort() call."""


class WouldBlock(Exception):
"""Raised by nowait functions when the operation would block."""
183 changes: 182 additions & 1 deletion Lib/asyncio/locks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Synchronization primitives."""

__all__ = ('Lock', 'Event', 'Condition', 'Semaphore',
'BoundedSemaphore', 'Barrier')
'BoundedSemaphore', 'Barrier',
'CapacityLimiter', 'CapacityLimiterStatistics')

import collections
import dataclasses
import enum
import math

from . import exceptions
from . import mixins
Expand Down Expand Up @@ -615,3 +618,181 @@ def n_waiting(self):
def broken(self):
"""Return True if the barrier is in a broken state."""
return self._state is _BarrierState.BROKEN


@dataclasses.dataclass(frozen=True)
class CapacityLimiterStatistics:
"""Statistics for a CapacityLimiter."""
borrowed_tokens: int
total_tokens: int | float
borrowers: tuple[object, ...]
tasks_waiting: int


class CapacityLimiter(_ContextManagerMixin, mixins._LoopBoundMixin):
"""A capacity limiter that tracks borrowers and supports dynamic capacity.
Unlike a Semaphore, a CapacityLimiter:
- Tracks which tasks hold tokens, preventing the same task from
acquiring twice (which would deadlock a semaphore).
- Allows dynamic adjustment of total_tokens at runtime.
- Supports acquiring/releasing on behalf of arbitrary objects.
Usage::
limiter = CapacityLimiter(10)
async with limiter:
# At most 10 tasks can be here concurrently
...
"""

def __init__(self, total_tokens: int | float):
self._validate_tokens(total_tokens)
self._total_tokens: int | float = total_tokens
self._borrowers: set[object] = set()
self._waiters: collections.OrderedDict[object, object] = (
collections.OrderedDict()
)

def __repr__(self):
res = super().__repr__()
extra = (f'borrowed:{self.borrowed_tokens}, '
f'total:{self._total_tokens}')
if self._waiters:
extra = f'{extra}, waiters:{len(self._waiters)}'
return f'<{res[1:-1]} [{extra}]>'

@staticmethod
def _validate_tokens(total_tokens):
if not isinstance(total_tokens, (int, float)):
raise TypeError("'total_tokens' must be an int or float")
if isinstance(total_tokens, float) and total_tokens != math.inf:
raise ValueError(
"'total_tokens' must be an integer or math.inf"
)
if total_tokens < 0:
raise ValueError("'total_tokens' must be >= 0")

@property
def total_tokens(self) -> int | float:
"""The total number of tokens available (read-write)."""
return self._total_tokens

@total_tokens.setter
def total_tokens(self, value: int | float):
self._validate_tokens(value)
self._total_tokens = value
self._notify_waiters()

@property
def borrowed_tokens(self) -> int:
"""The number of tokens currently borrowed."""
return len(self._borrowers)

@property
def available_tokens(self) -> int | float:
"""The number of tokens currently available."""
return self._total_tokens - len(self._borrowers)

def acquire_nowait(self) -> None:
"""Acquire a token on behalf of the current task without blocking.
Raises WouldBlock if a token is not immediately available.
Raises RuntimeError if the current task already holds a token.
"""
from . import tasks
self.acquire_on_behalf_of_nowait(tasks.current_task())

async def acquire(self) -> None:
"""Acquire a token on behalf of the current task.
Blocks until a token is available.
Raises RuntimeError if the current task already holds a token.
"""
from . import tasks
await self.acquire_on_behalf_of(tasks.current_task())

def acquire_on_behalf_of_nowait(self, borrower) -> None:
"""Acquire a token on behalf of the given borrower without blocking.
Raises WouldBlock if a token is not immediately available.
Raises RuntimeError if the borrower already holds a token.
"""
if borrower in self._borrowers:
raise RuntimeError(
"this borrower is already holding one of this "
"CapacityLimiter's tokens"
)
if self._waiters or len(self._borrowers) >= self._total_tokens:
raise exceptions.WouldBlock
self._borrowers.add(borrower)

async def acquire_on_behalf_of(self, borrower) -> None:
"""Acquire a token on behalf of the given borrower.
Blocks until a token is available.
Raises RuntimeError if the borrower already holds a token.
"""
try:
self.acquire_on_behalf_of_nowait(borrower)
except exceptions.WouldBlock:
pass
else:
return

fut = self._get_loop().create_future()
self._waiters[borrower] = fut
try:
await fut
except exceptions.CancelledError:
self._waiters.pop(borrower, None)
# If the future was already resolved before we got cancelled,
# we already hold the token — release it and wake the next waiter.
if fut.done() and not fut.cancelled():
self._borrowers.discard(borrower)
self._notify_waiters()
raise
else:
# Future completed successfully; borrower was added by
# _notify_waiters, nothing more to do.
pass

def release(self) -> None:
"""Release a token on behalf of the current task.
Raises RuntimeError if the current task does not hold a token.
"""
from . import tasks
self.release_on_behalf_of(tasks.current_task())

def release_on_behalf_of(self, borrower) -> None:
"""Release a token on behalf of the given borrower.
Raises RuntimeError if the borrower does not hold a token.
"""
if borrower not in self._borrowers:
raise RuntimeError(
"this borrower is not holding any of this "
"CapacityLimiter's tokens"
)
self._borrowers.discard(borrower)
self._notify_waiters()

def _notify_waiters(self):
"""Wake up waiters while capacity is available."""
while self._waiters and len(self._borrowers) < self._total_tokens:
borrower, fut = self._waiters.popitem(last=False)
if not fut.done():
self._borrowers.add(borrower)
fut.set_result(None)

def statistics(self) -> CapacityLimiterStatistics:
"""Return statistics about the current state of the limiter."""
return CapacityLimiterStatistics(
borrowed_tokens=len(self._borrowers),
total_tokens=self._total_tokens,
borrowers=tuple(self._borrowers),
tasks_waiting=len(self._waiters),
)
Loading
Loading