Skip to content

Conversation

@alpha-baby
Copy link
Contributor

@alpha-baby alpha-baby commented Nov 14, 2025

why fix this?

model: deepseek v3.1

I'm using a model parallelism configuration of TP1 and PP8 for my large model.

image image

My machine's GPU memory is very large, but max_total_num_tokens is very small, which is not in line with expectations.

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.148.08             Driver Version: 570.148.08     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L20X                    On  |   00000000:2E:00.0 Off |                    0 |
| N/A   29C    P0            121W /  700W |   57184MiB / 143771MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA L20X                    On  |   00000000:58:00.0 Off |                    0 |
| N/A   27C    P0            116W /  700W |   84898MiB / 143771MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA L20X                    On  |   00000000:60:00.0 Off |                    0 |
| N/A   28C    P0            112W /  700W |   84898MiB / 143771MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA L20X                    On  |   00000000:67:00.0 Off |                    0 |
| N/A   28C    P0            115W /  700W |   84898MiB / 143771MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA L20X                    On  |   00000000:9B:00.0 Off |                    0 |
| N/A   27C    P0            113W /  700W |   84898MiB / 143771MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA L20X                    On  |   00000000:BB:00.0 Off |                    0 |
| N/A   28C    P0            112W /  700W |   84898MiB / 143771MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA L20X                    On  |   00000000:CA:00.0 Off |                    0 |
| N/A   27C    P0            115W /  700W |   84898MiB / 143771MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA L20X                    On  |   00000000:DA:00.0 Off |                    0 |
| N/A   28C    P0            116W /  700W |  139996MiB / 143771MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A          506726      C   sglang::scheduler_PP0                 57174MiB |
|    1   N/A  N/A          506727      C   sglang::scheduler_PP1                 84888MiB |
|    2   N/A  N/A          506728      C   sglang::scheduler_PP2                 84888MiB |
|    3   N/A  N/A          506729      C   sglang::scheduler_PP3                 84888MiB |
|    4   N/A  N/A          506730      C   sglang::scheduler_PP4                 84888MiB |
|    5   N/A  N/A          506731      C   sglang::scheduler_PP5                 84888MiB |
|    6   N/A  N/A          506732      C   sglang::scheduler_PP6                 84888MiB |
|    7   N/A  N/A          506733      C   sglang::scheduler_PP7                 13998... |
+-----------------------------------------------------------------------------------------+
image

I found the root cause of this problem through analysis.

I observed that the distribution of layer in each PP rank is particularly uneven, but the memory occupation is particularly different.

fix two point

Fix the uneven layer problem when PP
only pp last rank init ParallelLMHead

thx ! Co auther @XucSh

thx ! Co test @whybeyoung

Motivation

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Co-authored-by: Xuchun Shang <[email protected]>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @alpha-baby, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refines the implementation of pipeline parallelism (PP) by addressing two key areas. It introduces a more balanced strategy for distributing model layers across PP partitions, ensuring that extra layers are spread out rather than concentrated on a single partition. Additionally, it optimizes the initialization of the language model head, ensuring that this final output layer is only instantiated on the last stage of the pipeline, thereby improving resource efficiency.

Highlights

  • Improved Layer Distribution for Pipeline Parallelism: The get_pp_indices utility function has been updated to distribute remaining layers more evenly across the first few pipeline parallel partitions when the total number of layers is not perfectly divisible. Previously, all remaining layers were assigned to the last partition.
  • Optimized LM Head Initialization for DeepseekV2: The ParallelLMHead for the DeepseekV2 model is now conditionally initialized. It will only be created on the last pipeline parallel rank, while other ranks will receive a PPMissingLayer placeholder. This optimizes memory usage and aligns with standard pipeline parallelism practices.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two key improvements for pipeline parallelism. First, it corrects the uneven distribution of model layers across pipeline stages by ensuring any remainder layers are distributed one by one to the initial stages, leading to a more balanced workload. Second, it optimizes memory usage by initializing the lm_head only on the last pipeline stage, using a placeholder for other stages. The changes are logical and well-implemented. I have one suggestion to refactor the layer distribution logic for improved clarity and maintainability.

Comment on lines +86 to +98
base_layers = num_hidden_layers // pp_size
remainder = num_hidden_layers % pp_size
# Distribute the extra layers to the first 'remainder' partitions
if pp_rank < remainder:
# This partition gets one extra layer
start_layer = pp_rank * (base_layers + 1)
end_layer = start_layer + (base_layers + 1)
else:
# This partition gets only base layers
start_layer = (
remainder * (base_layers + 1) + (pp_rank - remainder) * base_layers
)
end_layer = start_layer + base_layers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for calculating start_layer and end_layer is correct, but it can be simplified for better readability and maintainability. You can calculate start_layer with a single, more intuitive formula and then determine end_layer based on whether the current rank receives an extra layer.

Suggested change
base_layers = num_hidden_layers // pp_size
remainder = num_hidden_layers % pp_size
# Distribute the extra layers to the first 'remainder' partitions
if pp_rank < remainder:
# This partition gets one extra layer
start_layer = pp_rank * (base_layers + 1)
end_layer = start_layer + (base_layers + 1)
else:
# This partition gets only base layers
start_layer = (
remainder * (base_layers + 1) + (pp_rank - remainder) * base_layers
)
end_layer = start_layer + base_layers
base_layers = num_hidden_layers // pp_size
remainder = num_hidden_layers % pp_size
# Each rank `i` is assigned `base_layers + (1 if i < remainder else 0)` layers.
# The start layer for `pp_rank` is the sum of layers for all previous ranks.
start_layer = pp_rank * base_layers + min(pp_rank, remainder)
if pp_rank < remainder:
# This partition gets an extra layer.
end_layer = start_layer + base_layers + 1
else:
# This partition gets base layers.
end_layer = start_layer + base_layers

@XucSh XucSh added the run-ci label Nov 14, 2025
@whybeyoung
Copy link
Collaborator

LGTM

@ShangmingCai ShangmingCai self-assigned this Nov 17, 2025
Copy link
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic LGTM, this can help us address the unbalanced weight distribution when num_hidden_layers is an odd number (such as 61 layers of DeepSeek, when pp size is 8, the original impl gonna get 7, 7, 7, 7, 7, 7, 7, 12, new impl will be 8, 8, 8, 8, 8, 7, 7, 7). Good enough as a bugfix when SGLANG_PP_LAYER_PARTITION is not set.

nit: wonder if there will be a performance difference when the remainder is placed behind.

@alpha-baby
Copy link
Contributor Author

Logic LGTM, this can help us address the unbalanced weight distribution when num_hidden_layers is an odd number (such as 61 layers of DeepSeek, when pp size is 8, the original impl gonna get 7, 7, 7, 7, 7, 7, 7, 12, new impl will be 8, 8, 8, 8, 8, 7, 7, 7). Good enough as a bugfix when SGLANG_PP_LAYER_PARTITION is not set.

nit: wonder if there will be a performance difference when the remainder is placed behind.

According to theoretical analysis!

If each layer is the same, the performance will be the same when the remainder is placed behind.

@ShangmingCai ShangmingCai requested a review from ch-wan as a code owner November 17, 2025 07:56
@ShangmingCai
Copy link
Collaborator

b200 and gb200 ci are flaky when this PR is created (fixed this morning), other tests have passed: https://github.com/sgl-project/sglang/actions/runs/19390457057/job/55561045148

Since this PR also modified some deepseek code, let me merge the main to see if we can pass b200 and gb200 for safety. I will merge it once these two tests pass.

@ShangmingCai
Copy link
Collaborator

image Pretty sure that deepseek test could pass, but b200, gb200, ut-1-gpu(1) are failing too much because of ckpt issues. Let me force-merge this PR.

@ShangmingCai ShangmingCai merged commit ac406d4 into sgl-project:main Nov 17, 2025
274 of 308 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants