Skip to content

Commit 12bcd66

Browse files
committed
[DEV] Add support of fake distributed process group (NVIDIA#2254)
1 parent c4ba666 commit 12bcd66

File tree

4 files changed

+19
-3
lines changed

4 files changed

+19
-3
lines changed

megatron/training/arguments.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,11 @@ def validate_args(args, defaults={}):
11031103
assert not args.distrib_optim_fully_reshardable_mem_efficient, \
11041104
'--distrib-optim-fully-reshardable-mem-efficient requires -enable-gloo-process-groups'
11051105

1106+
if args.fake_process_group:
1107+
# Disable nan check for fake process group
1108+
args.check_for_nan_in_loss_and_grad = False
1109+
# Disable gloo process groups for fake process group
1110+
args.enable_gloo_process_groups = False
11061111

11071112
# Checkpointing
11081113
if args.ckpt_fully_parallel_save_deprecated and args.rank == 0:
@@ -2746,6 +2751,10 @@ def _add_distributed_args(parser):
27462751
"and must be consistent across all ranks.")
27472752
group.add_argument('--replication-factor', default=2, type=int,
27482753
help="Number of machines storing the replica of a given rank's data.")
2754+
group.add_argument('--fake-process-group', action='store_true', default=False,
2755+
help='If set, initialize with fake distributed process group and all distributed communication operations will be skipped. \
2756+
This is quite useful for profiling memory usage of distributed training with just one GPU. \
2757+
Setting WORLD_SIZE and RANK to the specific values for target distribtued scale.')
27492758
return parser
27502759

27512760

megatron/training/initialize.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,11 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, s
346346
'rank': args.rank,
347347
'timeout': timedelta(minutes=args.distributed_timeout_minutes),
348348
}
349+
if args.fake_process_group:
350+
from torch.testing._internal.distributed.fake_pg import FakeStore
351+
store = FakeStore()
352+
init_process_group_kwargs['backend'] = 'fake'
353+
init_process_group_kwargs['store'] = store
349354

350355
torch.distributed.init_process_group(**init_process_group_kwargs)
351356
inprocess_restart.maybe_force_nccl_backend_init(device_id)

megatron/training/training.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1619,7 +1619,7 @@ def training_log(
16191619
mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict
16201620
)
16211621
if iteration % args.log_interval == 0:
1622-
if args.record_memory_history and is_last_rank():
1622+
if args.record_memory_history and (is_last_rank() or torch.distributed.get_backend() == 'fake'):
16231623
snapshot = torch.cuda.memory._snapshot()
16241624
from pickle import dump
16251625

@@ -1700,7 +1700,9 @@ def training_log(
17001700
num_microbatches = get_num_microbatches()
17011701
report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True)
17021702
report_memory(f'(after {iteration} iterations)')
1703-
report_memory_flag = False
1703+
if iteration > 1:
1704+
# Make sure the memory after the second iteration is reported to include optimizer state memory.
1705+
report_memory_flag = False
17041706
# Write timers to wandb, don't reset the counts
17051707
if args.log_timers_to_tensorboard:
17061708
timers.write(timers_to_log, writer, iteration, normalizer=args.log_interval, reset=False)

megatron/training/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def is_last_rank():
410410

411411
def print_rank_last(message):
412412
"""If distributed is initialized, print only on last rank."""
413-
if torch.distributed.is_initialized():
413+
if torch.distributed.is_initialized() and torch.distributed.get_backend() != 'fake':
414414
if is_last_rank():
415415
print(message, flush=True)
416416
else:

0 commit comments

Comments
 (0)