diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py
index 1aafb1ba2..072c4053e 100644
--- a/QEfficient/base/modeling_qeff.py
+++ b/QEfficient/base/modeling_qeff.py
@@ -5,9 +5,8 @@
#
# ----------------------------------------------------------------------------
-import hashlib
+import copy
import inspect
-import json
import logging
import shutil
import subprocess
@@ -23,8 +22,16 @@
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.generation.cloud_infer import QAICInferenceSession
-from QEfficient.utils import constants, dump_qconfig
-from QEfficient.utils.cache import QEFF_HOME, to_hashable
+from QEfficient.utils import (
+ constants,
+ create_json,
+ dump_qconfig,
+ filter_and_create_export_hash,
+ generate_mdp_partition_config,
+ hash_compile_params,
+ load_json,
+)
+from QEfficient.utils.cache import QEFF_HOME
logger = logging.getLogger(__name__)
@@ -46,12 +53,18 @@ class QEFFBaseModel(ABC):
def _transform_names(cls) -> List[str]:
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]
- def __init__(self, model: torch.nn.Module) -> None:
+ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
super().__init__()
self.model = model
+ self.hash_params = self.create_model_params(**kwargs)
+
self.onnx_path: Optional[str] = None
self.qpc_path: Optional[str] = None
self.qpc_session: Optional[QAICInferenceSession] = None
+ self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None)
+ if hasattr(self.model.config, "architectures"):
+ model_architecture = getattr(self.model.config, "architectures", None)
+ self.model_architecture = model_architecture[0] if isinstance(model_architecture, list) else None
# Apply the transformations
any_transformed = False
@@ -64,13 +77,16 @@ def __init__(self, model: torch.nn.Module) -> None:
else:
logger.info(f"Pytorch transforms applied to model: {self.model_name}")
- @property
- @abstractmethod
- def model_name(self) -> str: ...
+ def create_model_params(self, **kwargs) -> Dict:
+ model_params = copy.deepcopy(kwargs)
+ model_params["config"] = self.model.config.to_diff_dict()
+ model_params["peft_config"] = getattr(self.model, "active_peft_config", None)
+ model_params["applied_transform_names"] = self._transform_names()
+ return model_params
@property
@abstractmethod
- def model_hash(self) -> str: ...
+ def model_name(self) -> str: ...
@abstractmethod
def export(self, export_dir: Optional[str] = None) -> Path:
@@ -135,8 +151,17 @@ def _export(
:onnx_transform_kwargs (dict): Additional arguments to be passed to `Transform.apply` for this class.
:export_dir (str): Specify the export directory. The export_dir will be suffixed with a hash corresponding to current model.
"""
- export_dir = Path(export_dir or (QEFF_HOME / self.model_name))
- export_dir = export_dir.with_name(export_dir.name + "-" + self.model_hash)
+ parent_dir = self.model_architecture or self.model_name
+ export_dir = Path(export_dir or (QEFF_HOME / parent_dir / self.model_name))
+ export_hash, filtered_hash_params = filter_and_create_export_hash(
+ model_params=self.hash_params,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_kwargs=export_kwargs,
+ onnx_transform_kwargs=onnx_transform_kwargs,
+ )
+ self.export_hash = export_hash
+ export_dir = export_dir.with_name(export_dir.name + "-" + export_hash)
onnx_path = export_dir / f"{self.model_name}.onnx"
if onnx_path.is_file():
self.onnx_path = onnx_path
@@ -211,6 +236,11 @@ def _export(
finally:
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)
+ # Dump JSON file with hashed parameters
+ hashed_params_export_path = export_dir / "hashed_export_params.json"
+ create_json(hashed_params_export_path, filtered_hash_params)
+ logger.info("Hashed parameters exported successfully.")
+
self.onnx_path = onnx_path
return onnx_path
@@ -241,12 +271,10 @@ def _compile(
:mdp_ts_num_devices (int): Number of devices to partition to use Multi-Device Partitioning with tensor-slicing.
:num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model.
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
- :qnn_config (str): Path of QNN Config parameters file. Any extra parameters for QNN compilation can be passed via this file. ``Defaults to None.``
- :compiler_options: Pass any compiler option as input.
- Any flag that is supported by `qaic-exec` can be passed. Params are converted to flags as below:
+ :qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
+ :compiler_options: Pass any compiler option as input. Any flag that is supported by `qaic-exec` can be passed. Params are converted to flags as below:
- aic_num_cores=16 -> -aic-num-cores=16
- convert_to_fp16=True -> -convert-to-fp16
- For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored.
"""
if onnx_path is None and self.onnx_path is None:
self.export()
@@ -258,19 +286,14 @@ def _compile(
raise FileNotFoundError(f"ONNX file not found at: {onnx_path}")
if enable_qnn:
- if compiler_options:
- logger.warning(
- f"Extra arguments to QNN compilation are supported only via qnn_config file. Ignoring {compiler_options}"
- )
-
self.qpc_path = qnn_compile(
onnx_path=onnx_path,
qpc_base_path=compile_dir,
specializations=specializations,
custom_io=custom_io,
device_group=list(range(mdp_ts_num_devices)),
- num_cores=compiler_options.get("aic_num_cores", 16),
- mxfp6=compiler_options.get("mxfp6_matmul", False),
+ num_cores=compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES),
+ mxfp6=compiler_options.get("mxfp6_matmul", constants.DEFAULT_AIC_MXPF6_MATMUL),
mxint8=mxint8_kv_cache,
qnn_config=qnn_config,
)
@@ -278,8 +301,8 @@ def _compile(
return self.qpc_path
command = constants.COMPILER + [f"-m={onnx_path}"]
- if mdp_ts_json_path := compiler_options.pop("mdp_ts_json_path", None):
- mdp_ts_num_devices = None
+
+ if mdp_ts_json_path := compiler_options.pop("mdp_load_partition_config", None):
command.append(f"-mdp-load-partition-config={mdp_ts_json_path}")
for key, value in compiler_options.items():
@@ -289,24 +312,30 @@ def _compile(
command.append(option)
continue
command.append(f"{option}={value}")
- compile_hash = hashlib.sha256(to_hashable(command))
-
- if specializations is not None:
- compile_hash.update(to_hashable(specializations))
-
- if custom_io is not None:
- compile_hash.update(to_hashable(custom_io))
-
- if num_speculative_tokens:
- compile_hash.update(to_hashable({"num_speculative_tokens": num_speculative_tokens}))
- # Hash num_devices too, since default value would always be 1.
- compile_hash.update(to_hashable(mdp_ts_num_devices))
- # Check if already compiled
- compile_hash = compile_hash.hexdigest()[:16]
+ # Create a dummy mdp_ts_json if mdp-load-partition-config not provided and num_devices > 1
+ if mdp_ts_json_path is not None:
+ mdp_ts_json = load_json(str(mdp_ts_json_path))
+ elif mdp_ts_num_devices > 1:
+ mdp_ts_json = generate_mdp_partition_config(
+ mdp_ts_num_devices, compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES)
+ )
+ else:
+ mdp_ts_json = None
+
+ compile_hash, hashed_params = hash_compile_params(
+ command=command,
+ specializations=specializations,
+ custom_io=custom_io,
+ mdp_ts_num_devices=mdp_ts_num_devices,
+ mdp_ts_json=mdp_ts_json,
+ num_speculative_tokens=num_speculative_tokens,
+ )
compile_dir = qpc_path.with_name(qpc_path.name + "-" + compile_hash)
+
qpc_path = compile_dir / "qpc"
qpc_path.mkdir(parents=True, exist_ok=True)
+
if qpc_path.is_dir():
if (qpc_path / "programqpc.bin").is_file():
self.qpc_path = qpc_path
@@ -314,15 +343,19 @@ def _compile(
# Probably compilation failure last time, delete directory to start over
shutil.rmtree(qpc_path)
+ # write the MDP partition config file if not provided
+ if mdp_ts_json is not None:
+ mdp_ts_json_path = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json"
+ create_json(str(mdp_ts_json_path), mdp_ts_json)
+ command.append(f"-mdp-load-partition-config={mdp_ts_json_path}")
+
# Write specializations.json file
if specializations is not None:
specializations_json = compile_dir / "specializations.json"
- with open(specializations_json, "w") as fp:
- json.dump(
- {"specializations": [{k: str(v) for k, v in spec.items()} for spec in specializations]},
- fp,
- indent=4,
- )
+ specializations_data = {
+ "specializations": [{k: str(v) for k, v in spec.items()} for spec in specializations]
+ }
+ create_json(str(specializations_json), specializations_data)
command.append(f"-network-specialization-config={specializations_json}")
# Write custom_io.yaml file
@@ -333,30 +366,11 @@ def _compile(
fp.write(f" - IOName: {io_name}\n Precision: {dtype}\n\n")
command.append(f"-custom-IO-list-file={custom_io_yaml}")
- # Write mdp_config.json file
- if not mdp_ts_json_path and mdp_ts_num_devices > 1:
- num_cores = compiler_options.get("aic_num_cores", 16)
- mdp_ts_json = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json"
- with open(mdp_ts_json, "w") as fp:
- json.dump(
- {
- "connections": [{"devices": list(range(mdp_ts_num_devices)), "type": "p2p"}],
- "partitions": [
- {
- "name": "Partition0",
- "devices": [{"deviceId": d, "numCores": num_cores} for d in range(mdp_ts_num_devices)],
- }
- ],
- },
- fp,
- indent=4,
- )
- command.append(f"-mdp-load-partition-config={mdp_ts_json}")
-
command.append(f"-aic-binary-dir={qpc_path}")
logger.info(f"Running compiler: {' '.join(command)}")
try:
subprocess.run(command, capture_output=True, check=True)
+
except subprocess.CalledProcessError as e:
raise RuntimeError(
"\n".join(
@@ -370,6 +384,10 @@ def _compile(
)
)
+ # Dump JSON file with hashed parameters
+ hashed_compile_params_path = compile_dir / "hashed_compile_params.json"
+ create_json(hashed_compile_params_path, hashed_params)
+ logger.info("Hashed parameters exported successfully.")
self.qpc_path = qpc_path
return qpc_path
diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py
index 1e0dc48bc..8f8acd64f 100644
--- a/QEfficient/cloud/finetune.py
+++ b/QEfficient/cloud/finetune.py
@@ -5,9 +5,10 @@
#
# -----------------------------------------------------------------------------
+import logging
import random
import warnings
-from typing import Any, Dict, Optional, Union
+from typing import Any, Optional, Union
import numpy as np
import torch
@@ -17,7 +18,7 @@
import torch.utils.data
from peft import PeftModel, get_peft_model
from torch.optim.lr_scheduler import StepLR
-from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
+from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from QEfficient.finetune.configs.training import TrainConfig
from QEfficient.finetune.utils.config_utils import (
@@ -26,18 +27,23 @@
update_config,
)
from QEfficient.finetune.utils.dataset_utils import get_dataloader
+from QEfficient.finetune.utils.helper import Task_Mode
+from QEfficient.finetune.utils.logging_utils import logger
from QEfficient.finetune.utils.parser import get_finetune_parser
-from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train
-from QEfficient.utils._utils import login_and_download_hf_lm
+from QEfficient.finetune.utils.train_utils import (
+ get_longest_seq_length,
+ print_model_size,
+ print_trainable_parameters,
+ train,
+)
+from QEfficient.utils._utils import hf_download
# Try importing QAIC-specific module, proceed without it if unavailable
try:
import torch_qaic # noqa: F401
except ImportError as e:
- print(f"Warning: {e}. Proceeding without QAIC modules.")
-
+ logger.log_rank_zero(f"{e}. Moving ahead without these qaic modules.", logging.WARNING)
-from transformers import AutoModelForSequenceClassification
# Suppress all warnings
warnings.filterwarnings("ignore")
@@ -85,14 +91,13 @@ def setup_seeds(seed: int) -> None:
def load_model_and_tokenizer(
- train_config: TrainConfig, dataset_config: Any, peft_config_file: str, **kwargs
+ train_config: TrainConfig, dataset_config: Any, **kwargs
) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
"""Load the pre-trained model and tokenizer from Hugging Face.
Args:
config (TrainConfig): Training configuration object containing model and tokenizer names.
dataset_config (Any): A dataclass object representing dataset configuration.
- peft_config_file (str): Path to PEFT config file used for PEFT finetuning.
kwargs: Additional arguments to override PEFT config.
Returns:
@@ -106,8 +111,9 @@ def load_model_and_tokenizer(
- Resizes model embeddings if tokenizer vocab size exceeds model embedding size.
- Sets pad_token_id to eos_token_id if not defined in the tokenizer.
"""
- pretrained_model_path = login_and_download_hf_lm(train_config.model_name)
- if train_config.task_type == "seq_classification":
+ logger.log_rank_zero(f"Loading HuggingFace model for {train_config.model_name}")
+ pretrained_model_path = hf_download(train_config.model_name)
+ if train_config.task_mode == Task_Mode.SEQ_CLASSIFICATION:
model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_path,
num_labels=dataset_config.num_labels,
@@ -116,7 +122,7 @@ def load_model_and_tokenizer(
)
if not hasattr(model, "base_model_prefix"):
- raise RuntimeError("Given huggingface model does not have 'base_model_prefix' attribute.")
+ logger.raise_error("Given huggingface model does not have 'base_model_prefix' attribute.", RuntimeError)
for param in getattr(model, model.base_model_prefix).parameters():
param.requires_grad = False
@@ -141,11 +147,10 @@ def load_model_and_tokenizer(
# If there is a mismatch between tokenizer vocab size and embedding matrix,
# throw a warning and then expand the embedding matrix
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
- print("WARNING: Resizing embedding matrix to match tokenizer vocab size.")
+ logger.log_rank_zero("Resizing the embedding matrix to match the tokenizer vocab size.", logging.WARNING)
model.resize_token_embeddings(len(tokenizer))
- # FIXME (Meet): Cover below line inside the logger once it is implemented.
- print_model_size(model, train_config)
+ print_model_size(model)
# Note: Need to call this before calling PeftModel.from_pretrained or get_peft_model.
# Because, both makes model.is_gradient_checkpointing = True which is used in peft library to
@@ -157,23 +162,21 @@ def load_model_and_tokenizer(
if hasattr(model, "supports_gradient_checkpointing") and model.supports_gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": False})
else:
- raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.")
+ logger.raise_error(
+ "Given model doesn't support gradient checkpointing. Please disable it and run it.", RuntimeError
+ )
- model = apply_peft(model, train_config, peft_config_file, **kwargs)
+ model = apply_peft(model, train_config, **kwargs)
return model, tokenizer
-def apply_peft(
- model: AutoModel, train_config: TrainConfig, peft_config_file: Dict, **kwargs
-) -> Union[AutoModel, PeftModel]:
+def apply_peft(model: AutoModel, train_config: TrainConfig, **kwargs) -> Union[AutoModel, PeftModel]:
"""Apply Parameter-Efficient Fine-Tuning (PEFT) to the model if enabled.
Args:
model (AutoModel): Huggingface model.
train_config (TrainConfig): Training configuration object.
- peft_config_file (str, optional): Path to YAML/JSON file containing
- PEFT (LoRA) config. Defaults to None.
kwargs: Additional arguments to override PEFT config params.
Returns:
@@ -190,9 +193,9 @@ def apply_peft(
peft_config = model.peft_config
# Generate the peft config and start fine-tuning from original model
else:
- peft_config = generate_peft_config(train_config, peft_config_file, **kwargs)
+ peft_config = generate_peft_config(train_config, **kwargs)
model = get_peft_model(model, peft_config)
- model.print_trainable_parameters()
+ print_trainable_parameters(model)
return model
@@ -217,7 +220,7 @@ def setup_dataloaders(
- Length of longest sequence in the dataset.
Raises:
- ValueError: If validation is enabled but the validation set is too small.
+ RuntimeError: If validation is enabled but the validation set is too small.
Notes:
- Applies a custom data collator if provided by get_custom_data_collator.
@@ -225,17 +228,18 @@ def setup_dataloaders(
"""
train_dataloader = get_dataloader(tokenizer, dataset_config, train_config, split="train")
- print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
+ logger.log_rank_zero(f"Number of Training Set Batches loaded = {len(train_dataloader)}")
eval_dataloader = None
if train_config.run_validation:
eval_dataloader = get_dataloader(tokenizer, dataset_config, train_config, split="val")
if len(eval_dataloader) == 0:
- raise ValueError(
- f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
+ logger.raise_error(
+ f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})",
+ ValueError,
)
else:
- print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
+ logger.log_rank_zero(f"Number of Validation Set Batches loaded = {len(eval_dataloader)}")
longest_seq_length, _ = get_longest_seq_length(
torch.utils.data.ConcatDataset([train_dataloader.dataset, eval_dataloader.dataset])
@@ -246,12 +250,11 @@ def setup_dataloaders(
return train_dataloader, eval_dataloader, longest_seq_length
-def main(peft_config_file: str = None, **kwargs) -> None:
+def main(**kwargs) -> None:
"""
Fine-tune a model on QAIC hardware with configurable training and LoRA parameters.
Args:
- peft_config_file (str, optional): Path to YAML/JSON file containing PEFT (LoRA) config. Defaults to None.
kwargs: Additional arguments to override TrainConfig.
Example:
@@ -274,18 +277,19 @@ def main(peft_config_file: str = None, **kwargs) -> None:
dataset_config = generate_dataset_config(train_config.dataset)
update_config(dataset_config, **kwargs)
+ logger.prepare_for_logs(train_config.output_dir, train_config.dump_logs, train_config.log_level)
+
setup_distributed_training(train_config)
setup_seeds(train_config.seed)
- model, tokenizer = load_model_and_tokenizer(train_config, dataset_config, peft_config_file, **kwargs)
+ model, tokenizer = load_model_and_tokenizer(train_config, dataset_config, **kwargs)
# Create DataLoaders for the training and validation dataset
train_dataloader, eval_dataloader, longest_seq_length = setup_dataloaders(train_config, dataset_config, tokenizer)
- print(
+ logger.log_rank_zero(
f"The longest sequence length in the train data is {longest_seq_length}, "
f"passed context length is {train_config.context_length} and overall model's context length is "
f"{model.config.max_position_embeddings}"
)
-
model.to(train_config.device)
optimizer = optim.AdamW(model.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay)
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
diff --git a/QEfficient/compile/qnn_compiler.py b/QEfficient/compile/qnn_compiler.py
index 0f862b972..e2ec20364 100644
--- a/QEfficient/compile/qnn_compiler.py
+++ b/QEfficient/compile/qnn_compiler.py
@@ -12,12 +12,12 @@
from typing import Dict, List, Optional
from QEfficient.utils._utils import create_json, execute_command, load_json
-from QEfficient.utils.cache import to_hashable
from QEfficient.utils.constants import QnnConstants
from QEfficient.utils.generate_qnn_network_specialization_config import (
generate_data_format_config,
generate_qnn_specialization,
)
+from QEfficient.utils.hash_utils import to_hashable
from QEfficient.utils.logging_utils import logger
diff --git a/QEfficient/exporter/export_hf_to_cloud_ai_100.py b/QEfficient/exporter/export_hf_to_cloud_ai_100.py
index 1a0a04fc3..b769680ef 100644
--- a/QEfficient/exporter/export_hf_to_cloud_ai_100.py
+++ b/QEfficient/exporter/export_hf_to_cloud_ai_100.py
@@ -129,7 +129,6 @@ def export_bertstyle_model_to_onnx(model_name, model, tokenizer, onnx_dir_path,
)
# Generate inputFiles
- # todo(ochougul):rename to bert_style_input_list.txt
input_list_file = os.path.join(onnx_dir_path, "input_list.txt")
generate_input_files(
input_files_path=os.path.join(onnx_dir_path, "inputFiles"),
diff --git a/QEfficient/exporter/export_utils.py b/QEfficient/exporter/export_utils.py
index 11bb1e7bb..f86a0f254 100644
--- a/QEfficient/exporter/export_utils.py
+++ b/QEfficient/exporter/export_utils.py
@@ -218,8 +218,6 @@ def fix_onnx_fp16(
:str: Updated base name of exported ONNX model.
"""
model = onnx.load(os.path.join(gen_models_path, f"{model_base_name}.onnx"))
- # TODO: Remove this `fix_onnx_fp16` function and replace with this transform
- # as we're not utilizing the validations done in this function
model, fp16_fix = FP16ClipTransform.apply(model, onnx_base_dir=gen_models_path)
if fp16_fix:
@@ -256,8 +254,6 @@ def fix_onnx_fp16(
if ort_outputs is not None:
for oname, orto, ortof in zip(output_names, ort_outputs, ort_outputs_fixed):
fix_diff = np.abs(orto.astype(np.float32) - ortof.astype(np.float32)).max()
- # TODO: need to the debug this
- # info(oname, fix_diff)
close_outputs.append(fix_diff < 1e-5)
else:
info("No constants out of FP16 range")
diff --git a/QEfficient/finetune/configs/dataset_config.py b/QEfficient/finetune/configs/dataset_config.py
index b4ec1de3f..1f4fe094b 100644
--- a/QEfficient/finetune/configs/dataset_config.py
+++ b/QEfficient/finetune/configs/dataset_config.py
@@ -8,13 +8,6 @@
from dataclasses import dataclass
-@dataclass
-class samsum_dataset:
- dataset: str = "samsum_dataset"
- train_split: str = "train"
- test_split: str = "validation"
-
-
@dataclass
class grammar_dataset:
dataset: str = "grammar_dataset"
diff --git a/QEfficient/finetune/configs/training.py b/QEfficient/finetune/configs/training.py
index deac537bc..2d91f8403 100644
--- a/QEfficient/finetune/configs/training.py
+++ b/QEfficient/finetune/configs/training.py
@@ -5,8 +5,11 @@
#
# -----------------------------------------------------------------------------
+import logging
from dataclasses import dataclass
+from QEfficient.finetune.utils.helper import Batching_Strategy, Device, Peft_Method, Task_Mode
+
# Configuration Classes
@dataclass
@@ -33,12 +36,13 @@ class TrainConfig:
weight_decay (float): Weight decay for optimizer (default: 0.0).
gamma (float): Learning rate decay factor (default: 0.85).
seed (int): Random seed for reproducibility (default: 42).
- dataset (str): Dataset name for training (default: "samsum_dataset").
- task_type (str): Type of task for which the finetuning is to be done. Options: "generation" and "seq_classification". (default: "generation")
+ dataset (str): Dataset name for training (default: "alpaca_dataset").
+ task_mode (str): Mode of task for which the finetuning is to be done. Options: "generation" and "seq_classification". (default: "generation")
use_peft (bool): Whether to use PEFT (default: True).
peft_method (str): Parameter-efficient fine-tuning method (default: "lora").
- from_peft_checkpoint (str): Path to PEFT checkpoint (default: "").
- output_dir (str): Directory to save outputs (default: "meta-llama-samsum").
+ peft_config_file (str): Path to YAML/JSON file containing PEFT (LoRA) config. (default: None)
+ from_peft_checkpoint (str): Path to PEFT checkpoint (default: None).
+ output_dir (str): Directory to save outputs (default: "training_results").
save_model (bool): Save the trained model (default: True).
save_metrics (bool): Save training metrics (default: True).
intermediate_step_save (int): Steps between intermediate saves (default: 1000).
@@ -48,8 +52,9 @@ class TrainConfig:
convergence_loss (float): Loss threshold for convergence (default: 1e-4).
use_profiler (bool): Enable profiling (default: False).
enable_ddp (bool): Enable distributed data parallel (default: False).
- dump_root_dir (str): Directory for mismatch dumps (default: "meta-llama-samsum-mismatches/step_").
opByOpVerifier (bool): Enable operation-by-operation verification (default: False).
+ dump_logs (bool): Whether to dump logs (default: True).
+ log_level (str): logging level (default: logging.INFO)
"""
model_name: str = "meta-llama/Llama-3.2-1B"
@@ -65,22 +70,23 @@ class TrainConfig:
num_epochs: int = 1
max_train_step: int = 0
max_eval_step: int = 0
- device: str = "qaic"
+ device: str = Device.QAIC.value
num_workers_dataloader: int = 1
lr: float = 3e-4
weight_decay: float = 0.0
gamma: float = 0.85 # multiplicatively decay the learning rate by gamma after each epoch
seed: int = 42
dataset: str = "alpaca_dataset"
- task_type: str = "generation" # "generation" / "seq_classification"
+ task_mode: str = Task_Mode.GENERATION.value # "generation" / "seq_classification"
use_peft: bool = True # use parameter efficient finetuning
- peft_method: str = "lora"
- from_peft_checkpoint: str = "" # if not empty and peft_method='lora', will load the peft checkpoint and resume the fine-tuning on that checkpoint
+ peft_method: str = Peft_Method.LORA.value
+ peft_config_file: str = None
+ from_peft_checkpoint: str = None # if not empty and peft_method='lora', will load the peft checkpoint and resume the fine-tuning on that checkpoint
output_dir: str = "training_results"
save_model: bool = True
save_metrics: bool = True # saves training metrics to a json file for later plotting
intermediate_step_save: int = 1000
- batching_strategy: str = "packing"
+ batching_strategy: str = Batching_Strategy.PADDING.value
enable_ddp: bool = False
enable_sorting_for_ddp: bool = True
convergence_counter: int = 5 # its value should be >= 1, stop fine tuning when loss <= convergence_loss (defined below) for #convergence_counter steps
@@ -94,5 +100,7 @@ class TrainConfig:
use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
# profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler
- dump_root_dir: str = "mismatches/step_"
opByOpVerifier: bool = False
+
+ dump_logs: bool = True
+ log_level: str = logging.INFO
diff --git a/QEfficient/finetune/data/sampler.py b/QEfficient/finetune/data/sampler.py
index 1a4115419..60f789cbc 100644
--- a/QEfficient/finetune/data/sampler.py
+++ b/QEfficient/finetune/data/sampler.py
@@ -4,11 +4,9 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
-
import random
from itertools import islice
-import numpy as np
import torch
@@ -22,14 +20,14 @@ def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle = shuffle
+ self.data_source = data_source
def __iter__(self):
- ids = np.argsort(self.lengths, kind="mergesort")
+ ids = list(range(len(self.data_source)))
if self.drop_last:
ids = ids[: len(ids) // self.batch_size * self.batch_size]
batches = [ids[i : i + self.batch_size] for i in range(0, len(ids), self.batch_size)]
-
if self.shuffle:
random.shuffle(batches)
@@ -45,11 +43,17 @@ def __len__(self):
class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
def __init__(
- self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0
+ self,
+ data_source,
+ batch_size: int,
+ num_replicas: int,
+ rank: int,
+ shuffle: bool = True,
+ seed: int = 0,
) -> None:
random.seed(seed)
self.batch_sampler = LengthBasedBatchSampler(
- data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
+ data_source, batch_size=batch_size, drop_last=False, shuffle=shuffle
)
self.num_replicas = num_replicas
self.rank = rank
diff --git a/QEfficient/finetune/dataset/alpaca_dataset.py b/QEfficient/finetune/dataset/alpaca_dataset.py
index aecc0d2cc..c6ddb6ce1 100644
--- a/QEfficient/finetune/dataset/alpaca_dataset.py
+++ b/QEfficient/finetune/dataset/alpaca_dataset.py
@@ -11,6 +11,8 @@
import torch
from torch.utils.data import Dataset
+from QEfficient.finetune.utils.logging_utils import logger
+
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
@@ -27,7 +29,13 @@
class InstructionDataset(Dataset):
def __init__(self, dataset_config, tokenizer, partition="train", context_length=None):
- self.ann = json.load(open(dataset_config.data_path))
+ try:
+ self.ann = json.load(open(dataset_config.data_path))
+ except FileNotFoundError:
+ logger.raise_error(
+ "Loading of alpaca dataset failed! Please use (wget -c https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/refs/heads/main/alpaca_data.json -P dataset/) to download the alpaca dataset.",
+ FileNotFoundError,
+ )
# Use 5% of the dataset for evaluation
eval_length = int(len(self.ann) / 20)
if partition == "train":
diff --git a/QEfficient/finetune/dataset/custom_dataset.py b/QEfficient/finetune/dataset/custom_dataset.py
index 6d9baf90d..4a1f500e3 100644
--- a/QEfficient/finetune/dataset/custom_dataset.py
+++ b/QEfficient/finetune/dataset/custom_dataset.py
@@ -8,6 +8,8 @@
import importlib
from pathlib import Path
+from QEfficient.finetune.utils.logging_utils import logger
+
def load_module_from_py_file(py_file: str) -> object:
"""
@@ -30,20 +32,22 @@ def get_custom_dataset(dataset_config, tokenizer, split: str, context_length=Non
module_path, func_name = dataset_config.file, "get_custom_dataset"
if not module_path.endswith(".py"):
- raise ValueError(f"Dataset file {module_path} is not a .py file.")
+ logger.raise_error(f"Dataset file {module_path} is not a .py file.", ValueError)
module_path = Path(module_path)
if not module_path.is_file():
- raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
+ logger.raise_error(
+ f"Dataset py file {module_path.as_posix()} does not exist or is not a file.", FileNotFoundError
+ )
module = load_module_from_py_file(module_path.as_posix())
try:
return getattr(module, func_name)(dataset_config, tokenizer, split, context_length)
- except AttributeError as e:
- print(
- f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()})."
+ except AttributeError:
+ logger.raise_error(
+ f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).",
+ AttributeError,
)
- raise e
def get_data_collator(dataset_processer, dataset_config):
@@ -53,16 +57,20 @@ def get_data_collator(dataset_processer, dataset_config):
module_path, func_name = dataset_config.file, "get_data_collator"
if not module_path.endswith(".py"):
- raise ValueError(f"Dataset file {module_path} is not a .py file.")
+ logger.raise_error(f"Dataset file {module_path} is not a .py file.", ValueError)
module_path = Path(module_path)
if not module_path.is_file():
- raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
+ logger.raise_error(
+ f"Dataset py file {module_path.as_posix()} does not exist or is not a file.", FileNotFoundError
+ )
module = load_module_from_py_file(module_path.as_posix())
try:
return getattr(module, func_name)(dataset_processer)
except AttributeError:
- print(f"Can not find the custom data_collator in the dataset.py file ({module_path.as_posix()}).")
- print("Using the default data_collator instead.")
+ logger.log_rank_zero(
+ f"Can not find the custom data_collator in the dataset.py file ({module_path.as_posix()})."
+ )
+ logger.log_rank_zero("Using the default data_collator instead.")
return None
diff --git a/QEfficient/finetune/dataset/dataset_config.py b/QEfficient/finetune/dataset/dataset_config.py
index 63f4cf5f2..2e477be77 100644
--- a/QEfficient/finetune/dataset/dataset_config.py
+++ b/QEfficient/finetune/dataset/dataset_config.py
@@ -21,14 +21,10 @@
from QEfficient.finetune.dataset.imdb_dataset import (
get_preprocessed_imdb as get_imdb_dataset,
)
-from QEfficient.finetune.dataset.samsum_dataset import (
- get_preprocessed_samsum as get_samsum_dataset,
-)
DATASET_PREPROC = {
"alpaca_dataset": partial(get_alpaca_dataset),
"grammar_dataset": get_grammar_dataset,
- "samsum_dataset": get_samsum_dataset,
"gsm8k_dataset": get_gsm8k_dataset,
"custom_dataset": get_custom_dataset,
"imdb_dataset": get_imdb_dataset,
diff --git a/QEfficient/finetune/dataset/grammar_dataset.py b/QEfficient/finetune/dataset/grammar_dataset.py
index 43ee39158..e40c01e97 100644
--- a/QEfficient/finetune/dataset/grammar_dataset.py
+++ b/QEfficient/finetune/dataset/grammar_dataset.py
@@ -10,6 +10,8 @@
from datasets import load_dataset
from torch.utils.data import Dataset
+from QEfficient.finetune.utils.logging_utils import logger
+
class grammar(Dataset):
def __init__(self, tokenizer, csv_name=None, context_length=None):
@@ -19,11 +21,11 @@ def __init__(self, tokenizer, csv_name=None, context_length=None):
data_files={"train": [csv_name]}, # "eval": "grammar_validation.csv"},
delimiter=",",
)
- except Exception as e:
- print(
- "Loading of grammar dataset failed! Please see [here](https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb) for details on how to download the dataset."
+ except FileNotFoundError:
+ logger.raise_error(
+ "Loading of grammar dataset failed! Please check (https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb) for details on how to download the dataset.",
+ FileNotFoundError,
)
- raise e
self.context_length = context_length
self.tokenizer = tokenizer
@@ -36,7 +38,7 @@ def convert_to_features(self, example_batch):
# Create prompt and tokenize contexts and questions
if self.print_text:
- print("Input Text: ", self.clean_text(example_batch["text"]))
+ logger.log_rank_zero("Input Text: ", self.clean_text(example_batch["text"]))
input_ = example_batch["input"]
target_ = example_batch["target"]
@@ -71,9 +73,6 @@ def get_dataset(dataset_config, tokenizer, csv_name=None, context_length=None):
"""cover function for handling loading the working dataset"""
"""dataset loading"""
currPath = Path.cwd() / "datasets_grammar" / "grammar_train.csv"
- print(f"Loading dataset {currPath}")
- csv_name = str(currPath)
- print(csv_name)
- dataset = grammar(tokenizer=tokenizer, csv_name=csv_name, context_length=context_length)
+ dataset = grammar(tokenizer=tokenizer, csv_name=str(currPath), context_length=context_length)
return dataset
diff --git a/QEfficient/finetune/dataset/samsum_dataset.py b/QEfficient/finetune/dataset/samsum_dataset.py
deleted file mode 100644
index 67726d731..000000000
--- a/QEfficient/finetune/dataset/samsum_dataset.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# -----------------------------------------------------------------------------
-#
-# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
-# SPDX-License-Identifier: BSD-3-Clause
-#
-# -----------------------------------------------------------------------------
-
-import datasets
-
-
-def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None):
- dataset = datasets.load_dataset("Samsung/samsum", split=split, trust_remote_code=True)
-
- prompt = "Summarize this dialog:\n{dialog}\n---\nSummary:\n"
-
- def apply_prompt_template(sample):
- return {
- "prompt": prompt.format(dialog=sample["dialogue"]),
- "summary": sample["summary"],
- }
-
- dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
-
- def tokenize_add_label(sample):
- prompt = tokenizer.encode(
- tokenizer.bos_token + sample["prompt"],
- add_special_tokens=False,
- max_length=context_length,
- pad_to_max_length=True,
- )
- summary = tokenizer.encode(
- sample["summary"] + tokenizer.eos_token,
- add_special_tokens=False,
- max_length=context_length,
- pad_to_max_length=True,
- )
-
- sample = {
- "input_ids": prompt + summary,
- "attention_mask": [1] * (len(prompt) + len(summary)),
- "labels": [-100] * len(prompt) + summary,
- }
-
- return sample
-
- dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
-
- return dataset
diff --git a/QEfficient/finetune/eval.py b/QEfficient/finetune/eval.py
index c0d29d38b..72407a91e 100644
--- a/QEfficient/finetune/eval.py
+++ b/QEfficient/finetune/eval.py
@@ -19,13 +19,14 @@
from utils.train_utils import evaluation, print_model_size
from QEfficient.finetune.configs.training import TrainConfig
+from QEfficient.finetune.utils.logging_utils import logger
try:
import torch_qaic # noqa: F401
device = "qaic:0"
except ImportError as e:
- print(f"Warning: {e}. Moving ahead without these qaic modules.")
+ logger.log_rank_zero(f"{e}. Moving ahead without these qaic modules.")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Suppress all warnings
@@ -77,25 +78,20 @@ def main(**kwargs):
# If there is a mismatch between tokenizer vocab size and embedding matrix,
# throw a warning and then expand the embedding matrix
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
- print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
+ logger.log_rank_zero("Resizing the embedding matrix to match the tokenizer vocab size.")
model.resize_token_embeddings(len(tokenizer))
- print_model_size(model, train_config)
+ print_model_size(model)
if train_config.run_validation:
- # TODO: vbaddi enable packing later in entire infra.
- # if train_config.batching_strategy == "packing":
- # dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
-
eval_dataloader = get_dataloader(tokenizer, dataset_config, train_config, split="test")
-
- print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
if len(eval_dataloader) == 0:
- raise ValueError(
- f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
+ logger.raise_error(
+ f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})",
+ ValueError,
)
else:
- print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
+ logger.log_rank_zero(f"Number of Validation Set Batches loaded = {len(eval_dataloader)}")
model.to(device)
_ = evaluation(model, train_config, eval_dataloader, None, tokenizer, device)
diff --git a/QEfficient/finetune/utils/config_utils.py b/QEfficient/finetune/utils/config_utils.py
index bdc3c0429..64f17fecb 100644
--- a/QEfficient/finetune/utils/config_utils.py
+++ b/QEfficient/finetune/utils/config_utils.py
@@ -18,6 +18,8 @@
from QEfficient.finetune.configs.peft_config import LoraConfig
from QEfficient.finetune.configs.training import TrainConfig
from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC
+from QEfficient.finetune.utils.helper import Peft_Method
+from QEfficient.finetune.utils.logging_utils import logger
def update_config(config, **kwargs):
@@ -43,19 +45,19 @@ def update_config(config, **kwargs):
if hasattr(config, param_name):
setattr(config, param_name, v)
else:
- raise ValueError(f"Config '{config_name}' does not have parameter: '{param_name}'")
+ logger.raise_error(
+ f"Config '{config_name}' does not have parameter: '{param_name}'", ValueError
+ )
else:
config_type = type(config).__name__
- # FIXME (Meet): Once logger is available put this in debug level.
- print(f"[WARNING]: Unknown parameter '{k}' for config type '{config_type}'")
+ logger.debug(f"Unknown parameter '{k}' for config type '{config_type}'")
-def generate_peft_config(train_config: TrainConfig, peft_config_file: str = None, **kwargs) -> Any:
+def generate_peft_config(train_config: TrainConfig, **kwargs) -> Any:
"""Generate a PEFT-compatible configuration from a custom config based on peft_method.
Args:
train_config (TrainConfig): Training configuration with peft_method.
- custom_config: Custom configuration object (e.g., LoraConfig).
Returns:
Any: A PEFT-specific configuration object (e.g., PeftLoraConfig).
@@ -63,14 +65,14 @@ def generate_peft_config(train_config: TrainConfig, peft_config_file: str = None
Raises:
RuntimeError: If the peft_method is not supported.
"""
- if peft_config_file:
- peft_config_data = load_config_file(peft_config_file)
- validate_config(peft_config_data, config_type="lora")
+ if train_config.peft_config_file:
+ peft_config_data = load_config_file(train_config.peft_config_file)
+ validate_config(peft_config_data, config_type=Peft_Method.LORA)
peft_config = PeftLoraConfig(**peft_config_data)
else:
- config_map = {"lora": (LoraConfig, PeftLoraConfig)}
+ config_map = {Peft_Method.LORA: (LoraConfig, PeftLoraConfig)}
if train_config.peft_method not in config_map:
- raise RuntimeError(f"Peft config not found: {train_config.peft_method}")
+ logger.raise_error(f"Peft config not found: {train_config.peft_method}", RuntimeError)
config_cls, peft_config_cls = config_map[train_config.peft_method]
if config_cls is None:
@@ -103,7 +105,7 @@ def generate_dataset_config(dataset_name: str) -> Any:
return dataset_config
-def validate_config(config_data: Dict[str, Any], config_type: str = "lora") -> None:
+def validate_config(config_data: Dict[str, Any], config_type: str = Peft_Method.LORA) -> None:
"""Validate the provided YAML/JSON configuration for required fields and types.
Args:
@@ -118,8 +120,8 @@ def validate_config(config_data: Dict[str, Any], config_type: str = "lora") -> N
- Validates required fields for LoraConfig: r, lora_alpha, target_modules.
- Ensures types match expected values (int, float, list, etc.).
"""
- if config_type.lower() != "lora":
- raise ValueError(f"Unsupported config_type: {config_type}. Only 'lora' is supported.")
+ if config_type.lower() != Peft_Method.LORA:
+ logger.raise_error(f"Unsupported config_type: {config_type}. Only 'lora' is supported.", ValueError)
required_fields = {
"r": int,
@@ -136,26 +138,28 @@ def validate_config(config_data: Dict[str, Any], config_type: str = "lora") -> N
# Check for missing required fields
missing_fields = [field for field in required_fields if field not in config_data]
if missing_fields:
- raise ValueError(f"Missing required fields in {config_type} config: {missing_fields}")
+ logger.raise_error(f"Missing required fields in {config_type} config: {missing_fields}", ValueError)
# Validate types of required fields
for field, expected_type in required_fields.items():
if not isinstance(config_data[field], expected_type):
- raise ValueError(
+ logger.raise_error(
f"Field '{field}' in {config_type} config must be of type {expected_type.__name__}, "
- f"got {type(config_data[field]).__name__}"
+ f"got {type(config_data[field]).__name__}",
+ ValueError,
)
# Validate target_modules contains strings
if not all(isinstance(mod, str) for mod in config_data["target_modules"]):
- raise ValueError("All elements in 'target_modules' must be strings")
+ logger.raise_error("All elements in 'target_modules' must be strings", ValueError)
# Validate types of optional fields if present
for field, expected_type in optional_fields.items():
if field in config_data and not isinstance(config_data[field], expected_type):
- raise ValueError(
+ logger.raise_error(
f"Field '{field}' in {config_type} config must be of type {expected_type.__name__}, "
- f"got {type(config_data[field]).__name__}"
+ f"got {type(config_data[field]).__name__}",
+ ValueError,
)
@@ -173,7 +177,7 @@ def load_config_file(config_path: str) -> Dict[str, Any]:
ValueError: If the file format is unsupported.
"""
if not os.path.exists(config_path):
- raise FileNotFoundError(f"Config file not found: {config_path}")
+ logger.raise_error(f"Config file not found: {config_path}", FileNotFoundError)
with open(config_path, "r") as f:
if config_path.endswith(".yaml") or config_path.endswith(".yml"):
@@ -181,4 +185,4 @@ def load_config_file(config_path: str) -> Dict[str, Any]:
elif config_path.endswith(".json"):
return json.load(f)
else:
- raise ValueError("Unsupported config file format. Use .yaml, .yml, or .json")
+ logger.raise_error("Unsupported config file format. Use .yaml, .yml, or .json", ValueError)
diff --git a/QEfficient/finetune/utils/dataset_utils.py b/QEfficient/finetune/utils/dataset_utils.py
index 42d0aae71..aacff2bb5 100644
--- a/QEfficient/finetune/utils/dataset_utils.py
+++ b/QEfficient/finetune/utils/dataset_utils.py
@@ -4,20 +4,22 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
-
+import datasets
import torch
import torch.distributed as dist
from transformers.data import DataCollatorForSeq2Seq
from QEfficient.finetune.data.sampler import DistributedLengthBasedBatchSampler
from QEfficient.finetune.dataset.dataset_config import DATALOADER_COLLATE_FUNC, DATASET_PREPROC
+from QEfficient.finetune.utils.helper import get_num_ddp_devices
+from QEfficient.finetune.utils.logging_utils import logger
def get_preprocessed_dataset(
tokenizer, dataset_config, split: str = "train", context_length: int = None
) -> torch.utils.data.Dataset:
if dataset_config.dataset not in DATASET_PREPROC:
- raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented")
+ logger.raise_error(f"{dataset_config.dataset} is not (yet) implemented", NotImplementedError)
def get_split():
return dataset_config.train_split if split == "train" else dataset_config.test_split
@@ -38,8 +40,9 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, split):
if train_config.enable_ddp:
if train_config.enable_sorting_for_ddp:
if train_config.context_length:
- raise ValueError(
- "Sorting cannot be done with padding, Please disable sorting or pass context_length as None to disable padding"
+ logger.raise_error(
+ "Sorting cannot be done with padding, Please disable sorting or pass context_length as None to disable padding",
+ ValueError,
)
else:
kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
@@ -54,25 +57,56 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, split):
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False
)
kwargs["batch_size"] = batch_size
- kwargs["drop_last"] = True
+ kwargs["drop_last"] = False
else:
kwargs["batch_size"] = batch_size
- kwargs["drop_last"] = True
+ kwargs["drop_last"] = False
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
return kwargs
+def padding_dataset(train_config, dataset, batch_size):
+ if train_config.enable_ddp and train_config.enable_sorting_for_ddp:
+ if isinstance(dataset, datasets.Dataset):
+ # Hugging Face Dataset transformation
+ dataset = dataset.map(lambda x: {"input_length": len(x["input_ids"])})
+ dataset = dataset.sort("input_length")
+
+ else:
+ dataset = sorted(dataset, key=lambda x: len(x["input_ids"]))
+
+ dummy_row = next(iter(dataset))
+ dummy_row["labels"] = torch.tensor([-100] * len(dummy_row["labels"]))
+ padding_size = 0
+ num_replicas = get_num_ddp_devices()
+ remainder = len(dataset) % (num_replicas * batch_size)
+ padding_size = (num_replicas * batch_size) - remainder
+
+ dummy_data = [dummy_row.copy() for _ in range(padding_size)]
+ dummy_dataset = datasets.Dataset.from_list(dummy_data)
+ if isinstance(dataset, datasets.Dataset):
+ combined_dataset = datasets.concatenate_datasets([dataset, dummy_dataset])
+ else:
+ combined_dataset = dataset + list(dummy_dataset)
+ return combined_dataset
+
+
def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"):
dataset = get_preprocessed_dataset(tokenizer, dataset_config, split, context_length=train_config.context_length)
+
+ batch_size = train_config.train_batch_size if split == "train" else train_config.val_batch_size
+ dataset = padding_dataset(train_config, dataset, batch_size)
+
dl_kwargs = get_dataloader_kwargs(train_config, dataset, tokenizer, split)
# FIXME (Meet): Add custom data collator registration from the outside by the user.
custom_data_collator = get_custom_data_collator(tokenizer, dataset_config)
+
if custom_data_collator:
print("custom_data_collator is used")
dl_kwargs["collate_fn"] = custom_data_collator
- print(f"length of dataset_{split}", len(dataset))
+ logger.log_rank_zero(f"Length of {split} dataset is {len(dataset)}")
# Create data loader
dataloader = torch.utils.data.DataLoader(
diff --git a/QEfficient/finetune/utils/helper.py b/QEfficient/finetune/utils/helper.py
index fcc44fec8..378238a94 100644
--- a/QEfficient/finetune/utils/helper.py
+++ b/QEfficient/finetune/utils/helper.py
@@ -4,8 +4,76 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
+import os
+from contextlib import nullcontext
+from enum import Enum
-TASK_TYPE = ["generation", "seq_classification"]
-PEFT_METHOD = ["lora"]
-DEVICE = ["qaic", "cpu", "cuda"]
-BATCHING_STRATEGY = ["padding", "packing"]
+import torch
+
+try:
+ import torch_qaic.debug as qaic_debug # noqa: F401
+except ImportError as e:
+ print(f"Warning: {e}. Moving ahead without these qaic modules.")
+
+
+class Batching_Strategy(str, Enum):
+ PADDING = "padding"
+ PACKING = "packing"
+
+
+class Device(str, Enum):
+ QAIC = "qaic"
+ CPU = "cpu"
+ CUDA = "cuda"
+
+
+class Peft_Method(str, Enum):
+ LORA = "lora"
+
+
+class Task_Mode(str, Enum):
+ GENERATION = "generation"
+ SEQ_CLASSIFICATION = "seq_classification"
+
+
+def enum_names(enum_cls):
+ return [member.value for member in enum_cls]
+
+
+def is_rank_zero():
+ return int(os.getenv("LOCAL_RANK", 0)) == 0
+
+
+def get_num_ddp_devices():
+ return int(os.getenv("WORLD_SIZE", 1))
+
+
+def get_autocast_ctx(use_autocast, device_type, dtype=torch.float16):
+ return torch.autocast(device_type=device_type, dtype=dtype) if use_autocast else nullcontext()
+
+
+def get_op_verifier_ctx(
+ use_op_by_op_verifier,
+ train_device,
+ dump_dir,
+ step,
+ ref_device="cpu",
+ ref_dtype=torch.float32,
+ atol=1e-1,
+ rtol=1e-5,
+ use_ref_output_on_mismatch=True,
+):
+ if not use_op_by_op_verifier:
+ return nullcontext()
+
+ filter_config = qaic_debug.DispatchFilterConfig.default(train_device)
+ dump_dir = dump_dir + "/mismatches/step_" + str(step)
+ return qaic_debug.OpByOpVerifierMode(
+ ref_device=ref_device,
+ ref_dtype=ref_dtype,
+ atol=atol,
+ rtol=rtol,
+ use_ref_output_on_mismatch=use_ref_output_on_mismatch,
+ filter_config=filter_config,
+ dump_root_dir=dump_dir,
+ )
diff --git a/QEfficient/finetune/utils/logging_utils.py b/QEfficient/finetune/utils/logging_utils.py
new file mode 100644
index 000000000..15a67223f
--- /dev/null
+++ b/QEfficient/finetune/utils/logging_utils.py
@@ -0,0 +1,54 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import logging
+import os
+from datetime import datetime
+
+from QEfficient.finetune.utils.helper import is_rank_zero
+
+
+class FTLogger:
+ def __init__(self):
+ self.logger = logging.getLogger("QEfficient")
+ if not getattr(self.logger, "_custom_methods_added", False):
+ self._bind_custom_methods()
+ self.logger._custom_methods_added = True # Prevent adding handlers/methods twice
+
+ def _bind_custom_methods(self):
+ def raise_error(message, errortype=RuntimeError):
+ self.logger.error(message)
+ raise errortype(message)
+
+ def log_rank_zero(msg: str, level: int = logging.INFO):
+ if is_rank_zero():
+ self.logger.log(level, msg, stacklevel=2)
+
+ def prepare_for_logs(output_path, dump_logs=False, level=logging.INFO):
+ self.logger.setLevel(level)
+ if dump_logs:
+ logs_path = os.path.join(output_path, "logs")
+ if not os.path.exists(logs_path):
+ os.makedirs(logs_path, exist_ok=True)
+ file_name = f"log-file-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + ".txt"
+ log_file = os.path.join(logs_path, file_name)
+
+ fh = logging.FileHandler(log_file)
+ fh.setLevel(level)
+ formatter = logging.Formatter("%(levelname)s - %(name)s - %(message)s")
+ fh.setFormatter(formatter)
+ self.logger.addHandler(fh)
+
+ self.logger.raise_error = raise_error
+ self.logger.log_rank_zero = log_rank_zero
+ self.logger.prepare_for_logs = prepare_for_logs
+
+ def get_logger(self):
+ return self.logger
+
+
+logger = FTLogger().get_logger()
diff --git a/QEfficient/finetune/utils/parser.py b/QEfficient/finetune/utils/parser.py
index 39ce5f969..8e606fb0b 100644
--- a/QEfficient/finetune/utils/parser.py
+++ b/QEfficient/finetune/utils/parser.py
@@ -6,9 +6,10 @@
# -----------------------------------------------------------------------------
import argparse
+import logging
from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC
-from QEfficient.finetune.utils.helper import BATCHING_STRATEGY, DEVICE, PEFT_METHOD, TASK_TYPE
+from QEfficient.finetune.utils.helper import Batching_Strategy, Device, Peft_Method, Task_Mode, enum_names
def str2bool(v):
@@ -110,7 +111,14 @@ def get_finetune_parser():
default=0,
help="Maximum evaluation steps, unlimited if 0",
)
- parser.add_argument("--device", required=False, type=str, default="qaic", choices=DEVICE, help="Device to train on")
+ parser.add_argument(
+ "--device",
+ required=False,
+ type=str,
+ default=Device.QAIC.value,
+ choices=enum_names(Device),
+ help="Device to train on",
+ )
parser.add_argument(
"--num_workers_dataloader",
"--num-workers-dataloader",
@@ -140,12 +148,12 @@ def get_finetune_parser():
help="Dataset name to be used for finetuning (default: %(default)s)",
)
parser.add_argument(
- "--task_type",
- "--task-type",
+ "--task_mode",
+ "--task-mode",
required=False,
type=str,
- default="generation",
- choices=TASK_TYPE,
+ default=Task_Mode.GENERATION.value,
+ choices=enum_names(Task_Mode),
help="Task used for finetuning. Use 'generation' for decoder based models and 'seq_classification' for encoder based models.",
)
parser.add_argument(
@@ -162,8 +170,8 @@ def get_finetune_parser():
"--peft-method",
required=False,
type=str,
- default="lora",
- choices=PEFT_METHOD,
+ default=Peft_Method.LORA.value,
+ choices=enum_names(Peft_Method),
help="Parameter efficient finetuning technique to be used. Currently only 'lora' is supported.",
)
parser.add_argument(
@@ -213,8 +221,8 @@ def get_finetune_parser():
"--batching-strategy",
required=False,
type=str,
- default="padding",
- choices=BATCHING_STRATEGY,
+ default=Batching_Strategy.PADDING.value,
+ choices=enum_names(Batching_Strategy),
help="Strategy for making batches of data points. Packing groups data points into batches by minimizing unnecessary empty spaces. Padding adds extra values (often zeros) to batch sequences so they align in size. Currently only padding is supported which is by default.",
)
parser.add_argument(
@@ -255,17 +263,28 @@ def get_finetune_parser():
help="Enable distributed data parallel training. This will load the replicas of model on given number of devices and train the model. This should be used using torchrun interface. Please check docs for exact usage.",
)
parser.add_argument(
- "--dump_root_dir",
- "--dump-root-dir",
+ "--opByOpVerifier",
+ action="store_true",
+ help=argparse.SUPPRESS,
+ # This is for debugging purpose only.
+ # Enables operation-by-operation verification w.r.t reference device(cpu).
+ # It is a context manager interface that captures and verifies each operator against reference device.
+ # In case results of test & reference do not match under given tolerances, a standalone unittest is generated at output_dir/mismatches.
+ )
+ parser.add_argument(
+ "--log_level",
+ "--log-level",
required=False,
type=str,
- default="mismatches/step_",
- help="Directory for mismatch dumps by opByOpVerifier",
+ default=logging.INFO,
+ help="logging level",
)
parser.add_argument(
- "--opByOpVerifier",
- action="store_true",
- help="Enable operation-by-operation verification w.r.t reference device(cpu). It is a context manager interface that captures and verifies each operator against reference device. In case results of test & reference do not match under given tolerances, a standalone unittest is generated at dump_root_dir.",
+ "--peft_config_file",
+ "--peft-config-file",
+ type=str,
+ default=None,
+ help="Path to YAML/JSON file containing PEFT (LoRA) config.",
)
return parser
diff --git a/QEfficient/finetune/utils/plot_metrics.py b/QEfficient/finetune/utils/plot_metrics.py
index 416ec3cdf..1e22bc6a8 100644
--- a/QEfficient/finetune/utils/plot_metrics.py
+++ b/QEfficient/finetune/utils/plot_metrics.py
@@ -11,6 +11,8 @@
import matplotlib.pyplot as plt
+from QEfficient.finetune.utils.logging_utils import logger
+
def plot_metric(data, metric_name, x_label, y_label, title, colors):
plt.figure(figsize=(7, 6))
@@ -67,14 +69,14 @@ def plot_metrics_by_step(data, metric_name, x_label, y_label, colors):
def plot_metrics(file_path):
if not os.path.exists(file_path):
- print(f"File {file_path} does not exist.")
+ logger.raise_error(f"File {file_path} does not exist.", FileNotFoundError)
return
with open(file_path, "r") as f:
try:
data = json.load(f)
- except json.JSONDecodeError:
- print("Invalid JSON file.")
+ except json.JSONDecodeError as e:
+ logger.raise_error("Invalid JSON file.", e)
return
directory = os.path.dirname(file_path)
diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py
index 9f9f06917..93bca856e 100644
--- a/QEfficient/finetune/utils/train_utils.py
+++ b/QEfficient/finetune/utils/train_utils.py
@@ -8,8 +8,8 @@
import json
import os
import time
-from contextlib import nullcontext
from datetime import datetime
+from functools import partial
from typing import Dict, List, Tuple
import torch
@@ -19,6 +19,8 @@
from tqdm import tqdm
from QEfficient.finetune.configs.training import TrainConfig
+from QEfficient.finetune.utils.helper import Task_Mode, get_autocast_ctx, get_op_verifier_ctx, is_rank_zero
+from QEfficient.finetune.utils.logging_utils import logger
try:
import torch_qaic # noqa: F401
@@ -27,7 +29,7 @@
import torch_qaic.utils as qaic_utils # noqa: F401
from torch.qaic.amp import GradScaler as QAicGradScaler
except ImportError as e:
- print(f"Warning: {e}. Moving ahead without these qaic modules.")
+ logger.log_rank_zero(f"{e}. Moving ahead without these qaic modules.")
from torch.amp import GradScaler
@@ -83,11 +85,9 @@ def train(
max_steps_reached = False # Flag to indicate max training steps reached
tensorboard_updates = None
- if train_config.enable_ddp:
- if local_rank == 0:
- tensorboard_updates = SummaryWriter()
- else:
- tensorboard_updates = SummaryWriter()
+ if is_rank_zero():
+ tensorboard_log_dir = train_config.output_dir + "/runs/" + f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
+ tensorboard_updates = SummaryWriter(log_dir=tensorboard_log_dir)
device_type = torch.device(device).type
@@ -103,37 +103,34 @@ def train(
dist.broadcast(loss_0_counter, src=0)
acc_helper = None
- if train_config.task_type == "seq_classification":
+ if train_config.task_mode == Task_Mode.SEQ_CLASSIFICATION:
if train_config.enable_ddp:
num_classes = model.module.classifier.out_features
else:
num_classes = model.classifier.out_features
acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device)
+ autocast_ctx = get_autocast_ctx(train_config.use_autocast, device_type, dtype=torch.float16)
+ op_verifier_ctx = partial(get_op_verifier_ctx, train_config.opByOpVerifier, device, train_config.output_dir)
+
# Start the training loop
for epoch in range(train_config.num_epochs):
if loss_0_counter.item() == train_config.convergence_counter:
- if train_config.enable_ddp:
- print(
- f"Not proceeding with epoch {epoch + 1} on device {local_rank} since loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps."
- )
- break
- else:
- print(
- f"Not proceeding with epoch {epoch + 1} since loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps."
- )
- break
+ logger.log_rank_zero(
+ f"Skipping epoch {epoch + 1} since loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps."
+ )
+ break
if train_config.use_peft and train_config.from_peft_checkpoint:
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
if epoch < intermediate_epoch:
- print(f"Skipping epoch {epoch + 1} since fine tuning has already completed for it.")
+ logger.log_rank_zero(f"Skipping epoch {epoch + 1} since fine tuning has already completed for it.")
# to bring the count of train_step in sync with where it left off
total_train_steps += len(train_dataloader)
continue
- print(f"Starting epoch {epoch + 1}/{train_config.num_epochs}")
- print(f"train_config.max_train_step: {train_config.max_train_step}")
+ logger.log_rank_zero(f"Starting epoch {epoch + 1}/{train_config.num_epochs}")
+ logger.log_rank_zero(f"train_config.max_train_step: {train_config.max_train_step}")
# stop when the maximum number of training steps is reached
if max_steps_reached:
break
@@ -151,7 +148,7 @@ def train(
# enable profile for qaic
qaic_profile.start_profiling(device, 1) if train_config.use_profiler else None
-
+ num_dummy_samples = 0
for step, batch in enumerate(train_dataloader):
# resume training from a particular checkpoint, assuming the dataset is not shuffled
if train_config.use_peft and train_config.from_peft_checkpoint:
@@ -160,8 +157,8 @@ def train(
# to bring the count of train_step in sync with where it left off
if epoch == intermediate_epoch and step == 0:
total_train_steps += intermediate_step
- print(
- f"skipping first {intermediate_step} steps for epoch {epoch + 1}, since fine tuning has already completed for them."
+ logger.log_rank_zero(
+ f"Skipping first {intermediate_step} steps for epoch {epoch + 1}, since fine tuning has already completed for it."
)
if epoch == intermediate_epoch and step < intermediate_step:
total_train_steps += 1
@@ -174,75 +171,75 @@ def train(
break
batch = {k: v.to(device) for k, v in batch.items()} # move the batch elements to qaic device
- with (
- torch.autocast(device_type=device_type, dtype=torch.float16)
- if train_config.use_autocast
- else nullcontext()
- ):
- # an additional condition can be put here to avoid opByOpVerifier getting triggered for each step
- if train_config.opByOpVerifier:
- with qaic_debug.OpByOpVerifierMode(
- ref_device="cpu",
- ref_dtype=torch.float32,
- # adjust atol & rtol this as required
- atol=1e-1,
- use_ref_output_on_mismatch=True,
- filter_config=qaic_debug.DispatchFilterConfig.default(device),
- dump_root_dir=train_config.dump_root_dir + str(step),
- ) as verifier:
- model_outputs = model(**batch)
- loss = model_outputs.loss # Forward call
- if train_config.task_type == "seq_classification":
- logits = model_outputs.logits
- labels = batch["labels"][:, 0]
- preds = torch.nn.functional.softmax(logits, dim=-1)
- acc_helper.forward(preds, labels)
- print("Mismatches detected:", verifier.get_perop_mismatch_count())
+ is_optimizer_step = (step + 1) % train_config.gradient_accumulation_steps == 0 or step == len(
+ train_dataloader
+ ) - 1
+ if train_config.enable_ddp:
+ # Below block derived from : https://github.com/karpathy/nanoGPT/blob/93a43d9a5c22450bbf06e78da2cb6eeef084b717/train.py#L293
+ # in DDP training we only need to sync gradients at the last micro step.
+ # the official way to do this is with model.no_sync() context manager, but
+ # using too many context managers may bloat the code and forces us to repeat code
+ # looking at the source of that context manager, it just toggles this variable
+ model.require_backward_grad_sync = is_optimizer_step
+
+ with autocast_ctx, op_verifier_ctx(step) as verifier:
+ model_outputs = model(**batch)
+ loss = model_outputs.loss # Forward call
+ if (batch["labels"] != -100).sum() == 0:
+ loss = loss.nan_to_num(nan=0.0)
+ num_dummy_samples += train_config.train_batch_size
else:
- model_outputs = model(**batch)
- loss = model_outputs.loss # Forward call
- if train_config.task_type == "seq_classification":
- logits = model_outputs.logits
- labels = batch["labels"][:, 0]
- preds = torch.nn.functional.softmax(logits, dim=-1)
- acc_helper.forward(preds, labels)
+ num_dummy_samples_per_batch = (
+ (torch.sum(batch["labels"] == -100, dim=1) == batch["labels"].shape[1]).sum().item()
+ )
+ if num_dummy_samples_per_batch > 0:
+ num_dummy_samples += num_dummy_samples_per_batch
+ loss = loss * train_config.train_batch_size / num_dummy_samples_per_batch
+
+ if train_config.task_mode == Task_Mode.SEQ_CLASSIFICATION:
+ logits = model_outputs.logits
+ labels = batch["labels"][:, 0]
+ preds = torch.nn.functional.softmax(logits, dim=-1)
+ acc_helper.forward(preds, labels)
+ if train_config.opByOpVerifier:
+ logger.info("Mismatches detected:", verifier.get_perop_mismatch_count())
total_loss += loss.detach().float()
- # Accumalate gradients
- loss = loss / train_config.gradient_accumulation_steps
- if train_config.enable_ddp:
- if local_rank == 0:
- if loss <= train_config.convergence_loss:
- loss_0_counter += 1
- else:
- loss_0_counter = torch.tensor([0]).to(device)
- dist.broadcast(loss_0_counter, src=0)
- else:
+ if is_rank_zero():
if loss <= train_config.convergence_loss:
loss_0_counter += 1
else:
loss_0_counter = torch.tensor([0]).to(device)
-
if train_config.enable_ddp:
- if local_rank == 0:
- tensorboard_updates.add_scalars("loss", {"train": loss}, total_train_steps)
- else:
+ dist.broadcast(loss_0_counter, src=0)
+ if is_rank_zero():
tensorboard_updates.add_scalars("loss", {"train": loss}, total_train_steps)
if train_config.save_metrics:
train_step_loss.append(loss.detach().float().item())
- if train_config.task_type == "seq_classification":
+ if train_config.task_mode == Task_Mode.SEQ_CLASSIFICATION:
step_metric_val = float(acc_helper.compute())
else:
step_metric_val = float(torch.exp(loss.detach().float()))
train_step_metric.append(step_metric_val)
+ # Accumalate gradients
+ complete_accum_steps = (
+ len(train_dataloader) - len(train_dataloader) % train_config.gradient_accumulation_steps
+ )
+ if step < complete_accum_steps:
+ num_samples_in_cur_update = train_config.gradient_accumulation_steps
+ else:
+ num_samples_in_cur_update = len(train_dataloader) % train_config.gradient_accumulation_steps
+
+ loss = loss / num_samples_in_cur_update
+
if train_config.grad_scaler:
scaler.scale(loss).backward() # backward pass
else:
loss.backward() # backward pass
- if (step + 1) % train_config.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+ if is_optimizer_step:
if train_config.grad_scaler:
scaler.step(optimizer)
scaler.update()
@@ -277,18 +274,11 @@ def train(
val_step_metric,
val_metric,
)
- if train_config.enable_ddp:
- if loss_0_counter.item() == train_config.convergence_counter:
- print(
- f"Loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps. Hence, stopping the fine tuning on device {local_rank}."
- )
- break
- else:
- if loss_0_counter.item() == train_config.convergence_counter:
- print(
- f"Loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps. Hence, stopping the fine tuning."
- )
- break
+ if loss_0_counter.item() == train_config.convergence_counter:
+ logger.log_rank_zero(
+ f"Loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps.Hence,stopping the fine tuning."
+ )
+ break
pbar.close()
epoch_end_time = time.perf_counter() - epoch_start_time
@@ -296,16 +286,31 @@ def train(
if loss_0_counter.item() == train_config.convergence_counter:
if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch:
- train_epoch_loss = total_loss / (step - intermediate_step)
+ train_epoch_loss = (
+ 0.0
+ if total_loss == 0.0
+ else total_loss / (step - intermediate_step - num_dummy_samples / train_config.train_batch_size)
+ )
else:
- train_epoch_loss = total_loss / step
+ train_epoch_loss = (
+ 0.0
+ if total_loss == 0.0
+ else total_loss / (step + 1 - num_dummy_samples / train_config.train_batch_size)
+ )
else:
if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch:
- train_epoch_loss = total_loss / (len(train_dataloader) - intermediate_step)
+ train_epoch_loss = (
+ 0.0
+ if total_loss == 0.0
+ else total_loss / (step - intermediate_step - (num_dummy_samples / train_config.train_batch_size))
+ )
else:
- train_epoch_loss = total_loss / len(train_dataloader)
-
- if train_config.task_type == "seq_classification":
+ train_epoch_loss = (
+ 0.0
+ if total_loss == 0.0
+ else total_loss / (step + 1 - (num_dummy_samples / train_config.train_batch_size))
+ )
+ if train_config.task_mode == Task_Mode.SEQ_CLASSIFICATION:
metric_val = acc_helper.compute()
acc_helper.reset()
else:
@@ -318,18 +323,10 @@ def train(
lr_scheduler.step()
if train_config.run_validation:
- if train_config.enable_ddp:
- dist.barrier()
- eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation_helper(
- model, train_config, eval_dataloader, device
- )
- if local_rank == 0:
- tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
-
- else:
- eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation_helper(
- model, train_config, eval_dataloader, device
- )
+ eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation_helper(
+ model, train_config, eval_dataloader, device
+ )
+ if is_rank_zero():
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
if train_config.save_metrics:
@@ -347,15 +344,15 @@ def train(
if train_config.run_validation:
if eval_epoch_loss < best_val_loss:
best_val_loss = eval_epoch_loss
- print(f"best eval loss on epoch {epoch + 1} is {best_val_loss}")
+ logger.log_rank_zero(f"best eval loss on epoch {epoch + 1} is {best_val_loss}")
val_loss.append(float(eval_epoch_loss))
val_metric.append(float(eval_metric))
- if train_config.task_type == "seq_classification":
- print(
+ if train_config.task_mode == Task_Mode.SEQ_CLASSIFICATION:
+ logger.log_rank_zero(
f"Epoch {epoch + 1}: train_acc={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
)
else:
- print(
+ logger.log_rank_zero(
f"Epoch {epoch + 1}: train_metric={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
)
@@ -389,7 +386,6 @@ def train(
results["avg_checkpoint_time"] = avg_checkpoint_time
if train_config.save_metrics:
results["metrics_filename"] = metrics_filename
-
return results
@@ -403,9 +399,12 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
Returns: eval_epoch_loss, eval_metric, eval_step_loss, eval_step_metric
"""
+ if train_config.enable_ddp:
+ dist.barrier()
+
model.eval()
- if train_config.task_type == "seq_classification":
+ if train_config.task_mode == Task_Mode.SEQ_CLASSIFICATION:
if train_config.enable_ddp:
num_classes = model.module.classifier.out_features
else:
@@ -421,6 +420,8 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
eval_loss = 0.0 # Initialize evaluation loss
device_type = torch.device(device).type
+ num_dummy_samples = 0
+ autocast_ctx = get_autocast_ctx(train_config.use_autocast, device_type, dtype=torch.float16)
for step, batch in enumerate(tqdm(eval_dataloader, colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
# stop when the maximum number of eval steps is reached
if train_config.max_eval_step > 0 and step > train_config.max_eval_step:
@@ -431,15 +432,22 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
# Ensure no gradients are computed for this scope to save memory
with torch.no_grad():
# Forward pass and compute loss
- with (
- torch.autocast(device_type=device_type, dtype=torch.float16)
- if train_config.use_autocast
- else nullcontext()
- ):
+ with autocast_ctx:
outputs = model(**batch)
loss = outputs.loss
- if train_config.task_type == "seq_classification":
+ if (batch["labels"] != -100).sum() == 0:
+ loss = loss.nan_to_num(nan=0.0)
+ num_dummy_samples += 1
+ else:
+ num_dummy_samples_per_batch = (
+ (torch.sum(batch["labels"] == -100, dim=1) == batch["labels"].shape[1]).sum().item()
+ )
+ if num_dummy_samples_per_batch > 0:
+ num_dummy_samples += num_dummy_samples_per_batch
+ loss = loss * train_config.val_batch_size / num_dummy_samples_per_batch
+
+ if train_config.task_mode == Task_Mode.SEQ_CLASSIFICATION:
logits = outputs.logits
labels = batch["labels"][:, 0]
preds = torch.nn.functional.softmax(logits, dim=-1)
@@ -453,16 +461,17 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
val_step_metric.append(metric_val)
eval_loss += loss.detach().float()
-
# Compute average loss and metric
- eval_epoch_loss = eval_loss / len(eval_dataloader)
- if train_config.task_type == "seq_classification":
+ eval_epoch_loss = (
+ 0.0 if eval_loss == 0.0 else eval_loss / (step + 1 - num_dummy_samples / train_config.val_batch_size)
+ )
+ if train_config.task_mode == Task_Mode.SEQ_CLASSIFICATION:
eval_metric = acc_helper.compute()
else:
eval_metric = torch.exp(eval_epoch_loss)
# Print evaluation metrics
- print(f" {eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
+ logger.log_rank_zero(f"{eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
return eval_epoch_loss, eval_metric, val_step_loss, val_step_metric
@@ -475,18 +484,28 @@ def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
return longest_seq_length, longest_seq_ix
-def print_model_size(model, config) -> None:
+def print_model_size(model) -> None:
"""
Print model name, the number of trainable parameters and initialization time.
Args:
- model: The PyTorch model.
- model_name (str): Name of the model.
+ model: PyTorch model.
"""
-
- print(f"--> Model {config.model_name}")
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
- print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
+ logger.log_rank_zero(f"Model has {total_params / 1e6} Million params.")
+
+
+def print_trainable_parameters(model) -> None:
+ """
+ Print the number of trainable parameters, all params and percentage of trainablke params.
+
+ Args:
+ model: The PyTorch model.
+ """
+ trainable_params, all_param = model.get_nb_trainable_parameters()
+ logger.log_rank_zero(
+ f"Trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param:.4f}"
+ )
def save_to_json(
diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py
index a9690aa51..fd7ef03ff 100755
--- a/QEfficient/generation/text_generation_inference.py
+++ b/QEfficient/generation/text_generation_inference.py
@@ -60,7 +60,7 @@ def __repr__(self):
return f"Average Prefill time a.k.a TTFT is= {round(self.perf_metrics.prefill_time, 2)} sec\
\nDecode is= {round(self.perf_metrics.decode_perf * self.batch_size, 2)} tokens/sec\
\nTotal is= {round(self.perf_metrics.total_perf * self.batch_size, 2)} tokens/sec\
- \nTotal (E2E) inference time is= {round(self.perf_metrics.total_time, 2)} tokens/sec"
+ \nTotal (E2E) inference time is= {round(self.perf_metrics.total_time, 2)} sec"
@dataclass
diff --git a/QEfficient/peft/auto.py b/QEfficient/peft/auto.py
index f475ad4ad..f1532ad1b 100644
--- a/QEfficient/peft/auto.py
+++ b/QEfficient/peft/auto.py
@@ -27,7 +27,7 @@
from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform
from QEfficient.utils import constants
from QEfficient.utils._utils import get_padding_shape_from_config
-from QEfficient.utils.cache import to_hashable
+from QEfficient.utils.hash_utils import to_hashable
logger = logging.getLogger(__name__)
diff --git a/QEfficient/peft/lora/auto.py b/QEfficient/peft/lora/auto.py
index 1e83c18c9..14cadf997 100644
--- a/QEfficient/peft/lora/auto.py
+++ b/QEfficient/peft/lora/auto.py
@@ -18,7 +18,7 @@
from QEfficient import QEFFAutoModelForCausalLM
from QEfficient.peft.lora.pytorch_transforms import LoraModelInputsTransform, TargetModulesTransform
from QEfficient.utils import constants, get_padding_shape_from_config
-from QEfficient.utils.cache import to_hashable
+from QEfficient.utils.hash_utils import to_hashable
from QEfficient.utils.logging_utils import logger
diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py
index 7162b856a..16767fbe2 100644
--- a/QEfficient/transformers/cache_utils.py
+++ b/QEfficient/transformers/cache_utils.py
@@ -288,7 +288,6 @@ def from_legacy_cache(
class QEffHybridCache(HybridCache):
def __init__(self, config, batch_size, max_cache_len):
super().__init__(config, batch_size, max_cache_len=max_cache_len)
- # breakpoint()
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
@@ -327,7 +326,6 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
backward compatibility."""
legacy_cache = ()
- # breakpoint()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache
diff --git a/QEfficient/transformers/models/codegen/modeling_codegen.py b/QEfficient/transformers/models/codegen/modeling_codegen.py
index 09400c51e..e0f6b5196 100644
--- a/QEfficient/transformers/models/codegen/modeling_codegen.py
+++ b/QEfficient/transformers/models/codegen/modeling_codegen.py
@@ -85,7 +85,6 @@ def forward(
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]:
qkv = self.qkv_proj(hidden_states)
- # TODO(enijkamp): factor out number of logical TPU-v4 cores or make forward pass agnostic
mp_num = 4
qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))
diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py
index 593d17f1b..9dca5f050 100644
--- a/QEfficient/transformers/models/falcon/modeling_falcon.py
+++ b/QEfficient/transformers/models/falcon/modeling_falcon.py
@@ -183,7 +183,11 @@ def forward(
):
residual = hidden_states
- attention_layernorm_out = self.input_layernorm(hidden_states)
+ if self.config.new_decoder_architecture:
+ attention_layernorm_out = self.ln_attn(hidden_states)
+ mlp_layernorm_out = self.ln_mlp(hidden_states)
+ else:
+ attention_layernorm_out = self.input_layernorm(hidden_states)
# Self attention.
attn_outputs = self.self_attention(
diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py
index bda5959a7..9e9544b7e 100644
--- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py
+++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py
@@ -238,9 +238,9 @@ def forward(
)
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if self.is_sliding:
- cos, sin = self.rotary_emb_local(value_states, seq_len=constants.GEMMA3_MAX_POSITION_EMBEDDINGS)
+ cos, sin = self.rotary_emb_local(value_states, seq_len=self.config.max_position_embeddings)
else:
- cos, sin = self.rotary_emb(value_states, seq_len=constants.GEMMA3_MAX_POSITION_EMBEDDINGS)
+ cos, sin = self.rotary_emb(value_states, seq_len=self.config.max_position_embeddings)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
@@ -687,7 +687,6 @@ def get_specializations(
"mm_tokens_per_image": mm_tokens_per_image,
},
]
-
specializations = {}
if kv_offload:
diff --git a/QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
index d4a322a56..5dd9362ee 100644
--- a/QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
+++ b/QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
@@ -29,8 +29,6 @@
# Fused kernels
# Use separate functions for each case because conditionals prevent kernel fusion.
-# TODO: Could have better fused kernels depending on scaling, dropout and head mask.
-# Is it doable without writing 32 functions?
@torch.jit.script
def upcast_masked_softmax(
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py
index 13f0eae7c..b6fb9fd38 100644
--- a/QEfficient/transformers/models/internvl/modeling_internvl.py
+++ b/QEfficient/transformers/models/internvl/modeling_internvl.py
@@ -66,7 +66,6 @@ def get_specializations(
kv_offload: bool = False,
**compiler_options,
):
- # TODO: check if this should be named num_patches or something else
num_patches = compiler_options.pop("num_patches", None)
if num_patches is None:
logger.warning(
diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py
index 6b30c7804..d46aa345d 100644
--- a/QEfficient/transformers/models/llama4/modeling_llama4.py
+++ b/QEfficient/transformers/models/llama4/modeling_llama4.py
@@ -32,7 +32,7 @@
repeat_kv,
)
-from QEfficient.transformers.cache_utils import QEffHybridChunkedCache
+from QEfficient.transformers.cache_utils import QEffDynamicCache
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
from QEfficient.utils import constants
from QEfficient.utils._utils import IOInfo
@@ -312,8 +312,10 @@ def __init__(self, config: Llama4TextConfig, device=None):
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
# self.max_seq_len_cached = config.max_position_embeddings
- # TODO: vbaddi Shouldn't for rope, the max posision_embeddings be original embeddings for rope,
- # chunk size 8192 always? and Revisit when >8K Chunked attention is enabled.
+ # TODO: max sequence length cached should be taken before export and model should be exported with that paramter.
+ logger.warning(
+ f"max_seq_len_cached is set to {constants.LLAMA4_MAX_POSITION_EMBEDDINGS}, this is the maximum sequence length supported for the model"
+ )
self.max_seq_len_cached = constants.LLAMA4_MAX_POSITION_EMBEDDINGS
# Get inverse frequency and scaling function (handles yarn/etc)
@@ -636,7 +638,7 @@ def forward(
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
- past_key_values = QEffHybridChunkedCache.from_legacy_cache(self.config, past_key_values)
+ past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@@ -883,7 +885,6 @@ def get_specializations(
kv_offload: bool = False,
**compiler_options,
):
- # TODO: check if this should be named num_patches or something else
max_num_tiles = compiler_options.pop("max_num_tiles", None)
if max_num_tiles is None:
logger.warning(
@@ -901,6 +902,13 @@ def get_specializations(
else constants.LLAMA4_ATTENTION_CHUNK_SIZE
),
)
+ if (
+ prefill_seq_len > constants.LLAMA4_MAX_POSITION_EMBEDDINGS
+ or ctx_len > constants.LLAMA4_MAX_POSITION_EMBEDDINGS
+ ):
+ raise ValueError(
+ f"max_seq_len_cached is set to {constants.LLAMA4_MAX_POSITION_EMBEDDINGS}, Your prefill_seq_len is {prefill_seq_len} and ctx_len is {ctx_len}."
+ )
if img_size is None and hasattr(self.config.vision_config, "image_size"):
img_size = getattr(self.config.vision_config, "image_size")
diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py
index f6cf2de49..7b96aefcc 100644
--- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py
+++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py
@@ -371,8 +371,8 @@ def forward(
hidden_states = orig_hidden_states[torch.arange(orig_hidden_states.shape[0]).reshape(-1, 1), last_pos_id, :]
causal_mask = causal_mask[torch.arange(orig_hidden_states.shape[0]).reshape(-1, 1), :, last_pos_id, :]
else:
- hidden_states = orig_hidden_states[torch.arange(bsz), last_pos_id, :]
- causal_mask = causal_mask[torch.arange(bsz), :, last_pos_id, :]
+ hidden_states = orig_hidden_states[torch.arange(bsz).reshape(-1, 1), last_pos_id, :]
+ causal_mask = causal_mask[torch.arange(bsz).reshape(-1, 1), :, last_pos_id, :]
hidden_states, next_decoder_cache = self._run_swiftkv_layers(
hidden_states, position_ids, past_key_values, causal_mask, batch_index
diff --git a/QEfficient/transformers/models/llava_next/modeling_llava_next.py b/QEfficient/transformers/models/llava_next/modeling_llava_next.py
index 338d141f8..23434fc18 100755
--- a/QEfficient/transformers/models/llava_next/modeling_llava_next.py
+++ b/QEfficient/transformers/models/llava_next/modeling_llava_next.py
@@ -123,7 +123,6 @@ def __init__(self, model):
def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values):
inputs_embeds = self.model.get_input_embeddings()(input_ids)
image_features = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
- # breakpoint()
mask = input_ids == self.config.image_token_index
indices1 = mask.to(torch.int64).cumsum(1) - 1
indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1)
diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py
index 6bff10f5a..e487f8860 100644
--- a/QEfficient/transformers/models/modeling_auto.py
+++ b/QEfficient/transformers/models/modeling_auto.py
@@ -5,7 +5,6 @@
#
# ----------------------------------------------------------------------------
-import hashlib
import warnings
from pathlib import Path
from time import perf_counter
@@ -56,7 +55,6 @@
constants,
get_padding_shape_from_config,
)
-from QEfficient.utils.cache import to_hashable
from QEfficient.utils.logging_utils import logger
@@ -75,7 +73,7 @@ def __init__(self, model: nn.Module, **kwargs) -> None:
):
raise AssertionError("Please use `from_pretrained` method to load quantized models")
- super().__init__(model)
+ super().__init__(model, **kwargs)
def __repr__(self) -> str:
return self.__class__.__name__ + "\n" + self.model.__repr__()
@@ -164,15 +162,14 @@ class QEFFAutoModel(QEFFTransformersBase):
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
def __init__(self, model: nn.Module, pooling=None, **kwargs):
- super().__init__(model)
+ super().__init__(model, **kwargs)
# Make Embedding specific transforms like appending pooling
if pooling:
self.model, _ = PoolingTransform.apply(self.model, pooling)
self.model.base_model.config.use_cache = True
-
- self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None)
+ self.hash_params["qeff_auto_class"] = self.__class__.__name__
@classmethod
@with_replaced_quantizers
@@ -225,29 +222,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **k
kv_offload = kwargs.pop("kv_offload", None)
if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP:
return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__](
- model, kv_offload=kv_offload
+ model, kv_offload=kv_offload, **kwargs
)
return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, pooling=pooling, **kwargs)
- @property
- def model_hash(self) -> str:
- # NOTE: model_config.to_diff_dict() has "_name_or_path" attribute which is the model card name or path.
- # Using same card name will result in same hash. But, using a relative path for one run and
- # absolute path for another run will result in different hash.
- # The added complexity to resolve different paths to same location is not worth pursuing.
- # Instead, advise the user to always provide same relative paths or absolute paths for local models.
-
- # Compute the hash with: model_config, transforms
- mhash = hashlib.sha256()
- mhash.update(to_hashable(self.model.config.to_diff_dict()))
- mhash.update(to_hashable(self._transform_names()))
-
- mhash.update(to_hashable(self.pretrained_model_name_or_path))
-
- mhash = mhash.hexdigest()[:16]
- return mhash
-
@property
def get_model_config(self) -> dict:
return self.model.config.__dict__
@@ -448,9 +427,10 @@ class QEffVisionEncoderForTextImageToTextModel(QEFFBaseModel):
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
- def __init__(self, model: nn.modules):
- super().__init__(model)
+ def __init__(self, model: nn.modules, **kwargs):
+ super().__init__(model, **kwargs)
self.model = model.get_qeff_vision_encoder()
+ self.hash_params["qeff_auto_class"] = self.__class__.__name__
def export(self, inputs, output_names, dynamic_axes, export_dir=None):
return self._export(inputs, output_names, dynamic_axes, export_dir)
@@ -479,20 +459,6 @@ def compile(
**compiler_options,
)
- @property
- def model_hash(self) -> str:
- # Compute the hash with: model_config, continuous_batching, transforms
- mhash = hashlib.sha256()
- mhash.update(to_hashable(self.model.model.config.to_diff_dict()))
- mhash.update(to_hashable(self._transform_names()))
- mhash.update(to_hashable({"QEffVisionEncoderForTextImageToTextModel": True}))
- if hasattr(self.model, "model"):
- mhash.update(to_hashable(self.model.model.pretrained_model_name_or_path))
- else:
- mhash.update(to_hashable(self.model.pretrained_model_name_or_path))
- mhash = mhash.hexdigest()[:16]
- return mhash
-
@property
def model_name(self) -> str:
mname = self.model.__class__.__name__
@@ -516,9 +482,10 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
- def __init__(self, model):
- super().__init__(model)
+ def __init__(self, model, **kwargs):
+ super().__init__(model, **kwargs)
self.model = model.get_qeff_language_decoder()
+ self.hash_params["qeff_auto_class"] = self.__class__.__name__
def export(self, inputs, output_names, dynamic_axes, export_dir=None):
return self._export(inputs, output_names, dynamic_axes, export_dir)
@@ -547,20 +514,6 @@ def compile(
**compiler_options,
)
- @property
- def model_hash(self) -> str:
- # Compute the hash with: model_config, continuous_batching, transforms
- mhash = hashlib.sha256()
- mhash.update(to_hashable(self.model.config.to_diff_dict()))
- mhash.update(to_hashable(self._transform_names()))
- mhash.update(to_hashable({"QEffCausalLMForTextImageToTextModel": True}))
- if hasattr(self.model, "model"):
- mhash.update(to_hashable(self.model.model.pretrained_model_name_or_path))
- else:
- mhash.update(to_hashable(self.model.pretrained_model_name_or_path))
- mhash = mhash.hexdigest()[:16]
- return mhash
-
@property
def model_name(self) -> str:
mname = self.model.__class__.__name__
@@ -585,9 +538,8 @@ def __init__(
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
self.model = model
self.config = model.config
- self.model.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None)
- self.vision_model = QEffVisionEncoderForTextImageToTextModel(model)
- self.lang_model = QEffCausalLMForTextImageToTextModel(model)
+ self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs)
+ self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs)
self.input_shapes, self.output_names = None, None
@property
@@ -821,7 +773,7 @@ def kv_offload_generate(
inputs["input_ids"],
(0, padded_len - input_ids_length),
"constant",
- 1,
+ pad_token_id,
)
inputs["attention_mask"] = torch.nn.functional.pad(
inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0
@@ -950,7 +902,7 @@ def __init__(
):
if kwargs.pop("full_batch_size", None):
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
- super().__init__(model)
+ super().__init__(model, **kwargs)
# to handle internvl models
if hasattr(self.model.config, "llm_config") and hasattr(self.model.config, "vision_config"):
@@ -959,7 +911,7 @@ def __init__(
self.model.config.vision_config.use_flash_attn = "false"
else:
self.model.config.text_config.use_cache = True
- self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None)
+ self.hash_params["qeff_auto_class"] = self.__class__.__name__
@classmethod
def from_pretrained(
@@ -1139,7 +1091,7 @@ def cloud_ai_100_generate(
inputs["input_ids"],
(0, padded_len - input_ids_length),
"constant",
- 1,
+ pad_token_id,
)
inputs["attention_mask"] = torch.nn.functional.pad(
inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0
@@ -1212,16 +1164,6 @@ def cloud_ai_100_generate(
),
)
- @property
- def model_hash(self) -> str:
- mhash = hashlib.sha256()
- mhash.update(to_hashable(self.model.config.to_diff_dict()))
- mhash.update(to_hashable(self._transform_names()))
- mhash.update(to_hashable({"QEFFAutoModelForImageTextToText1QPC": True}))
- mhash.update(to_hashable(self.pretrained_model_name_or_path))
- mhash = mhash.hexdigest()[:16]
- return mhash
-
@property
def model_name(self) -> str:
mname = self.model.__class__.__name__
@@ -1409,17 +1351,17 @@ def __init__(
"Please use `from_pretrained` method to load quantized models, might give unexpected results"
)
- super().__init__(model)
# Set use_cache=True to get KV values as output during ONNX export
- self.model.config.use_cache = True
+ model.config.use_cache = True
+
+ super().__init__(model, **kwargs)
+
self.num_layers = model.config.num_hidden_layers
self.continuous_batching = continuous_batching
self.model.qaic_config = qaic_config
-
self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs)
self.is_tlm = transformed
- self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None)
-
+ self.hash_params["qeff_auto_class"] = self.__class__.__name__
# ---Sampling---
# Note: SamplerTransform should be applied after all other transforms
# are done. The role of the sampler is to just add nodes at the output of the
@@ -1507,7 +1449,7 @@ def from_pretrained(
if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP:
return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__](
- model, kv_offload=kv_offload
+ model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
)
return cls(
model,
@@ -1517,19 +1459,6 @@ def from_pretrained(
**kwargs,
)
- @property
- def model_hash(self) -> str:
- # Compute the hash with: model_config, continuous_batching, transforms
- mhash = hashlib.sha256()
- mhash.update(to_hashable(self.model.config.to_diff_dict()))
- mhash.update(to_hashable({"continuous_batching": self.continuous_batching}))
- mhash.update(to_hashable({"is_tlm": self.is_tlm}))
- mhash.update(to_hashable({"qaic_config": self.model.qaic_config}))
- mhash.update(to_hashable(self._transform_names()))
- mhash.update(to_hashable(self.pretrained_model_name_or_path))
- mhash = mhash.hexdigest()[:16]
- return mhash
-
@property
def get_model_config(self) -> dict:
return self.model.config.__dict__
@@ -1970,26 +1899,10 @@ def __init__(self, model: nn.Module, **kwargs):
if not (model_class_name.endswith("ForConditionalGeneration")):
raise TypeError(f"Required pytorch module with ForConditionalGeneration, got {model_class_name}")
- super().__init__(model)
- self.model.config.use_cache = True
+ model.config.use_cache = True
+ super().__init__(model, **kwargs)
self.num_layers = model.config.num_hidden_layers
- self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None)
-
- @property
- def model_hash(self) -> str:
- # NOTE: model_config.to_diff_dict() has "_name_or_path" attribute which is the model card name or path.
- # Using same card name will result in same hash. But, using a relative path for one run and
- # absolute path for another run will result in different hash.
- # The added complexity to resolve different paths to same location is not worth pursuing.
- # Instead, advise the user to always provide same relative paths or absolute paths for local models.
-
- # Compute the hash with: model_config, transforms
- mhash = hashlib.sha256()
- mhash.update(to_hashable(self.model.config.to_diff_dict()))
- mhash.update(to_hashable(self._transform_names()))
- mhash.update(to_hashable(self.pretrained_model_name_or_path))
- mhash = mhash.hexdigest()[:16]
- return mhash
+ self.hash_params["qeff_auto_class"] = self.__class__.__name__
@property
def get_model_config(self) -> dict:
diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py
index 42807753d..ca74c0ddd 100644
--- a/QEfficient/transformers/models/pytorch_transforms.py
+++ b/QEfficient/transformers/models/pytorch_transforms.py
@@ -503,6 +503,7 @@ class SpDTransform:
@classmethod
def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]:
transformed = False
+ pretrained_model_name_or_path_temp = kwargs.pop("pretrained_model_name_or_path", None)
if qaic_config is None or (speculative_model_type := qaic_config.get("speculative_model_type")) is None:
return model, transformed
elif speculative_model_type not in (
@@ -524,6 +525,7 @@ def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -
raise NotImplementedError(
f"model class {model_class} does not yet support returning multiple logits to keep."
)
+ kwargs["pretrained_model_name_or_path"] = pretrained_model_name_or_path_temp
return model, transformed
diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py
index 7fc132b17..d9333cde2 100755
--- a/QEfficient/utils/__init__.py
+++ b/QEfficient/utils/__init__.py
@@ -11,8 +11,11 @@
)
from QEfficient.utils._utils import ( # noqa: F401
check_and_assign_cache_dir,
+ create_json,
custom_format_warning,
dump_qconfig,
+ filter_and_create_export_hash,
+ generate_mdp_partition_config,
get_num_layers_from_config,
get_num_layers_vlm,
get_onnx_dir_name,
@@ -21,10 +24,13 @@
get_qpc_dir_path,
get_sliding_window_layers,
get_sliding_window_shapes,
+ hash_compile_params,
hf_download,
load_hf_processor,
load_hf_tokenizer,
+ load_json,
login_and_download_hf_lm,
+ make_serializable,
onnx_exists,
padding_check_and_fix,
qpc_exists,
diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py
index 106647bc0..381880bce 100644
--- a/QEfficient/utils/_utils.py
+++ b/QEfficient/utils/_utils.py
@@ -25,7 +25,8 @@
PreTrainedTokenizerFast,
)
-from QEfficient.utils.constants import QEFF_MODELS_DIR, Constants, QnnConstants
+from QEfficient.utils.constants import KWARGS_EXCLUSION_LIST, QEFF_MODELS_DIR, Constants, QnnConstants
+from QEfficient.utils.hash_utils import hash_dict_params
from QEfficient.utils.logging_utils import logger
@@ -563,11 +564,37 @@ def create_json(file_path: str, json_data: object):
"""
try:
with open(file_path, "w") as file:
- json.dump(json_data, file, indent=4)
+ json.dump(make_serializable(json_data), file, indent=4)
except Exception as e:
print(f"Failed to create JSON File {file_path}: {e}")
+def generate_mdp_partition_config(num_devices: int, num_cores: int) -> str:
+ """
+ Generates an MDP partition configuration JSON file using the create_json utility.
+
+ Args:
+ num_devices (int): Number of devices.
+ num_cores (int): Number of cores per device.
+ output_dir (str): Directory where the JSON file will be saved.
+
+ Returns:
+ str: Path to the generated JSON file.
+ """
+
+ mdp_config = {
+ "connections": [{"devices": list(range(num_devices)), "type": "p2p"}],
+ "partitions": [
+ {
+ "name": "Partition0",
+ "devices": [{"deviceId": d, "numCores": num_cores} for d in range(num_devices)],
+ }
+ ],
+ }
+
+ return mdp_config
+
+
def model_swap(func):
def wrapper(*args, **kwargs):
if "model" in kwargs and kwargs["model"] is not None:
@@ -580,6 +607,19 @@ def wrapper(*args, **kwargs):
return wrapper
+# Ensure input obj is JSON serializable
+def make_serializable(obj):
+ if isinstance(obj, (int, float, str, bool, type(None))):
+ return obj
+ elif isinstance(obj, (list, tuple)):
+ return [make_serializable(item) for item in obj]
+ elif isinstance(obj, dict):
+ return {key: make_serializable(value) for key, value in obj.items()}
+ elif hasattr(obj, "__dict__"):
+ return make_serializable(vars(obj))
+ return str(obj)
+
+
@dataclass
class IOInfo:
name: str
@@ -667,18 +707,6 @@ def create_and_dump_qconfigs(
specializations_file_path = str(os.path.join(os.path.dirname(qpc_path), "specializations.json"))
compile_dir = str(os.path.dirname(qpc_path))
- # Ensure all objects in the configs dictionary are JSON serializable
- def make_serializable(obj):
- if isinstance(obj, (int, float, str, bool, type(None))):
- return obj
- elif isinstance(obj, (list, tuple)):
- return [make_serializable(item) for item in obj]
- elif isinstance(obj, dict):
- return {key: make_serializable(value) for key, value in obj.items()}
- elif hasattr(obj, "__dict__"):
- return make_serializable(vars(obj))
- return str(obj)
-
qconfigs = {
"huggingface_config": make_serializable(huggingface_config),
"qpc_config": {
@@ -723,6 +751,41 @@ def make_serializable(obj):
create_json(qconfig_file_path, qconfigs)
+def filter_and_create_export_hash(**kwargs):
+ """
+ This Method prepares all the model params required to create the hash for export directory.
+ """
+ # TODO: Add keywords list to filter out params that are not needed for hashing
+ filtered_params = kwargs["model_params"]
+ filtered_params = {k: v for k, v in filtered_params.items() if k not in KWARGS_EXCLUSION_LIST}
+
+ export_params = {}
+ export_params["output_names"] = kwargs.get("output_names")
+ export_params["dynamic_axes"] = kwargs.get("dynamic_axes")
+
+ filtered_params["export_params"] = export_params
+
+ export_kwargs = kwargs.get("export_kwargs")
+ if export_kwargs:
+ filtered_params.update(export_kwargs)
+
+ onnx_transform_kwargs = kwargs.get("onnx_transform_kwargs")
+ if onnx_transform_kwargs:
+ filtered_params.update(onnx_transform_kwargs)
+ if filtered_params.get("peft_config") is not None:
+ filtered_params["peft_config"] = filtered_params["peft_config"].to_dict()
+
+ return hash_dict_params(filtered_params), filtered_params
+
+
+def hash_compile_params(**kwargs):
+ """
+ This Method creates the hash for qpc directory.
+ """
+
+ return hash_dict_params(kwargs.copy()), kwargs.copy()
+
+
def filter_kwargs(func, kwargs):
"""
Filter a dictionary of keyword arguments to only include the valid arguments of a function.
diff --git a/QEfficient/utils/cache.py b/QEfficient/utils/cache.py
index b484a583b..a5d1ed7c9 100644
--- a/QEfficient/utils/cache.py
+++ b/QEfficient/utils/cache.py
@@ -5,7 +5,6 @@
#
# ----------------------------------------------------------------------------
-import json
import os
from pathlib import Path
@@ -16,26 +15,3 @@
QEFF_HOME = Path(os.environ["XDG_CACHE_HOME"]) / "qeff_models"
else:
QEFF_HOME = Path("~/.cache/qeff_models").expanduser()
-
-
-def json_serializable(obj):
- if isinstance(obj, set):
- return sorted(obj)
- raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
-
-
-def to_hashable(obj) -> bytes:
- """
- Converts obj to bytes such that same object will result in same hash
- """
- return json.dumps(
- obj,
- skipkeys=False,
- ensure_ascii=True,
- check_circular=True,
- allow_nan=False,
- indent=None,
- separators=(",", ":"),
- default=json_serializable,
- sort_keys=True,
- ).encode()
diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py
index 526b01683..cb6e73303 100644
--- a/QEfficient/utils/constants.py
+++ b/QEfficient/utils/constants.py
@@ -25,6 +25,51 @@
ONNX_EXPORT_IMAGE_DEPTH = 3
ONNX_EXPORT_CTX_LEN = 1024
+# Compiler defaults
+DEFAULT_AIC_NUM_CORES = 16
+DEFAULT_AIC_MXPF6_MATMUL = False
+
+# Hashing defaults
+HASH_HEXDIGEST_STR_LEN = 16
+# Why not use an Inclusion list instead?
+KWARGS_EXCLUSION_LIST = [
+ "from_tf",
+ "from_flax",
+ "proxies",
+ "output_loading_info",
+ "use_auth_token",
+ "_from_pipeline",
+ "_from_auto",
+ "torch_dtype",
+ "device_map",
+ "max_memory",
+ "offload_folder",
+ "offload_state_dict",
+ "offload_buffers",
+ "load_in_8bit",
+ "load_in_4bit",
+ "quantization_config",
+ "subfolder",
+ "variant",
+ "generation_config",
+ "tp_plan",
+ "tp_size",
+ "device_mesh",
+ "trust_remote_code",
+ "use_kernels",
+ "resume_download",
+ "cache_dir",
+ "mirror",
+ "_fast_init",
+ "low_cpu_mem_usage",
+ "ignore_mismatched_sizes",
+ "force_download",
+ "local_files_only",
+ "token",
+ "use_safetensors",
+ "weights_only",
+]
+
# Store the qeff_models inside the ~/.cache directory or over-ride with an env variable.
def get_models_dir():
diff --git a/QEfficient/utils/generate_qnn_network_specialization_config.py b/QEfficient/utils/generate_qnn_network_specialization_config.py
index 14d83efda..eca8e1873 100644
--- a/QEfficient/utils/generate_qnn_network_specialization_config.py
+++ b/QEfficient/utils/generate_qnn_network_specialization_config.py
@@ -166,8 +166,8 @@ def generate_data_format_config(
for output in onnx_model.graph.output:
if "past_key" in output.name or "past_value" in output.name:
kv_nodes.append(output.name)
- kv_overrides = {}
+ kv_overrides = {}
kv_overrides["graphs"] = [
{
"graph_name": model_dlc_name + "_configuration_1",
diff --git a/QEfficient/utils/hash_utils.py b/QEfficient/utils/hash_utils.py
new file mode 100644
index 000000000..06020cf16
--- /dev/null
+++ b/QEfficient/utils/hash_utils.py
@@ -0,0 +1,43 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+import hashlib
+import json
+from typing import Dict
+
+from QEfficient.utils.constants import HASH_HEXDIGEST_STR_LEN
+
+
+def json_serializable(obj):
+ if isinstance(obj, set):
+ return sorted(obj)
+ raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
+
+
+def to_hashable(obj) -> bytes:
+ """
+ Converts obj to bytes such that same object will result in same hash
+ """
+ return json.dumps(
+ obj,
+ skipkeys=False,
+ ensure_ascii=True,
+ check_circular=True,
+ allow_nan=False,
+ indent=None,
+ separators=(",", ":"),
+ default=json_serializable,
+ sort_keys=True,
+ ).encode()
+
+
+def hash_dict_params(dict_items: Dict, hash_string_size: int = HASH_HEXDIGEST_STR_LEN):
+ """
+ Takes a dictionary of items and returns a SHA256 hash object
+ """
+ mhash = hashlib.sha256(to_hashable(dict_items))
+ return mhash.hexdigest()[:hash_string_size]
diff --git a/README.md b/README.md
index 9149864df..85d0a18d1 100644
--- a/README.md
+++ b/README.md
@@ -6,8 +6,24 @@
---
*Latest news* :fire:
+
- [06/2025] Added support for Llama4 Multi-Model [meta-llama/Llama-4-Scout-17B-16E-Instruct](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct)
+- [06/2025] Added support for Gemma3 Multi-Modal-Model [google/gemma-3-4b-it](https://huggingface.co/google/gemma-3-4b-it)
- [06/2025] Added support of model `hpcai-tech/grok-1` [hpcai-tech/grok-1](https://huggingface.co/hpcai-tech/grok-1)
+- [06/2025] Added support for sentence embedding which improves efficiency, Flexible/Custom Pooling configuration and compilation with multiple sequence lengths, [Embedding model](https://github.com/quic/efficient-transformers/pull/424).
+
+
+More
+
+- [04/2025] Added support for [Granite Vision models](https://huggingface.co/collections/ibm-granite/granite-vision-models-67b3bd4ff90c915ba4cd2800)
+- [04/2025] Added support for [Granite MOE models](https://huggingface.co/ibm-granite/granite-3.0-1b-a400m-base)
+- [04/2025] Support for [SpD, multiprojection heads](https://quic.github.io/efficient-transformers/source/quick_start.html#draft-based-speculative-decoding). Implemented post-attention hidden size projections to speculate tokens ahead of the base model
+- [04/2025] [QNN Compilation support](https://github.com/quic/efficient-transformers/pull/374) for AutoModel classes. QNN compilation capabilities for multi-models, embedding models and causal models.
+- [04/2025] Added support for separate prefill and decode compilation for encoder (vision) and language models. This feature will be utilized for [disaggregated serving](https://github.com/quic/efficient-transformers/pull/365).
+- [04/2025] SwiftKV Support for both [continuous and non-continuous batching execution](https://github.com/quic/efficient-transformers/pull/367) in SwiftKV.
+- [04/2025] Support for [GGUF model execution](https://github.com/quic/efficient-transformers/pull/368) (without quantized weights)
+- [04/2025] Enabled FP8 model support on [replicate_kv_heads script](https://github.com/quic/efficient-transformers/tree/main/scripts/replicate_kv_head)
+- [04/2025] Added support for [gradient checkpointing](https://github.com/quic/efficient-transformers/pull/338) in the finetuning script
- [04/2025] Added support of model `ibm-granite/granite-vision-3.2-2b`[ibm-granite/granite-vision-3.2-2b](https://huggingface.co/ibm-granite/granite-vision-3.2-2b)
- [03/2025] Added support for swiftkv model [Snowflake/Llama-3.1-SwiftKV-8B-Instruct](https://huggingface.co/Snowflake/Llama-3.1-SwiftKV-8B-Instruct)
- [02/2025] [VLMs support](https://github.com/quic/efficient-transformers/pull/267) added for the models [InternVL-1B](https://huggingface.co/OpenGVLab/InternVL2_5-1B), [Llava](https://huggingface.co/llava-hf/llava-1.5-7b-hf) and [Mllama](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
@@ -17,10 +33,6 @@
- [11/2024] [finite adapters support](https://github.com/quic/efficient-transformers/pull/153) allows mixed adapter usage for peft models.
- [11/2024] [Speculative decoding TLM](https://github.com/quic/efficient-transformers/pull/119) QEFFAutoModelForCausalLM model can be compiled for returning more than 1 logits during decode for TLM.
- [11/2024] Added support for [Meta-Llama-3.3-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct), [Meta-Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) and [Meta-Llama-3.2-3B](https://huggingface.co/meta-llama/Llama-3.2-3B)
-
-
-More
-- [04/2025] [Granite 3.0 and 3.1 Language MOE Models] (https://huggingface.co/ibm-granite/granite-3.0-1b-a400m-base)
- [09/2024] [AWQ](https://arxiv.org/abs/2306.00978)/[GPTQ](https://arxiv.org/abs/2210.17323) 4-bit quantized models are supported
- [09/2024] Now we support [PEFT](https://huggingface.co/docs/peft/index) models
- [01/2025] Added support for [Ibm-Granite] (https://huggingface.co/ibm-granite/granite-3.1-8b-instruct)
diff --git a/docs/source/finetune.md b/docs/source/finetune.md
index 6899a4880..be8dfde00 100644
--- a/docs/source/finetune.md
+++ b/docs/source/finetune.md
@@ -84,7 +84,7 @@ To run fine tuning for any user specific dataset, prepare the dataset using the
3. Inside the newly created efficient-transformers/dataset/custom_dataset.py, define a function named 'get_custom_dataset'.
4. get_custom_dataset() should have following 4 parameters: dataset_config, tokenizer, split, context_length.
5. Inside get_custom_dataset(), user needs to apply prompt and tokenize the dataset accordingly. Please refer the below template on how to define get_custom_dataset().
-6. For examples, please refer python files present in [dataset](https://github.com/quic/efficient-transformers/tree/main/QEfficient/finetune/dataset). In case of Samsum dataset, get_preprocessed_samsum() of efficient-transformers/QEfficient/finetune/dataset/samsum_dataset.py is called.
+6. For examples, please refer python files present in [dataset](https://github.com/quic/efficient-transformers/tree/main/QEfficient/finetune/dataset).
7. In [dataset_config.py](https://github.com/quic/efficient-transformers/blob/main/QEfficient/finetune/configs/dataset_config.py), for custom_dataset class, pass the appropriate value for train_split and test_split. As an alternative, these values can be passed as command line arguments as well with the finetune command. For example "--train_split train".
8. While running fine tuning, pass argument "-–dataset custom_dataset" to finetune on custom dataset.
diff --git a/docs/source/introduction.md b/docs/source/introduction.md
index d842b40c4..7a2e3fd02 100644
--- a/docs/source/introduction.md
+++ b/docs/source/introduction.md
@@ -23,19 +23,35 @@ For other models, there is comprehensive documentation to inspire upon the chang
***Latest news*** :
- [coming soon] Support for more popular [models](models_coming_soon)
+- [06/2025] Added support for Llama4 Multi-Model [meta-llama/Llama-4-Scout-17B-16E-Instruct](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct)
+- [06/2025] Added support for Gemma3 Multi-Modal-Model [google/gemma-3-4b-it](https://huggingface.co/google/gemma-3-4b-it)
+- [06/2025] Added support of model `hpcai-tech/grok-1` [hpcai-tech/grok-1](https://huggingface.co/hpcai-tech/grok-1)
+- [06/2025] Added support for sentence embedding which improves efficiency, Flexible/Custom Pooling configuration and compilation with multiple sequence lengths, [Embedding model](https://github.com/quic/efficient-transformers/pull/424).
+
+
+More
+
+- [04/2025] Added support for [Granite Vision models](https://huggingface.co/collections/ibm-granite/granite-vision-models-67b3bd4ff90c915ba4cd2800)
+- [04/2025] Added support for [Granite MOE models](https://huggingface.co/ibm-granite/granite-3.0-1b-a400m-base)
+- [04/2025] Support for [SpD, multiprojection heads](https://quic.github.io/efficient-transformers/source/quick_start.html#draft-based-speculative-decoding). Implemented post-attention hidden size projections to speculate tokens ahead of the base model
+- [04/2025] [QNN Compilation support](https://github.com/quic/efficient-transformers/pull/374) for AutoModel classes. QNN compilation capabilities for multi-models, embedding models and causal models.
+- [04/2025] Added support for separate prefill and decode compilation for encoder (vision) and language models. This feature will be utilized for [disaggregated serving](https://github.com/quic/efficient-transformers/pull/365).
+- [04/2025] SwiftKV Support for both [continuous and non-continuous batching execution](https://github.com/quic/efficient-transformers/pull/367) in SwiftKV.
+- [04/2025] Support for [GGUF model execution](https://github.com/quic/efficient-transformers/pull/368) (without quantized weights)
+- [04/2025] Enabled FP8 model support on [replicate_kv_heads script](https://github.com/quic/efficient-transformers/tree/main/scripts/replicate_kv_head)
+- [04/2025] Added support for [gradient checkpointing](https://github.com/quic/efficient-transformers/pull/338) in the finetuning script
+- [03/2025] Added support for swiftkv model [Snowflake/Llama-3.1-SwiftKV-8B-Instruct](https://huggingface.co/Snowflake/Llama-3.1-SwiftKV-8B-Instruct)
+- [02/2025] [VLMs support](https://github.com/quic/efficient-transformers/pull/267) added for the models [InternVL-1B](https://huggingface.co/OpenGVLab/InternVL2_5-1B), [Llava](https://huggingface.co/llava-hf/llava-1.5-7b-hf) and [Mllama](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
- [01/2025] [FP8 models support](https://huggingface.co/collections/neuralmagic/fp8-llms-for-vllm-666742ed2b78b7ac8df13127) Added support for inference of FP8 models.
-- [01/2025] Added support for [Ibm-Granite](https://huggingface.co/ibm-granite/granite-3.1-8b-instruct)
+- [01/2025] Added support for [Ibm-Granite] (https://huggingface.co/ibm-granite/granite-3.1-8b-instruct)
- [11/2024] [finite adapters support](https://github.com/quic/efficient-transformers/pull/153) allows mixed adapter usage for peft models.
- [11/2024] [Speculative decoding TLM](https://github.com/quic/efficient-transformers/pull/119) QEFFAutoModelForCausalLM model can be compiled for returning more than 1 logits during decode for TLM.
- [11/2024] Added support for [Meta-Llama-3.3-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct), [Meta-Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) and [Meta-Llama-3.2-3B](https://huggingface.co/meta-llama/Llama-3.2-3B)
-- [09/2024] [AWQ](https://arxiv.org/abs/2306.00978)/[GPTQ](https://arxiv.org/abs/2210.17323) 4-bit quantized models are supported
+- [09/2024] [AWQ](https://arxiv.org/abs/2306.00978)/[GPTQ](https://arxiv.org/abs/2210.17323) 4-bit quantized models are supported
- [09/2024] Now we support [PEFT](https://huggingface.co/docs/peft/index) models
-
-More
-
-- [01/2025] Added support for [Ibm-Granite](https://huggingface.co/ibm-granite/granite-3.1-8b-instruct)
-- [01/2025] Added support for [Ibm-Granite-Guardian](https://huggingface.co/ibm-granite/granite-guardian-3.1-8b)
+- [01/2025] Added support for [Ibm-Granite] (https://huggingface.co/ibm-granite/granite-3.1-8b-instruct)
+- [01/2025] Added support for [Ibm-Granite-Guardian] (https://huggingface.co/ibm-granite/granite-guardian-3.1-8b)
- [09/2024] Added support for [Gemma-2-Family](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
- [09/2024] Added support for [CodeGemma-Family](https://huggingface.co/collections/google/codegemma-release-66152ac7b683e2667abdee11)
- [09/2024] Added support for [Gemma-Family](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b)
diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md
index abab4cfc3..233fb491a 100644
--- a/docs/source/quick_start.md
+++ b/docs/source/quick_start.md
@@ -14,8 +14,15 @@ To achieve this, we have 2 levels of APIs, with different levels of abstraction.
| Feature | Impact |
| --- | --- |
| Context Length Specializations (upcoming) | Increases the maximum context length that models can handle, allowing for better performance on tasks requiring long sequences of text. |
-| Swift KV [Snowflake/Llama-3.1-SwiftKV-8B-Instruct] | Reduces computational overhead during inference by optimizing key-value pair processing, leading to improved throughput. |
| Block Attention (in progress) | Reduces inference latency and computational cost by dividing context into blocks and reusing key-value states, particularly useful in RAG. |
+| Sentence embedding, Flexible Pooling configuration and compilation with multiple sequence lengths| Supports standard/custom pooling with AI 100 acceleration and sentence embedding. Enables efficient sentence embeddings via Efficient-Transformers. Compile with one or multiple seq_len; optimal graph auto-selected at runtime. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/embedding_model.py) for more **details**.|
+| [SpD, multiprojection heads](https://quic.github.io/efficient-transformers/source/quick_start.html#draft-based-speculative-decoding) | Implemented post-attention hidden size projections to speculate tokens ahead of the base model. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/multiprojs_spd_inference.py) for more **details**.|
+| [QNN Compilation support](https://github.com/quic/efficient-transformers/pull/374) | Enabled for AutoModel classes QNN compilation capabilities for multi-models, embedding models and causal models.|
+| [Disaggregated serving](https://github.com/quic/efficient-transformers/pull/365) | It support for separate prefill and decode compilation for encoder (vision) and language models.|
+| [GGUF model execution](https://github.com/quic/efficient-transformers/pull/368) | Supported GGUF model execution (without quantized weights). Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/basic_gguf_models.py) for more **details**. |
+| Replication of KV | Enabled FP8 model support on [replicate_kv_heads script](https://github.com/quic/efficient-transformers/tree/main/scripts/replicate_kv_head).|
+| [gradient checkpointing](https://github.com/quic/efficient-transformers/pull/338) | Supports gradient checkpointing in the finetuning script|
+| Swift KV [Snowflake/Llama-3.1-SwiftKV-8B-Instruct](https://huggingface.co/Snowflake/Llama-3.1-SwiftKV-8B-Instruct) | Reduces computational overhead during inference by optimizing key-value pair processing, leading to improved throughput. Support for both [continuous and non-continuous batching execution](https://github.com/quic/efficient-transformers/pull/367) in SwiftKV |
| [Vision Language Model](QEFFAutoModelForImageTextToText) | Provides support for the AutoModelForImageTextToText class from the transformers library, enabling advanced vision-language tasks. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/image_text_to_text_inference.py) for more **details**. |
| [Speech Sequence to Sequence Model](QEFFAutoModelForSpeechSeq2Seq) | Provides support for the QEFFAutoModelForSpeechSeq2Seq Facilitates speech-to-text sequence models. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/speech_to_text/run_whisper_speech_to_text.py) for more **details**. |
| Support for FP8 Execution | Enables execution with FP8 precision, significantly improving performance and reducing memory usage for computational tasks. |
@@ -87,7 +94,7 @@ python -m QEfficient.cloud.execute --model_name gpt2 --qpc_path qeff_models/gpt2
You can run the finetune with set of predefined existing datasets on QAIC using the eager pipeline
```bash
-python -m QEfficient.cloud.finetune --device qaic:0 --use-peft --output_dir ./meta-sam --num_epochs 2 --context_length 256
+python -m QEfficient.cloud.finetune --device qaic:0 --use-peft --output_dir ./meta-sam --num_epochs 2 --context_length 256
```
For more details on finetune, checkout the subsection.
@@ -131,6 +138,28 @@ Users can compile a model with QNN SDK by following the steps below:
* Enabled QNN by passing enable_qnn flag, add --enable_qnn in the cli command.
* An optional config file can be passed to override the default parameters.
+**Default Parameters**
+
+QNN Converter Stage:
+
+ "--float_bias_bitwidth 32 --float_bitwidth 16 --preserve_io_datatype --onnx_skip_simplification --target_backend AIC"
+
+QNN Context Binary Stage:
+
+ LOG_LEVEL = "error"
+ COMPILER_COMPILATION_TARGET = "hardware"
+ COMPILER_CONVERT_TO_FP16 = True
+ COMPILER_DO_DDR_TO_MULTICAST = True
+ COMPILER_HARDWARE_VERSION = "2.0"
+ COMPILER_PERF_WARNINGS = False
+ COMPILER_PRINT_DDR_STATS = False
+ COMPILER_PRINT_PERF_METRICS = False
+ COMPILER_RETAINED_STATE = True
+ COMPILER_STAT_LEVEL = 10
+ COMPILER_STATS_BATCH_SIZE = 1
+ COMPILER_TIME_PASSES = False
+
+
**CLI Inference Command**
Without QNN Config
diff --git a/docs/source/validate.md b/docs/source/validate.md
index b12db2287..5c3ce2b24 100644
--- a/docs/source/validate.md
+++ b/docs/source/validate.md
@@ -17,6 +17,8 @@
| **GPT2LMHeadModel** | GPT-2 | [openai-community/gpt2](https://huggingface.co/openai-community/gpt2) | ✔️ |
| **GraniteForCausalLM** | Granite 3.1 | [ibm-granite/granite-3.1-8b-instruct](https://huggingface.co/ibm-granite/granite-3.1-8b-instruct)
[ibm-granite/granite-guardian-3.1-8b](https://huggingface.co/ibm-granite/granite-guardian-3.1-8b) | ✔️ |
| | Granite 20B | [ibm-granite/granite-20b-code-base-8k](https://huggingface.co/ibm-granite/granite-20b-code-base-8k)
[ibm-granite/granite-20b-code-instruct-8k](https://huggingface.co/ibm-granite/granite-20b-code-instruct-8k) | ✔️ |
+| **GraniteMoeForCausalLM** | Granite 3.0 | [ibm-granite/granite-3.0-1b-a400m-base](https://huggingface.co/ibm-granite/granite-3.0-1b-a400m-base) | ✔️ |
+| | Granite 3.1 | [ibm-granite/granite-3.1-1b-a400m-base](https://huggingface.co/ibm-granite/granite-3.0-1b-a400m-base) | ✔️ |
| **InternVLChatModel** | Intern-VL | [OpenGVLab/InternVL2_5-1B](https://huggingface.co/OpenGVLab/InternVL2_5-1B) | |
| **LlamaForCausalLM** | CodeLlama | [codellama/CodeLlama-7b-hf](https://huggingface.co/codellama/CodeLlama-7b-hf)
[codellama/CodeLlama-13b-hf](https://huggingface.co/codellama/CodeLlama-13b-hf)
[codellama/CodeLlama-34b-hf](https://huggingface.co/codellama/CodeLlama-34b-hf) | ✔️ |
| | DeepSeek-R1-Distill-Llama | [deepseek-ai/DeepSeek-R1-Distill-Llama-70B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B) | ✔️ |
@@ -57,12 +59,14 @@
### Vision-Language Models (Text + Image Generation)
**QEff Auto Class:** `QEFFAutoModelForImageTextToText`
-| Architecture | Model Family | Representative Models |
-|-----------------------------|--------------|----------------------------------------|
-| **LlavaForConditionalGeneration** | LLaVA-1.5 | [llava-hf/llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf) |
-| **MllamaForConditionalGeneration** | Llama 3.2 | [meta-llama/Llama-3.2-11B-Vision Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
[meta-llama/Llama-3.2-90B-Vision](https://huggingface.co/meta-llama/Llama-3.2-90B-Vision) |
-|**LlavaNextForConditionalGeneration** | Granite Vision | [ibm-granite/granite-vision-3.2-2b](https://huggingface.co/ibm-granite/granite-vision-3.2-2b)
-|**Llama4ForConditionalGeneration** | Llama-4-Scout | [Llama-4-Scout-17B-16E-Instruct](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct)
+| Architecture | Model Family | Representative Models | CB Support | Single Qpc Support | Dual Qpc Support |
+|-----------------------------|--------------|----------------------------------------------------------------------------------------|------------|--------------------|------------------|
+| **LlavaForConditionalGeneration** | LLaVA-1.5 | [llava-hf/llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf) | ✕ | ✔️ | ✔️ |
+| **MllamaForConditionalGeneration** | Llama 3.2 | [meta-llama/Llama-3.2-11B-Vision Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
[meta-llama/Llama-3.2-90B-Vision](https://huggingface.co/meta-llama/Llama-3.2-90B-Vision) | ✕ | ✔️ | ✔️ |
+|**LlavaNextForConditionalGeneration** | Granite Vision | [ibm-granite/granite-vision-3.2-2b](https://huggingface.co/ibm-granite/granite-vision-3.2-2b) | ✕ | ✕ | ✔️ |
+|**Llama4ForConditionalGeneration** | Llama-4-Scout | [Llama-4-Scout-17B-16E-Instruct](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct) | ✕ | ✔️ | ✔️ |
+|**Gemma3ForConditionalGeneration** | Gemma3 | [google/gemma-3-4b-it](https://huggingface.co/google/gemma-3-4b-it)| ✕ | ✔️ | ✔️ |
+
### Audio Models
(Automatic Speech Recognition) - Transcription Task
**QEff Auto Class:** `QEFFAutoModelForSpeechSeq2Seq`
@@ -76,6 +80,8 @@
| Architecture | Model Family | Representative Models |
|-------------------------|--------------|--------------------------------------------|
+| **Qwen3MoeForCausalLM** |Qwen3| [Qwen/Qwen3-MoE-15B-A2B]() |
+| **Mistral3ForConditionalGeneration**|Mistral 3.1| [mistralai/Mistral-Small-3.1-24B-Base-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503) |
| **BaichuanForCausalLM** | Baichuan2 | [baichuan-inc/Baichuan2-7B-Base](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base) |
| **CohereForCausalLM** | Command-R | [CohereForAI/c4ai-command-r-v01](https://huggingface.co/CohereForAI/c4ai-command-r-v01) |
| **DbrxForCausalLM** | DBRX | [databricks/dbrx-base](https://huggingface.co/databricks/dbrx-base) |
\ No newline at end of file
diff --git a/examples/gemma3_example/fp32_mm.yaml b/examples/gemma3_example/fp32_mm.yaml
index 3414f2a54..28e7485fa 100755
--- a/examples/gemma3_example/fp32_mm.yaml
+++ b/examples/gemma3_example/fp32_mm.yaml
@@ -370,7 +370,7 @@ FP32NodeInstanceNames:
- /language_model/model/layers.4/self_attn/Mul_6_output_0
- /language_model/model/layers.4/self_attn/Mul_7_output_0
- /language_model/model/layers.4/self_attn/Mul_8_output_0
- - /language_model/model/layers.4/self_attn/Mul_9_output_0 [274/1312]
+ - /language_model/model/layers.4/self_attn/Mul_9_output_0
- /language_model/model/layers.5/self_attn/Mul_output_0
- /language_model/model/layers.5/self_attn/Mul_1_output_0
- /language_model/model/layers.5/self_attn/Mul_2_output_0
@@ -415,7 +415,7 @@ FP32NodeInstanceNames:
- /language_model/model/layers.9/self_attn/Mul_1_output_0
- /language_model/model/layers.9/self_attn/Mul_2_output_0
- /language_model/model/layers.9/self_attn/Mul_3_output_0
- - /language_model/model/layers.9/self_attn/Mul_4_output_0 [229/1312]
+ - /language_model/model/layers.9/self_attn/Mul_4_output_0
- /language_model/model/layers.9/self_attn/Mul_5_output_0
- /language_model/model/layers.9/self_attn/Mul_6_output_0
- /language_model/model/layers.9/self_attn/Mul_7_output_0
diff --git a/examples/gemma3_example/gemma3_mm.py b/examples/gemma3_example/gemma3_mm.py
index 717049d13..f48d2d307 100644
--- a/examples/gemma3_example/gemma3_mm.py
+++ b/examples/gemma3_example/gemma3_mm.py
@@ -7,7 +7,7 @@
import torch
import transformers
-from transformers import AutoConfig, AutoModelForImageTextToText, AutoProcessor, TextStreamer
+from transformers import AutoConfig, AutoProcessor
from QEfficient import QEFFAutoModelForImageTextToText
@@ -16,12 +16,14 @@
# For Testing Purpose Only
config.text_config.num_hidden_layers = 1
config.vision_config.num_hidden_layers = 2
-
-model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager", config=config)
-model.eval()
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_id)
-qeff_model = QEFFAutoModelForImageTextToText(model, kv_offload=True)
+
+# pass HF_TOKEN if gated model
+# For running the model in single QPC approach use kv_offload=False. For Dual QPC approach use kv_offload=True ###
+qeff_model = QEFFAutoModelForImageTextToText.from_pretrained(
+ model_id, config=config, attn_implementation="eager", kv_offload=True
+)
### use skip_vision=Ture, if want to run only text, or false ###
skip_vision = True
@@ -59,9 +61,7 @@
return_tensors="pt",
)
- streamer = TextStreamer(tokenizer)
- output = qeff_model.generate(inputs=inputs, device_ids=[0], generation_len=100)
- print(output.generated_ids)
+ output = qeff_model.generate(inputs=inputs, generation_len=100)
print(tokenizer.batch_decode(output.generated_ids))
print(output)
@@ -72,7 +72,7 @@
ctx_len=3072,
img_size=896,
num_cores=16,
- num_devices=8,
+ num_devices=1,
mxfp6_matmul=False,
mxint8_kv_cache=False,
aic_enable_depth_first=True,
@@ -103,9 +103,6 @@
return_tensors="pt",
)
inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32)
- streamer = TextStreamer(tokenizer)
- output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3], generation_len=100)
- print(output.generated_ids)
+ output = qeff_model.generate(inputs=inputs, generation_len=100)
print(tokenizer.batch_decode(output.generated_ids))
print(output)
- print()
diff --git a/pyproject.toml b/pyproject.toml
index 334dfc34c..479736c22 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -28,11 +28,11 @@ dependencies = [
"multidict==6.0.4",
"urllib3<2",
"sentencepiece==0.2.0",
- "onnx==1.16.0",
- "onnxruntime==1.16.3",
+ "onnx==1.18.0",
+ "onnxruntime==1.22",
"numpy==1.26.4",
- "protobuf==3.20.2",
- "onnxscript==0.1.0.dev20240327",
+ "protobuf==6.31.0",
+ "onnxscript==0.2.5",
"pillow===10.4.0",
"sympy",
"tensorboard",
diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile
index e6a69d5fb..103c04b73 100644
--- a/scripts/Jenkinsfile
+++ b/scripts/Jenkinsfile
@@ -25,6 +25,7 @@ pipeline {
pip install junitparser pytest-xdist &&
pip install librosa==0.10.2 soundfile==0.13.1 && #packages needed to load example for whisper testing
pip install --extra-index-url https://download.pytorch.org/whl/cpu timm==1.0.14 torchvision==0.19.1+cpu einops==0.8.1 && #packages to load VLMs
+ pip install /opt/qti-aic/integrations/torch_qaic/py310/torch_qaic-0.1.0-cp310-cp310-linux_x86_64.whl && # For finetuning tests
rm -rf QEfficient"
'''
}
@@ -41,7 +42,7 @@ pipeline {
mkdir -p $PWD/Non_cli_qaic &&
export TOKENIZERS_PARALLELISM=false &&
export QEFF_HOME=$PWD/Non_cli_qaic &&
- pytest tests -m '(not cli) and (not on_qaic)' --ignore tests/vllm -n auto --junitxml=tests/tests_log1.xml &&
+ pytest tests -m '(not cli) and (not on_qaic) and (not finetune)' --ignore tests/vllm -n auto --junitxml=tests/tests_log1.xml &&
junitparser merge tests/tests_log1.xml tests/tests_log.xml &&
deactivate"
'''
@@ -58,7 +59,7 @@ pipeline {
mkdir -p $PWD/Non_qaic &&
export TOKENIZERS_PARALLELISM=false &&
export QEFF_HOME=$PWD/Non_qaic &&
- pytest tests -m '(not cli) and (on_qaic) and (not multimodal) and (not qnn)' --ignore tests/vllm -n 4 --junitxml=tests/tests_log2.xml &&
+ pytest tests -m '(not cli) and (on_qaic) and (not multimodal) and (not qnn) and (not finetune)' --ignore tests/vllm -n 4 --junitxml=tests/tests_log2.xml &&
junitparser merge tests/tests_log2.xml tests/tests_log.xml &&
deactivate"
'''
@@ -77,14 +78,14 @@ pipeline {
mkdir -p $PWD/Non_cli_qaic_multimodal &&
export TOKENIZERS_PARALLELISM=false &&
export QEFF_HOME=$PWD/Non_cli_qaic_multimodal &&
- pytest tests -m '(not cli) and (on_qaic) and (multimodal) and (not qnn)' --ignore tests/vllm -n 4 --junitxml=tests/tests_log6.xml &&
+ pytest tests -m '(not cli) and (on_qaic) and (multimodal) and (not qnn) and (not finetune)' --ignore tests/vllm -n 4 --junitxml=tests/tests_log6.xml &&
junitparser merge tests/tests_log6.xml tests/tests_log.xml &&
deactivate"
'''
}
}
}
- stage('CLI Tests') {
+ stage('Inference Tests') {
steps {
timeout(time: 60, unit: 'MINUTES') {
sh '''
@@ -96,7 +97,7 @@ pipeline {
mkdir -p $PWD/cli &&
export TOKENIZERS_PARALLELISM=false &&
export QEFF_HOME=$PWD/cli &&
- pytest tests -m '(cli and not qnn)' --ignore tests/vllm --junitxml=tests/tests_log3.xml &&
+ pytest tests -m '(cli and not qnn) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log3.xml &&
junitparser merge tests/tests_log3.xml tests/tests_log.xml &&
deactivate"
'''
@@ -125,7 +126,7 @@ pipeline {
mkdir -p $PWD/Qnn_cli &&
export TOKENIZERS_PARALLELISM=false &&
export QEFF_HOME=$PWD/Qnn_cli &&
- pytest tests -m '(cli and qnn)' --ignore tests/vllm --junitxml=tests/tests_log4.xml &&
+ pytest tests -m '(cli and qnn) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log4.xml &&
junitparser merge tests/tests_log4.xml tests/tests_log.xml &&
deactivate"
'''
@@ -144,7 +145,7 @@ pipeline {
mkdir -p $PWD/Qnn_non_cli &&
export TOKENIZERS_PARALLELISM=false &&
export QEFF_HOME=$PWD/Qnn_non_cli &&
- pytest tests -m '(not cli) and (qnn) and (on_qaic) and (not multimodal)' --ignore tests/vllm --junitxml=tests/tests_log5.xml &&
+ pytest tests -m '(not cli) and (qnn) and (on_qaic) and (not multimodal) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log5.xml &&
junitparser merge tests/tests_log5.xml tests/tests_log.xml &&
deactivate"
'''
@@ -170,6 +171,23 @@ pipeline {
}
}
}
+ stage('Finetune CLI Tests') {
+ steps {
+ timeout(time: 5, unit: 'MINUTES') {
+ sh '''
+ sudo docker exec ${BUILD_TAG} bash -c "
+ cd /efficient-transformers &&
+ . preflight_qeff/bin/activate &&
+ mkdir -p $PWD/cli_qaic_finetuning &&
+ export TOKENIZERS_PARALLELISM=false &&
+ export QEFF_HOME=$PWD/cli_qaic_finetuning &&
+ pytest tests -m '(cli) and (on_qaic) and (not qnn) and (not multimodal) and (finetune)' --ignore tests/vllm --junitxml=tests/tests_log_finetune.xml &&
+ junitparser merge tests/tests_log_finetune.xml tests/tests_log.xml &&
+ deactivate"
+ '''
+ }
+ }
+ }
}
post {
diff --git a/scripts/finetune/run_ft_model.py b/scripts/finetune/run_ft_model.py
index e04c18f7f..0a98bdc5c 100644
--- a/scripts/finetune/run_ft_model.py
+++ b/scripts/finetune/run_ft_model.py
@@ -40,17 +40,13 @@
if not tokenizer.pad_token_id:
tokenizer.pad_token_id = tokenizer.eos_token_id
-eval_prompt = """
- Summarize this dialog:
- A: Hi Tom, are you busy tomorrow’s afternoon?
- B: I’m pretty sure I am. What’s up?
- A: Can you go with me to the animal shelter?.
- B: What do you want to do?
- A: I want to get a puppy for my son.
- B: That will make him so happy.
- ---
- Summary:
- """
+# This prompt template is specific to alpaca dataset, please change it according to your dataset.
+eval_prompt = """"Below is an instruction that describes a task. Write a response that appropriately completes the request.
+
+### Instruction:
+Give three tips for staying healthy.
+
+### Response:"""
model_input = tokenizer(eval_prompt, return_tensors="pt")
@@ -66,11 +62,8 @@
)
)
-trained_weights_path = os.path.join(train_config.output_dir, "trained_weights")
-list_paths = [d for d in os.listdir(trained_weights_path) if os.path.isdir(os.path.join(trained_weights_path, d))]
-max_index = max([int(path[5:]) for path in list_paths])
-
-save_dir = os.path.join(trained_weights_path, "step_" + str(max_index))
+# Load the pre-trained model from latest checkpoint
+save_dir = os.path.join(train_config.output_dir, "complete_epoch_" + str(train_config.num_epochs))
# Load PEFT model on CPU
model = AutoPeftModelForCausalLM.from_pretrained(save_dir)
diff --git a/tests/finetune/test_finetune.py b/tests/finetune/test_finetune.py
index dbff66fd4..b376234e5 100644
--- a/tests/finetune/test_finetune.py
+++ b/tests/finetune/test_finetune.py
@@ -7,9 +7,11 @@
import os
import shutil
+from pathlib import Path
import numpy as np
import pytest
+import requests
import torch.optim as optim
from torch.utils.data import DataLoader
@@ -17,61 +19,126 @@
import QEfficient.cloud.finetune
from QEfficient.cloud.finetune import main as finetune
+alpaca_json_path = Path.cwd() / "alpaca_data.json"
+
def clean_up(path):
- if os.path.exists(path):
+ if os.path.isdir(path) and os.path.exists(path):
shutil.rmtree(path)
+ if os.path.isfile(path):
+ os.remove(path)
+
+
+def download_alpaca():
+ alpaca_url = "https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/refs/heads/main/alpaca_data.json"
+ response = requests.get(alpaca_url)
+
+ with open(alpaca_json_path, "wb") as f:
+ f.write(response.content)
configs = [
pytest.param(
"meta-llama/Llama-3.2-1B", # model_name
+ "generation", # task_type
10, # max_eval_step
20, # max_train_step
+ "gsm8k_dataset", # dataset_name
+ None, # data_path
1, # intermediate_step_save
None, # context_length
True, # run_validation
True, # use_peft
"qaic", # device
- id="llama_config", # config name
- )
+ 1.5427961, # expected_train_loss
+ 4.6776514, # expected_train_metric
+ 1.2898713, # expected_eval_loss
+ 3.6323189, # expected_eval_metric
+ id="llama_config_gsm8k", # config name
+ ),
+ pytest.param(
+ "meta-llama/Llama-3.2-1B", # model_name
+ "generation", # task_type
+ 10, # max_eval_step
+ 20, # max_train_step
+ "alpaca_dataset", # dataset_name
+ alpaca_json_path, # data_path
+ 1, # intermediate_step_save
+ None, # context_length
+ True, # run_validation
+ True, # use_peft
+ "qaic", # device
+ 1.4348667, # expected_train_loss
+ 4.1990857, # expected_train_metric
+ 1.5941212, # expected_eval_loss
+ 4.9239997, # expected_eval_metric
+ id="llama_config_alpaca", # config name
+ ),
+ pytest.param(
+ "google-bert/bert-base-uncased", # model_name
+ "seq_classification", # task_type
+ 10, # max_eval_step
+ 20, # max_train_step
+ "imdb_dataset", # dataset_name
+ None, # data_path
+ 1, # intermediate_step_save
+ None, # context_length
+ True, # run_validation
+ False, # use_peft
+ "qaic", # device
+ 0.63060283, # expected_train_loss
+ 0.55554199, # expected_train_metric
+ 0.61503016, # expected_eval_loss
+ 0.70825195, # expected_eval_metric
+ id="bert_config_imdb", # config name
+ ),
]
-@pytest.mark.skip(reason="Currently CI is broken. Once it is fixed we will enable this test.")
+@pytest.mark.skip() # remove when it's clear why diff val_step_loss values are observed in diff runs on existing code (even without PR #478 changes)
@pytest.mark.cli
@pytest.mark.on_qaic
@pytest.mark.finetune
@pytest.mark.parametrize(
- "model_name,max_eval_step,max_train_step,intermediate_step_save,context_length,run_validation,use_peft,device",
+ "model_name,task_type,max_eval_step,max_train_step,dataset_name,data_path,intermediate_step_save,context_length,run_validation,use_peft,device,expected_train_loss,expected_train_metric,expected_eval_loss,expected_eval_metric",
configs,
)
-def test_finetune(
+def test_finetune_llama(
model_name,
+ task_type,
max_eval_step,
max_train_step,
+ dataset_name,
+ data_path,
intermediate_step_save,
context_length,
run_validation,
use_peft,
device,
+ expected_train_loss,
+ expected_train_metric,
+ expected_eval_loss,
+ expected_eval_metric,
mocker,
):
train_config_spy = mocker.spy(QEfficient.cloud.finetune, "TrainConfig")
generate_dataset_config_spy = mocker.spy(QEfficient.cloud.finetune, "generate_dataset_config")
generate_peft_config_spy = mocker.spy(QEfficient.cloud.finetune, "generate_peft_config")
- get_dataloader_kwargs_spy = mocker.spy(QEfficient.cloud.finetune, "get_dataloader_kwargs")
+ get_dataloader_kwargs_spy = mocker.spy(QEfficient.finetune.utils.dataset_utils, "get_dataloader_kwargs")
update_config_spy = mocker.spy(QEfficient.cloud.finetune, "update_config")
- get_custom_data_collator_spy = mocker.spy(QEfficient.cloud.finetune, "get_custom_data_collator")
- get_preprocessed_dataset_spy = mocker.spy(QEfficient.cloud.finetune, "get_preprocessed_dataset")
+ get_custom_data_collator_spy = mocker.spy(QEfficient.finetune.utils.dataset_utils, "get_custom_data_collator")
+ get_preprocessed_dataset_spy = mocker.spy(QEfficient.finetune.utils.dataset_utils, "get_preprocessed_dataset")
get_longest_seq_length_spy = mocker.spy(QEfficient.cloud.finetune, "get_longest_seq_length")
print_model_size_spy = mocker.spy(QEfficient.cloud.finetune, "print_model_size")
train_spy = mocker.spy(QEfficient.cloud.finetune, "train")
kwargs = {
"model_name": model_name,
+ "task_type": task_type,
"max_eval_step": max_eval_step,
"max_train_step": max_train_step,
+ "dataset": dataset_name,
+ "data_path": data_path,
"intermediate_step_save": intermediate_step_save,
"context_length": context_length,
"run_validation": run_validation,
@@ -79,22 +146,27 @@ def test_finetune(
"device": device,
}
+ if dataset_name == "alpaca_dataset":
+ download_alpaca()
+
results = finetune(**kwargs)
- assert np.allclose(results["avg_train_loss"], 0.00232327, atol=1e-5), "Train loss is not matching."
- assert np.allclose(results["avg_train_metric"], 1.002326, atol=1e-5), "Train metric is not matching."
- assert np.allclose(results["avg_eval_loss"], 0.0206124, atol=1e-5), "Eval loss is not matching."
- assert np.allclose(results["avg_eval_metric"], 1.020826, atol=1e-5), "Eval metric is not matching."
+
+ assert np.allclose(results["avg_train_loss"], expected_train_loss, atol=1e-3), "Train loss is not matching."
+ assert np.allclose(results["avg_train_metric"], expected_train_metric, atol=1e-3), "Train metric is not matching."
+ assert np.allclose(results["avg_eval_loss"], expected_eval_loss, atol=1e-3), "Eval loss is not matching."
+ assert np.allclose(results["avg_eval_metric"], expected_eval_metric, atol=1e-3), "Eval metric is not matching."
assert results["avg_epoch_time"] < 60, "Training should complete within 60 seconds."
train_config_spy.assert_called_once()
generate_dataset_config_spy.assert_called_once()
- generate_peft_config_spy.assert_called_once()
- get_custom_data_collator_spy.assert_called_once()
+ if task_type == "generation":
+ generate_peft_config_spy.assert_called_once()
get_longest_seq_length_spy.assert_called_once()
print_model_size_spy.assert_called_once()
train_spy.assert_called_once()
assert update_config_spy.call_count == 2
+ assert get_custom_data_collator_spy.call_count == 2
assert get_dataloader_kwargs_spy.call_count == 2
assert get_preprocessed_dataset_spy.call_count == 2
@@ -123,12 +195,19 @@ def test_finetune(
f"{train_config.gradient_accumulation_steps} which is gradient accumulation steps."
)
- saved_file = os.path.join(train_config.output_dir, "complete_epoch_1/adapter_model.safetensors")
+ if use_peft:
+ saved_file = os.path.join(train_config.output_dir, "complete_epoch_1/adapter_model.safetensors")
+ else:
+ saved_file = os.path.join(train_config.output_dir, "complete_epoch_1/model.safetensors")
assert os.path.isfile(saved_file)
clean_up(train_config.output_dir)
clean_up("runs")
+ clean_up("qaic-dumps")
clean_up(train_config.dump_root_dir)
+ if dataset_name == "alpaca_dataset":
+ clean_up(alpaca_json_path)
+
# TODO (Meet): Add seperate tests for BERT FT and LLama FT
diff --git a/tests/peft/lora/test_lora_model.py b/tests/peft/lora/test_lora_model.py
index 18fb0f5dd..4a82c8b0f 100644
--- a/tests/peft/lora/test_lora_model.py
+++ b/tests/peft/lora/test_lora_model.py
@@ -21,14 +21,24 @@
configs = [
pytest.param(
AutoConfig.for_model(
- "llama", num_hidden_layers=2, num_attention_heads=4, num_key_value_heads=2, hidden_size=128
+ "llama",
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ num_key_value_heads=2,
+ hidden_size=128,
+ architectures=["LlamaForCausalLM"],
),
LoraConfig(target_modules=["q_proj", "v_proj"], task_type="CAUSAL_LM", lora_alpha=8),
id="llama-2l-4h-2kvh-128d-qv",
),
pytest.param(
AutoConfig.for_model(
- "mistral", num_hidden_layers=2, num_attention_heads=4, num_key_value_heads=2, hidden_size=128
+ "mistral",
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ num_key_value_heads=2,
+ hidden_size=128,
+ architectures=["MistralForCausalLM"],
),
LoraConfig(target_modules=["q_proj", "v_proj"], task_type="CAUSAL_LM", lora_alpha=6),
id="mistral-2l-4h-128d-qv",
@@ -113,6 +123,7 @@ def test_auto_lora_model_for_causal_lm_init_from_unsupported_model(base_model_na
# test model hash
+@pytest.mark.skip(reason="Different adapter names will create different hashes so we'll skip this test.")
def test_auto_lora_model_for_causal_lm_hash():
base_config_0, adapter_config_0 = configs[0].values
base_config_1, adapter_config_1 = configs[1].values
@@ -124,7 +135,7 @@ def test_auto_lora_model_for_causal_lm_hash():
qeff_model_0.load_adapter(
"dummy_id", "adapter_1", adapter_config=adapter_config_1, adapter_weight={"weights": np.ones((3, 3))}
)
- model_hash_0_0 = qeff_model_0.model_hash
+ model_hash_0_0 = qeff_model_0.export_hash
qeff_model_1 = create_lora_base_model(base_config_1)
qeff_model_1.load_adapter(
@@ -133,7 +144,7 @@ def test_auto_lora_model_for_causal_lm_hash():
qeff_model_1.load_adapter(
"dummy_id", "adapter_1", adapter_config=adapter_config_1, adapter_weight={"weights": np.ones((3, 3))}
)
- model_hash_1_0 = qeff_model_1.model_hash
+ model_hash_1_0 = qeff_model_1.export_hash
qeff_model_0_1 = create_lora_base_model(base_config_0)
qeff_model_0_1.load_adapter(
@@ -142,7 +153,7 @@ def test_auto_lora_model_for_causal_lm_hash():
qeff_model_0_1.load_adapter(
"dummy_id", "adapter_1", adapter_config=adapter_config_1, adapter_weight={"weights": np.ones((3, 3))}
)
- model_hash_0_1_0 = qeff_model_0_1.model_hash
+ model_hash_0_1_0 = qeff_model_0_1.export_hash
# check if same model, same adapter config, same adapter weight, result in same hash
assert model_hash_0_1_0 == model_hash_0_0
@@ -156,7 +167,7 @@ def test_auto_lora_model_for_causal_lm_hash():
qeff_model_0_1.load_adapter(
"dummy_id", "adapter_1", adapter_config=adapter_config_1, adapter_weight={"weights": np.random.randn(3, 3)}
)
- model_hash_0_1_1 = qeff_model_0_1.model_hash
+ model_hash_0_1_1 = qeff_model_0_1.export_hash
assert model_hash_0_1_1 != model_hash_0_0
# check base model configs difference result in different hash
@@ -171,7 +182,7 @@ def test_auto_lora_model_for_causal_lm_hash():
qeff_model_1.load_adapter(
"dummy_id", "adapter_0", adapter_config=adapter_config_0, adapter_weight={"weights": np.ones((3, 3))}
)
- model_hash_1_1 = qeff_model_1.model_hash
+ model_hash_1_1 = qeff_model_1.export_hash
assert model_hash_1_1 != model_hash_1_0
# check if same adapter name, but different config, result in different hash
@@ -179,7 +190,7 @@ def test_auto_lora_model_for_causal_lm_hash():
qeff_model_0.load_adapter(
"dummy_id", "adapter_1", adapter_config=adapter_config_0, adapter_weight={"weights": np.ones((3, 3))}
)
- model_hash_0_1 = qeff_model_0.model_hash
+ model_hash_0_1 = qeff_model_0.export_hash
assert model_hash_0_1 != model_hash_0_0
@@ -213,7 +224,7 @@ def test_auto_lora_model_for_causal_lm_noncb_export_compile_generate(
qeff_model.export(export_dir=tmp_path)
end = perf_counter()
export_time_0 = end - start
- model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.model_hash)
+ model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.export_hash)
assert model_path.is_dir()
assert Path(qeff_model.onnx_path).is_file()
diff --git a/tests/peft/test_peft_model.py b/tests/peft/test_peft_model.py
index 0458eb521..eca77a988 100644
--- a/tests/peft/test_peft_model.py
+++ b/tests/peft/test_peft_model.py
@@ -20,14 +20,24 @@
configs = [
pytest.param(
AutoConfig.for_model(
- "llama", num_hidden_layers=2, num_attention_heads=4, num_key_value_heads=2, hidden_size=128
+ "llama",
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ num_key_value_heads=2,
+ hidden_size=128,
+ architectures=["LlamaForCausalLM"],
),
LoraConfig(target_modules=["q_proj", "v_proj"], task_type="CAUSAL_LM"),
id="llama-2l-4h-2kvh-128d-qv",
),
pytest.param(
AutoConfig.for_model(
- "mistral", num_hidden_layers=2, num_attention_heads=4, num_key_value_heads=2, hidden_size=128
+ "mistral",
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ num_key_value_heads=2,
+ hidden_size=128,
+ architectures=["MistralForCausalLM"],
),
LoraConfig(target_modules=["q_proj", "k_proj", "v_proj"], task_type="CAUSAL_LM"),
id="mistral-2l-4h-128d-qkv",
@@ -83,6 +93,9 @@ def test_auto_peft_model_for_causal_lm_from_pretrained(base_config, adapter_conf
QEffAutoPeftModelForCausalLM.from_pretrained(adapter_path / adapter_name, full_batch_size=4)
+# This test isn't required anymore as different adapter names should generate different hashes. We'll
+# phase out this test in some time.
+@pytest.mark.skip(reason="Different adapter names will create different hashes so we'll skip this test.")
def test_auto_peft_model_for_causal_lm_hash():
base_config_0, adapter_config_0 = configs[0].values
base_config_1, adapter_config_1 = configs[1].values
@@ -129,7 +142,7 @@ def test_auto_peft_model_for_causal_lm_export(base_config, adapter_config, tmp_p
qeff_model.export(tmp_path)
end = perf_counter()
export_time_0 = end - start
- model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.model_hash)
+ model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.export_hash)
assert model_path.is_dir()
assert qeff_model.onnx_path.is_file()
diff --git a/tests/peft/test_peft_onnx_transforms.py b/tests/peft/test_peft_onnx_transforms.py
index f8521deb1..0248dae3b 100644
--- a/tests/peft/test_peft_onnx_transforms.py
+++ b/tests/peft/test_peft_onnx_transforms.py
@@ -46,6 +46,7 @@ def test_adapter_weights_to_inputs_transform():
out_onnx, transformed = AdapterWeightsToInputsTransform.apply(test_onnx, adapter_name=adapter_name)
assert transformed
+
assert (
onnx.printer.to_text(out_onnx)
== textwrap.dedent("""
@@ -53,11 +54,11 @@ def test_adapter_weights_to_inputs_transform():
ir_version: 8,
opset_import: ["" : 17]
>
- test_adapter_weights (float[n,32] input, float[32,32] layer1.weight, float[32,32] layer2.weight) => (float[n,32] output, float[32,32] layer1.weight_RetainedState, float[32,32] layer2.weight_RetainedState) {
- layer1output = MatMul (input, layer1.weight)
- output = MatMul (layer1output, layer2.weight)
- layer1.weight_RetainedState = Identity (layer1.weight)
- layer2.weight_RetainedState = Identity (layer2.weight)
+ test_adapter_weights (float[n,32] input, float[32,32] "layer1.weight", float[32,32] "layer2.weight") => (float[n,32] output, float[32,32] "layer1.weight_RetainedState", float[32,32] "layer2.weight_RetainedState") {
+ layer1output = MatMul (input, "layer1.weight")
+ output = MatMul (layer1output, "layer2.weight")
+ ["layer1.weight_identity"] "layer1.weight_RetainedState" = Identity ("layer1.weight")
+ ["layer2.weight_identity"] "layer2.weight_RetainedState" = Identity ("layer2.weight")
}
""").strip()
)
diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py
index 3195c4828..642030c9f 100644
--- a/tests/transformers/models/test_causal_lm_models.py
+++ b/tests/transformers/models/test_causal_lm_models.py
@@ -26,38 +26,38 @@
test_models_qaic = [
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"gpt2",
- "Salesforce/codegen-350M-mono",
- "microsoft/Phi-3-mini-4k-instruct",
- "tiiuae/falcon-7b",
+ # "Salesforce/codegen-350M-mono",
+ # "microsoft/Phi-3-mini-4k-instruct",
+ # "tiiuae/falcon-7b",
"Qwen/Qwen2-0.5B",
- "bigcode/starcoder2-3b",
- "Felladrin/Minueza-32M-Base",
- "wtang06/mpt-125m-c4",
- "hakurei/gpt-j-random-tinier",
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
+ # "bigcode/starcoder2-3b",
+ # "Felladrin/Minueza-32M-Base",
+ # "wtang06/mpt-125m-c4",
+ # "hakurei/gpt-j-random-tinier",
+ # "mistralai/Mixtral-8x7B-Instruct-v0.1",
"meta-llama/Llama-3.2-1B",
- "unsloth/gemma-2b",
- "unsloth/gemma-2-2b",
- "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", # AWQ model
- "TheBloke/Llama-2-7B-GPTQ", # GPTQ model
- "ibm-granite/granite-20b-code-base",
- # "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic", # naive-quantized compressed-tensor FP8 model per-channel weight, per-token activations
- "neuralmagic/Llama-3.2-3B-Instruct-FP8", # float quantized compressed-tensor per tensor both weight and activations
+ # "unsloth/gemma-2b",
+ # "unsloth/gemma-2-2b",
+ # "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", # AWQ model
+ # "TheBloke/Llama-2-7B-GPTQ", # GPTQ model
+ # "ibm-granite/granite-20b-code-base",
+ # # "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic", # naive-quantized compressed-tensor FP8 model per-channel weight, per-token activations
+ # "neuralmagic/Llama-3.2-3B-Instruct-FP8", # float quantized compressed-tensor per tensor both weight and activations
"neuralmagic/Qwen2-0.5B-Instruct-FP8", # fp8 quant method, static, with lm head ignored
- "ibm-granite/granite-3.1-2b-instruct",
- "ibm-granite/granite-guardian-3.1-2b",
- "hpcai-tech/grok-1",
+ # "ibm-granite/granite-3.1-2b-instruct",
+ # "ibm-granite/granite-guardian-3.1-2b",
+ # "hpcai-tech/grok-1",
]
test_models_qnn = [
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
- "meta-llama/Llama-3.2-1B",
- "unsloth/gemma-2b",
- "ibm-granite/granite-guardian-3.1-2b",
+ # "mistralai/Mixtral-8x7B-Instruct-v0.1",
+ # "meta-llama/Llama-3.2-1B",
+ # "unsloth/gemma-2b",
+ # "ibm-granite/granite-guardian-3.1-2b",
]
spd_test_models = [
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
+ # "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"Qwen/Qwen2-0.5B",
]
diff --git a/tests/transformers/models/test_image_text_to_text_models.py b/tests/transformers/models/test_image_text_to_text_models.py
index c31491442..94e723326 100644
--- a/tests/transformers/models/test_image_text_to_text_models.py
+++ b/tests/transformers/models/test_image_text_to_text_models.py
@@ -66,29 +66,28 @@
"What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud",
1,
),
- (
- "meta-llama/Llama-4-Scout-17B-16E-Instruct",
- True,
- 1,
- 128,
- 3072,
- 336,
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg",
- "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud",
- 4,
- ),
- (
- "meta-llama/Llama-4-Scout-17B-16E-Instruct",
- False,
- 1,
- 128,
- 3072,
- 336,
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg",
- "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud",
- 4,
- ),
- # FIX: Accuracy in AIC
+ # (
+ # "meta-llama/Llama-4-Scout-17B-16E-Instruct",
+ # True,
+ # 1,
+ # 128,
+ # 3072,
+ # 336,
+ # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg",
+ # "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud",
+ # 4,
+ # ),
+ # (
+ # "meta-llama/Llama-4-Scout-17B-16E-Instruct",
+ # False,
+ # 1,
+ # 128,
+ # 3072,
+ # 336,
+ # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg",
+ # "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud",
+ # 4,
+ # ),
# (
# "google/gemma-3-4b-it",
# True,
@@ -98,7 +97,7 @@
# 896,
# "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png",
# "Can you describe the image in detail.",
- # 6,
+ # 1,
# ),
# (
# "google/gemma-3-4b-it",
@@ -109,7 +108,7 @@
# 896,
# "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png",
# "Can you describe the image in detail.",
- # 6,
+ # 1,
# ),
# (
# "meta-llama/Llama-3.2-11B-Vision-Instruct",
diff --git a/tests/transformers/test_causal_lm.py b/tests/transformers/test_causal_lm.py
index d0250d899..95289498e 100644
--- a/tests/transformers/test_causal_lm.py
+++ b/tests/transformers/test_causal_lm.py
@@ -14,6 +14,8 @@
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM
+from QEfficient.utils import constants, get_padding_shape_from_config
+from QEfficient.utils.hash_utils import hash_dict_params
configs = [
# name, max_position_embeddings, num_hidden_layers, num_attention_heads, hidden_size, intermediate_size, vocab_size, additional_params
@@ -88,53 +90,109 @@ def test_causal_lm_pretrained(config, cb, tmp_path):
@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
@pytest.mark.parametrize("config", configs, ids=config_ids)
-def test_causal_lm_hash(config, cb):
- hash_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb).model_hash
- hash_0_1 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb).model_hash
+def test_causal_lm_export_and_hash(config, cb, tmp_path):
+ model_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb)
+ model_0_0.export(tmp_path)
+ model_path = tmp_path.with_name(tmp_path.name + "-" + model_0_0.export_hash)
+ assert model_path.is_dir()
+ assert model_0_0.onnx_path.is_file()
+ assert model_0_0.onnx_path.relative_to(model_path).parts == (model_0_0.model_name + ".onnx",)
+
+ # Check if the KV-cache inputs and outputs are created
+ onnx_model = onnx.load(model_0_0.onnx_path, load_external_data=False)
+ retained_output_names = {
+ x.name[: -len("_RetainedState")] for x in onnx_model.graph.output if x.name.endswith("_RetainedState")
+ }
+ retained_output_names.issubset({x.name for x in onnx_model.graph.input})
+
+ # Check if there is no re-export
+ start = perf_counter()
+ model_0_0.export(tmp_path)
+ end = perf_counter()
+ export_time = end - start
+ assert export_time < 2.0
+
+ # Check if hashing is happening properly
+ model_0_1 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb)
+ model_0_1.export(tmp_path)
+ hash_0_0 = model_0_0.export_hash
+ hash_0_1 = model_0_1.export_hash
assert hash_0_0 == hash_0_1
cfg1 = copy.deepcopy(config)
cfg1.num_hidden_layers -= 1
- hash_1_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(cfg1, **model_kwargs), cb).model_hash
+ model_1_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(cfg1, **model_kwargs), cb)
+ model_1_0.export(tmp_path)
+ hash_1_0 = model_1_0.export_hash
cfg2 = copy.deepcopy(config)
cfg2.num_hidden_layers -= 1
- hash_1_1 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(cfg2, **model_kwargs), cb).model_hash
+ model_1_1 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(cfg2, **model_kwargs), cb)
+ model_1_1.export(tmp_path)
+ hash_1_1 = model_1_1.export_hash
assert hash_1_0 == hash_1_1
assert hash_0_0 != hash_1_0
if cb:
- hash_0_no_cb = QEFFAutoModelForCausalLM(
- AutoModelForCausalLM.from_config(config, **model_kwargs), False
- ).model_hash
+ model_0_no_cb = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), False)
+ model_0_no_cb.export(tmp_path)
+ hash_0_no_cb = model_0_no_cb.export_hash
assert hash_0_0 != hash_0_no_cb
@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
@pytest.mark.parametrize("config", configs, ids=config_ids)
-def test_causal_lm_export(config, cb, tmp_path):
+def test_causal_lm_hash_creation(config, cb, tmp_path):
model = AutoModelForCausalLM.from_config(config, **model_kwargs)
qeff_model = QEFFAutoModelForCausalLM(model, cb)
qeff_model.export(tmp_path)
- model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.model_hash)
- assert model_path.is_dir()
- assert qeff_model.onnx_path.is_file()
- assert qeff_model.onnx_path.relative_to(model_path).parts == (qeff_model.model_name + ".onnx",)
-
- # Check if the KV-cache inputs and outputs are created
- onnx_model = onnx.load(qeff_model.onnx_path, load_external_data=False)
- retained_output_names = {
- x.name[: -len("_RetainedState")] for x in onnx_model.graph.output if x.name.endswith("_RetainedState")
+ hash_params = {}
+ hash_params["config"] = qeff_model.model.config.to_diff_dict()
+ hash_params["peft_config"] = None
+ hash_params["applied_transform_names"] = qeff_model._transform_names()
+ hash_params["qeff_auto_class"] = qeff_model.__class__.__name__
+
+ # Create parameters separately for hash creation
+
+ bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
+ seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
+ fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS
+ kv_cache_shape = get_padding_shape_from_config(
+ qeff_model.model.config, fbs if qeff_model.continuous_batching else bs, seq_len
+ )
+ dynamic_axes = {
+ "input_ids": {0: "batch_size", 1: "seq_len"},
+ "position_ids": {0: "batch_size", 1: "seq_len"},
}
- retained_output_names.issubset({x.name for x in onnx_model.graph.input})
-
- # Check if there is no re-export
- start = perf_counter()
- qeff_model.export(tmp_path)
- end = perf_counter()
- export_time = end - start
- assert export_time < 2.0
+ if len(kv_cache_shape) == 3: # For GPTBigCode arch the pkv is 3d
+ pkv_dynamic_axes = {
+ 0: "full_batch_size" if qeff_model.continuous_batching else "batch_size",
+ 1: "ctx_len",
+ }
+ else: # pkv is 4d
+ pkv_dynamic_axes = {
+ 0: "full_batch_size" if qeff_model.continuous_batching else "batch_size",
+ 2: "ctx_len",
+ }
+ output_names = []
+ output_names.append("logits")
+
+ for i in range(qeff_model.num_layers):
+ for kv in ["key", "value"]:
+ dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
+ output_names.append(f"past_{kv}.{i}_RetainedState")
+
+ if qeff_model.continuous_batching:
+ dynamic_axes["batch_index"] = {0: "batch_size"}
+
+ export_params = {}
+ export_params["output_names"] = output_names
+ export_params["dynamic_axes"] = dynamic_axes
+ hash_params["export_params"] = export_params
+ manual_hash = hash_dict_params(hash_params)
+
+ assert manual_hash == qeff_model.export_hash
@pytest.fixture
@@ -153,8 +211,7 @@ def test_causal_lm_compile(config, cb, tmp_cache):
compile_params["full_batch_size"] = 32
compile_params["batch_size"] = 8
qeff_model.compile(**compile_params)
- model_path = tmp_cache / (qeff_model.model_name + "-" + qeff_model.model_hash)
-
+ model_path = tmp_cache / qeff_model.model_name / (qeff_model.model_name + "-" + qeff_model.export_hash)
# Check if ONNX is exported properly
assert model_path.is_dir()
assert qeff_model.onnx_path.is_file()
@@ -163,7 +220,7 @@ def test_causal_lm_compile(config, cb, tmp_cache):
# Check if QPC is compiled properly
assert qeff_model.qpc_path.is_dir()
assert (qeff_model.qpc_path / "programqpc.bin").is_file()
- assert qeff_model.qpc_path.relative_to(tmp_cache).parts[0] == qeff_model.model_name + "-" + qeff_model.model_hash
+ assert qeff_model.qpc_path.relative_to(tmp_cache).parts[1] == qeff_model.model_name + "-" + qeff_model.export_hash
# Check if there is no re-compilation
start = perf_counter()
diff --git a/tests/transformers/test_speech_seq2seq.py b/tests/transformers/test_speech_seq2seq.py
index 4d731c2b4..10f7ce709 100644
--- a/tests/transformers/test_speech_seq2seq.py
+++ b/tests/transformers/test_speech_seq2seq.py
@@ -14,6 +14,7 @@
from transformers import AutoConfig, AutoModel, AutoModelForSpeechSeq2Seq
from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForSpeechSeq2Seq
+from QEfficient.utils.hash_utils import hash_dict_params
configs = [
# name, max_source_positions, num_hidden_layers, num_attention_heads, hidden_size, encoder_ffn_dim, vocab_size, additional_params
@@ -72,45 +73,70 @@ def test_seq2seq_pretrained(config, tmp_path):
@pytest.mark.parametrize("config", configs, ids=config_ids)
-def test_seq2seq_hash(config):
- hash_0_0 = QEFFAutoModelForSpeechSeq2Seq(AutoModelForSpeechSeq2Seq.from_config(config, **model_kwargs)).model_hash
- hash_0_1 = QEFFAutoModelForSpeechSeq2Seq(AutoModelForSpeechSeq2Seq.from_config(config, **model_kwargs)).model_hash
+def test_seq2seq_export_and_hash(config, tmp_path):
+ model_0_0 = QEFFAutoModelForSpeechSeq2Seq(AutoModelForSpeechSeq2Seq.from_config(config, **model_kwargs))
+ model_0_0.export(tmp_path)
+ model_path = tmp_path.with_name(tmp_path.name + "-" + model_0_0.export_hash)
+ assert model_path.is_dir()
+ assert model_0_0.onnx_path.is_file()
+ assert model_0_0.onnx_path.relative_to(model_path).parts == (model_0_0.model_name + ".onnx",)
+
+ # Check if the KV-cache inputs and outputs are created
+ onnx_model = onnx.load(model_0_0.onnx_path, load_external_data=False)
+ retained_output_names = {
+ x.name[: -len("_RetainedState")] for x in onnx_model.graph.output if x.name.endswith("_RetainedState")
+ }
+ retained_output_names.issubset({x.name for x in onnx_model.graph.input})
+
+ # Check if there is no re-export
+ start = perf_counter()
+ model_0_0.export(tmp_path)
+ end = perf_counter()
+ export_time = end - start
+ assert export_time < 2.0
+
+ # Check if the hashing is happening properly.
+ hash_0_0 = model_0_0.export_hash
+ model_0_1 = QEFFAutoModelForSpeechSeq2Seq(AutoModelForSpeechSeq2Seq.from_config(config, **model_kwargs))
+ model_0_1.export(tmp_path)
+ hash_0_1 = model_0_1.export_hash
assert hash_0_0 == hash_0_1
cfg1 = copy.deepcopy(config)
- cfg1.num_hidden_layers -= 1
- hash_1_0 = QEFFAutoModelForSpeechSeq2Seq(AutoModelForSpeechSeq2Seq.from_config(cfg1, **model_kwargs)).model_hash
+ cfg1.num_hidden_layers += 1
+ model_1_0 = QEFFAutoModelForSpeechSeq2Seq(AutoModelForSpeechSeq2Seq.from_config(cfg1, **model_kwargs))
+ model_1_0.export(tmp_path)
+ hash_1_0 = model_1_0.export_hash
+
cfg2 = copy.deepcopy(config)
- cfg2.num_hidden_layers -= 1
- hash_1_1 = QEFFAutoModelForSpeechSeq2Seq(AutoModelForSpeechSeq2Seq.from_config(cfg2, **model_kwargs)).model_hash
+ cfg2.num_hidden_layers += 1
+ model_1_1 = QEFFAutoModelForSpeechSeq2Seq(AutoModelForSpeechSeq2Seq.from_config(cfg2, **model_kwargs))
+ model_1_1.export(tmp_path)
+ hash_1_1 = model_1_1.export_hash
+
assert hash_1_0 == hash_1_1
assert hash_0_0 != hash_1_0
@pytest.mark.parametrize("config", configs, ids=config_ids)
-def test_seq2seq_export(config, tmp_path):
+def test_seq2seq_hash_creation(config, tmp_path):
model = AutoModelForSpeechSeq2Seq.from_config(config, **model_kwargs)
qeff_model = QEFFAutoModelForSpeechSeq2Seq(model)
qeff_model.export(tmp_path)
- model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.model_hash)
- assert model_path.is_dir()
- assert qeff_model.onnx_path.is_file()
- assert qeff_model.onnx_path.relative_to(model_path).parts == (qeff_model.model_name + ".onnx",)
+ hash_params = {}
+ hash_params["config"] = qeff_model.model.config.to_diff_dict()
+ hash_params["peft_config"] = None
+ hash_params["applied_transform_names"] = qeff_model._transform_names()
+ hash_params["qeff_auto_class"] = qeff_model.__class__.__name__
- # Check if the KV-cache inputs and outputs are created
- onnx_model = onnx.load(qeff_model.onnx_path, load_external_data=False)
- retained_output_names = {
- x.name[: -len("_RetainedState")] for x in onnx_model.graph.output if x.name.endswith("_RetainedState")
- }
- retained_output_names.issubset({x.name for x in onnx_model.graph.input})
+ export_params = {}
+ export_params["output_names"] = qeff_model.model.get_output_names()
+ export_params["dynamic_axes"] = qeff_model.model.get_onnx_dynamic_axes()
+ hash_params["export_params"] = export_params
+ manual_hash = hash_dict_params(hash_params)
- # Check if there is no re-export
- start = perf_counter()
- qeff_model.export(tmp_path)
- end = perf_counter()
- export_time = end - start
- assert export_time < 2.0
+ assert manual_hash == qeff_model.export_hash
@pytest.fixture
@@ -125,7 +151,7 @@ def test_causal_lm_compile(config, tmp_cache):
model = AutoModelForSpeechSeq2Seq.from_config(config, **model_kwargs)
qeff_model = QEFFAutoModelForSpeechSeq2Seq(model)
qeff_model.compile()
- model_path = tmp_cache / (qeff_model.model_name + "-" + qeff_model.model_hash)
+ model_path = tmp_cache / qeff_model.model_name / (qeff_model.model_name + "-" + qeff_model.export_hash)
# Check if ONNX is exported properly
assert model_path.is_dir()
@@ -135,7 +161,7 @@ def test_causal_lm_compile(config, tmp_cache):
# Check if QPC is compiled properly
assert qeff_model.qpc_path.is_dir()
assert (qeff_model.qpc_path / "programqpc.bin").is_file()
- assert qeff_model.qpc_path.relative_to(tmp_cache).parts[0] == qeff_model.model_name + "-" + qeff_model.model_hash
+ assert qeff_model.qpc_path.relative_to(tmp_cache).parts[1] == qeff_model.model_name + "-" + qeff_model.export_hash
# Check if there is no re-compilation
start = perf_counter()
diff --git a/tests/utils/test_cache.py b/tests/utils/test_cache.py
index b60dfe04a..b91126afa 100644
--- a/tests/utils/test_cache.py
+++ b/tests/utils/test_cache.py
@@ -9,7 +9,7 @@
import pytest
-from QEfficient.utils.cache import to_hashable
+from QEfficient.utils.hash_utils import to_hashable
def get_random_string(length: int) -> str:
diff --git a/tests/utils/test_hash_utils.py b/tests/utils/test_hash_utils.py
new file mode 100644
index 000000000..fefa73973
--- /dev/null
+++ b/tests/utils/test_hash_utils.py
@@ -0,0 +1,99 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import json
+import random
+
+import pytest
+
+from QEfficient.utils.constants import HASH_HEXDIGEST_STR_LEN
+from QEfficient.utils.hash_utils import hash_dict_params, json_serializable, to_hashable
+
+
+def get_random_string(length: int) -> str:
+ return "".join([chr(random.randint(0x20, 0x7E)) for _ in range(length)])
+
+
+def test_to_hashable_dict():
+ dct = {get_random_string(i): i for i in range(5)}
+ dct = dict(sorted(dct.items()))
+ hash1 = to_hashable(dct)
+
+ dct = dict(reversed(dct.items()))
+ hash2 = to_hashable(dct)
+
+ assert hash1 == hash2
+
+
+def test_to_hashable_set():
+ assert to_hashable(set(range(4))) == to_hashable(set(range(4 - 1, -1, -1)))
+
+
+@pytest.mark.parametrize("value", [float("nan"), float("inf"), -float("inf")])
+def test_to_hashable_float_nan(value):
+ with pytest.raises(ValueError):
+ to_hashable(value)
+
+
+def test_json_serializable():
+ # Test with a set
+ assert json_serializable({1, 2, 3}) == [1, 2, 3]
+ # Test with an unsupported type
+ with pytest.raises(TypeError):
+ json_serializable({1, 2, 3, {4, 5}})
+
+
+def test_to_hashable():
+ # Test with a simple dictionary
+ obj = {"key": "value"}
+ expected = json.dumps(
+ obj,
+ skipkeys=False,
+ ensure_ascii=True,
+ check_circular=True,
+ allow_nan=False,
+ indent=None,
+ separators=(",", ":"),
+ default=json_serializable,
+ sort_keys=True,
+ ).encode()
+ assert to_hashable(obj) == expected
+
+ # Test with a dictionary containing a set
+ obj_with_set = {"key": {1, 2, 3}}
+ expected_with_set = json.dumps(
+ obj_with_set,
+ skipkeys=False,
+ ensure_ascii=True,
+ check_circular=True,
+ allow_nan=False,
+ indent=None,
+ separators=(",", ":"),
+ default=json_serializable,
+ sort_keys=True,
+ ).encode()
+ assert to_hashable(obj_with_set) == expected_with_set
+
+
+def test_hash_dict_params():
+ # Test with a simple dictionary
+ dict_items = {"key": "value"}
+ hash_result = hash_dict_params(dict_items)
+ assert len(hash_result) == HASH_HEXDIGEST_STR_LEN
+ assert isinstance(hash_result, str)
+
+ # Test with a dictionary containing a set
+ dict_items_with_set = {"key": {1, 2, 3}}
+ hash_result_with_set = hash_dict_params(dict_items_with_set)
+ assert len(hash_result_with_set) == HASH_HEXDIGEST_STR_LEN
+ assert isinstance(hash_result_with_set, str)
+
+ # Test with a custom hash string size
+ custom_hash_size = 10
+ hash_result_custom_size = hash_dict_params(dict_items, custom_hash_size)
+ assert len(hash_result_custom_size) == custom_hash_size
+ assert isinstance(hash_result_custom_size, str)