Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion streamable/_afunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

def buffer(
aiterator: AsyncIterator[T],
up_to: int,
up_to: Optional[int] = None,
) -> AsyncIterator[T]:
return _aiterators.BufferAsyncIterator(aiterator, up_to)

Expand Down
10 changes: 5 additions & 5 deletions streamable/_aiterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from asyncio.futures import Future
from contextlib import suppress
import datetime
import sys
import time
from abc import ABC, abstractmethod
from collections import defaultdict, deque
Expand Down Expand Up @@ -37,11 +38,10 @@
FIFOFutureResults,
FutureResults,
)
from streamable._tools._async import AsyncFunction, empty_aiter
from streamable._tools._async import AsyncFunction, anext, empty_aiter
from streamable._tools._context import noop_context_manager
from streamable._tools._error import ExceptionContainer, RaisingAsyncIterator

from streamable._tools._async import anext

T = TypeVar("T")
U = TypeVar("U")
Expand All @@ -59,10 +59,10 @@ class _BufferAsyncIterable(AsyncIterable[Union[T, ExceptionContainer]]):
def __init__(
self,
iterator: AsyncIterator[T],
up_to: int,
up_to: Optional[int],
) -> None:
self.iterator = iterator
self.up_to = up_to
self.up_to = up_to or sys.maxsize
self._buffer: "Optional[asyncio.Queue[Union[T, ExceptionContainer]]]" = None
self._slots: Optional[asyncio.Semaphore] = None
self._stopped = False
Expand Down Expand Up @@ -115,7 +115,7 @@ class BufferAsyncIterator(RaisingAsyncIterator[T]):
def __init__(
self,
iterator: AsyncIterator[T],
up_to: int,
up_to: Optional[int],
) -> None:
super().__init__(_BufferAsyncIterable(iterator, up_to).__aiter__())

Expand Down
2 changes: 1 addition & 1 deletion streamable/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

def buffer(
iterator: Iterator[T],
up_to: int,
up_to: Optional[int] = None,
) -> Iterator[T]:
return _iterators.BufferIterator(iterator, up_to)

Expand Down
6 changes: 4 additions & 2 deletions streamable/_iterators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import queue
import sys
from threading import Semaphore, Thread
import time
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -55,9 +56,10 @@ class _BufferIterable(Iterable[Union[T, ExceptionContainer]]):
def __init__(
self,
iterator: Iterator[T],
up_to: int,
up_to: Optional[int],
) -> None:
self.iterator = iterator
up_to = up_to or sys.maxsize
self._buffer: "queue.Queue[Union[T, ExceptionContainer]]" = queue.Queue()
self._slots = Semaphore(up_to)
self._stopped = False
Expand Down Expand Up @@ -99,7 +101,7 @@ class BufferIterator(RaisingIterator[T]):
def __init__(
self,
iterator: Iterator[T],
up_to: int,
up_to: Optional[int],
) -> None:
super().__init__(_BufferIterable(iterator, up_to).__iter__())

Expand Down
13 changes: 6 additions & 7 deletions streamable/_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,17 +300,15 @@ def cast(self, into: Type[U]) -> "stream[U]":

def buffer(
self,
up_to: int,
up_to: Optional[int] = None,
) -> "stream[T]":
"""
Buffer upstream elements into a bounded queue (max size ``up_to``), via a background task.

Allow to decouple the upstream production rate from the downstream consumption rate.
Buffer upstream elements into a queue, via a background task, decoupling upstream production rate from downstream consumption rate.

The background task is a thread during a sync iteration, and an async task during an async iteration.

Args:
up_to (``int``): The buffer size. Must be >= 1. When reached, upstream pulling pauses until an element is yielded out of the buffer.
up_to (``int | None``): The buffer size. Must be >= 1. When reached, upstream pulling pauses until an element is yielded out of the buffer.

Returns:
``stream[T]``: Upstream with buffering.
Expand All @@ -327,7 +325,8 @@ def buffer(
time.sleep(1e-3)
assert pulled == [0, 1, 2, 3, 4, 5]
"""
validate_int(up_to, gte=1, name="up_to")
if up_to is not None:
validate_int(up_to, gte=1, name="up_to")
return BufferStream(self, up_to)

@overload
Expand Down Expand Up @@ -1082,7 +1081,7 @@ class BufferStream(DownStream[T, T]):
def __init__(
self,
upstream: stream[T],
up_to: int,
up_to: Optional[int],
) -> None:
super().__init__(upstream)
self._up_to = up_to
Expand Down
9 changes: 1 addition & 8 deletions streamable/_tools/_async.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
TypeVar,
)
from typing import Any, AsyncIterator, Awaitable, Callable, Coroutine, TypeVar

T = TypeVar("T")
R = TypeVar("R")
Expand Down
11 changes: 8 additions & 3 deletions tests/test_buffer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from typing import Any, AsyncIterable, Callable, Iterable, List

import pytest
Expand Down Expand Up @@ -26,7 +27,7 @@ def test_buffer_preserves_elements(itype: IterableType) -> None:
assert alist_or_list(ints.buffer(5), itype) == list(INTEGERS)


@pytest.mark.parametrize("buffer_size", [1, 10])
@pytest.mark.parametrize("buffer_size", [1, 10, None])
@pytest.mark.parametrize(
"itype, slow_identity",
[(Iterable, slow_identity), (AsyncIterable, async_slow_identity)],
Expand All @@ -39,9 +40,13 @@ def test_buffer_size_is_respected(
buffering_ints_iter = aiter_or_iter(buffering_ints, itype)
assert buffered == []
assert anext_or_next(buffering_ints_iter, itype) == 0
assert buffered == list(INTEGERS)[: buffer_size + 1]
assert (
buffered == list(INTEGERS)[: (buffer_size + 1) if buffer_size else sys.maxsize]
)
assert anext_or_next(buffering_ints_iter, itype) == 1
assert buffered == list(INTEGERS)[: buffer_size + 2]
assert (
buffered == list(INTEGERS)[: (buffer_size + 2) if buffer_size else sys.maxsize]
)


@pytest.mark.parametrize(
Expand Down
Loading