Skip to content

Commit 0330a1c

Browse files
jenchen13kevalmorabia97
authored andcommitted
Fix hf_quant_config with kv cache type (#557)
Update hf_quant_config with correct kv cache type for FP8 and NVFP4 --------- Signed-off-by: jenchen13 <[email protected]> Signed-off-by: Jennifer Chen <[email protected]>
1 parent ba23cf4 commit 0330a1c

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

modelopt/torch/export/unified_export_megatron.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
from .model_config import (
4141
KV_CACHE_FP8,
42+
KV_CACHE_NVFP4,
4243
QUANTIZATION_FP8,
4344
QUANTIZATION_FP8_PB_REAL,
4445
QUANTIZATION_FP8_PB_WO,
@@ -285,7 +286,6 @@ def save_pretrained(
285286
quantization_format = self._get_quantization_format(self.model)
286287

287288
quantization = None
288-
kv_cache_quantization = None
289289

290290
if quantization_format in (
291291
QUANTIZATION_FP8_PB_REAL,
@@ -297,6 +297,11 @@ def save_pretrained(
297297
elif quantization_format == QUANTIZATION_NVFP4:
298298
quantization = "NVFP4"
299299

300+
kv_cache_quantization = None
301+
kv_cache_dtype = get_kv_cache_dtype(self.model)
302+
if kv_cache_dtype in (KV_CACHE_FP8, KV_CACHE_NVFP4):
303+
# FP8 KV Cache is supported in VLLM; NVFP4 supported in TRTLLM
304+
kv_cache_quantization = kv_cache_dtype
300305
# We use the last PP rank and the 1st EP rank to write the config because
301306
# medusa_heads and eagle_module only exist in the last stage.
302307
if is_last_stage_main_rank:

0 commit comments

Comments
 (0)