Skip to content

Add Support for Frequency Penalties in On Device Sampling #523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 171 additions & 27 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections import deque
from dataclasses import dataclass
from time import perf_counter
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import transformers
Expand Down Expand Up @@ -322,6 +322,9 @@ def cloud_ai_100_exec_kv(
automation=False,
prompt_to_lora_id_mapping: Optional[List[int]] = None,
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
):
"""
This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards.
Expand All @@ -342,6 +345,15 @@ def cloud_ai_100_exec_kv(
:Write_io_dir (str): Path to write the input and output files. ``Defaults to None``.
:automation (bool): If true, it prints input, output, and performance stats. ``Defaults to False``.
:prompt_to_lora_id_mapping (List[int]): Mapping to associate prompts with their respective LoRA adapter.
:include_sampler (bool): Enable/Disable sampling of next tokens.
:return_pdfs (bool): Return probability distributions along with sampled
next tokens. For Speculative Decoding Target Language Model,
`return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative
Decoding Draft Language Model and `return_pdfs`=False for regular model.
sampling_params (Dict[str, Any]): A dictionary of sampling parameters supported by the QAIC backend.
The dictionary should contain the following keys:
`repetition_penalties`, `frequency_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`,
`min_ps`, and `random_numbers`. Each value should be a numpy array of shape (batch_size, 1).

Returns:
:CloudAI100ExecInfo: Object holding execution output and performance details.
Expand Down Expand Up @@ -372,6 +384,9 @@ def cloud_ai_100_exec_kv(
write_io_dir=write_io_dir,
full_batch_size=full_batch_size,
is_tlm=is_tlm,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
sampling_params=sampling_params,
)
if full_batch_size is None:
exec_info = [
Expand Down Expand Up @@ -411,14 +426,60 @@ def __init__(
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
is_tlm: Optional[int] = None,
include_sampler: bool = False,
return_pdfs: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
) -> None:
self._ctx_len = ctx_len
self._write_io_dir = write_io_dir
self.is_tlm = is_tlm
self.include_sampler = include_sampler
self.return_pdfs = return_pdfs
self.sampling_params = sampling_params

# Load QPC
self._session = QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs)

# Validate sampler inputs for On-Device Sampling
sampler_inputs = [
"last_accepted_output_tokens",
"repetition_penalties",
"frequency_penalties",
"presence_penalties",
"temperatures",
"top_ks",
"top_ps",
"min_ps",
"random_numbers",
]
count = 0
for session_input_name in self._session.input_names:
if session_input_name in sampler_inputs:
count += 1
if count == len(sampler_inputs):
self.include_sampler = True
break
if count == 0:
self.include_sampler = False
elif count < len(sampler_inputs):
raise ValueError(
"The provided QPC does not have the required number of inputs to run sampling "
f"on the QAIC device (only {count}/{len(sampler_inputs)} inputs provided). Partial "
"sampling support is not available. Please check the QPC and try again."
)

if include_sampler and not self.include_sampler:
logger.warning_once(
"User entered `include_sampler`=True. But the provided QPC is not compiled "
"to run sampling on the QAIC device. Falling back to the PyTorch backend."
)
elif (include_sampler is None or not include_sampler) and self.include_sampler:
raise ValueError(
"The provided QPC is compiled to run sampling on the QAIC device. "
"But the user did not enter `include_sampler`=True. Please make sure the input "
"is specified correctly."
)

# Fetch the variables from the QPC
self._vocab_size = self._fetch_vocab_size() # Fetch Vocab size
self.batch_size, self._prefill_seq_len = self._fetch_batch_size_prefill_seq_len()
Expand Down Expand Up @@ -523,10 +584,17 @@ def _fetch_vocab_size(
Returns:
vocab_size: The vocabulary size fetched from the session's allowed shapes.
"""
key = (
"probs"
if self.include_sampler and self.return_pdfs
else "next_tokens"
if self.include_sampler
else "logits"
)
if self._session.allowed_shapes:
return [x[self._session.binding_index_map["logits"]] for x in self._session.allowed_shapes][0][1][2]
return [x[self._session.binding_index_map[key]] for x in self._session.allowed_shapes][0][1][2]

return self._session.bindings[self._session.binding_index_map["logits"]].dims[2]
return self._session.bindings[self._session.binding_index_map[key]].dims[2]

def _fetch_generation_len(self, generation_len, max_gen_len):
"""
Expand Down Expand Up @@ -574,6 +642,22 @@ def prepare_decode_inputs(self):
decode_inputs["position_ids"] = self.decode_pos_ids
if self.batch_index is not None:
decode_inputs["batch_index"] = self.batch_index
if self.include_sampler:
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]
for op in [
"repetition_penalties",
"frequency_penalties",
"presence_penalties",
"temperatures",
"top_ks",
"top_ps",
"min_ps",
"random_numbers",
]:
if self.batch_index is not None:
decode_inputs[op] = self.sampling_params[op][self.batch_index.flatten()]
else:
decode_inputs[op] = self.sampling_params[op]

if self._prompt_to_lora_id_mapping_decode:
if self.full_batch_size:
Expand All @@ -589,21 +673,24 @@ def prepare_decode_inputs(self):

def _fetch_next_token_id(self, outputs):
"""
Fetches the next token ID from the model's output logits.
The method identifies the token with the highest probability using argmax along the last dimension.
Fetches the next token ID from the model's output.

Args:
outputs (dict): A dictionary containing the model's output logits. The key "logits" should map to a numpy array of shape (batch_size, sequence_length, vocab_size) or (batch_size, vocab_size).
outputs (dict): A dictionary containing the model's output.

Returns:
numpy.ndarray: An array of the next token IDs for each sequence in the batch.
"""
logits = outputs["logits"]
if len(logits.shape) == 2:
logits = np.expand_dims(logits, 1)

# Get output token
next_token_id = logits.argmax(2)
return next_token_id
if self.include_sampler:
if self.return_pdfs:
return outputs["probs"].argmax(2)
else:
return outputs["next_tokens"].reshape(outputs["next_tokens"].shape[0], outputs["next_tokens"].shape[1])
else:
logits = outputs["logits"]
if len(logits.shape) == 2:
logits = np.expand_dims(logits, 1)
return logits.argmax(2)

def initialize_decode_inputs(self, num_prompts, execution_batch_size, max_gen_length):
"""
Expand Down Expand Up @@ -673,6 +760,23 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len):

_ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id)

def _set_output_buffers(self, batch_size: int = 1, sequence_length: int = 1):
"""
Sets the sizes of the output buffers.

Args:
batch_size (int): The batch size.
"""
if self.include_sampler:
if self.return_pdfs:
probs_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32)
self._session.set_buffers({"probs": probs_out_placeholder})
next_tokens_out_placeholder = np.zeros((batch_size, sequence_length, 1), dtype=np.int64)
self._session.set_buffers({"next_tokens": next_tokens_out_placeholder})
else:
logits_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32)
self._session.set_buffers({"logits": logits_out_placeholder})

def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None):
"""
Runs prefill for a given prompt and generation length.
Expand Down Expand Up @@ -702,9 +806,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
max_gen_len = self._ctx_len - position_ids.max()
generation_len = self._fetch_generation_len(generation_len, max_gen_len)

# Set the prefill logic buffer
logits_out_placeholder = np.zeros((prefill_logit_bs, 1, self._vocab_size), dtype=np.float32)
self._session.set_buffers({"logits": logits_out_placeholder})
# Set the prefill output buffers
self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1)

inputs = self.tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
Expand All @@ -714,6 +817,22 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
inputs["batch_index"] = decode_batch_id
if self.is_tlm:
inputs["num_logits_to_keep"] = np.zeros((1, 1))
if self.include_sampler:
inputs["last_accepted_output_tokens"] = inputs["input_ids"]
for op in [
"repetition_penalties",
"frequency_penalties",
"presence_penalties",
"temperatures",
"top_ks",
"top_ps",
"min_ps",
"random_numbers",
]:
if decode_batch_id is not None:
inputs[op] = self.sampling_params[op][decode_batch_id.flatten()]
else:
inputs[op] = self.sampling_params[op]

if self._prompt_to_lora_id_mapping_prefill:
if self.full_batch_size:
Expand All @@ -732,6 +851,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
chunk_inputs["position_ids"] = inputs["position_ids"][
:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len
]
if self.include_sampler:
chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"]
outputs = self._session.run(chunk_inputs)
if self._write_io_dir is not None:
write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False)
Expand All @@ -753,11 +874,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):

"""

# Set logits placeholder for decode
logits_out_placeholder = np.zeros(
(self.full_batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32
# Set output placeholders for decode
self._set_output_buffers(
batch_size=self.full_batch_size,
sequence_length=self._decode_seq_len,
)
self._session.set_buffers({"logits": logits_out_placeholder})

# Generate flag for tracking progress for each batch ID
current_decode_ongoing = np.full((self.full_batch_size, 1), True)

Expand All @@ -775,10 +897,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
outputs = self._session.run(decode_inputs)

# Prepare inputs for next iteration
logits = outputs["logits"]
if len(logits.shape) == 2:
logits = np.expand_dims(logits, 1)
next_token_id = logits.argmax(2)
next_token_id = self._fetch_next_token_id(outputs)

for decode_batch_id in range(self.full_batch_size):
if (
Expand All @@ -800,7 +919,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
self.generated_ids[batch_id_map[decode_batch_id], 0] = new_token_id.squeeze(1)
generated_id_current_index[decode_batch_id] = 1

self._session.set_buffers({"logits": logits_out_placeholder})
self._set_output_buffers(
batch_size=self.full_batch_size,
sequence_length=self._decode_seq_len,
)
decode_pause_time += perf_counter() - start

if self._prompt_to_lora_id_mapping_decode:
Expand All @@ -817,6 +939,8 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
self.generated_ids[batch_id_map[decode_batch_id], generated_id_current_index[decode_batch_id]] = (
next_token_id[decode_batch_id, -1]
)
if self.include_sampler:
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]

generated_id_current_index[decode_batch_id] += 1

Expand All @@ -840,6 +964,11 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
(self.batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32
)
self._session.set_buffers({"logits": logits_out_placeholder})
else:
self._set_output_buffers(
batch_size=self.batch_size,
sequence_length=self._decode_seq_len,
)
finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id
num_token = 0
for num_token in range(1, generation_len):
Expand All @@ -852,10 +981,12 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
self._write_io_dir = None

# Prepare inputs for next iteration
decode_inputs["input_ids"] = outputs["logits"].argmax(2)
decode_inputs["input_ids"] = self._fetch_next_token_id(outputs)
decode_inputs["position_ids"][:, -1] += 1
self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1]
finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id
if self.include_sampler:
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]

if finished_sequences.all():
break
Expand Down Expand Up @@ -905,9 +1036,22 @@ def __init__(
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
) -> None:
self._qaic_model = QEffTextGenerationBase(
tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir, is_tlm
tokenizer=tokenizer,
qpc_path=qpc_path,
full_batch_size=full_batch_size,
ctx_len=ctx_len,
device_id=device_id,
enable_debug_logs=enable_debug_logs,
write_io_dir=write_io_dir,
is_tlm=is_tlm,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
sampling_params=sampling_params,
)
self._full_batch_size = self._qaic_model.full_batch_size
self._tokenizer = self._qaic_model.tokenizer
Expand Down
8 changes: 7 additions & 1 deletion QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1652,13 +1652,18 @@ def get_sampling_inputs_and_outputs(
dynamic_axes["repetition_penalties"] = {0: "batch_size"}

example_inputs["past_presence_penalty_buffer"] = torch.zeros(
(fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool
(fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.int32
)
dynamic_axes["past_presence_penalty_buffer"] = {
0: "full_batch_size" if self.continuous_batching else "batch_size",
}
output_names.append("past_presence_penalty_buffer_RetainedState")

example_inputs["frequency_penalties"] = (
torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_FREQUENCY_PENALTIES
)
dynamic_axes["frequency_penalties"] = {0: "batch_size"}

example_inputs["presence_penalties"] = (
torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES
)
Expand Down Expand Up @@ -1893,6 +1898,7 @@ def generate(
device_id=device_id,
generation_len=generation_len,
is_tlm=self.is_tlm,
**kwargs,
)
else:
raise NotImplementedError("Only AI_100 runtime is supported right now via generate API")
Expand Down
Loading
Loading