Skip to content

Commit b818a91

Browse files
[https://nvbugs/5540752][fix] Support quantized Phi4 MM models (NVIDIA#8190)
Signed-off-by: Pamela <[email protected]>
1 parent 18c7a52 commit b818a91

File tree

14 files changed

+289
-68
lines changed

14 files changed

+289
-68
lines changed

tensorrt_llm/_torch/models/modeling_phi3.py

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -217,26 +217,46 @@ def filter_weights(prefix: str, weights: dict):
217217
if "self_attn.qkv_proj" in name:
218218
# The weights need to be split correctly before sharding to support tp_size >1.
219219
qkv_weight = module_weights['weight'][:]
220-
q_weight = qkv_weight[:hidden_size, :]
221-
k_weight = qkv_weight[hidden_size:hidden_size +
222-
num_kv_heads * head_dim, :]
223-
v_weight = qkv_weight[hidden_size +
224-
num_kv_heads * head_dim:, :]
220+
qk_split_index = hidden_size
221+
kv_split_index = hidden_size + num_kv_heads * head_dim
222+
223+
q_dict = {'weight': qkv_weight[:qk_split_index, :]}
224+
k_dict = {
225+
'weight':
226+
qkv_weight[qk_split_index:kv_split_index, :]
227+
}
228+
v_dict = {'weight': qkv_weight[kv_split_index:, :]}
225229

226230
# Get the scale factor for the fused QKV projection
227231
qkv_scale = module_weights.get('weight_scale', None)
228232

229-
q_dict = {'weight': q_weight}
230-
if qkv_scale is not None:
231-
q_dict['weight_scale'] = qkv_scale
232-
233-
k_dict = {'weight': k_weight}
234233
if qkv_scale is not None:
235-
k_dict['weight_scale'] = qkv_scale # Use same scale
236-
237-
v_dict = {'weight': v_weight}
238-
if qkv_scale is not None:
239-
v_dict['weight_scale'] = qkv_scale # Use same scale
234+
if qkv_scale.shape and qkv_scale.shape[
235+
0] == qkv_weight.shape[0]:
236+
q_dict[
237+
'weight_scale'] = qkv_scale[:
238+
qk_split_index, :]
239+
k_dict['weight_scale'] = qkv_scale[
240+
qk_split_index:kv_split_index, :]
241+
v_dict['weight_scale'] = qkv_scale[
242+
kv_split_index:, :]
243+
else: # use same scale
244+
q_dict['weight_scale'] = qkv_scale
245+
k_dict['weight_scale'] = qkv_scale
246+
v_dict['weight_scale'] = qkv_scale
247+
248+
input_scale = module_weights.get('input_scale', None)
249+
if input_scale is not None:
250+
q_dict['input_scale'] = input_scale
251+
k_dict['input_scale'] = input_scale
252+
v_dict['input_scale'] = input_scale
253+
254+
weight_scale_2 = module_weights.get(
255+
'weight_scale_2', None)
256+
if weight_scale_2 is not None:
257+
q_dict['weight_scale_2'] = weight_scale_2
258+
k_dict['weight_scale_2'] = weight_scale_2
259+
v_dict['weight_scale_2'] = weight_scale_2
240260

241261
module.load_weights(weights=[q_dict, k_dict, v_dict])
242262
elif "mlp.gate_up_proj" in name:
@@ -246,16 +266,33 @@ def filter_weights(prefix: str, weights: dict):
246266
gate_weight = gate_up_weight[:intermediate_size, :]
247267
up_weight = gate_up_weight[intermediate_size:, :]
248268

249-
# Get the scale factors if they exist
250-
gate_up_scale = module_weights.get('weight_scale', None)
251-
252269
gate_dict = {'weight': gate_weight}
253-
if gate_up_scale is not None:
254-
gate_dict['weight_scale'] = gate_up_scale
255-
256270
up_dict = {'weight': up_weight}
271+
272+
# Get the scale factors if they exist
273+
gate_up_scale = module_weights.get('weight_scale', None)
257274
if gate_up_scale is not None:
258-
up_dict['weight_scale'] = gate_up_scale
275+
if gate_up_scale.shape and gate_up_scale.shape[
276+
0] == gate_up_weight.shape[0]:
277+
gate_dict[
278+
'weight_scale'] = gate_up_scale[:
279+
intermediate_size, :]
280+
up_dict['weight_scale'] = gate_up_scale[
281+
intermediate_size:, :]
282+
else: # use same scale
283+
gate_dict['weight_scale'] = gate_up_scale
284+
up_dict['weight_scale'] = gate_up_scale
285+
286+
input_scale = module_weights.get('input_scale', None)
287+
if input_scale is not None:
288+
gate_dict['input_scale'] = input_scale
289+
up_dict['input_scale'] = input_scale
290+
291+
weight_scale_2 = module_weights.get(
292+
'weight_scale_2', None)
293+
if weight_scale_2 is not None:
294+
gate_dict['weight_scale_2'] = weight_scale_2
295+
up_dict['weight_scale_2'] = weight_scale_2
259296

260297
module.load_weights(weights=[gate_dict, up_dict])
261298
else:

tensorrt_llm/_torch/models/modeling_phi4mm.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,10 @@ def _load_phi4mm_classes(local_path):
8888
# Add parent folder to sys.path to enable relative import.
8989
original_sys_path = sys.path.copy()
9090
package_folder = Path(local_path)
91+
package_name = package_folder.name
9192
parent_folder = str(package_folder.parent)
9293
if parent_folder not in sys.path:
9394
sys.path.insert(0, parent_folder)
94-
9595
try:
9696
# Import Phi4MMConfig from configuration_phi4mm.py.
9797
config_path = os.path.join(local_path, 'configuration_phi4mm.py')
@@ -111,8 +111,7 @@ def _load_phi4mm_classes(local_path):
111111
# `Phi-4-multimodal-instruct` as the package name to avoid relative import errors.
112112
# `hf_modeling_phi4mm` as the module name to avoid name conflicts.
113113
spec = importlib.util.spec_from_file_location(
114-
"Phi-4-multimodal-instruct.hf_modeling_phi4mm",
115-
modeling_phi4mm_path)
114+
f"{package_name}.hf_modeling_phi4mm", modeling_phi4mm_path)
116115
hf_modeling_phi4mm = importlib.util.module_from_spec(spec)
117116
spec.loader.exec_module(hf_modeling_phi4mm)
118117
Phi4MMAudioEmbedding = hf_modeling_phi4mm.Phi4MMAudioEmbedding
@@ -989,12 +988,16 @@ def load_weights(self, weights):
989988
weights = {k: v for k, v in weights.items() if '.lora_' not in k}
990989
# Rename base layer weights.
991990
updated_weights = {}
991+
base_layer_weight_names = [
992+
'weight', 'input_scale', 'weight_scale', 'weight_scale_2'
993+
]
992994
for k in weights.keys():
993-
if 'base_layer.weight' in k:
994-
new_k = k.replace('base_layer.weight', 'weight')
995-
updated_weights[new_k] = weights[k]
996-
else:
997-
updated_weights[k] = weights[k]
995+
new_k = k
996+
for weight_name in base_layer_weight_names:
997+
if f'base_layer.{weight_name}' in k:
998+
new_k = k.replace(f'base_layer.{weight_name}', weight_name)
999+
break
1000+
updated_weights[new_k] = weights[k]
9981001
weights = updated_weights
9991002
self.llm.load_weights(weights)
10001003

tensorrt_llm/inputs/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -580,10 +580,10 @@ def convert_to_conversation_message(
580580
# Check if mdata is a MultimodalData
581581
if isinstance(mdata,
582582
dict) and "modality" in mdata and "data" in mdata:
583-
modality = mdata["modality"]
583+
mdata_modality = mdata["modality"]
584584
if modality == "multiple_image":
585-
modality = "image"
586-
mm_data_tracker.add_data(modality, mdata["data"])
585+
mdata_modality = "image"
586+
mm_data_tracker.add_data(mdata_modality, mdata["data"])
587587
else:
588588
# Add embeddings to the tracker for placeholder handling
589589
mm_data_tracker.add_data(mdata["modality"],

tests/integration/defs/accuracy/references/gsm8k.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,10 @@ mistralai/Mistral-Small-3.1-24B-Instruct-2503:
187187
accuracy: 89.23
188188
microsoft/Phi-4-multimodal-instruct:
189189
- accuracy: 81.19
190+
- quant_algo: FP8
191+
accuracy: 80.82
192+
- quant_algo: NVFP4
193+
accuracy: 69.33
190194
microsoft/Phi-4-multimodal-instruct-long-rope:
191195
- accuracy: 75.85
192196
microsoft/Phi-4-mini-instruct:

tests/integration/defs/accuracy/references/mmlu.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,10 @@ mistralai/Ministral-8B-Instruct-2410:
294294
accuracy: 65.96
295295
microsoft/Phi-4-multimodal-instruct:
296296
- accuracy: 69.69
297+
- quant_algo: FP8
298+
accuracy: 68.86
299+
- quant_algo: NVFP4
300+
accuracy: 64.04
297301
microsoft/Phi-4-multimodal-instruct-long-rope:
298302
- accuracy: 65.98
299303
microsoft/phi-4:

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3314,6 +3314,24 @@ def test_auto_dtype_long_rope(self):
33143314
task = GSM8K(model_name)
33153315
task.evaluate(llm)
33163316

3317+
@skip_pre_blackwell
3318+
def test_fp4(self):
3319+
model_path = f"{self.MODEL_PATH}-FP4"
3320+
with LLM(model_path, max_seq_len=4096) as llm:
3321+
task = MMLU(self.MODEL_NAME)
3322+
task.evaluate(llm)
3323+
task = GSM8K(self.MODEL_NAME)
3324+
task.evaluate(llm)
3325+
3326+
@skip_pre_hopper
3327+
def test_fp8(self):
3328+
model_path = f"{self.MODEL_PATH}-FP8"
3329+
with LLM(model_path, max_seq_len=4096) as llm:
3330+
task = MMLU(self.MODEL_NAME)
3331+
task.evaluate(llm)
3332+
task = GSM8K(self.MODEL_NAME)
3333+
task.evaluate(llm)
3334+
33173335

33183336
@skip_pre_hopper
33193337
@pytest.mark.skip_less_device_memory(80000)

tests/integration/defs/perf/test_perf.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,14 @@
127127
"phi_4_multimodal_instruct": "multimodals/Phi-4-multimodal-instruct",
128128
"phi_4_multimodal_instruct_image": "multimodals/Phi-4-multimodal-instruct",
129129
"phi_4_multimodal_instruct_audio": "multimodals/Phi-4-multimodal-instruct",
130+
"phi_4_multimodal_instruct_fp4_image":
131+
"multimodals/Phi-4-multimodal-instruct-FP4",
132+
"phi_4_multimodal_instruct_fp4_audio":
133+
"multimodals/Phi-4-multimodal-instruct-FP4",
134+
"phi_4_multimodal_instruct_fp8_image":
135+
"multimodals/Phi-4-multimodal-instruct-FP8",
136+
"phi_4_multimodal_instruct_fp8_audio":
137+
"multimodals/Phi-4-multimodal-instruct-FP8",
130138
"bielik_11b_v2.2_instruct": "Bielik-11B-v2.2-Instruct",
131139
"bielik_11b_v2.2_instruct_fp8": "Bielik-11B-v2.2-Instruct-FP8",
132140
"mistral_small_v3.1_24b": "Mistral-Small-3.1-24B-Instruct-2503",
@@ -177,6 +185,14 @@
177185
"multimodals/Phi-4-multimodal-instruct/vision-lora",
178186
"phi_4_multimodal_instruct_audio":
179187
"multimodals/Phi-4-multimodal-instruct/speech-lora",
188+
"phi_4_multimodal_instruct_fp4_image":
189+
"multimodals/Phi-4-multimodal-instruct-FP4/vision-lora",
190+
"phi_4_multimodal_instruct_fp4_audio":
191+
"multimodals/Phi-4-multimodal-instruct-FP4/speech-lora",
192+
"phi_4_multimodal_instruct_fp8_image":
193+
"multimodals/Phi-4-multimodal-instruct-FP8/vision-lora",
194+
"phi_4_multimodal_instruct_fp8_audio":
195+
"multimodals/Phi-4-multimodal-instruct-FP8/speech-lora",
180196
}
181197

182198
TIMING_CACHE_DIR = os.environ.get("TIMING_CACHE_DIR", "")

0 commit comments

Comments
 (0)