|
32 | 32 | preview_command = None # type: ignore[assignment] |
33 | 33 |
|
34 | 34 |
|
35 | | -def parse_args(args=None): |
36 | | - """Parse command line arguments""" |
37 | | - parser = argparse.ArgumentParser(description="eval-protocol: Tools for evaluation and reward modeling") |
| 35 | +def build_parser() -> argparse.ArgumentParser: |
| 36 | + """Build and return the argument parser for the CLI.""" |
| 37 | + parser = argparse.ArgumentParser( |
| 38 | + description="Inspect evaluation runs locally, upload evaluators, and create reinforcement fine-tuning jobs on Fireworks" |
| 39 | + ) |
| 40 | + return _configure_parser(parser) |
| 41 | + |
| 42 | + |
| 43 | +def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: |
| 44 | + """Configure all arguments and subparsers on the given parser.""" |
38 | 45 | parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose logging") |
39 | 46 | parser.add_argument( |
40 | 47 | "--server", |
@@ -392,39 +399,52 @@ def parse_args(args=None): |
392 | 399 | rft_parser.add_argument("--base-model", help="Base model resource id") |
393 | 400 | rft_parser.add_argument("--warm-start-from", help="Addon model to warm start from") |
394 | 401 | rft_parser.add_argument("--output-model", help="Output model id (defaults from evaluator)") |
395 | | - rft_parser.add_argument("--epochs", type=int, default=1) |
396 | | - rft_parser.add_argument("--batch-size", type=int, default=128000) |
397 | | - rft_parser.add_argument("--learning-rate", type=float, default=3e-5) |
398 | | - rft_parser.add_argument("--max-context-length", type=int, default=65536) |
399 | | - rft_parser.add_argument("--lora-rank", type=int, default=16) |
| 402 | + rft_parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs") |
| 403 | + rft_parser.add_argument("--batch-size", type=int, default=128000, help="Training batch size in tokens") |
| 404 | + rft_parser.add_argument("--learning-rate", type=float, default=3e-5, help="Learning rate for training") |
| 405 | + rft_parser.add_argument("--max-context-length", type=int, default=65536, help="Maximum context length in tokens") |
| 406 | + rft_parser.add_argument("--lora-rank", type=int, default=16, help="LoRA rank for fine-tuning") |
400 | 407 | rft_parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of gradient accumulation steps") |
401 | | - rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of LR warmup steps") |
402 | | - rft_parser.add_argument("--accelerator-count", type=int) |
403 | | - rft_parser.add_argument("--region", help="Fireworks region enum value") |
404 | | - rft_parser.add_argument("--display-name", help="RFT job display name") |
405 | | - rft_parser.add_argument("--evaluation-dataset", help="Optional separate eval dataset id") |
406 | | - rft_parser.add_argument("--eval-auto-carveout", dest="eval_auto_carveout", action="store_true", default=True) |
407 | | - rft_parser.add_argument("--no-eval-auto-carveout", dest="eval_auto_carveout", action="store_false") |
| 408 | + rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of learning rate warmup steps") |
| 409 | + rft_parser.add_argument("--accelerator-count", type=int, help="Number of accelerators (GPUs) to use") |
| 410 | + rft_parser.add_argument("--region", help="Fireworks region for training") |
| 411 | + rft_parser.add_argument("--display-name", help="Display name for the RFT job") |
| 412 | + rft_parser.add_argument("--evaluation-dataset", help="Separate dataset id for evaluation") |
| 413 | + rft_parser.add_argument( |
| 414 | + "--eval-auto-carveout", |
| 415 | + dest="eval_auto_carveout", |
| 416 | + action="store_true", |
| 417 | + default=True, |
| 418 | + help="Automatically carve out evaluation data from training set", |
| 419 | + ) |
| 420 | + rft_parser.add_argument( |
| 421 | + "--no-eval-auto-carveout", |
| 422 | + dest="eval_auto_carveout", |
| 423 | + action="store_false", |
| 424 | + help="Disable automatic evaluation data carveout", |
| 425 | + ) |
408 | 426 | # Rollout chunking |
409 | 427 | rft_parser.add_argument("--chunk-size", type=int, default=100, help="Data chunk size for rollout batching") |
410 | 428 | # Inference params |
411 | | - rft_parser.add_argument("--temperature", type=float) |
412 | | - rft_parser.add_argument("--top-p", type=float) |
413 | | - rft_parser.add_argument("--top-k", type=int) |
414 | | - rft_parser.add_argument("--max-output-tokens", type=int, default=32768) |
415 | | - rft_parser.add_argument("--response-candidates-count", type=int, default=8) |
| 429 | + rft_parser.add_argument("--temperature", type=float, help="Sampling temperature for rollouts") |
| 430 | + rft_parser.add_argument("--top-p", type=float, help="Top-p (nucleus) sampling parameter") |
| 431 | + rft_parser.add_argument("--top-k", type=int, help="Top-k sampling parameter") |
| 432 | + rft_parser.add_argument("--max-output-tokens", type=int, default=32768, help="Maximum output tokens per rollout") |
| 433 | + rft_parser.add_argument( |
| 434 | + "--response-candidates-count", type=int, default=8, help="Number of response candidates per prompt" |
| 435 | + ) |
416 | 436 | rft_parser.add_argument("--extra-body", help="JSON string for extra inference params") |
417 | 437 | # MCP server (optional) |
418 | 438 | rft_parser.add_argument( |
419 | 439 | "--mcp-server", |
420 | | - help="The MCP server resource name to use for the reinforcement fine-tuning job.", |
| 440 | + help="MCP server resource name for agentic rollouts", |
421 | 441 | ) |
422 | 442 | # Wandb |
423 | | - rft_parser.add_argument("--wandb-enabled", action="store_true") |
424 | | - rft_parser.add_argument("--wandb-project") |
425 | | - rft_parser.add_argument("--wandb-entity") |
426 | | - rft_parser.add_argument("--wandb-run-id") |
427 | | - rft_parser.add_argument("--wandb-api-key") |
| 443 | + rft_parser.add_argument("--wandb-enabled", action="store_true", help="Enable Weights & Biases logging") |
| 444 | + rft_parser.add_argument("--wandb-project", help="Weights & Biases project name") |
| 445 | + rft_parser.add_argument("--wandb-entity", help="Weights & Biases entity (username or team)") |
| 446 | + rft_parser.add_argument("--wandb-run-id", help="Weights & Biases run id for resuming") |
| 447 | + rft_parser.add_argument("--wandb-api-key", help="Weights & Biases API key") |
428 | 448 | # Misc |
429 | 449 | rft_parser.add_argument("--job-id", help="Specify an explicit RFT job id") |
430 | 450 | rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode") |
@@ -490,6 +510,38 @@ def parse_args(args=None): |
490 | 510 | # help="Run an evaluation using a Hydra configuration. All arguments after 'run' are passed to Hydra.", |
491 | 511 | # ) |
492 | 512 |
|
| 513 | + # Hidden command: export-docs (for generating CLI reference documentation) |
| 514 | + export_docs_parser = subparsers.add_parser("export-docs", help=argparse.SUPPRESS) |
| 515 | + export_docs_parser.add_argument( |
| 516 | + "--output", |
| 517 | + "-o", |
| 518 | + default="./docs/cli-reference.md", |
| 519 | + help="Output markdown file path (default: ./docs/cli-reference.md)", |
| 520 | + ) |
| 521 | + |
| 522 | + # Update metavar to only show visible commands (exclude those with SUPPRESS) |
| 523 | + _hide_suppressed_subparsers(parser) |
| 524 | + |
| 525 | + return parser |
| 526 | + |
| 527 | + |
| 528 | +def _hide_suppressed_subparsers(parser: argparse.ArgumentParser) -> None: |
| 529 | + """Update subparsers to exclude commands with help=SUPPRESS from help output.""" |
| 530 | + for action in parser._actions: |
| 531 | + if isinstance(action, argparse._SubParsersAction): |
| 532 | + # Filter _choices_actions to only visible commands |
| 533 | + choices_actions = getattr(action, "_choices_actions", []) |
| 534 | + visible_actions = [a for a in choices_actions if a.help != argparse.SUPPRESS] |
| 535 | + action._choices_actions = visible_actions |
| 536 | + # Update metavar to match |
| 537 | + visible_names = [a.dest for a in visible_actions] |
| 538 | + if visible_names: |
| 539 | + action.metavar = "{" + ",".join(visible_names) + "}" |
| 540 | + |
| 541 | + |
| 542 | +def parse_args(args=None): |
| 543 | + """Parse command line arguments.""" |
| 544 | + parser = build_parser() |
493 | 545 | # Use parse_known_args to allow Hydra to handle its own arguments |
494 | 546 | return parser.parse_known_args(args) |
495 | 547 |
|
@@ -589,6 +641,10 @@ def _extract_flag_value(argv_list, flag_name): |
589 | 641 | from .cli_commands.local_test import local_test_command |
590 | 642 |
|
591 | 643 | return local_test_command(args) |
| 644 | + elif args.command == "export-docs": |
| 645 | + from .cli_commands.export_docs import export_docs_command |
| 646 | + |
| 647 | + return export_docs_command(args) |
592 | 648 | # elif args.command == "run": |
593 | 649 | # # For the 'run' command, Hydra takes over argument parsing. |
594 | 650 | # |
|
0 commit comments