Skip to content

Commit 0d0a50c

Browse files
committed
clean up tests
1 parent c209678 commit 0d0a50c

File tree

1 file changed

+65
-90
lines changed

1 file changed

+65
-90
lines changed

tests/test_cli_create_rft.py

Lines changed: 65 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -21,53 +21,18 @@ def _write_json(path: str, data: dict) -> None:
2121

2222

2323
@pytest.fixture
24-
def rft_test_harness(tmp_path, monkeypatch):
24+
def stub_fireworks(monkeypatch) -> dict[str, Any]:
2525
"""
26-
Common setup for create_rft_command tests:
27-
- Creates a temp project and chdirs into it
28-
- Sets FIREWORKS_* env vars
29-
- Stubs out upload / polling / evaluator activation to avoid real network calls
30-
"""
31-
# Isolate HOME and CWD
32-
monkeypatch.setenv("HOME", str(tmp_path / "home"))
33-
project = tmp_path / "proj"
34-
project.mkdir()
35-
monkeypatch.chdir(project)
36-
37-
# Environment required by command
38-
monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy")
39-
monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai")
40-
# Account id is derived from API key; mock the verify call to keep tests offline.
41-
monkeypatch.setattr(cli_utils, "verify_api_key_and_get_account_id", lambda *a, **k: "acct123")
42-
43-
monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1])
44-
monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0)
45-
monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True)
46-
monkeypatch.setattr(cr, "_upload_and_ensure_evaluator", lambda *a, **k: True)
26+
Stub Fireworks SDK so tests stay offline and so create_rft.py can inspect a stable
27+
create() signature (it uses inspect.signature(Fireworks().reinforcement_fine_tuning_jobs.create)).
4728
48-
# Stub Fireworks client so tests stay offline (individual tests can override to capture kwargs)
49-
class _FakeJobs:
50-
def create(self, **kwargs):
51-
return {"name": f"accounts/{kwargs.get('account_id')}/reinforcementFineTuningJobs/xyz"}
52-
53-
class _FakeFW:
54-
def __init__(self, api_key=None, base_url=None):
55-
self.api_key = api_key
56-
self.base_url = base_url
57-
self.reinforcement_fine_tuning_jobs = _FakeJobs()
58-
59-
monkeypatch.setattr(cr, "Fireworks", _FakeFW)
60-
61-
return project
62-
63-
64-
def test_create_rft_passes_all_flags_into_request_body(rft_test_harness, monkeypatch):
65-
_ = rft_test_harness
29+
Returns:
30+
A dict containing the last captured create() kwargs under key "kwargs".
31+
"""
6632
captured: dict[str, Any] = {"kwargs": None}
6733

6834
class _FakeJobs:
69-
# Mirror the SDK method signature so create_rft.py can introspect parameter names
70-
# even when Fireworks is stubbed.
35+
# Mirror the SDK method signature for inspect.signature(...)
7136
def create(
7237
self,
7338
*,
@@ -113,6 +78,40 @@ def __init__(self, api_key=None, base_url=None):
11378
self.reinforcement_fine_tuning_jobs = _FakeJobs()
11479

11580
monkeypatch.setattr(cr, "Fireworks", _FakeFW)
81+
return captured
82+
83+
84+
@pytest.fixture
85+
def rft_test_harness(tmp_path, monkeypatch, stub_fireworks):
86+
"""
87+
Common setup for create_rft_command tests:
88+
- Creates a temp project and chdirs into it
89+
- Sets FIREWORKS_* env vars
90+
- Stubs out upload / polling / evaluator activation to avoid real network calls
91+
"""
92+
# Isolate HOME and CWD
93+
monkeypatch.setenv("HOME", str(tmp_path / "home"))
94+
project = tmp_path / "proj"
95+
project.mkdir()
96+
monkeypatch.chdir(project)
97+
98+
# Environment required by command
99+
monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy")
100+
monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai")
101+
# Account id is derived from API key; mock the verify call to keep tests offline.
102+
monkeypatch.setattr(cli_utils, "verify_api_key_and_get_account_id", lambda *a, **k: "acct123")
103+
104+
monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1])
105+
monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0)
106+
monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True)
107+
monkeypatch.setattr(cr, "_upload_and_ensure_evaluator", lambda *a, **k: True)
108+
109+
return project
110+
111+
112+
def test_create_rft_passes_all_flags_into_request_body(rft_test_harness, stub_fireworks):
113+
_ = rft_test_harness
114+
captured = stub_fireworks
116115

117116
args = argparse.Namespace(
118117
# Required top-level SDK fields
@@ -458,17 +457,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d
458457

459458
monkeypatch.setattr(cr, "create_dataset_from_jsonl", _fake_create_dataset_from_jsonl)
460459

461-
# Create job via SDK; stub Fireworks client
462-
class _FakeJobs:
463-
def create(self, **kwargs):
464-
return {"name": "jobs/123"}
465-
466-
class _FakeFW:
467-
def __init__(self, api_key=None, base_url=None):
468-
self.reinforcement_fine_tuning_jobs = _FakeJobs()
469-
470-
monkeypatch.setattr(cr, "Fireworks", _FakeFW)
471-
472460
# Build args: non_interactive (yes=True), no explicit evaluator_id, valid warm_start_from
473461
args = type("Args", (), {})()
474462
setattr(args, "evaluator", None)
@@ -647,7 +635,7 @@ def test_create_rft_interactive_selector_single_test(rft_test_harness, monkeypat
647635
assert captured["dataset_id"].startswith(expected_prefix)
648636

649637

650-
def test_create_rft_quiet_existing_evaluator_skips_upload(tmp_path, monkeypatch):
638+
def test_create_rft_quiet_existing_evaluator_skips_upload(tmp_path, monkeypatch, stub_fireworks):
651639
project = tmp_path / "proj"
652640
project.mkdir()
653641
monkeypatch.chdir(project)
@@ -680,7 +668,7 @@ def raise_for_status(self):
680668
{"name": f"accounts/{account_id}/datasets/{dataset_id}"},
681669
),
682670
)
683-
# Job creation is handled via the (stubbed) Fireworks SDK client in the fixture.
671+
_ = stub_fireworks
684672

685673
args = argparse.Namespace(
686674
evaluator="some-eval",
@@ -1051,7 +1039,7 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d
10511039
assert captured["jsonl_path"] == str(jsonl_path)
10521040

10531041

1054-
def test_cli_full_command_style_evaluator_and_dataset_flags(tmp_path, monkeypatch):
1042+
def test_cli_full_command_style_evaluator_and_dataset_flags(tmp_path, monkeypatch, stub_fireworks):
10551043
# Isolate CWD so _discover_tests doesn't run pytest in the real project
10561044
project = tmp_path / "proj"
10571045
project.mkdir()
@@ -1074,20 +1062,7 @@ def raise_for_status(self):
10741062

10751063
monkeypatch.setattr(cr.requests, "get", lambda *a, **k: _Resp())
10761064

1077-
captured = {"url": None, "json": None}
1078-
1079-
class _RespPost:
1080-
status_code = 200
1081-
1082-
def json(self):
1083-
return {"name": "accounts/pyroworks-dev/reinforcementFineTuningJobs/xyz"}
1084-
1085-
def _fake_post(url, json=None, headers=None, timeout=None):
1086-
captured["url"] = url
1087-
captured["json"] = json
1088-
return _RespPost()
1089-
1090-
monkeypatch.setattr(fr.requests, "post", _fake_post)
1065+
captured = stub_fireworks
10911066

10921067
argv = [
10931068
"create",
@@ -1125,33 +1100,33 @@ def _fake_post(url, json=None, headers=None, timeout=None):
11251100
# Execute command
11261101
rc = cr.create_rft_command(args)
11271102
assert rc == 0
1128-
assert captured["json"] is not None
1129-
body = captured["json"]
1103+
assert captured["kwargs"] is not None
1104+
kw = cast(dict[str, Any], captured["kwargs"])
11301105

1131-
# Evaluator and dataset resources
1132-
assert body["evaluator"] == "accounts/pyroworks-dev/evaluators/test-livesvgbench-test-svg-combined-evaluation1"
1133-
assert body["dataset"] == "accounts/pyroworks-dev/datasets/svgbench-small"
1106+
# Evaluator and dataset resources (from CLI args)
1107+
assert kw["evaluator"] == "accounts/pyroworks-dev/evaluators/test-livesvgbench-test-svg-combined-evaluation1"
1108+
# NOTE: current create_rft.py seeds dataset_resource but then may be overridden by args.dataset;
1109+
# this assertion reflects the parsed CLI value.
1110+
assert kw["dataset"] in ("svgbench-small", "accounts/pyroworks-dev/datasets/svgbench-small")
11341111

1135-
# Training config mapping
1136-
tc = body["trainingConfig"]
1137-
assert tc["baseModel"] == "accounts/fireworks/models/qwen3-0p6b"
1138-
assert tc["outputModel"] == "accounts/pyroworks-dev/models/svgbench-agent-small-bchen-2"
1112+
# Training config mapping (snake_case; values come from prefixed args)
1113+
tc = kw["training_config"]
1114+
assert tc["base_model"] == "accounts/fireworks/models/qwen3-0p6b"
1115+
assert tc["output_model"] == "svgbench-agent-small-bchen-2"
11391116
assert tc["epochs"] == 4
1140-
assert tc["batchSize"] == 128000
1141-
assert abs(tc["learningRate"] - 0.00003) < 1e-12
1142-
assert tc["loraRank"] == 16
1143-
assert tc["maxContextLength"] == 65536
1117+
assert tc["batch_size"] == 128000
1118+
assert abs(tc["learning_rate"] - 0.00003) < 1e-12
1119+
assert tc["lora_rank"] == 16
1120+
assert tc["max_context_length"] == 65536
11441121

11451122
# Inference params mapping
1146-
ip = body["inferenceParameters"]
1147-
assert ip["responseCandidatesCount"] == 4
1148-
assert ip["maxOutputTokens"] == 32768
1123+
ip = kw["inference_parameters"]
1124+
assert ip["response_candidates_count"] == 4
1125+
assert ip["max_output_tokens"] == 32768
11491126

11501127
# Other top-level
1151-
assert body["chunkSize"] == 50
1152-
# Job id sent as query param
1153-
assert captured["url"] is not None and "reinforcementFineTuningJobId=custom-job-123" in captured["url"]
1154-
assert "jobId" not in body
1128+
assert kw["chunk_size"] == 50
1129+
assert kw["reinforcement_fine_tuning_job_id"] == "custom-job-123"
11551130

11561131

11571132
def test_create_rft_prefers_explicit_dataset_jsonl_over_input_dataset(rft_test_harness, monkeypatch):

0 commit comments

Comments
 (0)