Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,11 @@ def validate_args(args, defaults={}):
assert not args.distrib_optim_fully_reshardable_mem_efficient, \
'--distrib-optim-fully-reshardable-mem-efficient requires -enable-gloo-process-groups'

if args.fake_process_group:
# Disable nan check for fake process group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Print a warning saying you are overriding these flags?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestions. Updated on #2280

args.check_for_nan_in_loss_and_grad = False
# Disable gloo process groups for fake process group
args.enable_gloo_process_groups = False

# Checkpointing
if args.ckpt_fully_parallel_save_deprecated and args.rank == 0:
Expand Down Expand Up @@ -2866,6 +2871,10 @@ def _add_distributed_args(parser):
"and must be consistent across all ranks.")
group.add_argument('--replication-factor', default=2, type=int,
help="Number of machines storing the replica of a given rank's data.")
group.add_argument('--fake-process-group', action='store_true', default=False,
help='If set, initialize with fake distributed process group and all distributed communication operations will be skipped. \
This is quite useful for profiling memory usage of distributed training with just one GPU. \
Setting WORLD_SIZE and RANK to the specific values for target distribtued scale.')
return parser


Expand Down
5 changes: 5 additions & 0 deletions megatron/training/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,11 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, s
'rank': args.rank,
'timeout': timedelta(minutes=args.distributed_timeout_minutes),
}
if args.fake_process_group:
from torch.testing._internal.distributed.fake_pg import FakeStore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What Pytorch version introduced this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Version check added on #2280

store = FakeStore()
init_process_group_kwargs['backend'] = 'fake'
init_process_group_kwargs['store'] = store

torch.distributed.init_process_group(**init_process_group_kwargs)
inprocess_restart.maybe_force_nccl_backend_init(device_id)
Expand Down
6 changes: 4 additions & 2 deletions megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,7 +1699,7 @@ def training_log(
mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict
)
if iteration % args.log_interval == 0:
if args.record_memory_history and is_last_rank():
if args.record_memory_history and (is_last_rank() or torch.distributed.get_backend() == 'fake'):
snapshot = torch.cuda.memory._snapshot()
from pickle import dump

Expand Down Expand Up @@ -1788,7 +1788,9 @@ def training_log(
num_microbatches = get_num_microbatches()
report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True)
report_memory(f'(after {iteration} iterations)')
report_memory_flag = False
if iteration > 1:
# Make sure the memory after the second iteration is reported to include optimizer state memory.
report_memory_flag = False
# Write timers to wandb, don't reset the counts
if args.log_timers_to_tensorboard:
timers.write(timers_to_log, writer, iteration, normalizer=args.log_interval, reset=False)
Expand Down
2 changes: 1 addition & 1 deletion megatron/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def is_last_rank():

def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if torch.distributed.is_initialized() and torch.distributed.get_backend() != 'fake':
if is_last_rank():
print(message, flush=True)
else:
Expand Down
Loading