1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import pytest
17- import torch
16+ import json
1817from copy import deepcopy
1918from functools import partial
20- import modelopt . torch . quantization as mtq
21- from modelopt . torch . export . unified_export_hf import export_hf_checkpoint
22- from modelopt . torch . export . unified_export_megatron import export_mcore_gpt_to_hf
23- from _test_utils .torch . transformers_models import create_tiny_llama_dir
19+
20+ import pytest
21+ import torch
22+ from _test_utils .import_helper import skip_if_no_megatron
2423from _test_utils .torch .distributed .utils import spawn_multiprocess_job
2524from _test_utils .torch .megatron .models import get_mcore_gpt_model
26- from _test_utils .import_helper import skip_if_no_megatron
25+ from _test_utils .torch . transformers_models import create_tiny_llama_dir
2726from transformers import AutoModelForCausalLM
2827
29- import os
30- import json
28+ import modelopt .torch .quantization as mtq
29+ from modelopt .torch .export .unified_export_hf import export_hf_checkpoint
30+ from modelopt .torch .export .unified_export_megatron import export_mcore_gpt_to_hf
3131
3232skip_if_no_megatron (apex_or_te_required = True )
3333
34+
3435@pytest .mark .parametrize ("quant_cfg" , [mtq .FP8_DEFAULT_CFG ])
3536def test_hf_vllm_export (tmp_path , quant_cfg ):
3637 """Test HuggingFace model export for vLLM with fake quantization.
37-
38+
3839 This test verifies:
3940 1. Model weights match before and after export
4041 2. quant_amax.pth file is created, huggingface config file does not exist
4142 3. Amax values are correctly extracted and saved in quant_amax.pth file
4243 """
43-
44+
4445 # Create a tiny LLaMA model for testing
4546 tiny_model_dir = create_tiny_llama_dir (tmp_path , with_tokenizer = True , num_hidden_layers = 2 )
46-
47+
4748 # Load the model
4849 model = AutoModelForCausalLM .from_pretrained (tiny_model_dir )
4950 model = model .cuda ()
5051 model .eval ()
51-
52+
5253 # Quantize the model
5354 def forward_loop (model ):
5455 input_ids = torch .randint (0 , model .config .vocab_size , (1 , 128 )).cuda ()
5556 with torch .no_grad ():
5657 model (input_ids )
57-
58+
5859 model = mtq .quantize (model , quant_cfg , forward_loop )
59-
60+
6061 model_state_dict = deepcopy (model .state_dict ())
6162
6263 # Export directory
6364 export_dir = tmp_path / "vllm_export"
6465 export_dir .mkdir (exist_ok = True )
65-
66+
6667 # Export for vLLM
6768 export_hf_checkpoint (model , export_dir = export_dir , export_vllm_fq_weights_qstate = True )
6869
6970 # check if quant_amax.pth file exists
7071 quant_amax_file = export_dir / "quant_amax.pth"
7172 assert quant_amax_file .exists (), f"quant_amax.pth file should be created in { export_dir } "
72-
73+
7374 # make sure hf_quant_config.json file does not exist
7475 hf_quant_config_file = export_dir / "hf_quant_config.json"
75- assert not hf_quant_config_file .exists (), f"hf_quant_config.json file should not be created in { export_dir } "
76+ assert not hf_quant_config_file .exists (), (
77+ f"hf_quant_config.json file should not be created in { export_dir } "
78+ )
7679
7780 # check weights match before and after export
7881 model_after = AutoModelForCausalLM .from_pretrained (export_dir )
7982 model_after = model_after .cuda ()
8083 model_after .eval ()
8184 model_after_state_dict = model_after .state_dict ()
8285 amax_state_dict = {}
83- for key in model_state_dict .keys ():
86+ for key , param in model_state_dict .items ():
8487 if key .endswith ("_amax" ):
85- amax_state_dict [key ] = model_state_dict [ key ]
88+ amax_state_dict [key ] = param
8689 continue
87-
88- assert torch .allclose (model_state_dict [ key ] , model_after_state_dict [key ], atol = 1e-6 ), (
90+
91+ assert torch .allclose (param , model_after_state_dict [key ], atol = 1e-6 ), (
8992 f"Weight mismatch for { key } : "
90- f"before shape={ model_state_dict [ key ] .shape } , after shape={ model_after_state_dict [key ].shape } , "
91- f"max diff={ torch .abs (model_state_dict [ key ] - model_after_state_dict [key ]).max ()} "
93+ f"before shape={ param .shape } , after shape={ model_after_state_dict [key ].shape } , "
94+ f"max diff={ torch .abs (param - model_after_state_dict [key ]).max ()} "
9295 )
9396
9497 # Verify amax values are correct
9598 amax_dict = torch .load (quant_amax_file )
9699 assert len (amax_dict ) > 0 , "amax_dict should not be empty"
97- assert amax_dict .keys () == amax_state_dict .keys (), f"amax keys mismatch between before and after export"
100+ assert amax_dict .keys () == amax_state_dict .keys (), (
101+ "amax keys mismatch between before and after export"
102+ )
98103
99104
100105def _test_mcore_vllm_export (tmp_path , quant_cfg , rank , size ):
101- """Test megatron-core model export for vLLM with fake quantization.
102-
103- """
106+ """Test megatron-core model export for vLLM with fake quantization."""
104107 # Create a tiny mcore GPT model
105108 num_layers = 2
106109 hidden_size = 64
@@ -109,7 +112,7 @@ def _test_mcore_vllm_export(tmp_path, quant_cfg, rank, size):
109112 ffn_hidden_size = 128
110113 max_sequence_length = 32
111114 vocab_size = 64
112-
115+
113116 model = get_mcore_gpt_model (
114117 tensor_model_parallel_size = size ,
115118 pipeline_model_parallel_size = 1 ,
@@ -126,7 +129,7 @@ def _test_mcore_vllm_export(tmp_path, quant_cfg, rank, size):
126129 transformer_impl = "modelopt" ,
127130 ).cuda ()
128131 model .eval ()
129-
132+
130133 # Quantize the model
131134 def forward_loop (model ):
132135 batch_size = 1
@@ -138,11 +141,8 @@ def forward_loop(model):
138141 attention_mask = attention_mask < 0.5 # Convert to boolean mask
139142 with torch .no_grad ():
140143 model (input_ids , position_ids , attention_mask )
141-
142- model = mtq .quantize (model , quant_cfg , forward_loop )
143-
144- model_state_dict = deepcopy (model .state_dict ())
145144
145+ model = mtq .quantize (model , quant_cfg , forward_loop )
146146 # Create HF config for export
147147 pretrained_config = {
148148 "architectures" : ["LlamaForCausalLM" ],
@@ -156,14 +156,14 @@ def forward_loop(model):
156156 "num_key_value_heads" : num_query_groups ,
157157 "torch_dtype" : "bfloat16" ,
158158 }
159-
159+
160160 with open (tmp_path / "config.json" , "w" ) as f :
161161 json .dump (pretrained_config , f )
162162
163163 # Export directory
164164 export_dir = tmp_path / "vllm_export"
165165 export_dir .mkdir (exist_ok = True )
166-
166+
167167 # Export for vLLM
168168 export_mcore_gpt_to_hf (
169169 model ,
@@ -176,10 +176,12 @@ def forward_loop(model):
176176 # check if quant_amax.pth file exists
177177 quant_amax_file = export_dir / "quant_amax.pth"
178178 assert quant_amax_file .exists (), f"quant_amax.pth file should be created in { export_dir } "
179-
179+
180180 # make sure hf_quant_config.json file does not exist
181181 hf_quant_config_file = export_dir / "hf_quant_config.json"
182- assert not hf_quant_config_file .exists (), f"hf_quant_config.json file should not be created in { export_dir } "
182+ assert not hf_quant_config_file .exists (), (
183+ f"hf_quant_config.json file should not be created in { export_dir } "
184+ )
183185
184186
185187@pytest .mark .parametrize ("quant_cfg" , [mtq .FP8_DEFAULT_CFG ])
@@ -190,5 +192,3 @@ def test_mcore_vllm_export(tmp_path, quant_cfg):
190192 job = partial (_test_mcore_vllm_export , tmp_path , quant_cfg ),
191193 backend = "nccl" ,
192194 )
193-
194-
0 commit comments