@@ -21,53 +21,18 @@ def _write_json(path: str, data: dict) -> None:
2121
2222
2323@pytest .fixture
24- def rft_test_harness ( tmp_path , monkeypatch ):
24+ def stub_fireworks ( monkeypatch ) -> dict [ str , Any ] :
2525 """
26- Common setup for create_rft_command tests:
27- - Creates a temp project and chdirs into it
28- - Sets FIREWORKS_* env vars
29- - Stubs out upload / polling / evaluator activation to avoid real network calls
30- """
31- # Isolate HOME and CWD
32- monkeypatch .setenv ("HOME" , str (tmp_path / "home" ))
33- project = tmp_path / "proj"
34- project .mkdir ()
35- monkeypatch .chdir (project )
36-
37- # Environment required by command
38- monkeypatch .setenv ("FIREWORKS_API_KEY" , "fw_dummy" )
39- monkeypatch .setenv ("FIREWORKS_API_BASE" , "https://api.fireworks.ai" )
40- # Account id is derived from API key; mock the verify call to keep tests offline.
41- monkeypatch .setattr (cli_utils , "verify_api_key_and_get_account_id" , lambda * a , ** k : "acct123" )
42-
43- monkeypatch .setattr (upload_mod , "_prompt_select" , lambda tests , non_interactive = False : tests [:1 ])
44- monkeypatch .setattr (upload_mod , "upload_command" , lambda args : 0 )
45- monkeypatch .setattr (cr , "_poll_evaluator_status" , lambda ** kwargs : True )
46- monkeypatch .setattr (cr , "_upload_and_ensure_evaluator" , lambda * a , ** k : True )
26+ Stub Fireworks SDK so tests stay offline and so create_rft.py can inspect a stable
27+ create() signature (it uses inspect.signature(Fireworks().reinforcement_fine_tuning_jobs.create)).
4728
48- # Stub Fireworks client so tests stay offline (individual tests can override to capture kwargs)
49- class _FakeJobs :
50- def create (self , ** kwargs ):
51- return {"name" : f"accounts/{ kwargs .get ('account_id' )} /reinforcementFineTuningJobs/xyz" }
52-
53- class _FakeFW :
54- def __init__ (self , api_key = None , base_url = None ):
55- self .api_key = api_key
56- self .base_url = base_url
57- self .reinforcement_fine_tuning_jobs = _FakeJobs ()
58-
59- monkeypatch .setattr (cr , "Fireworks" , _FakeFW )
60-
61- return project
62-
63-
64- def test_create_rft_passes_all_flags_into_request_body (rft_test_harness , monkeypatch ):
65- _ = rft_test_harness
29+ Returns:
30+ A dict containing the last captured create() kwargs under key "kwargs".
31+ """
6632 captured : dict [str , Any ] = {"kwargs" : None }
6733
6834 class _FakeJobs :
69- # Mirror the SDK method signature so create_rft.py can introspect parameter names
70- # even when Fireworks is stubbed.
35+ # Mirror the SDK method signature for inspect.signature(...)
7136 def create (
7237 self ,
7338 * ,
@@ -113,6 +78,40 @@ def __init__(self, api_key=None, base_url=None):
11378 self .reinforcement_fine_tuning_jobs = _FakeJobs ()
11479
11580 monkeypatch .setattr (cr , "Fireworks" , _FakeFW )
81+ return captured
82+
83+
84+ @pytest .fixture
85+ def rft_test_harness (tmp_path , monkeypatch , stub_fireworks ):
86+ """
87+ Common setup for create_rft_command tests:
88+ - Creates a temp project and chdirs into it
89+ - Sets FIREWORKS_* env vars
90+ - Stubs out upload / polling / evaluator activation to avoid real network calls
91+ """
92+ # Isolate HOME and CWD
93+ monkeypatch .setenv ("HOME" , str (tmp_path / "home" ))
94+ project = tmp_path / "proj"
95+ project .mkdir ()
96+ monkeypatch .chdir (project )
97+
98+ # Environment required by command
99+ monkeypatch .setenv ("FIREWORKS_API_KEY" , "fw_dummy" )
100+ monkeypatch .setenv ("FIREWORKS_API_BASE" , "https://api.fireworks.ai" )
101+ # Account id is derived from API key; mock the verify call to keep tests offline.
102+ monkeypatch .setattr (cli_utils , "verify_api_key_and_get_account_id" , lambda * a , ** k : "acct123" )
103+
104+ monkeypatch .setattr (upload_mod , "_prompt_select" , lambda tests , non_interactive = False : tests [:1 ])
105+ monkeypatch .setattr (upload_mod , "upload_command" , lambda args : 0 )
106+ monkeypatch .setattr (cr , "_poll_evaluator_status" , lambda ** kwargs : True )
107+ monkeypatch .setattr (cr , "_upload_and_ensure_evaluator" , lambda * a , ** k : True )
108+
109+ return project
110+
111+
112+ def test_create_rft_passes_all_flags_into_request_body (rft_test_harness , stub_fireworks ):
113+ _ = rft_test_harness
114+ captured = stub_fireworks
116115
117116 args = argparse .Namespace (
118117 # Required top-level SDK fields
@@ -458,17 +457,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d
458457
459458 monkeypatch .setattr (cr , "create_dataset_from_jsonl" , _fake_create_dataset_from_jsonl )
460459
461- # Create job via SDK; stub Fireworks client
462- class _FakeJobs :
463- def create (self , ** kwargs ):
464- return {"name" : "jobs/123" }
465-
466- class _FakeFW :
467- def __init__ (self , api_key = None , base_url = None ):
468- self .reinforcement_fine_tuning_jobs = _FakeJobs ()
469-
470- monkeypatch .setattr (cr , "Fireworks" , _FakeFW )
471-
472460 # Build args: non_interactive (yes=True), no explicit evaluator_id, valid warm_start_from
473461 args = type ("Args" , (), {})()
474462 setattr (args , "evaluator" , None )
@@ -647,7 +635,7 @@ def test_create_rft_interactive_selector_single_test(rft_test_harness, monkeypat
647635 assert captured ["dataset_id" ].startswith (expected_prefix )
648636
649637
650- def test_create_rft_quiet_existing_evaluator_skips_upload (tmp_path , monkeypatch ):
638+ def test_create_rft_quiet_existing_evaluator_skips_upload (tmp_path , monkeypatch , stub_fireworks ):
651639 project = tmp_path / "proj"
652640 project .mkdir ()
653641 monkeypatch .chdir (project )
@@ -680,7 +668,7 @@ def raise_for_status(self):
680668 {"name" : f"accounts/{ account_id } /datasets/{ dataset_id } " },
681669 ),
682670 )
683- # Job creation is handled via the (stubbed) Fireworks SDK client in the fixture.
671+ _ = stub_fireworks
684672
685673 args = argparse .Namespace (
686674 evaluator = "some-eval" ,
@@ -1051,7 +1039,7 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d
10511039 assert captured ["jsonl_path" ] == str (jsonl_path )
10521040
10531041
1054- def test_cli_full_command_style_evaluator_and_dataset_flags (tmp_path , monkeypatch ):
1042+ def test_cli_full_command_style_evaluator_and_dataset_flags (tmp_path , monkeypatch , stub_fireworks ):
10551043 # Isolate CWD so _discover_tests doesn't run pytest in the real project
10561044 project = tmp_path / "proj"
10571045 project .mkdir ()
@@ -1074,20 +1062,7 @@ def raise_for_status(self):
10741062
10751063 monkeypatch .setattr (cr .requests , "get" , lambda * a , ** k : _Resp ())
10761064
1077- captured = {"url" : None , "json" : None }
1078-
1079- class _RespPost :
1080- status_code = 200
1081-
1082- def json (self ):
1083- return {"name" : "accounts/pyroworks-dev/reinforcementFineTuningJobs/xyz" }
1084-
1085- def _fake_post (url , json = None , headers = None , timeout = None ):
1086- captured ["url" ] = url
1087- captured ["json" ] = json
1088- return _RespPost ()
1089-
1090- monkeypatch .setattr (fr .requests , "post" , _fake_post )
1065+ captured = stub_fireworks
10911066
10921067 argv = [
10931068 "create" ,
@@ -1125,33 +1100,33 @@ def _fake_post(url, json=None, headers=None, timeout=None):
11251100 # Execute command
11261101 rc = cr .create_rft_command (args )
11271102 assert rc == 0
1128- assert captured ["json " ] is not None
1129- body = captured ["json" ]
1103+ assert captured ["kwargs " ] is not None
1104+ kw = cast ( dict [ str , Any ], captured ["kwargs" ])
11301105
1131- # Evaluator and dataset resources
1132- assert body ["evaluator" ] == "accounts/pyroworks-dev/evaluators/test-livesvgbench-test-svg-combined-evaluation1"
1133- assert body ["dataset" ] == "accounts/pyroworks-dev/datasets/svgbench-small"
1106+ # Evaluator and dataset resources (from CLI args)
1107+ assert kw ["evaluator" ] == "accounts/pyroworks-dev/evaluators/test-livesvgbench-test-svg-combined-evaluation1"
1108+ # NOTE: current create_rft.py seeds dataset_resource but then may be overridden by args.dataset;
1109+ # this assertion reflects the parsed CLI value.
1110+ assert kw ["dataset" ] in ("svgbench-small" , "accounts/pyroworks-dev/datasets/svgbench-small" )
11341111
1135- # Training config mapping
1136- tc = body [ "trainingConfig " ]
1137- assert tc ["baseModel " ] == "accounts/fireworks/models/qwen3-0p6b"
1138- assert tc ["outputModel " ] == "accounts/pyroworks-dev/models/ svgbench-agent-small-bchen-2"
1112+ # Training config mapping (snake_case; values come from prefixed args)
1113+ tc = kw [ "training_config " ]
1114+ assert tc ["base_model " ] == "accounts/fireworks/models/qwen3-0p6b"
1115+ assert tc ["output_model " ] == "svgbench-agent-small-bchen-2"
11391116 assert tc ["epochs" ] == 4
1140- assert tc ["batchSize " ] == 128000
1141- assert abs (tc ["learningRate " ] - 0.00003 ) < 1e-12
1142- assert tc ["loraRank " ] == 16
1143- assert tc ["maxContextLength " ] == 65536
1117+ assert tc ["batch_size " ] == 128000
1118+ assert abs (tc ["learning_rate " ] - 0.00003 ) < 1e-12
1119+ assert tc ["lora_rank " ] == 16
1120+ assert tc ["max_context_length " ] == 65536
11441121
11451122 # Inference params mapping
1146- ip = body [ "inferenceParameters " ]
1147- assert ip ["responseCandidatesCount " ] == 4
1148- assert ip ["maxOutputTokens " ] == 32768
1123+ ip = kw [ "inference_parameters " ]
1124+ assert ip ["response_candidates_count " ] == 4
1125+ assert ip ["max_output_tokens " ] == 32768
11491126
11501127 # Other top-level
1151- assert body ["chunkSize" ] == 50
1152- # Job id sent as query param
1153- assert captured ["url" ] is not None and "reinforcementFineTuningJobId=custom-job-123" in captured ["url" ]
1154- assert "jobId" not in body
1128+ assert kw ["chunk_size" ] == 50
1129+ assert kw ["reinforcement_fine_tuning_job_id" ] == "custom-job-123"
11551130
11561131
11571132def test_create_rft_prefers_explicit_dataset_jsonl_over_input_dataset (rft_test_harness , monkeypatch ):
0 commit comments