Skip to content

Commit dc9a38d

Browse files
authored
[DEV] Add support of fake distributed process group (#2254)
1 parent 2782acf commit dc9a38d

File tree

4 files changed

+19
-3
lines changed

4 files changed

+19
-3
lines changed

megatron/training/arguments.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,11 @@ def validate_args(args, defaults={}):
11421142
assert not args.distrib_optim_fully_reshardable_mem_efficient, \
11431143
'--distrib-optim-fully-reshardable-mem-efficient requires -enable-gloo-process-groups'
11441144

1145+
if args.fake_process_group:
1146+
# Disable nan check for fake process group
1147+
args.check_for_nan_in_loss_and_grad = False
1148+
# Disable gloo process groups for fake process group
1149+
args.enable_gloo_process_groups = False
11451150

11461151
# Checkpointing
11471152
if args.ckpt_fully_parallel_save_deprecated and args.rank == 0:
@@ -2869,6 +2874,10 @@ def _add_distributed_args(parser):
28692874
"and must be consistent across all ranks.")
28702875
group.add_argument('--replication-factor', default=2, type=int,
28712876
help="Number of machines storing the replica of a given rank's data.")
2877+
group.add_argument('--fake-process-group', action='store_true', default=False,
2878+
help='If set, initialize with fake distributed process group and all distributed communication operations will be skipped. \
2879+
This is quite useful for profiling memory usage of distributed training with just one GPU. \
2880+
Setting WORLD_SIZE and RANK to the specific values for target distribtued scale.')
28722881
return parser
28732882

28742883

megatron/training/initialize.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,11 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, s
346346
'rank': args.rank,
347347
'timeout': timedelta(minutes=args.distributed_timeout_minutes),
348348
}
349+
if args.fake_process_group:
350+
from torch.testing._internal.distributed.fake_pg import FakeStore
351+
store = FakeStore()
352+
init_process_group_kwargs['backend'] = 'fake'
353+
init_process_group_kwargs['store'] = store
349354

350355
torch.distributed.init_process_group(**init_process_group_kwargs)
351356
inprocess_restart.maybe_force_nccl_backend_init(device_id)

megatron/training/training.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1699,7 +1699,7 @@ def training_log(
16991699
mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict
17001700
)
17011701
if iteration % args.log_interval == 0:
1702-
if args.record_memory_history and is_last_rank():
1702+
if args.record_memory_history and (is_last_rank() or torch.distributed.get_backend() == 'fake'):
17031703
snapshot = torch.cuda.memory._snapshot()
17041704
from pickle import dump
17051705

@@ -1788,7 +1788,9 @@ def training_log(
17881788
num_microbatches = get_num_microbatches()
17891789
report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True)
17901790
report_memory(f'(after {iteration} iterations)')
1791-
report_memory_flag = False
1791+
if iteration > 1:
1792+
# Make sure the memory after the second iteration is reported to include optimizer state memory.
1793+
report_memory_flag = False
17921794
# Write timers to wandb, don't reset the counts
17931795
if args.log_timers_to_tensorboard:
17941796
timers.write(timers_to_log, writer, iteration, normalizer=args.log_interval, reset=False)

megatron/training/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def is_last_rank():
425425

426426
def print_rank_last(message):
427427
"""If distributed is initialized, print only on last rank."""
428-
if torch.distributed.is_initialized():
428+
if torch.distributed.is_initialized() and torch.distributed.get_backend() != 'fake':
429429
if is_last_rank():
430430
print(message, flush=True)
431431
else:

0 commit comments

Comments
 (0)