Skip to content

Commit 99ff668

Browse files
committed
Modifications made based on Rishin's suggestion. WIP
Signed-off-by: Dhiraj Kumar Sah <[email protected]>
1 parent 47673cf commit 99ff668

File tree

6 files changed

+152
-65
lines changed

6 files changed

+152
-65
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 42 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,14 @@
2323
from QEfficient.base.pytorch_transforms import PytorchTransform
2424
from QEfficient.compile.qnn_compiler import compile as qnn_compile
2525
from QEfficient.generation.cloud_infer import QAICInferenceSession
26-
from QEfficient.utils import constants, dump_qconfig, make_serializable
27-
from QEfficient.utils.cache import QEFF_HOME, hash_dict_params
26+
from QEfficient.utils import (
27+
constants,
28+
create_json,
29+
dump_qconfig,
30+
filter_and_hash_compile_params,
31+
filter_and_hash_export_params,
32+
)
33+
from QEfficient.utils.cache import QEFF_HOME
2834

2935
logger = logging.getLogger(__name__)
3036

@@ -46,15 +52,18 @@ class QEFFBaseModel(ABC):
4652
def _transform_names(cls) -> List[str]:
4753
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]
4854

55+
def create_model_params(self, **kwargs) -> Dict:
56+
model_params = copy.deepcopy(kwargs)
57+
58+
model_params["config"] = self.model.config.to_diff_dict()
59+
model_params["_transform_names"] = self._transform_names()
60+
# TODO: Add keywords list to filter out params that are not needed for hashing
61+
return model_params
62+
4963
def __init__(self, model: torch.nn.Module, **kwargs) -> None:
5064
super().__init__()
5165
self.model = model
52-
53-
# Store Model parameters to Calculate Hash for caching
54-
self.model_params = {}
55-
self.model_params = copy.deepcopy(kwargs)
56-
self.model_params["config"] = self.model.config.to_diff_dict()
57-
self.model_params["_transform_names"] = self._transform_names()
66+
self.model_params = self.create_model_params(**kwargs)
5867

5968
if hasattr(self.model.config, "architectures"):
6069
self.model_architecture = self.model.config.architectures[0]
@@ -121,6 +130,7 @@ def compile(self, *args, **kwargs) -> Path:
121130
:str: Path of the compiled ``qpc`` package.
122131
"""
123132

133+
# @dump_model_params
124134
def _export(
125135
self,
126136
example_inputs: Dict[str, torch.Tensor],
@@ -141,19 +151,17 @@ def _export(
141151
:onnx_transform_kwargs (dict): Additional arguments to be passed to `Transform.apply` for this class.
142152
:export_dir (str): Specify the export directory. The export_dir will be suffixed with a hash corresponding to current model.
143153
"""
144-
export_params = {}
145-
export_params["output_names"] = output_names
146-
export_params["dynamic_axes"] = dynamic_axes
147-
148-
self.model_params["export_params"] = export_params
149-
150-
self.model_params.update(export_kwargs) if export_kwargs is not None else None
151-
self.model_params.update(onnx_transform_kwargs) if export_kwargs is not None else None
152154

153155
export_dir = Path(export_dir or (QEFF_HOME / self.model_architecture / self.model_name))
156+
export_hash, hashed_params = filter_and_hash_export_params(
157+
model_params=copy.deepcopy(self.model_params),
158+
output_names=output_names,
159+
dynamic_axes=dynamic_axes,
160+
export_kwargs=export_kwargs,
161+
onnx_transform_kwargs=onnx_transform_kwargs,
162+
export_dir=export_dir,
163+
)
154164

155-
export_hash = hash_dict_params(self.model_params)
156-
export_hash = export_hash.hexdigest()[:16]
157165
export_dir = export_dir.with_name(export_dir.name + "-" + export_hash)
158166
onnx_path = export_dir / f"{self.model_name}.onnx"
159167
if onnx_path.is_file():
@@ -221,20 +229,6 @@ def _export(
221229
onnx.save(model, onnx_path)
222230
logger.info("Transformed onnx saved")
223231

224-
# Dumping model paramters in a JSON file after successful ONNX export
225-
model_params_json = export_dir / "model_params.json"
226-
with open(model_params_json, "w") as fp:
227-
json.dump(
228-
{
229-
"model_params": {
230-
k: make_serializable(self.model_params[k]) for k in sorted(self.model_params.keys())
231-
}
232-
},
233-
fp,
234-
indent=4,
235-
)
236-
logger.info("Parameters used for export hash dumped in a JSON file successfully")
237-
238232
except Exception as e:
239233
logger.error(f"ONNX export (or) ONNXTransforms failed: {e}")
240234

@@ -243,6 +237,11 @@ def _export(
243237
finally:
244238
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)
245239

240+
# Dump JSON file with hashed parameters
241+
hashed_params_export_path = export_dir / "hashed_model_params.json"
242+
create_json(hashed_params_export_path, hashed_params)
243+
logger.info("Hashed parameters exported successfully.")
244+
246245
self.onnx_path = onnx_path
247246
return onnx_path
248247

@@ -281,8 +280,6 @@ def _compile(
281280
if onnx_path is None and self.onnx_path is None:
282281
self.export()
283282

284-
self.compile_params = {}
285-
286283
onnx_path = Path(onnx_path or self.onnx_path)
287284
compile_dir = Path(compile_dir or onnx_path.parent)
288285
qpc_path = compile_dir / "qpc"
@@ -317,23 +314,13 @@ def _compile(
317314
continue
318315
command.append(f"{option}={value}")
319316

320-
self.compile_params["command"] = command
321-
322-
if specializations is not None:
323-
self.compile_params.update({"specializations": specializations})
324-
325-
if custom_io is not None:
326-
self.compile_params.update({"custom_io": custom_io})
327-
328-
if num_speculative_tokens:
329-
self.compile_params.update({"num_speculative_tokens": num_speculative_tokens})
330-
331-
if mdp_ts_num_devices is not None:
332-
self.compile_params.update({"mdp_ts_num_devices": mdp_ts_num_devices})
333-
334-
# Check if already compiled
335-
compile_hash = hash_dict_params(self.compile_params)
336-
compile_hash = compile_hash.hexdigest()[:16]
317+
compile_hash, hashed_params = filter_and_hash_compile_params(
318+
command=command,
319+
specializations=specializations,
320+
custom_io=custom_io,
321+
mdp_ts_num_devices=mdp_ts_num_devices,
322+
num_speculative_tokens=num_speculative_tokens,
323+
)
337324
compile_dir = qpc_path.with_name(qpc_path.name + "-" + compile_hash)
338325

339326
qpc_path = compile_dir / "qpc"
@@ -389,18 +376,6 @@ def _compile(
389376
try:
390377
subprocess.run(command, capture_output=True, check=True)
391378

392-
# Dumping compile paramters in a JSON file after successful QPC compilation
393-
compile_params_json = compile_dir / "compile_params.json"
394-
with open(compile_params_json, "w") as fp:
395-
json.dump(
396-
{
397-
"compile_params": {
398-
k: make_serializable(self.compile_params[k]) for k in sorted(self.compile_params.keys())
399-
}
400-
},
401-
fp,
402-
indent=4,
403-
)
404379
except subprocess.CalledProcessError as e:
405380
raise RuntimeError(
406381
"\n".join(
@@ -414,6 +389,10 @@ def _compile(
414389
)
415390
)
416391

392+
# Dump JSON file with hashed parameters
393+
hashed_compile_params_path = compile_dir / "hashed_compile_params.json"
394+
create_json(hashed_compile_params_path, hashed_params)
395+
logger.info("Hashed parameters exported successfully.")
417396
self.qpc_path = qpc_path
418397

419398
return qpc_path

QEfficient/transformers/models/modeling_auto.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,11 @@ def auto_correct_inputs(self, inputs):
130130
return {k: v for k, v in inputs.items() if k in [iinfo.name for iinfo in inputs_info]}
131131

132132

133+
class NoInitMeta(type):
134+
def __call__(cls, *args, **kwargs):
135+
raise RuntimeError("Use `from_pretrained` to create an instance.")
136+
137+
133138
class QEFFAutoModel(QEFFTransformersBase):
134139
"""
135140
The QEFFAutoModel class is designed for manipulating any transformer model from the HuggingFace hub.
@@ -911,6 +916,7 @@ def __init__(
911916
self.model.config.vision_config.use_flash_attn = "false"
912917
else:
913918
self.model.config.text_config.use_cache = True
919+
self.model_params["qeff_class"] = self.__class__.__name__
914920

915921
@classmethod
916922
def from_pretrained(
@@ -934,6 +940,10 @@ def from_pretrained(
934940
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, config, *args, **kwargs)
935941

936942
return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
943+
# # Bypass __call__ and manually initialize
944+
# instance = object.__new__(cls)
945+
# instance.__init__(model, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
946+
# return instance
937947

938948
def export(
939949
self,
@@ -1175,6 +1185,7 @@ def get_model_config(self) -> dict:
11751185
return self.model.config.__dict__
11761186

11771187

1188+
# class QEFFAutoModelForImageTextToText(metaclass=NoInitMeta):
11781189
class QEFFAutoModelForImageTextToText:
11791190
"""
11801191
The QEFFAutoModelForImageTextToText class is used to work with multimodal language models from the HuggingFace hub.
@@ -1277,10 +1288,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona
12771288
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
12781289
return cls(model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
12791290

1291+
# # Bypass __call__ and manually initialize
1292+
# instance = object.__new__(cls)
1293+
# instance.__init__(model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
1294+
# return instance
1295+
12801296

12811297
MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = {"InternVLChatModel": QEFFAutoModelForImageTextToText}
12821298

12831299

1300+
# class QEFFAutoModelForCausalLM(QEFFBaseModel, metaclass=NoInitMeta):
12841301
class QEFFAutoModelForCausalLM(QEFFBaseModel):
12851302
"""
12861303
The QEFF class is designed for manipulating any causal language model from the HuggingFace hub.

QEfficient/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111
)
1212
from QEfficient.utils._utils import ( # noqa: F401
1313
check_and_assign_cache_dir,
14+
create_json,
1415
custom_format_warning,
16+
dump_model_params,
1517
dump_qconfig,
18+
filter_and_hash_compile_params,
19+
filter_and_hash_export_params,
1620
get_num_layers_from_config,
1721
get_num_layers_vlm,
1822
get_onnx_dir_name,

QEfficient/utils/_utils.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import subprocess
1212
import xml.etree.ElementTree as ET
1313
from dataclasses import dataclass
14+
from pathlib import Path
1415
from typing import Any, Dict, List, Optional, Tuple, Union
1516

1617
import requests
@@ -25,6 +26,7 @@
2526
PreTrainedTokenizerFast,
2627
)
2728

29+
from QEfficient.utils.cache import QEFF_HOME, hash_dict_params
2830
from QEfficient.utils.constants import QEFF_MODELS_DIR, Constants, QnnConstants
2931
from QEfficient.utils.logging_utils import logger
3032

@@ -630,6 +632,43 @@ def wrapper(self, *args, **kwargs):
630632
return wrapper
631633

632634

635+
def dump_model_params(func):
636+
def wrapper(self, *args, **kwargs):
637+
# Bind args to their parameter names
638+
sig = inspect.signature(func)
639+
bound_args = sig.bind(self, *args, **kwargs)
640+
bound_args.apply_defaults()
641+
642+
# Convert bound arguments to a dictionary and exclude 'self'
643+
all_kwargs = {k: v for k, v in bound_args.arguments.items() if k != "self"}
644+
645+
export_dir = Path(kwargs["export_dir"] or (QEFF_HOME / self.model_architecture / self.model_name))
646+
try:
647+
filter_and_hash_export_params(
648+
self.model_params,
649+
**{k: v for k, v in all_kwargs.items() if k not in ["example_inputs"]},
650+
)
651+
652+
export_hash = hash_dict_params(self.model_params)
653+
export_hash = export_hash.hexdigest()[:16]
654+
export_dir = export_dir.with_name(export_dir.name + "-" + export_hash)
655+
656+
os.makedirs(export_dir, exist_ok=True)
657+
658+
hashed_params_file_path = os.path.join(export_dir, "hashed_model_params.json")
659+
create_json(hashed_params_file_path, self.model_params)
660+
661+
logger.info("Parameters used for export hash dumped in a JSON file successfully")
662+
except Exception as e:
663+
logger.error(f"An unexpected error occurred while dumping the hashed model params: {e}")
664+
665+
result = func(self, *args, **kwargs)
666+
667+
return result
668+
669+
return wrapper
670+
671+
633672
def get_qaic_sdk_version(qaic_sdk_xml_path: str) -> Optional[str]:
634673
"""
635674
Extracts the QAIC SDK version from the given SDK XML file.
@@ -724,6 +763,50 @@ def create_and_dump_qconfigs(
724763
create_json(qconfig_file_path, qconfigs)
725764

726765

766+
def filter_and_hash_export_params(**kwargs):
767+
"""
768+
This Method prepares all the model params required to create the hash for export directory.
769+
"""
770+
filtered_params = kwargs["model_params"]
771+
export_params = {}
772+
export_params["output_names"] = kwargs.get("output_names")
773+
export_params["dynamic_axes"] = kwargs.get("dynamic_axes")
774+
775+
filtered_params["export_params"] = export_params
776+
777+
export_kwargs = kwargs.get("export_kwargs")
778+
if export_kwargs:
779+
filtered_params.update(export_kwargs)
780+
781+
onnx_transform_kwargs = kwargs.get("onnx_transform_kwargs")
782+
if onnx_transform_kwargs:
783+
filtered_params.update(onnx_transform_kwargs)
784+
785+
return hash_dict_params(filtered_params), filtered_params
786+
787+
788+
def filter_and_hash_compile_params(**kwargs):
789+
"""
790+
This Method creates the hash for qpc directory.
791+
"""
792+
filtered_params = {}
793+
filtered_params["command"] = kwargs["command"]
794+
795+
if kwargs.get("specializations", None):
796+
filtered_params["specializations"] = kwargs["specializations"]
797+
798+
if kwargs.get("custom_io", None):
799+
filtered_params["custom_io"] = kwargs["custom_io"]
800+
801+
if kwargs.get("num_speculative_tokens", None):
802+
filtered_params["num_speculative_tokens"] = kwargs["num_speculative_tokens"]
803+
804+
if kwargs.get("mdp_ts_num_devices", None):
805+
filtered_params["mdp_ts_num_devices"] = kwargs["mdp_ts_num_devices"]
806+
807+
return hash_dict_params(filtered_params), filtered_params
808+
809+
727810
def filter_kwargs(func, kwargs):
728811
"""
729812
Filter a dictionary of keyword arguments to only include the valid arguments of a function.

QEfficient/utils/cache.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from pathlib import Path
1212
from typing import Dict
1313

14+
from QEfficient.utils.constants import HASH_HEXDIGEST_STR_LEN
15+
1416
QEFF_HOME: Path = None
1517
if "QEFF_HOME" in os.environ:
1618
QEFF_HOME = Path(os.environ["QEFF_HOME"])
@@ -43,9 +45,9 @@ def to_hashable(obj) -> bytes:
4345
).encode()
4446

4547

46-
def hash_dict_params(dict_items: Dict):
48+
def hash_dict_params(dict_items: Dict, hash_string_size: int = HASH_HEXDIGEST_STR_LEN):
4749
"""
4850
Takes a dictionary of items and returns a SHA256 hash object
4951
"""
5052
mhash = hashlib.sha256(to_hashable(dict_items))
51-
return mhash
53+
return mhash.hexdigest()[:hash_string_size]

QEfficient/utils/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
ONNX_EXPORT_IMAGE_DEPTH = 3
2626
ONNX_EXPORT_CTX_LEN = 1024
2727

28+
HASH_HEXDIGEST_STR_LEN = 16
29+
2830

2931
# Store the qeff_models inside the ~/.cache directory or over-ride with an env variable.
3032
def get_models_dir():

0 commit comments

Comments
 (0)