Skip to content

Commit acb168f

Browse files
authored
weave integration (#202)
more fixes
1 parent b120611 commit acb168f

File tree

9 files changed

+314
-0
lines changed

9 files changed

+314
-0
lines changed

eval_protocol/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@
6464
except ImportError:
6565
LangSmithAdapter = None
6666

67+
68+
try:
69+
from .adapters import WeaveAdapter
70+
except ImportError:
71+
WeaveAdapter = None
72+
6773
warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")
6874

6975
__all__ = [

eval_protocol/adapters/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,10 @@
9292
__all__.extend(["LangSmithAdapter"])
9393
except ImportError:
9494
pass
95+
96+
try:
97+
from .weave import WeaveAdapter
98+
99+
__all__.extend(["WeaveAdapter"])
100+
except ImportError:
101+
pass

eval_protocol/adapters/weave.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""Weave (Weights & Biases) adapter for Eval Protocol.
2+
3+
This adapter fetches recent root traces from Weave Trace API and converts them
4+
to `EvaluationRow` format for use in evaluation pipelines. It is intentionally
5+
minimal and depends only on requests.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
from typing import Any, Dict, List, Optional
11+
import os
12+
import requests
13+
14+
from eval_protocol.models import EvaluationRow, InputMetadata, Message, ExecutionMetadata
15+
from .base import BaseAdapter
16+
17+
18+
def _extract_messages_from_trace(trace: Dict[str, Any], include_tool_calls: bool = True) -> List[Message]:
19+
messages: List[Message] = []
20+
21+
# Prefer explicit output messages if provided
22+
output = trace.get("output") or {}
23+
out_msgs = output.get("messages")
24+
if isinstance(out_msgs, list):
25+
for m in out_msgs:
26+
messages.append(
27+
Message(
28+
role=m.get("role"),
29+
content=m.get("content"),
30+
tool_calls=m.get("tool_calls") if include_tool_calls else None,
31+
tool_call_id=m.get("tool_call_id"),
32+
name=m.get("name"),
33+
)
34+
)
35+
36+
# If no explicit output messages, fall back to final bubble from choices
37+
if not messages:
38+
choices = output.get("choices")
39+
if isinstance(choices, list) and choices:
40+
msg = (choices[0] or {}).get("message", {})
41+
if msg:
42+
messages.append(Message(role=msg.get("role"), content=msg.get("content")))
43+
44+
# Prepend input messages if present and not already contained
45+
inputs = trace.get("inputs") or {}
46+
in_msgs = inputs.get("messages")
47+
if isinstance(in_msgs, list):
48+
prefixed = [Message(role=m.get("role"), content=m.get("content")) for m in in_msgs]
49+
messages = prefixed + messages
50+
51+
return messages
52+
53+
54+
def _convert_trace_to_evaluation_row(
55+
trace: Dict[str, Any], include_tool_calls: bool = True
56+
) -> Optional[EvaluationRow]:
57+
messages = _extract_messages_from_trace(trace, include_tool_calls=include_tool_calls)
58+
if not messages:
59+
return None
60+
61+
# Provider-native IDs for UI joinability
62+
session_data = {
63+
"weave_trace_id": trace.get("id"),
64+
"weave_project_id": trace.get("project_id"),
65+
}
66+
67+
# Optional EP identifiers (if present in provider payload)
68+
meta_in = (trace.get("inputs") or {}).get("metadata") or {}
69+
meta_out = (trace.get("output") or {}).get("metadata") or {}
70+
metadata = {**meta_in, **meta_out}
71+
72+
input_metadata = InputMetadata(row_id=metadata.get("row_id"), session_data=session_data)
73+
74+
# Preserve default factory behavior by only setting provided fields
75+
exec_kwargs: Dict[str, Any] = {}
76+
for k in ("invocation_id", "experiment_id", "rollout_id", "run_id"):
77+
if metadata.get(k) is not None:
78+
exec_kwargs[k] = metadata[k]
79+
execution_metadata = ExecutionMetadata(**exec_kwargs)
80+
81+
# Capture tools if provider exposes them (prefer inputs)
82+
tools = None
83+
inputs = trace.get("inputs") or {}
84+
if include_tool_calls and isinstance(inputs, dict) and "tools" in inputs:
85+
tools = inputs.get("tools")
86+
87+
return EvaluationRow(
88+
messages=messages, tools=tools, input_metadata=input_metadata, execution_metadata=execution_metadata
89+
)
90+
91+
92+
class WeaveAdapter(BaseAdapter):
93+
"""Adapter to pull data from Weave Trace API and convert to EvaluationRow format."""
94+
95+
def __init__(
96+
self, base_url: Optional[str] = None, api_token: Optional[str] = None, project_id: Optional[str] = None
97+
):
98+
self.base_url = base_url or os.getenv("WEAVE_TRACE_BASE_URL", "https://trace.wandb.ai")
99+
self.api_token = api_token or os.getenv("WANDB_API_KEY")
100+
# project_id is in form "<entity>/<project>"
101+
self.project_id = project_id or (f"{os.getenv('WANDB_ENTITY')}/{os.getenv('WANDB_PROJECT')}")
102+
if not self.api_token or not self.project_id or "/" not in self.project_id:
103+
raise ValueError("Missing Weave credentials or project (WANDB_API_KEY and WANDB_ENTITY/WANDB_PROJECT)")
104+
105+
def _fetch_traces(self, limit: int = 100) -> List[Dict[str, Any]]:
106+
url = f"{self.base_url}/calls/stream_query"
107+
payload = {
108+
"project_id": self.project_id,
109+
"filter": {"trace_roots_only": True},
110+
"limit": limit,
111+
"offset": 0,
112+
"sort_by": [{"field": "started_at", "direction": "desc"}],
113+
"include_feedback": False,
114+
}
115+
headers = {"Authorization": f"Bearer {self.api_token}", "Content-Type": "application/json"}
116+
resp = requests.post(url, json=payload, headers=headers, timeout=30)
117+
resp.raise_for_status()
118+
body = resp.json() or {}
119+
return body.get("data", [])
120+
121+
def get_evaluation_rows(self, *args, **kwargs) -> List[EvaluationRow]:
122+
limit = kwargs.get("limit", 100)
123+
include_tool_calls = kwargs.get("include_tool_calls", True)
124+
traces = self._fetch_traces(limit=limit)
125+
rows: List[EvaluationRow] = []
126+
for tr in traces:
127+
row = _convert_trace_to_evaluation_row(tr, include_tool_calls=include_tool_calls)
128+
if row:
129+
rows.append(row)
130+
return rows

examples/README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,19 @@ The typical lifecycle of working with or developing an example involves these ke
7070
2. Model your structure and documentation after `examples/math_example/`.
7171
3. Ensure your example has its own clear `README.md` and necessary `conf/` files.
7272
4. Test thoroughly.
73+
74+
## Tracing provider IO references
75+
76+
Provider-specific IO references (input logging + output pulling) live under:
77+
78+
- `examples/tracing/<provider>/`
79+
80+
Current providers:
81+
82+
- `examples/tracing/weave/`: Input/Output reference for Weave (W&B) tracing
83+
84+
Each provider folder includes:
85+
86+
- `produce_input_trace.py`: Minimal script to log a chat completion
87+
- `pull_output_traces.py`: Script to fetch traces and convert to `EvaluationRow`
88+
- `converter.py`: Provider-to-EP message+metadata mapping

examples/adapters/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,16 @@ Loads datasets from HuggingFace Hub and converts them to EvaluationRow format.
4343
pip install 'eval-protocol[huggingface]'
4444
```
4545

46+
## Tracing provider IO references
47+
48+
Provider-specific IO references (input logging + output pulling) have moved under:
49+
50+
- `examples/tracing/<provider>/`
51+
52+
For Weave, see `examples/tracing/weave/` which contains a focused `converter.py` illustrating how to map provider payloads to EP messages and metadata.
53+
54+
These examples are designed to be self-contained and usable as references for building or validating provider adapters.
55+
4656
## Running the Examples
4757

4858
### Basic Usage

examples/tracing/__init__.py

Whitespace-only changes.

examples/tracing/weave/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Weave (Weights & Biases) tracing examples.
2+
3+
This package contains a focused `converter.py` that illustrates how to map
4+
Weave provider payloads to Eval Protocol `EvaluationRow` objects. Use it as a
5+
reference when building or validating provider adapters.
6+
"""
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from typing import Any, Dict, List, Optional
2+
3+
from eval_protocol.models import EvaluationRow, InputMetadata, Message, ExecutionMetadata
4+
5+
6+
def _extract_messages_from_trace(trace: Dict[str, Any], include_tool_calls: bool = True) -> List[Message]:
7+
messages: List[Message] = []
8+
9+
# Prefer explicit output messages if provided
10+
output = trace.get("output") or {}
11+
out_msgs = output.get("messages")
12+
if isinstance(out_msgs, list):
13+
for m in out_msgs:
14+
messages.append(
15+
Message(
16+
role=m.get("role"),
17+
content=m.get("content"),
18+
tool_calls=m.get("tool_calls") if include_tool_calls else None,
19+
tool_call_id=m.get("tool_call_id"),
20+
name=m.get("name"),
21+
)
22+
)
23+
24+
# If no explicit output messages, fall back to final bubble from choices
25+
if not messages:
26+
choices = output.get("choices")
27+
if isinstance(choices, list) and choices:
28+
msg = (choices[0] or {}).get("message", {})
29+
if msg:
30+
messages.append(Message(role=msg.get("role"), content=msg.get("content")))
31+
32+
# Prepend input messages if present and not already contained
33+
inputs = trace.get("inputs") or {}
34+
in_msgs = inputs.get("messages")
35+
if isinstance(in_msgs, list):
36+
prefixed = [Message(role=m.get("role"), content=m.get("content")) for m in in_msgs]
37+
messages = prefixed + messages
38+
39+
return messages
40+
41+
42+
def convert_trace_to_evaluation_row(trace: Dict[str, Any], include_tool_calls: bool = True) -> Optional[EvaluationRow]:
43+
messages = _extract_messages_from_trace(trace, include_tool_calls=include_tool_calls)
44+
if not messages:
45+
return None
46+
47+
# Provider-native IDs for UI joinability
48+
session_data = {
49+
"weave_trace_id": trace.get("id"),
50+
"weave_project_id": trace.get("project_id"),
51+
}
52+
53+
# Optional EP identifiers (if present in provider payload)
54+
meta_in = (trace.get("inputs") or {}).get("metadata") or {}
55+
meta_out = (trace.get("output") or {}).get("metadata") or {}
56+
metadata = {**meta_in, **meta_out}
57+
58+
input_metadata = InputMetadata(row_id=metadata.get("row_id"), session_data=session_data)
59+
60+
# Preserve default factory behavior by only setting provided fields
61+
exec_kwargs: Dict[str, Any] = {}
62+
for k in ("invocation_id", "experiment_id", "rollout_id", "run_id"):
63+
if metadata.get(k) is not None:
64+
exec_kwargs[k] = metadata[k]
65+
execution_metadata = ExecutionMetadata(**exec_kwargs)
66+
67+
# Capture tools if provider exposes them (prefer inputs)
68+
tools = None
69+
inputs = trace.get("inputs") or {}
70+
if include_tool_calls and isinstance(inputs, dict) and "tools" in inputs:
71+
tools = inputs.get("tools")
72+
73+
return EvaluationRow(
74+
messages=messages,
75+
tools=tools,
76+
input_metadata=input_metadata,
77+
execution_metadata=execution_metadata,
78+
)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import os
2+
import importlib.util
3+
from pathlib import Path
4+
5+
import pytest
6+
7+
8+
def _load_module_from_path(name: str, path: str):
9+
spec = importlib.util.spec_from_file_location(name, path)
10+
assert spec and spec.loader, f"Failed to load module spec for {name} from {path}"
11+
mod = importlib.util.module_from_spec(spec)
12+
spec.loader.exec_module(mod) # type: ignore[attr-defined]
13+
return mod
14+
15+
16+
@pytest.mark.skip(reason="Weave example only: converter IO smoke-test placeholder (no live fetch script).")
17+
def test_weave_converter_basic_messages():
18+
root = Path(__file__).resolve().parents[2]
19+
converter_path = root / "examples" / "tracing" / "weave" / "converter.py"
20+
mod = _load_module_from_path("weave_converter", str(converter_path))
21+
convert = getattr(mod, "convert_trace_to_evaluation_row")
22+
23+
trace = {
24+
"id": "tr_123",
25+
"project_id": "team/proj",
26+
"inputs": {"messages": [{"role": "user", "content": "Hi"}]},
27+
"output": {"choices": [{"message": {"role": "assistant", "content": "Hello"}}]},
28+
}
29+
30+
row = convert(trace)
31+
assert len(row.messages) >= 1
32+
assert row.input_metadata.session_data.get("weave_trace_id") == "tr_123"
33+
34+
35+
@pytest.mark.skip(reason="Credential-gated live fetch; enable locally with WANDB creds.")
36+
def test_weave_fetch_and_convert_live():
37+
# Require explicit env to avoid CI failures
38+
if not os.getenv("WANDB_API_KEY"):
39+
pytest.skip("WANDB_API_KEY not set")
40+
41+
team = os.getenv("WANDB_ENTITY") or os.getenv("WEAVE_TEAM_ID")
42+
project = os.getenv("WANDB_PROJECT") or os.getenv("WEAVE_PROJECT_ID")
43+
if not team or not project:
44+
pytest.skip("Weave project not configured")
45+
46+
base_url = os.getenv("WEAVE_TRACE_BASE_URL", "https://trace.wandb.ai")
47+
root = Path(__file__).resolve().parents[2]
48+
pull_path = root / "examples" / "tracing" / "weave" / "pull_output_traces.py"
49+
conv_path = root / "examples" / "tracing" / "weave" / "converter.py"
50+
51+
pull_mod = _load_module_from_path("weave_pull", str(pull_path))
52+
conv_mod = _load_module_from_path("weave_converter", str(conv_path))
53+
54+
fetch_weave_traces = getattr(pull_mod, "fetch_weave_traces")
55+
convert = getattr(conv_mod, "convert_trace_to_evaluation_row")
56+
57+
traces = fetch_weave_traces(
58+
base_url=base_url, project_id=f"{team}/{project}", api_token=os.environ["WANDB_API_KEY"], limit=1
59+
)
60+
rows = [convert(tr) for tr in traces]
61+
assert any(r is not None for r in rows)

0 commit comments

Comments
 (0)