From 196f349a13b94d31ff6ce7405e1feb3055478274 Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Fri, 11 Apr 2025 10:22:33 +0000 Subject: [PATCH 1/8] Onboarding Mistral3.1_24B Signed-off-by: Mohit Soni --- QEfficient/transformers/modeling_utils.py | 12 + .../transformers/models/mistral3/__init__.py | 6 + .../models/mistral3/modeling_mistral3.py | 237 ++++++++++++++++++ .../transformers/models/pytorch_transforms.py | 10 + 4 files changed, 265 insertions(+) create mode 100644 QEfficient/transformers/models/mistral3/__init__.py create mode 100644 QEfficient/transformers/models/mistral3/modeling_mistral3.py diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index 0a0e4d54b..9826f4cff 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -58,6 +58,10 @@ MistralModel, MistralRMSNorm, ) +from transformers.models.mistral3.modeling_mistral3 import ( + Mistral3ForConditionalGeneration, + Mistral3RMSNorm, +) from transformers.models.mixtral.modeling_mixtral import ( MixtralAttention, MixtralDecoderLayer, @@ -70,6 +74,7 @@ from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel from transformers.models.phi.modeling_phi import PhiAttention, PhiForCausalLM, PhiModel from transformers.models.phi3.modeling_phi3 import Phi3Attention, Phi3ForCausalLM, Phi3Model, Phi3RMSNorm +from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2ForCausalLM, Qwen2Model, Qwen2RMSNorm from transformers.models.starcoder2.modeling_starcoder2 import ( Starcoder2Attention, @@ -88,6 +93,7 @@ ) from QEfficient.customop import CustomRMSNormAIC +from QEfficient.transformers.models.mistral3.modeling_mistral3 import QEffMistral3ForConditionalGeneration # Placeholder for all non-transformer models from .models.codegen.modeling_codegen import ( @@ -179,6 +185,7 @@ GPTBigCodeForCausalLM.__name__, MllamaForCausalLM.__name__, WhisperForConditionalGeneration.__name__, + Mistral3ForConditionalGeneration.__name__, ] ) @@ -230,6 +237,9 @@ MistralModel: QEffMistralModel, MistralForCausalLM: QEffMistralForCausalLM, MistralRMSNorm: CustomRMSNormAIC, + # Mistral3 model layers + Mistral3ForConditionalGeneration: QEffMistral3ForConditionalGeneration, + Mistral3RMSNorm: CustomRMSNormAIC, # Mixtral model layers MixtralAttention: QEffMixtralAttention, MixtralDecoderLayer: QeffMixtralDecoderLayer, @@ -246,6 +256,8 @@ PhiAttention: QEffPhiAttention, PhiModel: QEffPhiModel, PhiForCausalLM: QEffPhiForCausalLM, + # Pixtral model layers + PixtralRMSNorm: CustomRMSNormAIC, # Falcon model layers FalconAttention: QEffFalconAttention, FalconForCausalLM: QEffFalconForCausalLM, diff --git a/QEfficient/transformers/models/mistral3/__init__.py b/QEfficient/transformers/models/mistral3/__init__.py new file mode 100644 index 000000000..72ba36c8a --- /dev/null +++ b/QEfficient/transformers/models/mistral3/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py new file mode 100644 index 000000000..a4f77f82b --- /dev/null +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -0,0 +1,237 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration + +from QEfficient.utils import constants +from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config + +BS = 1 +NUM_CHANNEL = 3 +SEQ_LEN = 3072 +CTX_LEN = 4096 + + +class QEFFMistral3EncoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.model.vision_model = self.model.vision_tower + + def forward(self, pixel_values): + image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]]) + image_features = self.model.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=self.model.config.vision_feature_layer, + image_sizes=image_sizes, + ) + return image_features + + +class QEFFMistral3DecoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.config = self.model.config + self.language_model = self.model.language_model + + def forward(self, input_ids, vit_embeds, position_ids, past_key_values): + inputs_embeds = self.model.get_input_embeddings()(input_ids) + vit_embeds = vit_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + mask = input_ids == self.model.config.image_token_index + indices1 = mask.to(torch.int64).cumsum(1) - 1 + indices0 = torch.arange(mask.shape[0]).view(-1, 1) + image_features_expanded = vit_embeds.unsqueeze(0)[indices0, indices1] + inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) + outputs = self.model.language_model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + ) + + return outputs.logits, vit_embeds, outputs.past_key_values + + +class QEffMistral3ForConditionalGeneration(Mistral3ForConditionalGeneration): + def get_qeff_vision_encoder(self): + return QEFFMistral3EncoderWrapper(self) + + def get_qeff_language_decoder(self): + return QEFFMistral3DecoderWrapper(self) + + def forward(self, pixel_values, input_ids, position_ids, past_key_values): + inputs_embeds = self.get_input_embeddings()(input_ids) + # Image features + image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]]) + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=self.config.vision_feature_layer, + image_sizes=image_sizes, + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + mask = input_ids == self.config.image_token_index + indices1 = mask.to(torch.int64).cumsum(1) - 1 + indices0 = torch.arange(mask.shape[0]).view(-1, 1) + image_features_expanded = image_features.unsqueeze(0)[indices0, indices1] + inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) + outputs = self.language_model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + ) + return outputs.logits, pixel_values, outputs.past_key_values + + def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + inputs_shapes = {} + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + inputs_shapes["vit_embeds"] = ( + constants.MISTRAL3_FEATURE_SIZE, + self.language_model.config.hidden_size, + ) + inputs_shapes["position_ids"] = ( + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + ) + inputs_shapes["pixel_values"] = ( + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + constants.MISTRAL3_NUM_CHANNELS, + constants.MISTRAL3_HEIGHT, + constants.MISTRAL3_WIDTH, + ) + + # Define inputs + vision_inputs = {} + lang_inputs = {} + vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32) + lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) + lang_inputs["vit_embeds"] = torch.zeros((inputs_shapes["vit_embeds"]), dtype=torch.float32) + lang_inputs["position_ids"] = ( + torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) + .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) + ) + + # Add data for KV + kv_cache_shape = get_padding_shape_from_config( + config=self.language_model.config, + batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + ) + + lang_inputs["past_key_values"] = [[] for _ in range(self.language_model.config.num_hidden_layers)] + for i in range(self.language_model.config.num_hidden_layers): + for kv in ["key", "value"]: + lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + + inputs = {} + if kv_offload: + inputs["vision"] = vision_inputs + inputs["lang"] = lang_inputs + else: + lang_inputs.pop("vit_embeds") + inputs = {**vision_inputs, **lang_inputs} + + return inputs + + def get_specializations( + self, + batch_size: int, + prefill_seq_len: int, + ctx_len: int, + img_size: int, + kv_offload: bool = False, + **compiler_options, + ): + prefill_seq_len = prefill_seq_len if prefill_seq_len else SEQ_LEN + ctx_len = ctx_len if ctx_len else CTX_LEN + height = constants.MISTRAL3_HEIGHT + width = constants.MISTRAL3_WIDTH + + vision = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "height": height, + "width": width, + } + ] + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "height": height, + "width": width, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "height": height, + "width": width, + }, + ] + specializations = {} + + if kv_offload: + specializations["vision"] = vision + specializations["lang"] = lang + return specializations, compiler_options + else: + return lang, compiler_options + + def get_onnx_dynamic_axes(self, kv_offload: bool = False): + # Define dynamic axes + num_layers = self.config.text_config.num_hidden_layers + + vision_dynamic_axes = { + "pixel_values": {0: "batch_size", 2: "height", 3: "width"}, + } + lang_dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, + } + + for i in range(num_layers): + lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + + dynamic_axes = {} + if kv_offload: + dynamic_axes["vision"] = vision_dynamic_axes + dynamic_axes["lang"] = lang_dynamic_axes + else: + dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes} + return dynamic_axes + + def get_output_names(self, kv_offload: bool = False): + vision_output_names = ["vit_embeds"] + lang_output_names = ["logits"] + for i in range(self.language_model.config.num_hidden_layers): + for kv in ["key", "value"]: + lang_output_names.append(f"past_{kv}.{i}_RetainedState") + + output_names = {} + if kv_offload: + lang_output_names.insert(1, "vit_embeds_RetainedState") + output_names["vision"] = vision_output_names + output_names["lang"] = lang_output_names + else: + lang_output_names.insert(1, "pixel_values_RetainedState") + return lang_output_names + return output_names + + def get_inputs_info(self): + return [ + IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")), + IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")), + IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "height", "width")), + ] diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index ca74c0ddd..0bcd7c6e6 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -100,6 +100,10 @@ MistralModel, MistralRMSNorm, ) +from transformers.models.mistral3.modeling_mistral3 import ( + Mistral3ForConditionalGeneration, + Mistral3RMSNorm, +) from transformers.models.mixtral.modeling_mixtral import ( MixtralAttention, MixtralDecoderLayer, @@ -129,6 +133,7 @@ Phi3Model, Phi3RMSNorm, ) +from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm from transformers.models.qwen2.modeling_qwen2 import ( Qwen2Attention, Qwen2DecoderLayer, @@ -260,6 +265,7 @@ QEffMistralForCausalLM, QEffMistralModel, ) +from QEfficient.transformers.models.mistral3.modeling_mistral3 import QEffMistral3ForConditionalGeneration from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import ( QEffMixtralAttention, QeffMixtralDecoderLayer, @@ -332,6 +338,7 @@ class CustomOpsTransform(ModuleMappingTransform): LlamaRMSNorm: CustomRMSNormAIC, Llama4TextRMSNorm: CustomRMSNormAIC, MistralRMSNorm: CustomRMSNormAIC, + Mistral3RMSNorm: CustomRMSNormAIC, MixtralRMSNorm: CustomRMSNormAIC, Phi3RMSNorm: CustomRMSNormAIC, Qwen2RMSNorm: CustomRMSNormAIC, @@ -339,6 +346,7 @@ class CustomOpsTransform(ModuleMappingTransform): GraniteRMSNorm: CustomRMSNormAIC, GraniteMoeRMSNorm: CustomRMSNormAIC, Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC, + PixtralRMSNorm: CustomRMSNormAIC, } @@ -426,6 +434,8 @@ class KVCacheTransform(ModuleMappingTransform): MistralDecoderLayer: QEffMistralDecoderLayer, MistralModel: QEffMistralModel, MistralForCausalLM: QEffMistralForCausalLM, + # Mistral3 + Mistral3ForConditionalGeneration: QEffMistral3ForConditionalGeneration, # Mixtral MixtralAttention: QEffMixtralAttention, MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock, From ffbba1aa76ff40e6cd5fbe6a6dd711721cf6297e Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Thu, 3 Jul 2025 08:15:03 +0000 Subject: [PATCH 2/8] Modeling Changes Signed-off-by: Mohit Soni --- .../models/mistral3/modeling_mistral3.py | 168 ++++++++++++++---- .../transformers/models/pytorch_transforms.py | 11 +- 2 files changed, 144 insertions(+), 35 deletions(-) diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index a4f77f82b..73e8a004f 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -5,18 +5,91 @@ # # ----------------------------------------------------------------------------- +from typing import Optional, Tuple, Union + import torch import torch.nn as nn import torch.utils.checkpoint +from transformers.modeling_outputs import BaseModelOutput from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration +from transformers.models.pixtral.modeling_pixtral import PixtralVisionModel, position_ids_in_meshgrid from QEfficient.utils import constants from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config -BS = 1 -NUM_CHANNEL = 3 -SEQ_LEN = 3072 -CTX_LEN = 4096 + +def custom_cumsum(tensor): + dim = 0 + result = torch.zeros_like(tensor) + indices = [slice(None)] * tensor.dim() + for i in range(tensor.size(dim)): + indices[dim] = slice(0, i + 1) + result.select(dim, i).copy_(tensor[tuple(indices)].sum(dim)) + return result + + +def qeff_generate_block_attention_mask(patch_embeds_list, tensor): + dtype = tensor.dtype + device = tensor.device + seq_len = tensor.shape[1] + d_min = torch.finfo(dtype).min + causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device) + block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1) + block_end_idx = custom_cumsum(torch.tensor(patch_embeds_list)) + block_start_idx = custom_cumsum(torch.tensor([0] + patch_embeds_list[:-1])) + for start, end in zip(block_start_idx.tolist(), block_end_idx.tolist()): + causal_mask[start:end, start:end] = 0 + causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1) + return causal_mask + + +class QEffPixtralVisionModel(PixtralVisionModel): + def forward( + self, + pixel_values: torch.Tensor, + image_sizes: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + *args, + **kwargs, + ) -> Union[Tuple, BaseModelOutput]: + """ + Returns: + pixel_values: tensor of token features for + all tokens of all images of shape (N_toks, D) + """ + # pass images through initial convolution independently + patch_embeds = self.patch_conv(pixel_values) + patch_embeds_list = [ + embed[..., : (size[0] // self.patch_size), : (size[1] // self.patch_size)] + for embed, size in zip(patch_embeds, image_sizes) + ] + + # flatten to a single sequence + patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0) + patch_embeds = self.ln_pre(patch_embeds) + + # positional embeddings + position_ids = position_ids_in_meshgrid( + patch_embeds_list, max_width=self.config.image_size // self.config.patch_size + ) + position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids) + + attention_mask = qeff_generate_block_attention_mask( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds + ) + + out = self.transformer( + patch_embeds, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + return out class QEFFMistral3EncoderWrapper(nn.Module): @@ -26,7 +99,7 @@ def __init__(self, model): self.model.vision_model = self.model.vision_tower def forward(self, pixel_values): - image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]]) + image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]]).repeat(pixel_values.shape[0], 1) image_features = self.model.get_image_features( pixel_values=pixel_values, vision_feature_layer=self.model.config.vision_feature_layer, @@ -42,21 +115,23 @@ def __init__(self, model): self.config = self.model.config self.language_model = self.model.language_model - def forward(self, input_ids, vit_embeds, position_ids, past_key_values): + def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): inputs_embeds = self.model.get_input_embeddings()(input_ids) - vit_embeds = vit_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) mask = input_ids == self.model.config.image_token_index indices1 = mask.to(torch.int64).cumsum(1) - 1 + indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) indices0 = torch.arange(mask.shape[0]).view(-1, 1) - image_features_expanded = vit_embeds.unsqueeze(0)[indices0, indices1] - inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) + image_features_expanded = vision_embeds.unsqueeze(0)[indices0, indices1] + inputs_embeds_1 = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) outputs = self.model.language_model( - inputs_embeds=inputs_embeds, + inputs_embeds=inputs_embeds_1, position_ids=position_ids, past_key_values=past_key_values, ) - - return outputs.logits, vit_embeds, outputs.past_key_values + next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) + return outputs.logits, vision_embeds, image_idx, outputs.past_key_values class QEffMistral3ForConditionalGeneration(Mistral3ForConditionalGeneration): @@ -66,10 +141,9 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEFFMistral3DecoderWrapper(self) - def forward(self, pixel_values, input_ids, position_ids, past_key_values): + def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values): inputs_embeds = self.get_input_embeddings()(input_ids) - # Image features - image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]]) + image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]]).repeat(pixel_values.shape[0], 1) image_features = self.get_image_features( pixel_values=pixel_values, vision_feature_layer=self.config.vision_feature_layer, @@ -78,21 +152,31 @@ def forward(self, pixel_values, input_ids, position_ids, past_key_values): image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) mask = input_ids == self.config.image_token_index indices1 = mask.to(torch.int64).cumsum(1) - 1 + indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) indices0 = torch.arange(mask.shape[0]).view(-1, 1) image_features_expanded = image_features.unsqueeze(0)[indices0, indices1] - inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) + image_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, ) - return outputs.logits, pixel_values, outputs.past_key_values + next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) + + return outputs.logits, pixel_values, image_idx, outputs.past_key_values def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) - inputs_shapes["vit_embeds"] = ( - constants.MISTRAL3_FEATURE_SIZE, + height = self.config.vision_config.image_size + width = self.config.vision_config.image_size + patch_size = self.config.vision_config.patch_size + kernel_size = self.config.spatial_merge_size + vision_size = ((height // patch_size) * (width // patch_size)) // (kernel_size * kernel_size) + inputs_shapes["vision_embeds"] = ( + vision_size, self.language_model.config.hidden_size, ) inputs_shapes["position_ids"] = ( @@ -101,23 +185,24 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): ) inputs_shapes["pixel_values"] = ( constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, - constants.MISTRAL3_NUM_CHANNELS, - constants.MISTRAL3_HEIGHT, - constants.MISTRAL3_WIDTH, + 3, + height, + width, ) - + inputs_shapes["image_idx"] = (1, 1) + inputs_shapes["image_sizes"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 2) # Define inputs vision_inputs = {} lang_inputs = {} vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32) lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) - lang_inputs["vit_embeds"] = torch.zeros((inputs_shapes["vit_embeds"]), dtype=torch.float32) + lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32) lang_inputs["position_ids"] = ( torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) - + lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) # Add data for KV kv_cache_shape = get_padding_shape_from_config( config=self.language_model.config, @@ -135,7 +220,7 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): inputs["vision"] = vision_inputs inputs["lang"] = lang_inputs else: - lang_inputs.pop("vit_embeds") + lang_inputs.pop("vision_embeds") inputs = {**vision_inputs, **lang_inputs} return inputs @@ -149,10 +234,17 @@ def get_specializations( kv_offload: bool = False, **compiler_options, ): - prefill_seq_len = prefill_seq_len if prefill_seq_len else SEQ_LEN - ctx_len = ctx_len if ctx_len else CTX_LEN - height = constants.MISTRAL3_HEIGHT - width = constants.MISTRAL3_WIDTH + height = compiler_options.pop("height", None) + width = compiler_options.pop("width", None) + if height is None: + height = self.config.vision_config.image_size + if width is None: + width = self.config.vision_config.image_size + prefill_seq_len = prefill_seq_len if prefill_seq_len else 128 + ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN + patch_size = self.config.vision_config.patch_size + kernel_size = self.config.spatial_merge_size + vision_size = ((height // patch_size) * (width // patch_size)) // (kernel_size * kernel_size) vision = [ { @@ -161,6 +253,7 @@ def get_specializations( "ctx_len": ctx_len, "height": height, "width": width, + "vision_size": vision_size, } ] lang = [ @@ -170,6 +263,7 @@ def get_specializations( "ctx_len": ctx_len, "height": height, "width": width, + "vision_size": vision_size, }, { "batch_size": batch_size, @@ -177,8 +271,10 @@ def get_specializations( "ctx_len": ctx_len, "height": height, "width": width, + "vision_size": vision_size, }, ] + specializations = {} if kv_offload: @@ -186,6 +282,9 @@ def get_specializations( specializations["lang"] = lang return specializations, compiler_options else: + # return vision, compiler_options + lang[0].pop("vision_size") + lang[1].pop("vision_size") return lang, compiler_options def get_onnx_dynamic_axes(self, kv_offload: bool = False): @@ -198,6 +297,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, + "vision_embeds": {0: "vision_size"}, } for i in range(num_layers): @@ -209,11 +309,13 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): dynamic_axes["vision"] = vision_dynamic_axes dynamic_axes["lang"] = lang_dynamic_axes else: + lang_dynamic_axes.pop("vision_embeds") dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes} + # dynamic_axes = vision_dynamic_axes return dynamic_axes def get_output_names(self, kv_offload: bool = False): - vision_output_names = ["vit_embeds"] + vision_output_names = ["vision_embeds"] lang_output_names = ["logits"] for i in range(self.language_model.config.num_hidden_layers): for kv in ["key", "value"]: @@ -221,11 +323,13 @@ def get_output_names(self, kv_offload: bool = False): output_names = {} if kv_offload: - lang_output_names.insert(1, "vit_embeds_RetainedState") + lang_output_names.insert(1, "vision_embeds_RetainedState") + lang_output_names.insert(2, "image_idx_output") output_names["vision"] = vision_output_names output_names["lang"] = lang_output_names else: lang_output_names.insert(1, "pixel_values_RetainedState") + lang_output_names.insert(2, "image_idx_output") return lang_output_names return output_names diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 0bcd7c6e6..ba06582af 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -133,7 +133,7 @@ Phi3Model, Phi3RMSNorm, ) -from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm +from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm, PixtralVisionModel from transformers.models.qwen2.modeling_qwen2 import ( Qwen2Attention, Qwen2DecoderLayer, @@ -265,7 +265,10 @@ QEffMistralForCausalLM, QEffMistralModel, ) -from QEfficient.transformers.models.mistral3.modeling_mistral3 import QEffMistral3ForConditionalGeneration +from QEfficient.transformers.models.mistral3.modeling_mistral3 import ( + QEffMistral3ForConditionalGeneration, + QEffPixtralVisionModel, +) from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import ( QEffMixtralAttention, QeffMixtralDecoderLayer, @@ -344,9 +347,9 @@ class CustomOpsTransform(ModuleMappingTransform): Qwen2RMSNorm: CustomRMSNormAIC, MllamaTextRMSNorm: CustomRMSNormAIC, GraniteRMSNorm: CustomRMSNormAIC, + PixtralRMSNorm: CustomRMSNormAIC, GraniteMoeRMSNorm: CustomRMSNormAIC, Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC, - PixtralRMSNorm: CustomRMSNormAIC, } @@ -457,6 +460,8 @@ class KVCacheTransform(ModuleMappingTransform): PhiDecoderLayer: QEffPhiDecoderLayer, PhiModel: QEffPhiModel, PhiForCausalLM: QEffPhiForCausalLM, + # Pixtral + PixtralVisionModel: QEffPixtralVisionModel, # Qwen2 Qwen2Attention: QEffQwen2Attention, Qwen2DecoderLayer: QEffQwen2DecoderLayer, From 3c5a86d35d33d2812fd9f5a716c24e4346f3c282 Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Mon, 7 Jul 2025 08:30:19 +0000 Subject: [PATCH 3/8] Updating modeling and addding example script Signed-off-by: Mohit Soni --- .../models/mistral3/modeling_mistral3.py | 8 +- examples/mistral3_example.py | 109 ++++++++++++++++++ 2 files changed, 115 insertions(+), 2 deletions(-) create mode 100644 examples/mistral3_example.py diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index 73e8a004f..a3ce9d38f 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -174,7 +174,11 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): width = self.config.vision_config.image_size patch_size = self.config.vision_config.patch_size kernel_size = self.config.spatial_merge_size - vision_size = ((height // patch_size) * (width // patch_size)) // (kernel_size * kernel_size) + vision_size = ( + ((height // patch_size) * (width // patch_size)) + * (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE) + // (kernel_size * kernel_size) + ) inputs_shapes["vision_embeds"] = ( vision_size, self.language_model.config.hidden_size, @@ -244,7 +248,7 @@ def get_specializations( ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN patch_size = self.config.vision_config.patch_size kernel_size = self.config.spatial_merge_size - vision_size = ((height // patch_size) * (width // patch_size)) // (kernel_size * kernel_size) + vision_size = ((height // patch_size) * (width // patch_size)) * (batch_size) // (kernel_size * kernel_size) vision = [ { diff --git a/examples/mistral3_example.py b/examples/mistral3_example.py new file mode 100644 index 000000000..83f170ed8 --- /dev/null +++ b/examples/mistral3_example.py @@ -0,0 +1,109 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +from PIL import Image +from transformers import AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +# Add HuggingFace Token to access the model +HF_TOKEN = "" + + +def run_model( + model_name, + token, + query, + image_url, + kv_offload=False, + prefill_seq_len=128, + ctx_len=4096, + generation_len=128, + img_size=1540, + num_cores=16, + num_devices=4, +): + ## STEP - 1 Load the Processor and Model + + processor = AutoProcessor.from_pretrained(model_name) + + # `kv_offload` is used to compile the model in a 2 QPCs.Currently we are not supporting 1 qpc so the flag false is not allowed. + # The `kv_offload` flag should always be set to True. + # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. + # The outputs of the Vision Encoder are then passed to the Language model via host in this case. + + model = QEFFAutoModelForImageTextToText.from_pretrained(model_name, kv_offload=kv_offload) + + ## STEP - 2 Export & Compile the Model + + model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + mxfp6_matmul=False, + ) + + ## STEP - 3 Load and process the inputs for Inference + + # We are resizing the image to (w x h) (1540 x 1540) so that any image can work on the model irrespective of image dimensssions + # we have a fixed size of height 1540 and width 1540 as defined in the config + + image = Image.open(requests.get(image_url, stream=True).raw) + image = image.resize((1540, 1540)) + + messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": query}]}] + input_text = processor.apply_chat_template(messages, add_generation_prompt=True) + inputs = processor(image, input_text, add_special_tokens=False, return_tensors="pt") + + ## STEP - 4 Run Inference on the compiled model + + streamer = TextStreamer(processor.tokenizer) + output = model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len) + print(output) + + +if __name__ == "__main__": + # Model name and Input parameters + model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" + + # Please add prompt here + query = "Describe the image" + + # Please pass image url or image path .The format of the image should be jpg. + image_url = "https://www.ilankelman.org/stopsigns/australia.jpg" + + # Compilation parameters for the model + kv_offload = True + prefill_seq_len = 128 + ctx_len = 4096 + generation_len = 128 + num_cores = 16 + num_devices = 4 + + run_model( + model_name=model_name, + token=HF_TOKEN, + query=query, + kv_offload=kv_offload, + image_url=image_url, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + generation_len=generation_len, + num_cores=num_cores, + num_devices=num_devices, + ) + + +""" +Expected Response: + + + +""" From dad4315e2bbfc668a088332e2ee3328da5a6ef0a Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Mon, 7 Jul 2025 09:40:22 +0000 Subject: [PATCH 4/8] Adding test file Signed-off-by: Mohit Soni --- examples/mistral3_example.py | 3 ++- .../models/test_image_text_to_text_models.py | 25 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/examples/mistral3_example.py b/examples/mistral3_example.py index 83f170ed8..4fee69dff 100644 --- a/examples/mistral3_example.py +++ b/examples/mistral3_example.py @@ -103,7 +103,8 @@ def run_model( """ Expected Response: - +The image depicts a street scene in what appears to be a Chinatown district. The focal point is a traditional Chinese archway, known as a paifang, which is intricately designed with red columns and ornate details. The archway features Chinese characters at the top, which translate to "Chinatown Gate." +In the foreground, there is a red stop sign mounted on a pole. The street is relatively quiet, with a single dark-colored SUV driving through the archway. On either side of the archway, there are stone lion statues, which are common decorative elements in Chinese architecture and symbolize protection. """ diff --git a/tests/transformers/models/test_image_text_to_text_models.py b/tests/transformers/models/test_image_text_to_text_models.py index 54f167281..221725491 100644 --- a/tests/transformers/models/test_image_text_to_text_models.py +++ b/tests/transformers/models/test_image_text_to_text_models.py @@ -110,6 +110,28 @@ "Can you describe the image in detail.", 1, ), + ( + "mistralai/Mistral-Small-3.1-24B-Instruct-2503", + True, + 1, + 128, + 4096, + 1540, + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "Can you describe the image in detail.", + 1, + ), + ( + "mistralai/Mistral-Small-3.1-24B-Instruct-2503", + False, + 1, + 128, + 4096, + 1540, + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "Can you describe the image in detail.", + 1, + ), # ( # "meta-llama/Llama-3.2-11B-Vision-Instruct", # True, @@ -212,6 +234,9 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( n_layer = get_num_layers_vlm(config) image = Image.open(requests.get(img_url, stream=True).raw) + if model_name == "mistralai/Mistral-Small-3.1-24B-Instruct-2503": + image = image.resize((1540, 1540)) + conversation = [ { "role": "user", From d961aa33f93ade65b86c6efe79a0c9cab4dda5df Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Mon, 7 Jul 2025 10:09:51 +0000 Subject: [PATCH 5/8] Updating utils Signed-off-by: Mohit Soni --- QEfficient/transformers/modeling_utils.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index 9826f4cff..0a0e4d54b 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -58,10 +58,6 @@ MistralModel, MistralRMSNorm, ) -from transformers.models.mistral3.modeling_mistral3 import ( - Mistral3ForConditionalGeneration, - Mistral3RMSNorm, -) from transformers.models.mixtral.modeling_mixtral import ( MixtralAttention, MixtralDecoderLayer, @@ -74,7 +70,6 @@ from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel from transformers.models.phi.modeling_phi import PhiAttention, PhiForCausalLM, PhiModel from transformers.models.phi3.modeling_phi3 import Phi3Attention, Phi3ForCausalLM, Phi3Model, Phi3RMSNorm -from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2ForCausalLM, Qwen2Model, Qwen2RMSNorm from transformers.models.starcoder2.modeling_starcoder2 import ( Starcoder2Attention, @@ -93,7 +88,6 @@ ) from QEfficient.customop import CustomRMSNormAIC -from QEfficient.transformers.models.mistral3.modeling_mistral3 import QEffMistral3ForConditionalGeneration # Placeholder for all non-transformer models from .models.codegen.modeling_codegen import ( @@ -185,7 +179,6 @@ GPTBigCodeForCausalLM.__name__, MllamaForCausalLM.__name__, WhisperForConditionalGeneration.__name__, - Mistral3ForConditionalGeneration.__name__, ] ) @@ -237,9 +230,6 @@ MistralModel: QEffMistralModel, MistralForCausalLM: QEffMistralForCausalLM, MistralRMSNorm: CustomRMSNormAIC, - # Mistral3 model layers - Mistral3ForConditionalGeneration: QEffMistral3ForConditionalGeneration, - Mistral3RMSNorm: CustomRMSNormAIC, # Mixtral model layers MixtralAttention: QEffMixtralAttention, MixtralDecoderLayer: QeffMixtralDecoderLayer, @@ -256,8 +246,6 @@ PhiAttention: QEffPhiAttention, PhiModel: QEffPhiModel, PhiForCausalLM: QEffPhiForCausalLM, - # Pixtral model layers - PixtralRMSNorm: CustomRMSNormAIC, # Falcon model layers FalconAttention: QEffFalconAttention, FalconForCausalLM: QEffFalconForCausalLM, From 5f3d9bd59741edeae16f2d0f24ca601fb9cc186e Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Tue, 8 Jul 2025 08:47:56 +0000 Subject: [PATCH 6/8] Minor Changes Signed-off-by: Mohit Soni --- .../models/mistral3/modeling_mistral3.py | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index a3ce9d38f..96fc4b653 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -16,7 +16,7 @@ from QEfficient.utils import constants from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config - +from QEfficient.utils.logging_utils import logger def custom_cumsum(tensor): dim = 0 @@ -238,25 +238,24 @@ def get_specializations( kv_offload: bool = False, **compiler_options, ): - height = compiler_options.pop("height", None) - width = compiler_options.pop("width", None) - if height is None: - height = self.config.vision_config.image_size - if width is None: - width = self.config.vision_config.image_size + + if img_size is None and hasattr(self.config.vision_config, "image_size"): + img_size = getattr(self.config.vision_config, "image_size") + elif img_size is None: + img_size = 1540 # FIXME based on Mistral3 Image size + logger.warning("Setting img_size to be 1540, as it was neither passed nor found in vision_config") prefill_seq_len = prefill_seq_len if prefill_seq_len else 128 ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN patch_size = self.config.vision_config.patch_size kernel_size = self.config.spatial_merge_size - vision_size = ((height // patch_size) * (width // patch_size)) * (batch_size) // (kernel_size * kernel_size) + vision_size = ((img_size // patch_size) * (img_size // patch_size)) * (batch_size) // (kernel_size * kernel_size) vision = [ { "batch_size": batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, - "height": height, - "width": width, + "image_size": img_size, "vision_size": vision_size, } ] @@ -265,16 +264,14 @@ def get_specializations( "batch_size": batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, - "height": height, - "width": width, + "image_size": img_size, "vision_size": vision_size, }, { "batch_size": batch_size, "seq_len": "1", "ctx_len": ctx_len, - "height": height, - "width": width, + "image_size": img_size, "vision_size": vision_size, }, ] @@ -296,7 +293,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): num_layers = self.config.text_config.num_hidden_layers vision_dynamic_axes = { - "pixel_values": {0: "batch_size", 2: "height", 3: "width"}, + "pixel_values": {0: "batch_size", 2: "image_size", 3: "image_size"}, } lang_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, @@ -341,5 +338,5 @@ def get_inputs_info(self): return [ IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")), IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")), - IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "height", "width")), + IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "image_size", "image_size")), ] From 7310d9268c7bbf2747c0dc947e3b37bec9e7c258 Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Tue, 8 Jul 2025 08:51:24 +0000 Subject: [PATCH 7/8] Revert "Minor Changes" This reverts commit fed3a56b6231f38d64f564a75ea21085c8283018. Signed-off-by: Mohit Soni --- .../models/mistral3/modeling_mistral3.py | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index 96fc4b653..a3ce9d38f 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -16,7 +16,7 @@ from QEfficient.utils import constants from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config -from QEfficient.utils.logging_utils import logger + def custom_cumsum(tensor): dim = 0 @@ -238,24 +238,25 @@ def get_specializations( kv_offload: bool = False, **compiler_options, ): - - if img_size is None and hasattr(self.config.vision_config, "image_size"): - img_size = getattr(self.config.vision_config, "image_size") - elif img_size is None: - img_size = 1540 # FIXME based on Mistral3 Image size - logger.warning("Setting img_size to be 1540, as it was neither passed nor found in vision_config") + height = compiler_options.pop("height", None) + width = compiler_options.pop("width", None) + if height is None: + height = self.config.vision_config.image_size + if width is None: + width = self.config.vision_config.image_size prefill_seq_len = prefill_seq_len if prefill_seq_len else 128 ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN patch_size = self.config.vision_config.patch_size kernel_size = self.config.spatial_merge_size - vision_size = ((img_size // patch_size) * (img_size // patch_size)) * (batch_size) // (kernel_size * kernel_size) + vision_size = ((height // patch_size) * (width // patch_size)) * (batch_size) // (kernel_size * kernel_size) vision = [ { "batch_size": batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, - "image_size": img_size, + "height": height, + "width": width, "vision_size": vision_size, } ] @@ -264,14 +265,16 @@ def get_specializations( "batch_size": batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, - "image_size": img_size, + "height": height, + "width": width, "vision_size": vision_size, }, { "batch_size": batch_size, "seq_len": "1", "ctx_len": ctx_len, - "image_size": img_size, + "height": height, + "width": width, "vision_size": vision_size, }, ] @@ -293,7 +296,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): num_layers = self.config.text_config.num_hidden_layers vision_dynamic_axes = { - "pixel_values": {0: "batch_size", 2: "image_size", 3: "image_size"}, + "pixel_values": {0: "batch_size", 2: "height", 3: "width"}, } lang_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, @@ -338,5 +341,5 @@ def get_inputs_info(self): return [ IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")), IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")), - IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "image_size", "image_size")), + IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "height", "width")), ] From 2986422d209d990dbec9a827066118028253ed24 Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Tue, 8 Jul 2025 08:57:44 +0000 Subject: [PATCH 8/8] Minor Fixes Signed-off-by: Mohit Soni --- .../models/mistral3/modeling_mistral3.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index a3ce9d38f..e28869cf5 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -16,6 +16,7 @@ from QEfficient.utils import constants from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config +from QEfficient.utils.logging_utils import logger def custom_cumsum(tensor): @@ -238,25 +239,25 @@ def get_specializations( kv_offload: bool = False, **compiler_options, ): - height = compiler_options.pop("height", None) - width = compiler_options.pop("width", None) - if height is None: - height = self.config.vision_config.image_size - if width is None: - width = self.config.vision_config.image_size + if img_size is None and hasattr(self.config.vision_config, "image_size"): + img_size = getattr(self.config.vision_config, "image_size") + elif img_size is None: + img_size = 1540 # FIXME based on mistral3 Image size + logger.warning("Setting img_size to be 1540, as it was neither passed nor found in vision_config") prefill_seq_len = prefill_seq_len if prefill_seq_len else 128 ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN patch_size = self.config.vision_config.patch_size kernel_size = self.config.spatial_merge_size - vision_size = ((height // patch_size) * (width // patch_size)) * (batch_size) // (kernel_size * kernel_size) + vision_size = ( + ((img_size // patch_size) * (img_size // patch_size)) * (batch_size) // (kernel_size * kernel_size) + ) vision = [ { "batch_size": batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, - "height": height, - "width": width, + "image_size": img_size, "vision_size": vision_size, } ] @@ -265,16 +266,14 @@ def get_specializations( "batch_size": batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, - "height": height, - "width": width, + "image_size": img_size, "vision_size": vision_size, }, { "batch_size": batch_size, "seq_len": "1", "ctx_len": ctx_len, - "height": height, - "width": width, + "image_size": img_size, "vision_size": vision_size, }, ] @@ -296,7 +295,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): num_layers = self.config.text_config.num_hidden_layers vision_dynamic_axes = { - "pixel_values": {0: "batch_size", 2: "height", 3: "width"}, + "pixel_values": {0: "batch_size", 2: "image_size", 3: "image_size"}, } lang_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, @@ -341,5 +340,5 @@ def get_inputs_info(self): return [ IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")), IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")), - IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "height", "width")), + IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "image_size", "image_size")), ]