23
23
from QEfficient .base .pytorch_transforms import PytorchTransform
24
24
from QEfficient .compile .qnn_compiler import compile as qnn_compile
25
25
from QEfficient .generation .cloud_infer import QAICInferenceSession
26
- from QEfficient .utils import constants , dump_qconfig , make_serializable
27
- from QEfficient .utils .cache import QEFF_HOME , hash_dict_params
26
+ from QEfficient .utils import (
27
+ constants ,
28
+ create_json ,
29
+ dump_qconfig ,
30
+ filter_and_hash_compile_params ,
31
+ filter_and_hash_export_params ,
32
+ )
33
+ from QEfficient .utils .cache import QEFF_HOME
28
34
29
35
logger = logging .getLogger (__name__ )
30
36
@@ -46,15 +52,18 @@ class QEFFBaseModel(ABC):
46
52
def _transform_names (cls ) -> List [str ]:
47
53
return [x .__name__ for x in cls ._pytorch_transforms + cls ._onnx_transforms ]
48
54
55
+ def create_model_params (self , ** kwargs ) -> Dict :
56
+ model_params = copy .deepcopy (kwargs )
57
+
58
+ model_params ["config" ] = self .model .config .to_diff_dict ()
59
+ model_params ["_transform_names" ] = self ._transform_names ()
60
+ # TODO: Add keywords list to filter out params that are not needed for hashing
61
+ return model_params
62
+
49
63
def __init__ (self , model : torch .nn .Module , ** kwargs ) -> None :
50
64
super ().__init__ ()
51
65
self .model = model
52
-
53
- # Store Model parameters to Calculate Hash for caching
54
- self .model_params = {}
55
- self .model_params = copy .deepcopy (kwargs )
56
- self .model_params ["config" ] = self .model .config .to_diff_dict ()
57
- self .model_params ["_transform_names" ] = self ._transform_names ()
66
+ self .model_params = self .create_model_params (** kwargs )
58
67
59
68
if hasattr (self .model .config , "architectures" ):
60
69
self .model_architecture = self .model .config .architectures [0 ]
@@ -121,6 +130,7 @@ def compile(self, *args, **kwargs) -> Path:
121
130
:str: Path of the compiled ``qpc`` package.
122
131
"""
123
132
133
+ # @dump_model_params
124
134
def _export (
125
135
self ,
126
136
example_inputs : Dict [str , torch .Tensor ],
@@ -141,19 +151,17 @@ def _export(
141
151
:onnx_transform_kwargs (dict): Additional arguments to be passed to `Transform.apply` for this class.
142
152
:export_dir (str): Specify the export directory. The export_dir will be suffixed with a hash corresponding to current model.
143
153
"""
144
- export_params = {}
145
- export_params ["output_names" ] = output_names
146
- export_params ["dynamic_axes" ] = dynamic_axes
147
-
148
- self .model_params ["export_params" ] = export_params
149
-
150
- self .model_params .update (export_kwargs ) if export_kwargs is not None else None
151
- self .model_params .update (onnx_transform_kwargs ) if export_kwargs is not None else None
152
154
153
155
export_dir = Path (export_dir or (QEFF_HOME / self .model_architecture / self .model_name ))
156
+ export_hash , hashed_params = filter_and_hash_export_params (
157
+ model_params = copy .deepcopy (self .model_params ),
158
+ output_names = output_names ,
159
+ dynamic_axes = dynamic_axes ,
160
+ export_kwargs = export_kwargs ,
161
+ onnx_transform_kwargs = onnx_transform_kwargs ,
162
+ export_dir = export_dir ,
163
+ )
154
164
155
- export_hash = hash_dict_params (self .model_params )
156
- export_hash = export_hash .hexdigest ()[:16 ]
157
165
export_dir = export_dir .with_name (export_dir .name + "-" + export_hash )
158
166
onnx_path = export_dir / f"{ self .model_name } .onnx"
159
167
if onnx_path .is_file ():
@@ -221,20 +229,6 @@ def _export(
221
229
onnx .save (model , onnx_path )
222
230
logger .info ("Transformed onnx saved" )
223
231
224
- # Dumping model paramters in a JSON file after successful ONNX export
225
- model_params_json = export_dir / "model_params.json"
226
- with open (model_params_json , "w" ) as fp :
227
- json .dump (
228
- {
229
- "model_params" : {
230
- k : make_serializable (self .model_params [k ]) for k in sorted (self .model_params .keys ())
231
- }
232
- },
233
- fp ,
234
- indent = 4 ,
235
- )
236
- logger .info ("Parameters used for export hash dumped in a JSON file successfully" )
237
-
238
232
except Exception as e :
239
233
logger .error (f"ONNX export (or) ONNXTransforms failed: { e } " )
240
234
@@ -243,6 +237,11 @@ def _export(
243
237
finally :
244
238
shutil .rmtree (tmp_onnx_dir , ignore_errors = True )
245
239
240
+ # Dump JSON file with hashed parameters
241
+ hashed_params_export_path = export_dir / "hashed_model_params.json"
242
+ create_json (hashed_params_export_path , hashed_params )
243
+ logger .info ("Hashed parameters exported successfully." )
244
+
246
245
self .onnx_path = onnx_path
247
246
return onnx_path
248
247
@@ -281,8 +280,6 @@ def _compile(
281
280
if onnx_path is None and self .onnx_path is None :
282
281
self .export ()
283
282
284
- self .compile_params = {}
285
-
286
283
onnx_path = Path (onnx_path or self .onnx_path )
287
284
compile_dir = Path (compile_dir or onnx_path .parent )
288
285
qpc_path = compile_dir / "qpc"
@@ -317,23 +314,13 @@ def _compile(
317
314
continue
318
315
command .append (f"{ option } ={ value } " )
319
316
320
- self .compile_params ["command" ] = command
321
-
322
- if specializations is not None :
323
- self .compile_params .update ({"specializations" : specializations })
324
-
325
- if custom_io is not None :
326
- self .compile_params .update ({"custom_io" : custom_io })
327
-
328
- if num_speculative_tokens :
329
- self .compile_params .update ({"num_speculative_tokens" : num_speculative_tokens })
330
-
331
- if mdp_ts_num_devices is not None :
332
- self .compile_params .update ({"mdp_ts_num_devices" : mdp_ts_num_devices })
333
-
334
- # Check if already compiled
335
- compile_hash = hash_dict_params (self .compile_params )
336
- compile_hash = compile_hash .hexdigest ()[:16 ]
317
+ compile_hash , hashed_params = filter_and_hash_compile_params (
318
+ command = command ,
319
+ specializations = specializations ,
320
+ custom_io = custom_io ,
321
+ mdp_ts_num_devices = mdp_ts_num_devices ,
322
+ num_speculative_tokens = num_speculative_tokens ,
323
+ )
337
324
compile_dir = qpc_path .with_name (qpc_path .name + "-" + compile_hash )
338
325
339
326
qpc_path = compile_dir / "qpc"
@@ -389,18 +376,6 @@ def _compile(
389
376
try :
390
377
subprocess .run (command , capture_output = True , check = True )
391
378
392
- # Dumping compile paramters in a JSON file after successful QPC compilation
393
- compile_params_json = compile_dir / "compile_params.json"
394
- with open (compile_params_json , "w" ) as fp :
395
- json .dump (
396
- {
397
- "compile_params" : {
398
- k : make_serializable (self .compile_params [k ]) for k in sorted (self .compile_params .keys ())
399
- }
400
- },
401
- fp ,
402
- indent = 4 ,
403
- )
404
379
except subprocess .CalledProcessError as e :
405
380
raise RuntimeError (
406
381
"\n " .join (
@@ -414,6 +389,10 @@ def _compile(
414
389
)
415
390
)
416
391
392
+ # Dump JSON file with hashed parameters
393
+ hashed_compile_params_path = compile_dir / "hashed_compile_params.json"
394
+ create_json (hashed_compile_params_path , hashed_params )
395
+ logger .info ("Hashed parameters exported successfully." )
417
396
self .qpc_path = qpc_path
418
397
419
398
return qpc_path
0 commit comments