Skip to content

Commit 742d1dd

Browse files
Merge branch 'main' into user/nzmora/add_mem_logs
2 parents 90eeb01 + cf8a1d2 commit 742d1dd

File tree

21 files changed

+494
-214
lines changed

21 files changed

+494
-214
lines changed

examples/disaggregated/README.md

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,39 @@ srun -A <account> -p <partition> -t <time> \
204204
Additionally, we offer a fully executable script—please refer to [Disaggregated SLURM Scripts](./slurm/simple_example/).
205205

206206

207-
## Dynamic scaling (Prototype)
207+
## Dynamic scaling
208+
209+
### Service discovery method
210+
211+
Disaggregated server also supports dynamic service-discovery and auto-scaling of context/generation servers. This can be achieved by setting `disagg_cluster` section in the configurations of both context/generation servers and disagg-server. In this case, the context/generation servers must include an extra command line of `--server-role=[context|generation]`, also the `context/genration_servers` section of disaggregated server must be removed. You can simplify context/generation servers' config section by only passing `--disagg_cluster_uri=<disagg_cluster_uri>` in the command line (but disaggregated server's config must have this section). The omitted fields will use the defaults shown below.
212+
213+
```yaml
214+
disagg_cluster:
215+
cluster_uri: <your_cluster_uri>
216+
cluster_name: ""
217+
minimal_instances:
218+
context_servers: 1
219+
generation_servers: 1
220+
heartbeat_interval_sec: 5
221+
inactive_interval_sec: 10
222+
```
223+
- `cluster_uri`: the http address of disagg-server like `http://<your-disagg-server-host>:<your-disagg-server-port>` or a pre-configured Etcd server address like `etcd://<your-etcd-host>:2379`.
224+
- `cluster_name` : optional namespace to isolate multiple disagg-clusters in Etcd.
225+
- `minimal_instances`: the equivalence of `num_instances` in the auto-scaling concept, disagg-server will reject requests when
226+
the active context/generation servers is below the corresponding threshold.
227+
- `heartbeat_interval_sec`: frequency at which context/generation servers send heartbeats to the disagg-server.
228+
- `inactive_interval_sec`: A server is marked inactive if no heartbeat is received within this interval (set higher than the heartbeat interval).
229+
230+
Note that the disaggregated server and all the context/generation servers should have the same `disagg_cluster` configuration values, or the disaggregated server may not be able to keep alive or detect inactivity the other servers properly. If `disagg_cluster` section is specified,
231+
232+
Additionally, we offer a fully executable script—please refer to [Disaggregated SLURM Scripts](./slurm/service_discovery_example/).
233+
234+
#### Dynamically adding servers
235+
236+
To add servers dynamically, you can start more context/generation workers with the same `disagg_cluster`, then the disaggregated server can discover the new servers and dispatch requests to them automatically. If a context/generation server becomes inactive, the disaggregated server will also detect this and stop routing requests to it.
237+
238+
239+
### Metadata server method (Prototype)
208240

209241
Currently, trtllm supports dynamic addition and removal of servers by leveraging ETCD. To enable this feature, you should start the context and generation servers with an additional flag ```--metadata_server_config_file``` and ```--server_role```.
210242
Before launching the context and generation servers, you should first start the ETCD server. By default, the ETCD server listens for client requests at ```localhost:2379```.
@@ -240,7 +272,7 @@ refersh_interval: 10.0
240272

241273
The ```hostname``` and ```port``` must match those used when starting the ETCD server. The ```health_check_timeout``` parameter specifies how long a server will be considered dead if no healthy response is received. By default, trtllm will perform two checks before marking a server as dead. The ```refresh_interval``` parameter determines how often the latest server list is fetched from the ETCD server.
242274

243-
### Dynamically adding servers
275+
#### Dynamically adding servers
244276

245277
Users can add servers by directly launching them with trtllm-serve. For example, you can start an additional generation server as follows:
246278

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
bin/bash
2+
#SBATCH --partition=${partition}
3+
#SBATCH --account=${account}
4+
#SBATCH --job-name=${job_name}
5+
#SBATCH --time=02:00:00
6+
7+
container_image="${container_image:-}"
8+
mount_paths="${mount_paths:-}"
9+
work_path="${work_path:-}"
10+
enable_etcd="${enable_etcd:-0}"
11+
disagg_port="8000"
12+
ctx_port="8001"
13+
gen_port="8002"
14+
15+
# use the first node as the disaggregated server node
16+
disagg_server_node=$(head -n 1 <(scontrol show hostnames $SLURM_JOB_NODELIST))
17+
18+
if [[ "$enable_etcd" == "1" ]]; then
19+
# you can optionally launch a etcd server, the container image must have etcd installed
20+
disagg_cluster_uri="etcd://${disagg_server_node}:2379"
21+
srun --container-image=${container_image} \
22+
--container-mounts=${mount_paths} \
23+
-w $disagg_server_node -N 1 --ntasks-per-node=1 \
24+
--mpi=pmix \
25+
bash -c "etcd" &
26+
sleep 5 # wait for etcd to start
27+
else
28+
# or use the disaggregated server's http address as built-in service discovery server
29+
disagg_cluster_uri="http://${disagg_server_node}:${disagg_port}"
30+
fi
31+
32+
cat >${work_path}/disagg_config.yaml << EOL
33+
hostname: localhost
34+
port: ${disagg_port}
35+
backend: pytorch
36+
disagg_cluster:
37+
cluster_uri: ${disagg_cluster_uri}
38+
cluster_name: example_cluster
39+
EOL
40+
41+
cat >${work_path}/ctx_extra-llm-api-config.yaml << EOL
42+
disable_overlap_scheduler: True
43+
cache_transceiver_config:
44+
backend: UCX
45+
max_tokens_in_buffer: 2048
46+
EOL
47+
48+
cat >${work_path}/gen_extra-llm-api-config.yaml << EOL
49+
cache_transceiver_config:
50+
backend: UCX
51+
max_tokens_in_buffer: 2048
52+
EOL
53+
54+
# Launch a proxy without any context/generation servers.
55+
srun --container-image=${container_image} \
56+
--container-mounts=${mount_paths} \
57+
-w $disagg_server_node -N 1 --ntasks-per-node=1 \
58+
--mpi=pmix \
59+
bash -c "trtllm-llmapi-launch trtllm-serve disaggregated -c ${work_path}/disagg_config.yaml" &
60+
61+
# Launch a context with `tp_size=8` using two 4-GPU nodes, and register itself through disagg_cluster_uri
62+
srun --container-image=${container_image} \
63+
--container-mounts=${mount_paths} \
64+
-N 2 --ntasks-per-node=4 \
65+
--mpi=pmix \
66+
bash -c "trtllm-llmapi-launch trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --tp_size 8 --host 0.0.0.0 --port ${ctx_port} --extra_llm_api_options ${work_path}/ctx_extra-llm-api-config.yaml --disagg_cluster_uri ${disagg_cluster_uri} --server-role context" &
67+
68+
# Launch a generation with `tp_size=4` using one 4-GPU node.
69+
srun --container-image=${container_image} \
70+
--container-mounts=${mount_paths} \
71+
-N 1 --ntasks-per-node=4 \
72+
--mpi=pmix \
73+
bash -c "trtllm-llmapi-launch trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --tp_size 4 --host 0.0.0.0 --port ${gen_port} --extra_llm_api_options ${work_path}/gen_extra-llm-api-config.yaml --disagg_cluster_uri ${disagg_cluster_uri} --server-role generation" &

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 2 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22
import os
33
from typing import Any, Dict, List, Optional, Tuple, Union
44

5-
import numpy as np
65
import torch
76
import torch.nn as nn
8-
from PIL import Image
97
from torch.nn import functional as F
108
from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig,
119
PreTrainedModel)
@@ -31,7 +29,6 @@
3129
ExtraProcessedInputs, InputProcessor,
3230
MultimodalPlaceholderMetadata,
3331
MultimodalPlaceholderPlacement, TextPrompt,
34-
default_multimodal_input_loader,
3532
register_input_processor)
3633
from ...logger import logger
3734
from ...sampling_params import SamplingParams
@@ -95,6 +92,8 @@ def __init__(self,
9592
model_config: PretrainedConfig,
9693
tokenizer: AutoTokenizer,
9794
trust_remote_code: bool = True):
95+
96+
super().__init__()
9897
self.model_config = model_config
9998
self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
10099
model_path)
@@ -284,81 +283,6 @@ def get_rope_index(
284283
mrope_position_deltas, device=input_ids.device).unsqueeze(1)
285284
return position_ids, mrope_position_deltas
286285

287-
def get_dummy_text(self, input_seq_len: int) -> str:
288-
ids = np.random.randint(
289-
low=0,
290-
high=int(
291-
self.model_config.vocab_size), # high is exclusive in NumPy
292-
size=input_seq_len,
293-
).tolist()
294-
return self.tokenizer.decode(ids, skip_special_tokens=True)
295-
296-
def get_dummy_image(self, max_width: int, max_height: int):
297-
image = Image.new("RGB", (max_width, max_height), color=255)
298-
return image
299-
300-
def get_dummy_prompt(self, input_seq_len: int):
301-
text = ""
302-
# we use the max resolution as starting point
303-
img_max_dim = 3584
304-
image = self.get_dummy_image(max_width=img_max_dim,
305-
max_height=img_max_dim)
306-
307-
test_mm_prompt = default_multimodal_input_loader(
308-
tokenizer=self.tokenizer,
309-
model_dir=self.model_path,
310-
model_type=self.model_config.model_type,
311-
modality="image",
312-
prompts=[text],
313-
media=[[image]],
314-
image_data_format="pt")[0]
315-
316-
prompt_token_ids_single_img, _ = self(test_mm_prompt, None)
317-
318-
# if the max img resolution results in a number of tokens greater then
319-
# input_seq_len, we keep lowering the resolution such as to find the
320-
# max resolution such as it does not exceed the input_seq_len
321-
while len(prompt_token_ids_single_img) > input_seq_len:
322-
# reduce img resolution
323-
img_max_dim = img_max_dim >> 1
324-
325-
image = self.get_dummy_image(max_width=img_max_dim,
326-
max_height=img_max_dim)
327-
328-
test_mm_prompt = default_multimodal_input_loader(
329-
tokenizer=self.tokenizer,
330-
model_dir=self.model_path,
331-
model_type=self.model_config.model_type,
332-
modality="image",
333-
prompts=[text],
334-
media=[[image]],
335-
image_data_format="pt")[0]
336-
337-
prompt_token_ids_single_img, _ = self(test_mm_prompt, None)
338-
339-
len_prompt_tokens_ids = len(prompt_token_ids_single_img)
340-
# There are corner cases where if we strictly try to generate a text based
341-
# on how many tokens we need to complete the input_seq_len, the output of
342-
# default_multimodal_input_loader may give more tokens then the input_seq_len and this
343-
# can lead to errors.
344-
# That is why we try to clip the variable text_token_left to a lower threshold
345-
# but close enough to the actual input_seq_len
346-
text_generation_perc_threshold = 0.95
347-
text_token_left = int((input_seq_len - len_prompt_tokens_ids) *
348-
text_generation_perc_threshold)
349-
350-
if text_token_left > 0:
351-
text = self.get_dummy_text(text_token_left)
352-
353-
return default_multimodal_input_loader(
354-
tokenizer=self.tokenizer,
355-
model_dir=self.model_path,
356-
model_type=self.model_config.model_type,
357-
modality="image",
358-
prompts=[text],
359-
media=[[image]],
360-
image_data_format="pt")[0]
361-
362286
def _preprocess(self, text: dict[str, any], mm_data: dict[str, any],
363287
mm_processor_kwargs: Dict[str, Any]):
364288
images = mm_data.get("image")
@@ -1018,7 +942,6 @@ def forward(
1018942

1019943
mm_embeds = find_input_mm_embeds(
1020944
mm_embeds, multimodal_params[:num_context_requests])
1021-
1022945
if not self.model_config.pretrained_config.disable_fuse_rope:
1023946
mrope_config = self.prepare_mrope_config(
1024947
multimodal_params, num_context_requests)

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 67 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66

77
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe
88
from tensorrt_llm._torch.distributed.moe_alltoall import MoeAlltoAll
9+
from tensorrt_llm.logger import logger
910

1011
from ...distributed import allgather
1112
from ...model_config import ModelConfig
1213
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor, ceil_div
13-
from .interface import MoE
14+
from .interface import AlltoallMethodType, MoE
1415

1516
# isort: off
1617
from .quantization import (
@@ -140,28 +141,44 @@ def __init__(
140141
self.has_been_profiled_min_latency = False
141142

142143
# TODO: AlltoAll code is largely duplicated with WideEPMoE. Consider refactor and reuse in the future.
144+
self.alltoall_method_type = self.select_alltoall_method_type()
145+
logger.info_once(
146+
f"{self.__class__.__name__} selects alltoall_method_type {self.alltoall_method_type!r}",
147+
key="alltoall_method_type")
143148
self.alltoall_workspace = None
144149
self.alltoall_prepare_workspace = None
150+
self.use_low_precision_combine = False
145151
if self.enable_alltoall:
146-
if self.moe_alltoall_backend == "mnnvllatency":
147-
MnnvlMemory.initialize()
148-
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
149-
model_config.mapping)
150-
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
151-
model_config.mapping)
152-
elif self.moe_alltoall_backend == "mnnvlthroughput":
153-
workspace_mb = int(
154-
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "512"))
155-
self.moe_a2a = MoeAlltoAll(
156-
mapping=self.mapping,
157-
max_num_tokens_per_rank=model_config.max_num_tokens,
158-
top_k=self.routing_method.experts_per_token,
159-
num_experts=self.num_experts,
160-
workspace_size_per_rank=workspace_mb * 1024 * 1024,
152+
self.use_low_precision_combine = model_config.use_low_precision_moe_combine
153+
154+
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
155+
if self.moe_alltoall_backend == "mnnvllatency":
156+
MnnvlMemory.initialize()
157+
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
158+
model_config.mapping)
159+
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
160+
model_config.mapping)
161+
elif self.moe_alltoall_backend == "mnnvlthroughput":
162+
workspace_mb = int(
163+
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "512"))
164+
self.moe_a2a = MoeAlltoAll(
165+
mapping=self.mapping,
166+
max_num_tokens_per_rank=model_config.max_num_tokens,
167+
top_k=self.routing_method.experts_per_token,
168+
num_experts=self.num_experts,
169+
workspace_size_per_rank=workspace_mb * 1024 * 1024,
170+
)
171+
else:
172+
raise ValueError(
173+
f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}"
174+
)
175+
elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
176+
raise NotImplementedError(
177+
"DeepEP and DeepEPLowLatency are not supported for CutlassFusedMoE yet"
161178
)
162179
else:
163-
raise ValueError(
164-
f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}"
180+
raise NotImplementedError(
181+
f"Not available alltoall method type: {self.alltoall_method_type!r}"
165182
)
166183

167184
# If True, the router weight will be multiplied on the input rather than at the end of FC2
@@ -204,13 +221,38 @@ def has_int8_woq_per_channel(self):
204221
return self.quant_config.layer_quant_mode.is_int8_weight_only(
205222
) and not self.quant_config.layer_quant_mode.has_per_group_scaling()
206223

224+
def select_alltoall_method_type(self) -> AlltoallMethodType:
225+
all2all_method_type = os.environ.get("TRTLLM_FORCE_ALLTOALL_METHOD")
226+
if all2all_method_type is not None:
227+
if AlltoallMethodType[all2all_method_type] in [
228+
AlltoallMethodType.DeepEP,
229+
AlltoallMethodType.DeepEPLowLatency
230+
]:
231+
raise NotImplementedError(
232+
"DeepEP and DeepEPLowLatency are not supported for CutlassFusedMoE yet"
233+
)
234+
return AlltoallMethodType[all2all_method_type]
235+
236+
if not self.mapping.enable_attention_dp:
237+
return AlltoallMethodType.NotEnabled
238+
239+
if self.mapping.tp_size == 1:
240+
return AlltoallMethodType.NotEnabled
241+
242+
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
243+
return AlltoallMethodType.NotEnabled
244+
245+
if not (self.mapping.moe_ep_size > self.routing_method.experts_per_token
246+
and MnnvlMemory.supports_mnnvl()):
247+
return AlltoallMethodType.NotEnabled
248+
249+
return AlltoallMethodType.MNNVL
250+
207251
@cached_property
208252
def enable_alltoall(self):
209-
return (self.mapping.moe_ep_size > self.routing_method.experts_per_token
210-
and self.mapping.enable_attention_dp
211-
and self.mapping.tp_size > 1
212-
and os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") != "1"
213-
and MnnvlMemory.supports_mnnvl())
253+
""" enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
254+
"""
255+
return self.alltoall_method_type != AlltoallMethodType.NotEnabled
214256

215257
@cached_property
216258
def moe_alltoall_backend(self):
@@ -510,6 +552,8 @@ def forward_chunk(
510552
ep_rank=self.ep_rank,
511553
ep_size=self.ep_size,
512554
top_k=top_k,
555+
use_low_precision_combine=self.
556+
use_low_precision_combine,
513557
token_count=token_count)
514558
elif self.moe_alltoall_backend == "mnnvlthroughput":
515559
hidden = final_hidden_states.shape[-1]

0 commit comments

Comments
 (0)