Skip to content

Commit

Permalink
feat: add max_queue_size argument to set a limit for buffered items
Browse files Browse the repository at this point in the history
  • Loading branch information
hussein-awala committed Feb 28, 2024
1 parent 26cb5c7 commit 7fd55c0
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 2 deletions.
9 changes: 8 additions & 1 deletion async_batcher/batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from collections import namedtuple
from typing import TYPE_CHECKING, Generic, TypeVar

from async_batcher.exceptions import QueueFullException

if TYPE_CHECKING:
from concurrent.futures import Executor

Expand All @@ -23,6 +25,8 @@ class AsyncBatcher(Generic[T, S], abc.ABC):
max_queue_time (float, optional): The max time for a task to stay in the queue before processing
it if the batch is not full and the number of running batches is less than the concurrency.
Defaults to 0.01.
max_queue_size (int, optional): The max number of items to keep in the queue.
Defaults to -1 (no limit).
concurrency (int, optional): The max number of concurrent batches to process.
Defaults to 1. If -1, it will process all batches concurrently.
executor (Executor, optional): The executor to use to process the batch if the `process_batch` method
Expand All @@ -38,6 +42,7 @@ def __init__(
max_batch_size: int = -1,
max_queue_time: float = 0.01,
concurrency: int = 1,
max_queue_size: int = -1,
executor: Executor | None = None,
**kwargs,
):
Expand Down Expand Up @@ -71,7 +76,7 @@ def __init__(
self.max_queue_time = max_queue_time
self.concurrency = concurrency
self.executor = executor
self._queue = asyncio.Queue()
self._queue = asyncio.Queue(maxsize=max_queue_size)
self._current_task: asyncio.Task | None = None
self._running_batches: dict[int, asyncio.Task] = {}
self._concurrency_semaphore = asyncio.Semaphore(concurrency) if concurrency > 0 else None
Expand Down Expand Up @@ -100,6 +105,8 @@ async def process(self, item: T) -> S:
self._current_task = asyncio.get_running_loop().create_task(self.run())
logging.debug(item)
future = asyncio.get_running_loop().create_future()
if self._queue.full():
raise QueueFullException("The queue is full, cannot process more items at the moment.")
await self._queue.put(self.QueueItem(item, future))
await future
return future.result()
Expand Down
9 changes: 9 additions & 0 deletions async_batcher/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations


class AsyncBatchException(Exception):
pass


class QueueFullException(AsyncBatchException):
pass
32 changes: 31 additions & 1 deletion tests/test_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys

import pytest
from async_batcher.exceptions import QueueFullException

from tests.conftest import MockAsyncBatcher, SlowAsyncBatcher

Expand All @@ -24,10 +25,16 @@ def __init__(self, batcher, sleep_time, start_range, end_range):
self.start_range = start_range
self.end_range = end_range

async def process_and_catch_exception(self, item):
try:
return await self.batcher.process(item=item)
except Exception as e:
return e

async def arun(self):
await asyncio.sleep(self.sleep_time)
result = await asyncio.gather(
*[self.batcher.process(item=i) for i in range(self.start_range, self.end_range)]
*[self.process_and_catch_exception(item=i) for i in range(self.start_range, self.end_range)]
)
self.result = result

Expand Down Expand Up @@ -176,3 +183,26 @@ async def test_force_stop_batcher():
for task in batcher._running_batches.values():
assert task.cancelled() or task.cancelling()
batcher.mock_batch_processor.reset_mock()


@pytest.mark.asyncio(scope="session")
async def test_max_queue_size():
batcher = SlowAsyncBatcher(
sleep_time=1,
max_batch_size=10,
max_queue_time=0.2,
concurrency=1,
max_queue_size=15,
)
batcher.mock_batch_processor.reset_mock()
calls_maker1 = CallsMaker(batcher, 0, 0, 10)
calls_maker2 = CallsMaker(batcher, 0.25, 10, 20)
calls_maker3 = CallsMaker(batcher, 0.4, 20, 30)
await asyncio.gather(calls_maker1.arun(), calls_maker2.arun(), calls_maker3.arun())
assert batcher.mock_batch_processor.call_count == 3
assert calls_maker1.result == [i * 2 for i in range(10)]
assert calls_maker2.result == [i * 2 for i in range(10, 20)]
assert calls_maker3.result[:5] == [i * 2 for i in range(20, 25)]
assert all(isinstance(e, QueueFullException) for e in calls_maker3.result[5:])
batcher.mock_batch_processor.reset_mock()
await batcher.stop()

0 comments on commit 7fd55c0

Please sign in to comment.