diff --git a/streamable/_afunctions.py b/streamable/_afunctions.py index 5ff792c7..b581838b 100644 --- a/streamable/_afunctions.py +++ b/streamable/_afunctions.py @@ -33,7 +33,7 @@ def buffer( aiterator: AsyncIterator[T], - up_to: int, + up_to: Optional[int] = None, ) -> AsyncIterator[T]: return _aiterators.BufferAsyncIterator(aiterator, up_to) diff --git a/streamable/_aiterators.py b/streamable/_aiterators.py index eee57894..6a90ead0 100644 --- a/streamable/_aiterators.py +++ b/streamable/_aiterators.py @@ -2,6 +2,7 @@ from asyncio.futures import Future from contextlib import suppress import datetime +import sys import time from abc import ABC, abstractmethod from collections import defaultdict, deque @@ -37,11 +38,10 @@ FIFOFutureResults, FutureResults, ) -from streamable._tools._async import AsyncFunction, empty_aiter +from streamable._tools._async import AsyncFunction, anext, empty_aiter from streamable._tools._context import noop_context_manager from streamable._tools._error import ExceptionContainer, RaisingAsyncIterator -from streamable._tools._async import anext T = TypeVar("T") U = TypeVar("U") @@ -59,10 +59,10 @@ class _BufferAsyncIterable(AsyncIterable[Union[T, ExceptionContainer]]): def __init__( self, iterator: AsyncIterator[T], - up_to: int, + up_to: Optional[int], ) -> None: self.iterator = iterator - self.up_to = up_to + self.up_to = up_to or sys.maxsize self._buffer: "Optional[asyncio.Queue[Union[T, ExceptionContainer]]]" = None self._slots: Optional[asyncio.Semaphore] = None self._stopped = False @@ -115,7 +115,7 @@ class BufferAsyncIterator(RaisingAsyncIterator[T]): def __init__( self, iterator: AsyncIterator[T], - up_to: int, + up_to: Optional[int], ) -> None: super().__init__(_BufferAsyncIterable(iterator, up_to).__aiter__()) diff --git a/streamable/_functions.py b/streamable/_functions.py index 6dd7cc14..af1e26b2 100644 --- a/streamable/_functions.py +++ b/streamable/_functions.py @@ -30,7 +30,7 @@ def buffer( iterator: Iterator[T], - up_to: int, + up_to: Optional[int] = None, ) -> Iterator[T]: return _iterators.BufferIterator(iterator, up_to) diff --git a/streamable/_iterators.py b/streamable/_iterators.py index 8f6d8936..38778339 100644 --- a/streamable/_iterators.py +++ b/streamable/_iterators.py @@ -1,5 +1,6 @@ import datetime import queue +import sys from threading import Semaphore, Thread import time from abc import ABC, abstractmethod @@ -55,9 +56,10 @@ class _BufferIterable(Iterable[Union[T, ExceptionContainer]]): def __init__( self, iterator: Iterator[T], - up_to: int, + up_to: Optional[int], ) -> None: self.iterator = iterator + up_to = up_to or sys.maxsize self._buffer: "queue.Queue[Union[T, ExceptionContainer]]" = queue.Queue() self._slots = Semaphore(up_to) self._stopped = False @@ -99,7 +101,7 @@ class BufferIterator(RaisingIterator[T]): def __init__( self, iterator: Iterator[T], - up_to: int, + up_to: Optional[int], ) -> None: super().__init__(_BufferIterable(iterator, up_to).__iter__()) diff --git a/streamable/_stream.py b/streamable/_stream.py index 3a17b945..e4faf752 100644 --- a/streamable/_stream.py +++ b/streamable/_stream.py @@ -300,17 +300,15 @@ def cast(self, into: Type[U]) -> "stream[U]": def buffer( self, - up_to: int, + up_to: Optional[int] = None, ) -> "stream[T]": """ - Buffer upstream elements into a bounded queue (max size ``up_to``), via a background task. - - Allow to decouple the upstream production rate from the downstream consumption rate. + Buffer upstream elements into a queue, via a background task, decoupling upstream production rate from downstream consumption rate. The background task is a thread during a sync iteration, and an async task during an async iteration. Args: - up_to (``int``): The buffer size. Must be >= 1. When reached, upstream pulling pauses until an element is yielded out of the buffer. + up_to (``int | None``): The buffer size. Must be >= 1. When reached, upstream pulling pauses until an element is yielded out of the buffer. Returns: ``stream[T]``: Upstream with buffering. @@ -327,7 +325,8 @@ def buffer( time.sleep(1e-3) assert pulled == [0, 1, 2, 3, 4, 5] """ - validate_int(up_to, gte=1, name="up_to") + if up_to is not None: + validate_int(up_to, gte=1, name="up_to") return BufferStream(self, up_to) @overload @@ -1082,7 +1081,7 @@ class BufferStream(DownStream[T, T]): def __init__( self, upstream: stream[T], - up_to: int, + up_to: Optional[int], ) -> None: super().__init__(upstream) self._up_to = up_to diff --git a/streamable/_tools/_async.py b/streamable/_tools/_async.py index d11e456d..931406c3 100644 --- a/streamable/_tools/_async.py +++ b/streamable/_tools/_async.py @@ -1,11 +1,4 @@ -from typing import ( - Any, - AsyncIterator, - Awaitable, - Callable, - Coroutine, - TypeVar, -) +from typing import Any, AsyncIterator, Awaitable, Callable, Coroutine, TypeVar T = TypeVar("T") R = TypeVar("R") diff --git a/tests/test_buffer.py b/tests/test_buffer.py index 620d4da6..21808697 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -1,3 +1,4 @@ +import sys from typing import Any, AsyncIterable, Callable, Iterable, List import pytest @@ -26,7 +27,7 @@ def test_buffer_preserves_elements(itype: IterableType) -> None: assert alist_or_list(ints.buffer(5), itype) == list(INTEGERS) -@pytest.mark.parametrize("buffer_size", [1, 10]) +@pytest.mark.parametrize("buffer_size", [1, 10, None]) @pytest.mark.parametrize( "itype, slow_identity", [(Iterable, slow_identity), (AsyncIterable, async_slow_identity)], @@ -39,9 +40,13 @@ def test_buffer_size_is_respected( buffering_ints_iter = aiter_or_iter(buffering_ints, itype) assert buffered == [] assert anext_or_next(buffering_ints_iter, itype) == 0 - assert buffered == list(INTEGERS)[: buffer_size + 1] + assert ( + buffered == list(INTEGERS)[: (buffer_size + 1) if buffer_size else sys.maxsize] + ) assert anext_or_next(buffering_ints_iter, itype) == 1 - assert buffered == list(INTEGERS)[: buffer_size + 2] + assert ( + buffered == list(INTEGERS)[: (buffer_size + 2) if buffer_size else sys.maxsize] + ) @pytest.mark.parametrize(