Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions examples/offline_inference/basic/chunked_prefill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os
os.environ["VLLM_SKIP_WARMUP"] = "true"
os.environ['VLLM_CONTIGUOUS_PA'] = 'false'
os.environ['VLLM_MLA_DISABLE_REQUANTIZATION']='1'
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES']='true'
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need below env vars for aice/v1.22.0 branch?
os.environ['VLLM_MLA_DISABLE_REQUANTIZATION']='1'
os.environ['VLLM_MLA_PERFORM_MATRIX_ABSORPTION']='0'
os.environ['VLLM_MTP_PRINT_ACCPET_RATE']='0'

os.environ['PT_HPU_WEIGHT_SHARING']='0'
os.environ['VLLM_MLA_PERFORM_MATRIX_ABSORPTION']='0'
os.environ['VLLM_MTP_PRINT_ACCPET_RATE']='0'
os.environ['PT_HPU_LAZY_MODE']='1'
os.environ['VLLM_DELAYED_SAMPLING']='false'
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does chunked prefill conflict with delayed sampling?

#os.environ['VLLM_USE_V1']='1'


if __name__ == "__main__":

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, max_tokens=128)

model_name = "/home/HF_models/llama-3-8b"
llm = LLM(model=model_name,
trust_remote_code=True,
enforce_eager=True,
dtype="bfloat16",
use_v2_block_manager=True,
tensor_parallel_size=1,
max_model_len=1024,
num_scheduler_steps=1,
gpu_memory_utilization=0.5,
max_num_seqs=128,
enable_chunked_prefill=True,
max_num_batched_tokens=128,
seed=2024)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

228 changes: 227 additions & 1 deletion vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

import habana_frameworks.torch as htorch
import torch
import vllm.envs as envs
import vllm_hpu_extension.kernels as kernels
import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.runtime import get_config
Expand Down Expand Up @@ -146,7 +148,21 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
conv_state_indices: Optional[torch.Tensor] = None
mamba_cache_decode_indices: Optional[torch.Tensor] = None
mamba_cache_prefill_indices: Optional[torch.Tensor] = None

decode_slot_mapping: Optional[torch.Tensor] = None
decode_block_list: Optional[torch.Tensor] = None
decode_attn_bias: Optional[torch.Tensor] = None
chunk_prefill_enabled: bool = False

class HPUAttentionData:
query: torch.Tensor = None
key: torch.Tensor = None
value: torch.Tensor = None
key_cache: torch.Tensor = None
value_cache: torch.Tensor = None
batch_size: int = 0
seq_len: int = 0
hidden_size: int = 0
seq_len_kv: int = 0

@dataclass
class HPUMLAMetadata(HPUAttentionMetadata, AttentionMetadata):
Expand Down Expand Up @@ -461,6 +477,205 @@ def _maybe_init_alibi_biases(
dtype=self.alibi_slopes.dtype,
)

def preprocess_forward(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
is_prefill: bool) -> HPUAttentionData:
attn_data: HPUAttentionData = HPUAttentionData()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will be good to add a description on the preprocess_forward API, including the purpose of this API, arguments, return values of the API.

seq_len = 1
slot_mapping = attn_metadata.decode_slot_mapping.flatten(
) if attn_metadata.decode_slot_mapping is not None else None
batch_size = attn_metadata.num_decode_tokens
if is_prefill:
seq_len = attn_metadata.num_prefill_tokens //\
attn_metadata.num_prefills
slot_mapping = attn_metadata.slot_mapping.flatten(
) if attn_metadata.slot_mapping is not None else None
batch_size = attn_metadata.num_prefills
# Convert Flat inputs into 2D Inputs
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrong comment? should be 3D input, i.e [batch_size, seq_len, hidden_size]

hidden_size = query.shape[-1]
query = query.reshape(batch_size, seq_len, hidden_size)

hidden_size = key.shape[-1]
key = key.reshape(batch_size, seq_len, hidden_size)

hidden_size = value.shape[-1]
value = value.reshape(batch_size, seq_len, hidden_size)

# Insert key and value to kv cache
attn_data.batch_size, attn_data.seq_len, attn_data.hidden_size\
= query.shape
_, attn_data.seq_len_kv, _ = key.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

if kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.

attn_data.key_cache = self.k_cache(key,
key_cache,
slot_mapping)
attn_data.value_cache = self.v_cache(value,
value_cache,
slot_mapping)
attn_data.key = key
attn_data.value = value
attn_data.query = query
return attn_data

def forward_chunked_prefill(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is the chunk happened? don't see the for loop in this function.

self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrong comment?


Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0

if (self.attn_type != AttentionType.DECODER):
raise NotImplementedError("Chunked Prefill Enabled"
"only for Decoder")
prompt_output: torch.Tensor = None
decode_output: torch.Tensor = None
prefill_batch_size = 0
prefill_seq_len = 0
prefill_hidden_size = 0
decode_batch_size = 0
decode_seq_len = 0
decode_hidden_size = 0
if attn_metadata.num_prefills > 0:
attn_data = self.preprocess_forward(
query[:attn_metadata.num_prefill_tokens],
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need put kvcache for all the tokens? or only chunked tokens are enough?

key[:attn_metadata.num_prefill_tokens],
value[:attn_metadata.num_prefill_tokens], kv_cache,
attn_metadata, True)
# Prompt run.
prefill_batch_size = attn_data.batch_size
prefill_seq_len = attn_data.seq_len
prefill_hidden_size = attn_data.hidden_size
query_shape = (prefill_batch_size, prefill_seq_len, self.num_heads,
self.head_size)
kv_shape = (prefill_batch_size, attn_data.seq_len_kv,
self.num_kv_heads, self.head_size)

if attn_metadata is None or attn_metadata.block_list is None:

block_list = attn_metadata.block_list if attn_metadata \
and attn_metadata.block_list is not None else None

common_args = self.common_attention_args(block_list, attn_data.key_cache,
attn_data.value_cache,
attn_metadata.block_size)
attn_bias = attn_metadata.attn_bias
position_bias = None

out = ops.prompt_attention(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls improve code format

impl=self.prefill_impl,
query=attn_data.query.view(query_shape),
key=attn_data.key.view(kv_shape),
value=attn_data.value.view(kv_shape),
is_causal=True,
position_bias=position_bias,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
**common_args)

else:
# TODO: enable FusedSDPA
block_list = attn_metadata.block_list if attn_metadata \
and attn_metadata.block_list is not None else None

common_args = self.common_attention_args(block_list, attn_data.key_cache,
attn_data.value_cache,
attn_metadata.block_size)
attn_bias = attn_metadata.attn_bias
position_bias = None

if envs.VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT:
assert prefill_batch_size == 1, (
"Only batch size 1 is supported for chunked prefill "
"with dynamic block list."
)
key_attn = attn_data.key.view(kv_shape)
value_attn = attn_data.value.view(kv_shape)
common_args['need_context'] = True
else:
key_attn = self.k_cache.fetch_from_cache(
attn_data.key_cache.unflatten(0, (-1, attn_metadata.block_size)),
attn_metadata.block_list).view(kv_shape)
value_attn = self.v_cache.fetch_from_cache(
attn_data.value_cache.unflatten(0, (-1, attn_metadata.block_size)),
attn_metadata.block_list).view(kv_shape)

out = ops.prompt_attention(
impl=self.prefill_impl,
query=attn_data.query.view(query_shape),

key=key_attn,
value=value_attn,
is_causal=False,
attn_bias=attn_bias,
position_bias=position_bias,
**common_args)

prompt_output = out.reshape(prefill_batch_size, prefill_seq_len,
prefill_hidden_size)
htorch.core.mark_step()
if attn_metadata.num_decode_tokens > 0:
# Decoding run.
attn_data = self.preprocess_forward(
query[attn_metadata.num_prefill_tokens:],
key[attn_metadata.num_prefill_tokens:],
value[attn_metadata.num_prefill_tokens:], kv_cache,
attn_metadata, False)
decode_batch_size = attn_data.batch_size
decode_seq_len = attn_data.seq_len
decode_hidden_size = attn_data.hidden_size
decode_output = HPUPagedAttention.forward_decode(
query=attn_data.query.view(attn_data.batch_size, attn_data.seq_len, attn_data.hidden_size),
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.decode_attn_bias,
block_groups=attn_metadata.block_groups,
position_bias=None,
**self.common_attention_args(attn_metadata.decode_block_list, attn_data.key_cache,
attn_data.value_cache,
attn_metadata.block_size))
htorch.core.mark_step()
# Reshape the output tensor.
if decode_output is None:
prompt_output = prompt_output.view(
prefill_batch_size , prefill_seq_len, prefill_hidden_size)
return prompt_output
elif prompt_output is None:
return decode_output.view(decode_batch_size * decode_seq_len,
decode_hidden_size)
else:
prompt_output = prompt_output.view(
prefill_batch_size * prefill_seq_len, prefill_hidden_size)
decode_output = decode_output.view(
decode_batch_size * decode_seq_len, decode_hidden_size)
output = torch.cat((prompt_output, decode_output))
htorch.core.mark_step()
return output
def forward(
self,
layer: AttentionLayer,
Expand All @@ -483,6 +698,16 @@ def forward(
shape = [num_tokens, num_heads * head_size]
"""
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
if attn_metadata.chunk_prefill_enabled:
return self.forward_chunked_prefill(
layer=layer,
query=query,
key=key,
value=value,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
output=output,
)
if self.attn_type == AttentionType.ENCODER_DECODER:
return self.forward_encoder_decoder(
query=query,
Expand Down Expand Up @@ -815,3 +1040,4 @@ def _make_decode_alibi_bias(
per_head_bias.mul_(alibi_slopes[None, :, None])

return per_head_bias

2 changes: 2 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,8 @@ def create_engine_config(
if speculative_config is None \
else speculative_config.num_lookahead_slots

if self.enable_chunked_prefill:
self.use_padding_aware_scheduling = False
scheduler_config = SchedulerConfig(
runner_type=model_config.runner_type,
max_num_batched_tokens=self.max_num_batched_tokens,
Expand Down
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
VLLM_SLEEP_WHEN_IDLE: bool = False
VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -872,6 +873,10 @@ def get_vllm_port() -> Optional[int]:
# latency penalty when a request eventually comes.
"VLLM_SLEEP_WHEN_IDLE":
lambda: bool(int(os.getenv("VLLM_SLEEP_WHEN_IDLE", "0"))),

# Use chunked prefill with dynamic input shapes for HPU backend.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the meaning of VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT? when should this be set?

"VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT":
lambda: bool(int(os.getenv("VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT", "0"))),
}

# --8<-- [end:env-vars-definition]
Expand Down
Loading