diff --git a/QEfficient/cloud/infer.py b/QEfficient/cloud/infer.py index 1c620ad7d..a4d2c9207 100644 --- a/QEfficient/cloud/infer.py +++ b/QEfficient/cloud/infer.py @@ -102,6 +102,7 @@ def main( full_batch_size: Optional[int] = None, prompt_len: int = 32, ctx_len: int = 128, + comp_ctx_lengths: Optional[List[int]] = None, generation_len: Optional[int] = None, mxfp6: bool = False, mxint8: bool = False, @@ -165,6 +166,7 @@ def main( cache_dir=cache_dir, hf_token=hf_token, full_batch_size=full_batch_size, + comp_ctx_lengths=comp_ctx_lengths, local_model_dir=local_model_dir, trust_remote_code=trust_remote_code, ) @@ -260,6 +262,12 @@ def main( "--prompt-len", "--prompt_len", default=32, type=int, help="Sequence length for text generation." ) parser.add_argument("--ctx-len", "--ctx_len", default=128, type=int, help="Context length for text generation.") + parser.add_argument( + "--comp_ctx_lengths", + "--comp_ctx_lengths", + type=lambda comp_ctx_lengths: [int(x) for x in comp_ctx_lengths.strip("[]").split(",")], + help="Compute Context length for text generation (comma-separated) e.g. [512,1024,2048] ", + ) parser.add_argument( "--mxfp6", "--mxfp6_matmul", diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index c4f5a7bbd..269ccb0be 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -115,8 +115,14 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) -def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: - ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0])) +def CtxGather( + data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32 +) -> onnxscript.FLOAT: + # Create a shape tensor based on comp_ctx_len + shape_tensor = ops.Concat(ops.Shape(data)[:2], ops.Reshape(comp_ctx_len, [1]), axis=0) + + # Directly use the shape tensor without validation + ctx_indices = ops.Expand(ctx_indices, shape_tensor) ctx_indices = ops.Unsqueeze(ctx_indices, [-1]) return ops.GatherND(data, ctx_indices, batch_dims=2) @@ -127,7 +133,7 @@ class CtxGatherFunc(torch.autograd.Function): """ @staticmethod - def forward(data: torch.Tensor, ctx_indices: torch.Tensor): + def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int): batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1) head_indices = torch.arange(data.shape[1]).view(1, -1, 1) return data[batch_indices, head_indices, ctx_indices] @@ -137,5 +143,5 @@ def setup_context(ctx, inputs, outputs): pass @staticmethod - def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: - return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data) + def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value: + return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data) diff --git a/QEfficient/customop/ctx_scatter_gather_cb.py b/QEfficient/customop/ctx_scatter_gather_cb.py index 75d9a12ef..cc9693716 100644 --- a/QEfficient/customop/ctx_scatter_gather_cb.py +++ b/QEfficient/customop/ctx_scatter_gather_cb.py @@ -97,16 +97,20 @@ def symbolic( @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxGatherCB( - data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32 + data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32 ) -> onnxscript.FLOAT: batch_size = ops.Gather(ops.Shape(batch_index), [0]) num_heads = ops.Gather(ops.Shape(data), [1]) - ctx_len = ops.Gather(ops.Shape(data), [2]) + # using compute-context-length (CCL) instead of context-length to do gather process based on CCL and later do attention computations based on CCL as well. + ctx_len = ops.Reshape(comp_ctx_len, [1]) # Expanded shape to create indices zero = ops.Constant(value_ints=[0]) one = ops.Constant(value_ints=[1]) - exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0) + # exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0) + exp_shape = ops.Concat( + ops.Reshape(batch_size, [1]), ops.Reshape(num_heads, [1]), ops.Reshape(ctx_len, [1]), one, axis=0 + ) # Create indices batch_idx = ops.Expand(ops.Unsqueeze(batch_index, [2, 3]), exp_shape) @@ -119,7 +123,7 @@ def CtxGatherCB( class CtxGatherFuncCB(torch.autograd.Function): @staticmethod - def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor): + def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int): batch_indices = batch_index.view(-1, 1, 1) head_indices = torch.arange(data.shape[1]).view(1, -1, 1) return data[batch_indices, head_indices, ctx_indices] @@ -129,8 +133,10 @@ def setup_context(ctx, inputs, outputs): pass @staticmethod - def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value: - return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices).setTypeAs(data) + def symbolic( + g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int + ) -> torch.Value: + return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data) @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index a9690aa51..d8420afdc 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -316,6 +316,7 @@ def cloud_ai_100_exec_kv( prompts_txt_file_path: Optional[str] = None, device_id: Optional[List[int]] = None, generation_len: Optional[int] = None, + comp_ctx_lengths: Optional[List[int]] = None, enable_debug_logs: bool = False, stream: bool = True, write_io_dir: Optional[str] = None, @@ -368,6 +369,7 @@ def cloud_ai_100_exec_kv( qpc_path=qpc_path, device_id=device_id, ctx_len=ctx_len, + comp_ctx_lengths=comp_ctx_lengths, enable_debug_logs=enable_debug_logs, write_io_dir=write_io_dir, full_batch_size=full_batch_size, @@ -407,12 +409,14 @@ def __init__( qpc_path: str, full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, + comp_ctx_lengths: Optional[List[int]] = None, device_id: Optional[List[int]] = None, enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, is_tlm: Optional[int] = None, ) -> None: self._ctx_len = ctx_len + self.comp_ctx_lengths = comp_ctx_lengths self._write_io_dir = write_io_dir self.is_tlm = is_tlm @@ -724,6 +728,11 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)] inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) + if self.comp_ctx_lengths is not None: + inputs["comp_ctx_lengths"] = np.random.rand(self.comp_ctx_lengths[0]) + buffers = {"comp_ctx_len_out": np.zeros(1)} + self._session.set_buffers(buffers) + for i in range(num_chunks): chunk_inputs = inputs.copy() chunk_inputs["input_ids"] = inputs["input_ids"][ @@ -741,6 +750,18 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i generation_len, ) + def initialize_ccl(self, decode_inputs): + max_ccl_id = len(self.comp_ctx_lengths) - 1 + max_position_id = np.max(decode_inputs["position_ids"]) + ccl_id = 1 + for i in range(1, len(self.comp_ctx_lengths)): + if max_position_id < self.comp_ctx_lengths[i]: + ccl_id = i + break + buffers = {"comp_ctx_len_out": np.zeros(1)} + + return buffers, ccl_id, max_ccl_id + def run_continuous_batching_decode(self, prompt_queue, generation_len): """ Runs continuous batching decode for the given prompt queue and generation length. @@ -771,6 +792,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): # Prepare decode inputs inputs. decode_inputs = self.prepare_decode_inputs() + if self.comp_ctx_lengths is not None: + list_of_comp_ctx_lengths = [np.zeros(length) for length in self.comp_ctx_lengths] + buffers, ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs) + decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id] + self._session.set_buffers(buffers) + while prompt_queue or current_decode_ongoing.any(): outputs = self._session.run(decode_inputs) @@ -808,6 +835,19 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): batch_id_map[decode_batch_id] ] + if self.comp_ctx_lengths is not None: + ###Recalculate ccl_id based on position ids### + # Determine the maximum value of position_ids across all batch elements + max_position_id = np.max(decode_inputs["position_ids"]) + + # Update ccl_id and comp_ctx_lengths based on the maximum position id + ccl_id = 1 + for i in range(1, len(self.comp_ctx_lengths)): + if max_position_id < self.comp_ctx_lengths[i]: + ccl_id = i + break + decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id] + else: current_decode_ongoing[decode_batch_id] = False else: @@ -818,6 +858,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): next_token_id[decode_batch_id, -1] ) + if self.comp_ctx_lengths is not None: + # Update ccl_id and comp_ctx_lengths based on the maximum position id + if decode_inputs["position_ids"][decode_batch_id, -1] >= self.comp_ctx_lengths[ccl_id] - 1: + ccl_id = min(ccl_id + 1, max_ccl_id) + decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id] + generated_id_current_index[decode_batch_id] += 1 return decode_pause_time @@ -842,7 +888,21 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform self._session.set_buffers({"logits": logits_out_placeholder}) finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id num_token = 0 + + if self.comp_ctx_lengths is not None: + list_of_comp_ctx_lengths = [np.zeros(length) for length in self.comp_ctx_lengths] + buffers, ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs) + decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id] + self._session.set_buffers(buffers) + + cache_index = np.max(decode_inputs["position_ids"]) for num_token in range(1, generation_len): + if self.comp_ctx_lengths is not None: + if cache_index >= self.comp_ctx_lengths[ccl_id] - 1: + # if cache_index >= self.comp_ctx_lengths[ccl_id] - 1: + ccl_id = min(ccl_id + 1, max_ccl_id) + decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id] + if streamer: streamer.put(decode_inputs["input_ids"][0]) outputs = self._session.run(decode_inputs) @@ -854,6 +914,7 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform # Prepare inputs for next iteration decode_inputs["input_ids"] = outputs["logits"].argmax(2) decode_inputs["position_ids"][:, -1] += 1 + cache_index += 1 self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1] finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id @@ -901,17 +962,27 @@ def __init__( qpc_path: str, full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, + comp_ctx_lengths: Optional[List[int]] = None, device_id: Optional[List[int]] = None, enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, is_tlm: bool = False, ) -> None: self._qaic_model = QEffTextGenerationBase( - tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir, is_tlm + tokenizer, + qpc_path, + full_batch_size, + ctx_len, + comp_ctx_lengths, + device_id, + enable_debug_logs, + write_io_dir, + is_tlm, ) self._full_batch_size = self._qaic_model.full_batch_size self._tokenizer = self._qaic_model.tokenizer self._ctx_len = ctx_len + self.comp_ctx_lengths = comp_ctx_lengths self._perf_metrics = None self._prompt_queue = None self._text_streamer = None diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 16767fbe2..e159ec69d 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -91,6 +91,8 @@ def read_only(self, layer_idx, cache_kwargs): k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) + comp_ctx_len = cache_kwargs.get("CCL") + ctx_len = k_out.shape[2] ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) @@ -101,15 +103,19 @@ def read_only(self, layer_idx, cache_kwargs): else: invalid_idx_value = 0 + ctx_indices = ctx_indices[:, :, :comp_ctx_len] + invalid_mask = ctx_indices > gather_limit + + invalid_mask = invalid_mask[:, :, :comp_ctx_len] + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + k_out = CtxGatherFuncCB.apply(self.key_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len) + v_out = CtxGatherFuncCB.apply(self.value_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len) else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices) - v_out = CtxGatherFunc.apply(v_out, ctx_indices) - + k_out = CtxGatherFunc.apply(self.key_cache[layer_idx], ctx_indices, comp_ctx_len) + v_out = CtxGatherFunc.apply(self.value_cache[layer_idx], ctx_indices, comp_ctx_len) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out @@ -144,6 +150,7 @@ def update( else: position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs + comp_ctx_len = cache_kwargs.get("CCL") # Scatter if batch_index is not None: @@ -163,26 +170,29 @@ def update( self.value_cache[layer_idx], position_ids, value_states ) - k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] - # Gather - ctx_len = k_out.shape[2] + ctx_len = self.key_cache[layer_idx].shape[2] ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) - invalid_mask = ctx_indices > gather_limit if torch.onnx.is_in_onnx_export(): invalid_idx_value = torch.iinfo(torch.int32).max else: invalid_idx_value = 0 + ctx_indices = ctx_indices[:, :, :comp_ctx_len] + invalid_mask = ctx_indices > gather_limit + + invalid_mask = invalid_mask[:, :, :comp_ctx_len] + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + k_out = CtxGatherFuncCB.apply(self.key_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len) + v_out = CtxGatherFuncCB.apply(self.value_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len) else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices) - v_out = CtxGatherFunc.apply(v_out, ctx_indices) + k_out = CtxGatherFunc.apply(self.key_cache[layer_idx], ctx_indices, comp_ctx_len) + v_out = CtxGatherFunc.apply(self.value_cache[layer_idx], ctx_indices, comp_ctx_len) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index bd5e85d84..031e443a7 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +from dataclasses import dataclass from typing import Callable, List, Optional, Tuple, Union import torch @@ -29,6 +30,16 @@ from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +@dataclass +class QEffBaseModelOutputWithPast(BaseModelOutputWithPast): + comp_ctx_len_out: Optional[torch.LongTensor] = None + + +@dataclass +class QEffCausalLMOutputWithPast(CausalLMOutputWithPast): + comp_ctx_len_out: Optional[torch.LongTensor] = None + + class QEffGemmaRotaryEmbedding(GemmaRotaryEmbedding): """ Copied from GemmaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py @@ -135,6 +146,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -153,8 +165,17 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + "batch_index": batch_index, + "position_ids": position_ids, + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -187,6 +208,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -219,6 +241,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -257,6 +280,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -318,6 +342,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -338,11 +363,13 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() - output = BaseModelOutputWithPast( + comp_ctx_len_out = comp_ctx_lengths[comp_ctx_lengths.shape[-1] - 1 :] if comp_ctx_lengths is not None else None + output = QEffBaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, + comp_ctx_len_out=comp_ctx_len_out, ) return output if return_dict else output.to_tuple() @@ -360,6 +387,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -383,6 +411,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -400,10 +429,12 @@ def forward( logits = self.lm_head(hidden_states).float() logits = logits.float() - return CausalLMOutputWithPast( + comp_ctx_len_out = comp_ctx_lengths[comp_ctx_lengths.shape[-1] - 1 :] if comp_ctx_lengths is not None else None + return QEffCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + comp_ctx_len_out=comp_ctx_len_out, ) diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index fa0b3cc49..99d0e17da 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +from dataclasses import dataclass from typing import Callable, List, Optional, Tuple, Union import torch @@ -32,6 +33,16 @@ from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +@dataclass +class QEffBaseModelOutputWithPast(BaseModelOutputWithPast): + comp_ctx_len_out: Optional[torch.LongTensor] = None + + +@dataclass +class QEffCausalLMOutputWithPast(CausalLMOutputWithPast): + comp_ctx_len_out: Optional[torch.LongTensor] = None + + class QEffGemma2RotaryEmbedding(Gemma2RotaryEmbedding): """ Copied from Gemma2RotaryEmbedding: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma2/modeling_gemma2.py @@ -141,6 +152,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -159,8 +171,17 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + "batch_index": batch_index, + "position_ids": position_ids, + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -193,6 +214,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -225,6 +247,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -265,6 +288,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -337,6 +361,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -358,11 +383,13 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() - output = BaseModelOutputWithPast( + comp_ctx_len_out = comp_ctx_lengths[comp_ctx_lengths.shape[-1] - 1 :] if comp_ctx_lengths is not None else None + output = QEffBaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, + comp_ctx_len_out=comp_ctx_len_out, ) return output if return_dict else output.to_tuple() @@ -380,6 +407,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -403,6 +431,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -423,10 +452,12 @@ def forward( logits = torch.tanh(logits) logits = logits * self.config.final_logit_softcapping - return CausalLMOutputWithPast( + comp_ctx_len_out = comp_ctx_lengths[comp_ctx_lengths.shape[-1] - 1 :] if comp_ctx_lengths is not None else None + return QEffCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + comp_ctx_len_out=comp_ctx_len_out, ) diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index af4ebfc92..36bc2145c 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +from dataclasses import dataclass from typing import Callable, List, Optional, Tuple, Union import torch @@ -28,6 +29,16 @@ from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +@dataclass +class QEffBaseModelOutputWithPast(BaseModelOutputWithPast): + comp_ctx_len_out: Optional[torch.LongTensor] = None + + +@dataclass +class QEffCausalLMOutputWithPast(CausalLMOutputWithPast): + comp_ctx_len_out: Optional[torch.LongTensor] = None + + class QEffGraniteRotaryEmbedding(GraniteRotaryEmbedding): """ Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -126,6 +137,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -144,8 +156,16 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -170,6 +190,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -225,6 +246,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -246,11 +268,13 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() if use_cache else None - output = BaseModelOutputWithPast( + comp_ctx_len_out = comp_ctx_lengths[comp_ctx_lengths.shape[-1] - 1 :] if comp_ctx_lengths is not None else None + output = QEffBaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, + comp_ctx_len_out=comp_ctx_len_out, ) return output if return_dict else output.to_tuple() @@ -266,6 +290,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -318,6 +343,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -335,10 +361,12 @@ def forward( logits = self.lm_head(hidden_states) logits = logits.float() - return CausalLMOutputWithPast( + comp_ctx_len_out = comp_ctx_lengths[comp_ctx_lengths.shape[-1] - 1 :] if comp_ctx_lengths is not None else None + return QEffCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + comp_ctx_len_out=comp_ctx_len_out, ) diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 0cccd7fcf..95872fa6c 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +from dataclasses import dataclass from typing import Callable, List, Optional, Tuple, Union import torch @@ -29,6 +30,16 @@ from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +@dataclass +class QEffBaseModelOutputWithPast(BaseModelOutputWithPast): + comp_ctx_len_out: Optional[torch.LongTensor] = None + + +@dataclass +class QEffCausalLMOutputWithPast(CausalLMOutputWithPast): + comp_ctx_len_out: Optional[torch.LongTensor] = None + + class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding): """ Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -130,6 +141,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -154,8 +166,16 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -188,6 +208,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -204,6 +225,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -241,6 +263,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -294,6 +317,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -315,11 +339,13 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() - output = BaseModelOutputWithPast( + comp_ctx_len_out = comp_ctx_lengths[comp_ctx_lengths.shape[-1] - 1 :] if comp_ctx_lengths is not None else None + output = QEffBaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, + comp_ctx_len_out=comp_ctx_len_out, ) return output if return_dict else output.to_tuple() @@ -337,6 +363,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -360,6 +387,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -377,10 +405,12 @@ def forward( logits = self.lm_head(hidden_states) logits = logits.float() - return CausalLMOutputWithPast( + comp_ctx_len_out = comp_ctx_lengths[comp_ctx_lengths.shape[-1] - 1 :] if comp_ctx_lengths is not None else None + return QEffCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + comp_ctx_len_out=comp_ctx_len_out, ) diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 59c19baa2..eb06f599b 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -7,6 +7,7 @@ """PyTorch Mistral model.""" +from dataclasses import dataclass from typing import Callable, List, Optional, Tuple, Union import torch @@ -33,6 +34,16 @@ from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +@dataclass +class QEffBaseModelOutputWithPast(BaseModelOutputWithPast): + comp_ctx_len_out: Optional[torch.LongTensor] = None + + +@dataclass +class QEffCausalLMOutputWithPast(CausalLMOutputWithPast): + comp_ctx_len_out: Optional[torch.LongTensor] = None + + class QEffMistralRotaryEmbedding(MistralRotaryEmbedding): """ Copied from MistralForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -139,6 +150,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -163,8 +175,16 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -197,6 +217,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -227,6 +248,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -265,6 +287,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -329,6 +352,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -350,11 +374,13 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() - output = BaseModelOutputWithPast( + comp_ctx_len_out = comp_ctx_lengths[comp_ctx_lengths.shape[-1] - 1 :] if comp_ctx_lengths is not None else None + output = QEffBaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, + comp_ctx_len_out=comp_ctx_len_out, ) return output if return_dict else output.to_tuple() @@ -372,6 +398,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -395,6 +422,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -412,10 +440,12 @@ def forward( logits = self.lm_head(hidden_states) logits = logits.float() - return CausalLMOutputWithPast( + comp_ctx_len_out = comp_ctx_lengths[comp_ctx_lengths.shape[-1] - 1 :] if comp_ctx_lengths is not None else None + return QEffCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + comp_ctx_len_out=comp_ctx_len_out, ) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 2f3ee3dc0..a6d65e23d 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1427,6 +1427,8 @@ def __init__( if self.is_tlm: self.model.qaic_config["return_pdfs"] = True + self.comp_ctx_lengths = kwargs.pop("comp_ctx_lengths", None) + @property def model_name(self) -> str: mname = self.model.__class__.__name__ @@ -1497,6 +1499,8 @@ def from_pretrained( kv_offload = kwargs.pop("kv_offload", None) + comp_ctx_lengths = kwargs.pop("comp_ctx_lengths", None) + kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) if qaic_config is not None: @@ -1513,6 +1517,7 @@ def from_pretrained( continuous_batching=continuous_batching, qaic_config=qaic_config, pretrained_model_name_or_path=pretrained_model_name_or_path, + comp_ctx_lengths=comp_ctx_lengths, **kwargs, ) @@ -1558,6 +1563,10 @@ def export(self, export_dir: Optional[str] = None) -> str: "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, } + if self.comp_ctx_lengths is not None: + example_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + if len(kv_cache_shape) == 3: # For GPTBigCode arch the pkv is 3d pkv_dynamic_axes = { 0: "full_batch_size" if self.continuous_batching else "batch_size", @@ -1688,6 +1697,7 @@ def build_prefill_specialization( self, prefill_seq_len: int = 32, ctx_len: int = 128, + comp_ctx_lengths: Optional[int] = None, batch_size: int = 1, kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, @@ -1698,6 +1708,9 @@ def build_prefill_specialization( "ctx_len": ctx_len, "num_logits_to_keep": 1 if self.is_tlm else None, } + if comp_ctx_lengths is not None: + spec["comp_ctx_lengths"] = comp_ctx_lengths + if self.continuous_batching: spec["full_batch_size"] = kv_cache_batch_size else: @@ -1710,6 +1723,7 @@ def build_decode_specialization( self, prefill_seq_len: int = 32, ctx_len: int = 128, + comp_ctx_lengths: Optional[int] = None, batch_size: int = 1, kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, @@ -1723,6 +1737,8 @@ def build_decode_specialization( "ctx_len": ctx_len, "num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm else None, } + if comp_ctx_lengths is not None: + spec["comp_ctx_lengths"] = comp_ctx_lengths if self.continuous_batching: spec["full_batch_size"] = kv_cache_batch_size @@ -1811,26 +1827,45 @@ def compile( # --- Specializations --- specializations = [] if prefill_only is None or prefill_only or prefill_seq_len == 1: + ctx_for_specialization = self.comp_ctx_lengths[0] if self.comp_ctx_lengths is not None else None + specializations.append( self.build_prefill_specialization( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, + comp_ctx_lengths=ctx_for_specialization, batch_size=batch_size, kv_cache_batch_size=kv_cache_batch_size, full_batch_size=full_batch_size, ) ) if prefill_only is None or not prefill_only: - decode_spec = self.build_decode_specialization( - prefill_seq_len=prefill_seq_len, - ctx_len=ctx_len, - batch_size=batch_size, - kv_cache_batch_size=kv_cache_batch_size, - full_batch_size=full_batch_size, - num_speculative_tokens=num_speculative_tokens, - ) - if decode_spec: - specializations.append(decode_spec) + if self.comp_ctx_lengths is not None: + # Adding elements from self.comp_ctx_lengths to decode_specialization + for i in range(1, len(self.comp_ctx_lengths)): + decode_spec = self.build_decode_specialization( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths=self.comp_ctx_lengths[i], + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, + ) + if decode_spec: + specializations.append(decode_spec) + + else: + decode_spec = self.build_decode_specialization( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, + ) + if decode_spec: + specializations.append(decode_spec) # --- Compilation --- kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" @@ -1890,6 +1925,7 @@ def generate( tokenizer, self.qpc_path, prompt=prompts, + comp_ctx_lengths=self.comp_ctx_lengths, device_id=device_id, generation_len=generation_len, is_tlm=self.is_tlm, diff --git a/QEfficient/transformers/models/phi/modeling_phi.py b/QEfficient/transformers/models/phi/modeling_phi.py index e08dfa528..57fc27953 100644 --- a/QEfficient/transformers/models/phi/modeling_phi.py +++ b/QEfficient/transformers/models/phi/modeling_phi.py @@ -7,6 +7,7 @@ """PyTorch Phi model.""" +from dataclasses import dataclass from typing import Callable, List, Optional, Tuple, Union import torch @@ -26,6 +27,16 @@ from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +@dataclass +class QEffBaseModelOutputWithPast(BaseModelOutputWithPast): + comp_ctx_len_out: Optional[torch.LongTensor] = None + + +@dataclass +class QEffCausalLMOutputWithPast(CausalLMOutputWithPast): + comp_ctx_len_out: Optional[torch.LongTensor] = None + + def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -64,6 +75,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -101,8 +113,16 @@ def forward( key_states = torch.cat((key_rot, key_pass), dim=-1) if past_key_value is not None: - # Update the cache_kwargs with position_ids for Cloud AI 100 - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -137,6 +157,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, @@ -178,6 +199,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -210,6 +232,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -271,6 +294,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -290,11 +314,14 @@ def forward( all_hidden_states += (hidden_states,) if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() if use_cache else None - output = BaseModelOutputWithPast( + + comp_ctx_len_out = comp_ctx_lengths[comp_ctx_lengths.shape[-1] - 1 :] if comp_ctx_lengths is not None else None + output = QEffBaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, + comp_ctx_len_out=comp_ctx_len_out, ) return output if return_dict else output.to_tuple() @@ -313,6 +340,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -367,6 +395,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, @@ -381,10 +410,12 @@ def forward( logits = self.lm_head(hidden_states) logits = logits.float() - return CausalLMOutputWithPast( + comp_ctx_len_out = comp_ctx_lengths[comp_ctx_lengths.shape[-1] - 1 :] if comp_ctx_lengths is not None else None + return QEffCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + comp_ctx_len_out=comp_ctx_len_out, ) diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index 3a54a1e83..de888aad6 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -7,6 +7,7 @@ """PyTorch Phi-3 model.""" +from dataclasses import dataclass from typing import Callable, Optional, Tuple, Union import torch @@ -29,6 +30,16 @@ from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +@dataclass +class QEffBaseModelOutputWithPast(BaseModelOutputWithPast): + comp_ctx_len_out: Optional[torch.LongTensor] = None + + +@dataclass +class QEffCausalLMOutputWithPast(CausalLMOutputWithPast): + comp_ctx_len_out: Optional[torch.LongTensor] = None + + class QEffPhi3RotaryEmbedding(Phi3RotaryEmbedding): """ Copied from Phi3ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -137,6 +148,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, position_ids=Optional[torch.Tensor], past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -159,6 +171,8 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { "sin": sin, @@ -166,6 +180,7 @@ def forward( "cache_position": cache_position, "batch_index": batch_index, "position_ids": position_ids, + "CCL": attention_mask.shape[-1], } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -199,6 +214,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -240,6 +256,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -274,6 +291,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -329,6 +347,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -349,11 +368,13 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() if use_cache else None - output = BaseModelOutputWithPast( + comp_ctx_len_out = comp_ctx_lengths[comp_ctx_lengths.shape[-1] - 1 :] if comp_ctx_lengths is not None else None + output = QEffBaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, + comp_ctx_len_out=comp_ctx_len_out, ) return output if return_dict else output.to_tuple() @@ -373,6 +394,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -420,6 +442,7 @@ def forward( batch_index=batch_index, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, @@ -435,10 +458,12 @@ def forward( logits = self.lm_head(hidden_states) logits = logits.float() - return CausalLMOutputWithPast( + comp_ctx_len_out = comp_ctx_lengths[comp_ctx_lengths.shape[-1] - 1 :] if comp_ctx_lengths is not None else None + return QEffCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + comp_ctx_len_out=comp_ctx_len_out, ) diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 67c71b32c..6ee4d71c0 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -7,6 +7,7 @@ """PyTorch Qwen2 model.""" +from dataclasses import dataclass from typing import Callable, List, Optional, Tuple, Union import torch @@ -32,6 +33,16 @@ from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +@dataclass +class QEffBaseModelOutputWithPast(BaseModelOutputWithPast): + comp_ctx_len_out: Optional[torch.LongTensor] = None + + +@dataclass +class QEffCausalLMOutputWithPast(CausalLMOutputWithPast): + comp_ctx_len_out: Optional[torch.LongTensor] = None + + # Can be replaced with llama/modeling_llama.py::QEffLlamaRotaryEmbedding but keeping it following transformers ideology class QEffQwen2RotaryEmbedding(Qwen2RotaryEmbedding): """ @@ -149,6 +160,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -167,8 +179,16 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -202,6 +222,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -237,6 +258,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -276,6 +298,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -334,6 +357,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -354,11 +378,13 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() - output = BaseModelOutputWithPast( + comp_ctx_len_out = comp_ctx_lengths[comp_ctx_lengths.shape[-1] - 1 :] if comp_ctx_lengths is not None else None + output = QEffBaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, + comp_ctx_len_out=comp_ctx_len_out, ) return output if return_dict else output.to_tuple() @@ -377,6 +403,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -400,6 +427,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -415,10 +443,12 @@ def forward( logits = self.lm_head(hidden_states) logits = logits.float() - return CausalLMOutputWithPast( + comp_ctx_len_out = comp_ctx_lengths[comp_ctx_lengths.shape[-1] - 1 :] if comp_ctx_lengths is not None else None + return QEffCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + comp_ctx_len_out=comp_ctx_len_out, ) diff --git a/examples/compute_context_length.py b/examples/compute_context_length.py new file mode 100644 index 000000000..3711591d2 --- /dev/null +++ b/examples/compute_context_length.py @@ -0,0 +1,47 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +## In this example, you can run a model for static and continuous batching with different Compute-Context-Length (CCL) inputs. ## + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM + +## Using optional variable comp_ctx_lengths variable you can pass a list of context lengths. It will run the model with default context length if comp_ctx_lengths=None. ## +## - The first number in this list is the context length that will be used during prefilling. ## +## - During the decoding process, based on the position_id or cache index it will work with the specific compute-context-length in the list. It will start from a proper compute-context-length in the list based on input prompt length and will gradually increase the compute-context-length if the cache index passes the current compute-context-length. ## +comp_ctx_lengths = [256, 512, 1024] # None + +model_name = "meta-llama/Llama-3.2-1B-Instruct" +model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, continuous_batching=True, comp_ctx_lengths=comp_ctx_lengths +) +# model = QEFFAutoModelForCausalLM.from_pretrained(model_name, comp_ctx_lengths=comp_ctx_lengths) + +# model compilation for either continuous or static batching. For continuous batching full_batch_size is needed. +model.compile( + prefill_seq_len=128, + ctx_len=1024, + num_cores=16, + num_devices=1, + full_batch_size=1, + mxfp6_matmul=True, + mxint8_kv_cache=True, +) +# model.compile(prefill_seq_len=128, ctx_len=1024, num_cores=16, num_devices=1,batch_size=4,mxfp6_matmul=True,mxint8_kv_cache=True) + +# Create tokenizer and run model.generate and passes the input prompts to it. It also receives comp_ctx_lengths list which will be used during the decoding process to apply the best and most efficient compute context length. +tokenizer = AutoTokenizer.from_pretrained(model_name) +model.generate( + prompts=[ + "What are some healthy foods to include in a balanced diet?", + "What is a nutritious meal that can keep you energized throughout the day?", + "What are some fun and relaxing activities to do over the weekend?", + "What's your favorite hobby?", + ], + tokenizer=tokenizer, +) diff --git a/tests/transformers/test_compute_context_length.py b/tests/transformers/test_compute_context_length.py new file mode 100644 index 000000000..f3994692c --- /dev/null +++ b/tests/transformers/test_compute_context_length.py @@ -0,0 +1,176 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import copy +import os +from time import perf_counter + +import onnx +import pytest +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + +configs = [ + # name, max_position_embeddings, num_hidden_layers, num_attention_heads, hidden_size, intermediate_size, vocab_size, additional_params + ("gpt2", 256, 2, 4, 128, 512, 127, {}), + ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + ("falcon", 256, 2, 4, 128, 512, 127, {}), + ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mpt", 256, 2, 4, 128, 512, 127, {}), + ("phi", 256, 2, 4, 128, 512, 127, {}), + ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), + ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("starcoder2", 256, 2, 4, 128, 512, 127, {}), + ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), +] + +configs = [ + AutoConfig.for_model( + model_name, + max_position_embeddings=max_position_embeddings, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + vocab_size=vocab_size, + **additional_params, + ) + for ( + model_name, + max_position_embeddings, + num_hidden_layers, + num_attention_heads, + hidden_size, + intermediate_size, + vocab_size, + additional_params, + ) in configs +] +config_ids = [x.model_type for x in configs] + +model_kwargs = {"attn_implementation": "eager"} + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +def test_causal_lm_unsupported(cb): + model = AutoModelForCausalLM.from_config(AutoConfig.for_model("opt")) + with pytest.warns(): + QEFFAutoModelForCausalLM(model, cb) + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_causal_lm_init(config, cb): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + qeff_model = QEFFAutoModelForCausalLM(model, cb) + with pytest.raises(TypeError): + QEFFAutoModelForCausalLM(AutoModel.from_config(config, **model_kwargs), cb) + assert qeff_model.model.__class__.__name__.startswith("QEff") + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_causal_lm_pretrained(config, cb, tmp_path): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + model.save_pretrained(tmp_path) + + qeff_model = QEFFAutoModelForCausalLM.from_pretrained(tmp_path, cb) + assert qeff_model.model.__class__.__name__.startswith("QEff") + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_causal_lm_hash(config, cb): + hash_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb).model_hash + hash_0_1 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb).model_hash + + assert hash_0_0 == hash_0_1 + + cfg1 = copy.deepcopy(config) + cfg1.num_hidden_layers -= 1 + hash_1_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(cfg1, **model_kwargs), cb).model_hash + cfg2 = copy.deepcopy(config) + cfg2.num_hidden_layers -= 1 + hash_1_1 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(cfg2, **model_kwargs), cb).model_hash + assert hash_1_0 == hash_1_1 + + assert hash_0_0 != hash_1_0 + + if cb: + hash_0_no_cb = QEFFAutoModelForCausalLM( + AutoModelForCausalLM.from_config(config, **model_kwargs), False + ).model_hash + assert hash_0_0 != hash_0_no_cb + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_causal_lm_export(config, cb, tmp_path): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + comp_ctx_lengths = [512, 1024, 2048] + qeff_model = QEFFAutoModelForCausalLM(model, cb, comp_ctx_lengths=comp_ctx_lengths) + qeff_model.export(tmp_path) + model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.model_hash) + assert model_path.is_dir() + assert qeff_model.onnx_path.is_file() + assert qeff_model.onnx_path.relative_to(model_path).parts == (qeff_model.model_name + ".onnx",) + + # Check if the KV-cache inputs and outputs are created + onnx_model = onnx.load(qeff_model.onnx_path, load_external_data=False) + retained_output_names = { + x.name[: -len("_RetainedState")] for x in onnx_model.graph.output if x.name.endswith("_RetainedState") + } + retained_output_names.issubset({x.name for x in onnx_model.graph.input}) + + # Check if there is no re-export + start = perf_counter() + qeff_model.export(tmp_path) + end = perf_counter() + export_time = end - start + assert export_time < 2.0 + + +@pytest.fixture +def tmp_cache(tmp_path, monkeypatch): + monkeypatch.setattr("QEfficient.base.modeling_qeff.QEFF_HOME", tmp_path) + yield tmp_path + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_causal_lm_compile(config, cb, tmp_cache): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + comp_ctx_lengths = [8, 12, 16] + qeff_model = QEFFAutoModelForCausalLM(model, cb, comp_ctx_lengths=comp_ctx_lengths) + compile_params = {"prefill_seq_len": 8, "ctx_len": 16} + if cb: + compile_params["full_batch_size"] = 32 + compile_params["batch_size"] = 8 + qeff_model.compile(**compile_params) + model_path = tmp_cache / (qeff_model.model_name + "-" + qeff_model.model_hash) + + # Check if ONNX is exported properly + assert model_path.is_dir() + assert qeff_model.onnx_path.is_file() + assert qeff_model.onnx_path.relative_to(model_path).parts == (qeff_model.model_name + ".onnx",) + + # Check if QPC is compiled properly + assert qeff_model.qpc_path.is_dir() + assert (qeff_model.qpc_path / "programqpc.bin").is_file() + assert qeff_model.qpc_path.relative_to(tmp_cache).parts[0] == qeff_model.model_name + "-" + qeff_model.model_hash + + # Check if there is no re-compilation + start = perf_counter() + qeff_model.compile(**compile_params) + end = perf_counter() + compile_time = end - start + assert compile_time < 2.0 + assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json"))