-
Notifications
You must be signed in to change notification settings - Fork 134
Enable chunked prefill on aice 1.22 #2070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: aice/v1.22.0
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| import os | ||
| os.environ["VLLM_SKIP_WARMUP"] = "true" | ||
| os.environ['VLLM_CONTIGUOUS_PA'] = 'false' | ||
| os.environ['VLLM_MLA_DISABLE_REQUANTIZATION']='1' | ||
| os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES']='true' | ||
| os.environ['PT_HPU_WEIGHT_SHARING']='0' | ||
| os.environ['VLLM_MLA_PERFORM_MATRIX_ABSORPTION']='0' | ||
| os.environ['VLLM_MTP_PRINT_ACCPET_RATE']='0' | ||
| os.environ['PT_HPU_LAZY_MODE']='1' | ||
| os.environ['VLLM_DELAYED_SAMPLING']='false' | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does chunked prefill conflict with delayed sampling? |
||
| #os.environ['VLLM_USE_V1']='1' | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
|
|
||
| # Sample prompts. | ||
| prompts = [ | ||
| "Hello, my name is", | ||
| "The president of the United States is", | ||
| "The capital of France is", | ||
| "The future of AI is", | ||
| ] | ||
| # Create a sampling params object. | ||
| sampling_params = SamplingParams(temperature=0.0, max_tokens=128) | ||
|
|
||
| model_name = "/home/HF_models/llama-3-8b" | ||
| llm = LLM(model=model_name, | ||
| trust_remote_code=True, | ||
| enforce_eager=True, | ||
| dtype="bfloat16", | ||
| use_v2_block_manager=True, | ||
| tensor_parallel_size=1, | ||
| max_model_len=1024, | ||
| num_scheduler_steps=1, | ||
| gpu_memory_utilization=0.5, | ||
| max_num_seqs=128, | ||
| enable_chunked_prefill=True, | ||
| max_num_batched_tokens=128, | ||
| seed=2024) | ||
| # Generate texts from the prompts. The output is a list of RequestOutput objects | ||
| # that contain the prompt, generated text, and other information. | ||
| outputs = llm.generate(prompts, sampling_params) | ||
| # Print the outputs. | ||
| for output in outputs: | ||
| prompt = output.prompt | ||
| generated_text = output.outputs[0].text | ||
| print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,7 +9,9 @@ | |
| from dataclasses import dataclass | ||
| from typing import Any, Dict, List, Optional, Tuple, Type | ||
|
|
||
| import habana_frameworks.torch as htorch | ||
| import torch | ||
| import vllm.envs as envs | ||
| import vllm_hpu_extension.kernels as kernels | ||
| import vllm_hpu_extension.ops as ops | ||
| from vllm_hpu_extension.runtime import get_config | ||
|
|
@@ -146,7 +148,21 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): | |
| conv_state_indices: Optional[torch.Tensor] = None | ||
| mamba_cache_decode_indices: Optional[torch.Tensor] = None | ||
| mamba_cache_prefill_indices: Optional[torch.Tensor] = None | ||
|
|
||
| decode_slot_mapping: Optional[torch.Tensor] = None | ||
| decode_block_list: Optional[torch.Tensor] = None | ||
| decode_attn_bias: Optional[torch.Tensor] = None | ||
| chunk_prefill_enabled: bool = False | ||
|
|
||
| class HPUAttentionData: | ||
| query: torch.Tensor = None | ||
| key: torch.Tensor = None | ||
| value: torch.Tensor = None | ||
| key_cache: torch.Tensor = None | ||
| value_cache: torch.Tensor = None | ||
| batch_size: int = 0 | ||
| seq_len: int = 0 | ||
| hidden_size: int = 0 | ||
| seq_len_kv: int = 0 | ||
|
|
||
| @dataclass | ||
| class HPUMLAMetadata(HPUAttentionMetadata, AttentionMetadata): | ||
|
|
@@ -461,6 +477,205 @@ def _maybe_init_alibi_biases( | |
| dtype=self.alibi_slopes.dtype, | ||
| ) | ||
|
|
||
| def preprocess_forward(self, query: torch.Tensor, key: torch.Tensor, | ||
| value: torch.Tensor, kv_cache: torch.Tensor, | ||
| attn_metadata: HPUAttentionMetadata, | ||
| is_prefill: bool) -> HPUAttentionData: | ||
| attn_data: HPUAttentionData = HPUAttentionData() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will be good to add a description on the preprocess_forward API, including the purpose of this API, arguments, return values of the API. |
||
| seq_len = 1 | ||
| slot_mapping = attn_metadata.decode_slot_mapping.flatten( | ||
| ) if attn_metadata.decode_slot_mapping is not None else None | ||
| batch_size = attn_metadata.num_decode_tokens | ||
| if is_prefill: | ||
| seq_len = attn_metadata.num_prefill_tokens //\ | ||
| attn_metadata.num_prefills | ||
| slot_mapping = attn_metadata.slot_mapping.flatten( | ||
| ) if attn_metadata.slot_mapping is not None else None | ||
| batch_size = attn_metadata.num_prefills | ||
| # Convert Flat inputs into 2D Inputs | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wrong comment? should be 3D input, i.e [batch_size, seq_len, hidden_size] |
||
| hidden_size = query.shape[-1] | ||
| query = query.reshape(batch_size, seq_len, hidden_size) | ||
|
|
||
| hidden_size = key.shape[-1] | ||
| key = key.reshape(batch_size, seq_len, hidden_size) | ||
|
|
||
| hidden_size = value.shape[-1] | ||
| value = value.reshape(batch_size, seq_len, hidden_size) | ||
|
|
||
| # Insert key and value to kv cache | ||
| attn_data.batch_size, attn_data.seq_len, attn_data.hidden_size\ | ||
| = query.shape | ||
| _, attn_data.seq_len_kv, _ = key.shape | ||
| query = query.view(-1, self.num_heads, self.head_size) | ||
| key = key.view(-1, self.num_kv_heads, self.head_size) | ||
| value = value.view(-1, self.num_kv_heads, self.head_size) | ||
|
|
||
| if kv_cache is not None: | ||
| key_cache, value_cache = HPUPagedAttention.split_kv_cache( | ||
| kv_cache, self.num_kv_heads, self.head_size) | ||
|
|
||
| # Reshape the input keys and values and store them in the cache. | ||
| # If kv_cache is not provided, the new key and value tensors are | ||
| # not cached. This happens during the initial memory profiling run. | ||
|
|
||
| attn_data.key_cache = self.k_cache(key, | ||
| key_cache, | ||
| slot_mapping) | ||
| attn_data.value_cache = self.v_cache(value, | ||
| value_cache, | ||
| slot_mapping) | ||
| attn_data.key = key | ||
| attn_data.value = value | ||
| attn_data.query = query | ||
| return attn_data | ||
|
|
||
| def forward_chunked_prefill( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where is the chunk happened? don't see the for loop in this function. |
||
| self, | ||
| layer: AttentionLayer, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| kv_cache: torch.Tensor, | ||
| attn_metadata: HPUAttentionMetadata, | ||
| output: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| """Forward pass with xFormers and PagedAttention. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wrong comment? |
||
|
|
||
| Args: | ||
| query: shape = [num_tokens, num_heads * head_size] | ||
| key: shape = [num_tokens, num_kv_heads * head_size] | ||
| value: shape = [num_tokens, num_kv_heads * head_size] | ||
| kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] | ||
| attn_metadata: Metadata for attention. | ||
| Returns: | ||
| shape = [num_tokens, num_heads * head_size] | ||
| """ | ||
| assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 | ||
|
|
||
| if (self.attn_type != AttentionType.DECODER): | ||
| raise NotImplementedError("Chunked Prefill Enabled" | ||
| "only for Decoder") | ||
| prompt_output: torch.Tensor = None | ||
| decode_output: torch.Tensor = None | ||
| prefill_batch_size = 0 | ||
| prefill_seq_len = 0 | ||
| prefill_hidden_size = 0 | ||
| decode_batch_size = 0 | ||
| decode_seq_len = 0 | ||
| decode_hidden_size = 0 | ||
| if attn_metadata.num_prefills > 0: | ||
| attn_data = self.preprocess_forward( | ||
| query[:attn_metadata.num_prefill_tokens], | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need put kvcache for all the tokens? or only chunked tokens are enough? |
||
| key[:attn_metadata.num_prefill_tokens], | ||
| value[:attn_metadata.num_prefill_tokens], kv_cache, | ||
| attn_metadata, True) | ||
| # Prompt run. | ||
| prefill_batch_size = attn_data.batch_size | ||
| prefill_seq_len = attn_data.seq_len | ||
| prefill_hidden_size = attn_data.hidden_size | ||
| query_shape = (prefill_batch_size, prefill_seq_len, self.num_heads, | ||
| self.head_size) | ||
| kv_shape = (prefill_batch_size, attn_data.seq_len_kv, | ||
| self.num_kv_heads, self.head_size) | ||
|
|
||
| if attn_metadata is None or attn_metadata.block_list is None: | ||
|
|
||
| block_list = attn_metadata.block_list if attn_metadata \ | ||
| and attn_metadata.block_list is not None else None | ||
|
|
||
| common_args = self.common_attention_args(block_list, attn_data.key_cache, | ||
| attn_data.value_cache, | ||
| attn_metadata.block_size) | ||
| attn_bias = attn_metadata.attn_bias | ||
| position_bias = None | ||
|
|
||
| out = ops.prompt_attention( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pls improve code format |
||
| impl=self.prefill_impl, | ||
| query=attn_data.query.view(query_shape), | ||
| key=attn_data.key.view(kv_shape), | ||
| value=attn_data.value.view(kv_shape), | ||
| is_causal=True, | ||
| position_bias=position_bias, | ||
| valid_seq_lengths=attn_metadata.seq_lens_tensor, | ||
| **common_args) | ||
|
|
||
| else: | ||
| # TODO: enable FusedSDPA | ||
| block_list = attn_metadata.block_list if attn_metadata \ | ||
| and attn_metadata.block_list is not None else None | ||
|
|
||
| common_args = self.common_attention_args(block_list, attn_data.key_cache, | ||
| attn_data.value_cache, | ||
| attn_metadata.block_size) | ||
| attn_bias = attn_metadata.attn_bias | ||
| position_bias = None | ||
|
|
||
| if envs.VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT: | ||
| assert prefill_batch_size == 1, ( | ||
| "Only batch size 1 is supported for chunked prefill " | ||
| "with dynamic block list." | ||
| ) | ||
| key_attn = attn_data.key.view(kv_shape) | ||
| value_attn = attn_data.value.view(kv_shape) | ||
| common_args['need_context'] = True | ||
| else: | ||
| key_attn = self.k_cache.fetch_from_cache( | ||
| attn_data.key_cache.unflatten(0, (-1, attn_metadata.block_size)), | ||
| attn_metadata.block_list).view(kv_shape) | ||
| value_attn = self.v_cache.fetch_from_cache( | ||
| attn_data.value_cache.unflatten(0, (-1, attn_metadata.block_size)), | ||
| attn_metadata.block_list).view(kv_shape) | ||
|
|
||
| out = ops.prompt_attention( | ||
| impl=self.prefill_impl, | ||
| query=attn_data.query.view(query_shape), | ||
|
|
||
| key=key_attn, | ||
| value=value_attn, | ||
| is_causal=False, | ||
| attn_bias=attn_bias, | ||
| position_bias=position_bias, | ||
| **common_args) | ||
|
|
||
| prompt_output = out.reshape(prefill_batch_size, prefill_seq_len, | ||
| prefill_hidden_size) | ||
| htorch.core.mark_step() | ||
| if attn_metadata.num_decode_tokens > 0: | ||
| # Decoding run. | ||
| attn_data = self.preprocess_forward( | ||
| query[attn_metadata.num_prefill_tokens:], | ||
| key[attn_metadata.num_prefill_tokens:], | ||
| value[attn_metadata.num_prefill_tokens:], kv_cache, | ||
| attn_metadata, False) | ||
| decode_batch_size = attn_data.batch_size | ||
| decode_seq_len = attn_data.seq_len | ||
| decode_hidden_size = attn_data.hidden_size | ||
| decode_output = HPUPagedAttention.forward_decode( | ||
| query=attn_data.query.view(attn_data.batch_size, attn_data.seq_len, attn_data.hidden_size), | ||
| block_mapping=attn_metadata.block_mapping, | ||
| block_bias=attn_metadata.decode_attn_bias, | ||
| block_groups=attn_metadata.block_groups, | ||
| position_bias=None, | ||
| **self.common_attention_args(attn_metadata.decode_block_list, attn_data.key_cache, | ||
| attn_data.value_cache, | ||
| attn_metadata.block_size)) | ||
| htorch.core.mark_step() | ||
| # Reshape the output tensor. | ||
| if decode_output is None: | ||
| prompt_output = prompt_output.view( | ||
| prefill_batch_size , prefill_seq_len, prefill_hidden_size) | ||
| return prompt_output | ||
| elif prompt_output is None: | ||
| return decode_output.view(decode_batch_size * decode_seq_len, | ||
| decode_hidden_size) | ||
| else: | ||
| prompt_output = prompt_output.view( | ||
| prefill_batch_size * prefill_seq_len, prefill_hidden_size) | ||
| decode_output = decode_output.view( | ||
| decode_batch_size * decode_seq_len, decode_hidden_size) | ||
| output = torch.cat((prompt_output, decode_output)) | ||
| htorch.core.mark_step() | ||
| return output | ||
| def forward( | ||
| self, | ||
| layer: AttentionLayer, | ||
|
|
@@ -483,6 +698,16 @@ def forward( | |
| shape = [num_tokens, num_heads * head_size] | ||
| """ | ||
| assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 | ||
| if attn_metadata.chunk_prefill_enabled: | ||
| return self.forward_chunked_prefill( | ||
| layer=layer, | ||
| query=query, | ||
| key=key, | ||
| value=value, | ||
| kv_cache=kv_cache, | ||
| attn_metadata=attn_metadata, | ||
| output=output, | ||
| ) | ||
| if self.attn_type == AttentionType.ENCODER_DECODER: | ||
| return self.forward_encoder_decoder( | ||
| query=query, | ||
|
|
@@ -815,3 +1040,4 @@ def _make_decode_alibi_bias( | |
| per_head_bias.mul_(alibi_slopes[None, :, None]) | ||
|
|
||
| return per_head_bias | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -127,6 +127,7 @@ | |
| VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 | ||
| VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 | ||
| VLLM_SLEEP_WHEN_IDLE: bool = False | ||
| VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT: bool = False | ||
|
|
||
|
|
||
| def get_default_cache_root(): | ||
|
|
@@ -872,6 +873,10 @@ def get_vllm_port() -> Optional[int]: | |
| # latency penalty when a request eventually comes. | ||
| "VLLM_SLEEP_WHEN_IDLE": | ||
| lambda: bool(int(os.getenv("VLLM_SLEEP_WHEN_IDLE", "0"))), | ||
|
|
||
| # Use chunked prefill with dynamic input shapes for HPU backend. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the meaning of VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT? when should this be set? |
||
| "VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT": | ||
| lambda: bool(int(os.getenv("VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT", "0"))), | ||
| } | ||
|
|
||
| # --8<-- [end:env-vars-definition] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need below env vars for aice/v1.22.0 branch?
os.environ['VLLM_MLA_DISABLE_REQUANTIZATION']='1'
os.environ['VLLM_MLA_PERFORM_MATRIX_ABSORPTION']='0'
os.environ['VLLM_MTP_PRINT_ACCPET_RATE']='0'