From f4500e5ec41b2c03ac63ab5533a2003f599f8dda Mon Sep 17 00:00:00 2001 From: Gershon Bialer Date: Wed, 2 Aug 2023 10:08:51 -0700 Subject: [PATCH 01/14] Add code to celery batches for events. --- celery_batches/__init__.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/celery_batches/__init__.py b/celery_batches/__init__.py index b8aa179..345ad90 100644 --- a/celery_batches/__init__.py +++ b/celery_batches/__init__.py @@ -17,6 +17,7 @@ from celery_batches.trace import apply_batches_task from celery import VERSION as CELERY_VERSION +from celery import signals from celery.app import Celery from celery.app.task import Task from celery.concurrency.base import BasePool @@ -200,6 +201,10 @@ def Strategy(self, task: "Batches", app: Celery, consumer: Consumer) -> Callable connection_errors = consumer.connection_errors eventer = consumer.event_dispatcher + events = eventer and eventer.enabled + send_event = eventer and eventer.send + task_sends_events = events and task.send_events + Request = symbol_by_name(task.Request) # Celery 5.1 added the app argument to create_request_cls. @@ -253,6 +258,19 @@ def task_message_handler( ) put_buffer(request) + signals.task_received.send(sender=consumer, request=request) + if task_sends_events: + send_event( + 'task-received', + uuid=request.id, name=request.name, + args=request.argsrepr, kwargs=request.kwargsrepr, + root_id=request.root_id, parent_id=request.parent_id, + retries=request.request_dict.get('retries', 0), + eta=request.eta and request.eta.isoformat(), + expires=request.expires and request.expires.isoformat(), + ) + + if self._tref is None: # first request starts flush timer. self._tref = timer.call_repeatedly(self.flush_interval, flush_buffer) @@ -351,15 +369,20 @@ def flush(self, requests: Collection[Request]) -> Any: # Ensure the requests can be serialized using pickle for the prefork pool. serializable_requests = ([SimpleRequest.from_request(r) for r in requests],) - + def on_accepted(pid: int, time_accepted: float) -> None: for req in acks_early: req.acknowledge() + for request in requests: + request.send_event('task-started') + + def on_return(failed__retval__runtime, **kwargs) -> None: + failed, retval, runtime = failed__retval__runtime - def on_return(result: Optional[Any]) -> None: for req in acks_late: req.acknowledge() - + for request in requests: + request.send_event('task-succeeded', ret_val=retval,runtime=runtime ) return self._pool.apply_async( apply_batches_task, (self, serializable_requests, 0, None), From 9dff13a0af30ed3a38ca818a73df884deb3063df Mon Sep 17 00:00:00 2001 From: Gershon Bialer Date: Wed, 2 Aug 2023 10:36:32 -0700 Subject: [PATCH 02/14] Fix events to pass tests. --- celery_batches/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/celery_batches/__init__.py b/celery_batches/__init__.py index 345ad90..0ef3684 100644 --- a/celery_batches/__init__.py +++ b/celery_batches/__init__.py @@ -270,7 +270,6 @@ def task_message_handler( expires=request.expires and request.expires.isoformat(), ) - if self._tref is None: # first request starts flush timer. self._tref = timer.call_repeatedly(self.flush_interval, flush_buffer) @@ -376,13 +375,14 @@ def on_accepted(pid: int, time_accepted: float) -> None: for request in requests: request.send_event('task-started') - def on_return(failed__retval__runtime, **kwargs) -> None: - failed, retval, runtime = failed__retval__runtime - + def on_return(result: Optional[Any]) -> None: for req in acks_late: req.acknowledge() for request in requests: - request.send_event('task-succeeded', ret_val=retval,runtime=runtime ) + runtime = 0 + if type(result) == int: + runtime = result + request.send_event('task-succeeded',result=None, runtime=runtime ) return self._pool.apply_async( apply_batches_task, (self, serializable_requests, 0, None), From 33f0275e9f5de201681f2c696d6cdd74d45a5356 Mon Sep 17 00:00:00 2001 From: Gershon Bialer Date: Wed, 31 Jul 2024 14:48:00 -0700 Subject: [PATCH 03/14] Fixed signals and added unit tests. --- celery_batches/trace.py | 34 +++++++-- t/unit/test_signals.py | 148 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 177 insertions(+), 5 deletions(-) create mode 100644 t/unit/test_signals.py diff --git a/celery_batches/trace.py b/celery_batches/trace.py index 8f2fc38..f38acef 100644 --- a/celery_batches/trace.py +++ b/celery_batches/trace.py @@ -22,8 +22,11 @@ send_prerun = signals.task_prerun.send send_postrun = signals.task_postrun.send send_success = signals.task_success.send +send_failure = signals.task_failure.send +send_revoked = signals.task_revoked.send SUCCESS = states.SUCCESS FAILURE = states.FAILURE +REVOKED = states.REVOKED def apply_batches_task( @@ -38,6 +41,14 @@ def apply_batches_task( prerun_receivers = signals.task_prerun.receivers postrun_receivers = signals.task_postrun.receivers success_receivers = signals.task_success.receivers + failure_receivers = signals.task_failure.receivers + revoked_receivers = signals.task_revoked.receivers + + logger.debug(f"Debug: prerun_receivers: {prerun_receivers}") + logger.debug(f"Debug: postrun_receivers: {postrun_receivers}") + logger.debug(f"Debug: success_receivers: {success_receivers}") + logger.debug(f"Debug: failure_receivers: {failure_receivers}") + logger.debug(f"Debug: revoked_receivers: {revoked_receivers}") # Corresponds to multiple requests, so generate a new UUID. task_id = uuid() @@ -49,22 +60,35 @@ def apply_batches_task( try: # -*- PRE -*- if prerun_receivers: + logger.debug("Debug: Sending prerun signal") send_prerun(sender=task, task_id=task_id, task=task, args=args, kwargs={}) # -*- TRACE -*- try: - result = task(*args) - state = SUCCESS + result = task.run(*args) + if hasattr(task.request, 'state') and task.request.state == REVOKED: + state = REVOKED + else: + state = SUCCESS except Exception as exc: - result = None + result = exc state = FAILURE logger.error("Error: %r", exc, exc_info=True) + if failure_receivers: + logger.debug("Debug: Sending failure signal") + send_failure(sender=task, task_id=task_id, exception=exc, args=args, kwargs={}, einfo=None) else: - if success_receivers: + if state == REVOKED: + if revoked_receivers: + logger.debug("Debug: Sending revoked signal") + send_revoked(sender=task, request=task_request, terminated=True, signum=None, expired=False) + elif state == SUCCESS and success_receivers: + logger.debug("Debug: Sending success signal") send_success(sender=task, result=result) finally: try: if postrun_receivers: + logger.debug("Debug: Sending postrun signal") send_postrun( sender=task, task_id=task_id, @@ -78,4 +102,4 @@ def apply_batches_task( pop_task() pop_request() - return result + return result \ No newline at end of file diff --git a/t/unit/test_signals.py b/t/unit/test_signals.py new file mode 100644 index 0000000..9690986 --- /dev/null +++ b/t/unit/test_signals.py @@ -0,0 +1,148 @@ +import pytest +from unittest.mock import Mock, patch, DEFAULT +from celery import signals +from celery_batches import Batches +from celery_batches.trace import apply_batches_task +from celery.utils.collections import AttributeDict + + +class TestBatchTask: + + @property + def request_stack(self): + class RequestStack: + def __init__(self, stack): + self.stack = stack + + def push(self, item): + self.stack.append(item) + + def pop(self): + if self.stack: + return self.stack.pop() + return None + + return RequestStack(self._request_stack) + + @property + def request(self): + return self._request + + @request.setter + def request(self, value): + self._request = value + + def run(self, requests): + return [request['id'] for request in requests] + + +@pytest.fixture +def batch_task(): + tb = TestBatchTask() + tb._request_stack = [] + tb._request = AttributeDict({'state': None}) + + return tb + +@pytest.fixture +def simple_request(): + return { + 'id': "test_id", + 'name': "test_task", + 'args': (), + 'kwargs': {}, + 'delivery_info': {}, + 'hostname': "test_host", + 'ignore_result': False, + 'reply_to': None, + 'correlation_id': None, + 'request_dict': {}, + } + + +@pytest.fixture(autouse=True) +def setup_signal_receivers(): + def dummy_receiver(*args, **kwargs): + pass + + signals.task_prerun.connect(dummy_receiver) + signals.task_postrun.connect(dummy_receiver) + signals.task_success.connect(dummy_receiver) + signals.task_failure.connect(dummy_receiver) + signals.task_revoked.connect(dummy_receiver) + + yield + + signals.task_prerun.disconnect(dummy_receiver) + signals.task_postrun.disconnect(dummy_receiver) + signals.task_success.disconnect(dummy_receiver) + signals.task_failure.disconnect(dummy_receiver) + signals.task_revoked.disconnect(dummy_receiver) + + +def test_task_prerun_signal(batch_task, simple_request): + with patch('celery_batches.trace.send_prerun') as mock_send: + apply_batches_task(batch_task, ([simple_request],), 0, None) + mock_send.assert_called_once() + + +def test_task_postrun_signal(batch_task, simple_request): + with patch('celery_batches.trace.send_postrun') as mock_send: + apply_batches_task(batch_task, ([simple_request],), 0, None) + mock_send.assert_called_once() + + +def test_task_success_signal(batch_task, simple_request): + with patch('celery_batches.trace.send_success') as mock_send: + apply_batches_task(batch_task, ([simple_request],), 0, None) + mock_send.assert_called_once() + + +def test_task_failure_signal(batch_task, simple_request): + def failing_run(requests): + raise ValueError("Test exception") + + batch_task.run = failing_run + + with patch('celery_batches.trace.send_failure') as mock_send: + apply_batches_task(batch_task, ([simple_request],), 0, None) + mock_send.assert_called_once() + + +def test_task_revoked_signal(batch_task, simple_request): + def revoking_run(requests): + batch_task.request.state = 'REVOKED' + return [] + + batch_task.run = revoking_run + + with patch('celery_batches.trace.send_revoked') as mock_send: + apply_batches_task(batch_task, ([simple_request],), 0, None) + mock_send.assert_called_once() + + +def test_all_signals_sent(batch_task, simple_request): + with patch.multiple('celery_batches.trace', + send_prerun=DEFAULT, + send_postrun=DEFAULT, + send_success=DEFAULT) as mocks: + apply_batches_task(batch_task, ([simple_request],), 0, None) + + for mock in mocks.values(): + mock.assert_called_once() + + +def test_failure_signals_sent(batch_task, simple_request): + def failing_run(requests): + raise ValueError("Test exception") + + batch_task.run = failing_run + + with patch.multiple('celery_batches.trace', + send_prerun=DEFAULT, + send_postrun=DEFAULT, + send_failure=DEFAULT) as mocks: + apply_batches_task(batch_task, ([simple_request],), 0, None) + + for mock in mocks.values(): + mock.assert_called_once() \ No newline at end of file From c99f8766294694fef27ea507294ab7bee59298cc Mon Sep 17 00:00:00 2001 From: Gershon Bialer Date: Fri, 2 Aug 2024 09:33:26 -0700 Subject: [PATCH 04/14] Fix bugs. --- celery_batches/__init__.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/celery_batches/__init__.py b/celery_batches/__init__.py index f1ac522..4e218fa 100644 --- a/celery_batches/__init__.py +++ b/celery_batches/__init__.py @@ -260,16 +260,15 @@ def task_message_handler( signals.task_received.send(sender=consumer, request=req) - signals.task_received.send(sender=consumer, request=request) if task_sends_events: send_event( 'task-received', - uuid=request.id, name=request.name, - args=request.argsrepr, kwargs=request.kwargsrepr, - root_id=request.root_id, parent_id=request.parent_id, - retries=request.request_dict.get('retries', 0), - eta=request.eta and request.eta.isoformat(), - expires=request.expires and request.expires.isoformat(), + uuid=req.id, name=req.name, + args=req.argsrepr, kwargs=req.kwargsrepr, + root_id=req.root_id, parent_id=req.parent_id, + retries=req.request_dict.get('retries', 0), + eta=req.eta and request.eta.isoformat(), + expires=req.expires and req.expires.isoformat(), ) if self._tref is None: # first request starts flush timer. @@ -384,10 +383,10 @@ def on_return(result: Optional[Any]) -> None: runtime = 0 if type(result) == int: runtime = result - request.send_event('task-succeeded',result=None, runtime=runtime ) + request.send_event('task-succeeded', result=None, runtime=runtime ) return self._pool.apply_async( apply_batches_task, (self, serializable_requests, 0, None), accept_callback=on_accepted, callback=on_return, - ) + ) \ No newline at end of file From 3d140eafccddea295bd4c158cb8cebdbb0e1dbc7 Mon Sep 17 00:00:00 2001 From: Gershon Bialer Date: Fri, 2 Aug 2024 10:38:58 -0700 Subject: [PATCH 05/14] New tests and fix. --- celery_batches/trace.py | 13 +++++-- t/integration/test_events.py | 67 ++++++++++++++++++++++++++++++++++++ t/unit/test_signals.py | 8 ++--- 3 files changed, 82 insertions(+), 6 deletions(-) create mode 100644 t/integration/test_events.py diff --git a/celery_batches/trace.py b/celery_batches/trace.py index f38acef..29b909a 100644 --- a/celery_batches/trace.py +++ b/celery_batches/trace.py @@ -76,12 +76,21 @@ def apply_batches_task( logger.error("Error: %r", exc, exc_info=True) if failure_receivers: logger.debug("Debug: Sending failure signal") - send_failure(sender=task, task_id=task_id, exception=exc, args=args, kwargs={}, einfo=None) + send_failure(sender=task, + task_id=task_id, + exception=exc, + args=args, + kwargs={}, + einfo=None) else: if state == REVOKED: if revoked_receivers: logger.debug("Debug: Sending revoked signal") - send_revoked(sender=task, request=task_request, terminated=True, signum=None, expired=False) + send_revoked(sender=task, + request=task_request, + terminated=True, + signum=None, + expired=False) elif state == SUCCESS and success_receivers: logger.debug("Debug: Sending success signal") send_success(sender=task, result=result) diff --git a/t/integration/test_events.py b/t/integration/test_events.py new file mode 100644 index 0000000..cd01131 --- /dev/null +++ b/t/integration/test_events.py @@ -0,0 +1,67 @@ +import pytest + + +from celery import Celery, signals +from celery import shared_task +from celery.utils.log import get_task_logger +from typing import List +import sys + +from celery_batches import Batches, SimpleRequest + +def setup_celery(): + app = Celery('myapp') + app.conf.broker_url = 'memory://localhost/' + app.conf.result_backend = 'cache+memory://localhost/' + print("Created celeery app") + return app + +celery_app = setup_celery() + + + +@celery_app.task(base=Batches, flush_every=2, flush_interval=0.1) +def add(requests: List[SimpleRequest]) -> int: + """ + Add the first argument of each task. + + Marks the result of each task as the sum. + """ + print("add") + result = 0 + for request in requests: + result += sum(request.args) + sum(request.kwargs.values()) + + for request in requests: + celery_app.backend.mark_as_done(request.id, result, request=request) + + # TODO For EagerResults to work. + return result + +def test_tasks_for_add(): + # current_app.celery_broker_backend = 'memory' + print("test_tasks_for_add") + with celery_app.connection_for_write() as connection: + events_received = [0] + + def handler(event): + events_received[0] += 1 + + r = celery_app.events.Receiver(connection, + handlers={'*':handler}) + + result_1 = add.delay(1) + result_2 = add.delay(2) + + + print("READY") + assert result_1.get() == 3 + assert result_2.get() == 3 + + it = r.itercapture(limit=4,wakeup=True) + next(it) + assert events_received[0] > 0 + + + + diff --git a/t/unit/test_signals.py b/t/unit/test_signals.py index 9690986..f1e5880 100644 --- a/t/unit/test_signals.py +++ b/t/unit/test_signals.py @@ -1,7 +1,6 @@ import pytest -from unittest.mock import Mock, patch, DEFAULT +from unittest.mock import patch, DEFAULT from celery import signals -from celery_batches import Batches from celery_batches.trace import apply_batches_task from celery.utils.collections import AttributeDict @@ -44,6 +43,7 @@ def batch_task(): return tb + @pytest.fixture def simple_request(): return { @@ -143,6 +143,6 @@ def failing_run(requests): send_postrun=DEFAULT, send_failure=DEFAULT) as mocks: apply_batches_task(batch_task, ([simple_request],), 0, None) - + for mock in mocks.values(): - mock.assert_called_once() \ No newline at end of file + mock.assert_called_once() From cb59689557481ee32831271d63bd5ea513e24796 Mon Sep 17 00:00:00 2001 From: Gershon Bialer Date: Fri, 2 Aug 2024 11:03:19 -0700 Subject: [PATCH 06/14] Fix tests and formatting. --- celery_batches/__init__.py | 14 ++++-- celery_batches/trace.py | 34 +++++++++----- requirements/test.txt | 1 + t/integration/test_events.py | 90 ++++++++++++++++++------------------ 4 files changed, 78 insertions(+), 61 deletions(-) diff --git a/celery_batches/__init__.py b/celery_batches/__init__.py index 4e218fa..5a962c3 100644 --- a/celery_batches/__init__.py +++ b/celery_batches/__init__.py @@ -203,8 +203,9 @@ def Strategy(self, task: "Batches", app: Celery, consumer: Consumer) -> Callable eventer = consumer.event_dispatcher events = eventer and eventer.enabled send_event = eventer and eventer.send - task_sends_events = events and task.send_events - + task_sends_events = ( + events and task.send_events + ) Request = symbol_by_name(task.Request) # Celery 5.1 added the app argument to create_request_cls. @@ -263,9 +264,12 @@ def task_message_handler( if task_sends_events: send_event( 'task-received', - uuid=req.id, name=req.name, - args=req.argsrepr, kwargs=req.kwargsrepr, - root_id=req.root_id, parent_id=req.parent_id, + uuid=req.id, + name=req.name, + args=req.argsrepr, + kwargs=req.kwargsrepr, + root_id=req.root_id, + parent_id=req.parent_id, retries=req.request_dict.get('retries', 0), eta=req.eta and request.eta.isoformat(), expires=req.expires and req.expires.isoformat(), diff --git a/celery_batches/trace.py b/celery_batches/trace.py index 29b909a..5c54d80 100644 --- a/celery_batches/trace.py +++ b/celery_batches/trace.py @@ -61,7 +61,13 @@ def apply_batches_task( # -*- PRE -*- if prerun_receivers: logger.debug("Debug: Sending prerun signal") - send_prerun(sender=task, task_id=task_id, task=task, args=args, kwargs={}) + send_prerun( + sender=task, + task_id=task_id, + task=task, + args=args, + kwargs={} + ) # -*- TRACE -*- try: @@ -76,21 +82,25 @@ def apply_batches_task( logger.error("Error: %r", exc, exc_info=True) if failure_receivers: logger.debug("Debug: Sending failure signal") - send_failure(sender=task, - task_id=task_id, - exception=exc, - args=args, - kwargs={}, - einfo=None) + send_failure( + sender=task, + task_id=task_id, + exception=exc, + args=args, + kwargs={}, + einfo=None + ) else: if state == REVOKED: if revoked_receivers: logger.debug("Debug: Sending revoked signal") - send_revoked(sender=task, - request=task_request, - terminated=True, - signum=None, - expired=False) + send_revoked( + sender=task, + request=task_request, + terminated=True, + signum=None, + expired=False + ) elif state == SUCCESS and success_receivers: logger.debug("Debug: Sending success signal") send_success(sender=task, result=result) diff --git a/requirements/test.txt b/requirements/test.txt index a45cb0a..865e1ee 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -2,3 +2,4 @@ pytest-celery~=0.0.0 pytest~=6.2 coverage pytest-timeout +pytest-asyncio diff --git a/t/integration/test_events.py b/t/integration/test_events.py index cd01131..44a8ee9 100644 --- a/t/integration/test_events.py +++ b/t/integration/test_events.py @@ -1,25 +1,30 @@ import pytest - - -from celery import Celery, signals -from celery import shared_task -from celery.utils.log import get_task_logger +from celery import Celery +from celery_batches import Batches, SimpleRequest from typing import List -import sys +import asyncio +import logging -from celery_batches import Batches, SimpleRequest +pytest_plugins = ('pytest_asyncio',) + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) def setup_celery(): app = Celery('myapp') - app.conf.broker_url = 'memory://localhost/' - app.conf.result_backend = 'cache+memory://localhost/' - print("Created celeery app") + app.conf.update( + broker_url='memory://', + result_backend='cache+memory://', + task_always_eager=False, + worker_concurrency=1, + worker_prefetch_multiplier=1, + task_create_missing_queues=True, + broker_connection_retry_on_startup=True, + ) return app celery_app = setup_celery() - - @celery_app.task(base=Batches, flush_every=2, flush_interval=0.1) def add(requests: List[SimpleRequest]) -> int: """ @@ -27,41 +32,38 @@ def add(requests: List[SimpleRequest]) -> int: Marks the result of each task as the sum. """ - print("add") - result = 0 - for request in requests: - result += sum(request.args) + sum(request.kwargs.values()) + logger.debug(f"Processing {len(requests)} requests") + result = sum(sum(request.args) + sum(request.kwargs.values()) for request in requests) for request in requests: celery_app.backend.mark_as_done(request.id, result, request=request) - # TODO For EagerResults to work. + logger.debug(f"Finished processing. Result: {result}") return result -def test_tasks_for_add(): - # current_app.celery_broker_backend = 'memory' - print("test_tasks_for_add") - with celery_app.connection_for_write() as connection: - events_received = [0] - - def handler(event): - events_received[0] += 1 - - r = celery_app.events.Receiver(connection, - handlers={'*':handler}) - - result_1 = add.delay(1) - result_2 = add.delay(2) - - - print("READY") - assert result_1.get() == 3 - assert result_2.get() == 3 - - it = r.itercapture(limit=4,wakeup=True) - next(it) - assert events_received[0] > 0 - - - - +@pytest.mark.asyncio +async def test_tasks_for_add(celery_worker): + logger.debug("Starting test_tasks_for_add") + + # Send tasks + logger.debug("Sending tasks") + result_1 = add.delay(1) + result_2 = add.delay(2) + + logger.debug("Waiting for results") + try: + # Wait for the batch to be processed + results = await asyncio.wait_for(asyncio.gather( + asyncio.to_thread(result_1.get), + asyncio.to_thread(result_2.get) + ), timeout=5.0) + logger.debug(f"Results: {results}") + except asyncio.TimeoutError: + logger.error("Test timed out while waiting for results") + pytest.fail("Test timed out while waiting for results") + + # Check results + assert results[0] == 3, f"Expected 3, got {results[0]}" + assert results[1] == 3, f"Expected 3, got {results[1]}" + + logger.debug("Test completed successfully") From be3707699b05d33c55592431d395afe92f90deca Mon Sep 17 00:00:00 2001 From: Gershon Bialer Date: Fri, 2 Aug 2024 11:13:24 -0700 Subject: [PATCH 07/14] Fix flake8 issues. --- celery_batches/__init__.py | 10 +++++----- celery_batches/trace.py | 2 +- t/integration/test_events.py | 10 +++++++++- t/unit/test_signals.py | 1 - 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/celery_batches/__init__.py b/celery_batches/__init__.py index 5a962c3..ef6eadc 100644 --- a/celery_batches/__init__.py +++ b/celery_batches/__init__.py @@ -271,7 +271,7 @@ def task_message_handler( root_id=req.root_id, parent_id=req.parent_id, retries=req.request_dict.get('retries', 0), - eta=req.eta and request.eta.isoformat(), + eta=req.eta and req.eta.isoformat(), expires=req.expires and req.expires.isoformat(), ) @@ -373,7 +373,7 @@ def flush(self, requests: Collection[Request]) -> Any: # Ensure the requests can be serialized using pickle for the prefork pool. serializable_requests = ([SimpleRequest.from_request(r) for r in requests],) - + def on_accepted(pid: int, time_accepted: float) -> None: for req in acks_early: req.acknowledge() @@ -385,12 +385,12 @@ def on_return(result: Optional[Any]) -> None: req.acknowledge() for request in requests: runtime = 0 - if type(result) == int: + if isinstance(result, int): runtime = result - request.send_event('task-succeeded', result=None, runtime=runtime ) + request.send_event('task-succeeded', result=None, runtime=runtime) return self._pool.apply_async( apply_batches_task, (self, serializable_requests, 0, None), accept_callback=on_accepted, callback=on_return, - ) \ No newline at end of file + ) diff --git a/celery_batches/trace.py b/celery_batches/trace.py index 5c54d80..683f052 100644 --- a/celery_batches/trace.py +++ b/celery_batches/trace.py @@ -121,4 +121,4 @@ def apply_batches_task( pop_task() pop_request() - return result \ No newline at end of file + return result diff --git a/t/integration/test_events.py b/t/integration/test_events.py index 44a8ee9..654d3cf 100644 --- a/t/integration/test_events.py +++ b/t/integration/test_events.py @@ -5,11 +5,13 @@ import asyncio import logging + pytest_plugins = ('pytest_asyncio',) logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) + def setup_celery(): app = Celery('myapp') app.conf.update( @@ -23,8 +25,10 @@ def setup_celery(): ) return app + celery_app = setup_celery() + @celery_app.task(base=Batches, flush_every=2, flush_interval=0.1) def add(requests: List[SimpleRequest]) -> int: """ @@ -33,7 +37,10 @@ def add(requests: List[SimpleRequest]) -> int: Marks the result of each task as the sum. """ logger.debug(f"Processing {len(requests)} requests") - result = sum(sum(request.args) + sum(request.kwargs.values()) for request in requests) + result = sum( + sum(request.args) + sum(request.kwargs.values()) + for request in requests + ) for request in requests: celery_app.backend.mark_as_done(request.id, result, request=request) @@ -41,6 +48,7 @@ def add(requests: List[SimpleRequest]) -> int: logger.debug(f"Finished processing. Result: {result}") return result + @pytest.mark.asyncio async def test_tasks_for_add(celery_worker): logger.debug("Starting test_tasks_for_add") diff --git a/t/unit/test_signals.py b/t/unit/test_signals.py index f1e5880..2d60da3 100644 --- a/t/unit/test_signals.py +++ b/t/unit/test_signals.py @@ -127,7 +127,6 @@ def test_all_signals_sent(batch_task, simple_request): send_postrun=DEFAULT, send_success=DEFAULT) as mocks: apply_batches_task(batch_task, ([simple_request],), 0, None) - for mock in mocks.values(): mock.assert_called_once() From 43851f91370c9a16a180200b32261d47a312e0b1 Mon Sep 17 00:00:00 2001 From: Gershon Bialer Date: Fri, 2 Aug 2024 11:49:26 -0700 Subject: [PATCH 08/14] Format stuff properly with linter. --- celery_batches/__init__.py | 13 ++-- celery_batches/trace.py | 49 ++++++------- t/integration/test_events.py | 38 ++++++----- t/unit/test_signals.py | 129 +++++++++++++++++++---------------- 4 files changed, 120 insertions(+), 109 deletions(-) diff --git a/celery_batches/__init__.py b/celery_batches/__init__.py index ef6eadc..bbc7663 100644 --- a/celery_batches/__init__.py +++ b/celery_batches/__init__.py @@ -203,9 +203,7 @@ def Strategy(self, task: "Batches", app: Celery, consumer: Consumer) -> Callable eventer = consumer.event_dispatcher events = eventer and eventer.enabled send_event = eventer and eventer.send - task_sends_events = ( - events and task.send_events - ) + task_sends_events = events and task.send_events Request = symbol_by_name(task.Request) # Celery 5.1 added the app argument to create_request_cls. @@ -263,14 +261,14 @@ def task_message_handler( if task_sends_events: send_event( - 'task-received', + "task-received", uuid=req.id, name=req.name, args=req.argsrepr, kwargs=req.kwargsrepr, root_id=req.root_id, parent_id=req.parent_id, - retries=req.request_dict.get('retries', 0), + retries=req.request_dict.get("retries", 0), eta=req.eta and req.eta.isoformat(), expires=req.expires and req.expires.isoformat(), ) @@ -378,7 +376,7 @@ def on_accepted(pid: int, time_accepted: float) -> None: for req in acks_early: req.acknowledge() for request in requests: - request.send_event('task-started') + request.send_event("task-started") def on_return(result: Optional[Any]) -> None: for req in acks_late: @@ -387,7 +385,8 @@ def on_return(result: Optional[Any]) -> None: runtime = 0 if isinstance(result, int): runtime = result - request.send_event('task-succeeded', result=None, runtime=runtime) + request.send_event("task-succeeded", result=None, runtime=runtime) + return self._pool.apply_async( apply_batches_task, (self, serializable_requests, 0, None), diff --git a/celery_batches/trace.py b/celery_batches/trace.py index 683f052..a9cd1b3 100644 --- a/celery_batches/trace.py +++ b/celery_batches/trace.py @@ -6,7 +6,7 @@ Mimics some of the functionality found in celery.app.trace.trace_task. """ -from typing import TYPE_CHECKING, Any, List, Tuple +from typing import TYPE_CHECKING, Any, List, Tuple, Union from celery import signals, states from celery._state import _task_stack @@ -57,25 +57,19 @@ def apply_batches_task( task_request = Context(loglevel=loglevel, logfile=logfile) push_request(task_request) + result: Union[Any, Exception] + state: str + try: # -*- PRE -*- if prerun_receivers: logger.debug("Debug: Sending prerun signal") - send_prerun( - sender=task, - task_id=task_id, - task=task, - args=args, - kwargs={} - ) + send_prerun(sender=task, task_id=task_id, task=task, args=args, kwargs={}) # -*- TRACE -*- try: result = task.run(*args) - if hasattr(task.request, 'state') and task.request.state == REVOKED: - state = REVOKED - else: - state = SUCCESS + state = REVOKED if hasattr(task.request, "state") and task.request.state == REVOKED else SUCCESS except Exception as exc: result = exc state = FAILURE @@ -88,22 +82,23 @@ def apply_batches_task( exception=exc, args=args, kwargs={}, - einfo=None + einfo=None, ) - else: - if state == REVOKED: - if revoked_receivers: - logger.debug("Debug: Sending revoked signal") - send_revoked( - sender=task, - request=task_request, - terminated=True, - signum=None, - expired=False - ) - elif state == SUCCESS and success_receivers: - logger.debug("Debug: Sending success signal") - send_success(sender=task, result=result) + + # Handle signals based on the state + if state == REVOKED and revoked_receivers: + logger.debug("Debug: Sending revoked signal") + send_revoked( + sender=task, + request=task_request, + terminated=True, + signum=None, + expired=False, + ) + elif state == SUCCESS and success_receivers: + logger.debug("Debug: Sending success signal") + send_success(sender=task, result=result) + finally: try: if postrun_receivers: diff --git a/t/integration/test_events.py b/t/integration/test_events.py index 654d3cf..19c8f2b 100644 --- a/t/integration/test_events.py +++ b/t/integration/test_events.py @@ -1,22 +1,24 @@ -import pytest -from celery import Celery -from celery_batches import Batches, SimpleRequest -from typing import List import asyncio import logging +from typing import List, Any +from celery_batches import Batches, SimpleRequest -pytest_plugins = ('pytest_asyncio',) +from celery import Celery + +import pytest + +pytest_plugins = ("pytest_asyncio",) logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) -def setup_celery(): - app = Celery('myapp') +def setup_celery() -> Celery: + app = Celery("myapp") app.conf.update( - broker_url='memory://', - result_backend='cache+memory://', + broker_url="memory://", + result_backend="cache+memory://", task_always_eager=False, worker_concurrency=1, worker_prefetch_multiplier=1, @@ -37,10 +39,10 @@ def add(requests: List[SimpleRequest]) -> int: Marks the result of each task as the sum. """ logger.debug(f"Processing {len(requests)} requests") - result = sum( - sum(request.args) + sum(request.kwargs.values()) + result = int(sum( + sum(int(arg) for arg in request.args) + sum(int(value) for value in request.kwargs.values()) for request in requests - ) + )) for request in requests: celery_app.backend.mark_as_done(request.id, result, request=request) @@ -50,7 +52,7 @@ def add(requests: List[SimpleRequest]) -> int: @pytest.mark.asyncio -async def test_tasks_for_add(celery_worker): +async def test_tasks_for_add(celery_worker: Any) -> None: logger.debug("Starting test_tasks_for_add") # Send tasks @@ -61,10 +63,12 @@ async def test_tasks_for_add(celery_worker): logger.debug("Waiting for results") try: # Wait for the batch to be processed - results = await asyncio.wait_for(asyncio.gather( - asyncio.to_thread(result_1.get), - asyncio.to_thread(result_2.get) - ), timeout=5.0) + results = await asyncio.wait_for( + asyncio.gather( + asyncio.to_thread(result_1.get), asyncio.to_thread(result_2.get) + ), + timeout=5.0, + ) logger.debug(f"Results: {results}") except asyncio.TimeoutError: logger.error("Test timed out while waiting for results") diff --git a/t/unit/test_signals.py b/t/unit/test_signals.py index 2d60da3..548ca2b 100644 --- a/t/unit/test_signals.py +++ b/t/unit/test_signals.py @@ -1,22 +1,29 @@ +from unittest.mock import DEFAULT, patch +from typing import Any, Dict, List, Optional, Generator + +from celery_batches import Batches, SimpleRequest # type: ignore +from celery_batches.trace import apply_batches_task # type: ignore + +from celery import signals # type: ignore +from celery.utils.collections import AttributeDict # type: ignore + import pytest -from unittest.mock import patch, DEFAULT -from celery import signals -from celery_batches.trace import apply_batches_task -from celery.utils.collections import AttributeDict -class TestBatchTask: +class TestBatchTask(Batches): + _request_stack: List[Any] + _request: AttributeDict @property - def request_stack(self): + def request_stack(self) -> Any: class RequestStack: - def __init__(self, stack): + def __init__(self, stack: List[Any]): self.stack = stack - def push(self, item): + def push(self, item: Any) -> None: self.stack.append(item) - def pop(self): + def pop(self) -> Optional[Any]: if self.stack: return self.stack.pop() return None @@ -24,45 +31,47 @@ def pop(self): return RequestStack(self._request_stack) @property - def request(self): + def request(self) -> AttributeDict: return self._request @request.setter - def request(self, value): + def request(self, value: AttributeDict) -> None: self._request = value - def run(self, requests): - return [request['id'] for request in requests] + def run(self, *args: Any, **kwargs: Any) -> List[str]: + requests = args[0] if args else kwargs.get('requests', []) + result = [request.id for request in requests if hasattr(request, 'id')] + return result # Changed from raise NoReturn to return @pytest.fixture -def batch_task(): +def batch_task() -> TestBatchTask: tb = TestBatchTask() tb._request_stack = [] - tb._request = AttributeDict({'state': None}) + tb._request = AttributeDict({"state": None}) return tb @pytest.fixture -def simple_request(): - return { - 'id': "test_id", - 'name': "test_task", - 'args': (), - 'kwargs': {}, - 'delivery_info': {}, - 'hostname': "test_host", - 'ignore_result': False, - 'reply_to': None, - 'correlation_id': None, - 'request_dict': {}, - } +def simple_request() -> SimpleRequest: + return SimpleRequest( + id="test_id", + name="test_task", + args=(), + kwargs={}, + delivery_info={}, + hostname="test_host", + ignore_result=False, + reply_to=None, + correlation_id=None, + request_dict={}, + ) @pytest.fixture(autouse=True) -def setup_signal_receivers(): - def dummy_receiver(*args, **kwargs): +def setup_signal_receivers() -> Generator[None, None, None]: + def dummy_receiver(*args: Any, **kwargs: Any) -> None: pass signals.task_prerun.connect(dummy_receiver) @@ -80,67 +89,71 @@ def dummy_receiver(*args, **kwargs): signals.task_revoked.disconnect(dummy_receiver) -def test_task_prerun_signal(batch_task, simple_request): - with patch('celery_batches.trace.send_prerun') as mock_send: +def test_task_prerun_signal(batch_task: TestBatchTask, simple_request: SimpleRequest) -> None: + with patch("celery_batches.trace.send_prerun") as mock_send: apply_batches_task(batch_task, ([simple_request],), 0, None) mock_send.assert_called_once() -def test_task_postrun_signal(batch_task, simple_request): - with patch('celery_batches.trace.send_postrun') as mock_send: +def test_task_postrun_signal(batch_task: TestBatchTask, simple_request: SimpleRequest) -> None: + with patch("celery_batches.trace.send_postrun") as mock_send: apply_batches_task(batch_task, ([simple_request],), 0, None) mock_send.assert_called_once() -def test_task_success_signal(batch_task, simple_request): - with patch('celery_batches.trace.send_success') as mock_send: +def test_task_success_signal(batch_task: TestBatchTask, simple_request: SimpleRequest) -> None: + with patch("celery_batches.trace.send_success") as mock_send: apply_batches_task(batch_task, ([simple_request],), 0, None) mock_send.assert_called_once() -def test_task_failure_signal(batch_task, simple_request): - def failing_run(requests): +def test_task_failure_signal(batch_task: TestBatchTask, simple_request: SimpleRequest) -> None: + def failing_run(*args: Any, **kwargs: Any) -> None: raise ValueError("Test exception") - batch_task.run = failing_run + batch_task.run = failing_run # type: ignore - with patch('celery_batches.trace.send_failure') as mock_send: + with patch("celery_batches.trace.send_failure") as mock_send: apply_batches_task(batch_task, ([simple_request],), 0, None) mock_send.assert_called_once() -def test_task_revoked_signal(batch_task, simple_request): - def revoking_run(requests): - batch_task.request.state = 'REVOKED' - return [] +def test_task_revoked_signal(batch_task: TestBatchTask, simple_request: SimpleRequest) -> None: + def revoking_run(*args: Any, **kwargs: Any) -> None: + batch_task.request.state = "REVOKED" + return [] # Changed from raise NoReturn to return - batch_task.run = revoking_run + batch_task.run = revoking_run # type: ignore - with patch('celery_batches.trace.send_revoked') as mock_send: + with patch("celery_batches.trace.send_revoked") as mock_send: apply_batches_task(batch_task, ([simple_request],), 0, None) mock_send.assert_called_once() -def test_all_signals_sent(batch_task, simple_request): - with patch.multiple('celery_batches.trace', - send_prerun=DEFAULT, - send_postrun=DEFAULT, - send_success=DEFAULT) as mocks: +def test_all_signals_sent(batch_task: TestBatchTask, simple_request: SimpleRequest) -> None: + with patch.multiple( + "celery_batches.trace", + send_prerun=DEFAULT, + send_postrun=DEFAULT, + send_success=DEFAULT, + ) as mocks: apply_batches_task(batch_task, ([simple_request],), 0, None) for mock in mocks.values(): mock.assert_called_once() -def test_failure_signals_sent(batch_task, simple_request): - def failing_run(requests): +def test_failure_signals_sent(batch_task: TestBatchTask, simple_request: SimpleRequest) -> None: + def failing_run(*args: Any, **kwargs: Any) -> None: raise ValueError("Test exception") - batch_task.run = failing_run + batch_task.run = failing_run # type: ignore - with patch.multiple('celery_batches.trace', - send_prerun=DEFAULT, - send_postrun=DEFAULT, - send_failure=DEFAULT) as mocks: + with patch.multiple( + "celery_batches.trace", + send_prerun=DEFAULT, + send_postrun=DEFAULT, + send_failure=DEFAULT, + ) as mocks: apply_batches_task(batch_task, ([simple_request],), 0, None) for mock in mocks.values(): From b5ae1d9c2daf1ecae44147b5ec2ab28dfe1d8adb Mon Sep 17 00:00:00 2001 From: Gershon Bialer Date: Mon, 5 Aug 2024 12:44:58 -0700 Subject: [PATCH 09/14] Fix lintin issues. --- celery_batches/trace.py | 3 ++- setup.cfg | 5 +++++ t/integration/test_events.py | 3 ++- t/unit/test_signals.py | 23 +++++++++++++++-------- 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/celery_batches/trace.py b/celery_batches/trace.py index a9cd1b3..007d8f8 100644 --- a/celery_batches/trace.py +++ b/celery_batches/trace.py @@ -69,7 +69,8 @@ def apply_batches_task( # -*- TRACE -*- try: result = task.run(*args) - state = REVOKED if hasattr(task.request, "state") and task.request.state == REVOKED else SUCCESS + state = REVOKED if (hasattr(task.request, "state") + and task.request.state == REVOKED) else SUCCESS except Exception as exc: result = exc state = FAILURE diff --git a/setup.cfg b/setup.cfg index 09da6f7..dabf3af 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,6 +39,11 @@ packages = install_requires = celery>=5.0,<5.5 python_requires = >=3.8 +[options.extras_require] +test = + pytest + pytest-asyncio + [flake8] extend-ignore = E203 max-line-length = 88 diff --git a/t/integration/test_events.py b/t/integration/test_events.py index 19c8f2b..739510e 100644 --- a/t/integration/test_events.py +++ b/t/integration/test_events.py @@ -40,7 +40,8 @@ def add(requests: List[SimpleRequest]) -> int: """ logger.debug(f"Processing {len(requests)} requests") result = int(sum( - sum(int(arg) for arg in request.args) + sum(int(value) for value in request.kwargs.values()) + sum(int(arg) for arg in request.args) + + sum(int(value) for value in request.kwargs.values()) for request in requests )) diff --git a/t/unit/test_signals.py b/t/unit/test_signals.py index 548ca2b..82f72de 100644 --- a/t/unit/test_signals.py +++ b/t/unit/test_signals.py @@ -1,5 +1,5 @@ from unittest.mock import DEFAULT, patch -from typing import Any, Dict, List, Optional, Generator +from typing import Any, List, Optional, Generator from celery_batches import Batches, SimpleRequest # type: ignore from celery_batches.trace import apply_batches_task # type: ignore @@ -89,25 +89,29 @@ def dummy_receiver(*args: Any, **kwargs: Any) -> None: signals.task_revoked.disconnect(dummy_receiver) -def test_task_prerun_signal(batch_task: TestBatchTask, simple_request: SimpleRequest) -> None: +def test_task_prerun_signal(batch_task: TestBatchTask, + simple_request: SimpleRequest) -> None: with patch("celery_batches.trace.send_prerun") as mock_send: apply_batches_task(batch_task, ([simple_request],), 0, None) mock_send.assert_called_once() -def test_task_postrun_signal(batch_task: TestBatchTask, simple_request: SimpleRequest) -> None: +def test_task_postrun_signal(batch_task: TestBatchTask, + simple_request: SimpleRequest) -> None: with patch("celery_batches.trace.send_postrun") as mock_send: apply_batches_task(batch_task, ([simple_request],), 0, None) mock_send.assert_called_once() -def test_task_success_signal(batch_task: TestBatchTask, simple_request: SimpleRequest) -> None: +def test_task_success_signal(batch_task: TestBatchTask, + simple_request: SimpleRequest) -> None: with patch("celery_batches.trace.send_success") as mock_send: apply_batches_task(batch_task, ([simple_request],), 0, None) mock_send.assert_called_once() -def test_task_failure_signal(batch_task: TestBatchTask, simple_request: SimpleRequest) -> None: +def test_task_failure_signal(batch_task: TestBatchTask, + simple_request: SimpleRequest) -> None: def failing_run(*args: Any, **kwargs: Any) -> None: raise ValueError("Test exception") @@ -118,7 +122,8 @@ def failing_run(*args: Any, **kwargs: Any) -> None: mock_send.assert_called_once() -def test_task_revoked_signal(batch_task: TestBatchTask, simple_request: SimpleRequest) -> None: +def test_task_revoked_signal(batch_task: TestBatchTask, + simple_request: SimpleRequest) -> None: def revoking_run(*args: Any, **kwargs: Any) -> None: batch_task.request.state = "REVOKED" return [] # Changed from raise NoReturn to return @@ -130,7 +135,8 @@ def revoking_run(*args: Any, **kwargs: Any) -> None: mock_send.assert_called_once() -def test_all_signals_sent(batch_task: TestBatchTask, simple_request: SimpleRequest) -> None: +def test_all_signals_sent(batch_task: TestBatchTask, + simple_request: SimpleRequest) -> None: with patch.multiple( "celery_batches.trace", send_prerun=DEFAULT, @@ -142,7 +148,8 @@ def test_all_signals_sent(batch_task: TestBatchTask, simple_request: SimpleReque mock.assert_called_once() -def test_failure_signals_sent(batch_task: TestBatchTask, simple_request: SimpleRequest) -> None: +def test_failure_signals_sent(batch_task: TestBatchTask, + simple_request: SimpleRequest) -> None: def failing_run(*args: Any, **kwargs: Any) -> None: raise ValueError("Test exception") From 722a5c7ee50f25a2d66e3b3700bd9bc12d1bdcdd Mon Sep 17 00:00:00 2001 From: Gershon Bialer Date: Mon, 5 Aug 2024 12:48:17 -0700 Subject: [PATCH 10/14] Fix sort order --- t/integration/test_events.py | 2 +- t/unit/test_signals.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/t/integration/test_events.py b/t/integration/test_events.py index 739510e..acddb5b 100644 --- a/t/integration/test_events.py +++ b/t/integration/test_events.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import List, Any +from typing import Any, List from celery_batches import Batches, SimpleRequest diff --git a/t/unit/test_signals.py b/t/unit/test_signals.py index 82f72de..a48a6bc 100644 --- a/t/unit/test_signals.py +++ b/t/unit/test_signals.py @@ -1,11 +1,11 @@ +from typing import Any, Generator, List, Optional from unittest.mock import DEFAULT, patch -from typing import Any, List, Optional, Generator -from celery_batches import Batches, SimpleRequest # type: ignore -from celery_batches.trace import apply_batches_task # type: ignore +from celery_batches import Batches, SimpleRequest +from celery_batches.trace import apply_batches_task -from celery import signals # type: ignore -from celery.utils.collections import AttributeDict # type: ignore +from celery import signals +from celery.utils.collections import AttributeDict import pytest From 9e51aa6cb6d82a8ea15136c8454292b7b602acb4 Mon Sep 17 00:00:00 2001 From: Gershon Bialer Date: Mon, 5 Aug 2024 13:58:50 -0700 Subject: [PATCH 11/14] Fix black liniting issues. --- celery_batches/trace.py | 7 +++++-- t/integration/test_events.py | 12 ++++++----- t/unit/test_signals.py | 39 +++++++++++++++++++++--------------- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/celery_batches/trace.py b/celery_batches/trace.py index 007d8f8..8f6d388 100644 --- a/celery_batches/trace.py +++ b/celery_batches/trace.py @@ -69,8 +69,11 @@ def apply_batches_task( # -*- TRACE -*- try: result = task.run(*args) - state = REVOKED if (hasattr(task.request, "state") - and task.request.state == REVOKED) else SUCCESS + state = ( + REVOKED + if (hasattr(task.request, "state") and task.request.state == REVOKED) + else SUCCESS + ) except Exception as exc: result = exc state = FAILURE diff --git a/t/integration/test_events.py b/t/integration/test_events.py index acddb5b..d3a86ff 100644 --- a/t/integration/test_events.py +++ b/t/integration/test_events.py @@ -39,11 +39,13 @@ def add(requests: List[SimpleRequest]) -> int: Marks the result of each task as the sum. """ logger.debug(f"Processing {len(requests)} requests") - result = int(sum( - sum(int(arg) for arg in request.args) - + sum(int(value) for value in request.kwargs.values()) - for request in requests - )) + result = int( + sum( + sum(int(arg) for arg in request.args) + + sum(int(value) for value in request.kwargs.values()) + for request in requests + ) + ) for request in requests: celery_app.backend.mark_as_done(request.id, result, request=request) diff --git a/t/unit/test_signals.py b/t/unit/test_signals.py index a48a6bc..d6135f8 100644 --- a/t/unit/test_signals.py +++ b/t/unit/test_signals.py @@ -39,8 +39,8 @@ def request(self, value: AttributeDict) -> None: self._request = value def run(self, *args: Any, **kwargs: Any) -> List[str]: - requests = args[0] if args else kwargs.get('requests', []) - result = [request.id for request in requests if hasattr(request, 'id')] + requests = args[0] if args else kwargs.get("requests", []) + result = [request.id for request in requests if hasattr(request, "id")] return result # Changed from raise NoReturn to return @@ -89,29 +89,33 @@ def dummy_receiver(*args: Any, **kwargs: Any) -> None: signals.task_revoked.disconnect(dummy_receiver) -def test_task_prerun_signal(batch_task: TestBatchTask, - simple_request: SimpleRequest) -> None: +def test_task_prerun_signal( + batch_task: TestBatchTask, simple_request: SimpleRequest +) -> None: with patch("celery_batches.trace.send_prerun") as mock_send: apply_batches_task(batch_task, ([simple_request],), 0, None) mock_send.assert_called_once() -def test_task_postrun_signal(batch_task: TestBatchTask, - simple_request: SimpleRequest) -> None: +def test_task_postrun_signal( + batch_task: TestBatchTask, simple_request: SimpleRequest +) -> None: with patch("celery_batches.trace.send_postrun") as mock_send: apply_batches_task(batch_task, ([simple_request],), 0, None) mock_send.assert_called_once() -def test_task_success_signal(batch_task: TestBatchTask, - simple_request: SimpleRequest) -> None: +def test_task_success_signal( + batch_task: TestBatchTask, simple_request: SimpleRequest +) -> None: with patch("celery_batches.trace.send_success") as mock_send: apply_batches_task(batch_task, ([simple_request],), 0, None) mock_send.assert_called_once() -def test_task_failure_signal(batch_task: TestBatchTask, - simple_request: SimpleRequest) -> None: +def test_task_failure_signal( + batch_task: TestBatchTask, simple_request: SimpleRequest +) -> None: def failing_run(*args: Any, **kwargs: Any) -> None: raise ValueError("Test exception") @@ -122,8 +126,9 @@ def failing_run(*args: Any, **kwargs: Any) -> None: mock_send.assert_called_once() -def test_task_revoked_signal(batch_task: TestBatchTask, - simple_request: SimpleRequest) -> None: +def test_task_revoked_signal( + batch_task: TestBatchTask, simple_request: SimpleRequest +) -> None: def revoking_run(*args: Any, **kwargs: Any) -> None: batch_task.request.state = "REVOKED" return [] # Changed from raise NoReturn to return @@ -135,8 +140,9 @@ def revoking_run(*args: Any, **kwargs: Any) -> None: mock_send.assert_called_once() -def test_all_signals_sent(batch_task: TestBatchTask, - simple_request: SimpleRequest) -> None: +def test_all_signals_sent( + batch_task: TestBatchTask, simple_request: SimpleRequest +) -> None: with patch.multiple( "celery_batches.trace", send_prerun=DEFAULT, @@ -148,8 +154,9 @@ def test_all_signals_sent(batch_task: TestBatchTask, mock.assert_called_once() -def test_failure_signals_sent(batch_task: TestBatchTask, - simple_request: SimpleRequest) -> None: +def test_failure_signals_sent( + batch_task: TestBatchTask, simple_request: SimpleRequest +) -> None: def failing_run(*args: Any, **kwargs: Any) -> None: raise ValueError("Test exception") From 67e07bc51c0e3c21443a14f11da284f4c7dedecd Mon Sep 17 00:00:00 2001 From: Gershon Bialer Date: Tue, 6 Aug 2024 10:00:02 -0700 Subject: [PATCH 12/14] Remove unused imports. --- celery_batches/__init__.py | 3 +-- celery_batches/trace.py | 2 +- t/unit/test_signals.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/celery_batches/__init__.py b/celery_batches/__init__.py index bbc7663..184d6f5 100644 --- a/celery_batches/__init__.py +++ b/celery_batches/__init__.py @@ -7,7 +7,6 @@ Collection, Dict, Iterable, - NoReturn, Optional, Set, Tuple, @@ -185,7 +184,7 @@ def __init__(self) -> None: self._tref: Optional[Timer] = None self._pool: BasePool = None - def run(self, *args: Any, **kwargs: Any) -> NoReturn: + def run(self, *args: Any, **kwargs: Any) -> list[str]: raise NotImplementedError("must implement run(requests)") def Strategy(self, task: "Batches", app: Celery, consumer: Consumer) -> Callable: diff --git a/celery_batches/trace.py b/celery_batches/trace.py index 8f6d388..370bdc7 100644 --- a/celery_batches/trace.py +++ b/celery_batches/trace.py @@ -68,7 +68,7 @@ def apply_batches_task( # -*- TRACE -*- try: - result = task.run(*args) + result = task(*args) state = ( REVOKED if (hasattr(task.request, "state") and task.request.state == REVOKED) diff --git a/t/unit/test_signals.py b/t/unit/test_signals.py index d6135f8..edbba24 100644 --- a/t/unit/test_signals.py +++ b/t/unit/test_signals.py @@ -129,7 +129,7 @@ def failing_run(*args: Any, **kwargs: Any) -> None: def test_task_revoked_signal( batch_task: TestBatchTask, simple_request: SimpleRequest ) -> None: - def revoking_run(*args: Any, **kwargs: Any) -> None: + def revoking_run(*args: Any, **kwargs: Any) -> List: batch_task.request.state = "REVOKED" return [] # Changed from raise NoReturn to return From 15ac37e0cee4481761470f389de034579a5444ef Mon Sep 17 00:00:00 2001 From: Gershon Bialer Date: Tue, 6 Aug 2024 10:21:55 -0700 Subject: [PATCH 13/14] Change return type for run to any. --- celery_batches/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/celery_batches/__init__.py b/celery_batches/__init__.py index 184d6f5..405128c 100644 --- a/celery_batches/__init__.py +++ b/celery_batches/__init__.py @@ -184,7 +184,7 @@ def __init__(self) -> None: self._tref: Optional[Timer] = None self._pool: BasePool = None - def run(self, *args: Any, **kwargs: Any) -> list[str]: + def run(self, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError("must implement run(requests)") def Strategy(self, task: "Batches", app: Celery, consumer: Consumer) -> Callable: From f5ff15d8aa762c2539da76911862c63d7a3d57e7 Mon Sep 17 00:00:00 2001 From: Gershon Bialer Date: Tue, 6 Aug 2024 10:48:07 -0700 Subject: [PATCH 14/14] Fix test to work with older version of Python. --- t/integration/test_events.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/t/integration/test_events.py b/t/integration/test_events.py index d3a86ff..1d6977c 100644 --- a/t/integration/test_events.py +++ b/t/integration/test_events.py @@ -68,7 +68,8 @@ async def test_tasks_for_add(celery_worker: Any) -> None: # Wait for the batch to be processed results = await asyncio.wait_for( asyncio.gather( - asyncio.to_thread(result_1.get), asyncio.to_thread(result_2.get) + asyncio.get_event_loop().run_in_executor(None, result_1.get), + asyncio.get_event_loop().run_in_executor(None, result_2.get), ), timeout=5.0, )