Skip to content

Commit 4790cde

Browse files
committed
Merge branch 'main' into derekx/take-out-auth-ini
2 parents 310fb85 + 29afd31 commit 4790cde

File tree

5 files changed

+429
-40
lines changed

5 files changed

+429
-40
lines changed

eval_protocol/cli.py

Lines changed: 82 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,16 @@
3232
preview_command = None # type: ignore[assignment]
3333

3434

35-
def parse_args(args=None):
36-
"""Parse command line arguments"""
37-
parser = argparse.ArgumentParser(description="eval-protocol: Tools for evaluation and reward modeling")
35+
def build_parser() -> argparse.ArgumentParser:
36+
"""Build and return the argument parser for the CLI."""
37+
parser = argparse.ArgumentParser(
38+
description="Inspect evaluation runs locally, upload evaluators, and create reinforcement fine-tuning jobs on Fireworks"
39+
)
40+
return _configure_parser(parser)
41+
42+
43+
def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
44+
"""Configure all arguments and subparsers on the given parser."""
3845
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose logging")
3946
parser.add_argument(
4047
"--server",
@@ -392,39 +399,52 @@ def parse_args(args=None):
392399
rft_parser.add_argument("--base-model", help="Base model resource id")
393400
rft_parser.add_argument("--warm-start-from", help="Addon model to warm start from")
394401
rft_parser.add_argument("--output-model", help="Output model id (defaults from evaluator)")
395-
rft_parser.add_argument("--epochs", type=int, default=1)
396-
rft_parser.add_argument("--batch-size", type=int, default=128000)
397-
rft_parser.add_argument("--learning-rate", type=float, default=3e-5)
398-
rft_parser.add_argument("--max-context-length", type=int, default=65536)
399-
rft_parser.add_argument("--lora-rank", type=int, default=16)
402+
rft_parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs")
403+
rft_parser.add_argument("--batch-size", type=int, default=128000, help="Training batch size in tokens")
404+
rft_parser.add_argument("--learning-rate", type=float, default=3e-5, help="Learning rate for training")
405+
rft_parser.add_argument("--max-context-length", type=int, default=65536, help="Maximum context length in tokens")
406+
rft_parser.add_argument("--lora-rank", type=int, default=16, help="LoRA rank for fine-tuning")
400407
rft_parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of gradient accumulation steps")
401-
rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of LR warmup steps")
402-
rft_parser.add_argument("--accelerator-count", type=int)
403-
rft_parser.add_argument("--region", help="Fireworks region enum value")
404-
rft_parser.add_argument("--display-name", help="RFT job display name")
405-
rft_parser.add_argument("--evaluation-dataset", help="Optional separate eval dataset id")
406-
rft_parser.add_argument("--eval-auto-carveout", dest="eval_auto_carveout", action="store_true", default=True)
407-
rft_parser.add_argument("--no-eval-auto-carveout", dest="eval_auto_carveout", action="store_false")
408+
rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of learning rate warmup steps")
409+
rft_parser.add_argument("--accelerator-count", type=int, help="Number of accelerators (GPUs) to use")
410+
rft_parser.add_argument("--region", help="Fireworks region for training")
411+
rft_parser.add_argument("--display-name", help="Display name for the RFT job")
412+
rft_parser.add_argument("--evaluation-dataset", help="Separate dataset id for evaluation")
413+
rft_parser.add_argument(
414+
"--eval-auto-carveout",
415+
dest="eval_auto_carveout",
416+
action="store_true",
417+
default=True,
418+
help="Automatically carve out evaluation data from training set",
419+
)
420+
rft_parser.add_argument(
421+
"--no-eval-auto-carveout",
422+
dest="eval_auto_carveout",
423+
action="store_false",
424+
help="Disable automatic evaluation data carveout",
425+
)
408426
# Rollout chunking
409427
rft_parser.add_argument("--chunk-size", type=int, default=100, help="Data chunk size for rollout batching")
410428
# Inference params
411-
rft_parser.add_argument("--temperature", type=float)
412-
rft_parser.add_argument("--top-p", type=float)
413-
rft_parser.add_argument("--top-k", type=int)
414-
rft_parser.add_argument("--max-output-tokens", type=int, default=32768)
415-
rft_parser.add_argument("--response-candidates-count", type=int, default=8)
429+
rft_parser.add_argument("--temperature", type=float, help="Sampling temperature for rollouts")
430+
rft_parser.add_argument("--top-p", type=float, help="Top-p (nucleus) sampling parameter")
431+
rft_parser.add_argument("--top-k", type=int, help="Top-k sampling parameter")
432+
rft_parser.add_argument("--max-output-tokens", type=int, default=32768, help="Maximum output tokens per rollout")
433+
rft_parser.add_argument(
434+
"--response-candidates-count", type=int, default=8, help="Number of response candidates per prompt"
435+
)
416436
rft_parser.add_argument("--extra-body", help="JSON string for extra inference params")
417437
# MCP server (optional)
418438
rft_parser.add_argument(
419439
"--mcp-server",
420-
help="The MCP server resource name to use for the reinforcement fine-tuning job.",
440+
help="MCP server resource name for agentic rollouts",
421441
)
422442
# Wandb
423-
rft_parser.add_argument("--wandb-enabled", action="store_true")
424-
rft_parser.add_argument("--wandb-project")
425-
rft_parser.add_argument("--wandb-entity")
426-
rft_parser.add_argument("--wandb-run-id")
427-
rft_parser.add_argument("--wandb-api-key")
443+
rft_parser.add_argument("--wandb-enabled", action="store_true", help="Enable Weights & Biases logging")
444+
rft_parser.add_argument("--wandb-project", help="Weights & Biases project name")
445+
rft_parser.add_argument("--wandb-entity", help="Weights & Biases entity (username or team)")
446+
rft_parser.add_argument("--wandb-run-id", help="Weights & Biases run id for resuming")
447+
rft_parser.add_argument("--wandb-api-key", help="Weights & Biases API key")
428448
# Misc
429449
rft_parser.add_argument("--job-id", help="Specify an explicit RFT job id")
430450
rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode")
@@ -490,6 +510,38 @@ def parse_args(args=None):
490510
# help="Run an evaluation using a Hydra configuration. All arguments after 'run' are passed to Hydra.",
491511
# )
492512

513+
# Hidden command: export-docs (for generating CLI reference documentation)
514+
export_docs_parser = subparsers.add_parser("export-docs", help=argparse.SUPPRESS)
515+
export_docs_parser.add_argument(
516+
"--output",
517+
"-o",
518+
default="./docs/cli-reference.md",
519+
help="Output markdown file path (default: ./docs/cli-reference.md)",
520+
)
521+
522+
# Update metavar to only show visible commands (exclude those with SUPPRESS)
523+
_hide_suppressed_subparsers(parser)
524+
525+
return parser
526+
527+
528+
def _hide_suppressed_subparsers(parser: argparse.ArgumentParser) -> None:
529+
"""Update subparsers to exclude commands with help=SUPPRESS from help output."""
530+
for action in parser._actions:
531+
if isinstance(action, argparse._SubParsersAction):
532+
# Filter _choices_actions to only visible commands
533+
choices_actions = getattr(action, "_choices_actions", [])
534+
visible_actions = [a for a in choices_actions if a.help != argparse.SUPPRESS]
535+
action._choices_actions = visible_actions
536+
# Update metavar to match
537+
visible_names = [a.dest for a in visible_actions]
538+
if visible_names:
539+
action.metavar = "{" + ",".join(visible_names) + "}"
540+
541+
542+
def parse_args(args=None):
543+
"""Parse command line arguments."""
544+
parser = build_parser()
493545
# Use parse_known_args to allow Hydra to handle its own arguments
494546
return parser.parse_known_args(args)
495547

@@ -589,6 +641,10 @@ def _extract_flag_value(argv_list, flag_name):
589641
from .cli_commands.local_test import local_test_command
590642

591643
return local_test_command(args)
644+
elif args.command == "export-docs":
645+
from .cli_commands.export_docs import export_docs_command
646+
647+
return export_docs_command(args)
592648
# elif args.command == "run":
593649
# # For the 'run' command, Hydra takes over argument parsing.
594650
#

0 commit comments

Comments
 (0)