Skip to content

Commit 1df9e72

Browse files
committed
auto generated cli
1 parent 686ed67 commit 1df9e72

File tree

4 files changed

+324
-441
lines changed

4 files changed

+324
-441
lines changed

eval_protocol/cli.py

Lines changed: 60 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
"""
44

55
import argparse
6+
import inspect
7+
import json
68
import logging
79
import os
810
import sys
911
from pathlib import Path
1012
from typing import Any, cast
13+
from .cli_commands.utils import add_args_from_callable_signature
14+
15+
from fireworks import Fireworks
1116

1217
logger = logging.getLogger(__name__)
1318

@@ -374,87 +379,11 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
374379
"rft",
375380
help="Create a Reinforcement Fine-tuning Job on Fireworks",
376381
)
377-
rft_parser.add_argument(
378-
"--evaluator",
379-
help="Evaluator ID or fully-qualified resource (accounts/{acct}/evaluators/{id}); if omitted, derive from local tests",
380-
)
381-
# Dataset options
382-
rft_parser.add_argument(
383-
"--dataset",
384-
help="Use existing dataset (ID or resource 'accounts/{acct}/datasets/{id}') to skip local materialization",
385-
)
386-
rft_parser.add_argument(
387-
"--dataset-jsonl",
388-
help="Path to JSONL to upload as a new Fireworks dataset",
389-
)
390-
rft_parser.add_argument(
391-
"--dataset-builder",
392-
help="Explicit dataset builder spec (module::function or path::function)",
393-
)
394-
rft_parser.add_argument(
395-
"--dataset-display-name",
396-
help="Display name for dataset on Fireworks (defaults to dataset id)",
397-
)
398-
# Training config and evaluator/job settings
399-
rft_parser.add_argument("--base-model", help="Base model resource id")
400-
rft_parser.add_argument("--warm-start-from", help="Addon model to warm start from")
401-
rft_parser.add_argument("--output-model", help="Output model id (defaults from evaluator)")
402-
rft_parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs")
403-
rft_parser.add_argument("--batch-size", type=int, default=128000, help="Training batch size in tokens")
404-
rft_parser.add_argument("--learning-rate", type=float, default=3e-5, help="Learning rate for training")
405-
rft_parser.add_argument("--max-context-length", type=int, default=65536, help="Maximum context length in tokens")
406-
rft_parser.add_argument("--lora-rank", type=int, default=16, help="LoRA rank for fine-tuning")
407-
rft_parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of gradient accumulation steps")
408-
rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of learning rate warmup steps")
409-
rft_parser.add_argument("--accelerator-count", type=int, help="Number of accelerators (GPUs) to use")
410-
rft_parser.add_argument("--region", help="Fireworks region for training")
411-
rft_parser.add_argument("--display-name", help="Display name for the RFT job")
412-
rft_parser.add_argument("--evaluation-dataset", help="Separate dataset id for evaluation")
413-
rft_parser.add_argument(
414-
"--eval-auto-carveout",
415-
dest="eval_auto_carveout",
416-
action="store_true",
417-
default=True,
418-
help="Automatically carve out evaluation data from training set",
419-
)
420-
rft_parser.add_argument(
421-
"--no-eval-auto-carveout",
422-
dest="eval_auto_carveout",
423-
action="store_false",
424-
help="Disable automatic evaluation data carveout",
425-
)
426-
# Rollout chunking
427-
rft_parser.add_argument("--chunk-size", type=int, default=100, help="Data chunk size for rollout batching")
428-
# Inference params
429-
rft_parser.add_argument("--temperature", type=float, help="Sampling temperature for rollouts")
430-
rft_parser.add_argument("--top-p", type=float, help="Top-p (nucleus) sampling parameter")
431-
rft_parser.add_argument("--top-k", type=int, help="Top-k sampling parameter")
432-
rft_parser.add_argument("--max-output-tokens", type=int, default=32768, help="Maximum output tokens per rollout")
433-
rft_parser.add_argument(
434-
"--response-candidates-count", type=int, default=8, help="Number of response candidates per prompt"
435-
)
436-
rft_parser.add_argument("--extra-body", help="JSON string for extra inference params")
437-
# MCP server (optional)
438-
rft_parser.add_argument(
439-
"--mcp-server",
440-
help="MCP server resource name for agentic rollouts",
441-
)
442-
# Wandb
443-
rft_parser.add_argument("--wandb-enabled", action="store_true", help="Enable Weights & Biases logging")
444-
rft_parser.add_argument("--wandb-project", help="Weights & Biases project name")
445-
rft_parser.add_argument("--wandb-entity", help="Weights & Biases entity (username or team)")
446-
rft_parser.add_argument("--wandb-run-id", help="Weights & Biases run id for resuming")
447-
rft_parser.add_argument("--wandb-api-key", help="Weights & Biases API key")
448-
# Misc
449-
rft_parser.add_argument("--job-id", help="Specify an explicit RFT job id")
382+
450383
rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode")
451-
rft_parser.add_argument("--dry-run", action="store_true", help="Print planned REST calls without sending")
384+
rft_parser.add_argument("--dry-run", action="store_true", help="Print planned SDK call without sending")
452385
rft_parser.add_argument("--force", action="store_true", help="Overwrite existing evaluator with the same ID")
453-
rft_parser.add_argument(
454-
"--skip-validation",
455-
action="store_true",
456-
help="Skip local dataset and evaluator validation before creating the RFT job",
457-
)
386+
rft_parser.add_argument("--skip-validation", action="store_true", help="Skip local dataset/evaluator validation")
458387
rft_parser.add_argument(
459388
"--ignore-docker",
460389
action="store_true",
@@ -463,12 +392,62 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
463392
rft_parser.add_argument(
464393
"--docker-build-extra",
465394
default="",
466-
help="Extra flags to pass to 'docker build' when validating evaluator (quoted string, e.g. \"--no-cache --pull --progress=plain\")",
395+
metavar="",
396+
help="Extra flags to pass to 'docker build' when validating evaluator",
467397
)
468398
rft_parser.add_argument(
469399
"--docker-run-extra",
470400
default="",
471-
help="Extra flags to pass to 'docker run' when validating evaluator (quoted string, e.g. \"--env-file .env --memory=8g\")",
401+
metavar="",
402+
help="Extra flags to pass to 'docker run' when validating evaluator",
403+
)
404+
405+
# Everything below has to manually be maintained, can't be auto-generated
406+
rft_parser.add_argument(
407+
"--source-job",
408+
metavar="",
409+
help="The source reinforcement fine-tuning job to copy configuration from. If other flags are set, they will override the source job's configuration.",
410+
)
411+
rft_parser.add_argument(
412+
"--quiet",
413+
action="store_true",
414+
help="If set, only errors will be printed.",
415+
)
416+
skip_fields = {
417+
"__top_level__": {
418+
"extra_headers",
419+
"extra_query",
420+
"extra_body",
421+
"timeout",
422+
"node_count",
423+
"display_name",
424+
"account_id",
425+
},
426+
"loss_config": {"kl_beta", "method"},
427+
"training_config": {"region", "jinja_template"},
428+
"wandb_config": {"run_id"},
429+
}
430+
aliases = {
431+
"wandb_config.api_key": ["--wandb-api-key"],
432+
"wandb_config.project": ["--wandb-project"],
433+
"wandb_config.entity": ["--wandb-entity"],
434+
"wandb_config.enabled": ["--wandb"],
435+
"reinforcement_fine_tuning_job_id": ["--job-id"],
436+
}
437+
help_overrides = {
438+
"training_config.gradient_accumulation_steps": "The number of batches to accumulate gradients before updating the model parameters. The effective batch size will be batch-size multiplied by this value.",
439+
"training_config.learning_rate_warmup_steps": "The number of learning rate warmup steps for the reinforcement fine-tuning job.",
440+
"mcp_server": "The MCP server resource name to use for the reinforcement fine-tuning job. (Optional)",
441+
}
442+
443+
create_rft_job_fn = Fireworks().reinforcement_fine_tuning_jobs.create
444+
445+
add_args_from_callable_signature(
446+
rft_parser,
447+
create_rft_job_fn,
448+
skip_fields=skip_fields,
449+
aliases=aliases,
450+
help_overrides=help_overrides,
472451
)
473452

474453
# Local test command

eval_protocol/cli_commands/utils.py

Lines changed: 201 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
import os
2+
import ast
23
import sys
34
import time
45
import inspect
5-
import subprocess
6+
import argparse
7+
import typing
8+
import types
69
from dataclasses import dataclass
710
from pathlib import Path
8-
from typing import Any, Dict, Iterable, List, Optional, Tuple
9-
11+
from typing import Any, List, Optional, is_typeddict
12+
import typing_extensions
13+
import inspect
14+
from collections.abc import Callable
1015
import pytest
1116

1217
from ..auth import (
@@ -505,3 +510,196 @@ def _build_entry_point(project_root: str, source_file_path: Optional[str], func_
505510
return f"{rel}::{func_name}"
506511
# Fallback: use filename only
507512
return f"{func_name}.py::{func_name}"
513+
514+
515+
def unwrap_union(tp):
516+
origin = typing.get_origin(tp)
517+
518+
# Handles both typing.Union[...] and PEP604 unions (A | B)
519+
if origin is typing.Union or origin is types.UnionType:
520+
args = [a for a in typing.get_args(tp) if getattr(a, "__name__", "") != "Omit" and a is not type(None)]
521+
return args[0] if args else None
522+
523+
return tp
524+
525+
526+
def argparse_type_from_hint(t: Any) -> Any:
527+
"""Return a callable argparse type for a type hint (minimal unwrapping + fallback).
528+
529+
- Drops Omit/None from unions
530+
- Unwraps Annotated[T, ...] => T
531+
- Falls back to str when the result isn't callable
532+
"""
533+
t = unwrap_union(t)
534+
if typing.get_origin(t) is typing.Annotated:
535+
args = typing.get_args(t)
536+
t = args[0] if args else str
537+
return t if callable(t) else str
538+
539+
540+
def typed_dict_field_docs(typed_dict_cls: type) -> dict[str, str]:
541+
"""
542+
Extract per-field docstrings from a TypedDict class that uses the pattern:
543+
544+
field: Type
545+
'doc...'
546+
547+
Returns { "field": "doc..." }
548+
"""
549+
try:
550+
src = inspect.getsource(typed_dict_cls)
551+
except Exception:
552+
return {}
553+
554+
try:
555+
mod = ast.parse(src)
556+
except SyntaxError:
557+
return {}
558+
559+
# find the class definition
560+
cls_node = None
561+
for node in mod.body:
562+
if isinstance(node, ast.ClassDef) and node.name == typed_dict_cls.__name__:
563+
cls_node = node
564+
break
565+
if cls_node is None:
566+
return {}
567+
568+
docs: dict[str, str] = {}
569+
body = cls_node.body
570+
571+
i = 0
572+
while i < len(body):
573+
node = body[i]
574+
575+
# field: Annotated[...] or field: T
576+
if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
577+
field_name = node.target.id
578+
579+
# next node is a string literal expression => treat as "field doc"
580+
if i + 1 < len(body):
581+
nxt = body[i + 1]
582+
if (
583+
isinstance(nxt, ast.Expr)
584+
and isinstance(nxt.value, ast.Constant)
585+
and isinstance(nxt.value.value, str)
586+
):
587+
docs[field_name] = nxt.value.value.strip()
588+
i += 2
589+
continue
590+
591+
i += 1
592+
593+
return docs
594+
595+
596+
def _parse_args_section_from_doc(doc: str) -> dict[str, str]:
597+
if not doc:
598+
return {}
599+
600+
lines = doc.splitlines()
601+
602+
# find "Args:"
603+
try:
604+
start = next(i for i, line in enumerate(lines) if line.strip() == "Args:")
605+
except StopIteration:
606+
return {}
607+
608+
out: dict[str, str] = {}
609+
cur_name: str | None = None
610+
cur_parts: list[str] = []
611+
612+
for line in lines[start + 1 :]:
613+
# stop if we hit another top-level section header like "Returns:"
614+
if line and not line.startswith(" ") and line.endswith(":"):
615+
break
616+
617+
if not line.strip():
618+
continue
619+
620+
stripped = line.strip()
621+
622+
# New arg header like "dataset: blah"
623+
if ":" in stripped:
624+
name, rest = stripped.split(":", 1)
625+
name = name.strip()
626+
if name.replace("_", "").isalnum():
627+
if cur_name:
628+
out[cur_name] = " ".join(cur_parts).strip()
629+
cur_name = name
630+
cur_parts = [rest.strip()]
631+
continue
632+
633+
# Continuation
634+
if cur_name:
635+
cur_parts.append(stripped)
636+
637+
if cur_name:
638+
out[cur_name] = " ".join(cur_parts).strip()
639+
640+
return out
641+
642+
643+
def _add_flag(
644+
parser: argparse.ArgumentParser,
645+
flags: list[str],
646+
hint: Any,
647+
help_text: str | None,
648+
) -> None:
649+
if unwrap_union(hint) is bool:
650+
parser.add_argument(*flags, action="store_true", help=help_text)
651+
return
652+
parser.add_argument(
653+
*flags,
654+
type=argparse_type_from_hint(hint),
655+
help=help_text,
656+
metavar="",
657+
)
658+
659+
660+
def add_args_from_callable_signature(
661+
parser: argparse.ArgumentParser,
662+
fn: Callable[..., Any],
663+
*,
664+
overrides: dict[str, str] | None = None,
665+
skip_fields: dict[str, set[str]] | None = None,
666+
aliases: dict[str, list[str]] | None = None,
667+
help_overrides: dict[str, str] | None = None,
668+
) -> None:
669+
overrides = overrides or {}
670+
aliases = aliases or {}
671+
help_overrides = help_overrides or {}
672+
skip_fields = skip_fields or {}
673+
top_level_skip = skip_fields.get("__top_level__", set())
674+
675+
sig = inspect.signature(fn)
676+
help = _parse_args_section_from_doc(inspect.getdoc(fn) or "")
677+
hints = typing.get_type_hints(fn, include_extras=True)
678+
679+
for name, param in sig.parameters.items():
680+
resolved_type = unwrap_union(hints.get(name))
681+
682+
# Allow one nested layer of TypeDicts
683+
if resolved_type and typing_extensions.is_typeddict(resolved_type):
684+
field_help = typed_dict_field_docs(resolved_type)
685+
field_hints = typing.get_type_hints(resolved_type, include_extras=True)
686+
field_skip = skip_fields.get(name, set())
687+
688+
for field_name, field_type in resolved_type.__annotations__.items():
689+
if field_name in field_skip:
690+
continue
691+
flag_name = "--" + field_name.replace("_", "-")
692+
flags = [flag_name] + aliases.get(f"{name}.{field_name}", [])
693+
help_text = help_overrides.get(f"{name}.{field_name}", field_help.get(field_name))
694+
695+
_add_flag(parser, flags, field_hints.get(field_name, field_type), help_text)
696+
continue
697+
698+
if name in top_level_skip:
699+
continue
700+
701+
flag_name = "--" + name.replace("_", "-")
702+
flags = [flag_name] + aliases.get(name, [])
703+
help_text = help_overrides.get(name, help.get(name))
704+
705+
_add_flag(parser, flags, hints.get(name), help_text)

0 commit comments

Comments
 (0)