Skip to content

Commit 33dbcd7

Browse files
committed
Support MistralLarge3 model
Signed-off-by: bhsueh <[email protected]>
1 parent 974ad56 commit 33dbcd7

30 files changed

+1672
-143
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ def add_llm_args(parser):
2323
type=str,
2424
nargs="+",
2525
help="A single or a list of text prompts.")
26+
parser.add_argument('--checkpoint_format',
27+
type=str,
28+
default=None,
29+
help="Model checkpoint format.")
2630
# Build config
2731
parser.add_argument("--max_seq_len",
2832
type=int,
@@ -237,6 +241,7 @@ def setup_llm(args, **kwargs):
237241
llm = LLM(
238242
model=args.model_dir,
239243
backend='pytorch',
244+
checkpoint_format=args.checkpoint_format,
240245
disable_overlap_scheduler=args.disable_overlap_scheduler,
241246
kv_cache_config=kv_cache_config,
242247
attn_backend=args.attention_backend,
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Mistral Large V3
2+
3+
* Setup the model path
4+
5+
```bash
6+
export mistral_large_3_model_path=<mistral_large_3_model_path>
7+
export mistral_large_3_eagle_model_path=<mistral_large_3_eagle_model_path>
8+
```
9+
10+
## LLM-only run
11+
12+
* Run the Mistral Large V3 by `quickstart_advanced.py`
13+
14+
```bash
15+
mpirun -n 1 --allow-run-as-root --oversubscribe python3 examples/llm-api/quickstart_advanced.py \
16+
--model_dir ${mistral_large_3_model_path} \
17+
--tp_size 4 \
18+
--moe_ep_size 4 \
19+
--max_tokens 100 \
20+
--checkpoint_format mistral_large_3 \
21+
--kv_cache_fraction 0.25 \
22+
--moe_backend TRTLLM # optional
23+
```
24+
25+
* Run the Mistral Large V3 by `quickstart_advanced.py` with Eagle3.
26+
27+
```bash
28+
mpirun -n 1 --allow-run-as-root --oversubscribe python3 examples/llm-api/quickstart_advanced.py \
29+
--model_dir ${mistral_large_3_model_path} \
30+
--tp_size 4 \
31+
--moe_ep_size 4 \
32+
--max_tokens 10 \
33+
--checkpoint_format mistral_large_3 \
34+
--kv_cache_fraction 0.25 \
35+
--disable_kv_cache_reuse \
36+
--spec_decode_algo EAGLE3 \
37+
--spec_decode_max_draft_len 1 \
38+
--use_one_model \
39+
--draft_model_dir ${mistral_large_3_eagle_model_path} \
40+
--moe_backend TRTLLM \
41+
--print_iter_log \
42+
2>&1 | tee debug.log
43+
```
44+
45+
* Launch the trtllm-serve and send a request
46+
47+
```bash
48+
echo "
49+
backend: pytorch
50+
tensor_parallel_size: 4
51+
moe_expert_parallel_size: 4
52+
enable_attention_dp: false
53+
kv_cache_config:
54+
free_gpu_memory_fraction: 0.25
55+
enable_block_reuse: true
56+
checkpoint_format: mistral_large_3
57+
" > serve.yml
58+
mpirun -n 1 --allow-run-as-root --oversubscribe python3 -m tensorrt_llm.commands.serve serve \
59+
${mistral_large_3_model_path} \
60+
--host localhost --port 8001 --backend pytorch \
61+
--extra_llm_api_options serve.yml \
62+
--tokenizer ${mistral_large_3_model_path} \
63+
2>&1 | tee serve_debug.log &
64+
65+
curl http://localhost:8001/v1/completions \
66+
-H "Content-Type: application/json" \
67+
-d '{
68+
"model": "${mistral_large_3_model_path}",
69+
"prompt": "The capital of France is",
70+
"max_tokens": 16,
71+
"top_k": 16
72+
}'
73+
74+
# The result would be like
75+
{"id":"cmpl-7e342c1d722d4226a1bf3ed35d762c35","object":"text_completion","created":1764061351,"model":"${mistral_large_3_model_path}","choices":[{"index":0,"text":"The capital of France is **Paris**.\n\nParis is the largest city in France and","token_ids":null,"logprobs":null,"context_logits":null,"finish_reason":"length","stop_reason":null,"disaggregated_params":null,"avg_decoded_tokens_per_iter":1.0}],"usage":{"prompt_tokens":7,"total_tokens":23,"completion_tokens":16,"prompt_tokens_details":{"cached_tokens":1}},"prompt_token_ids":null}
76+
```
77+
78+
* Launch the trtllm-serve with eagle3 and send a request
79+
80+
```bash
81+
echo "
82+
backend: pytorch
83+
tensor_parallel_size: 4
84+
moe_expert_parallel_size: 4
85+
enable_attention_dp: false
86+
kv_cache_config:
87+
free_gpu_memory_fraction: 0.25
88+
enable_block_reuse: true
89+
checkpoint_format: mistral_large_3
90+
speculative_config:
91+
decoding_type: Eagle
92+
max_draft_len: 1
93+
speculative_model_dir: ${mistral_large_3_eagle_model_path}
94+
eagle3_one_model: true
95+
" > serve.yml
96+
mpirun -n 1 --allow-run-as-root --oversubscribe python3 -m tensorrt_llm.commands.serve serve \
97+
${mistral_large_3_model_path} \
98+
--host localhost --port 8001 --backend pytorch \
99+
--extra_llm_api_options serve.yml \
100+
--tokenizer ${mistral_large_3_model_path} \
101+
2>&1 | tee serve_debug.log &
102+
103+
curl http://localhost:8001/v1/completions \
104+
-H "Content-Type: application/json" \
105+
-d '{
106+
"model": "${mistral_large_3_model_path}",
107+
"prompt": "The capital of France is",
108+
"max_tokens": 16,
109+
"top_k": 16
110+
}'
111+
```
112+
113+
## How to use the modules
114+
115+
The following explains how to use the different modules of Mistral Large V3.
116+
117+
```python
118+
from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3ForCausalLM
119+
from tensorrt_llm._torch.models.modeling_mistral import Mistral3VLM
120+
from tensorrt_llm.llmapi.tokenizer import MistralTokenizer
121+
from tensorrt_llm._torch.models.checkpoints.mistral.checkpoint_loader import MistralCheckpointLoader
122+
from tensorrt_llm._torch.models.checkpoints.mistral.weight_mapper import MistralLarge3WeightMapper
123+
from tensorrt_llm._torch.models.checkpoints.mistral.config_loader import MistralConfigLoader
124+
```
125+
126+
### Tokenizer
127+
```python
128+
mtok = MistralTokenizer.from_pretrained(TOKENIZER_DIR)
129+
```
130+
131+
### Config and model instance
132+
```python
133+
config_loader = MistralConfigLoader()
134+
config = config_loader.load(MODEL_DIR)
135+
136+
model = Mistral3VLM(model_config=config)
137+
assert isinstance(model.llm, DeepseekV3ForCausalLM)
138+
```
139+
140+
### Checkpoint loading
141+
```python
142+
weight_mapper=MistralLarge3WeightMapper()
143+
loader = MistralCheckpointLoader(weight_mapper=weight_mapper)
144+
145+
weights_dict = loader.load_weights(MODEL_DIR)
146+
```
147+
148+
### Weight loading
149+
#### E2E
150+
```python
151+
model.load_weights(weights_dict, weight_mapper=weight_mapper) # target usage
152+
```
153+
#### By module
154+
```python
155+
def _filter_weights(weights, prefix):
156+
return {
157+
name[len(prefix):]: weight
158+
for name, weight in weights.items() if name.startswith(prefix)
159+
}
160+
161+
llm_weights = weight_mapper.rename_by_params_map(
162+
params_map=weight_mapper.mistral_llm_mapping,
163+
weights=_filter_weights(weights_dict, "language_model."))
164+
model.llm.load_weights(llm_weights, weight_mapper=weight_mapper)
165+
```

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,4 @@ nvidia-cutlass-dsl==4.3.1; python_version >= "3.10"
7373
plotly
7474
numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing
7575
partial_json_parser
76+
mistral-common

tensorrt_llm/_torch/model_config.py

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -312,21 +312,39 @@ def load_hf_quant_config(hf_quant_config, moe_backend):
312312
layer_quant_config = None
313313

314314
# DeepSeek V3 FP8 ckpt
315-
if hf_quant_config.get("quant_method") == "fp8" and hf_quant_config.get(
316-
"weight_block_size", []):
317-
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
318-
if moe_backend == 'TRTLLM':
319-
# TODO: This is a hack. Remove after fp8 bmm is integrated.
320-
quant_config.exclude_modules = [
321-
"*kv_b_proj*", "*k_b_proj*", "*eh_proj"
322-
]
323-
else:
324-
quant_config.exclude_modules = ["*eh_proj"]
315+
if hf_quant_config.get("quant_method") == "fp8":
316+
if hf_quant_config.get("weight_block_size", []):
317+
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
318+
if moe_backend == 'TRTLLM':
319+
# TODO: This is a hack. Remove after fp8 bmm is integrated.
320+
quant_config.exclude_modules = [
321+
"*kv_b_proj*", "*k_b_proj*", "*eh_proj"
322+
]
323+
else:
324+
quant_config.exclude_modules = ["*eh_proj"]
325+
326+
block_size = hf_quant_config.get("weight_block_size", [])
327+
assert tuple(block_size) == (
328+
128,
329+
128), "FP8_BLOCK_SCALES only supports block_size=(128,128)"
330+
quant_config.group_size = block_size[0]
331+
332+
# DeepSeek V3 FP8 per tensor hack
333+
elif hf_quant_config.get("activation_scheme", None) == "static":
334+
logger.debug(f"Expanding weight scale to mimic DS FP8 recipe")
335+
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
336+
if moe_backend == 'TRTLLM':
337+
# TODO: This is a hack. Remove after fp8 bmm is integrated.
338+
quant_config.exclude_modules = [
339+
"*kv_b_proj*", "*k_b_proj*", "*eh_proj"
340+
]
341+
else:
342+
quant_config.exclude_modules = ["*eh_proj"]
343+
344+
block_size = (128, 128)
345+
quant_config.group_size = block_size[0]
346+
logger.info(f"quant_config: {quant_config}")
325347

326-
block_size = hf_quant_config.get("weight_block_size", [])
327-
assert tuple(block_size) == (
328-
128, 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)"
329-
quant_config.group_size = block_size[0]
330348
# MXFP4 checkpoints.
331349
elif hf_quant_config.get("quant_method") == "mxfp4":
332350
quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo(
@@ -394,44 +412,51 @@ def override_quant_algo():
394412
@classmethod
395413
def from_pretrained(cls,
396414
checkpoint_dir: str,
415+
pretrained_hf_config=None,
397416
trust_remote_code=False,
398417
**kwargs):
399418
# Use file lock to prevent race conditions when multiple processes
400419
# try to import/cache the same remote model config file
401420
with config_file_lock():
402421
# When handling the case where model_format is TLLM_ENGINE
403422
# send cyclic requests to the NONE URL.
404-
if checkpoint_dir is not None:
423+
if checkpoint_dir is not None and pretrained_hf_config is not None:
424+
logger.warning(
425+
f"Both checkpoint_dir and pretrained config specified. Using pretrained_config."
426+
)
427+
428+
if pretrained_hf_config is not None:
429+
pretrained_config = pretrained_hf_config
430+
elif checkpoint_dir is not None:
405431
pretrained_config = load_pretrained_config(
406432
checkpoint_dir,
407433
trust_remote_code=trust_remote_code,
408434
**kwargs,
409435
)
410-
if pretrained_config.architectures[
411-
0] == "DeepseekV32ForCausalLM":
412-
sparse_attention_config = kwargs.get(
413-
'sparse_attention_config')
414-
if sparse_attention_config:
415-
index_n_heads = sparse_attention_config.index_n_heads or pretrained_config.index_n_heads
416-
index_head_dim = sparse_attention_config.index_head_dim or pretrained_config.index_head_dim
417-
index_topk = sparse_attention_config.index_topk or pretrained_config.index_topk
418-
indexer_max_chunk_size = sparse_attention_config.indexer_max_chunk_size
419-
else:
420-
index_n_heads = pretrained_config.index_n_heads
421-
index_head_dim = pretrained_config.index_head_dim
422-
index_topk = pretrained_config.index_topk
423-
indexer_max_chunk_size = None
424-
kwargs[
425-
'sparse_attention_config'] = DeepSeekSparseAttentionConfig(
426-
index_n_heads=index_n_heads,
427-
index_head_dim=index_head_dim,
428-
index_topk=index_topk,
429-
indexer_max_chunk_size=indexer_max_chunk_size)
430436
else:
431437
raise ValueError(
432-
"checkpoint_dir is None. Cannot load model config without a valid checkpoint directory."
438+
"checkpoint_dir is None and pretrained config is not specified. Cannot load model config without a valid checkpoint directory or a pretrained config."
433439
)
434440

441+
if pretrained_config.architectures[0] == "DeepseekV32ForCausalLM":
442+
sparse_attention_config = kwargs.get('sparse_attention_config')
443+
if sparse_attention_config:
444+
index_n_heads = sparse_attention_config.index_n_heads or pretrained_config.index_n_heads
445+
index_head_dim = sparse_attention_config.index_head_dim or pretrained_config.index_head_dim
446+
index_topk = sparse_attention_config.index_topk or pretrained_config.index_topk
447+
indexer_max_chunk_size = sparse_attention_config.indexer_max_chunk_size
448+
else:
449+
index_n_heads = pretrained_config.index_n_heads
450+
index_head_dim = pretrained_config.index_head_dim
451+
index_topk = pretrained_config.index_topk
452+
indexer_max_chunk_size = None
453+
kwargs[
454+
'sparse_attention_config'] = DeepSeekSparseAttentionConfig(
455+
index_n_heads=index_n_heads,
456+
index_head_dim=index_head_dim,
457+
index_topk=index_topk,
458+
indexer_max_chunk_size=indexer_max_chunk_size)
459+
435460
# Get cached file from path or repo id, return None if not exists.
436461
def cached_file(path_or_repo_id, file_name):
437462
try:

tensorrt_llm/_torch/models/checkpoints/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,16 @@
1212
from .hf.qwen3_next_weight_mapper import Qwen3NextHfWeightMapper
1313
from .hf.weight_loader import HfWeightLoader
1414
from .hf.weight_mapper import HfWeightMapper
15+
from .mistral.checkpoint_loader import (MistralCheckpointLoader,
16+
MistralLarge3CheckpointLoader)
17+
from .mistral.config_loader import MistralConfigLoader
18+
from .mistral.weight_mapper import (MistralLarge3WeightMapper,
19+
MistralWeightMapper)
1520

1621
__all__ = [
1722
"HfConfigLoader", "HfWeightLoader", "HfWeightMapper",
23+
"MistralLarge3CheckpointLoader", "MistralCheckpointLoader",
24+
"MistralConfigLoader", "MistralWeightMapper", "MistralLarge3WeightMapper",
1825
"BaseCheckpointLoader", "HfCheckpointLoader", "NemotronHHfWeightMapper",
1926
"Gemma3HfWeightMapper", "MixtralHfWeightMapper", "Llama4HfWeightMapper",
2027
"Qwen2MoeHfWeightMapper", "Qwen3MoeHfWeightMapper", "Qwen2VLHfWeightMapper",

tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020

2121
@register_checkpoint_weight_loader("HF")
22+
@register_checkpoint_weight_loader("mistral")
23+
@register_checkpoint_weight_loader("mistral_large_3")
2224
class HfWeightLoader(BaseWeightLoader):
2325
"""
2426
Loads weights from SafeTensors/bin/pth files.

tensorrt_llm/_torch/models/checkpoints/mistral/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)