Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ee69c1b
add training loop config dataclass
maanug-nv Jul 16, 2025
09e044d
move file
maanug-nv Jul 17, 2025
b118b24
remove variable batch size options
maanug-nv Jul 24, 2025
29c1294
replace train iters with train samples
maanug-nv Jul 24, 2025
e9375f9
replace eval iters with eval samples
maanug-nv Jul 24, 2025
1684ede
Revert "remove variable batch size options"
maanug-nv Jul 24, 2025
cc56ab3
update some defaults and docstrings
maanug-nv Jul 31, 2025
ff961f7
first draft of factory
maanug-nv Nov 11, 2025
4892f2f
set metadata for rbs
maanug-nv Nov 11, 2025
c22431f
support excluding keys
maanug-nv Nov 11, 2025
57ae59b
change return object
maanug-nv Nov 14, 2025
cff582f
support multiple arg names and arg name prefixes
maanug-nv Nov 14, 2025
17d745c
remove some auto-generated arguments
maanug-nv Nov 14, 2025
47f1639
support default factories
maanug-nv Nov 15, 2025
46e5684
add iterations to skip to config
maanug-nv Nov 15, 2025
74ad7de
support both iters and samples for training
maanug-nv Nov 15, 2025
79c6245
split validation into separate config
maanug-nv Nov 15, 2025
14b348d
add unit tests
maanug-nv Nov 17, 2025
e6d185a
defer to metadata if present on type check failure
maanug-nv Nov 17, 2025
ed96bfe
revert name changes to val config
maanug-nv Nov 17, 2025
5eb7a83
more unit test coverage
maanug-nv Nov 18, 2025
7fcff86
formatting
maanug-nv Nov 19, 2025
7f8792e
fix recursive call
maanug-nv Nov 19, 2025
5fe366f
fix serializability error
maanug-nv Nov 20, 2025
4664c50
interval default must be int
maanug-nv Nov 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions megatron/training/argument_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

import dataclasses
import typing
from typing import Any, Optional
from argparse import ArgumentParser, _ArgumentGroup
import inspect
import itertools
import builtins
import ast
from dataclasses import Field, fields

# TODO: support arg renames

class TypeInferenceError(Exception):
"""Custom exception type to be conditionally handled by ArgumentGroupFactory."""
pass

class ArgumentGroupFactory:
"""Utility that adds an argument group to an ArgumentParser based on the attributes of a dataclass.

This class can be overriden as needed to support dataclasses
that require some customized or additional handling.

Args:
src_cfg_class: The source dataclass type (not instance) whose fields will be
converted into command-line arguments. Each field's type annotation determines
the argument type, default values become argument defaults, and field-level
docstrings are extracted to populate argument help text.
exclude: Optional list of attribute names from `src_cfg_class` to exclude from
argument generation. Useful for omitting internal fields, computed properties,
or attributes that should be configured through other means. If None, all
dataclass fields will be converted to command-line arguments. Default: None.
"""

def __init__(self, src_cfg_class: type, exclude: Optional[list[str]] = None) -> None:
self.src_cfg_class = src_cfg_class
self.field_docstrings = self._get_field_docstrings(src_cfg_class)
self.exclude = set(exclude) if exclude is not None else set()

def _format_arg_name(self, config_attr_name: str, prefix: Optional[str] = None) -> str:
"""Convert dataclass name into appropriate argparse flag name.

Args:
config_attr_name: dataclass attribute name
prefix: prefix string to add to the dataclass attribute name. e.g. 'no' for bool
settings that are default True. A hyphen is added after the prefix. Default: None
"""
arg_name = config_attr_name
if prefix:
arg_name = prefix + '_' + arg_name
arg_name = "--" + arg_name.replace("_", "-")
return arg_name

def _extract_type(self, config_type: type) -> dict[str, Any]:
"""Determine the type, nargs, and choices settings for this argument.

Args:
config_type: attribute type from dataclass
"""
origin = typing.get_origin(config_type)
type_tuple = typing.get_args(config_type)

# Primitive type
if origin is None:
return {"type": config_type}

if origin is typing.Union:
# Handle Optional and Union
if type_tuple[1] == type(None): # Optional type. First element is value inside Optional[]
return self._extract_type(type_tuple[0])
else:
raise TypeInferenceError(f"Unions not supported by argparse: {config_type}")

elif origin is list:
if len(type_tuple) == 1:
kwargs = self._extract_type(type_tuple[0])
kwargs["nargs"] = "+"
return kwargs
else:
raise TypeInferenceError(f"Multi-type lists not supported by argparse: {config_type}")

elif origin is typing.Literal:
choices_types = [type(choice) for choice in type_tuple]
assert all([t == choices_types[0] for t in choices_types]), "Type of each choice in a Literal type should all be the same."
kwargs = {"type": choices_types[0], "choices": type_tuple}
return kwargs
else:
raise TypeInferenceError(f"Unsupported type: {config_type}")


def _build_argparse_kwargs_from_field(self, attribute: Field) -> dict[str, Any]:
"""Assemble kwargs for add_argument().

Args:
attribute: dataclass attribute
"""
argparse_kwargs = {}
argparse_kwargs["arg_names"] = [self._format_arg_name(attribute.name)]
argparse_kwargs["dest"] = attribute.name
argparse_kwargs["help"] = self.field_docstrings[attribute.name]

# dataclasses specifies that both should not be set
if isinstance(attribute.default, type(dataclasses.MISSING)):
# dataclasses specified default_factory must be a zero-argument callable
argparse_kwargs["default"] = attribute.default_factory()
else:
argparse_kwargs["default"] = attribute.default

attr_argparse_meta = None
if attribute.metadata != {} and "argparse_meta" in attribute.metadata:
# save metadata here, but update at the end so the metadata has highest precedence
attr_argparse_meta = attribute.metadata["argparse_meta"]


# if we cannot infer the argparse type, all of this logic may fail. we try to defer
# to the developer-specified metadata if present
try:
argparse_kwargs.update(self._extract_type(attribute.type))

# use store_true or store_false action for enable/disable flags, which doesn't accept a 'type'
if argparse_kwargs["type"] == bool:
argparse_kwargs["action"] = "store_true" if attribute.default == False else "store_false"
argparse_kwargs.pop("type")

# add '--no-*' and '--disable-*' prefix if this is a store_false argument
if argparse_kwargs["action"] == "store_false":
argparse_kwargs["arg_names"] = [self._format_arg_name(attribute.name, prefix="no"), self._format_arg_name(attribute.name, prefix="disable")]
except TypeInferenceError as e:
if attr_argparse_meta is not None:
print(
f"WARNING: Inferring the appropriate argparse argument type from {self.src_cfg_class} "
f"failed for {attribute.name}: {attribute.type}.\n"
"Deferring to attribute metadata. If the metadata is incomplete, 'parser.add_argument()' may fail.\n"
f"Original failure: {e}"
)
else:
raise e

# metadata provided by field takes precedence
if attr_argparse_meta is not None:
argparse_kwargs.update(attr_argparse_meta)

return argparse_kwargs

def build_group(self, parser: ArgumentParser, title: Optional[str] = None) -> _ArgumentGroup:
"""Entrypoint method that adds the argument group to the parser.

Args:
parser: The parser to add arguments to
title: Title for the argument group
"""
arg_group = parser.add_argument_group(title=title, description=self.src_cfg_class.__doc__)
for attr in fields(self.src_cfg_class):
if attr.name in self.exclude:
continue

add_arg_kwargs = self._build_argparse_kwargs_from_field(attr)

arg_names = add_arg_kwargs.pop("arg_names")
arg_group.add_argument(*arg_names, **add_arg_kwargs)

return arg_group

def _get_field_docstrings(self, src_cfg_class: type) -> dict[str, str]:
"""Extract field-level docstrings from a dataclass by inspecting its AST.

Recurses on parent classes of `src_cfg_class`.

Args:
src_cfg_class: Dataclass to get docstrings from.
"""
source = inspect.getsource(src_cfg_class)
tree = ast.parse(source)
root_node = tree.body[0]

assert isinstance(root_node, ast.ClassDef), "Provided object must be a class."

field_docstrings = {}

# Iterate over body of the dataclass using 2-width sliding window.
# When 'a' is an assignment expression and 'b' is a constant, the window is
# lined up with an attribute-docstring pair. The pair can be saved to our dict.
for a, b in itertools.pairwise(root_node.body):
a_cond = isinstance(a, ast.AnnAssign) and isinstance(a.target, ast.Name)
b_cond = isinstance(b, ast.Expr) and isinstance(b.value, ast.Constant)

if a_cond and b_cond:
# These should be guaranteed by typechecks above, but assert just in case
assert isinstance(a.target.id, str), "Dataclass attribute not in the expected format. Name is not a string."
assert isinstance(b.value.value, str), "Dataclass attribute docstring is not a string."

# Formatting
docstring = inspect.cleandoc(b.value.value)
docstring = ' '.join(docstring.split())

field_docstrings[a.target.id] = docstring

# recurse on parent class
base_classes = src_cfg_class.__bases__
if len(base_classes) > 0:
parent_class = base_classes[0]
if parent_class.__name__ not in builtins.__dict__:
field_docstrings.update(self._get_field_docstrings(base_classes[0]))

return field_docstrings
101 changes: 10 additions & 91 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
load_quantization_recipe,
)

from megatron.training.argument_utils import ArgumentGroupFactory

def add_megatron_arguments(parser: argparse.ArgumentParser):
""""Add Megatron-LM arguments to the given parser."""

Expand Down Expand Up @@ -1967,41 +1969,14 @@ def _add_rl_args(parser):
return parser

def _add_training_args(parser):
group = parser.add_argument_group(title='training')
from megatron.training.config import TrainingConfig

train_factory = ArgumentGroupFactory(TrainingConfig)
group = train_factory.build_group(parser, "training")

group.add_argument('--micro-batch-size', type=int, default=None,
help='Batch size per model instance (local batch size). '
'Global batch size is local batch size times data '
'parallel size times number of micro batches.')
group.add_argument('--batch-size', type=int, default=None,
help='Old batch size parameter, do not use. '
'Use --micro-batch-size instead')
group.add_argument('--global-batch-size', type=int, default=None,
help='Training batch size. If set, it should be a '
'multiple of micro-batch-size times data-parallel-size. '
'If this value is None, then '
'use micro-batch-size * data-parallel-size as the '
'global batch size. This choice will result in 1 for '
'number of micro-batches.')
group.add_argument('--rampup-batch-size', nargs='*', default=None,
help='Batch size ramp up with the following values:'
' --rampup-batch-size <start batch size> '
' <batch size incerement> '
' <ramp-up samples> '
'For example:'
' --rampup-batch-size 16 8 300000 \\ '
' --global-batch-size 1024'
'will start with global batch size 16 and over '
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.')
group.add_argument('--decrease-batch-size-if-needed', action='store_true', default=False,
help='If set, decrease batch size if microbatch_size * dp_size'
'does not divide batch_size. Useful for KSO (Keep Soldiering On)'
'to continue making progress if number of healthy GPUs (and'
'corresponding dp_size) does not support current batch_size.'
'Old batch_size will be restored if training is re-started with'
'dp_size that divides batch_size // microbatch_size.')
group.add_argument('--recompute-activations', action='store_true',
help='recompute activation to allow for training '
'with larger models, sequences, and batch sizes.')
Expand Down Expand Up @@ -2070,8 +2045,6 @@ def _add_training_args(parser):
help='Global step to start profiling.')
group.add_argument('--profile-step-end', type=int, default=12,
help='Global step to stop profiling.')
group.add_argument('--iterations-to-skip', nargs='+', type=int, default=[],
help='List of iterations to skip, empty by default.')
group.add_argument('--result-rejected-tracker-filename', type=str, default=None,
help='Optional name of file tracking `result_rejected` events.')
group.add_argument('--disable-gloo-process-groups', action='store_false',
Expand Down Expand Up @@ -2114,47 +2087,19 @@ def _add_training_args(parser):
group.add_argument('--use-cpu-initialization', action='store_true',
default=None,
help='If set, initialize weights on the CPU. This eliminates init differences based on tensor parallelism.')
group.add_argument('--empty-unused-memory-level', default=0, type=int,
choices=[0, 1, 2],
help='Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.')
group.add_argument('--deterministic-mode', action='store_true',
help='Choose code that has deterministic execution. This usually '
'means slower execution, but is good for debugging and testing.')
group.add_argument('--check-weight-hash-across-dp-replicas-interval', type=int, default=None,
help='Interval to check weight hashes are same across DP replicas. If not specified, weight hashes not checked.')
group.add_argument('--calculate-per-token-loss', action='store_true',
help=('Scale cross entropy loss by the number of non-padded tokens in the '
'global batch, versus the default behavior of assuming all tokens are non-padded.'))
group.add_argument('--train-sync-interval', type=int, default=None,
help='Training CPU-GPU synchronization interval, to ensure that CPU is not running too far ahead of GPU.')

# deprecated
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--train-samples', type=int, default=None,
help='Total number of samples to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--log-interval', type=int, default=100,
help='Report loss and timing interval.')
group.add_argument('--exit-interval', type=int, default=None,
help='Exit the program after the iteration is divisible '
'by this value.')
group.add_argument('--exit-duration-in-mins', type=int, default=None,
help='Exit the program after this many minutes.')
group.add_argument('--exit-signal-handler', action='store_true',
help='Dynamically save the checkpoint and shutdown the '
'training if signal is received')
group.add_argument('--exit-signal', type=str, default='SIGTERM',
choices=list(SIGNAL_MAP.keys()),
help='Signal to use for exit signal handler. If not specified, defaults to SIGTERM.')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.')
group.add_argument('--no-masked-softmax-fusion',
Expand Down Expand Up @@ -2242,22 +2187,6 @@ def _add_training_args(parser):
'--use-legacy-models to not use core models.')
group.add_argument('--use-legacy-models', action='store_true',
help='Use the legacy Megatron models, not Megatron-Core models.')
group.add_argument('--manual-gc', action='store_true',
help='Disable the threshold-based default garbage '
'collector and trigger the garbage collection manually. '
'Manual garbage collection helps to align the timing of '
'the collection across ranks which mitigates the impact '
'of CPU-associated jitters. When the manual gc is enabled, '
'garbage collection is performed only at the start and the '
'end of the validation routine by default.')
group.add_argument('--manual-gc-interval', type=int, default=0,
help='Training step interval to trigger manual garbage '
'collection. When the value is set to 0, garbage '
'collection is not triggered between training steps.')
group.add_argument('--no-manual-gc-eval', action='store_false',
help='When using manual garbage collection, disable '
'garbage collection at the start and the end of each '
'evaluation run.', dest='manual_gc_eval')
group.add_argument('--disable-tp-comm-split-ag', action='store_false',
help='Disables the All-Gather overlap with fprop GEMM.',
dest='tp_comm_split_ag')
Expand Down Expand Up @@ -2748,20 +2677,10 @@ def _add_distributed_args(parser):


def _add_validation_args(parser):
group = parser.add_argument_group(title='validation')

group.add_argument('--full-validation', action='store_true', help='If set, each time validation occurs it uses the full validation dataset(s). This currently only works for GPT datasets!')
group.add_argument('--multiple-validation-sets', action='store_true', help='If set, multiple datasets listed in the validation split are evaluated independently with a separate loss for each dataset in the list. This argument requires that no weights are included in the list')
group.add_argument('--eval-iters', type=int, default=100,
help='Number of iterations to run for evaluation'
'validation/test for.')
group.add_argument('--eval-interval', type=int, default=1000,
help='Interval between running evaluation on '
'validation set.')
group.add_argument("--test-mode", action="store_true", help='Run all real-time test alongside the experiment.')
group.add_argument('--skip-train', action='store_true',
default=False, help='If set, bypass the training loop, '
'optionally do evaluation for validation/test, and exit.')
from megatron.training.config import ValidationConfig

val_factory = ArgumentGroupFactory(ValidationConfig)
group = val_factory.build_group(parser, "validation")

return parser

Expand Down
Loading
Loading