diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 43b5e85f9..9b82398d7 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -39,6 +39,7 @@ from .model_config import ( KV_CACHE_FP8, + KV_CACHE_NVFP4, QUANTIZATION_FP8, QUANTIZATION_FP8_PB_REAL, QUANTIZATION_FP8_PB_WO, @@ -285,7 +286,6 @@ def save_pretrained( quantization_format = self._get_quantization_format(self.model) quantization = None - kv_cache_quantization = None if quantization_format in ( QUANTIZATION_FP8_PB_REAL, @@ -297,6 +297,11 @@ def save_pretrained( elif quantization_format == QUANTIZATION_NVFP4: quantization = "NVFP4" + kv_cache_quantization = None + kv_cache_dtype = get_kv_cache_dtype(self.model) + if kv_cache_dtype in (KV_CACHE_FP8, KV_CACHE_NVFP4): + # FP8 KV Cache is supported in VLLM; NVFP4 supported in TRTLLM + kv_cache_quantization = kv_cache_dtype # We use the last PP rank and the 1st EP rank to write the config because # medusa_heads and eagle_module only exist in the last stage. if is_last_stage_main_rank: