Skip to content

Commit baa60de

Browse files
authored
fix finding bucket for context length (#2118)
- Works with HabanaAI/vllm-hpu-extension#385 to enable padding ratio limit for the context length bucketing to reduce the number of buckets. - Truncate the context length based on the bucketing in the APC block manager. - Add assertion for `max_num_prefill_seqs==` when APC is enabled. --------- Signed-off-by: Youlei Yang <[email protected]>
1 parent 3bd00ce commit baa60de

File tree

3 files changed

+39
-5
lines changed

3 files changed

+39
-5
lines changed

vllm/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4586,6 +4586,14 @@ def __post_init__(self):
45864586
"but the scheduler is configured to publish them."
45874587
"Modify KVEventsConfig.enable_kv_cache_events"
45884588
"to True to enable.")
4589+
if (current_platform.is_hpu()
4590+
and self.cache_config.enable_prefix_caching
4591+
and self.scheduler_config.max_num_prefill_seqs is not None
4592+
and self.scheduler_config.max_num_prefill_seqs > 1):
4593+
logger.warning(
4594+
"Prefix caching with bs > 1 is not supported on HPU."
4595+
" Setting max_num_prefill_seqs to 1.")
4596+
self.scheduler_config.max_num_prefill_seqs = 1
45894597
current_platform.check_and_update_config(self)
45904598

45914599
if not self.instance_id:

vllm/core/block/prefix_caching_block.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
NaiveBlockAllocator)
1616
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
1717
from vllm.logger import init_logger
18+
from vllm.platforms import current_platform
1819
from vllm.sequence import Sequence
1920

2021
PrefixHash = int
@@ -1075,8 +1076,21 @@ def get_num_cached_tokens(self, seq: Sequence) -> int:
10751076
# This is O(logN), where N is the number of blocks.
10761077
num_cached_blocks = len(
10771078
self._allocator.find_cached_blocks_prefix(block_hashes))
1079+
if current_platform.is_hpu(
1080+
) and num_cached_blocks > 0 and seq.is_prefill():
1081+
from vllm_hpu_extension.bucketing.common import (
1082+
get_bucketing_manager)
1083+
hpu_bucketing_manager = get_bucketing_manager()
1084+
seq_len = seq.get_len() - num_cached_blocks * self._block_size
1085+
_, _, bkt_cached_blocks = hpu_bucketing_manager.find_prompt_bucket(
1086+
1, seq_len, num_cached_blocks, False)
1087+
logger.debug("HPU bucketing adjusted cached blocks from %d to %d",
1088+
num_cached_blocks, bkt_cached_blocks)
1089+
num_cached_blocks = bkt_cached_blocks
10781090
num_cached_tokens = num_cached_blocks * self._block_size
10791091
self._seq_id_to_num_tokens_computed[seq.seq_id] = num_cached_tokens
1092+
self._seq_id_to_blocks_hashes[
1093+
seq.seq_id] = block_hashes[:num_cached_blocks]
10801094
return num_cached_tokens
10811095

10821096
def remove_seq(self, seq_id: int) -> None:

vllm/worker/hpu_model_runner.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1963,6 +1963,9 @@ def _prepare_prompt(
19631963

19641964
if any(context_lens):
19651965
assert not self.scheduler_config.chunked_prefill_enabled
1966+
assert self.scheduler_config.max_num_prefill_seqs == 1
1967+
assert bs == 1, (
1968+
"Prefix caching with multiple sequences is not supported yet.")
19661969
# prefix caching
19671970

19681971
max_num_block = max(len(bt) for bt in prefix_block_tables)
@@ -2836,9 +2839,8 @@ def prepare_model_input_align_worker(
28362839
"""
28372840
with self.profiler.record_event('internal', 'prepare_input_tensors'):
28382841
assert seq_group_metadata_list is not None
2839-
if self.profiler.enabled:
2840-
self.profiler_counter_helper.capture_seq_group_metadata_stats(
2841-
seq_group_metadata_list=seq_group_metadata_list)
2842+
self.profiler_counter_helper.capture_seq_group_metadata_stats(
2843+
seq_group_metadata_list=seq_group_metadata_list)
28422844
model_input, sampling_metadata = self.prepare_input_tensors(
28432845
seq_group_metadata_list, finished_requests_ids, align_worker)
28442846
assert model_input.attn_metadata is not None
@@ -4055,7 +4057,7 @@ def execute_model(
40554057
warmup_mode=False,
40564058
previous_hidden_states: Optional[torch.Tensor] = None,
40574059
seqs=None,
4058-
ctx_blocks: int = 1,
4060+
ctx_blocks: int = 0,
40594061
is_dummy_run: bool = False,
40604062
is_pt_profiler_run: bool = False,
40614063
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
@@ -4144,6 +4146,9 @@ def execute_model(
41444146
if not warmup_mode:
41454147
ctx_blocks = seq_len
41464148
seq_len = 1
4149+
elif attn_metadata.block_list is not None:
4150+
if not warmup_mode:
4151+
ctx_blocks = attn_metadata.block_list.shape[-1]
41474152

41484153
if self._is_fla_model():
41494154
use_graphs = not is_prompt
@@ -4289,8 +4294,15 @@ def try_revert_dummy_output_tokens():
42894294
attn_metadata,
42904295
kv_caches=kv_caches
42914296
)
4297+
real_seq_lens = model_input.seq_lens
4298+
real_seq_lens = real_seq_lens if real_seq_lens else \
4299+
self.profiler_counter_helper.real_seq_lens
4300+
real_query_lens = model_input.query_lens
4301+
real_query_lens = real_query_lens if real_query_lens else \
4302+
self.profiler_counter_helper.prompt_seq_lens
42924303
profiler_args = {
4293-
'real_seq_len': model_input.seq_lens,
4304+
'real_seq_lens': real_seq_lens,
4305+
'real_query_lens': real_query_lens,
42944306
'real_batch_size': real_batch_size
42954307
}
42964308

0 commit comments

Comments
 (0)