Skip to content

Commit 44ffe72

Browse files
authored
correlate evaluation logs with traces (#282)
* correlate evaluation logs with traces * Fix rollout logging filter fallbacks (#283)
1 parent e769b6a commit 44ffe72

File tree

7 files changed

+387
-16
lines changed

7 files changed

+387
-16
lines changed

eval_protocol/log_utils/elasticsearch_direct_http_handler.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,17 @@ def emit(self, record: logging.LogRecord) -> None:
6060
if status_info:
6161
data.update(status_info)
6262

63+
# Optional correlation enrichment
64+
experiment_id = getattr(record, "experiment_id", None)
65+
if experiment_id is not None:
66+
data["experiment_id"] = experiment_id
67+
run_id = getattr(record, "run_id", None)
68+
if run_id is not None:
69+
data["run_id"] = run_id
70+
rollout_ids = getattr(record, "rollout_ids", None)
71+
if rollout_ids is not None:
72+
data["rollout_ids"] = rollout_ids
73+
6374
# Schedule the HTTP request to run asynchronously
6475
self._schedule_async_send(data, record)
6576
except Exception as e:

eval_protocol/log_utils/fireworks_tracing_http_handler.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,36 @@ def _build_payload(self, record: logging.LogRecord, rollout_id: str) -> Dict[str
4646
tags.append(f"experiment_id:{cast(Any, getattr(record, 'experiment_id'))}")
4747
if hasattr(record, "run_id") and cast(Any, getattr(record, "run_id")):
4848
tags.append(f"run_id:{cast(Any, getattr(record, 'run_id'))}")
49+
# Groupwise list of rollout_ids
50+
if hasattr(record, "rollout_ids") and cast(Any, getattr(record, "rollout_ids")):
51+
try:
52+
for rid in cast(List[str], getattr(record, "rollout_ids")):
53+
tags.append(f"rollout_id:{rid}")
54+
except Exception:
55+
pass
4956
program = cast(Optional[str], getattr(record, "program", None)) or "eval_protocol"
5057
status_val = cast(Any, getattr(record, "status", None))
5158
status = status_val if isinstance(status_val, str) else None
59+
# Capture optional structured status fields if present
60+
metadata: Dict[str, Any] = {}
61+
status_code = cast(Any, getattr(record, "status_code", None))
62+
if isinstance(status_code, int):
63+
metadata["status_code"] = status_code
64+
status_message = cast(Any, getattr(record, "status_message", None))
65+
if isinstance(status_message, str):
66+
metadata["status_message"] = status_message
67+
status_details = getattr(record, "status_details", None)
68+
if status_details is not None:
69+
metadata["status_details"] = status_details
70+
extra_metadata = cast(Any, getattr(record, "metadata", None))
71+
if isinstance(extra_metadata, dict):
72+
metadata.update(extra_metadata)
5273
return {
5374
"program": program,
5475
"status": status,
5576
"message": message,
5677
"tags": tags,
57-
"metadata": cast(Any, getattr(record, "metadata", None)),
78+
"metadata": metadata or None,
5879
"extras": {
5980
"logger_name": record.name,
6081
"level": record.levelname,

eval_protocol/log_utils/init.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import logging
2+
import os
3+
from typing import Optional
4+
5+
from eval_protocol.log_utils.fireworks_tracing_http_handler import (
6+
FireworksTracingHttpHandler,
7+
)
8+
from eval_protocol.log_utils.elasticsearch_direct_http_handler import (
9+
ElasticsearchDirectHttpHandler,
10+
)
11+
from eval_protocol.log_utils.rollout_context import ContextRolloutIdFilter
12+
from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig
13+
14+
15+
_INITIALIZED = False
16+
17+
18+
def _get_env(name: str) -> Optional[str]:
19+
val = os.getenv(name)
20+
return val if val and val.strip() else None
21+
22+
23+
def init_external_logging_from_env() -> None:
24+
"""
25+
Initialize external logging sinks (Fireworks tracing, optional Elasticsearch) from env vars.
26+
27+
Idempotent: safe to call multiple times.
28+
29+
Environment variables:
30+
- FW_TRACING_GATEWAY_BASE_URL: enable Fireworks tracing handler when set
31+
- EP_ELASTICSEARCH_URL, EP_ELASTICSEARCH_API_KEY, EP_ELASTICSEARCH_INDEX: enable ES when all set
32+
"""
33+
global _INITIALIZED
34+
if _INITIALIZED:
35+
return
36+
37+
root_logger = logging.getLogger()
38+
39+
# Ensure we do not add duplicate handlers if already present
40+
existing_handler_types = {type(h).__name__ for h in root_logger.handlers}
41+
42+
# Fireworks tracing
43+
fw_url = _get_env("FW_TRACING_GATEWAY_BASE_URL")
44+
if fw_url and "FireworksTracingHttpHandler" not in existing_handler_types:
45+
fw_handler = FireworksTracingHttpHandler(gateway_base_url=fw_url)
46+
fw_handler.setLevel(logging.INFO)
47+
fw_handler.addFilter(ContextRolloutIdFilter())
48+
root_logger.addHandler(fw_handler)
49+
50+
# Elasticsearch
51+
es_url = _get_env("EP_ELASTICSEARCH_URL")
52+
es_api_key = _get_env("EP_ELASTICSEARCH_API_KEY")
53+
es_index = _get_env("EP_ELASTICSEARCH_INDEX")
54+
if es_url and es_api_key and es_index and "ElasticsearchDirectHttpHandler" not in existing_handler_types:
55+
es_config = ElasticsearchConfig(url=es_url, api_key=es_api_key, index_name=es_index)
56+
es_handler = ElasticsearchDirectHttpHandler(elasticsearch_config=es_config)
57+
es_handler.setLevel(logging.INFO)
58+
es_handler.addFilter(ContextRolloutIdFilter())
59+
root_logger.addHandler(es_handler)
60+
61+
_INITIALIZED = True
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import logging
2+
import os
3+
from contextlib import asynccontextmanager
4+
from typing import List, Optional
5+
6+
import contextvars
7+
8+
9+
# Context variables used to correlate logs with rollouts under concurrency
10+
current_rollout_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar("ep_rollout_id", default=None)
11+
current_rollout_ids: contextvars.ContextVar[Optional[List[str]]] = contextvars.ContextVar(
12+
"ep_rollout_ids", default=None
13+
)
14+
current_experiment_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar("ep_experiment_id", default=None)
15+
current_run_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar("ep_run_id", default=None)
16+
17+
18+
class ContextRolloutIdFilter(logging.Filter):
19+
"""
20+
Logging filter that injects correlation fields into a LogRecord from ContextVars.
21+
22+
The filter is intended to be attached ONLY to external sink handlers (e.g.,
23+
Fireworks or Elasticsearch). If there is no active rollout context, it drops
24+
the record for that handler to avoid shipping uncorrelated logs.
25+
"""
26+
27+
def filter(self, record: logging.LogRecord) -> bool: # type: ignore[override]
28+
rollout_id = current_rollout_id.get()
29+
if not rollout_id:
30+
# Allow explicit rollout IDs on the record or via environment fallback.
31+
rollout_id = getattr(record, "rollout_id", None) or os.getenv("EP_ROLLOUT_ID")
32+
if not rollout_id:
33+
# No correlation context → do not emit to external sink
34+
return False
35+
36+
# Inject primary correlation fields
37+
setattr(record, "rollout_id", rollout_id)
38+
39+
rollout_ids = current_rollout_ids.get()
40+
if rollout_ids:
41+
setattr(record, "rollout_ids", rollout_ids)
42+
43+
experiment_id = current_experiment_id.get()
44+
if experiment_id:
45+
setattr(record, "experiment_id", experiment_id)
46+
47+
run_id = current_run_id.get()
48+
if run_id:
49+
setattr(record, "run_id", run_id)
50+
51+
return True
52+
53+
54+
@asynccontextmanager
55+
async def rollout_logging_context(
56+
rollout_id: str,
57+
*,
58+
experiment_id: Optional[str] = None,
59+
run_id: Optional[str] = None,
60+
rollout_ids: Optional[List[str]] = None,
61+
):
62+
"""
63+
Async context manager to set correlation ContextVars for the current task.
64+
65+
Args:
66+
rollout_id: Primary rollout identifier for correlation.
67+
experiment_id: Optional experiment ID for tagging.
68+
run_id: Optional run ID for tagging.
69+
rollout_ids: Optional list of related rollout IDs (e.g., groupwise mode).
70+
"""
71+
t_rollout = current_rollout_id.set(rollout_id)
72+
t_rollouts = current_rollout_ids.set(rollout_ids) if rollout_ids is not None else None
73+
t_experiment = current_experiment_id.set(experiment_id) if experiment_id is not None else None
74+
t_run = current_run_id.set(run_id) if run_id is not None else None
75+
try:
76+
yield
77+
finally:
78+
current_rollout_id.reset(t_rollout)
79+
if t_rollouts is not None:
80+
current_rollout_ids.reset(t_rollouts)
81+
if t_experiment is not None:
82+
current_experiment_id.reset(t_experiment)
83+
if t_run is not None:
84+
current_run_id.reset(t_run)

eval_protocol/pytest/evaluation_test.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
run_tasks_with_run_progress,
6464
)
6565
from eval_protocol.utils.show_results_url import store_local_ui_results_url, generate_invocation_filter_url
66+
from eval_protocol.log_utils.init import init_external_logging_from_env
67+
from eval_protocol.log_utils.rollout_context import rollout_logging_context
6668
from eval_protocol.utils.browser_utils import is_logs_server_running, open_browser_tab
6769

6870
from ..common_utils import load_jsonl
@@ -254,6 +256,9 @@ def create_wrapper_with_signature() -> Callable[[], None]:
254256
async def wrapper_body(**kwargs: Unpack[ParameterizedTestKwargs]) -> None:
255257
nonlocal browser_opened_for_invocation
256258

259+
# Initialize external logging sinks (Fireworks/ES) from env (idempotent)
260+
init_external_logging_from_env()
261+
257262
# Store URL for viewing results (after all postprocessing is complete)
258263
store_local_ui_results_url(invocation_id)
259264

@@ -419,11 +424,16 @@ async def _execute_pointwise_eval_with_semaphore(
419424
) -> EvaluationRow:
420425
async with semaphore:
421426
evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {}
422-
result = await execute_pytest(
423-
test_func,
424-
processed_row=row,
425-
evaluation_test_kwargs=evaluation_test_kwargs,
426-
)
427+
async with rollout_logging_context(
428+
row.execution_metadata.rollout_id or "",
429+
experiment_id=experiment_id,
430+
run_id=run_id,
431+
):
432+
result = await execute_pytest(
433+
test_func,
434+
processed_row=row,
435+
evaluation_test_kwargs=evaluation_test_kwargs,
436+
)
427437
if not isinstance(result, EvaluationRow):
428438
raise ValueError(
429439
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
@@ -435,11 +445,21 @@ async def _execute_groupwise_eval_with_semaphore(
435445
) -> list[EvaluationRow]:
436446
async with semaphore:
437447
evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {}
438-
results = await execute_pytest(
439-
test_func,
440-
processed_dataset=rows,
441-
evaluation_test_kwargs=evaluation_test_kwargs,
442-
)
448+
primary_rollout_id = rows[0].execution_metadata.rollout_id if rows else None
449+
group_rollout_ids = [
450+
r.execution_metadata.rollout_id for r in rows if r.execution_metadata.rollout_id
451+
]
452+
async with rollout_logging_context(
453+
primary_rollout_id or "",
454+
experiment_id=experiment_id,
455+
run_id=run_id,
456+
rollout_ids=group_rollout_ids or None,
457+
):
458+
results = await execute_pytest(
459+
test_func,
460+
processed_dataset=rows,
461+
evaluation_test_kwargs=evaluation_test_kwargs,
462+
)
443463
if not isinstance(results, list):
444464
raise ValueError(
445465
f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
@@ -516,11 +536,25 @@ async def _collect_result(config, lst):
516536
input_dataset.append(row)
517537
# NOTE: we will still evaluate errored rows (give users control over this)
518538
# i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
519-
results = await execute_pytest(
520-
test_func,
521-
processed_dataset=input_dataset,
522-
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
539+
primary_rollout_id = (
540+
input_dataset[0].execution_metadata.rollout_id if input_dataset else None
523541
)
542+
group_rollout_ids = [
543+
r.execution_metadata.rollout_id
544+
for r in input_dataset
545+
if r.execution_metadata.rollout_id
546+
]
547+
async with rollout_logging_context(
548+
primary_rollout_id or "",
549+
experiment_id=experiment_id,
550+
run_id=run_id,
551+
rollout_ids=group_rollout_ids or None,
552+
):
553+
results = await execute_pytest(
554+
test_func,
555+
processed_dataset=input_dataset,
556+
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
557+
)
524558
if (
525559
results is None
526560
or not isinstance(results, list)

tests/chinook/pydantic/test_pydantic_complex_queries_responses.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Awaitable, Callable
2+
import logging
23
import os
34
from typing_extensions import cast
45
from pydantic_ai import Agent
@@ -36,7 +37,7 @@ def agent_factory(config: RolloutProcessorConfig) -> Agent:
3637
input_rows=[collect_dataset()],
3738
completion_params=[
3839
{
39-
"model": "gpt-5",
40+
"model": "gpt-5-nano",
4041
},
4142
],
4243
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
@@ -45,6 +46,19 @@ async def test_pydantic_complex_queries_responses(row: EvaluationRow) -> Evaluat
4546
"""
4647
Evaluation of complex queries for the Chinook database using PydanticAI
4748
"""
49+
50+
logger = logging.getLogger("tests.chinook.pydantic.complex_queries_responses")
51+
logger.info(
52+
"Starting chinook responses evaluation",
53+
extra={"status": {"code": 101, "message": "RUNNING"}},
54+
)
55+
4856
casted_evaluation_test = cast(Callable[[EvaluationRow], Awaitable[EvaluationRow]], eval)
4957
evaluated_row = await casted_evaluation_test(row)
58+
59+
logger.info(
60+
"Finished chinook responses evaluation",
61+
extra={"status": {"code": 100, "message": "FINISHED"}},
62+
)
63+
5064
return evaluated_row

0 commit comments

Comments
 (0)