Skip to content

Commit 9a88211

Browse files
committed
2 parents 9de46b3 + 2ab6308 commit 9a88211

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import base64
23
import time
34
from typing import Any, Dict, List, Optional, Callable
45

@@ -25,7 +26,9 @@
2526
logger = logging.getLogger(__name__)
2627

2728

28-
def _build_fireworks_tracing_url(base_url: str, metadata: RolloutMetadata) -> str:
29+
def _build_fireworks_tracing_url(
30+
base_url: str, metadata: RolloutMetadata, completion_params_base_url: Optional[str] = None
31+
) -> str:
2932
"""Build a Fireworks tracing URL by appending rollout metadata to the base URL path,
3033
allowing the Fireworks tracing proxy to automatically tag traces.
3134
@@ -35,15 +38,24 @@ def _build_fireworks_tracing_url(base_url: str, metadata: RolloutMetadata) -> st
3538
base_url: Fireworks tracing proxy URL (we expect this to be https://tracing.fireworks.ai or
3639
https://tracing.fireworks.ai/project_id/{project_id})
3740
metadata: Rollout metadata containing IDs to embed in the URL
41+
completion_params_base_url: Optional LLM base URL to encode and append to the final URL
3842
"""
39-
return (
43+
url = (
4044
f"{base_url}/rollout_id/{metadata.rollout_id}"
4145
f"/invocation_id/{metadata.invocation_id}"
4246
f"/experiment_id/{metadata.experiment_id}"
4347
f"/run_id/{metadata.run_id}"
4448
f"/row_id/{metadata.row_id}"
4549
)
4650

51+
if (
52+
completion_params_base_url
53+
): # The final URL is both tracing.fireworks.ai and the actual LLM base URL we want to use
54+
encoded_base_url = base64.urlsafe_b64encode(completion_params_base_url.encode()).decode()
55+
url = f"{url}/encoded_base_url/{encoded_base_url}"
56+
57+
return url
58+
4759

4860
def _default_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
4961
"""Default output data loader that fetches traces from Fireworks tracing proxy.
@@ -164,6 +176,13 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
164176
"Model must be provided in row.input_metadata.completion_params or config.completion_params"
165177
)
166178

179+
# Extract base_url from completion_params if provided. If we're using tracing.fireworks.ai, this base_url gets encoded and passed to LiteLLM inside the proxy.
180+
completion_params_base_url: Optional[str] = None
181+
if row.input_metadata and row.input_metadata.completion_params:
182+
completion_params_base_url = row.input_metadata.completion_params.get("base_url")
183+
if completion_params_base_url is None and config.completion_params:
184+
completion_params_base_url = config.completion_params.get("base_url")
185+
167186
# Strip non-OpenAI fields from messages before sending to remote
168187
allowed_message_fields = {"role", "content", "tool_calls", "tool_call_id", "name"}
169188
clean_messages = []
@@ -192,7 +211,7 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
192211
model_base_url.startswith("https://tracing.fireworks.ai")
193212
or model_base_url.startswith("http://localhost")
194213
):
195-
final_model_base_url = _build_fireworks_tracing_url(model_base_url, meta)
214+
final_model_base_url = _build_fireworks_tracing_url(model_base_url, meta, completion_params_base_url)
196215

197216
init_payload: InitRequest = InitRequest(
198217
model=model,

0 commit comments

Comments
 (0)