Skip to content

Commit cc8666e

Browse files
dphuang2xzrderek
andauthored
Propagate error status (#253)
* Fireworks Tracing * update path * add status handling from ECS * various changes * Refactor remote server startup to use argparse for host and port configuration; add tests for fireworks status propagation. * fix test * add dataloaderconfig * test_remote_rollout_and_fetch_fireworks_propagate_status * sync on latest * use get * run CI when parent is another PR * Implement rollout status handling in rollout processor; add helper function to preserve error status during updates. * make work for GH action (test) * disable test in regulaR CI / increase setup timeout * smoke test * for testing * test correctly * udpate * fix test * update test name * remove unnecessary secret * ensure it runs * remove from PRs * run on all pull requests * update name --------- Co-authored-by: Derek Xu <[email protected]>
1 parent 3aa96ae commit cc8666e

File tree

8 files changed

+212
-12
lines changed

8 files changed

+212
-12
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ on:
1111
- "docs/**"
1212
- "*.md"
1313
pull_request:
14-
branches: [main]
1514
paths-ignore:
1615
- "docs/**"
1716
- "*.md"
@@ -110,6 +109,7 @@ jobs:
110109
--ignore=tests/test_tau_bench_airline_smoke.py \
111110
--ignore=tests/pytest/test_svgbench.py \
112111
--ignore=tests/pytest/test_livesvgbench.py \
112+
--ignore=tests/remote_server/test_remote_fireworks_propagate_status.py \
113113
--ignore=eval_protocol/benchmarks/ \
114114
--cov=eval_protocol --cov-append --cov-report=xml --cov-report=term-missing -v --durations=10
115115
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
name: RemoteRolloutProcessor Propagate Status Test
2+
3+
on:
4+
push:
5+
branches: [main]
6+
paths-ignore:
7+
- "docs/**"
8+
- "*.md"
9+
pull_request: # Run on all pull requests
10+
paths-ignore:
11+
- "docs/**"
12+
- "*.md"
13+
workflow_dispatch: # Allow manual triggering
14+
15+
jobs:
16+
remote-rollout-processor-propagate-status-smoke-test:
17+
name: Fireworks Propagate Status Smoke Test
18+
runs-on: ubuntu-latest
19+
20+
steps:
21+
- name: Checkout repository
22+
uses: actions/checkout@v4
23+
with:
24+
fetch-depth: 0
25+
26+
- name: Set up Python 3.10
27+
uses: actions/setup-python@v5
28+
with:
29+
python-version: "3.10"
30+
31+
- name: Install uv
32+
uses: astral-sh/setup-uv@v6
33+
with:
34+
enable-cache: true
35+
36+
- name: Install the project
37+
run: uv sync --locked --all-extras --dev
38+
39+
- name: Run RemoteRolloutProcessor Propagate Status Smoke Test
40+
env:
41+
FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
42+
PYTHONWARNINGS: "ignore::DeprecationWarning,ignore::RuntimeWarning"
43+
run: |
44+
uv run pytest tests/remote_server/test_remote_fireworks_propagate_status.py::test_remote_rollout_and_fetch_fireworks_propagate_status \
45+
-v --tb=short

eval_protocol/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
RolloutMetadata,
4141
StatusResponse,
4242
create_langfuse_config_tags,
43+
DataLoaderConfig,
4344
)
4445

4546
try:
@@ -67,6 +68,7 @@
6768
__all__ = [
6869
"ElasticsearchDirectHttpHandler",
6970
"RolloutIdFilter",
71+
"DataLoaderConfig",
7072
"Status",
7173
"RemoteRolloutProcessor",
7274
"InputMetadata",

eval_protocol/pytest/elasticsearch_setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _setup_initialized_docker_elasticsearch(self, env_file_path: str) -> Elastic
7676
# Use set -o pipefail to ensure we get the return code of the first failing command
7777
process = subprocess.Popen(
7878
[
79-
"sh",
79+
"bash",
8080
"-c",
8181
f"set -o pipefail; curl -fsSL https://elastic.co/start-local | sh -s -- --esonly | tee {temp_file_path}",
8282
],

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,19 @@ def _get_status() -> Dict[str, Any]:
262262
hits = search_results["hits"]["hits"] if search_results else []
263263

264264
if hits:
265-
# log all statuses found
265+
# log all statuses found and update rollout status from the last hit
266266
for hit in hits:
267267
document = hit["_source"]
268268
logger.info(
269269
f"Found log for rollout {row.execution_metadata.rollout_id} with status code {document['status_code']}"
270270
)
271+
# Update rollout status from the document
272+
if "status_code" in document:
273+
row.rollout_status = Status(
274+
code=Status.Code(document["status_code"]),
275+
message=document.get("status_message", ""),
276+
details=document.get("status_details", []),
277+
)
271278
logger.info("Stopping status polling for rollout %s", row.execution_metadata.rollout_id)
272279
break
273280

eval_protocol/pytest/utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,16 @@ def deep_update_dict(base: dict[str, Any], override: dict[str, Any]) -> dict[str
312312
return base
313313

314314

315+
def _set_rollout_status_to_finished(result: EvaluationRow) -> None:
316+
# Only set to finished if execution finished while not
317+
# updating status itself. In the case that the rollout
318+
# processor set the status to an error, we want to
319+
# preserve the error so we do nothing in this case.
320+
# test_remote_fireworks_propagate_status.py verifies this.
321+
if result.rollout_status.is_running():
322+
result.rollout_status = Status.rollout_finished()
323+
324+
315325
async def rollout_processor_with_retry(
316326
rollout_processor: RolloutProcessor,
317327
fresh_dataset: list[EvaluationRow],
@@ -359,7 +369,9 @@ async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: Evalu
359369
try:
360370
# Try original task first
361371
result = await task # pyright: ignore[reportUnknownVariableType]
362-
result.rollout_status = Status.rollout_finished()
372+
373+
_set_rollout_status_to_finished(result)
374+
363375
return result # pyright: ignore[reportUnknownVariableType]
364376
except Exception as e:
365377
# NOTE: we perform these checks because we don't put the backoff decorator on initial batch call. we don't want to retry whole batch if anything fails.
@@ -372,7 +384,9 @@ async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: Evalu
372384
# Use shared backoff function for retryable exceptions
373385
try:
374386
result = await execute_row_with_backoff_retry(row)
375-
result.rollout_status = Status.rollout_finished()
387+
388+
_set_rollout_status_to_finished(result)
389+
376390
return result
377391
except Exception as retry_error:
378392
# Backoff gave up

tests/remote_server/remote_server.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import random
33
import threading
4+
import argparse
45

56
import uvicorn
67
from fastapi import FastAPI
@@ -17,6 +18,9 @@
1718
logging.getLogger().addHandler(handler)
1819

1920

21+
force_early_error_message = None
22+
23+
2024
@app.post("/init")
2125
def init(req: InitRequest):
2226
if req.elastic_search_config:
@@ -46,24 +50,56 @@ def _worker():
4650
completion = client.chat.completions.create(**completion_kwargs)
4751
logger.info(f"Completed response: {completion}")
4852

53+
# If force_early_error is set via command-line arg, log the error and return early
54+
if force_early_error_message:
55+
logger.error(
56+
force_early_error_message,
57+
extra={"status": Status.rollout_error(force_early_error_message)},
58+
)
59+
raise RuntimeError(force_early_error_message)
60+
4961
except Exception as e:
5062
# Best-effort; mark as done even on error to unblock polling
5163
print(f"❌ Error in rollout {req.metadata.rollout_id}: {e}")
5264
pass
5365
finally:
54-
logger.info(
55-
f"Rollout {req.metadata.rollout_id} completed",
56-
extra={"status": Status.rollout_finished()},
57-
)
66+
if not force_early_error_message:
67+
logger.info(
68+
f"Rollout {req.metadata.rollout_id} completed",
69+
extra={"status": Status.rollout_finished()},
70+
)
5871

5972
t = threading.Thread(target=_worker, daemon=True)
6073
t.start()
6174

6275

6376
def main():
64-
host = os.getenv("REMOTE_SERVER_HOST", "127.0.0.1")
65-
port = int(os.getenv("REMOTE_SERVER_PORT", "3000"))
66-
uvicorn.run(app, host=host, port=port)
77+
global force_early_error_message
78+
79+
parser = argparse.ArgumentParser(description="Run the remote server for evaluation protocol")
80+
parser.add_argument(
81+
"--host",
82+
type=str,
83+
default=os.getenv("REMOTE_SERVER_HOST", "127.0.0.1"),
84+
help="Host to bind the server to (default: 127.0.0.1 or REMOTE_SERVER_HOST env var)",
85+
)
86+
parser.add_argument(
87+
"--port",
88+
type=int,
89+
default=int(os.getenv("REMOTE_SERVER_PORT", "3000")),
90+
help="Port to bind the server to (default: 3000 or REMOTE_SERVER_PORT env var)",
91+
)
92+
parser.add_argument(
93+
"--force-early-error",
94+
type=str,
95+
default=None,
96+
help="If set, /init will immediately return after logging a rollout_error with this message",
97+
)
98+
99+
args = parser.parse_args()
100+
force_early_error_message = args.force_early_error
101+
102+
uvicorn.run(app, host=args.host, port=args.port)
67103

68104

69105
if __name__ == "__main__":
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# MANUAL SERVER STARTUP REQUIRED:
2+
#
3+
# For Python server testing, start:
4+
# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000)
5+
#
6+
# For TypeScript server testing, start:
7+
# cd tests/remote_server/typescript-server
8+
# npm install
9+
# npm start
10+
#
11+
# The TypeScript server should be running on http://127.0.0.1:3000
12+
# You only need to start one of the servers!
13+
14+
import subprocess
15+
import socket
16+
import time
17+
from typing import List
18+
19+
import pytest
20+
import requests
21+
22+
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
23+
from eval_protocol.models import EvaluationRow, Message, Status
24+
from eval_protocol.pytest import evaluation_test
25+
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
26+
27+
28+
def find_available_port() -> int:
29+
"""Find an available port on localhost"""
30+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
31+
s.bind(("", 0))
32+
port = s.getsockname()[1]
33+
return port
34+
35+
36+
SERVER_PORT = find_available_port()
37+
38+
39+
def wait_for_server_to_startup(timeout: int = 120):
40+
start_time = time.time()
41+
while True:
42+
try:
43+
requests.get(f"http://127.0.0.1:{SERVER_PORT}")
44+
break
45+
except requests.exceptions.RequestException:
46+
time.sleep(1)
47+
if time.time() - start_time > timeout:
48+
raise TimeoutError(f"Server did not start within {timeout} seconds")
49+
50+
51+
@pytest.fixture(autouse=True)
52+
def setup_remote_server():
53+
"""Start the remote server"""
54+
# kill all Python processes matching "python -m tests.remote_server.remote_server"
55+
subprocess.run(["pkill", "-f", "python -m tests.remote_server.remote_server"])
56+
57+
host = "127.0.0.1"
58+
process = subprocess.Popen(
59+
[
60+
"python",
61+
"-m",
62+
"tests.remote_server.remote_server",
63+
"--host",
64+
host,
65+
"--port",
66+
str(SERVER_PORT),
67+
"--force-early-error",
68+
"test error",
69+
]
70+
)
71+
# wait for the server to startup by pollingK
72+
wait_for_server_to_startup()
73+
yield
74+
process.terminate()
75+
process.wait()
76+
77+
78+
def rows() -> List[EvaluationRow]:
79+
row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])
80+
return [row]
81+
82+
83+
@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}])
84+
@evaluation_test(
85+
data_loaders=DynamicDataLoader(
86+
generators=[rows],
87+
),
88+
rollout_processor=RemoteRolloutProcessor(
89+
remote_base_url=f"http://127.0.0.1:{SERVER_PORT}",
90+
timeout_seconds=30,
91+
),
92+
)
93+
async def test_remote_rollout_and_fetch_fireworks_propagate_status(row: EvaluationRow) -> EvaluationRow:
94+
assert row.rollout_status.code == Status.Code.INTERNAL
95+
assert row.rollout_status.message == "test error"
96+
return row

0 commit comments

Comments
 (0)