Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 82 additions & 26 deletions eval_protocol/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,16 @@
preview_command = None # type: ignore[assignment]


def parse_args(args=None):
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description="eval-protocol: Tools for evaluation and reward modeling")
def build_parser() -> argparse.ArgumentParser:
"""Build and return the argument parser for the CLI."""
parser = argparse.ArgumentParser(
description="Inspect evaluation runs locally, upload evaluators, and create reinforcement fine-tuning jobs on Fireworks"
)
return _configure_parser(parser)


def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Configure all arguments and subparsers on the given parser."""
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose logging")
parser.add_argument(
"--profile",
Expand Down Expand Up @@ -396,39 +403,52 @@ def parse_args(args=None):
rft_parser.add_argument("--base-model", help="Base model resource id")
rft_parser.add_argument("--warm-start-from", help="Addon model to warm start from")
rft_parser.add_argument("--output-model", help="Output model id (defaults from evaluator)")
rft_parser.add_argument("--epochs", type=int, default=1)
rft_parser.add_argument("--batch-size", type=int, default=128000)
rft_parser.add_argument("--learning-rate", type=float, default=3e-5)
rft_parser.add_argument("--max-context-length", type=int, default=65536)
rft_parser.add_argument("--lora-rank", type=int, default=16)
rft_parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs")
rft_parser.add_argument("--batch-size", type=int, default=128000, help="Training batch size in tokens")
rft_parser.add_argument("--learning-rate", type=float, default=3e-5, help="Learning rate for training")
rft_parser.add_argument("--max-context-length", type=int, default=65536, help="Maximum context length in tokens")
rft_parser.add_argument("--lora-rank", type=int, default=16, help="LoRA rank for fine-tuning")
rft_parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of gradient accumulation steps")
rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of LR warmup steps")
rft_parser.add_argument("--accelerator-count", type=int)
rft_parser.add_argument("--region", help="Fireworks region enum value")
rft_parser.add_argument("--display-name", help="RFT job display name")
rft_parser.add_argument("--evaluation-dataset", help="Optional separate eval dataset id")
rft_parser.add_argument("--eval-auto-carveout", dest="eval_auto_carveout", action="store_true", default=True)
rft_parser.add_argument("--no-eval-auto-carveout", dest="eval_auto_carveout", action="store_false")
rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of learning rate warmup steps")
rft_parser.add_argument("--accelerator-count", type=int, help="Number of accelerators (GPUs) to use")
rft_parser.add_argument("--region", help="Fireworks region for training")
rft_parser.add_argument("--display-name", help="Display name for the RFT job")
rft_parser.add_argument("--evaluation-dataset", help="Separate dataset id for evaluation")
rft_parser.add_argument(
"--eval-auto-carveout",
dest="eval_auto_carveout",
action="store_true",
default=True,
help="Automatically carve out evaluation data from training set",
)
rft_parser.add_argument(
"--no-eval-auto-carveout",
dest="eval_auto_carveout",
action="store_false",
help="Disable automatic evaluation data carveout",
)
# Rollout chunking
rft_parser.add_argument("--chunk-size", type=int, default=100, help="Data chunk size for rollout batching")
# Inference params
rft_parser.add_argument("--temperature", type=float)
rft_parser.add_argument("--top-p", type=float)
rft_parser.add_argument("--top-k", type=int)
rft_parser.add_argument("--max-output-tokens", type=int, default=32768)
rft_parser.add_argument("--response-candidates-count", type=int, default=8)
rft_parser.add_argument("--temperature", type=float, help="Sampling temperature for rollouts")
rft_parser.add_argument("--top-p", type=float, help="Top-p (nucleus) sampling parameter")
rft_parser.add_argument("--top-k", type=int, help="Top-k sampling parameter")
rft_parser.add_argument("--max-output-tokens", type=int, default=32768, help="Maximum output tokens per rollout")
rft_parser.add_argument(
"--response-candidates-count", type=int, default=8, help="Number of response candidates per prompt"
)
rft_parser.add_argument("--extra-body", help="JSON string for extra inference params")
# MCP server (optional)
rft_parser.add_argument(
"--mcp-server",
help="The MCP server resource name to use for the reinforcement fine-tuning job.",
help="MCP server resource name for agentic rollouts",
)
# Wandb
rft_parser.add_argument("--wandb-enabled", action="store_true")
rft_parser.add_argument("--wandb-project")
rft_parser.add_argument("--wandb-entity")
rft_parser.add_argument("--wandb-run-id")
rft_parser.add_argument("--wandb-api-key")
rft_parser.add_argument("--wandb-enabled", action="store_true", help="Enable Weights & Biases logging")
rft_parser.add_argument("--wandb-project", help="Weights & Biases project name")
rft_parser.add_argument("--wandb-entity", help="Weights & Biases entity (username or team)")
rft_parser.add_argument("--wandb-run-id", help="Weights & Biases run id for resuming")
rft_parser.add_argument("--wandb-api-key", help="Weights & Biases API key")
# Misc
rft_parser.add_argument("--job-id", help="Specify an explicit RFT job id")
rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode")
Expand Down Expand Up @@ -494,6 +514,38 @@ def parse_args(args=None):
# help="Run an evaluation using a Hydra configuration. All arguments after 'run' are passed to Hydra.",
# )

# Hidden command: export-docs (for generating CLI reference documentation)
export_docs_parser = subparsers.add_parser("export-docs", help=argparse.SUPPRESS)
export_docs_parser.add_argument(
"--output",
"-o",
default="./docs/cli-reference.md",
help="Output markdown file path (default: ./docs/cli-reference.md)",
)

# Update metavar to only show visible commands (exclude those with SUPPRESS)
_hide_suppressed_subparsers(parser)

return parser


def _hide_suppressed_subparsers(parser: argparse.ArgumentParser) -> None:
"""Update subparsers to exclude commands with help=SUPPRESS from help output."""
for action in parser._actions:
if isinstance(action, argparse._SubParsersAction):
# Filter _choices_actions to only visible commands
choices_actions = getattr(action, "_choices_actions", [])
visible_actions = [a for a in choices_actions if a.help != argparse.SUPPRESS]
action._choices_actions = visible_actions
# Update metavar to match
visible_names = [a.dest for a in visible_actions]
if visible_names:
action.metavar = "{" + ",".join(visible_names) + "}"


def parse_args(args=None):
"""Parse command line arguments."""
parser = build_parser()
# Use parse_known_args to allow Hydra to handle its own arguments
return parser.parse_known_args(args)

Expand Down Expand Up @@ -623,6 +675,10 @@ def _extract_flag_value(argv_list, flag_name):
from .cli_commands.local_test import local_test_command

return local_test_command(args)
elif args.command == "export-docs":
from .cli_commands.export_docs import export_docs_command

return export_docs_command(args)
# elif args.command == "run":
# # For the 'run' command, Hydra takes over argument parsing.
#
Expand Down
Loading
Loading