|
16 | 16 | import pytest |
17 | 17 | from celery.signals import task_prerun |
18 | 18 |
|
| 19 | +from taskbadger import StatusEnum |
| 20 | +from taskbadger.celery import Task |
19 | 21 | from taskbadger.mug import Badger, Settings |
20 | 22 | from taskbadger.systems.celery import CelerySystemIntegration |
21 | 23 | from tests.utils import task_for_test |
@@ -73,6 +75,74 @@ def add_normal(self, a, b): |
73 | 75 | assert Badger.current.session().client is None |
74 | 76 |
|
75 | 77 |
|
| 78 | +@pytest.mark.usefixtures("_bind_settings_with_system") |
| 79 | +def test_celery_record_task_args(celery_session_app, celery_session_worker): |
| 80 | + @celery_session_app.task(bind=True) |
| 81 | + def add_normal(self, a, b): |
| 82 | + assert self.request.get("taskbadger_task_id") is not None, "missing task in request" |
| 83 | + assert not hasattr(self, "taskbadger_task") |
| 84 | + assert Badger.current.session().client is not None, "missing client" |
| 85 | + return a + b |
| 86 | + |
| 87 | + celery_session_worker.reload() |
| 88 | + |
| 89 | + celery_system = Badger.current.settings.get_system_by_id("celery") |
| 90 | + celery_system.record_task_args = True |
| 91 | + |
| 92 | + with ( |
| 93 | + mock.patch("taskbadger.celery.create_task_safe") as create, |
| 94 | + mock.patch("taskbadger.celery.update_task_safe") as update, |
| 95 | + mock.patch("taskbadger.celery.get_task") as get_task, |
| 96 | + ): |
| 97 | + tb_task = task_for_test() |
| 98 | + create.return_value = tb_task |
| 99 | + result = add_normal.delay(2, 2) |
| 100 | + assert result.info.get("taskbadger_task_id") == tb_task.id |
| 101 | + assert result.get(timeout=10, propagate=True) == 4 |
| 102 | + |
| 103 | + create.assert_called_once_with( |
| 104 | + "tests.test_celery_system_integration.add_normal", |
| 105 | + status=StatusEnum.PENDING, |
| 106 | + data={"celery_task_args": [2, 2], "celery_task_kwargs": {}}, |
| 107 | + ) |
| 108 | + assert get_task.call_count == 1 |
| 109 | + assert update.call_count == 2 |
| 110 | + assert Badger.current.session().client is None |
| 111 | + |
| 112 | + |
| 113 | +@pytest.mark.usefixtures("_bind_settings_with_system") |
| 114 | +def test_celery_record_task_args_local_override(celery_session_app, celery_session_worker): |
| 115 | + """Test that passing `taskbadger_record_task_args` overrides the integration value""" |
| 116 | + |
| 117 | + @celery_session_app.task(bind=True, base=Task) |
| 118 | + def add_normal(self, a, b): |
| 119 | + assert self.request.get("taskbadger_task_id") is not None, "missing task in request" |
| 120 | + assert hasattr(self, "taskbadger_task") |
| 121 | + assert Badger.current.session().client is not None, "missing client" |
| 122 | + return a + b |
| 123 | + |
| 124 | + celery_session_worker.reload() |
| 125 | + |
| 126 | + celery_system = Badger.current.settings.get_system_by_id("celery") |
| 127 | + celery_system.record_task_args = True |
| 128 | + |
| 129 | + with ( |
| 130 | + mock.patch("taskbadger.celery.create_task_safe") as create, |
| 131 | + mock.patch("taskbadger.celery.update_task_safe") as update, |
| 132 | + mock.patch("taskbadger.celery.get_task") as get_task, |
| 133 | + ): |
| 134 | + tb_task = task_for_test() |
| 135 | + create.return_value = tb_task |
| 136 | + result = add_normal.delay(2, 2, taskbadger_record_task_args=False) |
| 137 | + assert result.info.get("taskbadger_task_id") == tb_task.id |
| 138 | + assert result.get(timeout=10, propagate=True) == 4 |
| 139 | + |
| 140 | + create.assert_called_once_with("tests.test_celery_system_integration.add_normal", status=StatusEnum.PENDING) |
| 141 | + assert get_task.call_count == 1 |
| 142 | + assert update.call_count == 2 |
| 143 | + assert Badger.current.session().client is None |
| 144 | + |
| 145 | + |
76 | 146 | @pytest.mark.parametrize( |
77 | 147 | ("include", "exclude", "expected"), |
78 | 148 | [ |
|
0 commit comments