|
7 | 7 | from datetime import timedelta |
8 | 8 | from pathlib import Path |
9 | 9 | from threading import Lock |
10 | | -from typing import TYPE_CHECKING, Any, Callable, List, Union |
| 10 | +from typing import TYPE_CHECKING, Any |
11 | 11 |
|
12 | 12 | from cosmos._triggers.watcher import WatcherTrigger, _parse_compressed_xcom |
13 | 13 |
|
|
30 | 30 | except ImportError: # pragma: no cover |
31 | 31 | from airflow.operators.empty import EmptyOperator # type: ignore[no-redef] |
32 | 32 |
|
33 | | -from packaging.version import Version |
34 | 33 |
|
| 34 | +from cosmos._utils.watcher_state import build_producer_state_fetcher |
35 | 35 | from cosmos.config import ProfileConfig |
36 | 36 | from cosmos.constants import AIRFLOW_VERSION, PRODUCER_WATCHER_TASK_ID, InvocationMode |
37 | 37 | from cosmos.operators.base import ( |
@@ -110,27 +110,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: |
110 | 110 | kwargs["default_args"] = default_args |
111 | 111 | kwargs["retries"] = 0 |
112 | 112 |
|
113 | | - on_failure_callback = self._set_on_failure_callback(kwargs.pop("on_failure_callback", None)) |
114 | | - super().__init__(task_id=task_id, *args, on_failure_callback=on_failure_callback, **kwargs) |
115 | | - |
116 | | - def _set_on_failure_callback( |
117 | | - self, user_callback: Any |
118 | | - ) -> Union[Callable[[Context], None], List[Callable[[Context], None]]]: |
119 | | - default_callback = self._store_producer_task_state |
120 | | - |
121 | | - if AIRFLOW_VERSION < Version("2.6.0"): |
122 | | - # Older versions only support a single callable |
123 | | - return default_callback |
124 | | - else: |
125 | | - if user_callback is None: |
126 | | - # No callback provided — use default in a list |
127 | | - return [default_callback] |
128 | | - elif isinstance(user_callback, list): |
129 | | - # Append to existing list of callbacks (make a copy to avoid side effects) |
130 | | - return user_callback + [default_callback] |
131 | | - else: |
132 | | - # Single callable provided — wrap it in a list and append ours |
133 | | - return [user_callback, default_callback] |
| 113 | + super().__init__(task_id=task_id, *args, **kwargs) |
134 | 114 |
|
135 | 115 | @staticmethod |
136 | 116 | def _serialize_event(event_message: EventMsg) -> dict[str, Any]: |
@@ -179,9 +159,6 @@ def _finalize(self, context: Context, startup_events: list[dict[str, Any]]) -> N |
179 | 159 | if startup_events: |
180 | 160 | safe_xcom_push(task_instance=context["ti"], key="dbt_startup_events", value=startup_events) |
181 | 161 |
|
182 | | - def _store_producer_task_state(self, context: Context) -> None: |
183 | | - safe_xcom_push(task_instance=context["ti"], key="state", value="failed") |
184 | | - |
185 | 162 | def execute(self, context: Context, **kwargs: Any) -> Any: |
186 | 163 | task_instance = context.get("ti") |
187 | 164 | if task_instance is None: |
@@ -371,8 +348,27 @@ def _get_status_from_run_results(self, ti: Any, context: Context) -> Any: |
371 | 348 |
|
372 | 349 | return node_result.get("status") |
373 | 350 |
|
374 | | - def _get_producer_task_state(self, ti: Any) -> Any: |
375 | | - return ti.xcom_pull(task_ids=self.producer_task_id, key="state") |
| 351 | + def _get_producer_task_status(self, context: Context) -> str | None: |
| 352 | + """ |
| 353 | + Get the task status of the producer task for both Airflow 2 and Airflow 3. |
| 354 | +
|
| 355 | + Returns the state of the producer task instance, or None if not found. |
| 356 | + """ |
| 357 | + ti = context["ti"] |
| 358 | + run_id = context["run_id"] |
| 359 | + dag_id = ti.dag_id |
| 360 | + |
| 361 | + fetch_state = build_producer_state_fetcher( |
| 362 | + airflow_version=AIRFLOW_VERSION, |
| 363 | + dag_id=dag_id, |
| 364 | + run_id=run_id, |
| 365 | + producer_task_id=self.producer_task_id, |
| 366 | + logger=logger, |
| 367 | + ) |
| 368 | + if fetch_state is None: |
| 369 | + return None |
| 370 | + |
| 371 | + return fetch_state() |
376 | 372 |
|
377 | 373 | def execute(self, context: Context, **kwargs: Any) -> None: |
378 | 374 | if not self.deferrable: |
@@ -433,7 +429,7 @@ def poke(self, context: Context) -> bool: |
433 | 429 | return self._fallback_to_local_run(try_number, context) |
434 | 430 |
|
435 | 431 | # We have assumption here that both the build producer and the sensor task will have same invocation mode |
436 | | - producer_task_state = self._get_producer_task_state(ti) |
| 432 | + producer_task_state = self._get_producer_task_status(context) |
437 | 433 | if self._use_event(): |
438 | 434 | status = self._get_status_from_events(ti, context) |
439 | 435 | else: |
|
0 commit comments