Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
27b4756
Add layerwise-nvtx and dummy register model function
kyleliang-nv Oct 17, 2025
271db95
Add pyt_hooks.py
kyleliang-nv Oct 17, 2025
5211d15
Cleanup pyt_hooks
kyleliang-nv Oct 17, 2025
4108f0a
Enable pyt_hooks
kyleliang-nv Oct 17, 2025
e699e00
Guard laywerwise nvtx
kyleliang-nv Oct 17, 2025
c63456b
Code cleanup
kyleliang-nv Oct 17, 2025
22fd4b9
Rename pyt_hooks.py to pytorch_hooks.py
kyleliang-nv Oct 17, 2025
79150f2
Add doc to describle how /start_profile and /end_profile works
kyleliang-nv Oct 17, 2025
13e5faf
Add doc for layerwise nvtx profiling
kyleliang-nv Oct 17, 2025
874ebae
Improve wording of start profile
kyleliang-nv Oct 17, 2025
5484dda
Fix nvtx prefix string
kyleliang-nv Oct 17, 2025
873325e
Fix bad merge
kyleliang-nv Oct 20, 2025
dae27d7
Rename arg to enable-layerwise-nvtx-marker
kyleliang-nv Oct 20, 2025
4a45355
Move printing prefill log from before to after run_batch
kyleliang-nv Oct 20, 2025
2dff0f5
Add unittest for /start_profile + nsys profile
kyleliang-nv Oct 21, 2025
22f7989
Cleanup debug and comments in pytorch_hooks.py
kyleliang-nv Oct 21, 2025
410b91b
Apply suggestions from code review
kyleliang-nv Oct 21, 2025
74939bc
Small fix/cleanup
kyleliang-nv Oct 21, 2025
f3d023e
Revert "Move printing prefill log from before to after run_batch"
kyleliang-nv Nov 5, 2025
80de452
Add test_start_profiler into 1-gpu nightly test
kyleliang-nv Nov 15, 2025
080ed50
Update server arg md
kyleliang-nv Nov 15, 2025
d14c259
Rename pytorch_hooks.py to nvtx_pytorch_hooks.py
kyleliang-nv Nov 15, 2025
a0b9030
Remove LSTM and ReflectionPad in nvtx hooks
kyleliang-nv Nov 15, 2025
fc99921
Merge branch 'main' into feature/layerwise_nvtx
Fridge003 Nov 15, 2025
32944bd
Update test/srt/run_suite.py
Fridge003 Nov 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/advanced_features/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--enable-return-hidden-states` | Enable returning hidden states with responses. | `False` | bool flag (set to enable) |
| `--scheduler-recv-interval` | The interval to poll requests in scheduler. Can be set to >1 to reduce the overhead of this. | `1` | Type: int |
| `--numa-node` | Sets the numa node for the subprocesses. i-th element corresponds to i-th subprocess. | `None` | List[int] |
| `--enable-layerwise-nvtx-marker` | Enable layerwise NVTX profiling annotations for the model. This adds NVTX markers to every layer for detailed per-layer performance analysis with Nsight Systems. | `False` | bool flag (set to enable) |
| `--enable-attn-tp-input-scattered` | Allow input of attention to be scattered when only using tensor parallelism, to reduce the computational load of operations such as qkv latent. | `False` | bool flag (set to enable) |

## Debug tensor dumps
Expand Down
184 changes: 184 additions & 0 deletions docs/developer_guide/benchmark_and_profiling.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,88 @@ You can also combine the above operations into a single command
python3 -m sglang.test.send_one --profile
```

### Profile a server with HTTP API endpoints

SGLang provides HTTP API endpoints to control profiling on a running server. This allows you to start and stop profiling programmatically, which is useful for capturing specific workload patterns.

#### Using `/start_profile` endpoint

The `/start_profile` endpoint starts profiling on the server. You can control when profiling begins and how long it runs using the following parameters:

**Basic usage:**

```bash
# Start profiling immediately for 10 steps
curl -X POST http://127.0.0.1:30000/start_profile \
-H "Content-Type: application/json" \
-d '{
"num_steps": 10
}'
```

**Parameters:**

- `output_dir` (optional): Directory where profile traces will be saved. If not specified, uses `SGLANG_TORCH_PROFILER_DIR` environment variable, or `/tmp` as the default
- `num_steps` (optional): Number of steps to profile. If not specified, profiling continues until manually stopped with `/end_profile`
- `start_step` (optional): Step number at which to start profiling (inclusive). Useful for skipping warmup iterations
- `activities` (optional): List of activities to profile, e.g., `["CPU", "GPU"]`. Default is `["CPU", "GPU"]`
- `merge_profiles` (optional): Whether to merge distributed traces. Default is `false`

**Note on step ranges:** Profiling starts at `start_step` (inclusive) and continues for `num_steps` iterations. For example, with `start_step=3` and `num_steps=10`, profiling captures steps 3, 4, 5, 6, 7, 8, 9, 10, 11, and 12 (10 steps total, starting from step 3).

**Advanced usage with `start_step`:**

```bash
# Wait 5 steps (warmup), then profile for 10 steps
curl -X POST http://127.0.0.1:30000/start_profile \
-H "Content-Type: application/json" \
-d '{
"output_dir": "/tmp/profiles",
"start_step": 5,
"num_steps": 10,
"activities": ["CPU", "GPU"]
}'
```

**Continuous profiling (manual stop):**

```bash
# Start profiling without num_steps - must manually stop with /end_profile
curl -X POST http://127.0.0.1:30000/start_profile
```

#### Using `/end_profile` endpoint

The `/end_profile` endpoint stops an ongoing profiling session and saves the trace file.

```bash
# Stop profiling and save traces
curl -X POST http://127.0.0.1:30000/end_profile
```

This is only needed when you start profiling without specifying `num_steps`. If `num_steps` is specified, profiling will automatically stop after that many steps.

#### Example workflow

```bash
# Terminal 1: Start the server
export SGLANG_TORCH_PROFILER_DIR=/tmp/profiles
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct

# Terminal 2: Start continuous profiling
curl -X POST http://127.0.0.1:30000/start_profile \
-H "Content-Type: application/json" \
-d '{
"start_step": 3
}'

# Terminal 3: Send requests to generate load
python -m sglang.bench_serving --backend sglang --num-prompts 100

# Terminal 2: Stop profiling when done
curl -X POST http://127.0.0.1:30000/end_profile
```

### Profiler Trace Merger for Distributed Traces

SGLang now supports automatic merging of profiling traces from distributed setups with multiple parallelism types (TP, DP, PP, EP). This feature is particularly useful for analyzing performance across distributed runs.
Expand Down Expand Up @@ -259,6 +341,108 @@ Additionally, if you want to locate the SGLang Python source code through the cu
# some critical code
```

### Layer-wise NVTX Profiling with Nsight Systems

SGLang provides built-in layerwise NVTX annotations that can be combined with the CUDA Profiler for detailed per-layer profiling in Nsight Systems. This is particularly useful for identifying performance bottlenecks at the layer level.

#### Using `--enable-layerwise-nvtx-marker` with Nsight Systems and `/start_profile`

The `--enable-layerwise-nvtx-marker` flag automatically adds NVTX markers to every layer in your model. This is particularly powerful when combined with Nsight Systems profiling to see detailed per-layer performance.

**Method 1: Using `/start_profile` with CUDA_PROFILER (for programmatic control)**

This method allows you to control exactly when profiling starts/stops via HTTP API while Nsight Systems is running.

1. Launch the server with layerwise NVTX enabled under Nsight Systems:

```bash
# Terminal 1: Start server with nsys and capture-range option
nsys profile --trace-fork-before-exec=true \
--cuda-graph-trace=node \
--capture-range=cudaProfilerApi \
--capture-range-end=stop \
-o layerwise_profile \
python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B-Instruct \
--enable-layerwise-nvtx-marker \
--disable-cuda-graph
```

Note: NVTX markers are not emitted for kernel launches captured by CUDA graphs. Use `--disable-cuda-graph` to ensure all layerwise NVTX markers are emitted in the trace.

2. In another terminal, control profiling via `/start_profile` with `CUDA_PROFILER` activity:

```bash
# Terminal 2: Wait for server to be ready, then start CUDA profiling
# Wait 3 steps for warmup, then profile for 10 steps
curl -X POST http://127.0.0.1:30000/start_profile \
-H "Content-Type: application/json" \
-d '{
"start_step": 3,
"num_steps": 10,
"activities": ["CUDA_PROFILER"]
}'
```

3. Send requests to generate load:

```bash
# Terminal 3: Generate workload
python -m sglang.bench_serving --backend sglang --num-prompts 100
```

4. Profiling will automatically stop after 10 steps (due to `num_steps: 10`). If you hadn't specified `num_steps`, you would need to manually stop it:

```bash
# Terminal 2: Only needed if num_steps was not specified
curl -X POST http://127.0.0.1:30000/end_profile
```

The `--capture-range=cudaProfilerApi` option tells Nsight Systems to only capture data between `cudaProfilerStart()` and `cudaProfilerStop()` calls (triggered by `/start_profile` and `/end_profile`), reducing overhead and file size. The `start_step` parameter skips the first 3 steps to avoid capturing warmup overhead.

**Method 2: Simpler approach without `/start_profile` API**

For simpler use cases where you don't need fine-grained control over profiling start/stop, you can profile with Nsight Systems capturing the entire workload:

```bash
# Terminal 1: Start server with layerwise NVTX
# Note: --disable-cuda-graph ensures all NVTX markers are emitted
python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B-Instruct \
--enable-layerwise-nvtx-marker \
--disable-cuda-graph

# Terminal 2: Profile the benchmarking client
nsys profile --trace-fork-before-exec=true \
--cuda-graph-trace=node \
-o layerwise_profile \
python -m sglang.bench_serving --backend sglang --num-prompts 10
```

This approach profiles the entire client execution, including all server interactions. The layerwise NVTX markers will be visible in the Nsight Systems timeline.

**Viewing the profiling results:**

Open the generated `.qdrep` file with Nsight Systems:

```bash
nsys-ui layerwise_profile.qdrep
```

In the Nsight Systems GUI, you'll see:
- **NVTX ranges**: Each layer appears as a labeled range in the timeline with detailed information in the marker metadata
- **CUDA kernels**: All GPU kernels are shown alongside the layer annotations
- **Layer hierarchy**: The full module path (e.g., `meta-llama/Meta-Llama-3.1-8B-Instruct.model.layers.0.self_attn.qkv_proj`) helps identify specific layers. The prefix uses the full model path from `--model-path`.
- **Tensor shapes**: Input/output dimensions and parameter shapes are included in the NVTX marker data

**Benefits of layerwise NVTX profiling:**

- **Granular visibility**: See exactly which layers are taking the most time
- **Memory tracking**: Identify layers with large memory allocations
- **Bottleneck identification**: Quickly locate inefficient operations
- **Communication overhead**: In multi-GPU setups, see per-layer communication costs
- **Development debugging**: Validate that model architecture changes have the expected performance impact

## Other tips

1. You can benchmark a model using dummy weights by only providing the config.json file. This allows for quick testing of model variants without training. To do so, add `--load-format dummy` to the above commands and then you only need a correct `config.json` under the checkpoint folder.
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@
slow_rank_detector,
xpu_has_xmx_support,
)
from sglang.srt.utils.nvtx_pytorch_hooks import PytHooks
from sglang.srt.utils.offloader import (
create_offloader_from_server_args,
get_offloader,
Expand Down Expand Up @@ -772,6 +773,11 @@ def load_model(self):

get_offloader().post_init()

# Register model for layerwise NVTX profiling if enabled
if self.server_args.enable_layerwise_nvtx_marker:
self.pyt_hooks = PytHooks()
self.pyt_hooks.register_hooks(self.model, module_prefix="model")

if self.server_args.kv_cache_dtype == "fp8_e4m3":
if self.server_args.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ class ServerArgs:
disable_cuda_graph_padding: bool = False
enable_profile_cuda_graph: bool = False
enable_cudagraph_gc: bool = False
enable_layerwise_nvtx_marker: bool = False
enable_nccl_nvls: bool = False
enable_symm_mem: bool = False
disable_flashinfer_cutlass_moe_fp4_allgather: bool = False
Expand Down Expand Up @@ -3240,6 +3241,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Enable garbage collection during CUDA graph capture. If disabled (default), GC is frozen during capture to speed up the process.",
)
parser.add_argument(
"--enable-layerwise-nvtx-marker",
action="store_true",
help="Enable layerwise NVTX profiling annotations for the model.",
)
parser.add_argument(
"--enable-nccl-nvls",
action="store_true",
Expand Down
Loading
Loading