From 3b71e8af48089bec94fd79633d610066400ba5c6 Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Wed, 26 Nov 2025 08:32:51 +0000 Subject: [PATCH 1/3] Added mllama cb initial commit Signed-off-by: Asmita Goswami --- QEfficient/generation/embedding_handler.py | 6 +- QEfficient/generation/vlm_generation.py | 4 +- .../models/mllama/modeling_mllama.py | 123 ++++++++++++------ .../models/mllama/continuous_batching.py | 108 +++++++++++++++ 4 files changed, 197 insertions(+), 44 deletions(-) create mode 100644 examples/image_text_to_text/models/mllama/continuous_batching.py diff --git a/QEfficient/generation/embedding_handler.py b/QEfficient/generation/embedding_handler.py index 76da7afc2..788386554 100644 --- a/QEfficient/generation/embedding_handler.py +++ b/QEfficient/generation/embedding_handler.py @@ -273,7 +273,7 @@ def setup_vision_buffers(self): if "vision_embeds" in output_name: buffers[output_name] = np.zeros(shape, dtype=np.float16) else: - buffers[output_name] = np.zeros(shape, dtype=np.float32) + buffers[output_name] = np.zeros(shape, dtype=np.float16) self._vision_session.set_buffers(buffers) @@ -359,7 +359,9 @@ def get_processed_inputs( else: lang_inputs["position_ids"] = np.where(lang_inputs.pop("attention_mask"), np.arange(padded_len), -1) - lang_inputs["image_idx"] = np.array([[0]]) + not_mllama = "mllama" != self._qeff_model.model.config.model_type + if not_mllama: + lang_inputs["image_idx"] = np.array([[0]]) return lang_inputs, vision_outputs, num_chunks diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index 5eb91d142..4ce35502b 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -652,7 +652,9 @@ def run_prefill_for_all_inputs_with_cached_vision(self, prompt_queue, generation prefill_logit_bs=1, ) - self._session.skip_buffers(vision_outputs.keys()) + not_mllama = "mllama" != self.qeff_model.model.config.model_type + if not_mllama: + self._session.skip_buffers(vision_outputs.keys()) # Calculate position_ids for decode position_ids_decode = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index a3cb4273d..7973ac140 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -421,6 +421,12 @@ def forward( past_key_value.layers[self.layer_idx].keys, past_key_value.layers[self.layer_idx].values, ) + # if key_states.size(0) != bsz: + # if bsz < key_states.size(0): + # # Slice down + # key_states = key_states[:bsz] + # value_states = value_states[:bsz] + else: raise ValueError( "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" @@ -439,6 +445,7 @@ def forward( scaling=self.scaling, ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + # breakpoint() attn_output = self.o_proj(attn_output) return attn_output, attn_weights @@ -674,6 +681,7 @@ def forward( position_ids=position_ids, past_key_value=past_key_values, comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, ) @@ -793,6 +801,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, @@ -840,6 +849,7 @@ def forward( full_text_row_masked_out_mask=full_text_row_masked_out_mask, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, use_cache=use_cache, inputs_embeds=inputs_embeds, cache_position=cache_position, @@ -891,6 +901,7 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, @@ -901,8 +912,9 @@ def forward( logits = self.lm_head(hidden_states).float() return logits, image_idx, outputs.past_key_values, pixel_values - def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False,): BS = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + FBS = constants.ONNX_EXPORT_EXAMPLE_FBS SEQ_LEN = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN CTX_LEN = constants.ONNX_EXPORT_CTX_LEN @@ -960,8 +972,8 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl 1, num_key_value_heads, image_tokens_len, head_dim ) else: - lang_inputs["past_key_values"].layers[i].keys = torch.zeros(1, num_key_value_heads, CTX_LEN, head_dim) - lang_inputs["past_key_values"].layers[i].values = torch.zeros(1, num_key_value_heads, CTX_LEN, head_dim) + lang_inputs["past_key_values"].layers[i].keys = torch.zeros(FBS if continuous_batching else BS, num_key_value_heads, CTX_LEN, head_dim) + lang_inputs["past_key_values"].layers[i].values = torch.zeros(FBS if continuous_batching else BS, num_key_value_heads, CTX_LEN, head_dim) lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, CTX_LEN - 1) @@ -969,6 +981,9 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl if comp_ctx_lengths is not None: lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(BS).view(BS, 1) + inputs = {} if kv_offload: @@ -988,6 +1003,9 @@ def get_specializations( comp_ctx_lengths_prefill: Optional[List[int]] = None, comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): vis_cfg = self.config.vision_config @@ -1006,47 +1024,67 @@ def get_specializations( lang = [] for i in range(0, len(comp_ctx_lengths_prefill)): - lang.append( - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "comp_ctx_lengths": comp_ctx_lengths_prefill[i], - "max_num_images": max_num_images, - "img_size": img_size, - } - ) - - # Remaining elements use comp_ctx_lengths[1:] in a loop - for i in range(0, len(comp_ctx_lengths_decode)): - lang.append( - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "comp_ctx_lengths": comp_ctx_lengths_decode[i], - "max_num_images": max_num_images, - "img_size": img_size, - } - ) - - else: - lang = [ - { - "batch_size": batch_size, + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], "max_num_images": max_num_images, "img_size": img_size, - }, - { - "batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + lang.append(lang_prefill) + + # Remaining elements use comp_ctx_lengths[1:] in a loop + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, "seq_len": "1", "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], "max_num_images": max_num_images, "img_size": img_size, - }, - ] + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + lang.append(lang_decode) + + else: + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) specializations = {} @@ -1057,7 +1095,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False): txt_cfg = self.config.get_text_config() num_hidden_layers = txt_cfg.num_hidden_layers cross_attention_layers = txt_cfg.cross_attention_layers @@ -1074,13 +1112,16 @@ def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv "cross_attention_mask": {0: "batch_size", 1: "seq_len", 2: "max_num_images"}, } + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} + for i in range(num_hidden_layers): if i in cross_attention_layers: - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size"} + lang_dynamic_axes[f"past_key.{i}"] = {0: "full_batch_size" if continuous_batching else "batch_size"} + lang_dynamic_axes[f"past_value.{i}"] = {0: "full_batch_size" if continuous_batching else "batch_size"} else: - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_key.{i}"] = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_value.{i}"] = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"} if comp_ctx_lengths is not None: lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} diff --git a/examples/image_text_to_text/models/mllama/continuous_batching.py b/examples/image_text_to_text/models/mllama/continuous_batching.py new file mode 100644 index 000000000..8500e9640 --- /dev/null +++ b/examples/image_text_to_text/models/mllama/continuous_batching.py @@ -0,0 +1,108 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import transformers +from transformers import AutoConfig, AutoProcessor + +from QEfficient import QEFFAutoModelForImageTextToText + + +def set_num_layers(config, n_layer=1): + ## -1 indicates use all the layers of the model. + if n_layer == -1: + return config + elif hasattr(config, "model_type") and "mllama" in config.model_type: + config.text_config.num_hidden_layers = n_layer + config.text_config.cross_attention_layers = [ + x for x in config.text_config.cross_attention_layers if x < n_layer + ] + elif hasattr(config, "text_config"): + config.text_config.num_hidden_layers = n_layer + config.vision_config.num_hidden_layers = n_layer + elif hasattr(config, "llm_config"): + config.llm_config.num_hidden_layers = n_layer + config.vision_config.num_hidden_layers = n_layer + else: + config.num_hidden_layers = n_layer + return config + + +model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +config = set_num_layers(config, n_layer=7) + +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +continious_batching = True +if continious_batching: + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, + ) + + qeff_model.compile( + prefill_seq_len=32, + ctx_len=512, + img_size=560, + num_cores=16, + num_devices=4, + batch_size=1, + full_batch_size=4, + mxfp6_matmul=False, + ) +else: + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + ) + + qeff_model.compile( + prefill_seq_len=32, + ctx_len=512, + img_size=560, + num_cores=16, + num_devices=4, + batch_size=1, + mxfp6_matmul=False, + # mxint8_kv_cache=True, + # aic_enable_depth_first=True, + # mos=1, + ) + +image_urls = [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", +] + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +exec_info = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + device_ids=[0, 1, 2, 3], + generation_len=10, +) + +print("Generated texts:", exec_info.generated_texts) +print("Generated IDs:", exec_info.generated_ids) +print(exec_info) From 660c13d6b3b69b78fe56adcc275bae35107b3eb0 Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Wed, 26 Nov 2025 08:35:12 +0000 Subject: [PATCH 2/3] Ruff check Signed-off-by: Asmita Goswami --- .../models/mllama/modeling_mllama.py | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 7973ac140..e34e6d00b 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -912,7 +912,12 @@ def forward( logits = self.lm_head(hidden_states).float() return logits, image_idx, outputs.past_key_values, pixel_values - def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False,): + def get_dummy_inputs( + self, + comp_ctx_lengths: Optional[List[int]] = None, + kv_offload: bool = False, + continuous_batching: bool = False, + ): BS = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE FBS = constants.ONNX_EXPORT_EXAMPLE_FBS SEQ_LEN = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN @@ -972,8 +977,12 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl 1, num_key_value_heads, image_tokens_len, head_dim ) else: - lang_inputs["past_key_values"].layers[i].keys = torch.zeros(FBS if continuous_batching else BS, num_key_value_heads, CTX_LEN, head_dim) - lang_inputs["past_key_values"].layers[i].values = torch.zeros(FBS if continuous_batching else BS, num_key_value_heads, CTX_LEN, head_dim) + lang_inputs["past_key_values"].layers[i].keys = torch.zeros( + FBS if continuous_batching else BS, num_key_value_heads, CTX_LEN, head_dim + ) + lang_inputs["past_key_values"].layers[i].values = torch.zeros( + FBS if continuous_batching else BS, num_key_value_heads, CTX_LEN, head_dim + ) lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, CTX_LEN - 1) @@ -1095,7 +1104,9 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False): + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): txt_cfg = self.config.get_text_config() num_hidden_layers = txt_cfg.num_hidden_layers cross_attention_layers = txt_cfg.cross_attention_layers @@ -1120,8 +1131,14 @@ def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv lang_dynamic_axes[f"past_key.{i}"] = {0: "full_batch_size" if continuous_batching else "batch_size"} lang_dynamic_axes[f"past_value.{i}"] = {0: "full_batch_size" if continuous_batching else "batch_size"} else: - lang_dynamic_axes[f"past_key.{i}"] = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_key.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + lang_dynamic_axes[f"past_value.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } if comp_ctx_lengths is not None: lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} From 52cb698d5d4a706789d5334d0cd1be84e03dc033 Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Wed, 26 Nov 2025 08:44:17 +0000 Subject: [PATCH 3/3] Updated cross_atten_layer BS Signed-off-by: Asmita Goswami --- QEfficient/transformers/models/mllama/modeling_mllama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index e34e6d00b..572f41d4e 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -971,10 +971,10 @@ def get_dummy_inputs( idx = cross_attention_layers.index(i) assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}" lang_inputs["past_key_values"].layers[i].keys = torch.zeros( - 1, num_key_value_heads, image_tokens_len, head_dim + FBS if continuous_batching else BS, num_key_value_heads, image_tokens_len, head_dim ) lang_inputs["past_key_values"].layers[i].values = torch.zeros( - 1, num_key_value_heads, image_tokens_len, head_dim + FBS if continuous_batching else BS, num_key_value_heads, image_tokens_len, head_dim ) else: lang_inputs["past_key_values"].layers[i].keys = torch.zeros(