diff --git a/application_sdk/activities/__init__.py b/application_sdk/activities/__init__.py index 4a33c0fc5..ab837d1fc 100644 --- a/application_sdk/activities/__init__.py +++ b/application_sdk/activities/__init__.py @@ -206,6 +206,12 @@ async def get_workflow_args( ) workflow_args["workflow_id"] = workflow_id workflow_args["workflow_run_id"] = get_workflow_run_id() + + # Preserve atlan- prefixed keys from workflow_config for logging context + for key, value in workflow_config.items(): + if key.startswith("atlan-") and value: + workflow_args[key] = str(value) + return workflow_args except Exception as e: diff --git a/application_sdk/clients/temporal.py b/application_sdk/clients/temporal.py index bad4e05ca..3e2a6abcb 100644 --- a/application_sdk/clients/temporal.py +++ b/application_sdk/clients/temporal.py @@ -33,6 +33,9 @@ WorkerTokenRefreshEventData, ) from application_sdk.interceptors.cleanup import CleanupInterceptor, cleanup +from application_sdk.interceptors.correlation_context import ( + CorrelationContextInterceptor, +) from application_sdk.interceptors.events import EventInterceptor, publish_event from application_sdk.interceptors.lock import RedisLockInterceptor from application_sdk.observability.logger_adaptor import get_logger @@ -430,6 +433,7 @@ def create_worker( max_concurrent_activities=max_concurrent_activities, activity_executor=activity_executor, interceptors=[ + CorrelationContextInterceptor(), EventInterceptor(), CleanupInterceptor(), RedisLockInterceptor(activities_dict), diff --git a/application_sdk/interceptors/correlation_context.py b/application_sdk/interceptors/correlation_context.py new file mode 100644 index 000000000..0c4391ee2 --- /dev/null +++ b/application_sdk/interceptors/correlation_context.py @@ -0,0 +1,143 @@ +"""Correlation context interceptor for Temporal workflows. + +Propagates atlan-* correlation context fields from workflow arguments to activities +via Temporal headers, ensuring all activity logs include correlation identifiers +""" + +from dataclasses import replace +from typing import Any, Dict, Optional, Type + +from temporalio import workflow +from temporalio.api.common.v1 import Payload +from temporalio.converter import default as default_converter +from temporalio.worker import ( + ActivityInboundInterceptor, + ExecuteActivityInput, + ExecuteWorkflowInput, + Interceptor, + StartActivityInput, + WorkflowInboundInterceptor, + WorkflowInterceptorClassInput, + WorkflowOutboundInterceptor, +) + +from application_sdk.observability.context import correlation_context +from application_sdk.observability.logger_adaptor import get_logger + +logger = get_logger(__name__) + +ATLAN_HEADER_PREFIX = "atlan-" + + +class CorrelationContextOutboundInterceptor(WorkflowOutboundInterceptor): + """Outbound interceptor that injects atlan-* context into activity headers.""" + + def __init__( + self, + next: WorkflowOutboundInterceptor, + inbound: "CorrelationContextWorkflowInboundInterceptor", + ): + """Initialize the outbound interceptor.""" + super().__init__(next) + self.inbound = inbound + + def start_activity(self, input: StartActivityInput) -> workflow.ActivityHandle[Any]: + """Inject atlan-* headers and trace_id into activity calls.""" + try: + if self.inbound.correlation_data: + new_headers: Dict[str, Payload] = dict(input.headers) + payload_converter = default_converter().payload_converter + + for key, value in self.inbound.correlation_data.items(): + # Include atlan-* prefixed headers and trace_id + if ( + key.startswith(ATLAN_HEADER_PREFIX) or key == "trace_id" + ) and value: + payload = payload_converter.to_payload(value) + new_headers[key] = payload + + input = replace(input, headers=new_headers) + except Exception as e: + logger.warning(f"Failed to inject correlation context headers: {e}") + + return self.next.start_activity(input) + + +class CorrelationContextWorkflowInboundInterceptor(WorkflowInboundInterceptor): + """Inbound workflow interceptor that extracts atlan-* context from workflow args.""" + + def __init__(self, next: WorkflowInboundInterceptor): + """Initialize the inbound interceptor.""" + super().__init__(next) + self.correlation_data: Dict[str, str] = {} + + def init(self, outbound: WorkflowOutboundInterceptor) -> None: + """Initialize with correlation context outbound interceptor.""" + context_outbound = CorrelationContextOutboundInterceptor(outbound, self) + super().init(context_outbound) + + async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any: + """Execute workflow and extract atlan-* fields and trace_id from arguments.""" + try: + if input.args and len(input.args) > 0: + workflow_config = input.args[0] + if isinstance(workflow_config, dict): + # Extract atlan-* prefixed fields + self.correlation_data = { + k: str(v) + for k, v in workflow_config.items() + if k.startswith(ATLAN_HEADER_PREFIX) and v + } + # Extract trace_id separately (not atlan- prefixed) + trace_id = workflow_config.get("trace_id", "") + if trace_id: + self.correlation_data["trace_id"] = str(trace_id) + if self.correlation_data: + correlation_context.set(self.correlation_data) + except Exception as e: + logger.warning(f"Failed to extract correlation context from args: {e}") + + return await super().execute_workflow(input) + + +class CorrelationContextActivityInboundInterceptor(ActivityInboundInterceptor): + """Activity interceptor that reads atlan-* headers and trace_id, sets correlation_context.""" + + async def execute_activity(self, input: ExecuteActivityInput) -> Any: + """Execute activity after extracting atlan-* headers and trace_id.""" + try: + atlan_fields: Dict[str, str] = {} + payload_converter = default_converter().payload_converter + + for key, payload in input.headers.items(): + # Extract atlan-* prefixed headers and trace_id + if key.startswith(ATLAN_HEADER_PREFIX) or key == "trace_id": + value = payload_converter.from_payload(payload, type_hint=str) + atlan_fields[key] = value + + if atlan_fields: + correlation_context.set(atlan_fields) + + except Exception as e: + logger.warning(f"Failed to extract correlation context from headers: {e}") + + return await super().execute_activity(input) + + +class CorrelationContextInterceptor(Interceptor): + """Main interceptor for propagating atlan-* correlation context. + + Ensures atlan-* fields are propagated from workflow arguments to all activities via Temporal headers. + """ + + def workflow_interceptor_class( + self, input: WorkflowInterceptorClassInput + ) -> Optional[Type[WorkflowInboundInterceptor]]: + """Get the workflow interceptor class.""" + return CorrelationContextWorkflowInboundInterceptor + + def intercept_activity( + self, next: ActivityInboundInterceptor + ) -> ActivityInboundInterceptor: + """Intercept activity executions to read correlation context.""" + return CorrelationContextActivityInboundInterceptor(next) diff --git a/application_sdk/interceptors/events.py b/application_sdk/interceptors/events.py index dd10ce15b..95d2fcb36 100644 --- a/application_sdk/interceptors/events.py +++ b/application_sdk/interceptors/events.py @@ -66,8 +66,6 @@ async def execute_activity(self, input: ExecuteActivityInput) -> Any: Returns: Any: The result of the activity execution. """ - # Extract activity information for tracking - start_event = Event( event_type=EventTypes.APPLICATION_EVENT.value, event_name=ApplicationEventNames.ACTIVITY_START.value, diff --git a/application_sdk/observability/context.py b/application_sdk/observability/context.py new file mode 100644 index 000000000..ac038bd8c --- /dev/null +++ b/application_sdk/observability/context.py @@ -0,0 +1,18 @@ +"""Shared context variables for observability. + +This module contains ContextVar definitions that are shared across +multiple observability modules to avoid circular imports. +""" + +from contextvars import ContextVar +from typing import Any, Dict + +# Context variable for request-scoped data (e.g., request_id from HTTP middleware) +request_context: ContextVar[Dict[str, Any] | None] = ContextVar( + "request_context", default=None +) + +# Context variable for correlation context (atlan- prefixed headers for distributed tracing) +correlation_context: ContextVar[Dict[str, Any] | None] = ContextVar( + "correlation_context", default=None +) diff --git a/application_sdk/observability/logger_adaptor.py b/application_sdk/observability/logger_adaptor.py index 8bdea9a3a..2f0d38278 100644 --- a/application_sdk/observability/logger_adaptor.py +++ b/application_sdk/observability/logger_adaptor.py @@ -2,7 +2,6 @@ import logging import sys import threading -from contextvars import ContextVar from time import time_ns from typing import Any, Dict, Optional, Tuple @@ -34,6 +33,7 @@ SERVICE_NAME, SERVICE_VERSION, ) +from application_sdk.observability.context import correlation_context, request_context from application_sdk.observability.observability import AtlanObservability from application_sdk.observability.utils import ( get_observability_dir, @@ -42,7 +42,13 @@ class LogExtraModel(BaseModel): - """Pydantic model for log extra fields.""" + """Pydantic model for log extra fields. + + This model allows arbitrary extra fields (prefixed with atlan-) to be included + for correlation context propagation to OTEL. + """ + + model_config = {"extra": "allow"} client_host: Optional[str] = None duration_ms: Optional[int] = None @@ -67,52 +73,8 @@ class LogExtraModel(BaseModel): heartbeat_timeout: Optional[str] = None # Other fields log_type: Optional[str] = None - - class Config: - """Pydantic model configuration for LogExtraModel. - - Provides custom parsing logic for converting dictionary values to appropriate types. - Handles type conversion for various fields like integers, strings, and other data types. - """ - - @classmethod - def parse_obj(cls, obj): - if isinstance(obj, dict): - # Define type mappings for each field - type_mappings = { - # Integer fields - "attempt": int, - "duration_ms": int, - "status_code": int, - # String fields - "client_host": str, - "method": str, - "path": str, - "request_id": str, - "url": str, - "workflow_id": str, - "run_id": str, - "workflow_type": str, - "namespace": str, - "task_queue": str, - "activity_id": str, - "activity_type": str, - "schedule_to_close_timeout": str, - "start_to_close_timeout": str, - "schedule_to_start_timeout": str, - "heartbeat_timeout": str, - "log_type": str, - } - - # Process each field with its type conversion - for field, type_func in type_mappings.items(): - if field in obj and obj[field] is not None: - try: - obj[field] = type_func(obj[field]) - except (ValueError, TypeError): - obj[field] = None - - return super().parse_obj(obj) + # Trace context + trace_id: Optional[str] = None class LogRecordModel(BaseModel): @@ -142,6 +104,9 @@ def from_loguru_message(cls, message: Any) -> "LogRecordModel": for k, v in message.record["extra"].items(): if k != "logger_name" and hasattr(extra, k): setattr(extra, k, v) + # Include atlan- prefixed fields as extra attributes (correlation context) + elif k.startswith("atlan-") and v is not None: + setattr(extra, k, str(v)) return cls( timestamp=message.record["time"].timestamp(), @@ -160,8 +125,9 @@ class Config: arbitrary_types_allowed = True -# Create a context variable for request_id -request_context: ContextVar[Dict[str, Any]] = ContextVar("request_context", default={}) +# Re-exported from context.py for backward compatibility: +# - request_context: ContextVar for request-scoped data (e.g., request_id) +# - correlation_context: ContextVar for atlan- prefixed headers # Add a Loguru handler for the Python logging system @@ -286,23 +252,41 @@ def __init__(self, logger_name: str) -> None: "TRACING", no=SEVERITY_MAPPING["TRACING"], color="", icon="🔍" ) - # Update format string to use the bound logger_name - atlan_format_str_color = "{time:YYYY-MM-DD HH:mm:ss} [{level}] {extra[logger_name]} - {message}" - atlan_format_str_plain = ( - "{time:YYYY-MM-DD HH:mm:ss} [{level}] {extra[logger_name]} - {message}" - ) + # Colorize the logs only if the log level is DEBUG + colorize = LOG_LEVEL == "DEBUG" - colorize = False - format_str = atlan_format_str_plain + def get_log_format(record: Any) -> str: + """Generate log format string with trace_id for correlation. - # Colorize the logs only if the log level is DEBUG - if LOG_LEVEL == "DEBUG": - colorize = True - format_str = atlan_format_str_color + Args: + record: Loguru record dictionary containing log information. + + Returns: + Format string for the log message. + """ + # Build trace_id display string (only trace_id is printed, atlan-* go to OTEL) + trace_id = record["extra"].get("trace_id", "") + record["extra"]["_trace_id_str"] = ( + f" trace_id={trace_id}" if trace_id else "" + ) + + if colorize: + return ( + "{time:YYYY-MM-DD HH:mm:ss} " + "[{level}]" + "{extra[_trace_id_str]} " + "{extra[logger_name]}" + " - {message}\n" + ) + return ( + "{time:YYYY-MM-DD HH:mm:ss} [{level}]" + "{extra[_trace_id_str]} {extra[logger_name]}" + " - {message}\n" + ) self.logger.add( sys.stderr, - format=format_str, + format=get_log_format, level=SEVERITY_MAPPING[LOG_LEVEL], colorize=colorize, ) @@ -538,16 +522,14 @@ def process(self, msg: Any, kwargs: Dict[str, Any]) -> Tuple[Any, Dict[str, Any] - Adds request context if available - Adds workflow context if in a workflow - Adds activity context if in an activity + - Adds correlation context if available """ kwargs["logger_name"] = self.logger_name # Get request context - try: - ctx = request_context.get() - if ctx and "request_id" in ctx: - kwargs["request_id"] = ctx["request_id"] - except Exception: - pass + ctx = request_context.get() + if ctx and "request_id" in ctx: + kwargs["request_id"] = ctx["request_id"] workflow_context = get_workflow_context() @@ -569,6 +551,17 @@ def process(self, msg: Any, kwargs: Dict[str, Any]) -> Tuple[Any, Dict[str, Any] except Exception: pass + # Add correlation context (atlan- prefixed keys and trace_id) to kwargs + corr_ctx = correlation_context.get() + if corr_ctx: + # Add trace_id if present (for log format display) + if "trace_id" in corr_ctx and corr_ctx["trace_id"]: + kwargs["trace_id"] = str(corr_ctx["trace_id"]) + # Add atlan-* headers for OTEL + for key, value in corr_ctx.items(): + if key.startswith("atlan-") and value: + kwargs[key] = str(value) + return msg, kwargs def debug(self, msg: str, *args: Any, **kwargs: Any): diff --git a/application_sdk/observability/utils.py b/application_sdk/observability/utils.py index 90bc17ce0..24c82b423 100644 --- a/application_sdk/observability/utils.py +++ b/application_sdk/observability/utils.py @@ -9,10 +9,17 @@ OBSERVABILITY_DIR, TEMPORARY_PATH, ) +from application_sdk.observability.context import correlation_context class WorkflowContext(BaseModel): - """Workflow context.""" + """Workflow context. + + This model supports dynamic correlation context fields (atlan- prefixed) + through Pydantic's extra="allow" configuration. + """ + + model_config = {"extra": "allow"} in_workflow: str = Field(default="false") in_activity: str = Field(default="false") @@ -75,4 +82,12 @@ def get_workflow_context() -> WorkflowContext: except Exception: pass + # Get correlation context from context variable (atlan- prefixed headers) + corr_ctx = correlation_context.get() + if corr_ctx: + # Add all correlation context fields as extra attributes + for key, value in corr_ctx.items(): + if key.startswith("atlan-") and value: + setattr(context, key, str(value)) + return context diff --git a/application_sdk/server/fastapi/middleware/logmiddleware.py b/application_sdk/server/fastapi/middleware/logmiddleware.py index dee1be8d1..2995eac8c 100644 --- a/application_sdk/server/fastapi/middleware/logmiddleware.py +++ b/application_sdk/server/fastapi/middleware/logmiddleware.py @@ -6,7 +6,8 @@ from starlette.responses import Response from starlette.types import ASGIApp -from application_sdk.observability.logger_adaptor import get_logger, request_context +from application_sdk.observability.context import request_context +from application_sdk.observability.logger_adaptor import get_logger logger = get_logger(__name__) diff --git a/tests/unit/interceptors/__init__.py b/tests/unit/interceptors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/interceptors/test_correlation_context.py b/tests/unit/interceptors/test_correlation_context.py new file mode 100644 index 000000000..25936865d --- /dev/null +++ b/tests/unit/interceptors/test_correlation_context.py @@ -0,0 +1,430 @@ +"""Unit tests for the correlation context interceptor. + +Tests the propagation of atlan-* correlation context fields from workflow +arguments to activities via Temporal headers. +""" + +from dataclasses import dataclass, field +from typing import Any, Mapping, Sequence +from unittest import mock + +import pytest +from temporalio.api.common.v1 import Payload +from temporalio.converter import default as default_converter + +from application_sdk.interceptors.correlation_context import ( + ATLAN_HEADER_PREFIX, + CorrelationContextActivityInboundInterceptor, + CorrelationContextInterceptor, + CorrelationContextOutboundInterceptor, + CorrelationContextWorkflowInboundInterceptor, +) +from application_sdk.observability.context import correlation_context + + +@dataclass +class MockExecuteWorkflowInput: + """Mock ExecuteWorkflowInput for testing.""" + + args: Sequence[Any] = field(default_factory=list) + headers: Mapping[str, Payload] = field(default_factory=dict) + + +@dataclass +class MockStartActivityInput: + """Mock StartActivityInput for testing.""" + + activity: str = "test_activity" + args: Sequence[Any] = field(default_factory=list) + headers: Mapping[str, Payload] = field(default_factory=dict) + + +@dataclass +class MockExecuteActivityInput: + """Mock ExecuteActivityInput for testing.""" + + fn: Any = None + args: Sequence[Any] = field(default_factory=list) + headers: Mapping[str, Payload] = field(default_factory=dict) + executor: Any = None + + +class TestCorrelationContextWorkflowInboundInterceptor: + """Tests for CorrelationContextWorkflowInboundInterceptor.""" + + @pytest.fixture + def mock_next_inbound(self): + """Create a mock next inbound interceptor.""" + mock_next = mock.AsyncMock() + mock_next.execute_workflow = mock.AsyncMock(return_value="workflow_result") + return mock_next + + @pytest.fixture + def interceptor(self, mock_next_inbound): + """Create the interceptor instance.""" + return CorrelationContextWorkflowInboundInterceptor(mock_next_inbound) + + @pytest.mark.asyncio + async def test_extracts_atlan_fields_from_workflow_args( + self, interceptor, mock_next_inbound + ): + """Test that atlan-* fields are extracted from workflow config.""" + workflow_config = { + "workflow_id": "test-workflow-123", + "atlan-ignore": "redshift-test-1.41", + "atlan-argo-workflow-id": "redshift-test-1.09", + "atlan-argo-workflow-node": "redshift-test.1(0).(2).(3)", + "other_field": "should_be_ignored", + } + input_data = MockExecuteWorkflowInput(args=[workflow_config]) + + await interceptor.execute_workflow(input_data) + + # Verify correlation data was extracted + assert interceptor.correlation_data == { + "atlan-ignore": "redshift-test-1.41", + "atlan-argo-workflow-id": "redshift-test-1.09", + "atlan-argo-workflow-node": "redshift-test.1(0).(2).(3)", + } + + @pytest.mark.asyncio + async def test_extracts_trace_id_from_workflow_args( + self, interceptor, mock_next_inbound + ): + """Test that trace_id is extracted from workflow config.""" + workflow_config = { + "workflow_id": "test-workflow-123", + "trace_id": "my-trace-id-abc", + "atlan-ignore": "redshift-test-1.41", + "other_field": "should_be_ignored", + } + input_data = MockExecuteWorkflowInput(args=[workflow_config]) + + await interceptor.execute_workflow(input_data) + + # Verify trace_id was extracted along with atlan-* fields + assert interceptor.correlation_data["trace_id"] == "my-trace-id-abc" + assert interceptor.correlation_data["atlan-ignore"] == "redshift-test-1.41" + + @pytest.mark.asyncio + async def test_handles_workflow_config_without_trace_id( + self, interceptor, mock_next_inbound + ): + """Test workflow config without trace_id is handled gracefully.""" + workflow_config = { + "workflow_id": "test-workflow-123", + "atlan-ignore": "redshift-test-1.41", + } + input_data = MockExecuteWorkflowInput(args=[workflow_config]) + + await interceptor.execute_workflow(input_data) + + # Verify trace_id is not in correlation data + assert "trace_id" not in interceptor.correlation_data + assert interceptor.correlation_data["atlan-ignore"] == "redshift-test-1.41" + + @pytest.mark.asyncio + async def test_handles_empty_workflow_args(self, interceptor, mock_next_inbound): + """Test that empty workflow args are handled gracefully.""" + input_data = MockExecuteWorkflowInput(args=[]) + + await interceptor.execute_workflow(input_data) + + assert interceptor.correlation_data == {} + + @pytest.mark.asyncio + async def test_handles_non_dict_workflow_args(self, interceptor, mock_next_inbound): + """Test that non-dict workflow args are handled gracefully.""" + input_data = MockExecuteWorkflowInput(args=["not_a_dict"]) + + await interceptor.execute_workflow(input_data) + + assert interceptor.correlation_data == {} + + @pytest.mark.asyncio + async def test_handles_workflow_config_without_atlan_fields( + self, interceptor, mock_next_inbound + ): + """Test workflow config without any atlan-* fields.""" + workflow_config = { + "workflow_id": "test-workflow-123", + "other_field": "value", + } + input_data = MockExecuteWorkflowInput(args=[workflow_config]) + + await interceptor.execute_workflow(input_data) + + assert interceptor.correlation_data == {} + + @pytest.mark.asyncio + async def test_filters_out_none_values(self, interceptor, mock_next_inbound): + """Test that None values in atlan-* fields are filtered out.""" + workflow_config = { + "atlan-ignore": "valid-value", + "atlan-empty": None, + "atlan-also-empty": "", + } + input_data = MockExecuteWorkflowInput(args=[workflow_config]) + + await interceptor.execute_workflow(input_data) + + assert interceptor.correlation_data == { + "atlan-ignore": "valid-value", + } + + @pytest.mark.asyncio + async def test_sets_correlation_context(self, interceptor, mock_next_inbound): + """Test that correlation context is set for workflow-level logging.""" + workflow_config = { + "atlan-ignore": "test-value", + } + input_data = MockExecuteWorkflowInput(args=[workflow_config]) + + # Reset correlation context before test + correlation_context.set({}) + + await interceptor.execute_workflow(input_data) + + # Verify correlation context was set + ctx = correlation_context.get() + assert ctx == {"atlan-ignore": "test-value"} + + +class TestCorrelationContextOutboundInterceptor: + """Tests for CorrelationContextOutboundInterceptor.""" + + @pytest.fixture + def mock_next_outbound(self): + """Create a mock next outbound interceptor.""" + mock_next = mock.MagicMock() + mock_next.start_activity = mock.MagicMock(return_value="activity_handle") + return mock_next + + @pytest.fixture + def mock_inbound(self): + """Create a mock inbound interceptor with correlation data.""" + mock_inbound = mock.MagicMock() + mock_inbound.correlation_data = { + "atlan-ignore": "test-value", + "atlan-argo-workflow-id": "workflow-123", + } + return mock_inbound + + @pytest.fixture + def mock_inbound_with_trace_id(self): + """Create a mock inbound interceptor with correlation data including trace_id.""" + mock_inbound = mock.MagicMock() + mock_inbound.correlation_data = { + "trace_id": "my-trace-id-123", + "atlan-ignore": "test-value", + } + return mock_inbound + + @pytest.fixture + def interceptor(self, mock_next_outbound, mock_inbound): + """Create the outbound interceptor instance.""" + return CorrelationContextOutboundInterceptor(mock_next_outbound, mock_inbound) + + def test_injects_headers_into_activity_calls( + self, interceptor, mock_next_outbound, mock_inbound + ): + """Test that atlan-* headers are injected into activity calls.""" + input_data = MockStartActivityInput(headers={}) + + interceptor.start_activity(input_data) + + # Verify that start_activity was called + mock_next_outbound.start_activity.assert_called_once() + + # Get the modified input + called_input = mock_next_outbound.start_activity.call_args[0][0] + + # Verify headers were injected + payload_converter = default_converter().payload_converter + assert "atlan-ignore" in called_input.headers + assert "atlan-argo-workflow-id" in called_input.headers + + # Verify payload values + ignore_value = payload_converter.from_payload( + called_input.headers["atlan-ignore"], type_hint=str + ) + assert ignore_value == "test-value" + + def test_preserves_existing_headers( + self, interceptor, mock_next_outbound, mock_inbound + ): + """Test that existing headers are preserved.""" + payload_converter = default_converter().payload_converter + existing_payload = payload_converter.to_payload("existing-value") + input_data = MockStartActivityInput( + headers={"existing-header": existing_payload} + ) + + interceptor.start_activity(input_data) + + called_input = mock_next_outbound.start_activity.call_args[0][0] + + # Verify existing header is preserved + assert "existing-header" in called_input.headers + # Verify new headers were added + assert "atlan-ignore" in called_input.headers + + def test_handles_empty_correlation_data(self, mock_next_outbound): + """Test that empty correlation data is handled gracefully.""" + mock_inbound = mock.MagicMock() + mock_inbound.correlation_data = {} + interceptor = CorrelationContextOutboundInterceptor( + mock_next_outbound, mock_inbound + ) + + input_data = MockStartActivityInput(headers={}) + + interceptor.start_activity(input_data) + + mock_next_outbound.start_activity.assert_called_once() + + def test_injects_trace_id_into_activity_headers( + self, mock_next_outbound, mock_inbound_with_trace_id + ): + """Test that trace_id is injected into activity headers.""" + interceptor = CorrelationContextOutboundInterceptor( + mock_next_outbound, mock_inbound_with_trace_id + ) + input_data = MockStartActivityInput(headers={}) + + interceptor.start_activity(input_data) + + # Get the modified input + called_input = mock_next_outbound.start_activity.call_args[0][0] + + # Verify trace_id was injected + payload_converter = default_converter().payload_converter + assert "trace_id" in called_input.headers + assert "atlan-ignore" in called_input.headers + + # Verify trace_id payload value + trace_id_value = payload_converter.from_payload( + called_input.headers["trace_id"], type_hint=str + ) + assert trace_id_value == "my-trace-id-123" + + +class TestCorrelationContextActivityInboundInterceptor: + """Tests for CorrelationContextActivityInboundInterceptor.""" + + @pytest.fixture + def mock_next_activity(self): + """Create a mock next activity interceptor.""" + mock_next = mock.AsyncMock() + mock_next.execute_activity = mock.AsyncMock(return_value="activity_result") + return mock_next + + @pytest.fixture + def interceptor(self, mock_next_activity): + """Create the activity interceptor instance.""" + return CorrelationContextActivityInboundInterceptor(mock_next_activity) + + @pytest.mark.asyncio + async def test_extracts_headers_and_sets_context( + self, interceptor, mock_next_activity + ): + """Test that atlan-* headers are extracted and correlation context is set.""" + payload_converter = default_converter().payload_converter + + headers = { + "atlan-ignore": payload_converter.to_payload("test-value"), + "atlan-argo-workflow-id": payload_converter.to_payload("workflow-123"), + "other-header": payload_converter.to_payload("should-be-ignored"), + } + input_data = MockExecuteActivityInput(headers=headers) + + # Reset correlation context before test + correlation_context.set({}) + + await interceptor.execute_activity(input_data) + + # Verify correlation context was set with only atlan-* headers + ctx = correlation_context.get() + assert ctx == { + "atlan-ignore": "test-value", + "atlan-argo-workflow-id": "workflow-123", + } + + @pytest.mark.asyncio + async def test_extracts_trace_id_from_headers( + self, interceptor, mock_next_activity + ): + """Test that trace_id is extracted from headers and set in correlation context.""" + payload_converter = default_converter().payload_converter + + headers = { + "trace_id": payload_converter.to_payload("my-trace-id-456"), + "atlan-ignore": payload_converter.to_payload("test-value"), + } + input_data = MockExecuteActivityInput(headers=headers) + + # Reset correlation context before test + correlation_context.set({}) + + await interceptor.execute_activity(input_data) + + # Verify trace_id was extracted and set in correlation context + ctx = correlation_context.get() + assert ctx["trace_id"] == "my-trace-id-456" + assert ctx["atlan-ignore"] == "test-value" + + @pytest.mark.asyncio + async def test_handles_empty_headers(self, interceptor, mock_next_activity): + """Test that empty headers are handled gracefully.""" + input_data = MockExecuteActivityInput(headers={}) + + # Reset correlation context before test + correlation_context.set({}) + + await interceptor.execute_activity(input_data) + + # Verify activity was still executed + mock_next_activity.execute_activity.assert_called_once() + + @pytest.mark.asyncio + async def test_calls_next_interceptor(self, interceptor, mock_next_activity): + """Test that the next interceptor is always called.""" + input_data = MockExecuteActivityInput(headers={}) + + result = await interceptor.execute_activity(input_data) + + mock_next_activity.execute_activity.assert_called_once_with(input_data) + assert result == "activity_result" + + +class TestCorrelationContextInterceptor: + """Tests for the main CorrelationContextInterceptor class.""" + + @pytest.fixture + def interceptor(self): + """Create the main interceptor instance.""" + return CorrelationContextInterceptor() + + def test_returns_workflow_interceptor_class(self, interceptor): + """Test that workflow_interceptor_class returns the correct class.""" + mock_input = mock.MagicMock() + + result = interceptor.workflow_interceptor_class(mock_input) + + assert result == CorrelationContextWorkflowInboundInterceptor + + def test_intercept_activity_wraps_next(self, interceptor): + """Test that intercept_activity wraps the next interceptor.""" + mock_next = mock.MagicMock() + + result = interceptor.intercept_activity(mock_next) + + assert isinstance(result, CorrelationContextActivityInboundInterceptor) + + +class TestAtlanHeaderPrefix: + """Tests for the ATLAN_HEADER_PREFIX constant.""" + + def test_prefix_value(self): + """Test that the prefix constant has the correct value.""" + assert ATLAN_HEADER_PREFIX == "atlan-" diff --git a/tests/unit/observability/test_logger_adaptor.py b/tests/unit/observability/test_logger_adaptor.py index db7ae209e..35d6d79ba 100644 --- a/tests/unit/observability/test_logger_adaptor.py +++ b/tests/unit/observability/test_logger_adaptor.py @@ -325,3 +325,243 @@ async def test_parquet_sink_error_handling(mock_parquet_file): # Verify buffer is empty (error was handled assert len(logger_adapter._buffer) == 0 + + +class TestCorrelationContext: + """Tests for correlation context in logging.""" + + WORKFLOW_NAME_HEADER = "atlan-workflow-name" + WORKFLOW_NODE_HEADER = "atlan-workflow-node" + WORKFLOW_NAME = "test-workflow-123" + WORKFLOW_NODE = "test-workflow-123.node-1" + WORKFLOW_ID = "test-workflow-123" + TRACE_ID = "my-trace-id-123" + + def test_process_with_correlation_context(self): + """Test process() when correlation context is set.""" + with create_logger_adapter() as logger_adapter: + with mock.patch( + "application_sdk.observability.logger_adaptor.correlation_context" + ) as mock_corr_context: + mock_corr_context.get.return_value = { + self.WORKFLOW_NAME_HEADER: self.WORKFLOW_NAME, + self.WORKFLOW_NODE_HEADER: self.WORKFLOW_NODE, + } + + msg, kwargs = logger_adapter.process("Test message", {}) + + assert kwargs["logger_name"] == "test_logger" + assert msg == "Test message" + + def test_process_without_correlation_context(self): + """Test process() when correlation context is empty.""" + with create_logger_adapter() as logger_adapter: + with mock.patch( + "application_sdk.observability.logger_adaptor.correlation_context" + ) as mock_corr_context: + mock_corr_context.get.return_value = {} + + msg, kwargs = logger_adapter.process("Test message", {}) + + assert kwargs["logger_name"] == "test_logger" + assert msg == "Test message" + + def test_process_extracts_trace_id_from_correlation_context(self): + """Test process() extracts trace_id from correlation context.""" + with create_logger_adapter() as logger_adapter: + with mock.patch( + "application_sdk.observability.logger_adaptor.correlation_context" + ) as mock_corr_context: + mock_corr_context.get.return_value = { + "trace_id": self.TRACE_ID, + self.WORKFLOW_NAME_HEADER: self.WORKFLOW_NAME, + } + + msg, kwargs = logger_adapter.process("Test message", {}) + + assert kwargs["trace_id"] == self.TRACE_ID + assert kwargs[self.WORKFLOW_NAME_HEADER] == self.WORKFLOW_NAME + + def test_process_without_trace_id(self): + """Test process() when trace_id is not in correlation context.""" + with create_logger_adapter() as logger_adapter: + with mock.patch( + "application_sdk.observability.logger_adaptor.correlation_context" + ) as mock_corr_context: + mock_corr_context.get.return_value = { + self.WORKFLOW_NAME_HEADER: self.WORKFLOW_NAME, + } + + msg, kwargs = logger_adapter.process("Test message", {}) + + assert "trace_id" not in kwargs + assert kwargs[self.WORKFLOW_NAME_HEADER] == self.WORKFLOW_NAME + + def test_process_handles_none_correlation_context(self): + """Test process() gracefully handles None when correlation context is unset.""" + with create_logger_adapter() as logger_adapter: + with mock.patch( + "application_sdk.observability.logger_adaptor.correlation_context" + ) as mock_corr_context: + # ContextVar returns None when not set + mock_corr_context.get.return_value = None + + # Should not raise exception - None is handled gracefully + msg, kwargs = logger_adapter.process("Test message", {}) + + # Verify basic functionality still works + assert msg == "Test message" + assert kwargs["logger_name"] == "test_logger" + # Verify no correlation context data was added + assert self.WORKFLOW_NAME_HEADER not in kwargs + assert "trace_id" not in kwargs + + def test_process_handles_none_request_context(self): + """Test process() gracefully handles None when request context is unset.""" + with create_logger_adapter() as logger_adapter: + with mock.patch( + "application_sdk.observability.logger_adaptor.request_context" + ) as mock_req_context: + # ContextVar returns None when not set + mock_req_context.get.return_value = None + + # Should not raise exception - None is handled gracefully + msg, kwargs = logger_adapter.process("Test message", {}) + + # Verify basic functionality still works + assert msg == "Test message" + assert kwargs["logger_name"] == "test_logger" + # Verify no request context data was added + assert "request_id" not in kwargs + + +class TestLogFormatFunction: + """Tests for the conditional log format function with trace_id.""" + + TRACE_ID = "my-workflow-trace-123" + + def test_format_includes_trace_id_when_present(self): + """Format should include trace_id when present.""" + record = { + "extra": { + "logger_name": "test_logger", + "trace_id": self.TRACE_ID, + } + } + + # Build trace_id display string (mimics logger_adaptor logic) + trace_id = record["extra"].get("trace_id", "") + trace_id_str = f" trace_id={trace_id}" if trace_id else "" + + assert "trace_id=" in trace_id_str + assert self.TRACE_ID in trace_id_str + + def test_format_excludes_trace_id_when_missing(self): + """Format should exclude trace_id when not present.""" + record = {"extra": {"logger_name": "test_logger"}} + + # Build trace_id display string + trace_id = record["extra"].get("trace_id", "") + trace_id_str = f" trace_id={trace_id}" if trace_id else "" + + assert trace_id_str == "" + + def test_format_excludes_trace_id_when_empty(self): + """Format should exclude trace_id when empty string.""" + record = {"extra": {"logger_name": "test_logger", "trace_id": ""}} + + # Build trace_id display string + trace_id = record["extra"].get("trace_id", "") + trace_id_str = f" trace_id={trace_id}" if trace_id else "" + + assert trace_id_str == "" + + def test_format_trace_id_does_not_include_atlan_headers(self): + """Format should only include trace_id, not atlan-* headers in display.""" + record = { + "extra": { + "logger_name": "test_logger", + "trace_id": self.TRACE_ID, + "atlan-tenant": "test-tenant", + "atlan-user": "test-user", + } + } + + # Build trace_id display string (only trace_id, not atlan-*) + trace_id = record["extra"].get("trace_id", "") + trace_id_str = f" trace_id={trace_id}" if trace_id else "" + + assert "trace_id=" in trace_id_str + assert self.TRACE_ID in trace_id_str + # atlan-* headers should NOT be in the display string + assert "atlan-tenant" not in trace_id_str + assert "atlan-user" not in trace_id_str + + +class TestCorrelationContextIntegration: + """Tests for correlation context combined with workflow/activity context.""" + + WORKFLOW_NAME_HEADER = "atlan-workflow-name" + WORKFLOW_NODE_HEADER = "atlan-workflow-node" + WORKFLOW_NAME = "test-workflow-123" + WORKFLOW_NODE = "test-workflow-123.node-1" + WORKFLOW_ID = "test-workflow-123" + + def test_correlation_context_with_workflow_context(self): + """Correlation context should work alongside workflow context.""" + with create_logger_adapter() as logger_adapter: + with mock.patch("temporalio.workflow.info") as mock_workflow_info: + with mock.patch( + "application_sdk.observability.logger_adaptor.correlation_context" + ) as mock_corr_context: + workflow_info = mock.Mock( + workflow_id=self.WORKFLOW_ID, + run_id="019b04bd-ac10-7989-87d7-06427dc0616c", + workflow_type="RedshiftMetadataExtractionWorkflow", + namespace="default", + task_queue="atlan-redshift-local", + attempt=1, + ) + mock_workflow_info.return_value = workflow_info + mock_corr_context.get.return_value = { + self.WORKFLOW_NAME_HEADER: self.WORKFLOW_NAME, + self.WORKFLOW_NODE_HEADER: self.WORKFLOW_NODE, + } + + msg, kwargs = logger_adapter.process("Test message", {}) + + assert kwargs["workflow_id"] == self.WORKFLOW_ID + assert ( + kwargs["workflow_run_id"] + == "019b04bd-ac10-7989-87d7-06427dc0616c" + ) + assert self.WORKFLOW_NAME_HEADER in kwargs + assert self.WORKFLOW_NODE_HEADER in kwargs + + def test_correlation_context_with_activity_context(self): + """Correlation context should work alongside activity context.""" + with create_logger_adapter() as logger_adapter: + with mock.patch("temporalio.activity.info") as mock_activity_info: + with mock.patch( + "application_sdk.observability.logger_adaptor.correlation_context" + ) as mock_corr_context: + activity_info = mock.Mock( + workflow_id=self.WORKFLOW_ID, + workflow_run_id="019b04bd-ac10-7989-87d7-06427dc0616c", + activity_id="fetch_databases", + activity_type="fetch_databases", + task_queue="atlan-redshift-local", + attempt=1, + ) + mock_activity_info.return_value = activity_info + mock_corr_context.get.return_value = { + self.WORKFLOW_NAME_HEADER: self.WORKFLOW_NAME, + self.WORKFLOW_NODE_HEADER: self.WORKFLOW_NODE, + } + + msg, kwargs = logger_adapter.process("Test message", {}) + + assert kwargs["activity_id"] == "fetch_databases" + assert kwargs["workflow_id"] == self.WORKFLOW_ID + assert self.WORKFLOW_NAME_HEADER in kwargs + assert self.WORKFLOW_NODE_HEADER in kwargs