Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ee69c1b
add training loop config dataclass
maanug-nv Jul 16, 2025
09e044d
move file
maanug-nv Jul 17, 2025
b118b24
remove variable batch size options
maanug-nv Jul 24, 2025
29c1294
replace train iters with train samples
maanug-nv Jul 24, 2025
e9375f9
replace eval iters with eval samples
maanug-nv Jul 24, 2025
1684ede
Revert "remove variable batch size options"
maanug-nv Jul 24, 2025
cc56ab3
update some defaults and docstrings
maanug-nv Jul 31, 2025
ff961f7
first draft of factory
maanug-nv Nov 11, 2025
4892f2f
set metadata for rbs
maanug-nv Nov 11, 2025
c22431f
support excluding keys
maanug-nv Nov 11, 2025
57ae59b
change return object
maanug-nv Nov 14, 2025
cff582f
support multiple arg names and arg name prefixes
maanug-nv Nov 14, 2025
17d745c
remove some auto-generated arguments
maanug-nv Nov 14, 2025
47f1639
support default factories
maanug-nv Nov 15, 2025
46e5684
add iterations to skip to config
maanug-nv Nov 15, 2025
74ad7de
support both iters and samples for training
maanug-nv Nov 15, 2025
79c6245
split validation into separate config
maanug-nv Nov 15, 2025
14b348d
add unit tests
maanug-nv Nov 17, 2025
e6d185a
defer to metadata if present on type check failure
maanug-nv Nov 17, 2025
ed96bfe
revert name changes to val config
maanug-nv Nov 17, 2025
5eb7a83
more unit test coverage
maanug-nv Nov 18, 2025
7fcff86
formatting
maanug-nv Nov 19, 2025
7f8792e
fix recursive call
maanug-nv Nov 19, 2025
5fe366f
fix serializability error
maanug-nv Nov 20, 2025
4664c50
interval default must be int
maanug-nv Nov 20, 2025
8de278f
add rngconfig dataclass
maanug-nv Nov 21, 2025
0fbe468
generate from rngconfig
maanug-nv Nov 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions megatron/training/argument_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

import dataclasses
import typing
from typing import Any, Optional
from argparse import ArgumentParser, _ArgumentGroup
import inspect
import itertools
import builtins
import ast
from dataclasses import Field, fields

# TODO: support arg renames

class TypeInferenceError(Exception):
"""Custom exception type to be conditionally handled by ArgumentGroupFactory."""
pass

class ArgumentGroupFactory:
"""Utility that adds an argument group to an ArgumentParser based on the attributes of a dataclass.

This class can be overriden as needed to support dataclasses
that require some customized or additional handling.

Args:
src_cfg_class: The source dataclass type (not instance) whose fields will be
converted into command-line arguments. Each field's type annotation determines
the argument type, default values become argument defaults, and field-level
docstrings are extracted to populate argument help text.
exclude: Optional list of attribute names from `src_cfg_class` to exclude from
argument generation. Useful for omitting internal fields, computed properties,
or attributes that should be configured through other means. If None, all
dataclass fields will be converted to command-line arguments. Default: None.
"""

def __init__(self, src_cfg_class: type, exclude: Optional[list[str]] = None) -> None:
self.src_cfg_class = src_cfg_class
self.field_docstrings = self._get_field_docstrings(src_cfg_class)
self.exclude = set(exclude) if exclude is not None else set()

def _format_arg_name(self, config_attr_name: str, prefix: Optional[str] = None) -> str:
"""Convert dataclass name into appropriate argparse flag name.

Args:
config_attr_name: dataclass attribute name
prefix: prefix string to add to the dataclass attribute name. e.g. 'no' for bool
settings that are default True. A hyphen is added after the prefix. Default: None
"""
arg_name = config_attr_name
if prefix:
arg_name = prefix + '_' + arg_name
arg_name = "--" + arg_name.replace("_", "-")
return arg_name

def _extract_type(self, config_type: type) -> dict[str, Any]:
"""Determine the type, nargs, and choices settings for this argument.

Args:
config_type: attribute type from dataclass
"""
origin = typing.get_origin(config_type)
type_tuple = typing.get_args(config_type)

# Primitive type
if origin is None:
return {"type": config_type}

if origin is typing.Union:
# Handle Optional and Union
if type_tuple[1] == type(None): # Optional type. First element is value inside Optional[]
return self._extract_type(type_tuple[0])
else:
raise TypeInferenceError(f"Unions not supported by argparse: {config_type}")

elif origin is list:
if len(type_tuple) == 1:
kwargs = self._extract_type(type_tuple[0])
kwargs["nargs"] = "+"
return kwargs
else:
raise TypeInferenceError(f"Multi-type lists not supported by argparse: {config_type}")

elif origin is typing.Literal:
choices_types = [type(choice) for choice in type_tuple]
assert all([t == choices_types[0] for t in choices_types]), "Type of each choice in a Literal type should all be the same."
kwargs = {"type": choices_types[0], "choices": type_tuple}
return kwargs
else:
raise TypeInferenceError(f"Unsupported type: {config_type}")


def _build_argparse_kwargs_from_field(self, attribute: Field) -> dict[str, Any]:
"""Assemble kwargs for add_argument().

Args:
attribute: dataclass attribute
"""
argparse_kwargs = {}
argparse_kwargs["arg_names"] = [self._format_arg_name(attribute.name)]
argparse_kwargs["dest"] = attribute.name
argparse_kwargs["help"] = self.field_docstrings[attribute.name]

# dataclasses specifies that both should not be set
if isinstance(attribute.default, type(dataclasses.MISSING)):
# dataclasses specified default_factory must be a zero-argument callable
argparse_kwargs["default"] = attribute.default_factory()
else:
argparse_kwargs["default"] = attribute.default

attr_argparse_meta = None
if attribute.metadata != {} and "argparse_meta" in attribute.metadata:
# save metadata here, but update at the end so the metadata has highest precedence
attr_argparse_meta = attribute.metadata["argparse_meta"]


# if we cannot infer the argparse type, all of this logic may fail. we try to defer
# to the developer-specified metadata if present
try:
argparse_kwargs.update(self._extract_type(attribute.type))

# use store_true or store_false action for enable/disable flags, which doesn't accept a 'type'
if argparse_kwargs["type"] == bool:
argparse_kwargs["action"] = "store_true" if attribute.default == False else "store_false"
argparse_kwargs.pop("type")

# add '--no-*' and '--disable-*' prefix if this is a store_false argument
if argparse_kwargs["action"] == "store_false":
argparse_kwargs["arg_names"] = [self._format_arg_name(attribute.name, prefix="no"), self._format_arg_name(attribute.name, prefix="disable")]
except TypeInferenceError as e:
if attr_argparse_meta is not None:
print(
f"WARNING: Inferring the appropriate argparse argument type from {self.src_cfg_class} "
f"failed for {attribute.name}: {attribute.type}.\n"
"Deferring to attribute metadata. If the metadata is incomplete, 'parser.add_argument()' may fail.\n"
f"Original failure: {e}"
)
else:
raise e

# metadata provided by field takes precedence
if attr_argparse_meta is not None:
argparse_kwargs.update(attr_argparse_meta)

return argparse_kwargs

def build_group(self, parser: ArgumentParser, title: Optional[str] = None) -> _ArgumentGroup:
"""Entrypoint method that adds the argument group to the parser.

Args:
parser: The parser to add arguments to
title: Title for the argument group
"""
arg_group = parser.add_argument_group(title=title, description=self.src_cfg_class.__doc__)
for attr in fields(self.src_cfg_class):
if attr.name in self.exclude:
continue

add_arg_kwargs = self._build_argparse_kwargs_from_field(attr)

arg_names = add_arg_kwargs.pop("arg_names")
arg_group.add_argument(*arg_names, **add_arg_kwargs)

return arg_group

def _get_field_docstrings(self, src_cfg_class: type) -> dict[str, str]:
"""Extract field-level docstrings from a dataclass by inspecting its AST.

Recurses on parent classes of `src_cfg_class`.

Args:
src_cfg_class: Dataclass to get docstrings from.
"""
source = inspect.getsource(src_cfg_class)
tree = ast.parse(source)
root_node = tree.body[0]

assert isinstance(root_node, ast.ClassDef), "Provided object must be a class."

field_docstrings = {}

# Iterate over body of the dataclass using 2-width sliding window.
# When 'a' is an assignment expression and 'b' is a constant, the window is
# lined up with an attribute-docstring pair. The pair can be saved to our dict.
for a, b in itertools.pairwise(root_node.body):
a_cond = isinstance(a, ast.AnnAssign) and isinstance(a.target, ast.Name)
b_cond = isinstance(b, ast.Expr) and isinstance(b.value, ast.Constant)

if a_cond and b_cond:
# These should be guaranteed by typechecks above, but assert just in case
assert isinstance(a.target.id, str), "Dataclass attribute not in the expected format. Name is not a string."
assert isinstance(b.value.value, str), "Dataclass attribute docstring is not a string."

# Formatting
docstring = inspect.cleandoc(b.value.value)
docstring = ' '.join(docstring.split())

field_docstrings[a.target.id] = docstring

# recurse on parent class
base_classes = src_cfg_class.__bases__
if len(base_classes) > 0:
parent_class = base_classes[0]
if parent_class.__name__ not in builtins.__dict__:
field_docstrings.update(self._get_field_docstrings(base_classes[0]))

return field_docstrings
Loading
Loading