Skip to content

Commit adb1742

Browse files
committed
Add support for fake distributed process groups.
1 parent 7020e1f commit adb1742

File tree

4 files changed

+15
-2
lines changed

4 files changed

+15
-2
lines changed

megatron/training/arguments.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,6 +1129,8 @@ def validate_args(args, defaults={}):
11291129
assert not args.distrib_optim_fully_reshardable_mem_efficient, \
11301130
'--distrib-optim-fully-reshardable-mem-efficient requires -enable-gloo-process-groups'
11311131

1132+
if args.fake_process_group:
1133+
assert not args.enable_gloo_process_groups, "Fake distributed group requires disabling Gloo process groups."
11321134

11331135
# Checkpointing
11341136
if args.ckpt_fully_parallel_save_deprecated and args.rank == 0:
@@ -2834,6 +2836,10 @@ def _add_distributed_args(parser):
28342836
"and must be consistent across all ranks.")
28352837
group.add_argument('--replication-factor', default=2, type=int,
28362838
help="Number of machines storing the replica of a given rank's data.")
2839+
group.add_argument('--fake-process-group', action='store_true', default=False,
2840+
help='If set, initialize with fake distributed process group and all distributed communication operations will be skipped. \
2841+
This is quite useful for profiling memory usage of distributed training with just one GPU. \
2842+
Setting WORLD_SIZE and RANK to the specific values for target distribtued scale.')
28372843
return parser
28382844

28392845

megatron/training/initialize.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,11 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, s
346346
'rank': args.rank,
347347
'timeout': timedelta(minutes=args.distributed_timeout_minutes),
348348
}
349+
if args.fake_process_group:
350+
from torch.testing._internal.distributed.fake_pg import FakeStore
351+
store = FakeStore()
352+
init_process_group_kwargs['backend'] = 'fake'
353+
init_process_group_kwargs['store'] = store
349354

350355
torch.distributed.init_process_group(**init_process_group_kwargs)
351356
inprocess_restart.maybe_force_nccl_backend_init(device_id)

megatron/training/training.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1773,7 +1773,9 @@ def training_log(
17731773
num_microbatches = get_num_microbatches()
17741774
report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True)
17751775
report_memory(f'(after {iteration} iterations)')
1776-
report_memory_flag = False
1776+
if iteration > 1:
1777+
# Make sure the memory after the second iteration is reported to include optimizer state memory.
1778+
report_memory_flag = False
17771779
# Write timers to wandb, don't reset the counts
17781780
if args.log_timers_to_tensorboard:
17791781
timers.write(timers_to_log, writer, iteration, normalizer=args.log_interval, reset=False)

megatron/training/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def is_last_rank():
414414

415415
def print_rank_last(message):
416416
"""If distributed is initialized, print only on last rank."""
417-
if torch.distributed.is_initialized():
417+
if torch.distributed.is_initialized() and torch.distributed.get_backend() != 'fake':
418418
if is_last_rank():
419419
print(message, flush=True)
420420
else:

0 commit comments

Comments
 (0)