Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion taskbadger/celery.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import functools
import json
import logging

import celery
Expand All @@ -10,6 +11,7 @@
task_retry,
task_success,
)
from kombu import serialization

from .internal.models import StatusEnum
from .mug import Badger
Expand All @@ -18,7 +20,7 @@

KWARG_PREFIX = "taskbadger_"
TB_KWARGS_ARG = f"{KWARG_PREFIX}kwargs"
IGNORE_ARGS = {TB_KWARGS_ARG, f"{KWARG_PREFIX}task", f"{KWARG_PREFIX}task_id"}
IGNORE_ARGS = {TB_KWARGS_ARG, f"{KWARG_PREFIX}task", f"{KWARG_PREFIX}task_id", f"{KWARG_PREFIX}record_task_args"}
TB_TASK_ID = f"{KWARG_PREFIX}task_id"

TERMINAL_STATES = {
Expand Down Expand Up @@ -124,6 +126,8 @@ def apply_async(self, *args, **kwargs):
if Badger.is_configured():
headers["taskbadger_track"] = True
headers[TB_KWARGS_ARG] = tb_kwargs
if "record_task_args" in tb_kwargs:
headers["taskbadger_record_task_args"] = tb_kwargs.pop("record_task_args")

result = super().apply_async(*args, **kwargs)

Expand Down Expand Up @@ -187,6 +191,20 @@ def task_publish_handler(sender=None, headers=None, body=None, **kwargs):
kwargs["status"] = StatusEnum.PENDING
name = kwargs.pop("name", headers["task"])

global_record_task_args = celery_system and celery_system.record_task_args
if headers.get("taskbadger_record_task_args", global_record_task_args):
data = {
"celery_task_args": body[0],
"celery_task_kwargs": body[1],
}
try:
_, _, value = serialization.dumps(data, serializer="json")
data = json.loads(value)
except Exception:
log.error("Error serializing task arguments for task '%s'", name)
else:
kwargs.setdefault("data", {}).update(data)

task = create_task_safe(name, **kwargs)
if task:
meta = {TB_TASK_ID: task.id}
Expand Down
4 changes: 3 additions & 1 deletion taskbadger/systems/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class CelerySystemIntegration(System):
identifier = "celery"

def __init__(self, auto_track_tasks=True, includes=None, excludes=None):
def __init__(self, auto_track_tasks=True, includes=None, excludes=None, record_task_args=False):
"""
Args:
auto_track_tasks: Automatically track all Celery tasks regardless of whether they are using the
Expand All @@ -16,10 +16,12 @@ def __init__(self, auto_track_tasks=True, includes=None, excludes=None):
matches both an include and an exclude, it will be excluded.
excludes: A list of task names to exclude from tracking. As with `includes`, these can be either
the full task name or a regular expression. Exclusions take precedence over inclusions.
record_task_args: Record the arguments passed to each task.
"""
self.auto_track_tasks = auto_track_tasks
self.includes = includes
self.excludes = excludes
self.record_task_args = record_task_args

if auto_track_tasks:
# Importing this here ensures that the Celery signal handlers are registered
Expand Down
104 changes: 104 additions & 0 deletions tests/test_celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import celery
import pytest
from kombu.utils.json import register_type

from taskbadger import Action, EmailIntegration, StatusEnum
from taskbadger.celery import Task
Expand Down Expand Up @@ -111,6 +112,109 @@ def add_with_task_args(self, a, b):
create.assert_called_once_with("new_name", value_max=10, actions=actions, status=StatusEnum.PENDING)


def test_celery_record_args(celery_session_app, celery_session_worker, bind_settings):
@celery_session_app.task(bind=True, base=Task)
def add_with_task_args(self, a, b):
assert self.taskbadger_task is not None
return a + b

celery_session_worker.reload()

with (
mock.patch("taskbadger.celery.create_task_safe") as create,
mock.patch("taskbadger.celery.update_task_safe"),
mock.patch("taskbadger.celery.get_task"),
):
create.return_value = task_for_test()

result = add_with_task_args.apply_async(
(2, 2),
taskbadger_name="new_name",
taskbadger_value_max=10,
taskbadger_kwargs={"data": {"foo": "bar"}},
taskbadger_record_task_args=True,
)
assert result.get(timeout=10, propagate=True) == 4

create.assert_called_once_with(
"new_name",
value_max=10,
data={"foo": "bar", "celery_task_args": [2, 2], "celery_task_kwargs": {}},
status=StatusEnum.PENDING,
)


def test_celery_record_task_kwargs(celery_session_app, celery_session_worker, bind_settings):
@celery_session_app.task(bind=True, base=Task)
def add_with_task_kwargs(self, a, b, c=0):
assert self.taskbadger_task is not None
return a + b + c

celery_session_worker.reload()

with (
mock.patch("taskbadger.celery.create_task_safe") as create,
mock.patch("taskbadger.celery.update_task_safe"),
mock.patch("taskbadger.celery.get_task"),
):
create.return_value = task_for_test()

actions = [Action("stale", integration=EmailIntegration(to="[email protected]"))]
result = add_with_task_kwargs.delay(
2,
2,
c=3,
taskbadger_name="new_name",
taskbadger_value_max=10,
taskbadger_kwargs={"actions": actions},
taskbadger_record_task_args=True,
)
assert result.get(timeout=10, propagate=True) == 7

create.assert_called_once_with(
"new_name",
value_max=10,
data={"celery_task_args": [2, 2], "celery_task_kwargs": {"c": 3}},
actions=actions,
status=StatusEnum.PENDING,
)


def test_celery_record_task_args_custom_serialization(celery_session_app, celery_session_worker, bind_settings):
class A:
def __init__(self, a, b):
self.a = a
self.b = b

register_type(A, "A", lambda o: [o.a, o.b], lambda o: A(*o))

@celery_session_app.task(bind=True, base=Task)
def add_task_custom_serialization(self, a):
assert self.taskbadger_task is not None
return a.a + a.b

celery_session_worker.reload()

with (
mock.patch("taskbadger.celery.create_task_safe") as create,
mock.patch("taskbadger.celery.update_task_safe"),
mock.patch("taskbadger.celery.get_task"),
):
create.return_value = task_for_test()

result = add_task_custom_serialization.delay(
A(2, 2),
taskbadger_record_task_args=True,
)
assert result.get(timeout=10, propagate=True) == 4

create.assert_called_once_with(
"tests.test_celery.add_task_custom_serialization",
data={"celery_task_args": [{"__type__": "A", "__value__": [2, 2]}], "celery_task_kwargs": {}},
status=StatusEnum.PENDING,
)


def test_celery_task_with_args_in_decorator(celery_session_app, celery_session_worker, bind_settings):
@celery_session_app.task(
bind=True,
Expand Down
72 changes: 72 additions & 0 deletions tests/test_celery_system_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import pytest
from celery.signals import task_prerun

from taskbadger import StatusEnum
from taskbadger.celery import Task
from taskbadger.mug import Badger, Settings
from taskbadger.systems.celery import CelerySystemIntegration
from tests.utils import task_for_test
Expand Down Expand Up @@ -73,6 +75,76 @@ def add_normal(self, a, b):
assert Badger.current.session().client is None


@pytest.mark.usefixtures("_bind_settings_with_system")
def test_celery_record_task_args(celery_session_app, celery_session_worker):
@celery_session_app.task(bind=True)
def add_normal(self, a, b):
assert self.request.get("taskbadger_task_id") is not None, "missing task in request"
assert not hasattr(self, "taskbadger_task")
assert Badger.current.session().client is not None, "missing client"
return a + b

celery_session_worker.reload()

celery_system = Badger.current.settings.get_system_by_id("celery")
celery_system.record_task_args = True

with (
mock.patch("taskbadger.celery.create_task_safe") as create,
mock.patch("taskbadger.celery.update_task_safe") as update,
mock.patch("taskbadger.celery.get_task") as get_task,
):
tb_task = task_for_test()
create.return_value = tb_task
result = add_normal.delay(2, 2)
assert result.info.get("taskbadger_task_id") == tb_task.id
assert result.get(timeout=10, propagate=True) == 4

create.assert_called_once_with(
"tests.test_celery_system_integration.add_normal",
status=StatusEnum.PENDING,
data={"celery_task_args": [2, 2], "celery_task_kwargs": {}},
)
assert get_task.call_count == 1
assert update.call_count == 2
assert Badger.current.session().client is None


@pytest.mark.usefixtures("_bind_settings_with_system")
def test_celery_record_task_args_local_override(celery_session_app, celery_session_worker):
"""Test that passing `taskbadger_record_task_args` overrides the integration value"""

@celery_session_app.task(bind=True, base=Task)
def add_normal_with_override(self, a, b):
assert self.request.get("taskbadger_task_id") is not None, "missing task in request"
assert hasattr(self, "taskbadger_task")
assert Badger.current.session().client is not None, "missing client"
return a + b

celery_session_worker.reload()

celery_system = Badger.current.settings.get_system_by_id("celery")
celery_system.record_task_args = True

with (
mock.patch("taskbadger.celery.create_task_safe") as create,
mock.patch("taskbadger.celery.update_task_safe") as update,
mock.patch("taskbadger.celery.get_task") as get_task,
):
tb_task = task_for_test()
create.return_value = tb_task
result = add_normal_with_override.delay(2, 2, taskbadger_record_task_args=False)
assert result.info.get("taskbadger_task_id") == tb_task.id
assert result.get(timeout=10, propagate=True) == 4

create.assert_called_once_with(
"tests.test_celery_system_integration.add_normal_with_override", status=StatusEnum.PENDING
)
assert get_task.call_count == 1
assert update.call_count == 2
assert Badger.current.session().client is None


@pytest.mark.parametrize(
("include", "exclude", "expected"),
[
Expand Down