-
Notifications
You must be signed in to change notification settings - Fork 46
Llama4 VLM Continuous Batching Support #510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -826,7 +826,15 @@ def __init__(self, model): | |
self.language_model = self.model.language_model | ||
self.config = self.model.config | ||
|
||
def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): | ||
def forward( | ||
self, | ||
input_ids, | ||
vision_embeds, | ||
position_ids, | ||
image_idx, | ||
past_key_values, | ||
batch_index: Optional[torch.LongTensor] = None, | ||
): | ||
inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) | ||
selected = input_ids == self.model.config.image_token_index | ||
indices1 = selected.to(torch.int64).cumsum(1) - 1 | ||
|
@@ -836,7 +844,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va | |
image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) | ||
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) | ||
outputs = self.model.language_model( | ||
inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True | ||
inputs_embeds=inputs_embeds, | ||
position_ids=position_ids, | ||
past_key_values=past_key_values, | ||
batch_index=batch_index, | ||
use_cache=True, | ||
) | ||
next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) | ||
image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) | ||
|
@@ -883,6 +895,9 @@ def get_specializations( | |
ctx_len: int, | ||
img_size: int, | ||
kv_offload: bool = False, | ||
continuous_batching: bool = False, | ||
kv_cache_batch_size: Optional[int] = None, | ||
full_batch_size: Optional[int] = None, | ||
**compiler_options, | ||
): | ||
max_num_tiles = compiler_options.pop("max_num_tiles", None) | ||
|
@@ -939,28 +954,42 @@ def get_specializations( | |
"img_size": img_size, | ||
} | ||
] | ||
lang = [ | ||
{ | ||
"batch_size": batch_size, | ||
"seq_len": prefill_seq_len, | ||
"ctx_len": ctx_len, | ||
"max_num_tiles": max_num_tiles, | ||
"img_size": img_size, | ||
"vision_size": vision_size, | ||
"chunk_length": prefill_seq_len, | ||
"chunk_ctx_len": chunk_ctx_len, | ||
}, | ||
{ | ||
"batch_size": batch_size, | ||
"seq_len": "1", | ||
"ctx_len": ctx_len, | ||
"max_num_tiles": max_num_tiles, | ||
"img_size": img_size, | ||
"vision_size": vision_size, | ||
"chunk_length": prefill_seq_len, | ||
"chunk_ctx_len": chunk_ctx_len, | ||
}, | ||
] | ||
|
||
lang_prefill = { | ||
"batch_size": 1 if continuous_batching else batch_size, | ||
"seq_len": prefill_seq_len, | ||
"ctx_len": ctx_len, | ||
"max_num_tiles": max_num_tiles, | ||
"img_size": img_size, | ||
"vision_size": vision_size, | ||
"chunk_length": prefill_seq_len, | ||
"chunk_ctx_len": chunk_ctx_len, | ||
} | ||
if continuous_batching: | ||
lang_prefill["full_batch_size"] = kv_cache_batch_size | ||
else: | ||
lang_prefill["batch_size"] = kv_cache_batch_size | ||
if full_batch_size: | ||
lang_prefill["full_batch_exec_size"] = full_batch_size | ||
|
||
lang_decode = { | ||
"batch_size": full_batch_size if continuous_batching else batch_size, | ||
"seq_len": 1, | ||
"ctx_len": ctx_len, | ||
"max_num_tiles": max_num_tiles, | ||
"img_size": img_size, | ||
"vision_size": vision_size, | ||
"chunk_length": prefill_seq_len, | ||
"chunk_ctx_len": chunk_ctx_len, | ||
} | ||
if continuous_batching: | ||
lang_decode["full_batch_size"] = kv_cache_batch_size | ||
else: | ||
lang_decode["batch_size"] = kv_cache_batch_size | ||
|
||
lang = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: better to use lang_specialization = [lang_prefill, lang_decode] for better readability |
||
lang.append(lang_prefill) | ||
lang.append(lang_decode) | ||
|
||
specializations = {} | ||
|
||
|
@@ -969,18 +998,22 @@ def get_specializations( | |
specializations["lang"] = lang | ||
return specializations, compiler_options | ||
else: | ||
lang[0].pop("vision_size") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. better to use a loop and do the pop operation rather than performing it for each index |
||
lang[1].pop("vision_size") | ||
return lang, compiler_options | ||
|
||
def get_onnx_dynamic_axes(self, kv_offload: bool = False): | ||
def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: bool = False): | ||
# Define dynamic axes | ||
vision_dynamic_axes = {} | ||
lang_dynamic_axes = {} | ||
lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} | ||
lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} | ||
lang_dynamic_axes["vision_embeds"] = {0: "vision_size"} | ||
if continuous_batching: | ||
lang_dynamic_axes["batch_index"] = {0: "batch_size"} | ||
vision_dynamic_axes["pixel_values"] = {0: "max_num_tiles", 2: "img_size", 3: "img_size"} | ||
|
||
pkv_dynamic_axes = {0: "batch_size"} | ||
pkv_dynamic_axes = {0: "full_batch_size" if continuous_batching else "batch_size"} | ||
for i in range(self.language_model.config.num_hidden_layers): | ||
# switch between chunk_ctx_len and ctx_len for RoPE and NoPE layers. | ||
if int((i + 1) % 4 != 0): | ||
|
@@ -1043,7 +1076,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): | |
past_key_values.append(pkv) | ||
return past_key_values | ||
|
||
def get_dummy_inputs(self, kv_offload: bool = False): | ||
def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = False): | ||
if vis_cfg := getattr(self.config, "vision_config", None): | ||
img_size = getattr(vis_cfg, "image_size", 336) | ||
else: | ||
|
@@ -1088,10 +1121,14 @@ def get_dummy_inputs(self, kv_offload: bool = False): | |
.repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) | ||
) | ||
lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) | ||
|
||
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE | ||
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS | ||
|
||
# Add data for KV | ||
past_key_values = self.get_dummy_pkv_cache( | ||
config=self.language_model.config, | ||
batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, | ||
batch_size=fbs if continuous_batching else bs, | ||
seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, | ||
) | ||
|
||
|
@@ -1100,6 +1137,8 @@ def get_dummy_inputs(self, kv_offload: bool = False): | |
for kv in ["key", "value"]: | ||
lang_inputs["past_key_values"][i].append(torch.zeros(past_key_values[0][0].shape, dtype=torch.float32)) | ||
|
||
if continuous_batching: | ||
lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) | ||
inputs = {} | ||
if kv_offload: | ||
inputs["vision"] = vision_inputs | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -579,6 +579,7 @@ class _QEffAutoModelForImageTextToTextDualQPC: | |
def __init__( | ||
self, | ||
model: nn.Module, | ||
continuous_batching, | ||
**kwargs, | ||
): | ||
if kwargs.pop("full_batch_size", None): | ||
|
@@ -588,6 +589,7 @@ def __init__( | |
self.model.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) | ||
self.vision_model = QEffVisionEncoderForTextImageToTextModel(model) | ||
self.lang_model = QEffCausalLMForTextImageToTextModel(model) | ||
self.continuous_batching = continuous_batching | ||
self.input_shapes, self.output_names = None, None | ||
|
||
@property | ||
|
@@ -627,8 +629,8 @@ def export( | |
export_dir: Optional[str] = None, | ||
**kwargs, | ||
) -> str: | ||
inputs = self.model.get_dummy_inputs(kv_offload=True) | ||
dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True) | ||
inputs = self.model.get_dummy_inputs(kv_offload=True, continuous_batching=self.continuous_batching) | ||
dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True, continuous_batching=self.continuous_batching) | ||
output_names = self.model.get_output_names(kv_offload=True) | ||
self.vision_model.export( | ||
inputs["vision"], | ||
|
@@ -637,6 +639,9 @@ def export( | |
export_dir, | ||
) | ||
|
||
import ipdb | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove |
||
|
||
ipdb.set_trace() | ||
self.lang_model.export(inputs["lang"], output_names["lang"], dynamic_axes["lang"], export_dir) | ||
return self.onnx_path | ||
|
||
|
@@ -661,14 +666,20 @@ def compile( | |
skip_lang: Optional[bool] = False, | ||
**compiler_options, | ||
) -> str: | ||
if any(param is not None for param in [full_batch_size, kv_cache_batch_size, num_speculative_tokens]): | ||
if skip_lang and skip_vision: | ||
raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") | ||
|
||
if self.continuous_batching and full_batch_size is None: | ||
raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") | ||
|
||
if kv_cache_batch_size and not full_batch_size: | ||
raise ValueError( | ||
f"Expected 'full_batch_size', 'kv_cache_batch_size', 'num_speculative_tokens' to be None but got: " | ||
f"full_batch_size={full_batch_size}, kv_cache_batch_size={kv_cache_batch_size}, num_speculative_tokens={num_speculative_tokens}, " | ||
"KV caching requires continuous batching. Please set `full_batch_size` and " | ||
"enable `continuous_batching=True` in `from_pretrained`." | ||
) | ||
|
||
if skip_lang and skip_vision: | ||
raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") | ||
# Infer kv_cache_batch_size if not provided | ||
kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size | ||
|
||
output_names = self.model.get_output_names(kv_offload=True) | ||
|
||
|
@@ -678,6 +689,9 @@ def compile( | |
ctx_len=ctx_len, | ||
img_size=img_size, | ||
kv_offload=True, | ||
continuous_batching=self.continuous_batching, | ||
kv_cache_batch_size=kv_cache_batch_size, | ||
full_batch_size=full_batch_size, | ||
**compiler_options, | ||
) | ||
|
||
|
@@ -746,6 +760,8 @@ def compile( | |
def generate( | ||
self, | ||
inputs: torch.Tensor, | ||
tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer] = None, | ||
prompts: List[str] = None, | ||
streamer: Optional[TextStreamer] = None, | ||
device_ids: List[int] = None, | ||
runtime_ai100: bool = True, | ||
|
@@ -763,6 +779,14 @@ def generate( | |
""" | ||
if not runtime_ai100: | ||
raise NotImplementedError("PyTorch execution is not supported yet for this model!") | ||
if tokenizer and prompts: | ||
return QEfficient.cloud_ai_100_exec_kv( | ||
tokenizer, | ||
self.lang_model.qpc_path, | ||
prompt=prompts, | ||
device_id=device_ids, | ||
generation_len=generation_len, | ||
) | ||
|
||
return self.kv_offload_generate( | ||
inputs=inputs, device_ids=device_ids, streamer=streamer, generation_len=generation_len | ||
|
@@ -1304,15 +1328,21 @@ class QEFFAutoModelForImageTextToText: | |
|
||
_hf_auto_class = AutoModelForImageTextToText | ||
|
||
def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs): | ||
def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, continuous_batching: bool = False, **kwargs): | ||
if kv_offload: | ||
return _QEffAutoModelForImageTextToTextDualQPC(model, **kwargs) | ||
return _QEffAutoModelForImageTextToTextDualQPC(model, continuous_batching, **kwargs) | ||
else: | ||
return _QEFFAutoModelForImageTextToTextSingleQPC(model, **kwargs) | ||
|
||
@classmethod | ||
@with_replaced_quantizers | ||
def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, **kwargs): | ||
def from_pretrained( | ||
cls, | ||
pretrained_model_name_or_path: str, | ||
kv_offload: Optional[bool] = None, | ||
continuous_batching: bool = False, | ||
**kwargs, | ||
): | ||
"""Used to load models supported by transformers.AutoModelForImageTextToText for Cloud AI 100. | ||
|
||
Args: | ||
|
@@ -1329,12 +1359,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona | |
if kwargs.get("low_cpu_mem_usage", None): | ||
logger.warning("Updating low_cpu_mem_usage=False") | ||
|
||
if kwargs.pop("continuous_batching", None): | ||
NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") | ||
if continuous_batching and not kv_offload: | ||
NotImplementedError("Continuous batching is not supported for kv_offload = False") | ||
|
||
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) | ||
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) | ||
return cls(model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) | ||
return cls( | ||
model, | ||
kv_offload=kv_offload, | ||
continuous_batching=continuous_batching, | ||
pretrained_model_name_or_path=pretrained_model_name_or_path, | ||
**kwargs, | ||
) | ||
|
||
|
||
MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = {"InternVLChatModel": QEFFAutoModelForImageTextToText} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can we update together for prefill and decode on line 985? Do we need 2 separate if conditions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup not needed, I will do both in one if condition.