Skip to content

Conversation

@Victarry
Copy link
Contributor

@Victarry Victarry commented Nov 17, 2025

What does this PR do ?

Mirror of Dev PR #2254

Add support of fake distributed process groups with --fake-process-group. In this case, all distributed communication operations will be skipped.

  • 🚀 This is quite useful for profiling memory usage and kernel execution time of large scale distributed training with only one GPU.
  • 💡 Memory snapshot and nsys profile is also supported!
  • ❗ NCCL memory cost can't be accounted for.

Motivation derived from https://github.com/Victarry/PyTorch-Memory-Profiler

Usage

Quick startup

Set the target world size and rank to profile, and run the scripts as usual:

export WORLD_SIZE=256 # Target world size you want to profile
export RANK=0 # The specific rank you want to profile
python pretrain_gpt.py \ # Don't need to use torchrun
   .... \
   --fake-process-group

Other useful arguments for benchmarking without running with actual data:

  • --mock-data to use mocked datasets.
  • --tokenizer-type NullTokenizer --vocab-size 32000 to use mocked tokenizer.
  • --moe-router-force-load-balancing to use force router load balancing.

Memory Snapshot Dump

Add below arguments to dump pytorch memory snaptshot:

--memory-snapshot-path ./memory.pickle \
--record-memory-history

Then you could visualize the memory figure on https://docs.pytorch.org/memory_viz

✈️ E2E Examples

## Mixtral 8x7B ##
export MEGATRON_PATH=.
export PYTHONPATH=$MEGATRON_PATH
export CUDA_DEVICE_MAX_CONNECTIONS=1

# Parallel setting configuration
TP=${TP:-1}
PP=${PP:-4}
EP=${EP:-8}
VPP=${VPP:-8}
MBS=${MBS:-1}
GBS=${GBS:-256}

export WORLD_SIZE=64
export RANK=0

python $MEGATRON_PATH/pretrain_gpt.py \
  --tensor-model-parallel-size $TP \
  --pipeline-model-parallel-size $PP \
  --expert-model-parallel-size $EP \
  --num-virtual-stages-per-pipeline-rank $VPP \
  --micro-batch-size $MBS \
  --global-batch-size $GBS \
  --use-distributed-optimizer \
  --overlap-grad-reduce \
  --overlap-param-gather \
  --use-mcore-models \
  --sequence-parallel \
  --disable-bias-linear \
  --train-samples 32768 \
  --transformer-impl transformer_engine \
  --data-cache-path /tmp/data-cache \
  --mock-data \
  --tokenizer-type NullTokenizer \
  --vocab-size 32000 \
  --split 99,1,0 \
  --no-mmap-bin-files \
  --num-workers 6 \
  --untie-embeddings-and-output-weights \
  --position-embedding-type rope \
  --rotary-percent 1.0 \
  --normalization RMSNorm \
  --swiglu \
  --num-layers 32 \
  --hidden-size 4096 \
  --ffn-hidden-size 14336 \
  --num-attention-heads 32 \
  --group-query-attention \
  --num-query-groups 8 \
  --seq-length 4096 \
  --max-position-embeddings 4096 \
  --make-vocab-size-divisible-by 128 \
  --attention-dropout 0.0 \
  --hidden-dropout 0.0 \
  --clip-grad 1.0 \
  --weight-decay 0.1 \
  --lr-decay-samples 255126953 \
  --lr-warmup-samples 162761 \
  --lr 1.2e-5 \
  --min-lr 1.2e-6 \
  --lr-decay-style cosine \
  --adam-beta1 0.9 \
  --adam-beta2 0.95 \
  --num-experts 8 \
  --moe-router-load-balancing-type aux_loss \
  --moe-router-topk 2 \
  --moe-grouped-gemm \
  --moe-aux-loss-coeff 1e-3 \
  --moe-token-dispatcher-type alltoall \
  --init-method-std 0.008 \
  --log-timers-to-tensorboard \
  --log-memory-to-tensorboard \
  --log-num-zeros-in-grad \
  --log-params-norm \
  --log-validation-ppl-to-tensorboard \
  --log-throughput \
  --log-interval 1 \
  --exit-interval 10 \
  --bf16 \
  --moe-router-force-load-balancing \
  --fake-process-group
## Qwen3-235B ##
export MEGATRON_PATH=.
export PYTHONPATH=$MEGATRON_PATH
export CUDA_DEVICE_MAX_CONNECTIONS=1

TP=${TP:-2}
EP=${EP:-32}
PP=${PP:-8}
VPP=${VPP:-4}
MBS=${MBS:-1}
GBS=${GBS:-2048}

export WORLD_SIZE=256
export RANK=0

python $MEGATRON_PATH/pretrain_gpt.py \
  --distributed-timeout-minutes 60 \
  --tensor-model-parallel-size $TP \
  --pipeline-model-parallel-size $PP \
  --expert-model-parallel-size $EP \
  --num-virtual-stages-per-pipeline-rank $VPP \
  --context-parallel-size 1 \
  --expert-tensor-parallel-size 1 \
  --use-distributed-optimizer \
  --no-create-attention-mask-in-dataloader \
  --use-mcore-models \
  --sequence-parallel \
  --use-flash-attn \
  --disable-bias-linear \
  --micro-batch-size $MBS \
  --global-batch-size $GBS \
  --train-samples 32768 \
  --transformer-impl transformer_engine \
  --data-cache-path /tmp/data-cache \
  --mock-data \
  --tokenizer-type NullTokenizer \
  --vocab-size 32000 \
  --split 99,1,0 \
  --no-mmap-bin-files \
  --num-workers 6 \
  --untie-embeddings-and-output-weights \
  --position-embedding-type rope \
  --rotary-percent 1.0 \
  --rotary-base 1000000 \
  --rotary-seq-len-interpolation-factor 1 \
  --normalization RMSNorm \
  --swiglu \
  --norm-epsilon 1e-06 \
  --num-layers 94 \
  --hidden-size 4096 \
  --ffn-hidden-size 12288 \
  --num-attention-heads 64 \
  --group-query-attention \
  --num-query-groups 4 \
  --qk-layernorm \
  --seq-length 4096 \
  --max-position-embeddings 4096 \
  --make-vocab-size-divisible-by 1187 \
  --attention-dropout 0.0 \
  --hidden-dropout 0.0 \
  --clip-grad 1.0 \
  --weight-decay 0.1 \
  --lr-decay-samples 255126953 \
  --lr-warmup-samples 162761 \
  --lr 1.2e-4 \
  --min-lr 1.2e-5 \
  --lr-decay-style cosine \
  --adam-beta1 0.9 \
  --adam-beta2 0.95 \
  --num-experts 128 \
  --moe-ffn-hidden-size 1536 \
  --moe-router-load-balancing-type aux_loss \
  --moe-router-topk 8 \
  --moe-router-pre-softmax \
  --moe-grouped-gemm \
  --moe-aux-loss-coeff 1e-3 \
  --moe-token-dispatcher-type alltoall \
  --moe-permute-fusion \
  --eval-iters 32 \
  --eval-interval 500 \
  --finetune \
  --auto-detect-ckpt-format \
  --no-ckpt-fully-parallel-save \
  --dist-ckpt-strictness log_all \
  --init-method-std 0.02 \
  --log-timers-to-tensorboard \
  --log-memory-to-tensorboard \
  --log-num-zeros-in-grad \
  --log-params-norm \
  --log-validation-ppl-to-tensorboard \
  --log-throughput \
  --log-interval 1 \
  --bf16 \
  --account-for-embedding-in-pipeline-split \
  --account-for-loss-in-pipeline-split \
  --moe-router-force-load-balancing \
  --exit-interval 5 \
  --fake-process-group \
  --record-memory-history \
  --memory-snapshot-path ./qwen3_235b_TP${TP}_PP${PP}_EP${EP}_VPP${VPP}.pickle
## DeepSeek-V3 ##
export MEGATRON_PATH=.
export PYTHONPATH=$MEGATRON_PATH
export CUDA_DEVICE_MAX_CONNECTIONS=1

# DeepSeek-V3 parallel configuration
TP=${TP:-2}
PP=${PP:-16}
EP=${EP:-64}
CP=${CP:-1}
ETP=${ETP:-1}

PP_LAYOUT="Et*3|(tt|)*29|L"
MBS=${MBS:-1}
GBS=${GBS:-8192}

export WORLD_SIZE=1024
export RANK=0

python $MEGATRON_PATH/pretrain_gpt.py \
  --distributed-timeout-minutes 60 \
  --tensor-model-parallel-size $TP \
  --pipeline-model-parallel-size $PP \
  --expert-model-parallel-size $EP \
  --context-parallel-size $CP \
  --expert-tensor-parallel-size $ETP \
  --use-distributed-optimizer \
  --overlap-grad-reduce \
  --overlap-param-gather \
  --use-mcore-models \
  --sequence-parallel \
  --disable-bias-linear \
  --micro-batch-size $MBS \
  --global-batch-size $GBS \
  --train-samples 524288 \
  --exit-duration-in-mins 220 \
  --no-check-for-nan-in-loss-and-grad \
  --no-rope-fusion \
  --manual-gc \
  --manual-gc-interval 1 \
  --transformer-impl transformer_engine \
  --seq-length 4096 \
  --data-cache-path /tmp/dsv3-data-cache \
  --mock-data \
  --tokenizer-type NullTokenizer \
  --vocab-size 32000 \
  --split 99,1,0 \
  --no-mmap-bin-files \
  --no-create-attention-mask-in-dataloader \
  --num-workers 6 \
  --num-layers 61 \
  --pipeline-model-parallel-layout "$PP_LAYOUT" \
  --moe-layer-freq "([0]*3+[1]*58)" \
  --hidden-size 7168 \
  --ffn-hidden-size 18432 \
  --num-attention-heads 128 \
  --kv-channels 128 \
  --max-position-embeddings 4096 \
  --position-embedding-type rope \
  --rotary-base 10000 \
  --make-vocab-size-divisible-by 3232 \
  --normalization RMSNorm \
  --norm-epsilon 1e-6 \
  --swiglu \
  --untie-embeddings-and-output-weights \
  --multi-latent-attention \
  --attention-dropout 0.0 \
  --hidden-dropout 0.0 \
  --clip-grad 1.0 \
  --weight-decay 0.1 \
  --qk-layernorm \
  --lr-decay-samples 524288 \
  --lr-warmup-samples 8192 \
  --lr-warmup-init 3.9e-7 \
  --lr 3.9e-6 \
  --min-lr 3.9e-7 \
  --lr-decay-style cosine \
  --adam-beta1 0.9 \
  --adam-beta2 0.95 \
  --num-experts 256 \
  --moe-ffn-hidden-size 2048 \
  --moe-shared-expert-intermediate-size 2048 \
  --moe-router-load-balancing-type seq_aux_loss \
  --moe-router-topk 8 \
  --moe-token-dispatcher-type alltoall \
  --moe-router-pre-softmax \
  --moe-grouped-gemm \
  --moe-aux-loss-coeff 1e-4 \
  --moe-router-group-topk 4 \
  --moe-router-num-groups 8 \
  --moe-router-topk-scaling-factor 2.5 \
  --moe-router-score-function sigmoid \
  --moe-router-enable-expert-bias \
  --moe-router-bias-update-rate 1e-3 \
  --moe-router-dtype fp32 \
  --moe-permute-fusion \
  --q-lora-rank 1536 \
  --kv-lora-rank 512 \
  --qk-head-dim 128 \
  --qk-pos-emb-head-dim 64 \
  --v-head-dim 128 \
  --rotary-scaling-factor 40 \
  --mscale 1.0 \
  --mscale-all-dim 1.0 \
  --eval-iters 32 \
  --eval-interval 200 \
  --finetune \
  --auto-detect-ckpt-format \
  --dist-ckpt-strictness log_all \
  --init-method-std 0.02 \
  --log-timers-to-tensorboard \
  --log-memory-to-tensorboard \
  --log-num-zeros-in-grad \
  --log-params-norm \
  --log-validation-ppl-to-tensorboard \
  --log-throughput \
  --log-interval 1 \
  --logging-level 40 \
  --bf16 \
  --exit-interval 5 \
  --recompute-granularity selective \
  --recompute-modules moe_act mlp layernorm mla_up_proj \
  --fake-process-group \
  --record-memory-history \
  --moe-router-force-load-balancing \
  --memory-snapshot-path dsv3_TP${TP}_PP${PP}_EP${EP}_CP${CP}_ETP${ETP}_VPP${VPP}.pickle \
  --fp8-format e4m3 \
  --fp8-recipe blockwise

Tested Features

  • Parallelisms
    • PP
    • EP
    • CP
    • DP
    • TP
  • CUDA Graph
  • FP8 computation
  • FP8 param gather
  • 1F1B A2A overlap
  • CPU Offloading

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share discuss a design-doc with the team.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either [email protected] or [email protected].

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Nov 17, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Victarry Victarry self-assigned this Nov 18, 2025
@Victarry Victarry added the Expert Review Apply this label to indicate that your PR is ready for expert review. label Nov 18, 2025
@Victarry
Copy link
Contributor Author

/ok to test c48164b

@ko3n1g ko3n1g added this to the Core 0.16 milestone Nov 18, 2025
@Victarry Victarry force-pushed the denliu/fake_process_group_main branch from 3bf1f30 to ef9c9e7 Compare November 19, 2025 07:57
@Victarry
Copy link
Contributor Author

/ok to test ef9c9e7

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Expert Review Apply this label to indicate that your PR is ready for expert review.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants