diff --git a/taskbadger/celery.py b/taskbadger/celery.py index 98f4a1a..dbf8072 100644 --- a/taskbadger/celery.py +++ b/taskbadger/celery.py @@ -1,5 +1,6 @@ import collections import functools +import json import logging import celery @@ -10,6 +11,7 @@ task_retry, task_success, ) +from kombu import serialization from .internal.models import StatusEnum from .mug import Badger @@ -18,7 +20,7 @@ KWARG_PREFIX = "taskbadger_" TB_KWARGS_ARG = f"{KWARG_PREFIX}kwargs" -IGNORE_ARGS = {TB_KWARGS_ARG, f"{KWARG_PREFIX}task", f"{KWARG_PREFIX}task_id"} +IGNORE_ARGS = {TB_KWARGS_ARG, f"{KWARG_PREFIX}task", f"{KWARG_PREFIX}task_id", f"{KWARG_PREFIX}record_task_args"} TB_TASK_ID = f"{KWARG_PREFIX}task_id" TERMINAL_STATES = { @@ -124,6 +126,8 @@ def apply_async(self, *args, **kwargs): if Badger.is_configured(): headers["taskbadger_track"] = True headers[TB_KWARGS_ARG] = tb_kwargs + if "record_task_args" in tb_kwargs: + headers["taskbadger_record_task_args"] = tb_kwargs.pop("record_task_args") result = super().apply_async(*args, **kwargs) @@ -187,6 +191,20 @@ def task_publish_handler(sender=None, headers=None, body=None, **kwargs): kwargs["status"] = StatusEnum.PENDING name = kwargs.pop("name", headers["task"]) + global_record_task_args = celery_system and celery_system.record_task_args + if headers.get("taskbadger_record_task_args", global_record_task_args): + data = { + "celery_task_args": body[0], + "celery_task_kwargs": body[1], + } + try: + _, _, value = serialization.dumps(data, serializer="json") + data = json.loads(value) + except Exception: + log.error("Error serializing task arguments for task '%s'", name) + else: + kwargs.setdefault("data", {}).update(data) + task = create_task_safe(name, **kwargs) if task: meta = {TB_TASK_ID: task.id} diff --git a/taskbadger/systems/celery.py b/taskbadger/systems/celery.py index 0c431a6..0cb06c9 100644 --- a/taskbadger/systems/celery.py +++ b/taskbadger/systems/celery.py @@ -6,7 +6,7 @@ class CelerySystemIntegration(System): identifier = "celery" - def __init__(self, auto_track_tasks=True, includes=None, excludes=None): + def __init__(self, auto_track_tasks=True, includes=None, excludes=None, record_task_args=False): """ Args: auto_track_tasks: Automatically track all Celery tasks regardless of whether they are using the @@ -16,10 +16,12 @@ def __init__(self, auto_track_tasks=True, includes=None, excludes=None): matches both an include and an exclude, it will be excluded. excludes: A list of task names to exclude from tracking. As with `includes`, these can be either the full task name or a regular expression. Exclusions take precedence over inclusions. + record_task_args: Record the arguments passed to each task. """ self.auto_track_tasks = auto_track_tasks self.includes = includes self.excludes = excludes + self.record_task_args = record_task_args if auto_track_tasks: # Importing this here ensures that the Celery signal handlers are registered diff --git a/tests/test_celery.py b/tests/test_celery.py index 16a1e78..cdea029 100644 --- a/tests/test_celery.py +++ b/tests/test_celery.py @@ -13,6 +13,7 @@ import celery import pytest +from kombu.utils.json import register_type from taskbadger import Action, EmailIntegration, StatusEnum from taskbadger.celery import Task @@ -111,6 +112,109 @@ def add_with_task_args(self, a, b): create.assert_called_once_with("new_name", value_max=10, actions=actions, status=StatusEnum.PENDING) +def test_celery_record_args(celery_session_app, celery_session_worker, bind_settings): + @celery_session_app.task(bind=True, base=Task) + def add_with_task_args(self, a, b): + assert self.taskbadger_task is not None + return a + b + + celery_session_worker.reload() + + with ( + mock.patch("taskbadger.celery.create_task_safe") as create, + mock.patch("taskbadger.celery.update_task_safe"), + mock.patch("taskbadger.celery.get_task"), + ): + create.return_value = task_for_test() + + result = add_with_task_args.apply_async( + (2, 2), + taskbadger_name="new_name", + taskbadger_value_max=10, + taskbadger_kwargs={"data": {"foo": "bar"}}, + taskbadger_record_task_args=True, + ) + assert result.get(timeout=10, propagate=True) == 4 + + create.assert_called_once_with( + "new_name", + value_max=10, + data={"foo": "bar", "celery_task_args": [2, 2], "celery_task_kwargs": {}}, + status=StatusEnum.PENDING, + ) + + +def test_celery_record_task_kwargs(celery_session_app, celery_session_worker, bind_settings): + @celery_session_app.task(bind=True, base=Task) + def add_with_task_kwargs(self, a, b, c=0): + assert self.taskbadger_task is not None + return a + b + c + + celery_session_worker.reload() + + with ( + mock.patch("taskbadger.celery.create_task_safe") as create, + mock.patch("taskbadger.celery.update_task_safe"), + mock.patch("taskbadger.celery.get_task"), + ): + create.return_value = task_for_test() + + actions = [Action("stale", integration=EmailIntegration(to="test@test.com"))] + result = add_with_task_kwargs.delay( + 2, + 2, + c=3, + taskbadger_name="new_name", + taskbadger_value_max=10, + taskbadger_kwargs={"actions": actions}, + taskbadger_record_task_args=True, + ) + assert result.get(timeout=10, propagate=True) == 7 + + create.assert_called_once_with( + "new_name", + value_max=10, + data={"celery_task_args": [2, 2], "celery_task_kwargs": {"c": 3}}, + actions=actions, + status=StatusEnum.PENDING, + ) + + +def test_celery_record_task_args_custom_serialization(celery_session_app, celery_session_worker, bind_settings): + class A: + def __init__(self, a, b): + self.a = a + self.b = b + + register_type(A, "A", lambda o: [o.a, o.b], lambda o: A(*o)) + + @celery_session_app.task(bind=True, base=Task) + def add_task_custom_serialization(self, a): + assert self.taskbadger_task is not None + return a.a + a.b + + celery_session_worker.reload() + + with ( + mock.patch("taskbadger.celery.create_task_safe") as create, + mock.patch("taskbadger.celery.update_task_safe"), + mock.patch("taskbadger.celery.get_task"), + ): + create.return_value = task_for_test() + + result = add_task_custom_serialization.delay( + A(2, 2), + taskbadger_record_task_args=True, + ) + assert result.get(timeout=10, propagate=True) == 4 + + create.assert_called_once_with( + "tests.test_celery.add_task_custom_serialization", + data={"celery_task_args": [{"__type__": "A", "__value__": [2, 2]}], "celery_task_kwargs": {}}, + status=StatusEnum.PENDING, + ) + + def test_celery_task_with_args_in_decorator(celery_session_app, celery_session_worker, bind_settings): @celery_session_app.task( bind=True, diff --git a/tests/test_celery_system_integration.py b/tests/test_celery_system_integration.py index 0ccd8c4..cd0829a 100644 --- a/tests/test_celery_system_integration.py +++ b/tests/test_celery_system_integration.py @@ -16,6 +16,8 @@ import pytest from celery.signals import task_prerun +from taskbadger import StatusEnum +from taskbadger.celery import Task from taskbadger.mug import Badger, Settings from taskbadger.systems.celery import CelerySystemIntegration from tests.utils import task_for_test @@ -73,6 +75,76 @@ def add_normal(self, a, b): assert Badger.current.session().client is None +@pytest.mark.usefixtures("_bind_settings_with_system") +def test_celery_record_task_args(celery_session_app, celery_session_worker): + @celery_session_app.task(bind=True) + def add_normal(self, a, b): + assert self.request.get("taskbadger_task_id") is not None, "missing task in request" + assert not hasattr(self, "taskbadger_task") + assert Badger.current.session().client is not None, "missing client" + return a + b + + celery_session_worker.reload() + + celery_system = Badger.current.settings.get_system_by_id("celery") + celery_system.record_task_args = True + + with ( + mock.patch("taskbadger.celery.create_task_safe") as create, + mock.patch("taskbadger.celery.update_task_safe") as update, + mock.patch("taskbadger.celery.get_task") as get_task, + ): + tb_task = task_for_test() + create.return_value = tb_task + result = add_normal.delay(2, 2) + assert result.info.get("taskbadger_task_id") == tb_task.id + assert result.get(timeout=10, propagate=True) == 4 + + create.assert_called_once_with( + "tests.test_celery_system_integration.add_normal", + status=StatusEnum.PENDING, + data={"celery_task_args": [2, 2], "celery_task_kwargs": {}}, + ) + assert get_task.call_count == 1 + assert update.call_count == 2 + assert Badger.current.session().client is None + + +@pytest.mark.usefixtures("_bind_settings_with_system") +def test_celery_record_task_args_local_override(celery_session_app, celery_session_worker): + """Test that passing `taskbadger_record_task_args` overrides the integration value""" + + @celery_session_app.task(bind=True, base=Task) + def add_normal_with_override(self, a, b): + assert self.request.get("taskbadger_task_id") is not None, "missing task in request" + assert hasattr(self, "taskbadger_task") + assert Badger.current.session().client is not None, "missing client" + return a + b + + celery_session_worker.reload() + + celery_system = Badger.current.settings.get_system_by_id("celery") + celery_system.record_task_args = True + + with ( + mock.patch("taskbadger.celery.create_task_safe") as create, + mock.patch("taskbadger.celery.update_task_safe") as update, + mock.patch("taskbadger.celery.get_task") as get_task, + ): + tb_task = task_for_test() + create.return_value = tb_task + result = add_normal_with_override.delay(2, 2, taskbadger_record_task_args=False) + assert result.info.get("taskbadger_task_id") == tb_task.id + assert result.get(timeout=10, propagate=True) == 4 + + create.assert_called_once_with( + "tests.test_celery_system_integration.add_normal_with_override", status=StatusEnum.PENDING + ) + assert get_task.call_count == 1 + assert update.call_count == 2 + assert Badger.current.session().client is None + + @pytest.mark.parametrize( ("include", "exclude", "expected"), [