Skip to content

Commit c209678

Browse files
committed
rft create uses sdk now
1 parent 2bb176d commit c209678

File tree

5 files changed

+222
-308
lines changed

5 files changed

+222
-308
lines changed

eval_protocol/cli.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,8 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
402402
help="Extra flags to pass to 'docker run' when validating evaluator (quoted string, e.g. \"--env-file .env --memory=8g\")",
403403
)
404404

405-
# Everything below has to manually be maintained, can't be auto-generated
405+
# The flags below are Eval Protocol CLI workflow controls (not part of the Fireworks SDK `create()` signature),
406+
# so they can’t be auto-generated via signature introspection and must be maintained here.
406407
rft_parser.add_argument(
407408
"--source-job",
408409
metavar="",
@@ -419,11 +420,9 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
419420
"extra_query",
420421
"extra_body",
421422
"timeout",
422-
"node_count",
423423
"display_name",
424424
"account_id",
425425
},
426-
"loss_config": {"kl_beta", "method"},
427426
"training_config": {"region", "jinja_template"},
428427
"wandb_config": {"run_id"},
429428
}
@@ -433,11 +432,15 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
433432
"wandb_config.entity": ["--wandb-entity"],
434433
"wandb_config.enabled": ["--wandb"],
435434
"reinforcement_fine_tuning_job_id": ["--job-id"],
435+
"loss_config.kl_beta": ["--rl-kl-beta"],
436+
"loss_config.method": ["--rl-loss-method"],
437+
"node_count": ["--nodes"],
436438
}
437439
help_overrides = {
438440
"training_config.gradient_accumulation_steps": "The number of batches to accumulate gradients before updating the model parameters. The effective batch size will be batch-size multiplied by this value.",
439441
"training_config.learning_rate_warmup_steps": "The number of learning rate warmup steps for the reinforcement fine-tuning job.",
440442
"mcp_server": "The MCP server resource name to use for the reinforcement fine-tuning job. (Optional)",
443+
"loss_config.method": "RL loss method for underlying trainers. One of {grpo,dapo}.",
441444
}
442445

443446
create_rft_job_fn = Fireworks().reinforcement_fine_tuning_jobs.create

eval_protocol/cli_commands/create_rft.py

Lines changed: 35 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import argparse
2+
from fireworks._client import Fireworks
3+
from fireworks.types.reinforcement_fine_tuning_job import ReinforcementFineTuningJob
24
import json
35
import os
46
import sys
57
import time
68
from typing import Any, Dict, Optional
7-
9+
import inspect
810
import requests
911
from pydantic import ValidationError
1012

@@ -13,7 +15,6 @@
1315
from ..fireworks_rft import (
1416
build_default_output_model,
1517
create_dataset_from_jsonl,
16-
create_reinforcement_fine_tuning_job,
1718
detect_dataset_builder,
1819
materialize_dataset_via_builder,
1920
)
@@ -33,6 +34,8 @@
3334
)
3435
from .local_test import run_evaluator_test
3536

37+
from fireworks import Fireworks
38+
3639

3740
def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) -> Optional[str]:
3841
"""Import the test module and extract a JSONL path from data_loaders param if present.
@@ -619,126 +622,48 @@ def _create_rft_job(
619622
args: argparse.Namespace,
620623
dry_run: bool,
621624
) -> int:
622-
"""Build and submit the RFT job request."""
623-
# Build training config/body
624-
# Exactly one of base-model or warm-start-from must be provided
625-
base_model_raw = getattr(args, "base_model", None)
626-
warm_start_from_raw = getattr(args, "warm_start_from", None)
627-
# Treat empty/whitespace strings as not provided
628-
base_model = base_model_raw.strip() if isinstance(base_model_raw, str) else base_model_raw
629-
warm_start_from = warm_start_from_raw.strip() if isinstance(warm_start_from_raw, str) else warm_start_from_raw
630-
has_base_model = bool(base_model)
631-
has_warm_start = bool(warm_start_from)
632-
if (not has_base_model and not has_warm_start) or (has_base_model and has_warm_start):
633-
print("Error: exactly one of --base-model or --warm-start-from must be specified.")
634-
return 1
625+
"""Build and submit the RFT job request (via Fireworks SDK)."""
635626

636-
training_config: Dict[str, Any] = {}
637-
if has_base_model:
638-
training_config["baseModel"] = base_model
639-
if has_warm_start:
640-
training_config["warmStartFrom"] = warm_start_from
641-
642-
# Optional hyperparameters
643-
for key, arg_name in [
644-
("epochs", "epochs"),
645-
("batchSize", "batch_size"),
646-
("learningRate", "learning_rate"),
647-
("maxContextLength", "max_context_length"),
648-
("loraRank", "lora_rank"),
649-
("gradientAccumulationSteps", "gradient_accumulation_steps"),
650-
("learningRateWarmupSteps", "learning_rate_warmup_steps"),
651-
("acceleratorCount", "accelerator_count"),
652-
("region", "region"),
653-
]:
654-
val = getattr(args, arg_name, None)
655-
if val is not None:
656-
training_config[key] = val
657-
658-
inference_params: Dict[str, Any] = {}
659-
for key, arg_name in [
660-
("temperature", "temperature"),
661-
("topP", "top_p"),
662-
("topK", "top_k"),
663-
("maxOutputTokens", "max_output_tokens"),
664-
("responseCandidatesCount", "response_candidates_count"),
665-
]:
666-
val = getattr(args, arg_name, None)
667-
if val is not None:
668-
inference_params[key] = val
669-
if getattr(args, "extra_body", None):
670-
extra = getattr(args, "extra_body")
671-
if isinstance(extra, (dict, list)):
672-
try:
673-
inference_params["extraBody"] = json.dumps(extra, ensure_ascii=False)
674-
except (TypeError, ValueError) as e:
675-
print(f"Error: --extra-body dict/list must be JSON-serializable: {e}")
676-
return 1
677-
elif isinstance(extra, str):
678-
inference_params["extraBody"] = extra
679-
else:
680-
print("Error: --extra-body must be a JSON string or a JSON-serializable dict/list.")
681-
return 1
627+
signature = inspect.signature(Fireworks().reinforcement_fine_tuning_jobs.create)
682628

683-
wandb_config: Optional[Dict[str, Any]] = None
684-
if getattr(args, "enabled", False):
685-
wandb_config = {
686-
"enabled": True,
687-
"apiKey": getattr(args, "api_key", None),
688-
"project": getattr(args, "project", None),
689-
"entity": getattr(args, "entity", None),
690-
"runId": getattr(args, "run_id", None),
691-
}
692-
693-
body: Dict[str, Any] = {
694-
"displayName": getattr(args, "display_name", None),
695-
"dataset": dataset_resource,
629+
# Build top-level SDK kwargs
630+
sdk_kwargs: Dict[str, Any] = {
696631
"evaluator": evaluator_resource_name,
697-
"evalAutoCarveout": bool(getattr(args, "eval_auto_carveout", True)),
698-
"trainingConfig": training_config,
699-
"inferenceParameters": inference_params or None,
700-
"wandbConfig": wandb_config,
701-
"chunkSize": getattr(args, "chunk_size", None),
702-
"outputStats": None,
703-
"outputMetrics": None,
704-
"mcpServer": getattr(args, "mcp_server", None),
705-
"jobId": getattr(args, "reinforcement_fine_tuning_job_id", None),
706-
"sourceJob": getattr(args, "source_job", None),
707-
"quiet": getattr(args, "quiet", False),
632+
"dataset": dataset_resource,
708633
}
709-
# Debug: print minimal summary
710-
print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'")
711-
if getattr(args, "evaluation_dataset", None):
712-
body["evaluationDataset"] = args.evaluation_dataset
713634

714-
output_model_arg = getattr(args, "output_model", None)
715-
if output_model_arg:
716-
if len(output_model_arg) > 63:
717-
print(f"Error: Output model name '{output_model_arg}' exceeds 63 characters.")
718-
return 1
719-
body.setdefault("trainingConfig", {})["outputModel"] = f"accounts/{account_id}/models/{output_model_arg}"
720-
else:
721-
# Auto-generate output model name if not provided
722-
auto_output_model = build_default_output_model(evaluator_id)
723-
body.setdefault("trainingConfig", {})["outputModel"] = f"accounts/{account_id}/models/{auto_output_model}"
635+
args_dict = vars(args)
636+
for name in signature.parameters:
637+
prefix = name + "_"
638+
639+
# Collect "flattened" argparse fields back into the nested dict expected by the SDK.
640+
# Example: training_config_epochs=3 becomes sdk_kwargs["training_config"]["epochs"] = 3.
641+
nested = {}
642+
for k, v in args_dict.items():
643+
if v is None:
644+
continue
645+
if not k.startswith(prefix):
646+
continue
647+
nested[k[len(prefix) :]] = v
648+
649+
if nested:
650+
sdk_kwargs[name] = nested
651+
elif args_dict.get(name) is not None:
652+
sdk_kwargs[name] = args_dict[name]
724653

725-
# Clean None fields to avoid noisy payloads
726-
body = {k: v for k, v in body.items() if v is not None}
654+
print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'")
727655

728656
if dry_run:
729-
print("--dry-run: would create RFT job with body:")
730-
print(json.dumps(body, indent=2))
657+
print("--dry-run: would call Fireworks().reinforcement_fine_tuning_jobs.create with kwargs:")
658+
print(json.dumps(sdk_kwargs, indent=2))
731659
_print_links(evaluator_id, dataset_id, None)
732660
return 0
733661

734662
try:
735-
result = create_reinforcement_fine_tuning_job(
736-
account_id=account_id, api_key=api_key, api_base=api_base, body=body
737-
)
738-
job_name = result.get("name") if isinstance(result, dict) else None
739-
print("\n✅ Created Reinforcement Fine-tuning Job")
740-
if job_name:
741-
print(f" name: {job_name}")
663+
fw: Fireworks = Fireworks(api_key=api_key, base_url=api_base)
664+
job: ReinforcementFineTuningJob = fw.reinforcement_fine_tuning_jobs.create(account_id=account_id, **sdk_kwargs)
665+
job_name = job.name
666+
print(f"\n✅ Created Reinforcement Fine-tuning Job: {job_name}")
742667
_print_links(evaluator_id, dataset_id, job_name)
743668
return 0
744669
except Exception as e:

eval_protocol/cli_commands/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,6 @@ def _add_flag(
660660
def add_args_from_callable_signature(
661661
parser: argparse.ArgumentParser,
662662
fn: Callable[..., Any],
663-
*,
664663
overrides: dict[str, str] | None = None,
665664
skip_fields: dict[str, set[str]] | None = None,
666665
aliases: dict[str, list[str]] | None = None,
@@ -676,7 +675,7 @@ def add_args_from_callable_signature(
676675
help = _parse_args_section_from_doc(inspect.getdoc(fn) or "")
677676
hints = typing.get_type_hints(fn, include_extras=True)
678677

679-
for name, param in sig.parameters.items():
678+
for name in sig.parameters.keys():
680679
resolved_type = unwrap_union(hints.get(name))
681680

682681
# Allow one nested layer of TypeDicts
@@ -688,8 +687,10 @@ def add_args_from_callable_signature(
688687
for field_name, field_type in resolved_type.__annotations__.items():
689688
if field_name in field_skip:
690689
continue
691-
flag_name = "--" + field_name.replace("_", "-")
692-
flags = [flag_name] + aliases.get(f"{name}.{field_name}", [])
690+
prefix = name.replace("_", "-")
691+
field_kebab = field_name.replace("_", "-")
692+
flag_name = f"--{prefix}-{field_kebab}"
693+
flags = [flag_name] + aliases.get(f"{name}.{field_name}", []) + [f"--{field_kebab}"]
693694
help_text = help_overrides.get(f"{name}.{field_name}", field_help.get(field_name))
694695

695696
_add_flag(parser, flags, field_hints.get(field_name, field_type), help_text)

eval_protocol/fireworks_rft.py

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -181,33 +181,6 @@ def create_dataset_from_jsonl(
181181
return dataset_id, ds
182182

183183

184-
def create_reinforcement_fine_tuning_job(
185-
account_id: str,
186-
api_key: str,
187-
api_base: str,
188-
body: Dict[str, Any],
189-
) -> Dict[str, Any]:
190-
url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/reinforcementFineTuningJobs"
191-
# Move optional jobId from body to query parameter if provided
192-
job_id = body.get("jobId")
193-
if isinstance(job_id, str):
194-
job_id = job_id.strip()
195-
if job_id:
196-
# Remove from body and append as query param
197-
body.pop("jobId", None)
198-
url = f"{url}?{urlencode({'reinforcementFineTuningJobId': job_id})}"
199-
headers = {
200-
"Authorization": f"Bearer {api_key}",
201-
"Content-Type": "application/json",
202-
"Accept": "application/json",
203-
"User-Agent": get_user_agent(),
204-
}
205-
resp = requests.post(url, json=body, headers=headers, timeout=60)
206-
if resp.status_code not in (200, 201):
207-
raise RuntimeError(f"RFT job creation failed: {resp.status_code} {resp.text}")
208-
return resp.json()
209-
210-
211184
def build_default_dataset_id(evaluator_id: str) -> str:
212185
ts = time.strftime("%Y%m%d%H%M%S")
213186
base = evaluator_id.lower().replace("_", "-")
@@ -217,22 +190,22 @@ def build_default_dataset_id(evaluator_id: str) -> str:
217190
def build_default_output_model(evaluator_id: str) -> str:
218191
base = evaluator_id.lower().replace("_", "-")
219192
uuid_suffix = str(uuid.uuid4())[:4]
220-
193+
221194
# suffix is "-rft-{4chars}" -> 9 chars
222195
suffix_len = 9
223196
max_len = 63
224-
197+
225198
# Check if we need to truncate
226199
if len(base) + suffix_len > max_len:
227200
# Calculate hash of the full base to preserve uniqueness
228201
hash_digest = hashlib.sha256(base.encode("utf-8")).hexdigest()[:6]
229202
# New structure: {truncated_base}-{hash}-{uuid_suffix}
230203
# Space needed for "-{hash}" is 1 + 6 = 7
231204
hash_part_len = 7
232-
205+
233206
allowed_base_len = max_len - suffix_len - hash_part_len
234207
truncated_base = base[:allowed_base_len].strip("-")
235-
208+
236209
return f"{truncated_base}-{hash_digest}-rft-{uuid_suffix}"
237210

238211
return f"{base}-rft-{uuid_suffix}"
@@ -242,7 +215,6 @@ def build_default_output_model(evaluator_id: str) -> str:
242215
"detect_dataset_builder",
243216
"materialize_dataset_via_builder",
244217
"create_dataset_from_jsonl",
245-
"create_reinforcement_fine_tuning_job",
246218
"build_default_dataset_id",
247219
"build_default_output_model",
248220
"_map_api_host_to_app_host",

0 commit comments

Comments
 (0)