From ee69c1bdccfd6d360b5e304c6a5687815b2e54e2 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Wed, 16 Jul 2025 16:57:36 -0700 Subject: [PATCH 01/27] add training loop config dataclass Signed-off-by: Maanu Grover --- megatron/core/training/config.py | 91 ++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 megatron/core/training/config.py diff --git a/megatron/core/training/config.py b/megatron/core/training/config.py new file mode 100644 index 0000000000..b147e191e9 --- /dev/null +++ b/megatron/core/training/config.py @@ -0,0 +1,91 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +from dataclasses import dataclass +import signal +from typing import Optional, Literal + +@dataclass(kw_only=True) +class TrainingConfig: + """Configuration settings related to the training loop and validation.""" + + # ---------------- Training config. ---------------- + + micro_batch_size: Optional[int] = None + """Batch size per model instance (local batch size). Global batch size is local batch size times + data parallel size times number of micro batches.""" + + global_batch_size: Optional[int] = None + """Training batch size. If set, it should be a multiple of micro-batch-size times + data-parallel-size. If this value is None, then use micro-batch-size * data-parallel-size + as the global batch size. This choice will result in 1 for number of micro-batches.""" + + rampup_batch_size: Optional[list[int]] = None + """Batch size ramp up with the following values: , , + + For example: + rampup-batch-size = [16, 8, 300000] + global-batch-size 1024 + will start with global batch size 16 and over (1024 - 16) / 8 = 126 intervals will increase + the batch size linearly to 1024. In each interval we will use approximately + 300000 / 126 = 2380 samples. + """ + + decrease_batch_size_if_needed: bool = False + """If set, decrease batch size if microbatch_size * dp_size does not + divide batch_size. Old batch_size will be restored if training is re-started + with dp_size that divides batch_size // microbatch_size.""" + + empty_unused_memory_level: Literal[0, 1, 2] = 0 + """Call torch.cuda.empty_cache() each iteration (training and eval), to reduce fragmentation. + 0=off, 1=moderate, 2=aggressive. + """ + + check_weight_hash_across_dp_replicas_interval: Optional[int] = None + """Interval to check weight hashes are same across DP replicas. If not specified, weight hashes not checked.""" + + train_sync_interval: Optional[int] = None + """Training CPU-GPU synchronization interval, to ensure that CPU is not running too far ahead of GPU.""" + + train_iters: Optional[int] = None + """Total number of iterations to train over all training runs.""" + + exit_interval: Optional[int] = None + """Exit the program after the iteration is divisible by this value.""" + + exit_duration_in_mins: Optional[int] = None + """Exit the program after this many minutes.""" + + exit_signal_handler: bool = False + """Dynamically save the checkpoint and shutdown the training if SIGTERM is received""" + + exit_signal: int = signal.SIGTERM + """Signal for the signal handler to detect.""" + + exit_signal_handler_for_dataloader: bool = False + """Use signal handler for dataloader workers""" + + manual_gc: bool = False + """Disable the threshold-based default garbage collector and trigger the garbage collection + manually. Manual garbage collection helps to align the timing of the collection across ranks + which mitigates the impact of CPU-associated jitters. When the manual gc is enabled, garbage + collection is performed only at the start and the end of the validation routine by default.""" + + manual_gc_interval: int = 0 + """Training step interval to trigger manual garbage collection. + When the value is set to 0, garbage collection is not triggered between training steps. + """ + + manual_gc_eval: bool = True + """When using manual garbage collection, + disable garbage collection at the start and the end of each evaluation run. + """ + + # ---------------- Validation config. ---------------- + + eval_iters: int = 100 + """Number of iterations to run for evaluation validation/test for.""" + + eval_interval: Optional[int] = 1000 + """Interval between running evaluation on validation set.""" + + skip_train: bool = False + """If set, bypass the training loop, optionally do evaluation for validation/test, and exit.""" From 09e044d4537d70b75aa14c2bac7aefd7dc1e6587 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Thu, 17 Jul 2025 14:07:01 -0700 Subject: [PATCH 02/27] move file --- megatron/{core => }/training/config.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename megatron/{core => }/training/config.py (100%) diff --git a/megatron/core/training/config.py b/megatron/training/config.py similarity index 100% rename from megatron/core/training/config.py rename to megatron/training/config.py From b118b244a0fb70f77d0eaff84f0c80d05a0f0f5f Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Thu, 24 Jul 2025 11:04:30 -0700 Subject: [PATCH 03/27] remove variable batch size options Signed-off-by: Maanu Grover --- megatron/training/config.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/megatron/training/config.py b/megatron/training/config.py index b147e191e9..4c17ea0a65 100644 --- a/megatron/training/config.py +++ b/megatron/training/config.py @@ -18,22 +18,6 @@ class TrainingConfig: data-parallel-size. If this value is None, then use micro-batch-size * data-parallel-size as the global batch size. This choice will result in 1 for number of micro-batches.""" - rampup_batch_size: Optional[list[int]] = None - """Batch size ramp up with the following values: , , - - For example: - rampup-batch-size = [16, 8, 300000] - global-batch-size 1024 - will start with global batch size 16 and over (1024 - 16) / 8 = 126 intervals will increase - the batch size linearly to 1024. In each interval we will use approximately - 300000 / 126 = 2380 samples. - """ - - decrease_batch_size_if_needed: bool = False - """If set, decrease batch size if microbatch_size * dp_size does not - divide batch_size. Old batch_size will be restored if training is re-started - with dp_size that divides batch_size // microbatch_size.""" - empty_unused_memory_level: Literal[0, 1, 2] = 0 """Call torch.cuda.empty_cache() each iteration (training and eval), to reduce fragmentation. 0=off, 1=moderate, 2=aggressive. From 29c129453e16900dd2e135ac9ecb53dde3a4a47a Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Thu, 24 Jul 2025 11:08:17 -0700 Subject: [PATCH 04/27] replace train iters with train samples Signed-off-by: Maanu Grover --- megatron/training/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/training/config.py b/megatron/training/config.py index 4c17ea0a65..6a7b3a8680 100644 --- a/megatron/training/config.py +++ b/megatron/training/config.py @@ -29,8 +29,8 @@ class TrainingConfig: train_sync_interval: Optional[int] = None """Training CPU-GPU synchronization interval, to ensure that CPU is not running too far ahead of GPU.""" - train_iters: Optional[int] = None - """Total number of iterations to train over all training runs.""" + train_samples: Optional[int] = None + """Total number of samples to train over all training runs.""" exit_interval: Optional[int] = None """Exit the program after the iteration is divisible by this value.""" From e9375f977310ad5757a7356f406e3607b891c051 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Thu, 24 Jul 2025 12:11:02 -0700 Subject: [PATCH 05/27] replace eval iters with eval samples Signed-off-by: Maanu Grover --- megatron/training/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/training/config.py b/megatron/training/config.py index 6a7b3a8680..853a04da57 100644 --- a/megatron/training/config.py +++ b/megatron/training/config.py @@ -65,8 +65,8 @@ class TrainingConfig: # ---------------- Validation config. ---------------- - eval_iters: int = 100 - """Number of iterations to run for evaluation validation/test for.""" + eval_samples: int = 0 + """Number of samples to run for evaluation. Used for both validation and test.""" eval_interval: Optional[int] = 1000 """Interval between running evaluation on validation set.""" From 1684ede627cefbdae07ba29a69ac937a573df95f Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Thu, 24 Jul 2025 12:12:14 -0700 Subject: [PATCH 06/27] Revert "remove variable batch size options" This reverts commit 1e160b4ce1884baa571939771d522bd9ede44c3f. --- megatron/training/config.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/megatron/training/config.py b/megatron/training/config.py index 853a04da57..77b9e9304f 100644 --- a/megatron/training/config.py +++ b/megatron/training/config.py @@ -18,6 +18,22 @@ class TrainingConfig: data-parallel-size. If this value is None, then use micro-batch-size * data-parallel-size as the global batch size. This choice will result in 1 for number of micro-batches.""" + rampup_batch_size: Optional[list[int]] = None + """Batch size ramp up with the following values: , , + + For example: + rampup-batch-size = [16, 8, 300000] + global-batch-size 1024 + will start with global batch size 16 and over (1024 - 16) / 8 = 126 intervals will increase + the batch size linearly to 1024. In each interval we will use approximately + 300000 / 126 = 2380 samples. + """ + + decrease_batch_size_if_needed: bool = False + """If set, decrease batch size if microbatch_size * dp_size does not + divide batch_size. Old batch_size will be restored if training is re-started + with dp_size that divides batch_size // microbatch_size.""" + empty_unused_memory_level: Literal[0, 1, 2] = 0 """Call torch.cuda.empty_cache() each iteration (training and eval), to reduce fragmentation. 0=off, 1=moderate, 2=aggressive. From cc56ab3002e06e2bfb699fdd1fa4de46e0a37bc5 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Thu, 31 Jul 2025 15:18:20 -0700 Subject: [PATCH 07/27] update some defaults and docstrings Signed-off-by: Maanu Grover --- megatron/training/config.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/megatron/training/config.py b/megatron/training/config.py index 77b9e9304f..48036826c0 100644 --- a/megatron/training/config.py +++ b/megatron/training/config.py @@ -69,23 +69,27 @@ class TrainingConfig: which mitigates the impact of CPU-associated jitters. When the manual gc is enabled, garbage collection is performed only at the start and the end of the validation routine by default.""" - manual_gc_interval: int = 0 - """Training step interval to trigger manual garbage collection. - When the value is set to 0, garbage collection is not triggered between training steps. + manual_gc_interval: Optional[int] = None + """Training step interval to trigger manual garbage collection. Values > 0 will trigger garbage + collections between training steps. """ manual_gc_eval: bool = True - """When using manual garbage collection, - disable garbage collection at the start and the end of each evaluation run. + """When using manual garbage collection, this controls garbage collection at the start and the + end of each evaluation run. """ # ---------------- Validation config. ---------------- - eval_samples: int = 0 - """Number of samples to run for evaluation. Used for both validation and test.""" + eval_samples: Optional[int] = None + """Number of samples to run for evaluation. Used for both validation and test. If not set, + evaluation will not run. + """ - eval_interval: Optional[int] = 1000 - """Interval between running evaluation on validation set.""" + eval_interval: Optional[int] = None + """Interval between running evaluation on validation set. If not set, evaluation will not run + during training. + """ skip_train: bool = False """If set, bypass the training loop, optionally do evaluation for validation/test, and exit.""" From ff961f764d4399ea68adb4661a6db7fa205a4fe1 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Mon, 10 Nov 2025 17:15:47 -0800 Subject: [PATCH 08/27] first draft of factory Signed-off-by: Maanu Grover --- megatron/training/argument_utils.py | 158 ++++++++++++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 megatron/training/argument_utils.py diff --git a/megatron/training/argument_utils.py b/megatron/training/argument_utils.py new file mode 100644 index 0000000000..e824aa27d7 --- /dev/null +++ b/megatron/training/argument_utils.py @@ -0,0 +1,158 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import typing +from typing import Any, Optional +from argparse import ArgumentParser +import inspect +import itertools +import builtins +import ast +from dataclasses import Field, fields + +# TODO: support include/exclude keys +# TODO: support arg renames, bool name invert +# TODO: if metadata handles types, ignore exceptions from _extract_type() + +class ArgumentGroupFactory: + """Utility that adds an argument group to an ArgumentParser based on the attributes of a dataclass. + + This class can be overriden as needed to support dataclasses + that require some customized or additional handling. + + Args: + src_cfg_class: The target dataclass to build arguments from. + """ + + def __init__(self, src_cfg_class: type) -> None: + self.src_cfg_class = src_cfg_class + self.field_docstrings = self._get_field_docstrings(src_cfg_class) + + def _format_arg_name(self, config_attr_name: str) -> str: + """Convert dataclass name into appropriate argparse flag name. + + Args: + config_attr_name: dataclass attribute name + """ + arg_name = "--" + config_attr_name.replace("_", "-") + return arg_name + + def _extract_type(self, config_type: type) -> dict[str, Any]: + """Determine the type, nargs, and choices settings for this argument. + + Args: + config_type: attribute type from dataclass + """ + origin = typing.get_origin(config_type) + type_tuple = typing.get_args(config_type) + + # Primitive type + if origin is None: + return {"type": config_type} + + if origin is typing.Union: + # Handle Optional and Union + if type_tuple[1] == type(None): # Optional type. First element is value inside Optional[] + return self._extract_type(type_tuple[0]) + else: + raise TypeError(f"Unions not supported by argparse: {config_type}") + + elif origin is list: + if len(type_tuple) == 1: + kwargs = self._extract_type(type_tuple[0]) + kwargs["nargs"] = "+" + return kwargs + else: + raise TypeError(f"Multi-type lists not supported by argparse: {config_type}") + + elif origin is typing.Literal: + choices_types = [type(choice) for choice in type_tuple] + assert all([t == choices_types[0] for t in choices_types]), "Type of each choice in a Literal type should all be the same." + kwargs = {"type": choices_types[0], "choices": type_tuple} + return kwargs + else: + raise TypeError(f"Unsupported type: {config_type}") + + + def _build_argparse_kwargs_from_field(self, attribute: Field) -> dict[str, Any]: + """Assemble kwargs for add_argument(). + + Args: + attribute: dataclass attribute + """ + argparse_kwargs = {} + argparse_kwargs["arg_name"] = self._format_arg_name(attribute.name) + argparse_kwargs["dest"] = attribute.name + argparse_kwargs["default"] = attribute.default + argparse_kwargs["help"] = self.field_docstrings[attribute.name] + + argparse_kwargs.update(self._extract_type(attribute.type)) + + # use store_true or store_false action for enable/disable flags, which doesn't accept a 'type' + if argparse_kwargs["type"] == bool: + argparse_kwargs["action"] = "store_true" if attribute.default == False else "store_false" + argparse_kwargs.pop("type") + + # metadata provided by field takes precedence + if attribute.metadata != {} and "argparse_meta" in attribute.metadata: + argparse_kwargs.update(attribute.metadata["argparse_meta"]) + + return argparse_kwargs + + def build_group(self, parser: ArgumentParser, title: Optional[str] = None) -> ArgumentParser: + """Entrypoint method that adds the argument group to the parser. + + Args: + parser: The parser to add arguments to + title: Title for the argument group + """ + arg_group = parser.add_argument_group(title=title, description=self.src_cfg_class.__doc__) + for attr in fields(self.src_cfg_class): + add_arg_kwargs = self._build_argparse_kwargs_from_field(attr) + + arg_name = add_arg_kwargs.pop("arg_name") + arg_group.add_argument(arg_name, **add_arg_kwargs) + + return parser + + def _get_field_docstrings(self, src_cfg_class: type) -> dict[str, str]: + """Extract field-level docstrings from a dataclass by inspecting its AST. + + Recurses on parent classes of `src_cfg_class`. + + Args: + src_cfg_class: Dataclass to get docstrings from. + """ + source = inspect.getsource(src_cfg_class) + tree = ast.parse(source) + root_node = tree.body[0] + + assert isinstance(root_node, ast.ClassDef), "Provided object must be a class." + + field_docstrings = {} + + # Iterate over body of the dataclass using 2-width sliding window. + # When 'a' is an assignment expression and 'b' is a constant, the window is + # lined up with an attribute-docstring pair. The pair can be saved to our dict. + for a, b in itertools.pairwise(root_node.body): + a_cond = isinstance(a, ast.AnnAssign) and isinstance(a.target, ast.Name) + b_cond = isinstance(b, ast.Expr) and isinstance(b.value, ast.Constant) + + if a_cond and b_cond: + # These should be guaranteed by typechecks above, but assert just in case + assert isinstance(a.target.id, str), "Dataclass attribute not in the expected format. Name is not a string." + assert isinstance(b.value.value, str), "Dataclass attribute docstring is not a string." + + # Formatting + docstring = inspect.cleandoc(b.value.value) + docstring = ' '.join(docstring.split()) + + field_docstrings[a.target.id] = docstring + + # recurse on parent class + base_classes = src_cfg_class.__bases__ + if len(base_classes) > 0: + parent_class = base_classes[0] + if parent_class.__name__ not in builtins.__dict__: + field_docstrings.update(get_field_docstrings(base_classes[0])) + + return field_docstrings From 4892f2f58941435f443203f55b216bd6942720cf Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Mon, 10 Nov 2025 17:16:00 -0800 Subject: [PATCH 09/27] set metadata for rbs Signed-off-by: Maanu Grover --- megatron/training/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/training/config.py b/megatron/training/config.py index 48036826c0..dae2b1453f 100644 --- a/megatron/training/config.py +++ b/megatron/training/config.py @@ -1,5 +1,5 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -from dataclasses import dataclass +from dataclasses import dataclass, field import signal from typing import Optional, Literal @@ -18,7 +18,7 @@ class TrainingConfig: data-parallel-size. If this value is None, then use micro-batch-size * data-parallel-size as the global batch size. This choice will result in 1 for number of micro-batches.""" - rampup_batch_size: Optional[list[int]] = None + rampup_batch_size: Optional[list[int]] = field(default=None, metadata={"argparse_meta": {"nargs": 3}}) """Batch size ramp up with the following values: , , For example: From c22431ffe280dc9fe93a0fe8b39e6ebae4b69756 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Tue, 11 Nov 2025 13:16:41 -0800 Subject: [PATCH 10/27] support excluding keys Signed-off-by: Maanu Grover --- megatron/training/argument_utils.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/megatron/training/argument_utils.py b/megatron/training/argument_utils.py index e824aa27d7..7a4c157546 100644 --- a/megatron/training/argument_utils.py +++ b/megatron/training/argument_utils.py @@ -9,7 +9,6 @@ import ast from dataclasses import Field, fields -# TODO: support include/exclude keys # TODO: support arg renames, bool name invert # TODO: if metadata handles types, ignore exceptions from _extract_type() @@ -20,12 +19,20 @@ class ArgumentGroupFactory: that require some customized or additional handling. Args: - src_cfg_class: The target dataclass to build arguments from. + src_cfg_class: The source dataclass type (not instance) whose fields will be + converted into command-line arguments. Each field's type annotation determines + the argument type, default values become argument defaults, and field-level + docstrings are extracted to populate argument help text. + exclude: Optional list of attribute names from `src_cfg_class` to exclude from + argument generation. Useful for omitting internal fields, computed properties, + or attributes that should be configured through other means. If None, all + dataclass fields will be converted to command-line arguments. Default: None. """ - def __init__(self, src_cfg_class: type) -> None: + def __init__(self, src_cfg_class: type, exclude: Optional[list[str]] = None) -> None: self.src_cfg_class = src_cfg_class self.field_docstrings = self._get_field_docstrings(src_cfg_class) + self.exclude = set(exclude) if exclude is not None else set() def _format_arg_name(self, config_attr_name: str) -> str: """Convert dataclass name into appropriate argparse flag name. @@ -107,6 +114,9 @@ def build_group(self, parser: ArgumentParser, title: Optional[str] = None) -> Ar """ arg_group = parser.add_argument_group(title=title, description=self.src_cfg_class.__doc__) for attr in fields(self.src_cfg_class): + if attr.name in self.exclude: + continue + add_arg_kwargs = self._build_argparse_kwargs_from_field(attr) arg_name = add_arg_kwargs.pop("arg_name") From 57ae59b7fafb74d5f7f42833156e5bef2f289c71 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Fri, 14 Nov 2025 14:36:28 -0800 Subject: [PATCH 11/27] change return object Signed-off-by: Maanu Grover --- megatron/training/argument_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/megatron/training/argument_utils.py b/megatron/training/argument_utils.py index 7a4c157546..e85b76399f 100644 --- a/megatron/training/argument_utils.py +++ b/megatron/training/argument_utils.py @@ -2,7 +2,7 @@ import typing from typing import Any, Optional -from argparse import ArgumentParser +from argparse import ArgumentParser, _ArgumentGroup import inspect import itertools import builtins @@ -105,7 +105,7 @@ def _build_argparse_kwargs_from_field(self, attribute: Field) -> dict[str, Any]: return argparse_kwargs - def build_group(self, parser: ArgumentParser, title: Optional[str] = None) -> ArgumentParser: + def build_group(self, parser: ArgumentParser, title: Optional[str] = None) -> _ArgumentGroup: """Entrypoint method that adds the argument group to the parser. Args: @@ -122,7 +122,7 @@ def build_group(self, parser: ArgumentParser, title: Optional[str] = None) -> Ar arg_name = add_arg_kwargs.pop("arg_name") arg_group.add_argument(arg_name, **add_arg_kwargs) - return parser + return arg_group def _get_field_docstrings(self, src_cfg_class: type) -> dict[str, str]: """Extract field-level docstrings from a dataclass by inspecting its AST. From cff582f71c3ebc2f5603bf954a50e2b855f5c000 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Fri, 14 Nov 2025 15:26:29 -0800 Subject: [PATCH 12/27] support multiple arg names and arg name prefixes Signed-off-by: Maanu Grover --- megatron/training/argument_utils.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/megatron/training/argument_utils.py b/megatron/training/argument_utils.py index e85b76399f..88fe77185a 100644 --- a/megatron/training/argument_utils.py +++ b/megatron/training/argument_utils.py @@ -9,7 +9,7 @@ import ast from dataclasses import Field, fields -# TODO: support arg renames, bool name invert +# TODO: support arg renames # TODO: if metadata handles types, ignore exceptions from _extract_type() class ArgumentGroupFactory: @@ -34,13 +34,18 @@ def __init__(self, src_cfg_class: type, exclude: Optional[list[str]] = None) -> self.field_docstrings = self._get_field_docstrings(src_cfg_class) self.exclude = set(exclude) if exclude is not None else set() - def _format_arg_name(self, config_attr_name: str) -> str: + def _format_arg_name(self, config_attr_name: str, prefix: Optional[str] = None) -> str: """Convert dataclass name into appropriate argparse flag name. Args: config_attr_name: dataclass attribute name + prefix: prefix string to add to the dataclass attribute name. e.g. 'no' for bool + settings that are default True. A hyphen is added after the prefix. Default: None """ - arg_name = "--" + config_attr_name.replace("_", "-") + arg_name = config_attr_name + if prefix: + arg_name = prefix + '_' + arg_name + arg_name = "--" + arg_name.replace("_", "-") return arg_name def _extract_type(self, config_type: type) -> dict[str, Any]: @@ -87,7 +92,7 @@ def _build_argparse_kwargs_from_field(self, attribute: Field) -> dict[str, Any]: attribute: dataclass attribute """ argparse_kwargs = {} - argparse_kwargs["arg_name"] = self._format_arg_name(attribute.name) + argparse_kwargs["arg_names"] = [self._format_arg_name(attribute.name)] argparse_kwargs["dest"] = attribute.name argparse_kwargs["default"] = attribute.default argparse_kwargs["help"] = self.field_docstrings[attribute.name] @@ -99,6 +104,10 @@ def _build_argparse_kwargs_from_field(self, attribute: Field) -> dict[str, Any]: argparse_kwargs["action"] = "store_true" if attribute.default == False else "store_false" argparse_kwargs.pop("type") + # add '--no-*' and '--disable-*' prefix if this is a store_false argument + if argparse_kwargs["action"] == "store_false": + argparse_kwargs["arg_names"] = [self._format_arg_name(attribute.name, prefix="no"), self._format_arg_name(attribute.name, prefix="disable")] + # metadata provided by field takes precedence if attribute.metadata != {} and "argparse_meta" in attribute.metadata: argparse_kwargs.update(attribute.metadata["argparse_meta"]) @@ -119,8 +128,8 @@ def build_group(self, parser: ArgumentParser, title: Optional[str] = None) -> _A add_arg_kwargs = self._build_argparse_kwargs_from_field(attr) - arg_name = add_arg_kwargs.pop("arg_name") - arg_group.add_argument(arg_name, **add_arg_kwargs) + arg_names = add_arg_kwargs.pop("arg_names") + arg_group.add_argument(*arg_names, **add_arg_kwargs) return arg_group From 17d745c235190b3aaee82c1229d66eb91813eb2f Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Fri, 14 Nov 2025 15:36:47 -0800 Subject: [PATCH 13/27] remove some auto-generated arguments Signed-off-by: Maanu Grover --- megatron/training/arguments.py | 73 +++------------------------------- 1 file changed, 6 insertions(+), 67 deletions(-) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index a726ae1def..d94ca43119 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -48,6 +48,8 @@ load_quantization_recipe, ) +from megatron.training.argument_utils import ArgumentGroupFactory + def add_megatron_arguments(parser: argparse.ArgumentParser): """"Add Megatron-LM arguments to the given parser.""" @@ -1967,41 +1969,14 @@ def _add_rl_args(parser): return parser def _add_training_args(parser): - group = parser.add_argument_group(title='training') + from megatron.training.config import TrainingConfig + + train_factory = ArgumentGroupFactory(TrainingConfig) + group = train_factory.build_group(parser, "training") - group.add_argument('--micro-batch-size', type=int, default=None, - help='Batch size per model instance (local batch size). ' - 'Global batch size is local batch size times data ' - 'parallel size times number of micro batches.') group.add_argument('--batch-size', type=int, default=None, help='Old batch size parameter, do not use. ' 'Use --micro-batch-size instead') - group.add_argument('--global-batch-size', type=int, default=None, - help='Training batch size. If set, it should be a ' - 'multiple of micro-batch-size times data-parallel-size. ' - 'If this value is None, then ' - 'use micro-batch-size * data-parallel-size as the ' - 'global batch size. This choice will result in 1 for ' - 'number of micro-batches.') - group.add_argument('--rampup-batch-size', nargs='*', default=None, - help='Batch size ramp up with the following values:' - ' --rampup-batch-size ' - ' ' - ' ' - 'For example:' - ' --rampup-batch-size 16 8 300000 \\ ' - ' --global-batch-size 1024' - 'will start with global batch size 16 and over ' - ' (1024 - 16) / 8 = 126 intervals will increase' - 'the batch size linearly to 1024. In each interval' - 'we will use approximately 300000 / 126 = 2380 samples.') - group.add_argument('--decrease-batch-size-if-needed', action='store_true', default=False, - help='If set, decrease batch size if microbatch_size * dp_size' - 'does not divide batch_size. Useful for KSO (Keep Soldiering On)' - 'to continue making progress if number of healthy GPUs (and' - 'corresponding dp_size) does not support current batch_size.' - 'Old batch_size will be restored if training is re-started with' - 'dp_size that divides batch_size // microbatch_size.') group.add_argument('--recompute-activations', action='store_true', help='recompute activation to allow for training ' 'with larger models, sequences, and batch sizes.') @@ -2114,21 +2089,12 @@ def _add_training_args(parser): group.add_argument('--use-cpu-initialization', action='store_true', default=None, help='If set, initialize weights on the CPU. This eliminates init differences based on tensor parallelism.') - group.add_argument('--empty-unused-memory-level', default=0, type=int, - choices=[0, 1, 2], - help='Call torch.cuda.empty_cache() each iteration ' - '(training and eval), to reduce fragmentation.' - '0=off, 1=moderate, 2=aggressive.') group.add_argument('--deterministic-mode', action='store_true', help='Choose code that has deterministic execution. This usually ' 'means slower execution, but is good for debugging and testing.') - group.add_argument('--check-weight-hash-across-dp-replicas-interval', type=int, default=None, - help='Interval to check weight hashes are same across DP replicas. If not specified, weight hashes not checked.') group.add_argument('--calculate-per-token-loss', action='store_true', help=('Scale cross entropy loss by the number of non-padded tokens in the ' 'global batch, versus the default behavior of assuming all tokens are non-padded.')) - group.add_argument('--train-sync-interval', type=int, default=None, - help='Training CPU-GPU synchronization interval, to ensure that CPU is not running too far ahead of GPU.') # deprecated group.add_argument('--checkpoint-activations', action='store_true', @@ -2144,17 +2110,6 @@ def _add_training_args(parser): 'train-samples should be provided.') group.add_argument('--log-interval', type=int, default=100, help='Report loss and timing interval.') - group.add_argument('--exit-interval', type=int, default=None, - help='Exit the program after the iteration is divisible ' - 'by this value.') - group.add_argument('--exit-duration-in-mins', type=int, default=None, - help='Exit the program after this many minutes.') - group.add_argument('--exit-signal-handler', action='store_true', - help='Dynamically save the checkpoint and shutdown the ' - 'training if signal is received') - group.add_argument('--exit-signal', type=str, default='SIGTERM', - choices=list(SIGNAL_MAP.keys()), - help='Signal to use for exit signal handler. If not specified, defaults to SIGTERM.') group.add_argument('--tensorboard-dir', type=str, default=None, help='Write TensorBoard logs to this directory.') group.add_argument('--no-masked-softmax-fusion', @@ -2242,22 +2197,6 @@ def _add_training_args(parser): '--use-legacy-models to not use core models.') group.add_argument('--use-legacy-models', action='store_true', help='Use the legacy Megatron models, not Megatron-Core models.') - group.add_argument('--manual-gc', action='store_true', - help='Disable the threshold-based default garbage ' - 'collector and trigger the garbage collection manually. ' - 'Manual garbage collection helps to align the timing of ' - 'the collection across ranks which mitigates the impact ' - 'of CPU-associated jitters. When the manual gc is enabled, ' - 'garbage collection is performed only at the start and the ' - 'end of the validation routine by default.') - group.add_argument('--manual-gc-interval', type=int, default=0, - help='Training step interval to trigger manual garbage ' - 'collection. When the value is set to 0, garbage ' - 'collection is not triggered between training steps.') - group.add_argument('--no-manual-gc-eval', action='store_false', - help='When using manual garbage collection, disable ' - 'garbage collection at the start and the end of each ' - 'evaluation run.', dest='manual_gc_eval') group.add_argument('--disable-tp-comm-split-ag', action='store_false', help='Disables the All-Gather overlap with fprop GEMM.', dest='tp_comm_split_ag') From 47f1639da60dc9835accd511f318ad617213dfdf Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Fri, 14 Nov 2025 16:05:39 -0800 Subject: [PATCH 14/27] support default factories Signed-off-by: Maanu Grover --- megatron/training/argument_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/megatron/training/argument_utils.py b/megatron/training/argument_utils.py index 88fe77185a..fa15914ab0 100644 --- a/megatron/training/argument_utils.py +++ b/megatron/training/argument_utils.py @@ -1,5 +1,6 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +import dataclasses import typing from typing import Any, Optional from argparse import ArgumentParser, _ArgumentGroup @@ -94,9 +95,15 @@ def _build_argparse_kwargs_from_field(self, attribute: Field) -> dict[str, Any]: argparse_kwargs = {} argparse_kwargs["arg_names"] = [self._format_arg_name(attribute.name)] argparse_kwargs["dest"] = attribute.name - argparse_kwargs["default"] = attribute.default argparse_kwargs["help"] = self.field_docstrings[attribute.name] + # dataclasses specifies that both should not be set + if isinstance(attribute.default, type(dataclasses.MISSING)): + # dataclasses specified default_factory must be a zero-argument callable + argparse_kwargs["default"] = attribute.default_factory() + else: + argparse_kwargs["default"] = attribute.default + argparse_kwargs.update(self._extract_type(attribute.type)) # use store_true or store_false action for enable/disable flags, which doesn't accept a 'type' From 46e568475fadcc87ef4fc62c171b2d8c426ee0a6 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Fri, 14 Nov 2025 16:06:11 -0800 Subject: [PATCH 15/27] add iterations to skip to config Signed-off-by: Maanu Grover --- megatron/training/arguments.py | 2 -- megatron/training/config.py | 4 ++++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index d94ca43119..7f80ccdbcf 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2045,8 +2045,6 @@ def _add_training_args(parser): help='Global step to start profiling.') group.add_argument('--profile-step-end', type=int, default=12, help='Global step to stop profiling.') - group.add_argument('--iterations-to-skip', nargs='+', type=int, default=[], - help='List of iterations to skip, empty by default.') group.add_argument('--result-rejected-tracker-filename', type=str, default=None, help='Optional name of file tracking `result_rejected` events.') group.add_argument('--disable-gloo-process-groups', action='store_false', diff --git a/megatron/training/config.py b/megatron/training/config.py index dae2b1453f..1d44999f1c 100644 --- a/megatron/training/config.py +++ b/megatron/training/config.py @@ -79,6 +79,10 @@ class TrainingConfig: end of each evaluation run. """ + iterations_to_skip: list[int] = field(default_factory=list) + """List of iterations to skip during training, empty by default.""" + + # ---------------- Validation config. ---------------- eval_samples: Optional[int] = None From 74ad7dee47f32879dbbffdb7662fbab6777157c0 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Fri, 14 Nov 2025 16:06:38 -0800 Subject: [PATCH 16/27] support both iters and samples for training Signed-off-by: Maanu Grover --- megatron/training/arguments.py | 8 -------- megatron/training/config.py | 8 +++++++- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 7f80ccdbcf..10b2bfb99c 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2098,14 +2098,6 @@ def _add_training_args(parser): group.add_argument('--checkpoint-activations', action='store_true', help='Checkpoint activation to allow for training ' 'with larger models, sequences, and batch sizes.') - group.add_argument('--train-iters', type=int, default=None, - help='Total number of iterations to train over all ' - 'training runs. Note that either train-iters or ' - 'train-samples should be provided.') - group.add_argument('--train-samples', type=int, default=None, - help='Total number of samples to train over all ' - 'training runs. Note that either train-iters or ' - 'train-samples should be provided.') group.add_argument('--log-interval', type=int, default=100, help='Report loss and timing interval.') group.add_argument('--tensorboard-dir', type=str, default=None, diff --git a/megatron/training/config.py b/megatron/training/config.py index 1d44999f1c..3c62d3cb7a 100644 --- a/megatron/training/config.py +++ b/megatron/training/config.py @@ -45,8 +45,14 @@ class TrainingConfig: train_sync_interval: Optional[int] = None """Training CPU-GPU synchronization interval, to ensure that CPU is not running too far ahead of GPU.""" + train_iters: Optional[int] = None + """Total number of iterations to train over all training runs. + Note that either train_iters or train_samples should be provided. + """ + train_samples: Optional[int] = None - """Total number of samples to train over all training runs.""" + """Total number of samples to train over all training runs. + Note that either train_iters or train_samples should be provided.""" exit_interval: Optional[int] = None """Exit the program after the iteration is divisible by this value.""" From 79c62455161d16bbbc0fa3635758e6b440e0907a Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Fri, 14 Nov 2025 17:52:43 -0800 Subject: [PATCH 17/27] split validation into separate config Signed-off-by: Maanu Grover --- megatron/training/arguments.py | 18 ++++-------------- megatron/training/config.py | 30 ++++++++++++++++++++---------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 10b2bfb99c..f32c7c6b40 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2677,20 +2677,10 @@ def _add_distributed_args(parser): def _add_validation_args(parser): - group = parser.add_argument_group(title='validation') - - group.add_argument('--full-validation', action='store_true', help='If set, each time validation occurs it uses the full validation dataset(s). This currently only works for GPT datasets!') - group.add_argument('--multiple-validation-sets', action='store_true', help='If set, multiple datasets listed in the validation split are evaluated independently with a separate loss for each dataset in the list. This argument requires that no weights are included in the list') - group.add_argument('--eval-iters', type=int, default=100, - help='Number of iterations to run for evaluation' - 'validation/test for.') - group.add_argument('--eval-interval', type=int, default=1000, - help='Interval between running evaluation on ' - 'validation set.') - group.add_argument("--test-mode", action="store_true", help='Run all real-time test alongside the experiment.') - group.add_argument('--skip-train', action='store_true', - default=False, help='If set, bypass the training loop, ' - 'optionally do evaluation for validation/test, and exit.') + from megatron.training.config import ValidationConfig + + val_factory = ArgumentGroupFactory(ValidationConfig) + group = val_factory.build_group(parser, "validation") return parser diff --git a/megatron/training/config.py b/megatron/training/config.py index 3c62d3cb7a..92454b5c40 100644 --- a/megatron/training/config.py +++ b/megatron/training/config.py @@ -5,9 +5,7 @@ @dataclass(kw_only=True) class TrainingConfig: - """Configuration settings related to the training loop and validation.""" - - # ---------------- Training config. ---------------- + """Configuration settings related to the training loop.""" micro_batch_size: Optional[int] = None """Batch size per model instance (local batch size). Global batch size is local batch size times @@ -89,17 +87,29 @@ class TrainingConfig: """List of iterations to skip during training, empty by default.""" - # ---------------- Validation config. ---------------- +@dataclass(kw_only=True) +class ValidationConfig: + """Configuration settings related to validation during or after model training.""" - eval_samples: Optional[int] = None - """Number of samples to run for evaluation. Used for both validation and test. If not set, - evaluation will not run. - """ + val_iters: Optional[int] = field(default=100, metadata={"argparse_meta": {"arg_names": ["--eval-iters", "--val-iters"], "dest": "eval_iters"}}) + """Number of iterations to run validation/test for.""" - eval_interval: Optional[int] = None + val_interval: Optional[int] = field(default=None, metadata={"argparse_meta": {"arg_names": ["--eval-interval", "--val-interval"], "dest": "eval_interval"}}) """Interval between running evaluation on validation set. If not set, evaluation will not run during training. """ skip_train: bool = False - """If set, bypass the training loop, optionally do evaluation for validation/test, and exit.""" + """If set, bypass the training loop, perform evaluation for validation/test, and exit.""" + + test_mode: bool = False + """Run all real-time test alongside the experiment.""" + + full_validation: bool = False + """If set, each time validation occurs it uses the full validation dataset(s). This currently only works for GPT datasets!""" + + multiple_validation_sets: bool = False + """If set, multiple datasets listed in the validation split are evaluated independently with a + separate loss for each dataset in the list. This argument requires that no weights are + included in the list. + """ From 14b348dc30c813ba83dfa2845b36865d1171c19a Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Mon, 17 Nov 2025 13:00:04 -0800 Subject: [PATCH 18/27] add unit tests Signed-off-by: Maanu Grover --- tests/unit_tests/test_argument_utils.py | 350 ++++++++++++++++++++++++ 1 file changed, 350 insertions(+) create mode 100644 tests/unit_tests/test_argument_utils.py diff --git a/tests/unit_tests/test_argument_utils.py b/tests/unit_tests/test_argument_utils.py new file mode 100644 index 0000000000..55376189fc --- /dev/null +++ b/tests/unit_tests/test_argument_utils.py @@ -0,0 +1,350 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import pytest +from argparse import ArgumentParser +from dataclasses import dataclass, field +from typing import Optional, Literal +from megatron.training.argument_utils import ArgumentGroupFactory + + +@dataclass +class DummyConfig: + """A dummy configuration for testing.""" + + name: str = "default_name" + """Name of the configuration""" + + count: int = 42 + """Number of items""" + + learning_rate: float = 0.001 + """Learning rate for training""" + + enabled: bool = False + """Whether feature is enabled""" + + disabled_feature: bool = True + """Feature that is disabled by default""" + + +@dataclass +class ConfigWithOptional: + """Config with optional fields.""" + + required_field: str = "required" + """A required field""" + + optional_field: Optional[int] = None + """An optional integer field""" + + optional_str: Optional[str] = "default" + """An optional string with default""" + + +@dataclass +class ConfigWithList: + """Config with list fields.""" + + tags: list[str] = field(default_factory=list) + """List of tags""" + + numbers: list[int] = field(default_factory=lambda: [1, 2, 3]) + """List of numbers with default""" + + +@dataclass +class ConfigWithLiteral: + """Config with Literal types.""" + + mode: Literal["train", "eval", "test"] = "train" + """Operating mode""" + + precision: Literal[16, 32] = 32 + """Precision level""" + + +class TestArgumentGroupFactoryBasic: + """Test basic functionality of ArgumentGroupFactory.""" + + def test_creates_argument_group(self): + """Test that build_group creates an argument group.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(DummyConfig) + + arg_group = factory.build_group(parser, title="Test Group") + + assert arg_group is not None + assert arg_group.title == "Test Group" + assert arg_group.description == DummyConfig.__doc__ + + def test_all_fields_added(self): + """Test that all dataclass fields are added as arguments.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(DummyConfig) + + factory.build_group(parser, title="Test Group") + + # Parse empty args to get all defaults + args = parser.parse_args([]) + + # Check all fields exist + assert hasattr(args, 'name') + assert hasattr(args, 'count') + assert hasattr(args, 'learning_rate') + assert hasattr(args, 'enabled') + assert hasattr(args, 'disabled_feature') + + def test_default_values_preserved(self): + """Test that default values from dataclass are preserved.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(DummyConfig) + + factory.build_group(parser, title="Test Group") + args = parser.parse_args([]) + + assert args.name == "default_name" + assert args.count == 42 + assert args.learning_rate == 0.001 + assert args.enabled == False + assert args.disabled_feature == True + + def test_argument_types(self): + """Test that argument types are correctly inferred.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(DummyConfig) + + factory.build_group(parser, title="Test Group") + + # Parse with actual values + args = parser.parse_args([ + '--name', 'test_name', + '--count', '100', + '--learning-rate', '0.01' + ]) + + assert isinstance(args.name, str) + assert args.name == 'test_name' + assert isinstance(args.count, int) + assert args.count == 100 + assert isinstance(args.learning_rate, float) + assert args.learning_rate == 0.01 + + def test_boolean_store_true(self): + """Test that boolean fields with default False use store_true.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(DummyConfig) + + factory.build_group(parser, title="Test Group") + + # Without flag, should be False + args = parser.parse_args([]) + assert args.enabled == False + + # With flag, should be True + args = parser.parse_args(['--enabled']) + assert args.enabled == True + + def test_boolean_store_false(self): + """Test that boolean fields with default True use store_false with no- prefix.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(DummyConfig) + + factory.build_group(parser, title="Test Group") + + # Without flag, should be True + args = parser.parse_args([]) + assert args.disabled_feature == True + + # With --no- flag, should be False + args = parser.parse_args(['--no-disabled-feature']) + assert args.disabled_feature == False + + # With --disable- flag, should also be False + args = parser.parse_args(['--disable-disabled-feature']) + assert args.disabled_feature == False + + def test_field_docstrings_as_help(self): + """Test that field docstrings are extracted and used as help text.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(DummyConfig) + + # Check that field_docstrings were extracted + assert 'name' in factory.field_docstrings + assert factory.field_docstrings['name'] == "Name of the configuration" + assert factory.field_docstrings['count'] == "Number of items" + assert factory.field_docstrings['learning_rate'] == "Learning rate for training" + + +class TestArgumentGroupFactoryExclusion: + """Test exclusion functionality.""" + + def test_exclude_single_field(self): + """Test excluding a single field.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(DummyConfig, exclude=['count']) + + factory.build_group(parser, title="Test Group") + args = parser.parse_args([]) + + # Excluded field should not exist + assert hasattr(args, 'name') + assert not hasattr(args, 'count') + assert hasattr(args, 'learning_rate') + + def test_exclude_multiple_fields(self): + """Test excluding multiple fields.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(DummyConfig, exclude=['count', 'learning_rate']) + + factory.build_group(parser, title="Test Group") + args = parser.parse_args([]) + + assert hasattr(args, 'name') + assert not hasattr(args, 'count') + assert not hasattr(args, 'learning_rate') + assert hasattr(args, 'enabled') + + +class TestArgumentGroupFactoryOptional: + """Test handling of Optional types.""" + + def test_optional_fields(self): + """Test that Optional fields are handled correctly.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithOptional) + + factory.build_group(parser, title="Test Group") + + # Default values + args = parser.parse_args([]) + assert args.required_field == "required" + assert args.optional_field is None + assert args.optional_str == "default" + + # Provided values + args = parser.parse_args([ + '--required-field', 'new_value', + '--optional-field', '123', + '--optional-str', 'custom' + ]) + assert args.required_field == "new_value" + assert args.optional_field == 123 + assert args.optional_str == "custom" + + +class TestArgumentGroupFactoryList: + """Test handling of list types.""" + + def test_list_fields_with_default_factory(self): + """Test that list fields use nargs='+'.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithList) + + factory.build_group(parser, title="Test Group") + + # Default values + args = parser.parse_args([]) + assert args.tags == [] + assert args.numbers == [1, 2, 3] + + # Provided values + args = parser.parse_args([ + '--tags', 'tag1', 'tag2', 'tag3', + '--numbers', '10', '20', '30' + ]) + assert args.tags == ['tag1', 'tag2', 'tag3'] + assert args.numbers == [10, 20, 30] + + +class TestArgumentGroupFactoryLiteral: + """Test handling of Literal types.""" + + def test_literal_fields_have_choices(self): + """Test that Literal types create choice constraints.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithLiteral) + + factory.build_group(parser, title="Test Group") + + # Default values + args = parser.parse_args([]) + assert args.mode == "train" + assert args.precision == 32 + + # Valid choices + args = parser.parse_args(['--mode', 'eval', '--precision', '16']) + assert args.mode == "eval" + assert args.precision == 16 + + def test_literal_fields_reject_invalid_choices(self): + """Test that invalid Literal choices are rejected.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithLiteral) + + factory.build_group(parser, title="Test Group") + + # Invalid choice should raise error + with pytest.raises(SystemExit): + parser.parse_args(['--mode', 'invalid']) + + with pytest.raises(SystemExit): + parser.parse_args(['--precision', '64']) + + +class TestArgumentGroupFactoryHelpers: + """Test helper methods.""" + + def test_format_arg_name_basic(self): + """Test basic argument name formatting.""" + factory = ArgumentGroupFactory(DummyConfig) + + assert factory._format_arg_name("simple") == "--simple" + assert factory._format_arg_name("with_underscore") == "--with-underscore" + assert factory._format_arg_name("multiple_under_scores") == "--multiple-under-scores" + + def test_format_arg_name_with_prefix(self): + """Test argument name formatting with prefix.""" + factory = ArgumentGroupFactory(DummyConfig) + + assert factory._format_arg_name("feature", prefix="no") == "--no-feature" + assert factory._format_arg_name("feature", prefix="disable") == "--disable-feature" + assert factory._format_arg_name("multi_word", prefix="no") == "--no-multi-word" + + def test_extract_type_primitive(self): + """Test type extraction for primitive types.""" + factory = ArgumentGroupFactory(DummyConfig) + + assert factory._extract_type(int) == {"type": int} + assert factory._extract_type(str) == {"type": str} + assert factory._extract_type(float) == {"type": float} + + def test_extract_type_optional(self): + """Test type extraction for Optional types.""" + factory = ArgumentGroupFactory(DummyConfig) + + result = factory._extract_type(Optional[int]) + assert result == {"type": int} + + result = factory._extract_type(Optional[str]) + assert result == {"type": str} + + def test_extract_type_list(self): + """Test type extraction for list types.""" + factory = ArgumentGroupFactory(DummyConfig) + + result = factory._extract_type(list[int]) + assert result == {"type": int, "nargs": "+"} + + result = factory._extract_type(list[str]) + assert result == {"type": str, "nargs": "+"} + + def test_extract_type_literal(self): + """Test type extraction for Literal types.""" + factory = ArgumentGroupFactory(DummyConfig) + + result = factory._extract_type(Literal["a", "b", "c"]) + assert result == {"type": str, "choices": ("a", "b", "c")} + + result = factory._extract_type(Literal[1, 2, 3]) + assert result == {"type": int, "choices": (1, 2, 3)} + From e6d185a08e9c3e355bc24ac0a97e5be6bba118f7 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Mon, 17 Nov 2025 13:22:06 -0800 Subject: [PATCH 19/27] defer to metadata if present on type check failure Signed-off-by: Maanu Grover --- megatron/training/argument_utils.py | 54 ++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 16 deletions(-) diff --git a/megatron/training/argument_utils.py b/megatron/training/argument_utils.py index fa15914ab0..7a163e6f59 100644 --- a/megatron/training/argument_utils.py +++ b/megatron/training/argument_utils.py @@ -11,7 +11,10 @@ from dataclasses import Field, fields # TODO: support arg renames -# TODO: if metadata handles types, ignore exceptions from _extract_type() + +class TypeInferenceError(Exception): + """Custom exception type to be conditionally handled by ArgumentGroupFactory.""" + pass class ArgumentGroupFactory: """Utility that adds an argument group to an ArgumentParser based on the attributes of a dataclass. @@ -67,7 +70,7 @@ def _extract_type(self, config_type: type) -> dict[str, Any]: if type_tuple[1] == type(None): # Optional type. First element is value inside Optional[] return self._extract_type(type_tuple[0]) else: - raise TypeError(f"Unions not supported by argparse: {config_type}") + raise TypeInferenceError(f"Unions not supported by argparse: {config_type}") elif origin is list: if len(type_tuple) == 1: @@ -75,7 +78,7 @@ def _extract_type(self, config_type: type) -> dict[str, Any]: kwargs["nargs"] = "+" return kwargs else: - raise TypeError(f"Multi-type lists not supported by argparse: {config_type}") + raise TypeInferenceError(f"Multi-type lists not supported by argparse: {config_type}") elif origin is typing.Literal: choices_types = [type(choice) for choice in type_tuple] @@ -83,7 +86,7 @@ def _extract_type(self, config_type: type) -> dict[str, Any]: kwargs = {"type": choices_types[0], "choices": type_tuple} return kwargs else: - raise TypeError(f"Unsupported type: {config_type}") + raise TypeInferenceError(f"Unsupported type: {config_type}") def _build_argparse_kwargs_from_field(self, attribute: Field) -> dict[str, Any]: @@ -104,20 +107,39 @@ def _build_argparse_kwargs_from_field(self, attribute: Field) -> dict[str, Any]: else: argparse_kwargs["default"] = attribute.default - argparse_kwargs.update(self._extract_type(attribute.type)) - - # use store_true or store_false action for enable/disable flags, which doesn't accept a 'type' - if argparse_kwargs["type"] == bool: - argparse_kwargs["action"] = "store_true" if attribute.default == False else "store_false" - argparse_kwargs.pop("type") - - # add '--no-*' and '--disable-*' prefix if this is a store_false argument - if argparse_kwargs["action"] == "store_false": - argparse_kwargs["arg_names"] = [self._format_arg_name(attribute.name, prefix="no"), self._format_arg_name(attribute.name, prefix="disable")] + attr_argparse_meta = None + if attribute.metadata != {} and "argparse_meta" in attribute.metadata: + # save metadata here, but update at the end so the metadata has highest precedence + attr_argparse_meta = attribute.metadata["argparse_meta"] + + + # if we cannot infer the argparse type, all of this logic may fail. we try to defer + # to the developer-specified metadata if present + try: + argparse_kwargs.update(self._extract_type(attribute.type)) + + # use store_true or store_false action for enable/disable flags, which doesn't accept a 'type' + if argparse_kwargs["type"] == bool: + argparse_kwargs["action"] = "store_true" if attribute.default == False else "store_false" + argparse_kwargs.pop("type") + + # add '--no-*' and '--disable-*' prefix if this is a store_false argument + if argparse_kwargs["action"] == "store_false": + argparse_kwargs["arg_names"] = [self._format_arg_name(attribute.name, prefix="no"), self._format_arg_name(attribute.name, prefix="disable")] + except TypeInferenceError as e: + if attr_argparse_meta is not None: + print( + f"WARNING: Inferring the appropriate argparse argument type from {self.src_cfg_class} " + f"failed for {attribute.name}: {attribute.type}.\n" + "Deferring to attribute metadata. If the metadata is incomplete, 'parser.add_argument()' may fail.\n" + f"Original failure: {e}" + ) + else: + raise e # metadata provided by field takes precedence - if attribute.metadata != {} and "argparse_meta" in attribute.metadata: - argparse_kwargs.update(attribute.metadata["argparse_meta"]) + if attr_argparse_meta is not None: + argparse_kwargs.update(attr_argparse_meta) return argparse_kwargs From ed96bfe1ec28a2dc7a12eb4d303c7c7a1005b4f0 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Mon, 17 Nov 2025 15:38:30 -0800 Subject: [PATCH 20/27] revert name changes to val config Signed-off-by: Maanu Grover --- megatron/training/config.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/megatron/training/config.py b/megatron/training/config.py index 92454b5c40..5495b6dd0c 100644 --- a/megatron/training/config.py +++ b/megatron/training/config.py @@ -91,10 +91,11 @@ class TrainingConfig: class ValidationConfig: """Configuration settings related to validation during or after model training.""" - val_iters: Optional[int] = field(default=100, metadata={"argparse_meta": {"arg_names": ["--eval-iters", "--val-iters"], "dest": "eval_iters"}}) - """Number of iterations to run validation/test for.""" + eval_iters: Optional[int] = 100 + """Number of iterations to run for evaluation. Used for both validation and test. If not set, + evaluation will not run.""" - val_interval: Optional[int] = field(default=None, metadata={"argparse_meta": {"arg_names": ["--eval-interval", "--val-interval"], "dest": "eval_interval"}}) + eval_interval: Optional[int] = None """Interval between running evaluation on validation set. If not set, evaluation will not run during training. """ From 5eb7a83501d5bac3a36b2dddc2ba26455f9ce772 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Mon, 17 Nov 2025 16:55:40 -0800 Subject: [PATCH 21/27] more unit test coverage Signed-off-by: Maanu Grover --- tests/unit_tests/test_argument_utils.py | 239 +++++++++++++++++++++++- 1 file changed, 237 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/test_argument_utils.py b/tests/unit_tests/test_argument_utils.py index 55376189fc..e9b6270d25 100644 --- a/tests/unit_tests/test_argument_utils.py +++ b/tests/unit_tests/test_argument_utils.py @@ -3,8 +3,8 @@ import pytest from argparse import ArgumentParser from dataclasses import dataclass, field -from typing import Optional, Literal -from megatron.training.argument_utils import ArgumentGroupFactory +from typing import Callable, Optional, Literal +from megatron.training.argument_utils import ArgumentGroupFactory, TypeInferenceError @dataclass @@ -348,3 +348,238 @@ def test_extract_type_literal(self): result = factory._extract_type(Literal[1, 2, 3]) assert result == {"type": int, "choices": (1, 2, 3)} + +@dataclass +class ConfigWithArgparseMeta: + """Config with argparse_meta metadata for testing overrides.""" + + custom_help: str = field( + default="default_value", + metadata={"argparse_meta": {"help": "Custom help text from metadata"}} + ) + """Original help text""" + + custom_type: str = field( + default="100", + metadata={"argparse_meta": {"type": int}} + ) + """Field with type override""" + + custom_default: str = field( + default="original_default", + metadata={"argparse_meta": {"default": "overridden_default"}} + ) + """Field with default override""" + + custom_choices: str = field( + default="option1", + metadata={"argparse_meta": {"choices": ["option1", "option2", "option3"]}} + ) + """Field with choices override""" + + custom_dest: str = field( + default="value", + metadata={"argparse_meta": {"dest": "renamed_destination"}} + ) + """Field with dest override""" + + custom_action: bool = field( + default=False, + metadata={"argparse_meta": {"action": "store_const", "const": "special_value"}} + ) + """Field with custom action override""" + + multiple_overrides: int = field( + default=42, + metadata={"argparse_meta": { + "type": str, + "help": "Multiple overrides applied", + "default": "999", + "dest": "multi_override_dest" + }} + ) + """Field with multiple metadata overrides""" + + nargs_override: str = field( + default="single", + metadata={"argparse_meta": {"nargs": "?"}} + ) + """Field with nargs override""" + + +@dataclass +class ConfigWithUnsupportedTypes: + """Config with argparse_meta metadata for testing overrides.""" + + unsupported_type: Optional[Callable] = None + """Cannot take a callable over CLI""" + + unsupported_with_metadata: Optional[Callable] = field(default=None, metadata={"argparse_meta": {"type": int, "choices": (0, 1, 2)}}) + """This argument should be 0, 1, or 2. The appropriate + Callable will be set by some other logic. + """ + + +class TestArgumentGroupFactoryArgparseMeta: + """Test argparse_meta metadata override functionality.""" + + def test_help_override(self): + """Test that argparse_meta can override help text.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithArgparseMeta) + + factory.build_group(parser, title="Test Group") + + # Find the action for this argument + for action in parser._actions: + if hasattr(action, 'dest') and action.dest == 'custom_help': + assert action.help == "Custom help text from metadata" + return + + pytest.fail("custom_help argument not found") + + def test_type_override(self): + """Test that argparse_meta can override argument type.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithArgparseMeta) + + factory.build_group(parser, title="Test Group") + + # Parse with integer value (metadata overrides type to int) + args = parser.parse_args(['--custom-type', '42']) + + # Should be parsed as int, not str + assert isinstance(args.custom_type, int) + assert args.custom_type == 42 + + def test_default_override(self): + """Test that argparse_meta can override default value.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithArgparseMeta) + + factory.build_group(parser, title="Test Group") + + # Parse with no arguments + args = parser.parse_args([]) + + # Should use metadata default, not field default + assert args.custom_default == "overridden_default" + + def test_choices_override(self): + """Test that argparse_meta can override choices.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithArgparseMeta) + + factory.build_group(parser, title="Test Group") + + # Valid choice from metadata + args = parser.parse_args(['--custom-choices', 'option2']) + assert args.custom_choices == "option2" + + # Invalid choice should fail + with pytest.raises(SystemExit): + parser.parse_args(['--custom-choices', 'invalid_option']) + + def test_dest_override(self): + """Test that argparse_meta can override destination name.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithArgparseMeta) + + factory.build_group(parser, title="Test Group") + + args = parser.parse_args(['--custom-dest', 'test_value']) + + # Should be stored in renamed destination + assert hasattr(args, 'renamed_destination') + assert args.renamed_destination == "test_value" + + def test_action_override(self): + """Test that argparse_meta can override action.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithArgparseMeta) + + factory.build_group(parser, title="Test Group") + + # With custom action=store_const and const="special_value" + args = parser.parse_args(['--custom-action']) + assert args.custom_action == "special_value" + + # Without flag, should use default + args = parser.parse_args([]) + assert args.custom_action == False + + def test_multiple_overrides(self): + """Test that multiple argparse_meta overrides work together.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithArgparseMeta) + + factory.build_group(parser, title="Test Group") + + # Parse with no arguments to check default override + args = parser.parse_args([]) + + # Check all overrides applied + assert hasattr(args, 'multi_override_dest') + assert args.multi_override_dest == "999" # default override + + # Parse with value to check type override + args = parser.parse_args(['--multiple-overrides', 'text_value']) + assert isinstance(args.multi_override_dest, str) # type override + assert args.multi_override_dest == "text_value" + + # Check help override was applied + for action in parser._actions: + if hasattr(action, 'dest') and action.dest == 'multi_override_dest': + assert action.help == "Multiple overrides applied" + break + + def test_nargs_override(self): + """Test that argparse_meta can override nargs.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithArgparseMeta) + + factory.build_group(parser, title="Test Group") + + # With nargs='?', argument is optional + args = parser.parse_args(['--nargs-override']) + assert args.nargs_override is None # No value provided with '?' + + # With value + args = parser.parse_args(['--nargs-override', 'provided_value']) + assert args.nargs_override == "provided_value" + + # Without flag at all, should use default + args = parser.parse_args([]) + assert args.nargs_override == "single" + + def test_metadata_takes_precedence_over_inference(self): + """Test that metadata has highest precedence over type inference.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithArgparseMeta) + + # Build kwargs for custom_type field which is str but metadata says int + from dataclasses import fields as dc_fields + for f in dc_fields(ConfigWithArgparseMeta): + if f.name == 'custom_type': + kwargs = factory._build_argparse_kwargs_from_field(f) + # Metadata type should override inferred type + assert kwargs['type'] == int + break + + def test_unhandled_unsupported_type(self): + """Test that an unsupported type produces a TypInferenceError.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithUnsupportedTypes, exclude=["unsupported_with_metadata"]) + + with pytest.raises(TypeInferenceError, match="Unsupported type"): + factory.build_group(parser, title="Test Group") + + def test_handled_unsupported_type(self): + """Test an attribute with an unsupported type that has type info in the metadata.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory(ConfigWithUnsupportedTypes, exclude=["unsupported_type"]) + + factory.build_group(parser, title="Test Group") + + args = parser.parse_args(['--unsupported-with-metadata', '0']) + assert args.unsupported_with_metadata == 0 From 7fcff86a7a83d842fd623c34f22983eacdc34042 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Wed, 19 Nov 2025 12:15:08 -0600 Subject: [PATCH 22/27] formatting Signed-off-by: Maanu Grover --- tests/unit_tests/test_argument_utils.py | 314 ++++++++++++------------ 1 file changed, 154 insertions(+), 160 deletions(-) diff --git a/tests/unit_tests/test_argument_utils.py b/tests/unit_tests/test_argument_utils.py index e9b6270d25..eedb7abaec 100644 --- a/tests/unit_tests/test_argument_utils.py +++ b/tests/unit_tests/test_argument_utils.py @@ -1,28 +1,30 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -import pytest from argparse import ArgumentParser from dataclasses import dataclass, field -from typing import Callable, Optional, Literal +from typing import Callable, Literal, Optional + +import pytest + from megatron.training.argument_utils import ArgumentGroupFactory, TypeInferenceError @dataclass class DummyConfig: """A dummy configuration for testing.""" - + name: str = "default_name" """Name of the configuration""" - + count: int = 42 """Number of items""" - + learning_rate: float = 0.001 """Learning rate for training""" - + enabled: bool = False """Whether feature is enabled""" - + disabled_feature: bool = True """Feature that is disabled by default""" @@ -30,13 +32,13 @@ class DummyConfig: @dataclass class ConfigWithOptional: """Config with optional fields.""" - + required_field: str = "required" """A required field""" - + optional_field: Optional[int] = None """An optional integer field""" - + optional_str: Optional[str] = "default" """An optional string with default""" @@ -44,10 +46,10 @@ class ConfigWithOptional: @dataclass class ConfigWithList: """Config with list fields.""" - + tags: list[str] = field(default_factory=list) """List of tags""" - + numbers: list[int] = field(default_factory=lambda: [1, 2, 3]) """List of numbers with default""" @@ -55,119 +57,117 @@ class ConfigWithList: @dataclass class ConfigWithLiteral: """Config with Literal types.""" - + mode: Literal["train", "eval", "test"] = "train" """Operating mode""" - + precision: Literal[16, 32] = 32 """Precision level""" class TestArgumentGroupFactoryBasic: """Test basic functionality of ArgumentGroupFactory.""" - + def test_creates_argument_group(self): """Test that build_group creates an argument group.""" parser = ArgumentParser() factory = ArgumentGroupFactory(DummyConfig) - + arg_group = factory.build_group(parser, title="Test Group") - + assert arg_group is not None assert arg_group.title == "Test Group" assert arg_group.description == DummyConfig.__doc__ - + def test_all_fields_added(self): """Test that all dataclass fields are added as arguments.""" parser = ArgumentParser() factory = ArgumentGroupFactory(DummyConfig) - + factory.build_group(parser, title="Test Group") - + # Parse empty args to get all defaults args = parser.parse_args([]) - + # Check all fields exist assert hasattr(args, 'name') assert hasattr(args, 'count') assert hasattr(args, 'learning_rate') assert hasattr(args, 'enabled') assert hasattr(args, 'disabled_feature') - + def test_default_values_preserved(self): """Test that default values from dataclass are preserved.""" parser = ArgumentParser() factory = ArgumentGroupFactory(DummyConfig) - + factory.build_group(parser, title="Test Group") args = parser.parse_args([]) - + assert args.name == "default_name" assert args.count == 42 assert args.learning_rate == 0.001 assert args.enabled == False assert args.disabled_feature == True - + def test_argument_types(self): """Test that argument types are correctly inferred.""" parser = ArgumentParser() factory = ArgumentGroupFactory(DummyConfig) - + factory.build_group(parser, title="Test Group") - + # Parse with actual values - args = parser.parse_args([ - '--name', 'test_name', - '--count', '100', - '--learning-rate', '0.01' - ]) - + args = parser.parse_args( + ['--name', 'test_name', '--count', '100', '--learning-rate', '0.01'] + ) + assert isinstance(args.name, str) assert args.name == 'test_name' assert isinstance(args.count, int) assert args.count == 100 assert isinstance(args.learning_rate, float) assert args.learning_rate == 0.01 - + def test_boolean_store_true(self): """Test that boolean fields with default False use store_true.""" parser = ArgumentParser() factory = ArgumentGroupFactory(DummyConfig) - + factory.build_group(parser, title="Test Group") - + # Without flag, should be False args = parser.parse_args([]) assert args.enabled == False - + # With flag, should be True args = parser.parse_args(['--enabled']) assert args.enabled == True - + def test_boolean_store_false(self): """Test that boolean fields with default True use store_false with no- prefix.""" parser = ArgumentParser() factory = ArgumentGroupFactory(DummyConfig) - + factory.build_group(parser, title="Test Group") - + # Without flag, should be True args = parser.parse_args([]) assert args.disabled_feature == True - + # With --no- flag, should be False args = parser.parse_args(['--no-disabled-feature']) assert args.disabled_feature == False - + # With --disable- flag, should also be False args = parser.parse_args(['--disable-disabled-feature']) assert args.disabled_feature == False - + def test_field_docstrings_as_help(self): """Test that field docstrings are extracted and used as help text.""" parser = ArgumentParser() factory = ArgumentGroupFactory(DummyConfig) - + # Check that field_docstrings were extracted assert 'name' in factory.field_docstrings assert factory.field_docstrings['name'] == "Name of the configuration" @@ -177,56 +177,54 @@ def test_field_docstrings_as_help(self): class TestArgumentGroupFactoryExclusion: """Test exclusion functionality.""" - + def test_exclude_single_field(self): """Test excluding a single field.""" parser = ArgumentParser() factory = ArgumentGroupFactory(DummyConfig, exclude=['count']) - + factory.build_group(parser, title="Test Group") args = parser.parse_args([]) - + # Excluded field should not exist assert hasattr(args, 'name') assert not hasattr(args, 'count') assert hasattr(args, 'learning_rate') - + def test_exclude_multiple_fields(self): """Test excluding multiple fields.""" parser = ArgumentParser() factory = ArgumentGroupFactory(DummyConfig, exclude=['count', 'learning_rate']) - + factory.build_group(parser, title="Test Group") args = parser.parse_args([]) - + assert hasattr(args, 'name') assert not hasattr(args, 'count') assert not hasattr(args, 'learning_rate') assert hasattr(args, 'enabled') - + class TestArgumentGroupFactoryOptional: """Test handling of Optional types.""" - + def test_optional_fields(self): """Test that Optional fields are handled correctly.""" parser = ArgumentParser() factory = ArgumentGroupFactory(ConfigWithOptional) - + factory.build_group(parser, title="Test Group") - + # Default values args = parser.parse_args([]) assert args.required_field == "required" assert args.optional_field is None assert args.optional_str == "default" - + # Provided values - args = parser.parse_args([ - '--required-field', 'new_value', - '--optional-field', '123', - '--optional-str', 'custom' - ]) + args = parser.parse_args( + ['--required-field', 'new_value', '--optional-field', '123', '--optional-str', 'custom'] + ) assert args.required_field == "new_value" assert args.optional_field == 123 assert args.optional_str == "custom" @@ -234,117 +232,114 @@ def test_optional_fields(self): class TestArgumentGroupFactoryList: """Test handling of list types.""" - + def test_list_fields_with_default_factory(self): """Test that list fields use nargs='+'.""" parser = ArgumentParser() factory = ArgumentGroupFactory(ConfigWithList) - + factory.build_group(parser, title="Test Group") - + # Default values args = parser.parse_args([]) assert args.tags == [] assert args.numbers == [1, 2, 3] - + # Provided values - args = parser.parse_args([ - '--tags', 'tag1', 'tag2', 'tag3', - '--numbers', '10', '20', '30' - ]) + args = parser.parse_args(['--tags', 'tag1', 'tag2', 'tag3', '--numbers', '10', '20', '30']) assert args.tags == ['tag1', 'tag2', 'tag3'] assert args.numbers == [10, 20, 30] class TestArgumentGroupFactoryLiteral: """Test handling of Literal types.""" - + def test_literal_fields_have_choices(self): """Test that Literal types create choice constraints.""" parser = ArgumentParser() factory = ArgumentGroupFactory(ConfigWithLiteral) - + factory.build_group(parser, title="Test Group") - + # Default values args = parser.parse_args([]) assert args.mode == "train" assert args.precision == 32 - + # Valid choices args = parser.parse_args(['--mode', 'eval', '--precision', '16']) assert args.mode == "eval" assert args.precision == 16 - + def test_literal_fields_reject_invalid_choices(self): """Test that invalid Literal choices are rejected.""" parser = ArgumentParser() factory = ArgumentGroupFactory(ConfigWithLiteral) - + factory.build_group(parser, title="Test Group") - + # Invalid choice should raise error with pytest.raises(SystemExit): parser.parse_args(['--mode', 'invalid']) - + with pytest.raises(SystemExit): parser.parse_args(['--precision', '64']) class TestArgumentGroupFactoryHelpers: """Test helper methods.""" - + def test_format_arg_name_basic(self): """Test basic argument name formatting.""" factory = ArgumentGroupFactory(DummyConfig) - + assert factory._format_arg_name("simple") == "--simple" assert factory._format_arg_name("with_underscore") == "--with-underscore" assert factory._format_arg_name("multiple_under_scores") == "--multiple-under-scores" - + def test_format_arg_name_with_prefix(self): """Test argument name formatting with prefix.""" factory = ArgumentGroupFactory(DummyConfig) - + assert factory._format_arg_name("feature", prefix="no") == "--no-feature" assert factory._format_arg_name("feature", prefix="disable") == "--disable-feature" assert factory._format_arg_name("multi_word", prefix="no") == "--no-multi-word" - + def test_extract_type_primitive(self): """Test type extraction for primitive types.""" factory = ArgumentGroupFactory(DummyConfig) - + assert factory._extract_type(int) == {"type": int} assert factory._extract_type(str) == {"type": str} assert factory._extract_type(float) == {"type": float} - + def test_extract_type_optional(self): """Test type extraction for Optional types.""" factory = ArgumentGroupFactory(DummyConfig) - + result = factory._extract_type(Optional[int]) assert result == {"type": int} - + result = factory._extract_type(Optional[str]) assert result == {"type": str} - + def test_extract_type_list(self): """Test type extraction for list types.""" factory = ArgumentGroupFactory(DummyConfig) - + result = factory._extract_type(list[int]) assert result == {"type": int, "nargs": "+"} - + result = factory._extract_type(list[str]) assert result == {"type": str, "nargs": "+"} - + def test_extract_type_literal(self): """Test type extraction for Literal types.""" factory = ArgumentGroupFactory(DummyConfig) - + result = factory._extract_type(Literal["a", "b", "c"]) assert result == {"type": str, "choices": ("a", "b", "c")} - + result = factory._extract_type(Literal[1, 2, 3]) assert result == {"type": int, "choices": (1, 2, 3)} @@ -352,69 +347,65 @@ def test_extract_type_literal(self): @dataclass class ConfigWithArgparseMeta: """Config with argparse_meta metadata for testing overrides.""" - + custom_help: str = field( default="default_value", - metadata={"argparse_meta": {"help": "Custom help text from metadata"}} + metadata={"argparse_meta": {"help": "Custom help text from metadata"}}, ) """Original help text""" - - custom_type: str = field( - default="100", - metadata={"argparse_meta": {"type": int}} - ) + + custom_type: str = field(default="100", metadata={"argparse_meta": {"type": int}}) """Field with type override""" - + custom_default: str = field( - default="original_default", - metadata={"argparse_meta": {"default": "overridden_default"}} + default="original_default", metadata={"argparse_meta": {"default": "overridden_default"}} ) """Field with default override""" - + custom_choices: str = field( default="option1", - metadata={"argparse_meta": {"choices": ["option1", "option2", "option3"]}} + metadata={"argparse_meta": {"choices": ["option1", "option2", "option3"]}}, ) """Field with choices override""" - + custom_dest: str = field( - default="value", - metadata={"argparse_meta": {"dest": "renamed_destination"}} + default="value", metadata={"argparse_meta": {"dest": "renamed_destination"}} ) """Field with dest override""" - + custom_action: bool = field( default=False, - metadata={"argparse_meta": {"action": "store_const", "const": "special_value"}} + metadata={"argparse_meta": {"action": "store_const", "const": "special_value"}}, ) """Field with custom action override""" - + multiple_overrides: int = field( default=42, - metadata={"argparse_meta": { - "type": str, - "help": "Multiple overrides applied", - "default": "999", - "dest": "multi_override_dest" - }} + metadata={ + "argparse_meta": { + "type": str, + "help": "Multiple overrides applied", + "default": "999", + "dest": "multi_override_dest", + } + }, ) """Field with multiple metadata overrides""" - - nargs_override: str = field( - default="single", - metadata={"argparse_meta": {"nargs": "?"}} - ) + + nargs_override: str = field(default="single", metadata={"argparse_meta": {"nargs": "?"}}) """Field with nargs override""" @dataclass class ConfigWithUnsupportedTypes: """Config with argparse_meta metadata for testing overrides.""" - + unsupported_type: Optional[Callable] = None """Cannot take a callable over CLI""" - unsupported_with_metadata: Optional[Callable] = field(default=None, metadata={"argparse_meta": {"type": int, "choices": (0, 1, 2)}}) + unsupported_with_metadata: Optional[Callable] = field( + default=None, metadata={"argparse_meta": {"type": int, "choices": (0, 1, 2)}} + ) """This argument should be 0, 1, or 2. The appropriate Callable will be set by some other logic. """ @@ -422,143 +413,144 @@ class ConfigWithUnsupportedTypes: class TestArgumentGroupFactoryArgparseMeta: """Test argparse_meta metadata override functionality.""" - + def test_help_override(self): """Test that argparse_meta can override help text.""" parser = ArgumentParser() factory = ArgumentGroupFactory(ConfigWithArgparseMeta) - + factory.build_group(parser, title="Test Group") - + # Find the action for this argument for action in parser._actions: if hasattr(action, 'dest') and action.dest == 'custom_help': assert action.help == "Custom help text from metadata" return - + pytest.fail("custom_help argument not found") - + def test_type_override(self): """Test that argparse_meta can override argument type.""" parser = ArgumentParser() factory = ArgumentGroupFactory(ConfigWithArgparseMeta) - + factory.build_group(parser, title="Test Group") - + # Parse with integer value (metadata overrides type to int) args = parser.parse_args(['--custom-type', '42']) - + # Should be parsed as int, not str assert isinstance(args.custom_type, int) assert args.custom_type == 42 - + def test_default_override(self): """Test that argparse_meta can override default value.""" parser = ArgumentParser() factory = ArgumentGroupFactory(ConfigWithArgparseMeta) - + factory.build_group(parser, title="Test Group") - + # Parse with no arguments args = parser.parse_args([]) - + # Should use metadata default, not field default assert args.custom_default == "overridden_default" - + def test_choices_override(self): """Test that argparse_meta can override choices.""" parser = ArgumentParser() factory = ArgumentGroupFactory(ConfigWithArgparseMeta) - + factory.build_group(parser, title="Test Group") - + # Valid choice from metadata args = parser.parse_args(['--custom-choices', 'option2']) assert args.custom_choices == "option2" - + # Invalid choice should fail with pytest.raises(SystemExit): parser.parse_args(['--custom-choices', 'invalid_option']) - + def test_dest_override(self): """Test that argparse_meta can override destination name.""" parser = ArgumentParser() factory = ArgumentGroupFactory(ConfigWithArgparseMeta) - + factory.build_group(parser, title="Test Group") - + args = parser.parse_args(['--custom-dest', 'test_value']) - + # Should be stored in renamed destination assert hasattr(args, 'renamed_destination') assert args.renamed_destination == "test_value" - + def test_action_override(self): """Test that argparse_meta can override action.""" parser = ArgumentParser() factory = ArgumentGroupFactory(ConfigWithArgparseMeta) - + factory.build_group(parser, title="Test Group") - + # With custom action=store_const and const="special_value" args = parser.parse_args(['--custom-action']) assert args.custom_action == "special_value" - + # Without flag, should use default args = parser.parse_args([]) assert args.custom_action == False - + def test_multiple_overrides(self): """Test that multiple argparse_meta overrides work together.""" parser = ArgumentParser() factory = ArgumentGroupFactory(ConfigWithArgparseMeta) - + factory.build_group(parser, title="Test Group") - + # Parse with no arguments to check default override args = parser.parse_args([]) - + # Check all overrides applied assert hasattr(args, 'multi_override_dest') assert args.multi_override_dest == "999" # default override - + # Parse with value to check type override args = parser.parse_args(['--multiple-overrides', 'text_value']) assert isinstance(args.multi_override_dest, str) # type override assert args.multi_override_dest == "text_value" - + # Check help override was applied for action in parser._actions: if hasattr(action, 'dest') and action.dest == 'multi_override_dest': assert action.help == "Multiple overrides applied" break - + def test_nargs_override(self): """Test that argparse_meta can override nargs.""" parser = ArgumentParser() factory = ArgumentGroupFactory(ConfigWithArgparseMeta) - + factory.build_group(parser, title="Test Group") - + # With nargs='?', argument is optional args = parser.parse_args(['--nargs-override']) assert args.nargs_override is None # No value provided with '?' - + # With value args = parser.parse_args(['--nargs-override', 'provided_value']) assert args.nargs_override == "provided_value" - + # Without flag at all, should use default args = parser.parse_args([]) assert args.nargs_override == "single" - + def test_metadata_takes_precedence_over_inference(self): """Test that metadata has highest precedence over type inference.""" parser = ArgumentParser() factory = ArgumentGroupFactory(ConfigWithArgparseMeta) - + # Build kwargs for custom_type field which is str but metadata says int from dataclasses import fields as dc_fields + for f in dc_fields(ConfigWithArgparseMeta): if f.name == 'custom_type': kwargs = factory._build_argparse_kwargs_from_field(f) @@ -569,7 +561,9 @@ def test_metadata_takes_precedence_over_inference(self): def test_unhandled_unsupported_type(self): """Test that an unsupported type produces a TypInferenceError.""" parser = ArgumentParser() - factory = ArgumentGroupFactory(ConfigWithUnsupportedTypes, exclude=["unsupported_with_metadata"]) + factory = ArgumentGroupFactory( + ConfigWithUnsupportedTypes, exclude=["unsupported_with_metadata"] + ) with pytest.raises(TypeInferenceError, match="Unsupported type"): factory.build_group(parser, title="Test Group") From 7f8792ebbdc905636163f2e5abccf17401fcb8f0 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Wed, 19 Nov 2025 12:19:31 -0600 Subject: [PATCH 23/27] fix recursive call Signed-off-by: Maanu Grover --- megatron/training/argument_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/training/argument_utils.py b/megatron/training/argument_utils.py index 7a163e6f59..2b6922204d 100644 --- a/megatron/training/argument_utils.py +++ b/megatron/training/argument_utils.py @@ -201,6 +201,6 @@ def _get_field_docstrings(self, src_cfg_class: type) -> dict[str, str]: if len(base_classes) > 0: parent_class = base_classes[0] if parent_class.__name__ not in builtins.__dict__: - field_docstrings.update(get_field_docstrings(base_classes[0])) + field_docstrings.update(self._get_field_docstrings(base_classes[0])) return field_docstrings From 5fe366f0408bd79fa54a89eb7594b810722c1f54 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Thu, 20 Nov 2025 14:13:04 -0600 Subject: [PATCH 24/27] fix serializability error Signed-off-by: Maanu Grover --- megatron/training/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/training/config.py b/megatron/training/config.py index 5495b6dd0c..db0909b840 100644 --- a/megatron/training/config.py +++ b/megatron/training/config.py @@ -61,7 +61,7 @@ class TrainingConfig: exit_signal_handler: bool = False """Dynamically save the checkpoint and shutdown the training if SIGTERM is received""" - exit_signal: int = signal.SIGTERM + exit_signal: int = int(signal.SIGTERM) """Signal for the signal handler to detect.""" exit_signal_handler_for_dataloader: bool = False From 4664c50956efb282371ac50e54cab720ebf1e045 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Thu, 20 Nov 2025 14:19:17 -0600 Subject: [PATCH 25/27] interval default must be int Signed-off-by: Maanu Grover --- megatron/training/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/training/config.py b/megatron/training/config.py index db0909b840..08a03220ca 100644 --- a/megatron/training/config.py +++ b/megatron/training/config.py @@ -73,7 +73,7 @@ class TrainingConfig: which mitigates the impact of CPU-associated jitters. When the manual gc is enabled, garbage collection is performed only at the start and the end of the validation routine by default.""" - manual_gc_interval: Optional[int] = None + manual_gc_interval: int = 0 """Training step interval to trigger manual garbage collection. Values > 0 will trigger garbage collections between training steps. """ From a8dee1c431a396947afa64be5b18b12d38974aee Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Sat, 22 Nov 2025 20:26:11 -0600 Subject: [PATCH 26/27] update optional typehint syntax Signed-off-by: Maanu Grover --- megatron/training/config.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/megatron/training/config.py b/megatron/training/config.py index 08a03220ca..59834e7d3f 100644 --- a/megatron/training/config.py +++ b/megatron/training/config.py @@ -1,22 +1,22 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. from dataclasses import dataclass, field import signal -from typing import Optional, Literal +from typing import Literal @dataclass(kw_only=True) class TrainingConfig: """Configuration settings related to the training loop.""" - micro_batch_size: Optional[int] = None + micro_batch_size: int | None = None """Batch size per model instance (local batch size). Global batch size is local batch size times data parallel size times number of micro batches.""" - global_batch_size: Optional[int] = None + global_batch_size: int | None = None """Training batch size. If set, it should be a multiple of micro-batch-size times data-parallel-size. If this value is None, then use micro-batch-size * data-parallel-size as the global batch size. This choice will result in 1 for number of micro-batches.""" - rampup_batch_size: Optional[list[int]] = field(default=None, metadata={"argparse_meta": {"nargs": 3}}) + rampup_batch_size: list[int] | None = field(default=None, metadata={"argparse_meta": {"nargs": 3}}) """Batch size ramp up with the following values: , , For example: @@ -37,25 +37,25 @@ class TrainingConfig: 0=off, 1=moderate, 2=aggressive. """ - check_weight_hash_across_dp_replicas_interval: Optional[int] = None + check_weight_hash_across_dp_replicas_interval: int | None = None """Interval to check weight hashes are same across DP replicas. If not specified, weight hashes not checked.""" - train_sync_interval: Optional[int] = None + train_sync_interval: int | None = None """Training CPU-GPU synchronization interval, to ensure that CPU is not running too far ahead of GPU.""" - train_iters: Optional[int] = None + train_iters: int | None = None """Total number of iterations to train over all training runs. Note that either train_iters or train_samples should be provided. """ - train_samples: Optional[int] = None + train_samples: int | None = None """Total number of samples to train over all training runs. Note that either train_iters or train_samples should be provided.""" - exit_interval: Optional[int] = None + exit_interval: int | None = None """Exit the program after the iteration is divisible by this value.""" - exit_duration_in_mins: Optional[int] = None + exit_duration_in_mins: int | None = None """Exit the program after this many minutes.""" exit_signal_handler: bool = False @@ -91,11 +91,11 @@ class TrainingConfig: class ValidationConfig: """Configuration settings related to validation during or after model training.""" - eval_iters: Optional[int] = 100 + eval_iters: int | None = 100 """Number of iterations to run for evaluation. Used for both validation and test. If not set, evaluation will not run.""" - eval_interval: Optional[int] = None + eval_interval: int | None = None """Interval between running evaluation on validation set. If not set, evaluation will not run during training. """ From 50576d9f80c55865cd11c089e5e3cc07bf0fc48a Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Sat, 22 Nov 2025 21:01:50 -0600 Subject: [PATCH 27/27] add test case for unsupported union Signed-off-by: Maanu Grover --- tests/unit_tests/test_argument_utils.py | 50 +++++++++++++++++++++---- 1 file changed, 43 insertions(+), 7 deletions(-) diff --git a/tests/unit_tests/test_argument_utils.py b/tests/unit_tests/test_argument_utils.py index eedb7abaec..050f3a65d7 100644 --- a/tests/unit_tests/test_argument_utils.py +++ b/tests/unit_tests/test_argument_utils.py @@ -1,8 +1,8 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -from argparse import ArgumentParser +from argparse import ArgumentParser, ArgumentError from dataclasses import dataclass, field -from typing import Callable, Literal, Optional +from typing import Callable, Literal, Optional, Union import pytest @@ -397,7 +397,7 @@ class ConfigWithArgparseMeta: @dataclass -class ConfigWithUnsupportedTypes: +class ConfigWithUnsupportedCallables: """Config with argparse_meta metadata for testing overrides.""" unsupported_type: Optional[Callable] = None @@ -411,6 +411,19 @@ class ConfigWithUnsupportedTypes: """ +@dataclass +class ConfigWithUnsupportedUnions: + """Config with argparse_meta metadata for testing overrides.""" + + unsupported_type: Union[int, str] = 0 + """Cannot infer type of a Union""" + + unsupported_with_metadata: Union[int, str] = field( + default=0, metadata={"argparse_meta": {"type": str, "choices": ("foo", "bar")}} + ) + """Metadata should take precedence over the exception caused by Union""" + + class TestArgumentGroupFactoryArgparseMeta: """Test argparse_meta metadata override functionality.""" @@ -558,22 +571,45 @@ def test_metadata_takes_precedence_over_inference(self): assert kwargs['type'] == int break - def test_unhandled_unsupported_type(self): + def test_unhandled_unsupported_callables(self): """Test that an unsupported type produces a TypInferenceError.""" parser = ArgumentParser() factory = ArgumentGroupFactory( - ConfigWithUnsupportedTypes, exclude=["unsupported_with_metadata"] + ConfigWithUnsupportedCallables, exclude=["unsupported_with_metadata"] ) with pytest.raises(TypeInferenceError, match="Unsupported type"): factory.build_group(parser, title="Test Group") - def test_handled_unsupported_type(self): + def test_handled_unsupported_callables(self): """Test an attribute with an unsupported type that has type info in the metadata.""" parser = ArgumentParser() - factory = ArgumentGroupFactory(ConfigWithUnsupportedTypes, exclude=["unsupported_type"]) + factory = ArgumentGroupFactory(ConfigWithUnsupportedCallables, exclude=["unsupported_type"]) factory.build_group(parser, title="Test Group") args = parser.parse_args(['--unsupported-with-metadata', '0']) assert args.unsupported_with_metadata == 0 + + def test_unhandled_unsupported_unions(self): + """Test that an unsupported type produces a TypInferenceError.""" + parser = ArgumentParser() + factory = ArgumentGroupFactory( + ConfigWithUnsupportedUnions, exclude=["unsupported_with_metadata"] + ) + + with pytest.raises(TypeInferenceError, match="Unions not supported by argparse"): + factory.build_group(parser, title="Test Group") + + def test_handled_unsupported_unions(self): + """Test an attribute with an unsupported type that has type info in the metadata.""" + parser = ArgumentParser(exit_on_error=False) + factory = ArgumentGroupFactory(ConfigWithUnsupportedUnions, exclude=["unsupported_type"]) + + factory.build_group(parser, title="Test Group") + + args = parser.parse_args(['--unsupported-with-metadata', 'foo']) + assert args.unsupported_with_metadata == 'foo' + + with pytest.raises(ArgumentError, match="invalid choice"): + args = parser.parse_args(['--unsupported-with-metadata', 'baz'])