@@ -1372,11 +1372,6 @@ def _add_transformer_engine_args(parser):
13721372 help = 'Keep the compute param in fp4 (do not use any other intermediate '
13731373 'dtype) and perform the param all-gather in fp4.' ,
13741374 dest = 'fp4_param' )
1375- group .add_argument ('--te-rng-tracker' , action = 'store_true' , default = False ,
1376- help = 'Use the Transformer Engine version of the random number generator. '
1377- 'Required for CUDA graphs support.' )
1378- group .add_argument ('--inference-rng-tracker' , action = 'store_true' , default = False ,
1379- help = 'Use a random number generator configured for inference.' )
13801375 return parser
13811376
13821377def _add_inference_args (parser ):
@@ -2224,14 +2219,11 @@ def _add_rerun_machine_args(parser):
22242219
22252220
22262221def _add_initialization_args (parser ):
2227- group = parser .add_argument_group (title = 'initialization' )
2228-
2229- group .add_argument ('--seed' , type = int , default = 1234 ,
2230- help = 'Random seed used for python, numpy, '
2231- 'pytorch, and cuda.' )
2232- group .add_argument ('--data-parallel-random-init' , action = 'store_true' ,
2233- help = 'Enable random initialization of params '
2234- 'across data parallel ranks' )
2222+ from megatron .training .config import RNGConfig
2223+
2224+ rng_factory = ArgumentGroupFactory (RNGConfig )
2225+ group = rng_factory .build_group (parser , "RNG and initialization" )
2226+
22352227 group .add_argument ('--init-method-std' , type = float , default = 0.02 ,
22362228 help = 'Standard deviation of the zero mean normal '
22372229 'distribution used for weight initialization.' )
0 commit comments