diff --git a/grain/_src/python/BUILD b/grain/_src/python/BUILD index ba918f2b7..d297e6892 100644 --- a/grain/_src/python/BUILD +++ b/grain/_src/python/BUILD @@ -368,3 +368,19 @@ py_library( "@pypi//etils:pkg", ], ) + +py_library( + name = "variable_size_queue", + srcs = ["variable_size_queue.py"], + srcs_version = "PY3", +) + +py_test( + name = "variable_size_queue_test", + srcs = ["variable_size_queue_test.py"], + srcs_version = "PY3", + deps = [ + ":variable_size_queue", + "@abseil-py//absl/testing:absltest", + ], +) diff --git a/grain/_src/python/variable_size_queue.py b/grain/_src/python/variable_size_queue.py new file mode 100644 index 000000000..5c521a690 --- /dev/null +++ b/grain/_src/python/variable_size_queue.py @@ -0,0 +1,172 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module provides variable size queue implementations.""" + +from multiprocessing import context +from multiprocessing import queues +from multiprocessing import sharedctypes +import queue +import threading +import time +from typing import Any, cast + + +class VariableSizeMultiprocessingQueue(queues.Queue): + """A multiprocessing queue whose max size can be dynamically changed.""" + + def __init__( + self, + max_size: int | sharedctypes.Synchronized, + ctx: context.BaseContext, + ): + super().__init__(maxsize=0, ctx=ctx) + self._max_size = ( + max_size + if isinstance(max_size, sharedctypes.Synchronized) + else ctx.Value("i", max_size) + ) + self._cond = ctx.Condition() + + def __getstate__(self): + return cast(tuple[Any, ...], super().__getstate__()) + ( + self._max_size, + self._cond, + ) + + def __setstate__(self, state): + super().__setstate__(state[:-2]) # pytype: disable=attribute-error + self._max_size, self._cond = state[-2:] + + def set_max_size(self, max_size: int): + with self._cond: + self._max_size.value = max_size + self._cond.notify_all() + + def put(self, obj, block: bool = True, timeout: float | None = None): + """Puts an item into the queue, similar to `queue.Queue.put`. + + This method behaves like `queue.Queue.put`, but respects the current + `_max_size` of this variable-size queue. If the queue is full based on + `_max_size`, this method can block or raise `queue.Full` depending on + `block` and `timeout`. + + Args: + obj: The object to put into the queue. + block: If True, block until a free slot is available. + timeout: If `block` is True, wait for at most `timeout` seconds. + + Raises: + queue.Full: If the queue is full and `block` is False or the `timeout` + is reached. + """ + if not block: + with self._cond: + if self.qsize() >= self._max_size.value: + raise queue.Full + super().put(obj, block=False) + return + + deadline = None + if timeout is not None: + deadline = time.time() + timeout + + with self._cond: + while self.qsize() >= self._max_size.value: + if deadline is None: + self._cond.wait() + continue + remaining = deadline - time.time() + if remaining <= 0: + raise queue.Full + if not self._cond.wait(remaining): + if self.qsize() >= self._max_size.value: + raise queue.Full + else: + break + super().put(obj, block=False) + + def get(self, block: bool = True, timeout: float | None = None): + item = super().get(block=block, timeout=timeout) + with self._cond: + self._cond.notify() + return item + + def get_nowait(self): + item = super().get_nowait() + with self._cond: + self._cond.notify() + return item + + +class VariableSizeQueue(queue.Queue): + """A queue whose max size can be dynamically changed.""" + + def __init__(self, max_size: int): + super().__init__(maxsize=0) + self._max_size = max_size + self._cond = threading.Condition() + + def set_max_size(self, max_size: int): + with self._cond: + self._max_size = max_size + self._cond.notify_all() + + def put(self, item, block: bool = True, timeout: float | None = None): + """Puts an item into the queue, similar to `queue.Queue.put`. + + This method behaves like `queue.Queue.put`, but respects the current + `_max_size` of this variable-size queue. If the queue is full based on + `_max_size`, this method can block or raise `queue.Full` depending on + `block` and `timeout`. + + Args: + item: The object to put into the queue. + block: If True, block until a free slot is available. + timeout: If `block` is True, wait for at most `timeout` seconds. + + Raises: + queue.Full: If the queue is full and `block` is False or the `timeout` + is reached. + """ + if not block: + with self._cond: + if self.qsize() >= self._max_size: + raise queue.Full + super().put(item, block=False) + return + + deadline = None + if timeout is not None: + deadline = time.time() + timeout + + with self._cond: + while self.qsize() >= self._max_size: + if deadline is None: + self._cond.wait() + continue + remaining = deadline - time.time() + if remaining <= 0: + raise queue.Full + if not self._cond.wait(remaining): + if self.qsize() >= self._max_size: + raise queue.Full + else: + break + super().put(item, block=False) + + def get(self, block: bool = True, timeout: float | None = None): + item = super().get(block=block, timeout=timeout) + with self._cond: + self._cond.notify() + return item diff --git a/grain/_src/python/variable_size_queue_test.py b/grain/_src/python/variable_size_queue_test.py new file mode 100644 index 000000000..6ace7e0fd --- /dev/null +++ b/grain/_src/python/variable_size_queue_test.py @@ -0,0 +1,249 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for variable size queue implementations.""" + +import multiprocessing +import queue +import threading +import time + +from absl.testing import absltest +from grain._src.python import variable_size_queue + + +class VariableSizeQueueTest(absltest.TestCase): + + def test_put_and_get(self): + q = variable_size_queue.VariableSizeQueue(max_size=1) + self.assertEqual(q.qsize(), 0) + q.put(1) + self.assertEqual(q.qsize(), 1) + self.assertEqual(q.get(), 1) + self.assertEqual(q.qsize(), 0) + + def test_put_non_blocking_to_full_queue_raises_full(self): + q = variable_size_queue.VariableSizeQueue(max_size=1) + q.put(1) + with self.assertRaises(queue.Full): + q.put(2, block=False) + + def test_put_blocking_with_timeout_to_full_queue_raises_full(self): + q = variable_size_queue.VariableSizeQueue(max_size=1) + q.put(1) + with self.assertRaises(queue.Full): + q.put(2, block=True, timeout=0.1) + + def test_set_max_size_to_increase_capacity(self): + q = variable_size_queue.VariableSizeQueue(max_size=1) + q.put(1) + with self.assertRaises(queue.Full): + q.put(2, block=False) + q.set_max_size(2) + q.put(2) # Should not raise. + self.assertEqual(q.qsize(), 2) + self.assertEqual(q.get(), 1) + self.assertEqual(q.get(), 2) + + def test_set_max_size_to_decrease_capacity(self): + q = variable_size_queue.VariableSizeQueue(max_size=2) + q.put(1) + q.put(2) + self.assertEqual(q.qsize(), 2) + q.set_max_size(1) + # qsize is 2, max_size is 1. put should fail. + with self.assertRaises(queue.Full): + q.put(3, block=False) + self.assertEqual(q.get(), 1) + self.assertEqual(q.qsize(), 1) + # qsize is 1, max_size is 1. put should fail. + with self.assertRaises(queue.Full): + q.put(3, block=False) + self.assertEqual(q.get(), 2) + self.assertEqual(q.qsize(), 0) + # qsize is 0, max_size is 1. put should succeed. + q.put(3) + self.assertEqual(q.qsize(), 1) + self.assertEqual(q.get(), 3) + + def test_put_blocks_until_item_is_retrieved(self): + q = variable_size_queue.VariableSizeQueue(max_size=1) + q.put(1) + result = [] + + def consumer(): + time.sleep(0.1) + result.append(q.get()) + + t = threading.Thread(target=consumer) + t.start() + q.put(2) # This should block until consumer gets item 1. + self.assertEqual(q.qsize(), 1) + self.assertEqual(q.get(), 2) + t.join() + self.assertEqual(result, [1]) + + def test_put_blocks_until_max_size_increases(self): + q = variable_size_queue.VariableSizeQueue(max_size=1) + q.put(1) + + def increase_max_size(): + time.sleep(0.1) + q.set_max_size(2) + + t = threading.Thread(target=increase_max_size) + t.start() + q.put(2) # This should block until max_size is increased. + self.assertEqual(q.qsize(), 2) + self.assertEqual(q.get(), 1) + self.assertEqual(q.get(), 2) + t.join() + + +class VariableSizeMultiprocessingQueueTest(absltest.TestCase): + + def test_put_and_get(self): + q = variable_size_queue.VariableSizeMultiprocessingQueue( + 1, ctx=multiprocessing.get_context("fork") + ) + self.assertEqual(q.qsize(), 0) + q.put(1) + self.assertEqual(q.qsize(), 1) + self.assertEqual(q.get(), 1) + self.assertEqual(q.qsize(), 0) + + def test_put_non_blocking_to_full_queue_raises_full(self): + q = variable_size_queue.VariableSizeMultiprocessingQueue( + 1, ctx=multiprocessing.get_context("fork") + ) + q.put(1) + with self.assertRaises(queue.Full): + q.put(2, block=False) + + def test_put_blocking_with_timeout_to_full_queue_raises_full(self): + q = variable_size_queue.VariableSizeMultiprocessingQueue( + 1, ctx=multiprocessing.get_context("fork") + ) + q.put(1) + with self.assertRaises(queue.Full): + q.put(2, block=True, timeout=0.1) + + def test_set_max_size_to_increase_capacity(self): + q = variable_size_queue.VariableSizeMultiprocessingQueue( + 1, ctx=multiprocessing.get_context("fork") + ) + q.put(1) + with self.assertRaises(queue.Full): + q.put(2, block=False) + q.set_max_size(2) + q.put(2) # Should not raise. + self.assertEqual(q.qsize(), 2) + self.assertEqual(q.get(), 1) + self.assertEqual(q.get(), 2) + + def test_set_max_size_to_decrease_capacity(self): + q = variable_size_queue.VariableSizeMultiprocessingQueue( + 2, ctx=multiprocessing.get_context("fork") + ) + q.put(1) + q.put(2) + self.assertEqual(q.qsize(), 2) + q.set_max_size(1) + # qsize is 2, max_size is 1. put should fail. + with self.assertRaises(queue.Full): + q.put(3, block=False) + self.assertEqual(q.get(), 1) + self.assertEqual(q.qsize(), 1) + # qsize is 1, max_size is 1. put should fail. + with self.assertRaises(queue.Full): + q.put(3, block=False) + self.assertEqual(q.get(), 2) + self.assertEqual(q.qsize(), 0) + # qsize is 0, max_size is 1. put should succeed. + q.put(3) + self.assertEqual(q.qsize(), 1) + self.assertEqual(q.get(), 3) + + def test_put_blocks_until_item_is_retrieved_from_process(self): + ctx = multiprocessing.get_context("fork") + q = variable_size_queue.VariableSizeMultiprocessingQueue(1, ctx=ctx) + q.put(1) + + with ctx.Manager() as manager: + result_list = manager.list() + + def consumer(q, result): + time.sleep(0.1) + result.append(q.get()) + + p = ctx.Process(target=consumer, args=(q, result_list)) + p.start() + q.put(2) # This should block until consumer gets item 1. + self.assertEqual(q.qsize(), 1) + self.assertEqual(q.get(), 2) + p.join() + self.assertEqual(list(result_list), [1]) + + def test_put_blocks_until_max_size_increases_from_process(self): + ctx = multiprocessing.get_context("fork") + q = variable_size_queue.VariableSizeMultiprocessingQueue(1, ctx=ctx) + q.put(1) + + def increase_max_size(q): + time.sleep(0.1) + q.set_max_size(2) + + p = ctx.Process(target=increase_max_size, args=(q,)) + p.start() + # This should block until max_size is increased in the other process. + q.put(2) + self.assertEqual(q.qsize(), 2) + self.assertEqual(q.get(), 1) + self.assertEqual(q.get(), 2) + p.join() + + def test_empty(self): + q = variable_size_queue.VariableSizeMultiprocessingQueue( + 1, ctx=multiprocessing.get_context("fork") + ) + self.assertTrue(q.empty()) + q.put(1, block=False) + while q.empty(): + time.sleep(0.1) + self.assertFalse(q.empty()) + q.get() + self.assertTrue(q.empty()) + + def test_get_nowait(self): + q = variable_size_queue.VariableSizeMultiprocessingQueue( + 1, ctx=multiprocessing.get_context("fork") + ) + with self.assertRaises(queue.Empty): + q.get_nowait() + q.put(1) + while q.empty(): + time.sleep(0.1) + self.assertEqual(q.get_nowait(), 1) + with self.assertRaises(queue.Empty): + q.get_nowait() + + def test_close_and_cancel_join_thread(self): + q = variable_size_queue.VariableSizeMultiprocessingQueue( + 1, ctx=multiprocessing.get_context("fork") + ) + q.close() + q.cancel_join_thread() + + +if __name__ == "__main__": + absltest.main()