diff --git a/megatron/core/safe_globals.py b/megatron/core/safe_globals.py index cc5eb8809e..20d1694f08 100755 --- a/megatron/core/safe_globals.py +++ b/megatron/core/safe_globals.py @@ -3,6 +3,7 @@ from argparse import Namespace from io import BytesIO from pathlib import PosixPath +from signal import Signals from types import SimpleNamespace import torch @@ -30,6 +31,7 @@ RerunMode, RerunState, BytesIO, + Signals, ] diff --git a/megatron/training/argument_utils.py b/megatron/training/argument_utils.py new file mode 100644 index 0000000000..da5dfc0e8d --- /dev/null +++ b/megatron/training/argument_utils.py @@ -0,0 +1,250 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import dataclasses +import typing +import types +from typing import Any, Optional +from argparse import ArgumentParser, _ArgumentGroup +import inspect +import itertools +import builtins +import ast +import enum +from dataclasses import Field, fields + +# TODO: support arg renames + +class TypeInferenceError(Exception): + """Custom exception type to be conditionally handled by ArgumentGroupFactory.""" + pass + +class ArgumentGroupFactory: + """Utility that adds an argument group to an ArgumentParser based on the attributes of a dataclass. + + This utility uses dataclass metadata including type annotations and docstrings to automatically + infer the type, default, and other argparse keyword arguments. + + You can override or supplement the automatically inferred argparse kwargs for any + dataclass field by providing an "argparse_meta" key in the field's metadata dict. + The value should be a dict of kwargs that will be passed to ArgumentParser.add_argument(). + These metadata kwargs take precedence over the automatically inferred values. + + Example: + @dataclass + class YourConfig: + your_attribute: int | str | None = field( + default=None, + metadata={ + "argparse_meta": { + "arg_names": ["--your-arg-name1", "--your-arg-name2"], + "type": str, + "nargs": "+", + "default": "foo", + } + }, + ) + + In this example, inferring the type automatically would fail, as Unions are + not supported. However the metadata is present, so that takes precedence. + Any keyword arguments to `ArgumentParser.add_argument()` can be included in + the "argparse_meta" dict, as well as "arg_names" for the argument flag name. + + This class can also be used as a base class and extended as needed to support dataclasses + that require some customized or additional handling. + + Args: + src_cfg_class: The source dataclass type (not instance) whose fields will be + converted into command-line arguments. Each field's type annotation determines + the argument type, default values become argument defaults, and field-level + docstrings are extracted to populate argument help text. + exclude: Optional list of attribute names from `src_cfg_class` to exclude from + argument generation. Useful for omitting internal fields, computed properties, + or attributes that should be configured through other means. If None, all + dataclass fields will be converted to command-line arguments. Default: None. + """ + + def __init__(self, src_cfg_class: type, exclude: Optional[list[str]] = None) -> None: + self.src_cfg_class = src_cfg_class + self.field_docstrings = self._get_field_docstrings(src_cfg_class) + self.exclude = set(exclude) if exclude is not None else set() + + def _format_arg_name(self, config_attr_name: str, prefix: Optional[str] = None) -> str: + """Convert dataclass name into appropriate argparse flag name. + + Args: + config_attr_name: dataclass attribute name + prefix: prefix string to add to the dataclass attribute name. e.g. 'no' for bool + settings that are default True. A hyphen is added after the prefix. Default: None + """ + arg_name = config_attr_name + if prefix: + arg_name = prefix + '_' + arg_name + arg_name = "--" + arg_name.replace("_", "-") + return arg_name + + def _get_enum_kwargs(self, config_type: enum.EnumMeta) -> dict[str, Any]: + """Build kwargs for Enums. + + With these settings, the user must provide a valid enum value, e.g. + 'flash', for `AttnBackend.flash`. + """ + def enum_type_handler(cli_arg): + return config_type[cli_arg] + + return {"type": enum_type_handler, "choices": list(config_type)} + + def _extract_type(self, config_type: type) -> dict[str, Any]: + """Determine the type, nargs, and choices settings for this argument. + + Args: + config_type: attribute type from dataclass + """ + origin = typing.get_origin(config_type) + type_tuple = typing.get_args(config_type) + + if isinstance(config_type, type) and issubclass(config_type, enum.Enum): + return self._get_enum_kwargs(config_type) + + # Primitive type + if origin is None: + return {"type": config_type} + + if origin in [types.UnionType, typing.Union]: + # Handle Optional and Union + if type_tuple[1] == type(None): # Optional type. First element is value inside Optional[] + return self._extract_type(type_tuple[0]) + else: + raise TypeInferenceError(f"Unions not supported by argparse: {config_type}") + + elif origin is list: + if len(type_tuple) == 1: + kwargs = self._extract_type(type_tuple[0]) + kwargs["nargs"] = "+" + return kwargs + else: + raise TypeInferenceError(f"Multi-type lists not supported by argparse: {config_type}") + + elif origin is typing.Literal: + choices_types = [type(choice) for choice in type_tuple] + assert all([t == choices_types[0] for t in choices_types]), "Type of each choice in a Literal type should all be the same." + kwargs = {"type": choices_types[0], "choices": type_tuple} + return kwargs + else: + raise TypeInferenceError(f"Unsupported type: {config_type}") + + + def _build_argparse_kwargs_from_field(self, attribute: Field) -> dict[str, Any]: + """Assemble kwargs for add_argument(). + + Args: + attribute: dataclass attribute + """ + argparse_kwargs = {} + argparse_kwargs["arg_names"] = [self._format_arg_name(attribute.name)] + argparse_kwargs["dest"] = attribute.name + argparse_kwargs["help"] = self.field_docstrings[attribute.name] + + # dataclasses specifies that both should not be set + if isinstance(attribute.default, type(dataclasses.MISSING)): + # dataclasses specified default_factory must be a zero-argument callable + argparse_kwargs["default"] = attribute.default_factory() + else: + argparse_kwargs["default"] = attribute.default + + attr_argparse_meta = None + if attribute.metadata != {} and "argparse_meta" in attribute.metadata: + # save metadata here, but update at the end so the metadata has highest precedence + attr_argparse_meta = attribute.metadata["argparse_meta"] + + + # if we cannot infer the argparse type, all of this logic may fail. we try to defer + # to the developer-specified metadata if present + try: + argparse_kwargs.update(self._extract_type(attribute.type)) + + # use store_true or store_false action for enable/disable flags, which doesn't accept a 'type' + if argparse_kwargs["type"] == bool: + argparse_kwargs["action"] = "store_true" if attribute.default == False else "store_false" + argparse_kwargs.pop("type") + + # add '--no-*' and '--disable-*' prefix if this is a store_false argument + if argparse_kwargs["action"] == "store_false": + argparse_kwargs["arg_names"] = [self._format_arg_name(attribute.name, prefix="no"), self._format_arg_name(attribute.name, prefix="disable")] + except TypeInferenceError as e: + if attr_argparse_meta is not None: + print( + f"WARNING: Inferring the appropriate argparse argument type from {self.src_cfg_class} " + f"failed for {attribute.name}: {attribute.type}.\n" + "Deferring to attribute metadata. If the metadata is incomplete, 'parser.add_argument()' may fail.\n" + f"Original failure: {e}" + ) + else: + raise e + + # metadata provided by field takes precedence + if attr_argparse_meta is not None: + argparse_kwargs.update(attr_argparse_meta) + + return argparse_kwargs + + def build_group(self, parser: ArgumentParser, title: Optional[str] = None) -> _ArgumentGroup: + """Entrypoint method that adds the argument group to the parser. + + Args: + parser: The parser to add arguments to + title: Title for the argument group + """ + arg_group = parser.add_argument_group(title=title, description=self.src_cfg_class.__doc__) + for attr in fields(self.src_cfg_class): + if attr.name in self.exclude: + continue + + add_arg_kwargs = self._build_argparse_kwargs_from_field(attr) + + arg_names = add_arg_kwargs.pop("arg_names") + arg_group.add_argument(*arg_names, **add_arg_kwargs) + + return arg_group + + def _get_field_docstrings(self, src_cfg_class: type) -> dict[str, str]: + """Extract field-level docstrings from a dataclass by inspecting its AST. + + Recurses on parent classes of `src_cfg_class`. + + Args: + src_cfg_class: Dataclass to get docstrings from. + """ + source = inspect.getsource(src_cfg_class) + tree = ast.parse(source) + root_node = tree.body[0] + + assert isinstance(root_node, ast.ClassDef), "Provided object must be a class." + + field_docstrings = {} + + # Iterate over body of the dataclass using 2-width sliding window. + # When 'a' is an assignment expression and 'b' is a constant, the window is + # lined up with an attribute-docstring pair. The pair can be saved to our dict. + for a, b in itertools.pairwise(root_node.body): + a_cond = isinstance(a, ast.AnnAssign) and isinstance(a.target, ast.Name) + b_cond = isinstance(b, ast.Expr) and isinstance(b.value, ast.Constant) + + if a_cond and b_cond: + # These should be guaranteed by typechecks above, but assert just in case + assert isinstance(a.target.id, str), "Dataclass attribute not in the expected format. Name is not a string." + assert isinstance(b.value.value, str), "Dataclass attribute docstring is not a string." + + # Formatting + docstring = inspect.cleandoc(b.value.value) + docstring = ' '.join(docstring.split()) + + field_docstrings[a.target.id] = docstring + + # recurse on parent class + base_classes = src_cfg_class.__bases__ + if len(base_classes) > 0: + parent_class = base_classes[0] + if parent_class.__name__ not in builtins.__dict__: + field_docstrings.update(self._get_field_docstrings(base_classes[0])) + + return field_docstrings diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index a4d1a07d83..aa16086854 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -34,7 +34,6 @@ ) from megatron.core.activations import squared_relu from megatron.core.fusions.fused_bias_geglu import quick_gelu -from megatron.training.dist_signal_handler import SIGNAL_MAP from megatron.training.utils import ( get_device_arch_version, update_use_dist_ckpt, @@ -48,6 +47,8 @@ load_quantization_recipe, ) +from megatron.training.argument_utils import ArgumentGroupFactory + def add_megatron_arguments(parser: argparse.ArgumentParser): """"Add Megatron-LM arguments to the given parser.""" @@ -1384,11 +1385,6 @@ def _add_transformer_engine_args(parser): help='Keep the compute param in fp4 (do not use any other intermediate ' 'dtype) and perform the param all-gather in fp4.', dest='fp4_param') - group.add_argument('--te-rng-tracker', action='store_true', default=False, - help='Use the Transformer Engine version of the random number generator. ' - 'Required for CUDA graphs support.') - group.add_argument('--inference-rng-tracker', action='store_true', default=False, - help='Use a random number generator configured for inference.') return parser def _add_inference_args(parser): @@ -1987,41 +1983,14 @@ def _add_rl_args(parser): return parser def _add_training_args(parser): - group = parser.add_argument_group(title='training') + from megatron.training.config import TrainingConfig + + train_factory = ArgumentGroupFactory(TrainingConfig) + group = train_factory.build_group(parser, "training") - group.add_argument('--micro-batch-size', type=int, default=None, - help='Batch size per model instance (local batch size). ' - 'Global batch size is local batch size times data ' - 'parallel size times number of micro batches.') group.add_argument('--batch-size', type=int, default=None, help='Old batch size parameter, do not use. ' 'Use --micro-batch-size instead') - group.add_argument('--global-batch-size', type=int, default=None, - help='Training batch size. If set, it should be a ' - 'multiple of micro-batch-size times data-parallel-size. ' - 'If this value is None, then ' - 'use micro-batch-size * data-parallel-size as the ' - 'global batch size. This choice will result in 1 for ' - 'number of micro-batches.') - group.add_argument('--rampup-batch-size', nargs='*', default=None, - help='Batch size ramp up with the following values:' - ' --rampup-batch-size ' - ' ' - ' ' - 'For example:' - ' --rampup-batch-size 16 8 300000 \\ ' - ' --global-batch-size 1024' - 'will start with global batch size 16 and over ' - ' (1024 - 16) / 8 = 126 intervals will increase' - 'the batch size linearly to 1024. In each interval' - 'we will use approximately 300000 / 126 = 2380 samples.') - group.add_argument('--decrease-batch-size-if-needed', action='store_true', default=False, - help='If set, decrease batch size if microbatch_size * dp_size' - 'does not divide batch_size. Useful for KSO (Keep Soldiering On)' - 'to continue making progress if number of healthy GPUs (and' - 'corresponding dp_size) does not support current batch_size.' - 'Old batch_size will be restored if training is re-started with' - 'dp_size that divides batch_size // microbatch_size.') group.add_argument('--recompute-activations', action='store_true', help='recompute activation to allow for training ' 'with larger models, sequences, and batch sizes.') @@ -2090,8 +2059,6 @@ def _add_training_args(parser): help='Global step to start profiling.') group.add_argument('--profile-step-end', type=int, default=12, help='Global step to stop profiling.') - group.add_argument('--iterations-to-skip', nargs='+', type=int, default=[], - help='List of iterations to skip, empty by default.') group.add_argument('--result-rejected-tracker-filename', type=str, default=None, help='Optional name of file tracking `result_rejected` events.') group.add_argument('--disable-gloo-process-groups', action='store_false', @@ -2134,47 +2101,19 @@ def _add_training_args(parser): group.add_argument('--use-cpu-initialization', action='store_true', default=None, help='If set, initialize weights on the CPU. This eliminates init differences based on tensor parallelism.') - group.add_argument('--empty-unused-memory-level', default=0, type=int, - choices=[0, 1, 2], - help='Call torch.cuda.empty_cache() each iteration ' - '(training and eval), to reduce fragmentation.' - '0=off, 1=moderate, 2=aggressive.') group.add_argument('--deterministic-mode', action='store_true', help='Choose code that has deterministic execution. This usually ' 'means slower execution, but is good for debugging and testing.') - group.add_argument('--check-weight-hash-across-dp-replicas-interval', type=int, default=None, - help='Interval to check weight hashes are same across DP replicas. If not specified, weight hashes not checked.') group.add_argument('--calculate-per-token-loss', action='store_true', help=('Scale cross entropy loss by the number of non-padded tokens in the ' 'global batch, versus the default behavior of assuming all tokens are non-padded.')) - group.add_argument('--train-sync-interval', type=int, default=None, - help='Training CPU-GPU synchronization interval, to ensure that CPU is not running too far ahead of GPU.') # deprecated group.add_argument('--checkpoint-activations', action='store_true', help='Checkpoint activation to allow for training ' 'with larger models, sequences, and batch sizes.') - group.add_argument('--train-iters', type=int, default=None, - help='Total number of iterations to train over all ' - 'training runs. Note that either train-iters or ' - 'train-samples should be provided.') - group.add_argument('--train-samples', type=int, default=None, - help='Total number of samples to train over all ' - 'training runs. Note that either train-iters or ' - 'train-samples should be provided.') group.add_argument('--log-interval', type=int, default=100, help='Report loss and timing interval.') - group.add_argument('--exit-interval', type=int, default=None, - help='Exit the program after the iteration is divisible ' - 'by this value.') - group.add_argument('--exit-duration-in-mins', type=int, default=None, - help='Exit the program after this many minutes.') - group.add_argument('--exit-signal-handler', action='store_true', - help='Dynamically save the checkpoint and shutdown the ' - 'training if signal is received') - group.add_argument('--exit-signal', type=str, default='SIGTERM', - choices=list(SIGNAL_MAP.keys()), - help='Signal to use for exit signal handler. If not specified, defaults to SIGTERM.') group.add_argument('--tensorboard-dir', type=str, default=None, help='Write TensorBoard logs to this directory.') group.add_argument('--no-masked-softmax-fusion', @@ -2262,22 +2201,6 @@ def _add_training_args(parser): '--use-legacy-models to not use core models.') group.add_argument('--use-legacy-models', action='store_true', help='Use the legacy Megatron models, not Megatron-Core models.') - group.add_argument('--manual-gc', action='store_true', - help='Disable the threshold-based default garbage ' - 'collector and trigger the garbage collection manually. ' - 'Manual garbage collection helps to align the timing of ' - 'the collection across ranks which mitigates the impact ' - 'of CPU-associated jitters. When the manual gc is enabled, ' - 'garbage collection is performed only at the start and the ' - 'end of the validation routine by default.') - group.add_argument('--manual-gc-interval', type=int, default=0, - help='Training step interval to trigger manual garbage ' - 'collection. When the value is set to 0, garbage ' - 'collection is not triggered between training steps.') - group.add_argument('--no-manual-gc-eval', action='store_false', - help='When using manual garbage collection, disable ' - 'garbage collection at the start and the end of each ' - 'evaluation run.', dest='manual_gc_eval') group.add_argument('--disable-tp-comm-split-ag', action='store_false', help='Disables the All-Gather overlap with fprop GEMM.', dest='tp_comm_split_ag') @@ -2315,14 +2238,11 @@ def _add_rerun_machine_args(parser): def _add_initialization_args(parser): - group = parser.add_argument_group(title='initialization') - - group.add_argument('--seed', type=int, default=1234, - help='Random seed used for python, numpy, ' - 'pytorch, and cuda.') - group.add_argument('--data-parallel-random-init', action='store_true', - help='Enable random initialization of params ' - 'across data parallel ranks') + from megatron.training.config import RNGConfig + + rng_factory = ArgumentGroupFactory(RNGConfig) + group = rng_factory.build_group(parser, "RNG and initialization") + group.add_argument('--init-method-std', type=float, default=0.02, help='Standard deviation of the zero mean normal ' 'distribution used for weight initialization.') @@ -2768,20 +2688,10 @@ def _add_distributed_args(parser): def _add_validation_args(parser): - group = parser.add_argument_group(title='validation') - - group.add_argument('--full-validation', action='store_true', help='If set, each time validation occurs it uses the full validation dataset(s). This currently only works for GPT datasets!') - group.add_argument('--multiple-validation-sets', action='store_true', help='If set, multiple datasets listed in the validation split are evaluated independently with a separate loss for each dataset in the list. This argument requires that no weights are included in the list') - group.add_argument('--eval-iters', type=int, default=100, - help='Number of iterations to run for evaluation' - 'validation/test for.') - group.add_argument('--eval-interval', type=int, default=1000, - help='Interval between running evaluation on ' - 'validation set.') - group.add_argument("--test-mode", action="store_true", help='Run all real-time test alongside the experiment.') - group.add_argument('--skip-train', action='store_true', - default=False, help='If set, bypass the training loop, ' - 'optionally do evaluation for validation/test, and exit.') + from megatron.training.config import ValidationConfig + + val_factory = ArgumentGroupFactory(ValidationConfig) + group = val_factory.build_group(parser, "validation") return parser diff --git a/megatron/training/config.py b/megatron/training/config.py new file mode 100644 index 0000000000..f75adb0947 --- /dev/null +++ b/megatron/training/config.py @@ -0,0 +1,135 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +from dataclasses import dataclass, field +import signal +from typing import Literal + +@dataclass(kw_only=True) +class TrainingConfig: + """Configuration settings related to the training loop.""" + + micro_batch_size: int | None = None + """Batch size per model instance (local batch size). Global batch size is local batch size times + data parallel size times number of micro batches.""" + + global_batch_size: int | None = None + """Training batch size. If set, it should be a multiple of micro-batch-size times + data-parallel-size. If this value is None, then use micro-batch-size * data-parallel-size + as the global batch size. This choice will result in 1 for number of micro-batches.""" + + rampup_batch_size: list[int] | None = field(default=None, metadata={"argparse_meta": {"nargs": 3}}) + """Batch size ramp up with the following values: , , + + For example: + rampup-batch-size = [16, 8, 300000] + global-batch-size 1024 + will start with global batch size 16 and over (1024 - 16) / 8 = 126 intervals will increase + the batch size linearly to 1024. In each interval we will use approximately + 300000 / 126 = 2380 samples. + """ + + decrease_batch_size_if_needed: bool = False + """If set, decrease batch size if microbatch_size * dp_size does not + divide batch_size. Old batch_size will be restored if training is re-started + with dp_size that divides batch_size // microbatch_size.""" + + empty_unused_memory_level: Literal[0, 1, 2] = 0 + """Call torch.cuda.empty_cache() each iteration (training and eval), to reduce fragmentation. + 0=off, 1=moderate, 2=aggressive. + """ + + check_weight_hash_across_dp_replicas_interval: int | None = None + """Interval to check weight hashes are same across DP replicas. If not specified, weight hashes not checked.""" + + train_sync_interval: int | None = None + """Training CPU-GPU synchronization interval, to ensure that CPU is not running too far ahead of GPU.""" + + train_iters: int | None = None + """Total number of iterations to train over all training runs. + Note that either train_iters or train_samples should be provided. + """ + + train_samples: int | None = None + """Total number of samples to train over all training runs. + Note that either train_iters or train_samples should be provided.""" + + exit_interval: int | None = None + """Exit the program after the iteration is divisible by this value.""" + + exit_duration_in_mins: int | None = None + """Exit the program after this many minutes.""" + + exit_signal_handler: bool = False + """Dynamically save the checkpoint and shutdown the training if SIGTERM is received""" + + exit_signal: signal.Signals = signal.SIGTERM + """Signal for the signal handler to detect.""" + + exit_signal_handler_for_dataloader: bool = False + """Use signal handler for dataloader workers""" + + manual_gc: bool = False + """Disable the threshold-based default garbage collector and trigger the garbage collection + manually. Manual garbage collection helps to align the timing of the collection across ranks + which mitigates the impact of CPU-associated jitters. When the manual gc is enabled, garbage + collection is performed only at the start and the end of the validation routine by default.""" + + manual_gc_interval: int = 0 + """Training step interval to trigger manual garbage collection. Values > 0 will trigger garbage + collections between training steps. + """ + + manual_gc_eval: bool = True + """When using manual garbage collection, this controls garbage collection at the start and the + end of each evaluation run. + """ + + iterations_to_skip: list[int] = field(default_factory=list) + """List of iterations to skip during training, empty by default.""" + + +@dataclass(kw_only=True) +class ValidationConfig: + """Configuration settings related to validation during or after model training.""" + + eval_iters: int | None = 100 + """Number of iterations to run for evaluation. Used for both validation and test. If not set, + evaluation will not run.""" + + eval_interval: int | None = None + """Interval between running evaluation on validation set. If not set, evaluation will not run + during training. + """ + + skip_train: bool = False + """If set, bypass the training loop, perform evaluation for validation/test, and exit.""" + + test_mode: bool = False + """Run all real-time test alongside the experiment.""" + + full_validation: bool = False + """If set, each time validation occurs it uses the full validation dataset(s). This currently only works for GPT datasets!""" + + multiple_validation_sets: bool = False + """If set, multiple datasets listed in the validation split are evaluated independently with a + separate loss for each dataset in the list. This argument requires that no weights are + included in the list. + """ + + +@dataclass(kw_only=True) +class RNGConfig: + """Configuration settings for random number generation.""" + + seed: int = 1234 + """Random seed used for python, numpy, pytorch, and cuda.""" + + te_rng_tracker: bool = False + """Use the Transformer Engine version of the random number generator. + Required for CUDA graphs support.""" + + inference_rng_tracker: bool = False + """Use a random number generator configured for inference.""" + + data_parallel_random_init: bool = False + """Enable random initialization of params across data parallel ranks""" + diff --git a/megatron/training/dist_signal_handler.py b/megatron/training/dist_signal_handler.py index f1f3725c8a..0ecd706fdc 100644 --- a/megatron/training/dist_signal_handler.py +++ b/megatron/training/dist_signal_handler.py @@ -3,13 +3,6 @@ import torch -SIGNAL_MAP = { - 'SIGTERM': signal.SIGTERM, - 'SIGINT': signal.SIGINT, - 'SIGUSR1': signal.SIGUSR1, - 'SIGUSR2': signal.SIGUSR2 -} - def get_world_size(): if torch.distributed.is_available() and torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() @@ -55,8 +48,8 @@ def all_gather_item(item, dtype, group=None, async_op=False, local_rank=None): class DistributedSignalHandler: - def __init__(self, sig: str = 'SIGTERM'): - self.sig = SIGNAL_MAP.get(sig, signal.SIGTERM) + def __init__(self, sig: signal.Signals = signal.SIGTERM): + self.sig = sig def signals_received(self): all_received = all_gather_item( diff --git a/tests/unit_tests/test_argument_utils.py b/tests/unit_tests/test_argument_utils.py new file mode 100644 index 0000000000..e5744c3b07 --- /dev/null +++ b/tests/unit_tests/test_argument_utils.py @@ -0,0 +1,643 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import signal +from argparse import ArgumentError, ArgumentParser +from dataclasses import dataclass, field +from typing import Callable, Literal, Optional, Union + +import pytest + +from megatron.training.argument_utils import ArgumentGroupFactory, TypeInferenceError + + +@dataclass +class DummyConfig: + """A dummy configuration for testing.""" + + name: str = "default_name" + """Name of the configuration""" + + count: int = 42 + """Number of items""" + + learning_rate: float = 0.001 + """Learning rate for training""" + + enabled: bool = False + """Whether feature is enabled""" + + disabled_feature: bool = True + """Feature that is disabled by default""" + + enum_setting: signal.Signals = signal.SIGTERM + """Setting with enum type to test enum handling""" + + +@dataclass +class ConfigWithOptional: + """Config with optional fields.""" + + required_field: str = "required" + """A required field""" + + optional_field: Optional[int] = None + """An optional integer field""" + + optional_str: Optional[str] = "default" + """An optional string with default""" + + int_new_form: int | None = None + """Optional using new syntax""" + + str_new_form: str | None = "default" + """Optional string using new syntax""" + + +@dataclass +class ConfigWithList: + """Config with list fields.""" + + tags: list[str] = field(default_factory=list) + """List of tags""" + + numbers: list[int] = field(default_factory=lambda: [1, 2, 3]) + """List of numbers with default""" + + +@dataclass +class ConfigWithLiteral: + """Config with Literal types.""" + + mode: Literal["train", "eval", "test"] = "train" + """Operating mode""" + + precision: Literal[16, 32] = 32 + """Precision level""" + + +class TestArgumentGroupFactoryBasic: + """Test basic functionality of ArgumentGroupFactory.""" + + def test_creates_argument_group(self): + """Test that build_group creates an argument group.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(DummyConfig) + + arg_group = factory.build_group(parser, title="Test Group") + + assert arg_group is not None + assert arg_group.title == "Test Group" + assert arg_group.description == DummyConfig.__doc__ + + def test_all_fields_added(self): + """Test that all dataclass fields are added as arguments.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(DummyConfig) + + factory.build_group(parser, title="Test Group") + + # Parse empty args to get all defaults + args = parser.parse_args([]) + + # Check all fields exist + assert hasattr(args, 'name') + assert hasattr(args, 'count') + assert hasattr(args, 'learning_rate') + assert hasattr(args, 'enabled') + assert hasattr(args, 'disabled_feature') + + def test_default_values_preserved(self): + """Test that default values from dataclass are preserved.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(DummyConfig) + + factory.build_group(parser, title="Test Group") + args = parser.parse_args([]) + + assert args.name == "default_name" + assert args.count == 42 + assert args.learning_rate == 0.001 + assert args.enabled == False + assert args.disabled_feature == True + + def test_argument_types(self): + """Test that argument types are correctly inferred.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(DummyConfig) + + factory.build_group(parser, title="Test Group") + + # Parse with actual values + args = parser.parse_args( + ['--name', 'test_name', '--count', '100', '--learning-rate', '0.01'] + ) + + assert isinstance(args.name, str) + assert args.name == 'test_name' + assert isinstance(args.count, int) + assert args.count == 100 + assert isinstance(args.learning_rate, float) + assert args.learning_rate == 0.01 + + def test_boolean_store_true(self): + """Test that boolean fields with default False use store_true.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(DummyConfig) + + factory.build_group(parser, title="Test Group") + + # Without flag, should be False + args = parser.parse_args([]) + assert args.enabled == False + + # With flag, should be True + args = parser.parse_args(['--enabled']) + assert args.enabled == True + + def test_boolean_store_false(self): + """Test that boolean fields with default True use store_false with no- prefix.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(DummyConfig) + + factory.build_group(parser, title="Test Group") + + # Without flag, should be True + args = parser.parse_args([]) + assert args.disabled_feature == True + + # With --no- flag, should be False + args = parser.parse_args(['--no-disabled-feature']) + assert args.disabled_feature == False + + # With --disable- flag, should also be False + args = parser.parse_args(['--disable-disabled-feature']) + assert args.disabled_feature == False + + def test_field_docstrings_as_help(self): + """Test that field docstrings are extracted and used as help text.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(DummyConfig) + + # Check that field_docstrings were extracted + assert 'name' in factory.field_docstrings + assert factory.field_docstrings['name'] == "Name of the configuration" + assert factory.field_docstrings['count'] == "Number of items" + assert factory.field_docstrings['learning_rate'] == "Learning rate for training" + + def test_enum_handling(self): + """Test that enum types are handled correctly.""" + parser = ArgumentParser(exit_on_error=False) + factory = ArgumentGroupFactory(DummyConfig) + + factory.build_group(parser, title="Test Group") + + args = parser.parse_args([]) + assert args.enum_setting == signal.SIGTERM + + # test a different valid enum value + args = parser.parse_args(["--enum-setting", "SIGINT"]) + assert args.enum_setting == signal.SIGINT + + # test an invalid enum value + with pytest.raises(KeyError, match="sigbar"): + parser.parse_args(["--enum-setting", "sigbar"]) + + +class TestArgumentGroupFactoryExclusion: + """Test exclusion functionality.""" + + def test_exclude_single_field(self): + """Test excluding a single field.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(DummyConfig, exclude=['count']) + + factory.build_group(parser, title="Test Group") + args = parser.parse_args([]) + + # Excluded field should not exist + assert hasattr(args, 'name') + assert not hasattr(args, 'count') + assert hasattr(args, 'learning_rate') + + def test_exclude_multiple_fields(self): + """Test excluding multiple fields.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(DummyConfig, exclude=['count', 'learning_rate']) + + factory.build_group(parser, title="Test Group") + args = parser.parse_args([]) + + assert hasattr(args, 'name') + assert not hasattr(args, 'count') + assert not hasattr(args, 'learning_rate') + assert hasattr(args, 'enabled') + + +class TestArgumentGroupFactoryOptional: + """Test handling of Optional types.""" + + def test_optional_fields(self): + """Test that Optional fields are handled correctly.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithOptional) + + factory.build_group(parser, title="Test Group") + + # Default values + args = parser.parse_args([]) + assert args.required_field == "required" + assert args.optional_field is None + assert args.optional_str == "default" + + # Provided values + args = parser.parse_args( + ['--required-field', 'new_value', '--optional-field', '123', '--optional-str', 'custom'] + ) + assert args.required_field == "new_value" + assert args.optional_field == 123 + assert args.optional_str == "custom" + + +class TestArgumentGroupFactoryList: + """Test handling of list types.""" + + def test_list_fields_with_default_factory(self): + """Test that list fields use nargs='+'.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithList) + + factory.build_group(parser, title="Test Group") + + # Default values + args = parser.parse_args([]) + assert args.tags == [] + assert args.numbers == [1, 2, 3] + + # Provided values + args = parser.parse_args(['--tags', 'tag1', 'tag2', 'tag3', '--numbers', '10', '20', '30']) + assert args.tags == ['tag1', 'tag2', 'tag3'] + assert args.numbers == [10, 20, 30] + + +class TestArgumentGroupFactoryLiteral: + """Test handling of Literal types.""" + + def test_literal_fields_have_choices(self): + """Test that Literal types create choice constraints.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithLiteral) + + factory.build_group(parser, title="Test Group") + + # Default values + args = parser.parse_args([]) + assert args.mode == "train" + assert args.precision == 32 + + # Valid choices + args = parser.parse_args(['--mode', 'eval', '--precision', '16']) + assert args.mode == "eval" + assert args.precision == 16 + + def test_literal_fields_reject_invalid_choices(self): + """Test that invalid Literal choices are rejected.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithLiteral) + + factory.build_group(parser, title="Test Group") + + # Invalid choice should raise error + with pytest.raises(SystemExit): + parser.parse_args(['--mode', 'invalid']) + + with pytest.raises(SystemExit): + parser.parse_args(['--precision', '64']) + + +class TestArgumentGroupFactoryHelpers: + """Test helper methods.""" + + def test_format_arg_name_basic(self): + """Test basic argument name formatting.""" + factory = ArgumentGroupFactory(DummyConfig) + + assert factory._format_arg_name("simple") == "--simple" + assert factory._format_arg_name("with_underscore") == "--with-underscore" + assert factory._format_arg_name("multiple_under_scores") == "--multiple-under-scores" + + def test_format_arg_name_with_prefix(self): + """Test argument name formatting with prefix.""" + factory = ArgumentGroupFactory(DummyConfig) + + assert factory._format_arg_name("feature", prefix="no") == "--no-feature" + assert factory._format_arg_name("feature", prefix="disable") == "--disable-feature" + assert factory._format_arg_name("multi_word", prefix="no") == "--no-multi-word" + + def test_extract_type_primitive(self): + """Test type extraction for primitive types.""" + factory = ArgumentGroupFactory(DummyConfig) + + assert factory._extract_type(int) == {"type": int} + assert factory._extract_type(str) == {"type": str} + assert factory._extract_type(float) == {"type": float} + + def test_extract_type_optional(self): + """Test type extraction for Optional types.""" + factory = ArgumentGroupFactory(DummyConfig) + + result = factory._extract_type(Optional[int]) + assert result == {"type": int} + + result = factory._extract_type(Optional[str]) + assert result == {"type": str} + + def test_extract_type_list(self): + """Test type extraction for list types.""" + factory = ArgumentGroupFactory(DummyConfig) + + result = factory._extract_type(list[int]) + assert result == {"type": int, "nargs": "+"} + + result = factory._extract_type(list[str]) + assert result == {"type": str, "nargs": "+"} + + def test_extract_type_literal(self): + """Test type extraction for Literal types.""" + factory = ArgumentGroupFactory(DummyConfig) + + result = factory._extract_type(Literal["a", "b", "c"]) + assert result == {"type": str, "choices": ("a", "b", "c")} + + result = factory._extract_type(Literal[1, 2, 3]) + assert result == {"type": int, "choices": (1, 2, 3)} + + +@dataclass +class ConfigWithArgparseMeta: + """Config with argparse_meta metadata for testing overrides.""" + + custom_help: str = field( + default="default_value", + metadata={"argparse_meta": {"help": "Custom help text from metadata"}}, + ) + """Original help text""" + + custom_type: str = field(default="100", metadata={"argparse_meta": {"type": int}}) + """Field with type override""" + + custom_default: str = field( + default="original_default", metadata={"argparse_meta": {"default": "overridden_default"}} + ) + """Field with default override""" + + custom_choices: str = field( + default="option1", + metadata={"argparse_meta": {"choices": ["option1", "option2", "option3"]}}, + ) + """Field with choices override""" + + custom_dest: str = field( + default="value", metadata={"argparse_meta": {"dest": "renamed_destination"}} + ) + """Field with dest override""" + + custom_action: bool = field( + default=False, + metadata={"argparse_meta": {"action": "store_const", "const": "special_value"}}, + ) + """Field with custom action override""" + + multiple_overrides: int = field( + default=42, + metadata={ + "argparse_meta": { + "type": str, + "help": "Multiple overrides applied", + "default": "999", + "dest": "multi_override_dest", + } + }, + ) + """Field with multiple metadata overrides""" + + nargs_override: str = field(default="single", metadata={"argparse_meta": {"nargs": "?"}}) + """Field with nargs override""" + + +@dataclass +class ConfigWithUnsupportedCallables: + """Config with argparse_meta metadata for testing overrides.""" + + unsupported_type: Optional[Callable] = None + """Cannot take a callable over CLI""" + + unsupported_with_metadata: Optional[Callable] = field( + default=None, metadata={"argparse_meta": {"type": int, "choices": (0, 1, 2)}} + ) + """This argument should be 0, 1, or 2. The appropriate + Callable will be set by some other logic. + """ + + +@dataclass +class ConfigWithUnsupportedUnions: + """Config with argparse_meta metadata for testing overrides.""" + + unsupported_type: Union[int, str] = 0 + """Cannot infer type of a Union""" + + unsupported_with_metadata: Union[int, str] = field( + default=0, metadata={"argparse_meta": {"type": str, "choices": ("foo", "bar")}} + ) + """Metadata should take precedence over the exception caused by Union""" + + +class TestArgumentGroupFactoryArgparseMeta: + """Test argparse_meta metadata override functionality.""" + + def test_help_override(self): + """Test that argparse_meta can override help text.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithArgparseMeta) + + factory.build_group(parser, title="Test Group") + + # Find the action for this argument + for action in parser._actions: + if hasattr(action, 'dest') and action.dest == 'custom_help': + assert action.help == "Custom help text from metadata" + return + + pytest.fail("custom_help argument not found") + + def test_type_override(self): + """Test that argparse_meta can override argument type.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithArgparseMeta) + + factory.build_group(parser, title="Test Group") + + # Parse with integer value (metadata overrides type to int) + args = parser.parse_args(['--custom-type', '42']) + + # Should be parsed as int, not str + assert isinstance(args.custom_type, int) + assert args.custom_type == 42 + + def test_default_override(self): + """Test that argparse_meta can override default value.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithArgparseMeta) + + factory.build_group(parser, title="Test Group") + + # Parse with no arguments + args = parser.parse_args([]) + + # Should use metadata default, not field default + assert args.custom_default == "overridden_default" + + def test_choices_override(self): + """Test that argparse_meta can override choices.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithArgparseMeta) + + factory.build_group(parser, title="Test Group") + + # Valid choice from metadata + args = parser.parse_args(['--custom-choices', 'option2']) + assert args.custom_choices == "option2" + + # Invalid choice should fail + with pytest.raises(SystemExit): + parser.parse_args(['--custom-choices', 'invalid_option']) + + def test_dest_override(self): + """Test that argparse_meta can override destination name.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithArgparseMeta) + + factory.build_group(parser, title="Test Group") + + args = parser.parse_args(['--custom-dest', 'test_value']) + + # Should be stored in renamed destination + assert hasattr(args, 'renamed_destination') + assert args.renamed_destination == "test_value" + + def test_action_override(self): + """Test that argparse_meta can override action.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithArgparseMeta) + + factory.build_group(parser, title="Test Group") + + # With custom action=store_const and const="special_value" + args = parser.parse_args(['--custom-action']) + assert args.custom_action == "special_value" + + # Without flag, should use default + args = parser.parse_args([]) + assert args.custom_action == False + + def test_multiple_overrides(self): + """Test that multiple argparse_meta overrides work together.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithArgparseMeta) + + factory.build_group(parser, title="Test Group") + + # Parse with no arguments to check default override + args = parser.parse_args([]) + + # Check all overrides applied + assert hasattr(args, 'multi_override_dest') + assert args.multi_override_dest == "999" # default override + + # Parse with value to check type override + args = parser.parse_args(['--multiple-overrides', 'text_value']) + assert isinstance(args.multi_override_dest, str) # type override + assert args.multi_override_dest == "text_value" + + # Check help override was applied + for action in parser._actions: + if hasattr(action, 'dest') and action.dest == 'multi_override_dest': + assert action.help == "Multiple overrides applied" + break + + def test_nargs_override(self): + """Test that argparse_meta can override nargs.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithArgparseMeta) + + factory.build_group(parser, title="Test Group") + + # With nargs='?', argument is optional + args = parser.parse_args(['--nargs-override']) + assert args.nargs_override is None # No value provided with '?' + + # With value + args = parser.parse_args(['--nargs-override', 'provided_value']) + assert args.nargs_override == "provided_value" + + # Without flag at all, should use default + args = parser.parse_args([]) + assert args.nargs_override == "single" + + def test_metadata_takes_precedence_over_inference(self): + """Test that metadata has highest precedence over type inference.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithArgparseMeta) + + # Build kwargs for custom_type field which is str but metadata says int + from dataclasses import fields as dc_fields + + for f in dc_fields(ConfigWithArgparseMeta): + if f.name == 'custom_type': + kwargs = factory._build_argparse_kwargs_from_field(f) + # Metadata type should override inferred type + assert kwargs['type'] == int + break + + def test_unhandled_unsupported_callables(self): + """Test that an unsupported type produces a TypInferenceError.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory( + ConfigWithUnsupportedCallables, exclude=["unsupported_with_metadata"] + ) + + with pytest.raises(TypeInferenceError, match="Unsupported type"): + factory.build_group(parser, title="Test Group") + + def test_handled_unsupported_callables(self): + """Test an attribute with an unsupported type that has type info in the metadata.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithUnsupportedCallables, exclude=["unsupported_type"]) + + factory.build_group(parser, title="Test Group") + + args = parser.parse_args(['--unsupported-with-metadata', '0']) + assert args.unsupported_with_metadata == 0 + + def test_unhandled_unsupported_unions(self): + """Test that an unsupported type produces a TypInferenceError.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory( + ConfigWithUnsupportedUnions, exclude=["unsupported_with_metadata"] + ) + + with pytest.raises(TypeInferenceError, match="Unions not supported by argparse"): + factory.build_group(parser, title="Test Group") + + def test_handled_unsupported_unions(self): + """Test an attribute with an unsupported type that has type info in the metadata.""" + parser = ArgumentParser(exit_on_error=False) + factory = ArgumentGroupFactory(ConfigWithUnsupportedUnions, exclude=["unsupported_type"]) + + factory.build_group(parser, title="Test Group") + + args = parser.parse_args(['--unsupported-with-metadata', 'foo']) + assert args.unsupported_with_metadata == 'foo' + + with pytest.raises(ArgumentError, match="invalid choice"): + args = parser.parse_args(['--unsupported-with-metadata', 'baz'])