From 8417d8f14694d64c765c0d2a09ca4007c958ba9b Mon Sep 17 00:00:00 2001 From: quic-sanising Date: Wed, 18 Jun 2025 13:25:31 -0500 Subject: [PATCH 01/14] Add sampler transform test Signed-off-by: quic-sanising --- tests/transformers/sampler/test_sampler.py | 119 +++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 tests/transformers/sampler/test_sampler.py diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py new file mode 100644 index 000000000..10a325754 --- /dev/null +++ b/tests/transformers/sampler/test_sampler.py @@ -0,0 +1,119 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from typing import List + +import pytest + +from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils.constants import Constants + +configs = [ + pytest.param( + "meta-llama/Llama-3.1-8B", # model + Constants.INPUT_STR, # prompts + 32, # prefill_seq_len + 256, # ctx_len + 4, # full_batch_size + 1, # num_devices + 16, # num_cores + 1, # spec_length + id="Llama-3.1-8B_32_256_4_1_16_1", + ), + pytest.param( + "meta-llama/Llama-3.1-8B", # model + Constants.INPUT_STR, # prompts + 32, # prefill_seq_len + 256, # ctx_len + 4, # full_batch_size + 4, # num_devices + 16, # num_cores + 1, # spec_length + id="Llama-3.1-8B_32_256_4_4_16_1", + ), +] + + +@pytest.mark.on_qaic +@pytest.mark.parametrize( + "model, prompts, prefill_seq_len, ctx_len, full_batch_size, num_devices, num_cores, spec_length", + configs, +) +def test_sampler_transform( + model: str, + prompts: List[str], + prefill_seq_len: int, + ctx_len: int, + full_batch_size: int, + num_devices: int, + num_cores: int, + spec_length: int, +): + # Export and compile QEfficient models + qaic_config = { + "include_sampler": True, + "return_pdfs": False, + "max_top_k_ids": 512, + } + model_w_sampler = AutoModelForCausalLM.from_pretrained(model, continuous_batching=True, qaic_config=qaic_config) + model_wo_sampler = AutoModelForCausalLM.from_pretrained(model, continuous_batching=True, qaic_config=None) + model_w_sampler_qpc_path: str = model_w_sampler.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=num_devices, + num_cores=num_cores, + num_speculative_tokens=spec_length - 1, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) + model_wo_sampler_qpc_path: str = model_wo_sampler.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=num_devices, + num_cores=num_cores, + num_speculative_tokens=spec_length - 1, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) + + # Init qaic session + model_w_sampler_session = QAICInferenceSession(model_w_sampler_qpc_path) + model_wo_sampler_session = QAICInferenceSession(model_wo_sampler_qpc_path) + + # Skip inputs/outputs buffers + model_w_sampler_session.skip_buffers(set([x for x in model_w_sampler_session.input_names if x.startswith("past_")])) + model_w_sampler_session.skip_buffers( + set([x for x in model_w_sampler_session.output_names if x.endswith("_RetainedState")]) + ) + model_wo_sampler_session.skip_buffers( + set([x for x in model_wo_sampler_session.input_names if x.startswith("past_")]) + ) + model_wo_sampler_session.skip_buffers( + set([x for x in model_wo_sampler_session.output_names if x.endswith("_RetainedState")]) + ) + + # Validate sampler inputs + sampler_inputs = [ + "last_accepted_output_tokens", + "repetition_penalties", + "presence_penalties", + "temperatures", + "top_ks", + "top_ps", + "min_ps", + "random_numbers", + ] + for input_name in sampler_inputs: + assert ( + input_name in model_w_sampler_session.input_names + ), f"Sampler input {input_name} not found in QPC compiled with Sampler" + assert ( + input_name not in model_wo_sampler_session.input_names + ), f"Sampler input {input_name} found in QPC compiled without Sampler" From 067f9b5d2375c8994212f96fe9a514b719f385c2 Mon Sep 17 00:00:00 2001 From: sanising Date: Mon, 30 Jun 2025 18:34:23 -0500 Subject: [PATCH 02/14] Add example script Signed-off-by: sanising --- examples/on_device_sampling.py | 246 +++++++++++++++++++++++++++++++++ 1 file changed, 246 insertions(+) create mode 100644 examples/on_device_sampling.py diff --git a/examples/on_device_sampling.py b/examples/on_device_sampling.py new file mode 100644 index 000000000..ebe82fbe1 --- /dev/null +++ b/examples/on_device_sampling.py @@ -0,0 +1,246 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +import argparse +import re + +import numpy as np + +from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM +from QEfficient.utils import load_hf_tokenizer + + +def main(args, **kwargs): + print(args.__dict__) + + # Get sampling inputs + include_sampler = None + return_pdfs = None + max_top_k_ids = None + sampling_params = None + if args.override_qaic_config is not None: + include_sampler = args.override_qaic_config.get("aic_include_sampler", None) == "true" + if include_sampler is not None: + return_pdfs = args.override_qaic_config.get("aic_return_pdfs", None) == "true" + max_top_k_ids = int(args.override_qaic_config.get("max_top_k_ids", 512)) + sampling_params = { + "repetition_penalties": np.array(args.repetition_penalties, dtype=np.float32).reshape(-1, 1), + "presence_penalties": np.array(args.presence_penalties, dtype=np.float32).reshape(-1, 1), + # "frequency_penalties": np.array(args.frequency_penalties, dtype=np.float32).reshape(-1, 1), + "temperatures": np.array(args.temperatures, dtype=np.float32).reshape(-1, 1), + "top_ks": np.array(args.top_ks, dtype=np.int32).reshape(-1, 1), + "top_ps": np.array(args.top_ps, dtype=np.float32).reshape(-1, 1), + "min_ps": np.array(args.min_ps, dtype=np.float32).reshape(-1, 1), + "random_numbers": np.array(args.random_numbers, dtype=np.float32).reshape(-1, 1), + } + + # Load model with On Device Sampler enabled + qeff_model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=args.model_name, + full_batch_size=args.full_batch_size, + qaic_config={ + k: v + for k, v in { + "include_sampler": include_sampler, + "return_pdfs": return_pdfs, + "max_top_k_ids": max_top_k_ids, + }.items() + if v is not None + }, + ) + print(f"{args.model_name} optimized for AI 100 \n", qeff_model) + + # Compile the model for inference + generated_qpc_path = qeff_model.compile( + prefill_seq_len=args.prompt_len, + ctx_len=args.ctx_len, + batch_size=args.batch_size, + full_batch_size=args.full_batch_size, + num_cores=args.num_cores, + num_devices=(0 if args.device_group is None else len(args.device_group)), + mxfp6_matmul=args.mxfp6, + mxint8_kv_cache=args.mxint8, + num_speculative_tokens=0, + **kwargs, + ) + print(f"Generated QPC file path: {generated_qpc_path}") + + # Generate texts from prompts + if not args.prompt: + args.prompt = [ + "Hi", + ] * args.full_batch_size + qeff_model.generate( + tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=args.model_name), + prompts=args.prompt, + prompts_txt_file_path=args.prompts_txt_file_path, + device_id=args.device_group, + generation_len=args.generation_len, + include_sampler=include_sampler, + return_pdfs=return_pdfs, + sampling_params=sampling_params, + ) + + +if __name__ == "__main__": + """ + Example usage: + + python3.10 examples/on_device_sampling.py \ + --model-name 'meta-llama/Llama-3.1-8B' \ + --prompt-len 128 \ + --ctx-len 256 \ + --generation-len 20 \ + --full-batch-size 2 \ + --device-group [0,1,2,3] \ + --num-cores 16 \ + --mxint8-kv-cache \ + --mxfp6-matmul \ + --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \ + --repetition-penalties 1.9,1.0 \ + --presence-penalties 0.8,0.11 \ + --temperatures 0.67,0.52 \ + --top-ks 54720,23095 \ + --top-ps 0.89,0.56 \ + --min-ps 0.6,0.71 \ + --random-numbers 0.26,0.87 + """ + + parser = argparse.ArgumentParser(description="Run QEfficient model with On Device Sampling") + parser.add_argument( + "--model-name", "--model_name", required=True, default="meta-llama/Llama-3.1-8B", help="HF Model card name/id" + ) + parser.add_argument("--batch-size", "--batch_size", type=int, default=1, help="Batch size for text generation") + parser.add_argument( + "--prompt-len", "--prompt_len", default=32, type=int, help="Sequence length for text generation." + ) + parser.add_argument("--ctx-len", "--ctx_len", default=128, type=int, help="Context length for text generation.") + parser.add_argument( + "--mxfp6", + "--mxfp6_matmul", + "--mxfp6-matmul", + action="store_true", + help="Compress constant MatMul weights to MXFP6 E2M3, default is no compression", + ) + parser.add_argument( + "--mxint8", + "--mxint8_kv_cache", + "--mxint8-kv-cache", + action="store_true", + help="Compress Present/Past KV to MXINT8 using CustomIO config, default is False", + ) + parser.add_argument( + "--num_cores", "--num-cores", type=int, required=True, help="Number of cores to compile on Cloud AI 100" + ) + parser.add_argument( + "--device_group", + "--device-group", + type=lambda device_ids: [int(x) for x in device_ids.strip("[]").split(",")], + help="Cloud AI 100 device ids (comma-separated) e.g. [0,1]", + ) + parser.add_argument( + "--prompt", + type=lambda prompt: prompt.split("|"), + help="Input prompt, if executing for batch size>1, pass input prompts in single string but separate with pipe (|) symbol", + ) + parser.add_argument( + "--prompts_txt_file_path", + "--prompts-txt-file-path", + type=str, + help="File path for taking input prompts from txt file, sample prompts.txt file present in examples folder", + ) + parser.add_argument("--generation_len", "--generation-len", type=int, help="Number of tokens to generate") + + parser.add_argument( + "--full_batch_size", + "--full-batch-size", + type=int, + default=None, + help="Set full batch size to enable continuous batching mode, default is None", + ) + parser.add_argument( + "--override-qaic-config", + type=lambda configs: { + str(value[0]): value[1] if len(value) > 1 else True + for value in (re.split(r"[:=]", config.strip()) for config in re.split(r"[ ]+", configs.strip())) + }, + default=None, + help="override or set qaic device configuration.", + ) + + # ---On Device Sampling--- + sampling_group = parser.add_argument_group("Sampling parameters") + sampling_group.add_argument( + "--repetition-penalties", + type=lambda data: [float(x) for x in data.split(",")], + default=None, + help="Comma-separated list of floating point values where each value is a sampling " + "parameter that penalizes new tokens based on whether they appear in the prompt and the " + "generated text so far. Values > 1 encourage the model to use new tokens, while values < 1 " + "encourage the model to repeat tokens.", + ) + sampling_group.add_argument( + "--presence-penalties", + type=lambda data: [float(x) for x in data.split(",")], + default=None, + help="Comma-separated list of floating point values where each value is a sampling " + "parameter that penalizes new tokens based on whether they appear in the generated text " + "so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the " + "model to repeat tokens.", + ) + sampling_group.add_argument( + "--temperatures", + type=lambda data: [float(x) for x in data.split(",")], + default=None, + help="Comma-separated list of floating point values where each value is a sampling " + "parameter that controls the randomness of the sampling. Lower values make the model more " + "deterministic, while higher values make the model more random. Zero means greedy sampling.", + ) + sampling_group.add_argument( + "--top-ks", + type=lambda data: [int(x) for x in data.split(",")], + default=None, + help="Comma-separated list of integer values where each value is a sampling parameter that " + "controls the number of top tokens to consider. Set to -1 to consider all tokens.", + ) + sampling_group.add_argument( + "--top-ps", + type=lambda data: [float(x) for x in data.split(",")], + default=None, + help="Comma-separated list of floating point values where each value is a sampling " + "parameter that controls the cumulative probability of the top tokens to consider. Must be " + "in (0, 1]. Set to 1.0 to consider all tokens.", + ) + sampling_group.add_argument( + "--min-ps", + type=lambda data: [float(x) for x in data.split(",")], + default=None, + help="Comma-separated list of floating point values where each value is a sampling " + "parameter that represents the minumum probability for a token to be considered, relative " + "to the probability of the most likely token. Must be in [0, 1]. Set to 0.0 to disable " + "this.", + ) + sampling_group.add_argument( + "--random-numbers", + type=lambda data: [float(x) for x in data.split(",")], + default=None, + help="Comma-separated list of floating point values where each value is a sampling " + "parameter that represents the random seeds to use for random sampling. Must be in [-1, 1].", + ) + args, compiler_options = parser.parse_known_args() + + compiler_options_dict = {} + for i in range(0, len(compiler_options)): + if compiler_options[i].startswith("--"): + key = compiler_options[i].lstrip("-").replace("-", "_") + value = ( + compiler_options[i + 1] + if i + 1 < len(compiler_options) and not compiler_options[i + 1].startswith("-") + else True + ) + compiler_options_dict[key] = value + + main(args, **compiler_options_dict) From 931860fbee2c2df9b6d0ddfb9a3cb7b1c2eca645 Mon Sep 17 00:00:00 2001 From: sanising Date: Mon, 30 Jun 2025 18:35:01 -0500 Subject: [PATCH 03/14] Update docs Signed-off-by: sanising --- docs/source/quick_start.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index abab4cfc3..67256e83b 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -19,7 +19,8 @@ To achieve this, we have 2 levels of APIs, with different levels of abstraction. | [Vision Language Model](QEFFAutoModelForImageTextToText) | Provides support for the AutoModelForImageTextToText class from the transformers library, enabling advanced vision-language tasks. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/image_text_to_text_inference.py) for more **details**. | | [Speech Sequence to Sequence Model](QEFFAutoModelForSpeechSeq2Seq) | Provides support for the QEFFAutoModelForSpeechSeq2Seq Facilitates speech-to-text sequence models. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/speech_to_text/run_whisper_speech_to_text.py) for more **details**. | | Support for FP8 Execution | Enables execution with FP8 precision, significantly improving performance and reducing memory usage for computational tasks. | -| Prefill caching | Enhances inference speed by caching key-value pairs for shared prefixes, reducing redundant computations and improving efficiency. | +| Prefix caching | Enhances inference speed by caching key-value pairs for shared prefixes, reducing redundant computations and improving efficiency. | +| On Device Sampling | Enables sampling operations to be executed directly on the QAIC device rather than the host CPU for QEffForCausalLM models. This enhancement significantly reduces host-device communication overhead and improves inference throughput and scalability. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/on_device_sampling.py) for more **details**. | |Prompt-Lookup Decoding | Speeds up text generation by using overlapping parts of the input prompt and the generated text, making the process faster without losing quality. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/pld_spd_inference.py) for more **details**.| | [PEFT LoRA support](QEffAutoPeftModelForCausalLM) | Enables parameter-efficient fine-tuning using low-rank adaptation techniques, reducing the computational and memory requirements for fine-tuning large models. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/peft_models.py) for more **details**. | | [QNN support](#qnn-compilation) | Enables compilation using QNN SDK, making Qeff adaptable for various backends in the future. | From 79b6c956ef77dfbc8486c985f7754aead9869150 Mon Sep 17 00:00:00 2001 From: sanising Date: Mon, 30 Jun 2025 18:40:36 -0500 Subject: [PATCH 04/14] Enable On Device Sampling for _continuous_batching_execution() Signed-off-by: sanising --- .../generation/text_generation_inference.py | 193 +++++++++++++++--- .../transformers/models/modeling_auto.py | 1 + 2 files changed, 168 insertions(+), 26 deletions(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index a9690aa51..2732ca0c0 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -10,7 +10,7 @@ from collections import deque from dataclasses import dataclass from time import perf_counter -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import transformers @@ -322,6 +322,9 @@ def cloud_ai_100_exec_kv( automation=False, prompt_to_lora_id_mapping: Optional[List[int]] = None, is_tlm: bool = False, + include_sampler: bool = False, + return_pdfs: bool = False, + sampling_params: Optional[Dict[str, Any]] = None, ): """ This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards. @@ -342,6 +345,11 @@ def cloud_ai_100_exec_kv( :Write_io_dir (str): Path to write the input and output files. ``Defaults to None``. :automation (bool): If true, it prints input, output, and performance stats. ``Defaults to False``. :prompt_to_lora_id_mapping (List[int]): Mapping to associate prompts with their respective LoRA adapter. + :include_sampler (bool): Enable/Disable sampling of next tokens. + :return_pdfs (bool): Return probability distributions along with sampled + next tokens. For Speculative Decoding Target Language Model, + `return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative + Decoding Draft Language Model and `return_pdfs`=False for regular model. Returns: :CloudAI100ExecInfo: Object holding execution output and performance details. @@ -372,6 +380,9 @@ def cloud_ai_100_exec_kv( write_io_dir=write_io_dir, full_batch_size=full_batch_size, is_tlm=is_tlm, + include_sampler=include_sampler, + return_pdfs=return_pdfs, + sampling_params=sampling_params, ) if full_batch_size is None: exec_info = [ @@ -411,14 +422,59 @@ def __init__( enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, is_tlm: Optional[int] = None, + include_sampler: bool = False, + return_pdfs: bool = False, + sampling_params: Optional[Dict[str, Any]] = None, ) -> None: self._ctx_len = ctx_len self._write_io_dir = write_io_dir self.is_tlm = is_tlm + self.include_sampler = include_sampler + self.return_pdfs = return_pdfs + self.sampling_params = sampling_params # Load QPC self._session = QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs) + # Validate sampler inputs for On-Device Sampling + sampler_inputs = [ + "last_accepted_output_tokens", + "repetition_penalties", + "presence_penalties", + "temperatures", + "top_ks", + "top_ps", + "min_ps", + "random_numbers", + ] + count = 0 + for session_input_name in self._session.input_names: + if session_input_name in sampler_inputs: + count += 1 + if count == len(sampler_inputs): + break + if count == 0: + self.include_sampler = False + elif count < len(sampler_inputs): + raise ValueError( + "The provided QPC does not have the required number of inputs to run sampling " + f"on the QAIC device (only {count}/{len(sampler_inputs)} inputs provided). Partial " + "sampling support is not available. Please check the QPC and try again." + ) + else: # count == len(sampler_inputs) + self.include_sampler = True + if include_sampler and not self.include_sampler: + logger.warning_once( + "User entered `include_sampler`=True. But the provided QPC is not compiled " + "to run sampling on the QAIC device. Falling back to the PyTorch backend." + ) + elif (include_sampler is None or not include_sampler) and self.include_sampler: + raise ValueError( + "The provided QPC is compiled to run sampling on the QAIC device. " + "But the user did not enter `include_sampler`=True. Please make sure the input " + "is specified correctly." + ) + # Fetch the variables from the QPC self._vocab_size = self._fetch_vocab_size() # Fetch Vocab size self.batch_size, self._prefill_seq_len = self._fetch_batch_size_prefill_seq_len() @@ -523,10 +579,17 @@ def _fetch_vocab_size( Returns: vocab_size: The vocabulary size fetched from the session's allowed shapes. """ + key = ( + "probs" + if self.include_sampler and self.return_pdfs + else "next_tokens" + if self.include_sampler + else "logits" + ) if self._session.allowed_shapes: - return [x[self._session.binding_index_map["logits"]] for x in self._session.allowed_shapes][0][1][2] + return [x[self._session.binding_index_map[key]] for x in self._session.allowed_shapes][0][1][2] - return self._session.bindings[self._session.binding_index_map["logits"]].dims[2] + return self._session.bindings[self._session.binding_index_map[key]].dims[2] def _fetch_generation_len(self, generation_len, max_gen_len): """ @@ -574,6 +637,21 @@ def prepare_decode_inputs(self): decode_inputs["position_ids"] = self.decode_pos_ids if self.batch_index is not None: decode_inputs["batch_index"] = self.batch_index + if self.include_sampler: + decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] + for op in [ + "repetition_penalties", + "presence_penalties", + "temperatures", + "top_ks", + "top_ps", + "min_ps", + "random_numbers", + ]: + if self.batch_index is not None: + decode_inputs[op] = self.sampling_params[op][self.batch_index.flatten()] + else: + decode_inputs[op] = self.sampling_params[op] if self._prompt_to_lora_id_mapping_decode: if self.full_batch_size: @@ -589,21 +667,24 @@ def prepare_decode_inputs(self): def _fetch_next_token_id(self, outputs): """ - Fetches the next token ID from the model's output logits. - The method identifies the token with the highest probability using argmax along the last dimension. + Fetches the next token ID from the model's output. + Args: - outputs (dict): A dictionary containing the model's output logits. The key "logits" should map to a numpy array of shape (batch_size, sequence_length, vocab_size) or (batch_size, vocab_size). + outputs (dict): A dictionary containing the model's output. Returns: numpy.ndarray: An array of the next token IDs for each sequence in the batch. """ - logits = outputs["logits"] - if len(logits.shape) == 2: - logits = np.expand_dims(logits, 1) - - # Get output token - next_token_id = logits.argmax(2) - return next_token_id + if self.include_sampler: + if self.return_pdfs: + return outputs["probs"].argmax(2) + else: + return outputs["next_tokens"].reshape(outputs["next_tokens"].shape[0], outputs["next_tokens"].shape[1]) + else: + logits = outputs["logits"] + if len(logits.shape) == 2: + logits = np.expand_dims(logits, 1) + return logits.argmax(2) def initialize_decode_inputs(self, num_prompts, execution_batch_size, max_gen_length): """ @@ -673,6 +754,23 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len): _ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id) + def _set_output_buffers(self, batch_size: int = 1, sequence_length: int = 1): + """ + Sets the sizes of the output buffers. + + Args: + batch_size (int): The batch size. + """ + if self.include_sampler: + if self.return_pdfs: + probs_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32) + self._session.set_buffers({"probs": probs_out_placeholder}) + next_tokens_out_placeholder = np.zeros((batch_size, sequence_length, 1), dtype=np.int64) + self._session.set_buffers({"next_tokens": next_tokens_out_placeholder}) + else: + logits_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32) + self._session.set_buffers({"logits": logits_out_placeholder}) + def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None): """ Runs prefill for a given prompt and generation length. @@ -702,9 +800,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i max_gen_len = self._ctx_len - position_ids.max() generation_len = self._fetch_generation_len(generation_len, max_gen_len) - # Set the prefill logic buffer - logits_out_placeholder = np.zeros((prefill_logit_bs, 1, self._vocab_size), dtype=np.float32) - self._session.set_buffers({"logits": logits_out_placeholder}) + # Set the prefill output buffers + self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1) inputs = self.tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) @@ -714,6 +811,21 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i inputs["batch_index"] = decode_batch_id if self.is_tlm: inputs["num_logits_to_keep"] = np.zeros((1, 1)) + if self.include_sampler: + inputs["last_accepted_output_tokens"] = inputs["input_ids"] + for op in [ + "repetition_penalties", + "presence_penalties", + "temperatures", + "top_ks", + "top_ps", + "min_ps", + "random_numbers", + ]: + if decode_batch_id is not None: + inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] + else: + inputs[op] = self.sampling_params[op] if self._prompt_to_lora_id_mapping_prefill: if self.full_batch_size: @@ -732,6 +844,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i chunk_inputs["position_ids"] = inputs["position_ids"][ :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len ] + if self.include_sampler: + chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] outputs = self._session.run(chunk_inputs) if self._write_io_dir is not None: write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) @@ -753,11 +867,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): """ - # Set logits placeholder for decode - logits_out_placeholder = np.zeros( - (self.full_batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32 + # Set output placeholders for decode + self._set_output_buffers( + batch_size=self.full_batch_size, + sequence_length=self._decode_seq_len, ) - self._session.set_buffers({"logits": logits_out_placeholder}) + # Generate flag for tracking progress for each batch ID current_decode_ongoing = np.full((self.full_batch_size, 1), True) @@ -775,10 +890,18 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): outputs = self._session.run(decode_inputs) # Prepare inputs for next iteration - logits = outputs["logits"] - if len(logits.shape) == 2: - logits = np.expand_dims(logits, 1) - next_token_id = logits.argmax(2) + if self.include_sampler: + if self.return_pdfs: + next_token_id = outputs["probs"].argmax(2) + else: + next_token_id = outputs["next_tokens"].reshape( + outputs["next_tokens"].shape[0], outputs["next_tokens"].shape[1] + ) + else: # Perform Greedy Sampling on Host + logits = outputs["logits"] + if len(logits.shape) == 2: + logits = np.expand_dims(logits, 1) + next_token_id = logits.argmax(2) for decode_batch_id in range(self.full_batch_size): if ( @@ -800,7 +923,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): self.generated_ids[batch_id_map[decode_batch_id], 0] = new_token_id.squeeze(1) generated_id_current_index[decode_batch_id] = 1 - self._session.set_buffers({"logits": logits_out_placeholder}) + self._set_output_buffers( + batch_size=self.full_batch_size, + sequence_length=self._decode_seq_len, + ) decode_pause_time += perf_counter() - start if self._prompt_to_lora_id_mapping_decode: @@ -817,6 +943,8 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): self.generated_ids[batch_id_map[decode_batch_id], generated_id_current_index[decode_batch_id]] = ( next_token_id[decode_batch_id, -1] ) + if self.include_sampler: + decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] generated_id_current_index[decode_batch_id] += 1 @@ -905,9 +1033,22 @@ def __init__( enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, is_tlm: bool = False, + include_sampler: bool = False, + return_pdfs: bool = False, + sampling_params: Optional[Dict[str, Any]] = None, ) -> None: self._qaic_model = QEffTextGenerationBase( - tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir, is_tlm + tokenizer=tokenizer, + qpc_path=qpc_path, + full_batch_size=full_batch_size, + ctx_len=ctx_len, + device_id=device_id, + enable_debug_logs=enable_debug_logs, + write_io_dir=write_io_dir, + is_tlm=is_tlm, + include_sampler=include_sampler, + return_pdfs=return_pdfs, + sampling_params=sampling_params, ) self._full_batch_size = self._qaic_model.full_batch_size self._tokenizer = self._qaic_model.tokenizer diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 2f3ee3dc0..479000bff 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1893,6 +1893,7 @@ def generate( device_id=device_id, generation_len=generation_len, is_tlm=self.is_tlm, + **kwargs, ) else: raise NotImplementedError("Only AI_100 runtime is supported right now via generate API") From 75eac30a8ddfa261833f6e4db5ae5f5ae5d03a4f Mon Sep 17 00:00:00 2001 From: sanising Date: Mon, 30 Jun 2025 19:10:32 -0500 Subject: [PATCH 05/14] Disable On Device Sampling for _regular_model_execution() Signed-off-by: sanising --- .../generation/text_generation_inference.py | 22 ++++++++----------- examples/on_device_sampling.py | 5 +++++ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 2732ca0c0..1b452f25d 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -890,18 +890,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): outputs = self._session.run(decode_inputs) # Prepare inputs for next iteration - if self.include_sampler: - if self.return_pdfs: - next_token_id = outputs["probs"].argmax(2) - else: - next_token_id = outputs["next_tokens"].reshape( - outputs["next_tokens"].shape[0], outputs["next_tokens"].shape[1] - ) - else: # Perform Greedy Sampling on Host - logits = outputs["logits"] - if len(logits.shape) == 2: - logits = np.expand_dims(logits, 1) - next_token_id = logits.argmax(2) + next_token_id = self._fetch_next_token_id(outputs) for decode_batch_id in range(self.full_batch_size): if ( @@ -968,6 +957,11 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform (self.batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32 ) self._session.set_buffers({"logits": logits_out_placeholder}) + else: + self._set_output_buffers( + batch_size=self.batch_size, + sequence_length=self._decode_seq_len, + ) finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id num_token = 0 for num_token in range(1, generation_len): @@ -980,10 +974,12 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform self._write_io_dir = None # Prepare inputs for next iteration - decode_inputs["input_ids"] = outputs["logits"].argmax(2) + decode_inputs["input_ids"] = self._fetch_next_token_id(outputs) decode_inputs["position_ids"][:, -1] += 1 self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1] finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id + if self.include_sampler: + decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] if finished_sequences.all(): break diff --git a/examples/on_device_sampling.py b/examples/on_device_sampling.py index ebe82fbe1..ba93d0749 100644 --- a/examples/on_device_sampling.py +++ b/examples/on_device_sampling.py @@ -16,6 +16,11 @@ def main(args, **kwargs): print(args.__dict__) + if args.full_batch_size is None: + raise ValueError( + "On Device Sampling is only supported with Continuous Batching. Please specify --full-batch-size." + ) + # Get sampling inputs include_sampler = None return_pdfs = None From eb6e2ebef95f96046e7a8eaa9ea8753620850a70 Mon Sep 17 00:00:00 2001 From: sanising Date: Tue, 1 Jul 2025 18:41:04 -0500 Subject: [PATCH 06/14] Use same sampling parameters for each sequence in a batch Signed-off-by: sanising --- examples/on_device_sampling.py | 95 ++++++++++++++++++---------------- 1 file changed, 50 insertions(+), 45 deletions(-) diff --git a/examples/on_device_sampling.py b/examples/on_device_sampling.py index ba93d0749..cb540c3ae 100644 --- a/examples/on_device_sampling.py +++ b/examples/on_device_sampling.py @@ -32,14 +32,24 @@ def main(args, **kwargs): return_pdfs = args.override_qaic_config.get("aic_return_pdfs", None) == "true" max_top_k_ids = int(args.override_qaic_config.get("max_top_k_ids", 512)) sampling_params = { - "repetition_penalties": np.array(args.repetition_penalties, dtype=np.float32).reshape(-1, 1), - "presence_penalties": np.array(args.presence_penalties, dtype=np.float32).reshape(-1, 1), - # "frequency_penalties": np.array(args.frequency_penalties, dtype=np.float32).reshape(-1, 1), - "temperatures": np.array(args.temperatures, dtype=np.float32).reshape(-1, 1), - "top_ks": np.array(args.top_ks, dtype=np.int32).reshape(-1, 1), - "top_ps": np.array(args.top_ps, dtype=np.float32).reshape(-1, 1), - "min_ps": np.array(args.min_ps, dtype=np.float32).reshape(-1, 1), - "random_numbers": np.array(args.random_numbers, dtype=np.float32).reshape(-1, 1), + "repetition_penalties": np.array(args.repetition_penalty, dtype=np.float32) + .repeat(args.full_batch_size) + .reshape(-1, 1), + "presence_penalties": np.array(args.presence_penalty, dtype=np.float32) + .repeat(args.full_batch_size) + .reshape(-1, 1), + # "frequency_penalties": np.array(args.frequency_penalty, dtype=np.float32) + # .repeat(args.full_batch_size) + # .reshape(-1, 1), + "temperatures": np.array(args.temperature, dtype=np.float32) + .repeat(args.full_batch_size) + .reshape(-1, 1), + "top_ks": np.array(args.top_k, dtype=np.int32).repeat(args.full_batch_size).reshape(-1, 1), + "top_ps": np.array(args.top_p, dtype=np.float32).repeat(args.full_batch_size).reshape(-1, 1), + "min_ps": np.array(args.min_p, dtype=np.float32).repeat(args.full_batch_size).reshape(-1, 1), + "random_numbers": np.array(args.random_number, dtype=np.float32) + .repeat(args.full_batch_size) + .reshape(-1, 1), } # Load model with On Device Sampler enabled @@ -179,61 +189,56 @@ def main(args, **kwargs): # ---On Device Sampling--- sampling_group = parser.add_argument_group("Sampling parameters") sampling_group.add_argument( - "--repetition-penalties", - type=lambda data: [float(x) for x in data.split(",")], + "--repetition-penalty", + type=float, default=None, - help="Comma-separated list of floating point values where each value is a sampling " - "parameter that penalizes new tokens based on whether they appear in the prompt and the " - "generated text so far. Values > 1 encourage the model to use new tokens, while values < 1 " - "encourage the model to repeat tokens.", + help="Sampling parameter that penalizes new tokens based on whether they appear in the " + "prompt and the generated text so far. Values > 1 encourage the model to use new tokens, " + "while values < 1 encourage the model to repeat tokens.", ) sampling_group.add_argument( - "--presence-penalties", - type=lambda data: [float(x) for x in data.split(",")], + "--presence-penalty", + type=float, default=None, - help="Comma-separated list of floating point values where each value is a sampling " - "parameter that penalizes new tokens based on whether they appear in the generated text " - "so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the " - "model to repeat tokens.", + help="Sampling parameter that penalizes new tokens based on whether they appear in the " + "generated text so far. Values > 0 encourage the model to use new tokens, while values < " + "0 encourage the model to repeat tokens.", ) sampling_group.add_argument( - "--temperatures", - type=lambda data: [float(x) for x in data.split(",")], - default=None, - help="Comma-separated list of floating point values where each value is a sampling " - "parameter that controls the randomness of the sampling. Lower values make the model more " - "deterministic, while higher values make the model more random. Zero means greedy sampling.", + "--temperature", + type=float, + default=0.0, + help="Sampling parameter that controls the randomness of the sampling. Lower" + "values make the model more deterministic, while higher values make" + "the model more random. Zero means greedy sampling.", ) sampling_group.add_argument( - "--top-ks", - type=lambda data: [int(x) for x in data.split(",")], + "--top-k", + type=int, default=None, - help="Comma-separated list of integer values where each value is a sampling parameter that " - "controls the number of top tokens to consider. Set to -1 to consider all tokens.", + help="Sampling parameter that controls the number of top tokens to consider. Set to -1 " + "to consider all tokens.", ) sampling_group.add_argument( - "--top-ps", - type=lambda data: [float(x) for x in data.split(",")], + "--top-p", + type=float, default=None, - help="Comma-separated list of floating point values where each value is a sampling " - "parameter that controls the cumulative probability of the top tokens to consider. Must be " - "in (0, 1]. Set to 1.0 to consider all tokens.", + help="Sampling parameter that controls the cumulative probability of the top tokens to " + "consider. Must be in (0, 1]. Set to 1.0 to consider all tokens.", ) sampling_group.add_argument( - "--min-ps", - type=lambda data: [float(x) for x in data.split(",")], + "--min-p", + type=float, default=None, - help="Comma-separated list of floating point values where each value is a sampling " - "parameter that represents the minumum probability for a token to be considered, relative " - "to the probability of the most likely token. Must be in [0, 1]. Set to 0.0 to disable " - "this.", + help="Sampling parameter that represents the minumum probability for a token to be " + "considered, relative to the probability of the most likely token. Must be in [0, 1]. " + "Set to 0.0 to disable this.", ) sampling_group.add_argument( - "--random-numbers", - type=lambda data: [float(x) for x in data.split(",")], + "--random-number", + type=float, default=None, - help="Comma-separated list of floating point values where each value is a sampling " - "parameter that represents the random seeds to use for random sampling. Must be in [-1, 1].", + help="Sampling parameter that represents the random seed to use for random sampling. " "Must be in [-1, 1].", ) args, compiler_options = parser.parse_known_args() From 48b35e39efdbf1227c80c2c9f0c89e46906c54ba Mon Sep 17 00:00:00 2001 From: sanising Date: Wed, 2 Jul 2025 18:43:37 -0500 Subject: [PATCH 07/14] Enable On Device Sampling for _regular_model_execution() Signed-off-by: sanising --- .../generation/text_generation_inference.py | 4 + QEfficient/transformers/sampler/sampler.py | 3 + examples/on_device_sampling.py | 98 +++++++++++-------- 3 files changed, 62 insertions(+), 43 deletions(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 1b452f25d..2e53fa275 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -350,6 +350,10 @@ def cloud_ai_100_exec_kv( next tokens. For Speculative Decoding Target Language Model, `return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative Decoding Draft Language Model and `return_pdfs`=False for regular model. + sampling_params (Dict[str, Any]): A dictionary of sampling parameters supported by the QAIC backend. + The dictionary should contain the following keys: + `repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`, + `min_ps`, and `random_numbers`. Each value should be a numpy array of shape (batch_size, 1). Returns: :CloudAI100ExecInfo: Object holding execution output and performance details. diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index 6bcabf29a..96846e712 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -193,6 +193,9 @@ def sampler_forward( batch_size, spec_length, vocab_size = logits.shape logits = logits.reshape(-1, vocab_size) # Reshape tensor to 2D + if batch_index is None: # Regular model execution + batch_index = torch.arange(batch_size).view(-1, 1) + batch_index_reshaped = batch_index.view(-1) # Prefill past_repetition_penalty_buffer_prefill, past_presence_penalty_buffer_prefill = prefill_path( diff --git a/examples/on_device_sampling.py b/examples/on_device_sampling.py index cb540c3ae..09b99b681 100644 --- a/examples/on_device_sampling.py +++ b/examples/on_device_sampling.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- import argparse import re +from pprint import pprint import numpy as np @@ -14,57 +15,48 @@ def main(args, **kwargs): - print(args.__dict__) - - if args.full_batch_size is None: - raise ValueError( - "On Device Sampling is only supported with Continuous Batching. Please specify --full-batch-size." - ) + pprint(args.__dict__) # Get sampling inputs include_sampler = None return_pdfs = None max_top_k_ids = None sampling_params = None + bs = args.full_batch_size if args.full_batch_size is not None else args.batch_size if args.override_qaic_config is not None: include_sampler = args.override_qaic_config.get("aic_include_sampler", None) == "true" if include_sampler is not None: return_pdfs = args.override_qaic_config.get("aic_return_pdfs", None) == "true" max_top_k_ids = int(args.override_qaic_config.get("max_top_k_ids", 512)) sampling_params = { - "repetition_penalties": np.array(args.repetition_penalty, dtype=np.float32) - .repeat(args.full_batch_size) - .reshape(-1, 1), - "presence_penalties": np.array(args.presence_penalty, dtype=np.float32) - .repeat(args.full_batch_size) - .reshape(-1, 1), - # "frequency_penalties": np.array(args.frequency_penalty, dtype=np.float32) - # .repeat(args.full_batch_size) - # .reshape(-1, 1), - "temperatures": np.array(args.temperature, dtype=np.float32) - .repeat(args.full_batch_size) - .reshape(-1, 1), - "top_ks": np.array(args.top_k, dtype=np.int32).repeat(args.full_batch_size).reshape(-1, 1), - "top_ps": np.array(args.top_p, dtype=np.float32).repeat(args.full_batch_size).reshape(-1, 1), - "min_ps": np.array(args.min_p, dtype=np.float32).repeat(args.full_batch_size).reshape(-1, 1), - "random_numbers": np.array(args.random_number, dtype=np.float32) - .repeat(args.full_batch_size) - .reshape(-1, 1), + "repetition_penalties": np.array(args.repetition_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), + "presence_penalties": np.array(args.presence_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), + # "frequency_penalties": np.array(args.frequency_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), + "temperatures": np.array(args.temperature, dtype=np.float32).repeat(bs).reshape(-1, 1), + "top_ks": np.array(args.top_k, dtype=np.int32).repeat(bs).reshape(-1, 1), + "top_ps": np.array(args.top_p, dtype=np.float32).repeat(bs).reshape(-1, 1), + "min_ps": np.array(args.min_p, dtype=np.float32).repeat(bs).reshape(-1, 1), + "random_numbers": np.array(args.random_number, dtype=np.float32).repeat(bs).reshape(-1, 1), } + qaic_config = { + k: v + for k, v in { + "include_sampler": include_sampler, + "return_pdfs": return_pdfs, + "max_top_k_ids": max_top_k_ids, + }.items() + if v is not None + } + print("qaic_config:") + pprint(qaic_config) + print("sampling_params:") + pprint(sampling_params) # Load model with On Device Sampler enabled qeff_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=args.model_name, - full_batch_size=args.full_batch_size, - qaic_config={ - k: v - for k, v in { - "include_sampler": include_sampler, - "return_pdfs": return_pdfs, - "max_top_k_ids": max_top_k_ids, - }.items() - if v is not None - }, + continuous_batching=args.full_batch_size is not None, + qaic_config=qaic_config, ) print(f"{args.model_name} optimized for AI 100 \n", qeff_model) @@ -87,7 +79,7 @@ def main(args, **kwargs): if not args.prompt: args.prompt = [ "Hi", - ] * args.full_batch_size + ] * bs qeff_model.generate( tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=args.model_name), prompts=args.prompt, @@ -103,7 +95,7 @@ def main(args, **kwargs): if __name__ == "__main__": """ Example usage: - + 1. For continuous batching: python3.10 examples/on_device_sampling.py \ --model-name 'meta-llama/Llama-3.1-8B' \ --prompt-len 128 \ @@ -115,13 +107,33 @@ def main(args, **kwargs): --mxint8-kv-cache \ --mxfp6-matmul \ --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \ - --repetition-penalties 1.9,1.0 \ - --presence-penalties 0.8,0.11 \ - --temperatures 0.67,0.52 \ - --top-ks 54720,23095 \ - --top-ps 0.89,0.56 \ - --min-ps 0.6,0.71 \ - --random-numbers 0.26,0.87 + --repetition-penalty 1.9 \ + --presence-penalty 0.8 \ + --temperature 0.67 \ + --top-k 54720 \ + --top-p 0.89 \ + --min-p 0.6 \ + --random-number 0.26 + + 2. For non-continuous batching: + python3.10 examples/on_device_sampling.py \ + --model-name 'meta-llama/Llama-3.1-8B' \ + --prompt-len 128 \ + --ctx-len 256 \ + --generation-len 20 \ + --batch-size 2 \ + --device-group [0,1,2,3] \ + --num-cores 16 \ + --mxint8-kv-cache \ + --mxfp6-matmul \ + --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \ + --repetition-penalty 1.9 \ + --presence-penalty 0.8 \ + --temperature 0.67 \ + --top-k 54720 \ + --top-p 0.89 \ + --min-p 0.6 \ + --random-number 0.26 """ parser = argparse.ArgumentParser(description="Run QEfficient model with On Device Sampling") From c83a631d64f89e7314acbdf99ef8ad0f2e8e4c42 Mon Sep 17 00:00:00 2001 From: sanising Date: Wed, 2 Jul 2025 19:15:49 -0500 Subject: [PATCH 08/14] Add test for greedy sampling Signed-off-by: sanising --- tests/transformers/sampler/test_sampler.py | 252 ++++++++++++++++++++- 1 file changed, 240 insertions(+), 12 deletions(-) diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index 10a325754..83474b204 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -7,10 +7,12 @@ from typing import List +import numpy as np import pytest from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils import load_hf_tokenizer from QEfficient.utils.constants import Constants configs = [ @@ -19,29 +21,47 @@ Constants.INPUT_STR, # prompts 32, # prefill_seq_len 256, # ctx_len + 20, # generation_len 4, # full_batch_size 1, # num_devices + [0], # device_group 16, # num_cores 1, # spec_length + 1.9, # repetition_penalty + 0.8, # presence_penalty + 0.67, # temperature + 54720, # top_k + 0.89, # top_p + 0.6, # min_p + 0.26, # random_number id="Llama-3.1-8B_32_256_4_1_16_1", ), - pytest.param( - "meta-llama/Llama-3.1-8B", # model - Constants.INPUT_STR, # prompts - 32, # prefill_seq_len - 256, # ctx_len - 4, # full_batch_size - 4, # num_devices - 16, # num_cores - 1, # spec_length - id="Llama-3.1-8B_32_256_4_4_16_1", - ), + # pytest.param( + # "meta-llama/Llama-3.1-8B", + # Constants.INPUT_STR, + # 32, + # 256, + # 20, + # 4, + # 4, + # [0, 1, 2, 3], + # 16, + # 1, + # 1.9, + # 0.8, + # 0.67, + # 54720, + # 0.89, + # 0.6, + # 0.26, + # id="Llama-3.1-8B_32_256_4_4_16_1", + # ), ] @pytest.mark.on_qaic @pytest.mark.parametrize( - "model, prompts, prefill_seq_len, ctx_len, full_batch_size, num_devices, num_cores, spec_length", + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, num_devices, device_group, num_cores, spec_length, repetition_penalty, presence_penalty, temperature, top_k, top_p, min_p, random_number", configs, ) def test_sampler_transform( @@ -49,11 +69,25 @@ def test_sampler_transform( prompts: List[str], prefill_seq_len: int, ctx_len: int, + generation_len: int, full_batch_size: int, num_devices: int, + device_group: List[int], num_cores: int, spec_length: int, + repetition_penalty: float, + presence_penalty: float, + temperature: float, + top_k: int, + top_p: float, + min_p: float, + random_number: float, ): + """ + Test if `SamplerTransform` adds nodes at the output of a `QEffForCausalLM model` to enable the + sampling of next tokens at the device (instead of the host) and returns the + next tokens and/or probability distributions. + """ # Export and compile QEfficient models qaic_config = { "include_sampler": True, @@ -117,3 +151,197 @@ def test_sampler_transform( assert ( input_name not in model_wo_sampler_session.input_names ), f"Sampler input {input_name} found in QPC compiled without Sampler" + + +@pytest.mark.on_qaic +@pytest.mark.parametrize( + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, num_devices, device_group, num_cores, spec_length, repetition_penalty, presence_penalty, temperature, top_k, top_p, min_p, random_number", + configs, +) +def test_greedy_sampling( + model: str, + prompts: List[str], + prefill_seq_len: int, + ctx_len: int, + generation_len: int, + full_batch_size: int, + num_devices: int, + device_group: List[int], + num_cores: int, + spec_length: int, + repetition_penalty: float, + presence_penalty: float, + temperature: float, + top_k: int, + top_p: float, + min_p: float, + random_number: float, +): + """ + Test greedy sampling with QPC compiled with and without On Device Sampling. + """ + # Export and compile QEfficient models + qaic_config = { + "include_sampler": True, + "return_pdfs": False, + "max_top_k_ids": 512, + } + model_w_sampler = AutoModelForCausalLM.from_pretrained(model, continuous_batching=True, qaic_config=qaic_config) + model_wo_sampler = AutoModelForCausalLM.from_pretrained(model, continuous_batching=True, qaic_config=None) + model_w_sampler.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=num_devices, + num_cores=num_cores, + num_speculative_tokens=spec_length - 1, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) + model_wo_sampler.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=num_devices, + num_cores=num_cores, + num_speculative_tokens=spec_length - 1, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) + + # Generate texts from prompts + model_w_sampler_exec_info = model_w_sampler.generate( + tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=model), + prompts=prompts, + device_id=device_group, + generation_len=generation_len, + include_sampler=True, + return_pdfs=False, + sampling_params={ + "repetition_penalties": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "presence_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + # "frequency_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "temperatures": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), + "top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + }, + ) + model_wo_sampler_exec_info = model_wo_sampler.generate( + tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=model), + prompts=prompts, + device_id=device_group, + generation_len=generation_len, + include_sampler=False, + return_pdfs=False, + sampling_params=None, + ) + + # Compare generated texts and ids + assert ( + model_w_sampler_exec_info.generated_texts == model_wo_sampler_exec_info.generated_texts + ), "Generated texts do not match" + assert ( + model_w_sampler_exec_info.generated_ids == model_wo_sampler_exec_info.generated_ids + ), "Generated ids do not match" + + +@pytest.mark.on_qaic +@pytest.mark.parametrize( + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, num_devices, device_group, num_cores, spec_length, repetition_penalty, presence_penalty, temperature, top_k, top_p, min_p, random_number", + configs, +) +def test_random_sampling( + model: str, + prompts: List[str], + prefill_seq_len: int, + ctx_len: int, + generation_len: int, + full_batch_size: int, + num_devices: int, + device_group: List[int], + num_cores: int, + spec_length: int, + repetition_penalty: float, + presence_penalty: float, + temperature: float, + top_k: int, + top_p: float, + min_p: float, + random_number: float, +): + """ + Test random sampling with QPC compiled with and without On Device Sampling. + """ + # Export and compile QEfficient models + qaic_config = { + "include_sampler": True, + "return_pdfs": False, + "max_top_k_ids": 512, + } + model_w_sampler = AutoModelForCausalLM.from_pretrained(model, continuous_batching=True, qaic_config=qaic_config) + model_wo_sampler = AutoModelForCausalLM.from_pretrained(model, continuous_batching=True, qaic_config=None) + model_w_sampler.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=num_devices, + num_cores=num_cores, + num_speculative_tokens=spec_length - 1, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) + model_wo_sampler.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=num_devices, + num_cores=num_cores, + num_speculative_tokens=spec_length - 1, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) + + # Generate texts from prompts + model_w_sampler_exec_info = model_w_sampler.generate( + tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=model), + prompts=prompts, + device_id=device_group, + generation_len=generation_len, + include_sampler=True, + return_pdfs=False, + sampling_params={ + "repetition_penalties": np.array(repetition_penalty, dtype=np.float32) + .repeat(full_batch_size) + .reshape(-1, 1), + "presence_penalties": np.array(presence_penalty, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + # "frequency_penalties": np.array(frequency_penalty, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "temperatures": np.array(temperature, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "top_ks": np.array(top_k, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), + "top_ps": np.array(top_p, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "min_ps": np.array(min_p, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.array(random_number, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + }, + ) + model_wo_sampler_exec_info = model_wo_sampler.generate( + tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=model), + prompts=prompts, + device_id=device_group, + generation_len=generation_len, + include_sampler=False, + return_pdfs=False, + sampling_params=None, + ) + + # Compare generated texts + golden_texts = { + "w_sampler": [""] * full_batch_size, + "wo_sampler": [""] * full_batch_size, + } + assert ( + model_w_sampler_exec_info.generated_texts == golden_texts["w_sampler"] + ), "Sampler generated texts do not match" + assert ( + model_wo_sampler_exec_info.generated_texts == golden_texts["wo_sampler"] + ), "Without sampler generated texts do not match" From f698a2486efcd3b224a39bc0db80b130049b373a Mon Sep 17 00:00:00 2001 From: sanising Date: Thu, 3 Jul 2025 13:44:25 -0500 Subject: [PATCH 09/14] Add test for random sampling Signed-off-by: sanising --- tests/transformers/sampler/test_sampler.py | 205 ++++++++++----------- 1 file changed, 101 insertions(+), 104 deletions(-) diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index 83474b204..8ab3cb90a 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -14,54 +14,24 @@ from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.utils import load_hf_tokenizer from QEfficient.utils.constants import Constants +from QEfficient.utils.device_utils import get_available_device_id configs = [ pytest.param( "meta-llama/Llama-3.1-8B", # model - Constants.INPUT_STR, # prompts + Constants.INPUT_STR * 4, # prompts 32, # prefill_seq_len 256, # ctx_len 20, # generation_len 4, # full_batch_size - 1, # num_devices - [0], # device_group - 16, # num_cores 1, # spec_length - 1.9, # repetition_penalty - 0.8, # presence_penalty - 0.67, # temperature - 54720, # top_k - 0.89, # top_p - 0.6, # min_p - 0.26, # random_number - id="Llama-3.1-8B_32_256_4_1_16_1", ), - # pytest.param( - # "meta-llama/Llama-3.1-8B", - # Constants.INPUT_STR, - # 32, - # 256, - # 20, - # 4, - # 4, - # [0, 1, 2, 3], - # 16, - # 1, - # 1.9, - # 0.8, - # 0.67, - # 54720, - # 0.89, - # 0.6, - # 0.26, - # id="Llama-3.1-8B_32_256_4_4_16_1", - # ), ] @pytest.mark.on_qaic @pytest.mark.parametrize( - "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, num_devices, device_group, num_cores, spec_length, repetition_penalty, presence_penalty, temperature, top_k, top_p, min_p, random_number", + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", configs, ) def test_sampler_transform( @@ -71,17 +41,7 @@ def test_sampler_transform( ctx_len: int, generation_len: int, full_batch_size: int, - num_devices: int, - device_group: List[int], - num_cores: int, spec_length: int, - repetition_penalty: float, - presence_penalty: float, - temperature: float, - top_k: int, - top_p: float, - min_p: float, - random_number: float, ): """ Test if `SamplerTransform` adds nodes at the output of a `QEffForCausalLM model` to enable the @@ -100,8 +60,8 @@ def test_sampler_transform( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, full_batch_size=full_batch_size, - num_devices=num_devices, - num_cores=num_cores, + num_devices=1, + num_cores=16, num_speculative_tokens=spec_length - 1, mxint8_kv_cache=True, mxfp6_matmul=True, @@ -110,8 +70,8 @@ def test_sampler_transform( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, full_batch_size=full_batch_size, - num_devices=num_devices, - num_cores=num_cores, + num_devices=1, + num_cores=16, num_speculative_tokens=spec_length - 1, mxint8_kv_cache=True, mxfp6_matmul=True, @@ -155,7 +115,7 @@ def test_sampler_transform( @pytest.mark.on_qaic @pytest.mark.parametrize( - "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, num_devices, device_group, num_cores, spec_length, repetition_penalty, presence_penalty, temperature, top_k, top_p, min_p, random_number", + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", configs, ) def test_greedy_sampling( @@ -165,17 +125,7 @@ def test_greedy_sampling( ctx_len: int, generation_len: int, full_batch_size: int, - num_devices: int, - device_group: List[int], - num_cores: int, spec_length: int, - repetition_penalty: float, - presence_penalty: float, - temperature: float, - top_k: int, - top_p: float, - min_p: float, - random_number: float, ): """ Test greedy sampling with QPC compiled with and without On Device Sampling. @@ -192,8 +142,8 @@ def test_greedy_sampling( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, full_batch_size=full_batch_size, - num_devices=num_devices, - num_cores=num_cores, + num_devices=1, + num_cores=16, num_speculative_tokens=spec_length - 1, mxint8_kv_cache=True, mxfp6_matmul=True, @@ -202,18 +152,19 @@ def test_greedy_sampling( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, full_batch_size=full_batch_size, - num_devices=num_devices, - num_cores=num_cores, + num_devices=1, + num_cores=16, num_speculative_tokens=spec_length - 1, mxint8_kv_cache=True, mxfp6_matmul=True, ) # Generate texts from prompts + tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model) model_w_sampler_exec_info = model_w_sampler.generate( - tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=model), + tokenizer=tokenizer, prompts=prompts, - device_id=device_group, + device_id=get_available_device_id(), generation_len=generation_len, include_sampler=True, return_pdfs=False, @@ -229,9 +180,9 @@ def test_greedy_sampling( }, ) model_wo_sampler_exec_info = model_wo_sampler.generate( - tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=model), + tokenizer=tokenizer, prompts=prompts, - device_id=device_group, + device_id=get_available_device_id(), generation_len=generation_len, include_sampler=False, return_pdfs=False, @@ -244,12 +195,12 @@ def test_greedy_sampling( ), "Generated texts do not match" assert ( model_w_sampler_exec_info.generated_ids == model_wo_sampler_exec_info.generated_ids - ), "Generated ids do not match" + ).all(), "Generated ids do not match" @pytest.mark.on_qaic @pytest.mark.parametrize( - "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, num_devices, device_group, num_cores, spec_length, repetition_penalty, presence_penalty, temperature, top_k, top_p, min_p, random_number", + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", configs, ) def test_random_sampling( @@ -259,17 +210,7 @@ def test_random_sampling( ctx_len: int, generation_len: int, full_batch_size: int, - num_devices: int, - device_group: List[int], - num_cores: int, spec_length: int, - repetition_penalty: float, - presence_penalty: float, - temperature: float, - top_k: int, - top_p: float, - min_p: float, - random_number: float, ): """ Test random sampling with QPC compiled with and without On Device Sampling. @@ -286,8 +227,8 @@ def test_random_sampling( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, full_batch_size=full_batch_size, - num_devices=num_devices, - num_cores=num_cores, + num_devices=1, + num_cores=16, num_speculative_tokens=spec_length - 1, mxint8_kv_cache=True, mxfp6_matmul=True, @@ -296,38 +237,37 @@ def test_random_sampling( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, full_batch_size=full_batch_size, - num_devices=num_devices, - num_cores=num_cores, + num_devices=1, + num_cores=16, num_speculative_tokens=spec_length - 1, mxint8_kv_cache=True, mxfp6_matmul=True, ) # Generate texts from prompts + tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model) model_w_sampler_exec_info = model_w_sampler.generate( - tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=model), + tokenizer=tokenizer, prompts=prompts, - device_id=device_group, + device_id=get_available_device_id(), generation_len=generation_len, include_sampler=True, return_pdfs=False, sampling_params={ - "repetition_penalties": np.array(repetition_penalty, dtype=np.float32) - .repeat(full_batch_size) - .reshape(-1, 1), - "presence_penalties": np.array(presence_penalty, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - # "frequency_penalties": np.array(frequency_penalty, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "temperatures": np.array(temperature, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "top_ks": np.array(top_k, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), - "top_ps": np.array(top_p, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "min_ps": np.array(min_p, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "random_numbers": np.array(random_number, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "repetition_penalties": np.array(1.9, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "presence_penalties": np.array(0.8, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + # "frequency_penalties": np.array(0.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "temperatures": np.array(0.67, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "top_ks": np.array(54720, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), + "top_ps": np.array(0.89, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "min_ps": np.array(0.6, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.array(0.26, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), }, ) model_wo_sampler_exec_info = model_wo_sampler.generate( - tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=model), + tokenizer=tokenizer, prompts=prompts, - device_id=device_group, + device_id=get_available_device_id(), generation_len=generation_len, include_sampler=False, return_pdfs=False, @@ -336,12 +276,69 @@ def test_random_sampling( # Compare generated texts golden_texts = { - "w_sampler": [""] * full_batch_size, - "wo_sampler": [""] * full_batch_size, + "w_sampler": " Kelsey and I am a 20 year old college student. My major in school right now,", + "wo_sampler": " Kaitlyn and I am a 20 year old college student. I am a junior at the", } - assert ( - model_w_sampler_exec_info.generated_texts == golden_texts["w_sampler"] - ), "Sampler generated texts do not match" - assert ( - model_wo_sampler_exec_info.generated_texts == golden_texts["wo_sampler"] - ), "Without sampler generated texts do not match" + golden_ids = { + "w_sampler": [ + [ + 735, + 93567, + 323, + 358, + 1097, + 264, + 220, + 508, + 1060, + 2362, + 7926, + 5575, + 13, + 3092, + 3682, + 304, + 2978, + 1314, + 1457, + 11, + ] + ], + "wo_sampler": [ + [ + 735, + 1339, + 18499, + 323, + 358, + 1097, + 264, + 220, + 508, + 1060, + 2362, + 7926, + 5575, + 13, + 358, + 1097, + 264, + 27144, + 520, + 279, + ] + ], + } + for i in range(full_batch_size): + assert ( + tokenizer.decode(model_w_sampler_exec_info.generated_ids[i][:generation_len]) == golden_texts["w_sampler"] + ), "Sampler generated texts does not match" + assert ( + model_w_sampler_exec_info.generated_ids[i][:generation_len] == golden_ids["w_sampler"] + ).all(), "Sampler generated ids do not match" + assert ( + tokenizer.decode(model_wo_sampler_exec_info.generated_ids[i][:generation_len]) == golden_texts["wo_sampler"] + ), "Without sampler generated texts does not match" + assert ( + model_wo_sampler_exec_info.generated_ids[i][:generation_len] == golden_ids["wo_sampler"] + ).all(), "Without sampler generated ids do not match" From 7b34a07769501c47c218fe5db3c8970efc848593 Mon Sep 17 00:00:00 2001 From: sanising Date: Thu, 3 Jul 2025 14:06:00 -0500 Subject: [PATCH 10/14] Remove else block Signed-off-by: sanising --- QEfficient/generation/text_generation_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 2e53fa275..6bbbf9169 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -456,6 +456,7 @@ def __init__( if session_input_name in sampler_inputs: count += 1 if count == len(sampler_inputs): + self.include_sampler = True break if count == 0: self.include_sampler = False @@ -465,8 +466,7 @@ def __init__( f"on the QAIC device (only {count}/{len(sampler_inputs)} inputs provided). Partial " "sampling support is not available. Please check the QPC and try again." ) - else: # count == len(sampler_inputs) - self.include_sampler = True + if include_sampler and not self.include_sampler: logger.warning_once( "User entered `include_sampler`=True. But the provided QPC is not compiled " From 0ee201ab7c8325cd796c248f182154c7ecf1bd25 Mon Sep 17 00:00:00 2001 From: sanising Date: Thu, 3 Jul 2025 14:16:12 -0500 Subject: [PATCH 11/14] Reformat code Signed-off-by: sanising --- examples/on_device_sampling.py | 5 ++- tests/transformers/sampler/test_sampler.py | 36 +++++++++++----------- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/examples/on_device_sampling.py b/examples/on_device_sampling.py index 09b99b681..00d8c2430 100644 --- a/examples/on_device_sampling.py +++ b/examples/on_device_sampling.py @@ -228,8 +228,7 @@ def main(args, **kwargs): "--top-k", type=int, default=None, - help="Sampling parameter that controls the number of top tokens to consider. Set to -1 " - "to consider all tokens.", + help="Sampling parameter that controls the number of top tokens to consider. Set to -1 to consider all tokens.", ) sampling_group.add_argument( "--top-p", @@ -250,7 +249,7 @@ def main(args, **kwargs): "--random-number", type=float, default=None, - help="Sampling parameter that represents the random seed to use for random sampling. " "Must be in [-1, 1].", + help="Sampling parameter that represents the random seed to use for random sampling. Must be in [-1, 1].", ) args, compiler_options = parser.parse_known_args() diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index 8ab3cb90a..95f00bbf3 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -105,12 +105,12 @@ def test_sampler_transform( "random_numbers", ] for input_name in sampler_inputs: - assert ( - input_name in model_w_sampler_session.input_names - ), f"Sampler input {input_name} not found in QPC compiled with Sampler" - assert ( - input_name not in model_wo_sampler_session.input_names - ), f"Sampler input {input_name} found in QPC compiled without Sampler" + assert input_name in model_w_sampler_session.input_names, ( + f"Sampler input {input_name} not found in QPC compiled with Sampler" + ) + assert input_name not in model_wo_sampler_session.input_names, ( + f"Sampler input {input_name} found in QPC compiled without Sampler" + ) @pytest.mark.on_qaic @@ -190,12 +190,12 @@ def test_greedy_sampling( ) # Compare generated texts and ids - assert ( - model_w_sampler_exec_info.generated_texts == model_wo_sampler_exec_info.generated_texts - ), "Generated texts do not match" - assert ( - model_w_sampler_exec_info.generated_ids == model_wo_sampler_exec_info.generated_ids - ).all(), "Generated ids do not match" + assert model_w_sampler_exec_info.generated_texts == model_wo_sampler_exec_info.generated_texts, ( + "Generated texts do not match" + ) + assert (model_w_sampler_exec_info.generated_ids == model_wo_sampler_exec_info.generated_ids).all(), ( + "Generated ids do not match" + ) @pytest.mark.on_qaic @@ -333,12 +333,12 @@ def test_random_sampling( assert ( tokenizer.decode(model_w_sampler_exec_info.generated_ids[i][:generation_len]) == golden_texts["w_sampler"] ), "Sampler generated texts does not match" - assert ( - model_w_sampler_exec_info.generated_ids[i][:generation_len] == golden_ids["w_sampler"] - ).all(), "Sampler generated ids do not match" + assert (model_w_sampler_exec_info.generated_ids[i][:generation_len] == golden_ids["w_sampler"]).all(), ( + "Sampler generated ids do not match" + ) assert ( tokenizer.decode(model_wo_sampler_exec_info.generated_ids[i][:generation_len]) == golden_texts["wo_sampler"] ), "Without sampler generated texts does not match" - assert ( - model_wo_sampler_exec_info.generated_ids[i][:generation_len] == golden_ids["wo_sampler"] - ).all(), "Without sampler generated ids do not match" + assert (model_wo_sampler_exec_info.generated_ids[i][:generation_len] == golden_ids["wo_sampler"]).all(), ( + "Without sampler generated ids do not match" + ) From df4aadd6c62b4fe378dc265e7f21276d470a3a83 Mon Sep 17 00:00:00 2001 From: quic-sanising Date: Thu, 24 Jul 2025 14:31:03 -0500 Subject: [PATCH 12/14] Add new sampling param: frequency_penalties Signed-off-by: quic-sanising --- .../transformers/models/modeling_auto.py | 7 +++- QEfficient/transformers/sampler/sampler.py | 42 +++++++++++++------ QEfficient/utils/constants.py | 1 + 3 files changed, 36 insertions(+), 14 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 2f3ee3dc0..3df96f399 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1652,13 +1652,18 @@ def get_sampling_inputs_and_outputs( dynamic_axes["repetition_penalties"] = {0: "batch_size"} example_inputs["past_presence_penalty_buffer"] = torch.zeros( - (fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool + (fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.int32 ) dynamic_axes["past_presence_penalty_buffer"] = { 0: "full_batch_size" if self.continuous_batching else "batch_size", } output_names.append("past_presence_penalty_buffer_RetainedState") + example_inputs["frequency_penalties"] = ( + torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_FREQUENCY_PENALTIES + ) + dynamic_axes["frequency_penalties"] = {0: "batch_size"} + example_inputs["presence_penalties"] = ( torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES ) diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index 6bcabf29a..753fbc788 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -12,7 +12,7 @@ from transformers.cache_utils import Cache from transformers.modeling_outputs import ModelOutput -from QEfficient.customop import CtxScatterFuncCB3D +from QEfficient.customop import CtxGatherFuncCB3D, CtxScatterFuncCB3D from QEfficient.utils.constants import Constants @@ -80,20 +80,23 @@ def decode_path( ) # Update retained states - scatter_values = torch.ones(last_accepted_output_tokens.shape, dtype=torch.bool) past_repetition_penalty_buffer = CtxScatterFuncCB3D.apply( past_repetition_penalty_buffer, batch_index, last_accepted_output_tokens, - scatter_values, + torch.ones(last_accepted_output_tokens.shape, dtype=torch.bool), + ) + gather_values = CtxGatherFuncCB3D.apply( + past_presence_penalty_buffer, + batch_index, + last_accepted_output_tokens, ) past_presence_penalty_buffer = CtxScatterFuncCB3D.apply( past_presence_penalty_buffer, batch_index, last_accepted_output_tokens, - scatter_values, + gather_values + 1, ) - # TODO: For frequency retain state, first gather and then scatter return past_repetition_penalty_buffer, past_presence_penalty_buffer @@ -116,6 +119,7 @@ def sampler_forward( past_repetition_penalty_buffer: Optional[torch.Tensor] = None, repetition_penalties: Optional[torch.Tensor] = None, past_presence_penalty_buffer: Optional[torch.Tensor] = None, + frequency_penalties: Optional[torch.Tensor] = None, presence_penalties: Optional[torch.Tensor] = None, temperatures: Optional[torch.Tensor] = None, top_ks: Optional[torch.Tensor] = None, @@ -141,8 +145,13 @@ def sampler_forward( new tokens, while values < 1 encourage the model to repeat tokens. past_presence_penalty_buffer (`torch.Tensor`, *optional*): - RetainedState buffer used as a mask to apply presence penalty to the output - generated so far. + RetainedState buffer used as a mask to apply frequency and presence penalties to + the output generated so far. + + frequency_penalties (`torch.Tensor`, *optional*): + Sampling parameter that penalizes new tokens based on their frequency in the + generated text so far. Values > 0 encourage the model to use new tokens, while + values < 0 encourage the model to repeat tokens. presence_penalties (`torch.Tensor`, *optional*): Sampling parameter that penalizes new tokens based on whether they appear in the @@ -240,17 +249,24 @@ def sampler_forward( repetition_penalties_mask = torch.where(past_repetition_penalty_buffer_selected, repetition_penalties, 1.0) logits *= repetition_penalties_mask ** (-torch.sign(logits)) + if (frequency_penalties != 0.0).any() or (presence_penalties != 0.0).any(): + past_presence_penalty_buffer_selected = past_presence_penalty_buffer[batch_index_reshaped].repeat( + spec_length, 1 + ) # (batch_size * spec_length, vocab_size) + + # Frequency Penalty + if (frequency_penalties != 0.0).any(): + frequency_penalties = frequency_penalties.repeat( + spec_length, 1 + ) # (batch_size, 1) -> (batch_size * spec_length, 1) + logits -= frequency_penalties * past_presence_penalty_buffer_selected + # Presence Penalty if (presence_penalties != 0.0).any(): presence_penalties = presence_penalties.repeat( spec_length, 1 ) # (batch_size, 1) -> (batch_size * spec_length, 1) - past_presence_penalty_buffer_selected = past_presence_penalty_buffer[batch_index_reshaped].repeat( - spec_length, 1 - ) # (batch_size * spec_length, vocab_size) - logits -= presence_penalties * past_presence_penalty_buffer_selected - - # TODO: Frequency Penalty + logits -= presence_penalties * (past_presence_penalty_buffer_selected > 0) # Temperature Scaling temperatures = temperatures.repeat(spec_length, 1) # (batch_size, 1) -> (batch_size * spec_length, 1) diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 5e855094c..84163333a 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -61,6 +61,7 @@ def get_models_dir(): QEFF_MODELS_DIR = get_models_dir() ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES = 0.5 +ONNX_EXPORT_EXAMPLE_FREQUENCY_PENALTIES = 0.5 ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES = 0.5 ONNX_EXPORT_EXAMPLE_TEMPERATURES = 0.80 ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS = 512 From c3f827cbaa0d69e5823eb35a2d606bf0a96a141c Mon Sep 17 00:00:00 2001 From: quic-sanising Date: Thu, 24 Jul 2025 14:49:00 -0500 Subject: [PATCH 13/14] Enable frequency penalties Signed-off-by: quic-sanising --- QEfficient/generation/text_generation_inference.py | 5 ++++- examples/on_device_sampling.py | 12 +++++++++++- tests/transformers/sampler/test_sampler.py | 5 +++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index d31378fe4..8f0927947 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -352,7 +352,7 @@ def cloud_ai_100_exec_kv( Decoding Draft Language Model and `return_pdfs`=False for regular model. sampling_params (Dict[str, Any]): A dictionary of sampling parameters supported by the QAIC backend. The dictionary should contain the following keys: - `repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`, + `repetition_penalties`, `frequency_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`, `min_ps`, and `random_numbers`. Each value should be a numpy array of shape (batch_size, 1). Returns: @@ -444,6 +444,7 @@ def __init__( sampler_inputs = [ "last_accepted_output_tokens", "repetition_penalties", + "frequency_penalties", "presence_penalties", "temperatures", "top_ks", @@ -645,6 +646,7 @@ def prepare_decode_inputs(self): decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] for op in [ "repetition_penalties", + "frequency_penalties", "presence_penalties", "temperatures", "top_ks", @@ -819,6 +821,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i inputs["last_accepted_output_tokens"] = inputs["input_ids"] for op in [ "repetition_penalties", + "frequency_penalties", "presence_penalties", "temperatures", "top_ks", diff --git a/examples/on_device_sampling.py b/examples/on_device_sampling.py index 00d8c2430..8431d5a83 100644 --- a/examples/on_device_sampling.py +++ b/examples/on_device_sampling.py @@ -30,8 +30,8 @@ def main(args, **kwargs): max_top_k_ids = int(args.override_qaic_config.get("max_top_k_ids", 512)) sampling_params = { "repetition_penalties": np.array(args.repetition_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), + "frequency_penalties": np.array(args.frequency_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), "presence_penalties": np.array(args.presence_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), - # "frequency_penalties": np.array(args.frequency_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), "temperatures": np.array(args.temperature, dtype=np.float32).repeat(bs).reshape(-1, 1), "top_ks": np.array(args.top_k, dtype=np.int32).repeat(bs).reshape(-1, 1), "top_ps": np.array(args.top_p, dtype=np.float32).repeat(bs).reshape(-1, 1), @@ -108,6 +108,7 @@ def main(args, **kwargs): --mxfp6-matmul \ --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \ --repetition-penalty 1.9 \ + --frequency-penalty 0.8 \ --presence-penalty 0.8 \ --temperature 0.67 \ --top-k 54720 \ @@ -128,6 +129,7 @@ def main(args, **kwargs): --mxfp6-matmul \ --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \ --repetition-penalty 1.9 \ + --frequency-penalty 0.8 \ --presence-penalty 0.8 \ --temperature 0.67 \ --top-k 54720 \ @@ -208,6 +210,14 @@ def main(args, **kwargs): "prompt and the generated text so far. Values > 1 encourage the model to use new tokens, " "while values < 1 encourage the model to repeat tokens.", ) + sampling_group.add_argument( + "--frequency-penalty", + type=float, + default=None, + help="Sampling parameter that penalizes new tokens based on their frequency in the " + "generated text so far. Values > 0 encourage the model to use new tokens, while values < " + "0 encourage the model to repeat tokens.", + ) sampling_group.add_argument( "--presence-penalty", type=float, diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index 95f00bbf3..70525c3bd 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -97,6 +97,7 @@ def test_sampler_transform( sampler_inputs = [ "last_accepted_output_tokens", "repetition_penalties", + "frequency_penalties", "presence_penalties", "temperatures", "top_ks", @@ -170,8 +171,8 @@ def test_greedy_sampling( return_pdfs=False, sampling_params={ "repetition_penalties": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "frequency_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "presence_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - # "frequency_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "temperatures": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), "top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), @@ -255,8 +256,8 @@ def test_random_sampling( return_pdfs=False, sampling_params={ "repetition_penalties": np.array(1.9, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "frequency_penalties": np.array(0.8, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "presence_penalties": np.array(0.8, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - # "frequency_penalties": np.array(0.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "temperatures": np.array(0.67, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "top_ks": np.array(54720, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), "top_ps": np.array(0.89, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), From 574b8ad60167406024915d50ac051c246e8aa80c Mon Sep 17 00:00:00 2001 From: quic-sanising Date: Thu, 24 Jul 2025 16:59:43 -0500 Subject: [PATCH 14/14] Remove CtxGatherFuncCB3D as it does not support 2D ctx_indices Signed-off-by: quic-sanising --- QEfficient/transformers/sampler/sampler.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index df45dc6f0..9689a67c9 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -12,7 +12,7 @@ from transformers.cache_utils import Cache from transformers.modeling_outputs import ModelOutput -from QEfficient.customop import CtxGatherFuncCB3D, CtxScatterFuncCB3D +from QEfficient.customop import CtxScatterFuncCB3D from QEfficient.utils.constants import Constants @@ -86,11 +86,7 @@ def decode_path( last_accepted_output_tokens, torch.ones(last_accepted_output_tokens.shape, dtype=torch.bool), ) - gather_values = CtxGatherFuncCB3D.apply( - past_presence_penalty_buffer, - batch_index, - last_accepted_output_tokens, - ) + gather_values = past_presence_penalty_buffer[batch_index, last_accepted_output_tokens] past_presence_penalty_buffer = CtxScatterFuncCB3D.apply( past_presence_penalty_buffer, batch_index,