|
| 1 | +"""Weave (Weights & Biases) adapter for Eval Protocol. |
| 2 | +
|
| 3 | +This adapter fetches recent root traces from Weave Trace API and converts them |
| 4 | +to `EvaluationRow` format for use in evaluation pipelines. It is intentionally |
| 5 | +minimal and depends only on requests. |
| 6 | +""" |
| 7 | + |
| 8 | +from __future__ import annotations |
| 9 | + |
| 10 | +from typing import Any, Dict, List, Optional |
| 11 | +import os |
| 12 | +import requests |
| 13 | + |
| 14 | +from eval_protocol.models import EvaluationRow, InputMetadata, Message, ExecutionMetadata |
| 15 | +from .base import BaseAdapter |
| 16 | + |
| 17 | + |
| 18 | +def _extract_messages_from_trace(trace: Dict[str, Any], include_tool_calls: bool = True) -> List[Message]: |
| 19 | + messages: List[Message] = [] |
| 20 | + |
| 21 | + # Prefer explicit output messages if provided |
| 22 | + output = trace.get("output") or {} |
| 23 | + out_msgs = output.get("messages") |
| 24 | + if isinstance(out_msgs, list): |
| 25 | + for m in out_msgs: |
| 26 | + messages.append( |
| 27 | + Message( |
| 28 | + role=m.get("role"), |
| 29 | + content=m.get("content"), |
| 30 | + tool_calls=m.get("tool_calls") if include_tool_calls else None, |
| 31 | + tool_call_id=m.get("tool_call_id"), |
| 32 | + name=m.get("name"), |
| 33 | + ) |
| 34 | + ) |
| 35 | + |
| 36 | + # If no explicit output messages, fall back to final bubble from choices |
| 37 | + if not messages: |
| 38 | + choices = output.get("choices") |
| 39 | + if isinstance(choices, list) and choices: |
| 40 | + msg = (choices[0] or {}).get("message", {}) |
| 41 | + if msg: |
| 42 | + messages.append(Message(role=msg.get("role"), content=msg.get("content"))) |
| 43 | + |
| 44 | + # Prepend input messages if present and not already contained |
| 45 | + inputs = trace.get("inputs") or {} |
| 46 | + in_msgs = inputs.get("messages") |
| 47 | + if isinstance(in_msgs, list): |
| 48 | + prefixed = [Message(role=m.get("role"), content=m.get("content")) for m in in_msgs] |
| 49 | + messages = prefixed + messages |
| 50 | + |
| 51 | + return messages |
| 52 | + |
| 53 | + |
| 54 | +def _convert_trace_to_evaluation_row( |
| 55 | + trace: Dict[str, Any], include_tool_calls: bool = True |
| 56 | +) -> Optional[EvaluationRow]: |
| 57 | + messages = _extract_messages_from_trace(trace, include_tool_calls=include_tool_calls) |
| 58 | + if not messages: |
| 59 | + return None |
| 60 | + |
| 61 | + # Provider-native IDs for UI joinability |
| 62 | + session_data = { |
| 63 | + "weave_trace_id": trace.get("id"), |
| 64 | + "weave_project_id": trace.get("project_id"), |
| 65 | + } |
| 66 | + |
| 67 | + # Optional EP identifiers (if present in provider payload) |
| 68 | + meta_in = (trace.get("inputs") or {}).get("metadata") or {} |
| 69 | + meta_out = (trace.get("output") or {}).get("metadata") or {} |
| 70 | + metadata = {**meta_in, **meta_out} |
| 71 | + |
| 72 | + input_metadata = InputMetadata(row_id=metadata.get("row_id"), session_data=session_data) |
| 73 | + |
| 74 | + # Preserve default factory behavior by only setting provided fields |
| 75 | + exec_kwargs: Dict[str, Any] = {} |
| 76 | + for k in ("invocation_id", "experiment_id", "rollout_id", "run_id"): |
| 77 | + if metadata.get(k) is not None: |
| 78 | + exec_kwargs[k] = metadata[k] |
| 79 | + execution_metadata = ExecutionMetadata(**exec_kwargs) |
| 80 | + |
| 81 | + # Capture tools if provider exposes them (prefer inputs) |
| 82 | + tools = None |
| 83 | + inputs = trace.get("inputs") or {} |
| 84 | + if include_tool_calls and isinstance(inputs, dict) and "tools" in inputs: |
| 85 | + tools = inputs.get("tools") |
| 86 | + |
| 87 | + return EvaluationRow( |
| 88 | + messages=messages, tools=tools, input_metadata=input_metadata, execution_metadata=execution_metadata |
| 89 | + ) |
| 90 | + |
| 91 | + |
| 92 | +class WeaveAdapter(BaseAdapter): |
| 93 | + """Adapter to pull data from Weave Trace API and convert to EvaluationRow format.""" |
| 94 | + |
| 95 | + def __init__( |
| 96 | + self, base_url: Optional[str] = None, api_token: Optional[str] = None, project_id: Optional[str] = None |
| 97 | + ): |
| 98 | + self.base_url = base_url or os.getenv("WEAVE_TRACE_BASE_URL", "https://trace.wandb.ai") |
| 99 | + self.api_token = api_token or os.getenv("WANDB_API_KEY") |
| 100 | + # project_id is in form "<entity>/<project>" |
| 101 | + self.project_id = project_id or (f"{os.getenv('WANDB_ENTITY')}/{os.getenv('WANDB_PROJECT')}") |
| 102 | + if not self.api_token or not self.project_id or "/" not in self.project_id: |
| 103 | + raise ValueError("Missing Weave credentials or project (WANDB_API_KEY and WANDB_ENTITY/WANDB_PROJECT)") |
| 104 | + |
| 105 | + def _fetch_traces(self, limit: int = 100) -> List[Dict[str, Any]]: |
| 106 | + url = f"{self.base_url}/calls/stream_query" |
| 107 | + payload = { |
| 108 | + "project_id": self.project_id, |
| 109 | + "filter": {"trace_roots_only": True}, |
| 110 | + "limit": limit, |
| 111 | + "offset": 0, |
| 112 | + "sort_by": [{"field": "started_at", "direction": "desc"}], |
| 113 | + "include_feedback": False, |
| 114 | + } |
| 115 | + headers = {"Authorization": f"Bearer {self.api_token}", "Content-Type": "application/json"} |
| 116 | + resp = requests.post(url, json=payload, headers=headers, timeout=30) |
| 117 | + resp.raise_for_status() |
| 118 | + body = resp.json() or {} |
| 119 | + return body.get("data", []) |
| 120 | + |
| 121 | + def get_evaluation_rows(self, *args, **kwargs) -> List[EvaluationRow]: |
| 122 | + limit = kwargs.get("limit", 100) |
| 123 | + include_tool_calls = kwargs.get("include_tool_calls", True) |
| 124 | + traces = self._fetch_traces(limit=limit) |
| 125 | + rows: List[EvaluationRow] = [] |
| 126 | + for tr in traces: |
| 127 | + row = _convert_trace_to_evaluation_row(tr, include_tool_calls=include_tool_calls) |
| 128 | + if row: |
| 129 | + rows.append(row) |
| 130 | + return rows |
0 commit comments