Skip to content

Commit 0b538a5

Browse files
authored
Add support for Kubernetes on_warning_callback (#673)
To make `on_warning_callback` work with pod operators, we need to read the logs of the dbt test runs. This is done by ensuring the pod is kept alive, and `on_success_callback` the log is read and analysed for warnings. Afterwards, the pod is cleaned up based on the original settings from the user. If `on_warning_callback` is not set, everything stays the way it always was. This feature only work with `apache-airflow-providers-cncf-kubernetes >= 7.4.0`.
1 parent a83911f commit 0b538a5

File tree

2 files changed

+214
-5
lines changed

2 files changed

+214
-5
lines changed

cosmos/operators/kubernetes.py

Lines changed: 96 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@
44
from typing import Any, Callable, Sequence
55

66
import yaml
7-
from airflow.utils.context import Context
7+
from airflow.utils.context import Context, context_merge
88

99
from cosmos.log import get_logger
1010
from cosmos.config import ProfileConfig
1111
from cosmos.operators.base import DbtBaseOperator
1212

13+
from airflow.models import TaskInstance
14+
from cosmos.dbt.parser.output import extract_log_issues
15+
16+
DBT_NO_TESTS_MSG = "Nothing to do"
17+
DBT_WARN_MSG = "WARN"
1318

1419
logger = get_logger(__name__)
1520

@@ -19,6 +24,7 @@
1924
convert_env_vars,
2025
)
2126
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
27+
from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction
2228
except ImportError:
2329
try:
2430
# apache-airflow-providers-cncf-kubernetes < 7.4.0
@@ -158,10 +164,96 @@ class DbtTestKubernetesOperator(DbtKubernetesBaseOperator):
158164
ui_color = "#8194E0"
159165

160166
def __init__(self, on_warning_callback: Callable[..., Any] | None = None, **kwargs: Any) -> None:
161-
super().__init__(**kwargs)
167+
if not on_warning_callback:
168+
super().__init__(**kwargs)
169+
else:
170+
self.on_warning_callback = on_warning_callback
171+
self.is_delete_operator_pod_original = kwargs.get("is_delete_operator_pod", None)
172+
if self.is_delete_operator_pod_original is not None:
173+
self.on_finish_action_original = (
174+
OnFinishAction.DELETE_POD if self.is_delete_operator_pod_original else OnFinishAction.KEEP_POD
175+
)
176+
else:
177+
self.on_finish_action_original = OnFinishAction(kwargs.get("on_finish_action", "delete_pod"))
178+
self.is_delete_operator_pod_original = self.on_finish_action_original == OnFinishAction.DELETE_POD
179+
# In order to read the pod logs, we need to keep the pod around.
180+
# Depending on the on_finish_action & is_delete_operator_pod settings,
181+
# we will clean up the pod later in the _handle_warnings method, which
182+
# is called in on_success_callback.
183+
kwargs["is_delete_operator_pod"] = False
184+
kwargs["on_finish_action"] = OnFinishAction.KEEP_POD
185+
186+
# Add an additional callback to both success and failure callbacks.
187+
# In case of success, check for a warning in the logs and clean up the pod.
188+
self.on_success_callback = kwargs.get("on_success_callback", None) or []
189+
if isinstance(self.on_success_callback, list):
190+
self.on_success_callback += [self._handle_warnings]
191+
else:
192+
self.on_success_callback = [self.on_success_callback, self._handle_warnings]
193+
kwargs["on_success_callback"] = self.on_success_callback
194+
# In case of failure, clean up the pod.
195+
self.on_failure_callback = kwargs.get("on_failure_callback", None) or []
196+
if isinstance(self.on_failure_callback, list):
197+
self.on_failure_callback += [self._cleanup_pod]
198+
else:
199+
self.on_failure_callback = [self.on_failure_callback, self._cleanup_pod]
200+
kwargs["on_failure_callback"] = self.on_failure_callback
201+
202+
super().__init__(**kwargs)
203+
162204
self.base_cmd = ["test"]
163-
# as of now, on_warning_callback in kubernetes executor does nothing
164-
self.on_warning_callback = on_warning_callback
205+
206+
def _handle_warnings(self, context: Context) -> None:
207+
"""
208+
Handles warnings by extracting log issues, creating additional context, and calling the
209+
on_warning_callback with the updated context.
210+
211+
:param context: The original airflow context in which the build and run command was executed.
212+
"""
213+
if not (
214+
isinstance(context["task_instance"], TaskInstance)
215+
and isinstance(context["task_instance"].task, DbtTestKubernetesOperator)
216+
):
217+
return
218+
task = context["task_instance"].task
219+
logs = [
220+
log.decode("utf-8") for log in task.pod_manager.read_pod_logs(task.pod, "base") if log.decode("utf-8") != ""
221+
]
222+
223+
should_trigger_callback = all(
224+
[
225+
logs,
226+
self.on_warning_callback,
227+
DBT_NO_TESTS_MSG not in logs[-1],
228+
DBT_WARN_MSG in logs[-1],
229+
]
230+
)
231+
232+
if should_trigger_callback:
233+
warnings = int(logs[-1].split(f"{DBT_WARN_MSG}=")[1].split()[0])
234+
if warnings > 0:
235+
test_names, test_results = extract_log_issues(logs)
236+
context_merge(context, test_names=test_names, test_results=test_results)
237+
self.on_warning_callback(context)
238+
239+
self._cleanup_pod(context)
240+
241+
def _cleanup_pod(self, context: Context) -> None:
242+
"""
243+
Handles the cleaning up of the pod after success or failure, if
244+
there is a on_warning_callback function defined.
245+
246+
:param context: The original airflow context in which the build and run command was executed.
247+
"""
248+
if not (
249+
isinstance(context["task_instance"], TaskInstance)
250+
and isinstance(context["task_instance"].task, DbtTestKubernetesOperator)
251+
):
252+
return
253+
task = context["task_instance"].task
254+
if task.pod:
255+
task.on_finish_action = self.on_finish_action_original
256+
task.cleanup(pod=task.pod, remote_pod=task.remote_pod)
165257

166258

167259
class DbtRunOperationKubernetesOperator(DbtKubernetesBaseOperator):

tests/operators/test_kubernetes.py

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pathlib import Path
22
from unittest.mock import MagicMock, patch
33

4-
from airflow.utils.context import Context
4+
import pytest
55
from pendulum import datetime
66

77
from cosmos.operators.kubernetes import (
@@ -12,6 +12,16 @@
1212
DbtTestKubernetesOperator,
1313
)
1414

15+
from airflow.utils.context import Context, context_merge
16+
from airflow.models import TaskInstance
17+
18+
try:
19+
from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction
20+
21+
module_available = True
22+
except ImportError:
23+
module_available = False
24+
1525

1626
def test_dbt_kubernetes_operator_add_global_flags() -> None:
1727
dbt_kube_operator = DbtKubernetesBaseOperator(
@@ -103,6 +113,113 @@ def test_dbt_kubernetes_build_command():
103113
]
104114

105115

116+
@pytest.mark.parametrize(
117+
"additional_kwargs,expected_results",
118+
[
119+
({"on_success_callback": None, "is_delete_operator_pod": True}, (1, 1, True, "delete_pod")),
120+
(
121+
{"on_success_callback": (lambda **kwargs: None), "is_delete_operator_pod": False},
122+
(2, 1, False, "keep_pod"),
123+
),
124+
(
125+
{"on_success_callback": [(lambda **kwargs: None), (lambda **kwargs: None)], "is_delete_operator_pod": None},
126+
(3, 1, True, "delete_pod"),
127+
),
128+
(
129+
{"on_failure_callback": None, "is_delete_operator_pod": True, "on_finish_action": "keep_pod"},
130+
(1, 1, True, "delete_pod"),
131+
),
132+
(
133+
{
134+
"on_failure_callback": (lambda **kwargs: None),
135+
"is_delete_operator_pod": None,
136+
"on_finish_action": "delete_pod",
137+
},
138+
(1, 2, True, "delete_pod"),
139+
),
140+
(
141+
{
142+
"on_failure_callback": [(lambda **kwargs: None), (lambda **kwargs: None)],
143+
"is_delete_operator_pod": None,
144+
"on_finish_action": "delete_succeeded_pod",
145+
},
146+
(1, 3, False, "delete_succeeded_pod"),
147+
),
148+
({"is_delete_operator_pod": None, "on_finish_action": "keep_pod"}, (1, 1, False, "keep_pod")),
149+
({}, (1, 1, True, "delete_pod")),
150+
],
151+
)
152+
@pytest.mark.skipif(
153+
not module_available, reason="Kubernetes module `airflow.providers.cncf.kubernetes.utils.pod_manager` not available"
154+
)
155+
def test_dbt_test_kubernetes_operator_constructor(additional_kwargs, expected_results):
156+
test_operator = DbtTestKubernetesOperator(
157+
on_warning_callback=(lambda **kwargs: None), **additional_kwargs, **base_kwargs
158+
)
159+
160+
print(additional_kwargs, test_operator.__dict__)
161+
162+
assert isinstance(test_operator.on_success_callback, list)
163+
assert isinstance(test_operator.on_failure_callback, list)
164+
assert test_operator._handle_warnings in test_operator.on_success_callback
165+
assert test_operator._cleanup_pod in test_operator.on_failure_callback
166+
assert len(test_operator.on_success_callback) == expected_results[0]
167+
assert len(test_operator.on_failure_callback) == expected_results[1]
168+
assert test_operator.is_delete_operator_pod_original == expected_results[2]
169+
assert test_operator.on_finish_action_original == OnFinishAction(expected_results[3])
170+
171+
172+
class FakePodManager:
173+
def read_pod_logs(self, pod, container):
174+
assert pod == "pod"
175+
assert container == "base"
176+
log_string = """
177+
19:48:25 Concurrency: 4 threads (target='target')
178+
19:48:25
179+
19:48:25 1 of 2 START test dbt_utils_accepted_range_table_col__12__0 ................... [RUN]
180+
19:48:25 2 of 2 START test unique_table__uuid .......................................... [RUN]
181+
19:48:27 1 of 2 WARN 252 dbt_utils_accepted_range_table_col__12__0 ..................... [WARN 117 in 1.83s]
182+
19:48:27 2 of 2 PASS unique_table__uuid ................................................ [PASS in 1.85s]
183+
19:48:27
184+
19:48:27 Finished running 2 tests, 1 hook in 0 hours 0 minutes and 12.86 seconds (12.86s).
185+
19:48:27
186+
19:48:27 Completed with 1 warning:
187+
19:48:27
188+
19:48:27 Warning in test dbt_utils_accepted_range_table_col__12__0 (models/ads/ads.yaml)
189+
19:48:27 Got 252 results, configured to warn if >0
190+
19:48:27
191+
19:48:27 compiled Code at target/compiled/model/models/table/table.yaml/dbt_utils_accepted_range_table_col__12__0.sql
192+
19:48:27
193+
19:48:27 Done. PASS=1 WARN=1 ERROR=0 SKIP=0 TOTAL=2
194+
"""
195+
return (log.encode("utf-8") for log in log_string.split("\n"))
196+
197+
198+
@pytest.mark.skipif(
199+
not module_available, reason="Kubernetes module `airflow.providers.cncf.kubernetes.utils.pod_manager` not available"
200+
)
201+
def test_dbt_test_kubernetes_operator_handle_warnings_and_cleanup_pod():
202+
def on_warning_callback(context: Context):
203+
assert context["test_names"] == ["dbt_utils_accepted_range_table_col__12__0"]
204+
assert context["test_results"] == ["Got 252 results, configured to warn if >0"]
205+
206+
def cleanup(pod: str, remote_pod: str):
207+
assert pod == remote_pod
208+
209+
test_operator = DbtTestKubernetesOperator(
210+
is_delete_operator_pod=True, on_warning_callback=on_warning_callback, **base_kwargs
211+
)
212+
task_instance = TaskInstance(test_operator)
213+
task_instance.task.pod_manager = FakePodManager()
214+
task_instance.task.pod = task_instance.task.remote_pod = "pod"
215+
task_instance.task.cleanup = cleanup
216+
217+
context = Context()
218+
context_merge(context, task_instance=task_instance)
219+
220+
test_operator._handle_warnings(context)
221+
222+
106223
@patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.hook")
107224
def test_created_pod(test_hook):
108225
test_hook.is_in_cluster = False

0 commit comments

Comments
 (0)