Skip to content

Commit 0d087f1

Browse files
authored
Merge branch 'main' into fix-doc-default-values
2 parents 996fcc9 + d1840d2 commit 0d087f1

File tree

12 files changed

+368
-111
lines changed

12 files changed

+368
-111
lines changed

.github/workflows/codeql.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343

4444
steps:
4545
- name: Checkout repository
46-
uses: actions/checkout@v5.0.1
46+
uses: actions/checkout@v6.0.0
4747
with:
4848
persist-credentials: false
4949

.github/workflows/deploy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
contents: read
1313

1414
steps:
15-
- uses: actions/checkout@v5.0.1
15+
- uses: actions/checkout@v6.0.0
1616
with:
1717
persist-credentials: false
1818

.github/workflows/test.yml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
permissions:
3434
contents: read
3535
steps:
36-
- uses: actions/checkout@v5.0.1
36+
- uses: actions/checkout@v6.0.0
3737
with:
3838
ref: ${{ github.event.pull_request.head.sha || github.ref }}
3939
persist-credentials: false
@@ -77,7 +77,7 @@ jobs:
7777
- python-version: "3.12"
7878
airflow-version: "2.8"
7979
steps:
80-
- uses: actions/checkout@v5.0.1
80+
- uses: actions/checkout@v6.0.0
8181
with:
8282
ref: ${{ github.event.pull_request.head.sha || github.ref }}
8383
persist-credentials: false
@@ -140,7 +140,7 @@ jobs:
140140
ports:
141141
- 5432:5432
142142
steps:
143-
- uses: actions/checkout@v5.0.1
143+
- uses: actions/checkout@v6.0.0
144144
with:
145145
ref: ${{ github.event.pull_request.head.sha || github.ref }}
146146
persist-credentials: false
@@ -223,7 +223,7 @@ jobs:
223223
- 5432:5432
224224

225225
steps:
226-
- uses: actions/checkout@v5.0.1
226+
- uses: actions/checkout@v6.0.0
227227
with:
228228
ref: ${{ github.event.pull_request.head.sha || github.ref }}
229229
persist-credentials: false
@@ -309,7 +309,7 @@ jobs:
309309
- 5432:5432
310310

311311
steps:
312-
- uses: actions/checkout@v5.0.1
312+
- uses: actions/checkout@v6.0.0
313313
with:
314314
ref: ${{ github.event.pull_request.head.sha || github.ref }}
315315
persist-credentials: false
@@ -385,7 +385,7 @@ jobs:
385385
- 5432:5432
386386

387387
steps:
388-
- uses: actions/checkout@v5.0.1
388+
- uses: actions/checkout@v6.0.0
389389
with:
390390
ref: ${{ github.event.pull_request.head.sha || github.ref }}
391391
persist-credentials: false
@@ -455,7 +455,7 @@ jobs:
455455
dbt-version: ["2.0"] # dbt Fusion
456456

457457
steps:
458-
- uses: actions/checkout@v5.0.1
458+
- uses: actions/checkout@v6.0.0
459459
with:
460460
ref: ${{ github.event.pull_request.head.sha || github.ref }}
461461
persist-credentials: false
@@ -552,7 +552,7 @@ jobs:
552552
ports:
553553
- 5432:5432
554554
steps:
555-
- uses: actions/checkout@v5.0.1
555+
- uses: actions/checkout@v6.0.0
556556
with:
557557
ref: ${{ github.event.pull_request.head.sha || github.ref }}
558558
persist-credentials: false
@@ -614,7 +614,7 @@ jobs:
614614
airflow-version: [ "2.10", "3.0" ]
615615
dbt-version: [ "1.10" ]
616616
steps:
617-
- uses: actions/checkout@v5.0.1
617+
- uses: actions/checkout@v6.0.0
618618
with:
619619
ref: ${{ github.event.pull_request.head.sha || github.ref }}
620620
persist-credentials: false
@@ -685,7 +685,7 @@ jobs:
685685
permissions:
686686
contents: read
687687
steps:
688-
- uses: actions/checkout@v5.0.1
688+
- uses: actions/checkout@v6.0.0
689689
with:
690690
ref: ${{ github.event.pull_request.head.sha || github.ref }}
691691
persist-credentials: false

cosmos/_triggers/watcher.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from asgiref.sync import sync_to_async
1212
from packaging.version import Version
1313

14+
from cosmos._utils.watcher_state import build_producer_state_fetcher
15+
1416
AIRFLOW_VERSION = Version(airflow.__version__)
1517

1618

@@ -72,8 +74,10 @@ def _get_xcom_val() -> Any | None:
7274
task_id=self.producer_task_id,
7375
run_id=self.run_id,
7476
)
75-
.one()
77+
.one_or_none()
7678
)
79+
if ti is None:
80+
return None
7781
return ti.xcom_pull(task_ids=self.producer_task_id, key=key)
7882

7983
return await sync_to_async(_get_xcom_val)()
@@ -103,11 +107,26 @@ async def _parse_node_status(self) -> str | None:
103107
)
104108
return node_result.get("status")
105109

110+
async def _get_producer_task_status(self) -> str | None:
111+
"""Retrieve the producer task state for both Airflow 2 and Airflow 3."""
112+
113+
fetch_state = build_producer_state_fetcher(
114+
airflow_version=AIRFLOW_VERSION,
115+
dag_id=self.dag_id,
116+
run_id=self.run_id,
117+
producer_task_id=self.producer_task_id,
118+
logger=self.log,
119+
)
120+
if fetch_state is None:
121+
return None
122+
123+
return await sync_to_async(fetch_state)()
124+
106125
async def run(self) -> AsyncIterator[TriggerEvent]:
107126
self.log.info("Starting WatcherTrigger for model: %s", self.model_unique_id)
108127

109128
while True:
110-
producer_task_state = await self.get_xcom_val("state")
129+
producer_task_state = await self._get_producer_task_status()
111130
node_status = await self._parse_node_status()
112131
if node_status == "success":
113132
self.log.info("Model '%s' succeeded", self.model_unique_id)

cosmos/_utils/watcher_state.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
from typing import Any, Callable
5+
6+
from packaging.version import Version
7+
8+
ProducerStateFetcher = Callable[[], str | None]
9+
10+
11+
def _load_airflow2_dependencies() -> tuple[Any, Callable[[], Any]]:
12+
from airflow.models import TaskInstance
13+
from airflow.utils.session import create_session
14+
15+
return TaskInstance, create_session
16+
17+
18+
def _load_airflow3_dependencies() -> Any:
19+
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
20+
21+
return RuntimeTaskInstance
22+
23+
24+
def build_producer_state_fetcher(
25+
*,
26+
airflow_version: Version,
27+
dag_id: str,
28+
run_id: str,
29+
producer_task_id: str,
30+
logger: logging.Logger,
31+
) -> ProducerStateFetcher | None:
32+
"""Return a callable that fetches the producer task state for the given Airflow version."""
33+
34+
if airflow_version < Version("3.0.0"):
35+
try:
36+
TaskInstance, create_session = _load_airflow2_dependencies()
37+
except ImportError as exc: # pragma: no cover - defensive guard for stripped test envs
38+
logger.warning("Could not import Airflow 2 state dependencies: %s", exc)
39+
return None
40+
41+
def fetch_state_airflow2() -> str | None:
42+
with create_session() as session:
43+
ti = (
44+
session.query(TaskInstance)
45+
.filter_by(
46+
dag_id=dag_id,
47+
task_id=producer_task_id,
48+
run_id=run_id,
49+
)
50+
.one_or_none()
51+
)
52+
if ti is not None:
53+
return str(ti.state)
54+
return None
55+
56+
return fetch_state_airflow2
57+
58+
try:
59+
RuntimeTaskInstance = _load_airflow3_dependencies()
60+
except (ImportError, NameError) as exc: # pragma: no cover - Airflow 3 libs missing
61+
logger.warning("Could not load Airflow 3 RuntimeTaskInstance: %s", exc)
62+
return None
63+
64+
def fetch_state_airflow3() -> str | None:
65+
try:
66+
task_states = RuntimeTaskInstance.get_task_states(
67+
dag_id=dag_id,
68+
task_ids=[producer_task_id],
69+
run_ids=[run_id],
70+
)
71+
except NameError as exc: # pragma: no cover - Airflow 3.0 missing supervisor comms
72+
logger.warning("RuntimeTaskInstance.get_task_states unavailable due to NameError: %s", exc)
73+
return None
74+
state = task_states.get(run_id, {}).get(producer_task_id)
75+
if state is not None:
76+
return str(state)
77+
return None
78+
79+
return fetch_state_airflow3

cosmos/operators/watcher.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from datetime import timedelta
88
from pathlib import Path
99
from threading import Lock
10-
from typing import TYPE_CHECKING, Any, Callable, List, Union
10+
from typing import TYPE_CHECKING, Any
1111

1212
from cosmos._triggers.watcher import WatcherTrigger, _parse_compressed_xcom
1313

@@ -30,8 +30,8 @@
3030
except ImportError: # pragma: no cover
3131
from airflow.operators.empty import EmptyOperator # type: ignore[no-redef]
3232

33-
from packaging.version import Version
3433

34+
from cosmos._utils.watcher_state import build_producer_state_fetcher
3535
from cosmos.config import ProfileConfig
3636
from cosmos.constants import AIRFLOW_VERSION, PRODUCER_WATCHER_TASK_ID, InvocationMode
3737
from cosmos.operators.base import (
@@ -110,27 +110,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
110110
kwargs["default_args"] = default_args
111111
kwargs["retries"] = 0
112112

113-
on_failure_callback = self._set_on_failure_callback(kwargs.pop("on_failure_callback", None))
114-
super().__init__(task_id=task_id, *args, on_failure_callback=on_failure_callback, **kwargs)
115-
116-
def _set_on_failure_callback(
117-
self, user_callback: Any
118-
) -> Union[Callable[[Context], None], List[Callable[[Context], None]]]:
119-
default_callback = self._store_producer_task_state
120-
121-
if AIRFLOW_VERSION < Version("2.6.0"):
122-
# Older versions only support a single callable
123-
return default_callback
124-
else:
125-
if user_callback is None:
126-
# No callback provided — use default in a list
127-
return [default_callback]
128-
elif isinstance(user_callback, list):
129-
# Append to existing list of callbacks (make a copy to avoid side effects)
130-
return user_callback + [default_callback]
131-
else:
132-
# Single callable provided — wrap it in a list and append ours
133-
return [user_callback, default_callback]
113+
super().__init__(task_id=task_id, *args, **kwargs)
134114

135115
@staticmethod
136116
def _serialize_event(event_message: EventMsg) -> dict[str, Any]:
@@ -179,9 +159,6 @@ def _finalize(self, context: Context, startup_events: list[dict[str, Any]]) -> N
179159
if startup_events:
180160
safe_xcom_push(task_instance=context["ti"], key="dbt_startup_events", value=startup_events)
181161

182-
def _store_producer_task_state(self, context: Context) -> None:
183-
safe_xcom_push(task_instance=context["ti"], key="state", value="failed")
184-
185162
def execute(self, context: Context, **kwargs: Any) -> Any:
186163
task_instance = context.get("ti")
187164
if task_instance is None:
@@ -371,8 +348,27 @@ def _get_status_from_run_results(self, ti: Any, context: Context) -> Any:
371348

372349
return node_result.get("status")
373350

374-
def _get_producer_task_state(self, ti: Any) -> Any:
375-
return ti.xcom_pull(task_ids=self.producer_task_id, key="state")
351+
def _get_producer_task_status(self, context: Context) -> str | None:
352+
"""
353+
Get the task status of the producer task for both Airflow 2 and Airflow 3.
354+
355+
Returns the state of the producer task instance, or None if not found.
356+
"""
357+
ti = context["ti"]
358+
run_id = context["run_id"]
359+
dag_id = ti.dag_id
360+
361+
fetch_state = build_producer_state_fetcher(
362+
airflow_version=AIRFLOW_VERSION,
363+
dag_id=dag_id,
364+
run_id=run_id,
365+
producer_task_id=self.producer_task_id,
366+
logger=logger,
367+
)
368+
if fetch_state is None:
369+
return None
370+
371+
return fetch_state()
376372

377373
def execute(self, context: Context, **kwargs: Any) -> None:
378374
if not self.deferrable:
@@ -433,7 +429,7 @@ def poke(self, context: Context) -> bool:
433429
return self._fallback_to_local_run(try_number, context)
434430

435431
# We have assumption here that both the build producer and the sensor task will have same invocation mode
436-
producer_task_state = self._get_producer_task_state(ti)
432+
producer_task_state = self._get_producer_task_status(context)
437433
if self._use_event():
438434
status = self._get_status_from_events(ti, context)
439435
else:

docs/getting_started/async-execution-mode.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,5 +248,7 @@ Limitations
248248

249249
9. **TeardownAsyncOperator limitation**: When using a remote object location, in addition to the ``SetupAsyncOperator``, a ``TeardownAsyncOperator`` is also added to the DAG. This task will delete the SQL files from the remote location by the end of the DAG Run. This is can lead to a limitation from a retry perspective, as described in the issue `#2066 <https://github.com/astronomer/astronomer-cosmos/issues/2066>`_. This can be avoided by setting the ``enable_teardown_async_task`` configuration to ``False``, as described in the :ref:`enable_teardown_async_task` section.
250250

251+
10. **Dataset events not emitted**: Dataset events are not currently emitted after dbt models complete when using ``ExecutionMode.AIRFLOW_ASYNC``. This means downstream DAGs scheduled with ``Dataset`` or ``DatasetAlias`` will not trigger automatically. This behaviour is present in ``ExecutionMode.LOCAL`` but is currently missing in async mode. This issue is being tracked in `#2141 <https://github.com/astronomer/astronomer-cosmos/issues/2141>`_.
252+
251253

252254
For a comparison between different Cosmos execution modes, please, check the :ref:`execution-modes-comparison` section.

0 commit comments

Comments
 (0)