Skip to content

Commit fd27001

Browse files
authored
Merge branch 'main' into users/nzmora/auto-select-moe-kernel-config-main
2 parents e8c21ce + 15de45d commit fd27001

File tree

83 files changed

+141
-12682
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+141
-12682
lines changed

cpp/tests/resources/data/test_model_lora_config.json

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -93,36 +93,6 @@
9393
],
9494
"trtllm_modules_to_hf_modules": {}
9595
},
96-
"auto_parallel_config": {
97-
"world_size": 1,
98-
"gpus_per_node": 8,
99-
"cluster_key": "A100-PCIe-80GB",
100-
"cluster_info": null,
101-
"sharding_cost_model": "alpha_beta",
102-
"comm_cost_model": "alpha_beta",
103-
"enable_pipeline_parallelism": false,
104-
"enable_shard_unbalanced_shape": false,
105-
"enable_shard_dynamic_shape": false,
106-
"enable_reduce_scatter": true,
107-
"builder_flags": null,
108-
"debug_mode": false,
109-
"infer_shape": true,
110-
"validation_mode": false,
111-
"same_buffer_io": {
112-
"past_key_value_(\\d+)": "present_key_value_\\1"
113-
},
114-
"same_spec_io": {},
115-
"sharded_io_allowlist": [
116-
"past_key_value_\\d+",
117-
"present_key_value_\\d*"
118-
],
119-
"fast_reduce": true,
120-
"fill_weights": false,
121-
"parallel_config_cache": null,
122-
"profile_cache": null,
123-
"dump_path": null,
124-
"debug_outputs": []
125-
},
12696
"weight_sparsity": false,
12797
"weight_streaming": false,
12898
"plugin_config": {

examples/models/core/llama/README.md

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,6 @@ trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp16_wq \
132132
--output_dir ./tmp/llama/7B/trt_engines/weight_only/1-gpu/ \
133133
--gemm_plugin auto
134134

135-
# Build LLaMA 7B using 2-way auto parallelism (deprecated).
136-
python convert_checkpoint.py --model_dir ./tmp/llama/7B/ \
137-
--output_dir ./tllm_checkpoint_1gpu_fp16 \
138-
--dtype float16
139-
140-
trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp16 \
141-
--output_dir ./tmp/llama/7B/trt_engines/fp16/2-gpu/ \
142-
--gemm_plugin auto \
143-
--auto_parallel 2
144-
145135
# Build LLaMA 7B using 2-way tensor parallelism.
146136
python convert_checkpoint.py --model_dir ./tmp/llama/7B/ \
147137
--output_dir ./tllm_checkpoint_2gpu_tp2 \

tensorrt_llm/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def _preload_python_lib():
7777
mpi_barrier, mpi_comm, mpi_rank, mpi_world_size,
7878
set_mpi_comm, str_dtype_to_torch, str_dtype_to_trt,
7979
torch_dtype_to_trt)
80-
from .auto_parallel import AutoParallelConfig, auto_parallel
8180
from .builder import BuildConfig, Builder, BuilderConfig, build
8281
from .disaggregated_params import DisaggregatedParams
8382
from .functional import Tensor, constant
@@ -130,8 +129,6 @@ def _preload_python_lib():
130129
'Module',
131130
'functional',
132131
'models',
133-
'auto_parallel',
134-
'AutoParallelConfig',
135132
'quantization',
136133
'tools',
137134
'LLM',

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,10 +391,13 @@ def validate_parallel_config(self):
391391
rank to automatically shard the model. This is just to ensure that other objects in the
392392
runtime that may read parallel_config can do so.
393393
"""
394+
395+
# Set tp_size = self.world_size so that _ParallelConfig.world_size will return the
396+
# correct value (computed as tp_size * pp_size * cp_size). This does not necessarily
397+
# mean that TP will actually be used.
394398
self._parallel_config = _ParallelConfig(
395-
auto_parallel=True, gpus_per_node=self.gpus_per_node
399+
tp_size=self.world_size, gpus_per_node=self.gpus_per_node
396400
)
397-
self._parallel_config.world_size = self.world_size
398401
return self
399402

400403
@model_validator(mode="after")

tensorrt_llm/_torch/device_mesh.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,15 @@ def moe_ep_group_pg(self):
7474
# Access rank
7575
@property
7676
def tp_rank(self) -> int:
77-
assert not self.auto_parallel, "Auto parallel is not currently supported in Ray mode."
7877
return self.tp_group_pg.rank()
7978

8079
@property
8180
def pp_rank(self) -> int:
82-
assert not self.auto_parallel, "Auto parallel is not currently supported in Ray mode."
8381
return self.pp_group_pg.rank()
8482

8583
@property
8684
def cp_rank(self) -> int:
8785
# TODO: WIP
88-
assert not self.auto_parallel, "Auto parallel is not currently supported in Ray mode."
8986
return self.cp_group_pg.rank()
9087

9188
# Access group ranks

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
MultimodalRuntimeData)
2121
from tensorrt_llm.inputs.registry import (create_input_processor,
2222
create_input_processor_with_hash)
23+
from tensorrt_llm.llmapi.llm_args import TorchLlmArgs
2324
from tensorrt_llm.logger import logger
2425
from tensorrt_llm.lora_helper import LoraConfig
2526
from tensorrt_llm.lora_manager import LoraModelConfig
@@ -38,7 +39,6 @@
3839
from ..expert_statistic import ExpertStatistic
3940
from ..memory_buffer_utils import with_shared_pool
4041
from ..metadata import KVCacheParams
41-
from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader
4242
from ..models.modeling_multimodal_utils import filter_mm_token_from_input_ids
4343
from ..models.modeling_utils import DecoderModelForCausalLM
4444
from ..modules.fused_moe.moe_load_balancer import (MoeLoadBalancer,
@@ -52,7 +52,7 @@
5252
from ..utils import (get_model_extra_attrs,
5353
set_per_request_piecewise_cuda_graph_flag,
5454
set_torch_compiling, with_model_extra_attrs)
55-
from .config import PyTorchConfig
55+
from .config import PyTorchConfig, _construct_checkpoint_loader
5656
from .config_utils import is_mla
5757
from .cuda_graph_runner import CUDAGraphRunner
5858
from .guided_decoder import CapturableGuidedDecoder
@@ -131,29 +131,36 @@ def __init__(
131131
*,
132132
model_path: str,
133133
pytorch_backend_config: PyTorchConfig,
134-
checkpoint_loader: BaseCheckpointLoader,
135-
batch_size: int = 8,
136-
max_beam_width: int = 1,
137-
max_num_tokens: int = 8192,
138-
max_seq_len: Optional[int] = None,
139134
mapping: Optional[Mapping] = None,
140135
attn_runtime_features: Optional[AttentionRuntimeFeatures] = None,
141136
dist: Optional[MPIDist] = None,
142137
spec_config: Optional["DecodingBaseConfig"] = None,
143-
sparse_attention_config: Optional["SparseAttentionConfig"] = None,
144-
lora_config: Optional[LoraConfig] = None,
145138
is_draft_model: bool = False,
146139
drafting_loop_wrapper: Optional[Callable[[torch.nn.Module],
147140
torch.nn.Module]] = None,
148141
model: Optional[torch.nn.Module] = None,
142+
llm_args: Optional[TorchLlmArgs] = None,
149143
):
144+
assert llm_args is not None, "llm_args must be provided for PyTorchModelEngine"
145+
150146
self.forward_pass_callable = None
151147
self.ub_buffers = None
152-
self.batch_size = batch_size
148+
(
149+
max_beam_width,
150+
max_num_tokens,
151+
max_seq_len,
152+
max_batch_size,
153+
) = llm_args.get_runtime_sizes()
154+
155+
self.batch_size = max_batch_size
153156
self.max_num_tokens = max_num_tokens
154157
self.max_seq_len = max_seq_len
155158
self.max_beam_width = max_beam_width
156159

160+
checkpoint_loader = _construct_checkpoint_loader(
161+
llm_args.backend, llm_args.checkpoint_loader,
162+
llm_args.checkpoint_format)
163+
157164
self.mapping = mapping
158165
if mapping.has_pp():
159166
init_pp_comm(mapping)
@@ -171,7 +178,7 @@ def __init__(
171178
spec_config.max_total_draft_tokens = 0
172179
self.spec_config = spec_config
173180
self.is_spec_decode = spec_config is not None
174-
self.sparse_attention_config = sparse_attention_config
181+
self.sparse_attention_config = None if is_draft_model else llm_args.sparse_attention_config
175182
self.enable_spec_decode = self.is_spec_decode
176183
self.is_draft_model = is_draft_model
177184

@@ -181,13 +188,15 @@ def __init__(
181188
self.input_processor_with_hash = create_input_processor_with_hash(
182189
self.input_processor)
183190
if model is None:
191+
lora_config: Optional[
192+
LoraConfig] = None if is_draft_model else llm_args.lora_config
184193
loader = ModelLoader(
185194
pytorch_backend_config=pytorch_backend_config,
186195
mapping=self.mapping,
187196
spec_config=self.spec_config,
188197
sparse_attention_config=self.sparse_attention_config,
189-
max_num_tokens=max_num_tokens,
190-
max_seq_len=max_seq_len,
198+
max_num_tokens=self.max_num_tokens,
199+
max_seq_len=self.max_seq_len,
191200
lora_config=lora_config,
192201
)
193202
self.model, moe_load_balancer = loader.load(
@@ -273,29 +282,27 @@ def __init__(
273282

274283
self.attn_backend = get_attention_backend(
275284
pytorch_backend_config.attn_backend,
276-
sparse_attn_config=sparse_attention_config)
285+
sparse_attn_config=self.sparse_attention_config)
277286

278287
if self.is_spec_decode:
279288
self.spec_metadata = None
280289
update_spec_config_from_model_config(self.spec_config,
281290
self.model.config)
282-
max_num_draft_tokens = self.original_max_total_draft_tokens * batch_size
291+
max_num_draft_tokens = self.original_max_total_draft_tokens * self.batch_size
283292
self.draft_tokens_cuda = torch.empty((max_num_draft_tokens, ),
284293
dtype=torch.int,
285294
device='cuda')
286295
self.gather_ids_cuda = torch.empty((self.max_num_tokens, ),
287296
dtype=torch.int,
288297
device='cuda')
289-
self.num_accepted_draft_tokens_cuda = torch.empty((batch_size, ),
290-
dtype=torch.int,
291-
device='cuda')
298+
self.num_accepted_draft_tokens_cuda = torch.empty(
299+
(self.batch_size, ), dtype=torch.int, device='cuda')
292300
self.previous_pos_indices_cuda = torch.empty(
293301
(self.max_num_tokens, ), dtype=torch.int, device='cuda')
294302
self.previous_pos_id_offsets_cuda = torch.zeros(
295303
(self.max_num_tokens, ), dtype=torch.int, device='cuda')
296-
self.previous_kv_lens_offsets_cuda = torch.zeros((batch_size, ),
297-
dtype=torch.int,
298-
device='cuda')
304+
self.previous_kv_lens_offsets_cuda = torch.zeros(
305+
(self.batch_size, ), dtype=torch.int, device='cuda')
299306
self.without_logits = self.spec_config.spec_dec_mode.without_logits(
300307
) or self.model_is_wrapped
301308
self.max_draft_len = spec_config.max_draft_len

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from ._util import (KvCacheCreator, _adjust_torch_mem_fraction,
3333
create_py_executor_instance, instantiate_sampler, is_mla,
3434
validate_feature_combination)
35-
from .config import PyTorchConfig, _construct_checkpoint_loader
35+
from .config import PyTorchConfig
3636
from .config_utils import is_mla
3737
from .guided_decoder import CapturableGuidedDecoder, GuidedDecoder
3838
from .kv_cache_connector import KvCacheConnectorManager
@@ -234,11 +234,6 @@ def create_py_executor(
234234
mm_encoder_only = llm_args.mm_encoder_only
235235
enable_chunked_context = llm_args.enable_chunked_prefill
236236

237-
assert llm_args.backend == "pytorch", "_construct_checkpoint_loader expects different parameters for autodeploy"
238-
checkpoint_loader = _construct_checkpoint_loader(llm_args.backend,
239-
llm_args.checkpoint_loader,
240-
llm_args.checkpoint_format)
241-
242237
(
243238
max_beam_width,
244239
max_num_tokens,
@@ -305,8 +300,6 @@ def create_py_executor(
305300
has_draft_model_engine = spec_config.spec_dec_mode.has_draft_model()
306301
has_spec_drafter = spec_config.spec_dec_mode.has_spec_drafter()
307302

308-
sparse_attention_config = llm_args.sparse_attention_config
309-
310303
# chunk_unit_size may be changed to 64 when using flash mla
311304
attn_runtime_features = AttentionRuntimeFeatures(
312305
chunked_prefill=enable_chunked_context,
@@ -322,17 +315,11 @@ def create_py_executor(
322315
model_engine = PyTorchModelEngine(
323316
model_path=checkpoint_dir,
324317
pytorch_backend_config=pytorch_backend_config,
325-
batch_size=max_batch_size,
326-
max_beam_width=max_beam_width,
327-
max_num_tokens=max_num_tokens,
328-
max_seq_len=max_seq_len,
329318
mapping=mapping,
330319
attn_runtime_features=attn_runtime_features,
331320
dist=dist,
332321
spec_config=spec_config,
333-
sparse_attention_config=sparse_attention_config,
334-
lora_config=lora_config,
335-
checkpoint_loader=checkpoint_loader,
322+
llm_args=llm_args,
336323
)
337324

338325
validate_feature_combination(llm_args, model_engine,
@@ -369,19 +356,13 @@ def drafting_loop_wrapper(model):
369356
draft_model_engine = PyTorchModelEngine(
370357
model_path=spec_config.speculative_model_dir,
371358
pytorch_backend_config=draft_pytorch_backend_config,
372-
batch_size=max_batch_size,
373-
max_beam_width=max_beam_width,
374-
max_num_tokens=max_num_tokens,
375-
# Note: The draft model engine will infer its own max_seq_len.
376-
# We'll stop drafting when we hit the max.
377-
max_seq_len=max_seq_len,
378359
mapping=mapping,
379360
attn_runtime_features=attn_runtime_features,
380361
dist=dist,
381362
spec_config=draft_spec_config,
382-
checkpoint_loader=checkpoint_loader,
383363
is_draft_model=True,
384364
drafting_loop_wrapper=drafting_loop_wrapper,
365+
llm_args=llm_args,
385366
)
386367
# For DeepseekV3 MTP, we need to set the num_hidden_layers to 1 for the draft model
387368
if spec_config.spec_dec_mode.is_mtp_eagle():
@@ -574,7 +555,7 @@ def drafting_loop_wrapper(model):
574555
pytorch_backend_config=pytorch_backend_config,
575556
speculative_config=spec_config,
576557
profiling_stage_data=profiling_stage_data,
577-
sparse_attention_config=sparse_attention_config,
558+
sparse_attention_config=llm_args.sparse_attention_config,
578559
)
579560
estimating_kv_cache = kv_cache_creator.try_prepare_estimation()
580561
with mem_monitor.observe_creation_stage(

tensorrt_llm/_utils.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import trace
2626
import weakref
2727
from contextlib import contextmanager
28-
from dataclasses import asdict
2928
from enum import EnumMeta
3029
from functools import lru_cache, partial, wraps
3130
from pathlib import Path
@@ -799,38 +798,6 @@ def localtrace(frame, why, arg):
799798
return wrapper
800799

801800

802-
class DictConversion:
803-
804-
@classmethod
805-
def from_dict(cls, config: Dict[str, Any]):
806-
obj = cls()
807-
fields = obj.__dataclass_fields__
808-
for key, value in config.items():
809-
assert hasattr(obj, key), f"cannot find {key} in {obj}"
810-
field_cls = fields[key].type
811-
if (isinstance(field_cls, type)
812-
and issubclass(field_cls, DictConversion)
813-
and isinstance(value, dict)):
814-
value = field_cls.from_dict(value)
815-
setattr(obj, key, value)
816-
return obj
817-
818-
def to_dict(self):
819-
return asdict(self)
820-
821-
@classmethod
822-
def from_json_file(cls, file):
823-
with open(file) as f:
824-
return cls.from_dict(json.load(f))
825-
826-
def set_defaults(self, **kwargs):
827-
for key, default in kwargs.items():
828-
value = getattr(self, key)
829-
if (value is None
830-
or (isinstance(value, (list, dict)) and len(value) == 0)):
831-
setattr(self, key, default)
832-
833-
834801
class BaseEnumMeta(EnumMeta):
835802

836803
def __contains__(cls, item):

tensorrt_llm/auto_parallel/__init__.py

Lines changed: 0 additions & 9 deletions
This file was deleted.

0 commit comments

Comments
 (0)