Skip to content

Commit ff0da82

Browse files
committed
Enable warmup for chunked prefill
Signed-off-by: jkyu <[email protected]>
1 parent 64fdbc6 commit ff0da82

File tree

2 files changed

+204
-14
lines changed

2 files changed

+204
-14
lines changed

vllm/attention/backends/hpu_attn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,10 @@ def forward_chunked_prefill(
611611
position_bias = None
612612

613613
if envs.VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT:
614-
assert prefill_batch_size == 1, "Only batch size 1 is supported for chunked prefill with dynamic block list."
614+
assert prefill_batch_size == 1, (
615+
"Only batch size 1 is supported for chunked prefill with "
616+
"dynamic block list."
617+
)
615618
key_attn = attn_data.key.view(kv_shape)
616619
value_attn = attn_data.value.view(kv_shape)
617620
common_args['need_context'] = True

vllm/worker/hpu_model_runner.py

Lines changed: 200 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,9 @@ def _update_metadata_chunked_prefill(self,
680680
attn_metadata.num_prefills)
681681
attn_bias = None
682682
if envs.VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT:
683-
assert batch_size == 1, "Chunked prefill with dynamic block_list only supports batch_size=1"
683+
assert batch_size == 1, (
684+
"Chunked prefill with dynamic block_list only supports bs=1"
685+
)
684686
for i in range(batch_size):
685687
single_attn_bias = self._set_attn_bias_chunked(
686688
int(seq_len), context_lens_t[i], query_lens_t[i], device,
@@ -1864,9 +1866,9 @@ def _prepare_prompt(
18641866

18651867
computed_block_nums = seq_group_metadata.computed_block_nums
18661868
if (self.scheduler_config is not None
1869+
and self.scheduler_config is not None
18671870
and self.scheduler_config.chunked_prefill_enabled
1868-
and not (computed_block_nums is None
1869-
or computed_block_nums == [])):
1871+
and self.cache_config.enable_prefix_caching):
18701872
raise RuntimeError(
18711873
"chunked prefill cannot be used with prefix caching "
18721874
"now.")
@@ -1896,7 +1898,10 @@ def _prepare_prompt(
18961898
# Prefill has chunked before.
18971899
block_table = seq_group_metadata.block_tables[seq_id]
18981900
if envs.VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT:
1899-
assert context_len % self.block_size == 0, "context len must be multiple of block size in dynamic chunked prefill mode"
1901+
assert context_len % self.block_size == 0, (
1902+
"context len must be multiple of block size in "
1903+
"dynamic chunked prefill mode"
1904+
)
19001905
prefix_blocks = context_len // self.block_size
19011906
prefix_block_tables.append(block_table[:prefix_blocks])
19021907
else:
@@ -2030,11 +2035,16 @@ def _prepare_prompt(
20302035
for _ in range(batch_size_padding))
20312036

20322037
real_num_seqs = len(query_lens)
2033-
if self.scheduler_config.chunked_prefill_enabled and envs.VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT:
2034-
assert target_query_len <= self.max_num_batched_tokens, f"{target_query_len=} exceeds {self.max_num_batched_tokens=} for chunked prefill"
2038+
bs = len(seq_group_metadata_list)
2039+
if (self.scheduler_config.chunked_prefill_enabled
2040+
and envs.VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT):
2041+
assert target_query_len <= self.max_num_batched_tokens, (
2042+
f"{target_query_len=} exceeds "
2043+
f"{self.max_num_batched_tokens=} for chunked prefill"
2044+
)
2045+
20352046
max_prompt_len = self.max_num_batched_tokens
20362047
else:
2037-
bs = len(seq_group_metadata_list)
20382048
if bs > 1 and self.use_merged_prefill:
20392049
bs = 1
20402050
max_prompt_len = max(
@@ -2069,7 +2079,9 @@ def _prepare_prompt(
20692079
if self.vllm_config.cache_config.enable_prefix_caching:
20702080
assert self.scheduler_config.max_num_prefill_seqs == 1
20712081
assert bs == 1, (
2072-
"Prefix caching or chunked prefill with multiple sequences is not supported yet.")
2082+
"Prefix caching or chunked prefill with multiple sequences "
2083+
"is not supported yet."
2084+
)
20732085
# prefix caching or chunked prefill
20742086

20752087
max_num_block = max(len(bt) for bt in prefix_block_tables)
@@ -2079,9 +2091,10 @@ def _prepare_prompt(
20792091
([_PAD_BLOCK_ID] * (max_num_block - len(bt)))
20802092
for bt in prefix_block_tables))
20812093

2082-
if self.scheduler_config.chunked_prefill_enabled and not envs.VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT:
2083-
if max_prompt_len < max_num_block * self.block_size:
2084-
max_prompt_len = max_num_block * self.block_size
2094+
if (self.scheduler_config.chunked_prefill_enabled
2095+
and not envs.VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT
2096+
and max_prompt_len < max_num_block * self.block_size):
2097+
max_prompt_len = max_num_block * self.block_size
20852098
pad_len = len(prefix_block_list)
20862099
prefix_block_list = pad_list(prefix_block_list, pad_len,
20872100
_PAD_BLOCK_ID)
@@ -2765,6 +2778,16 @@ def prepare_input_tensors(
27652778
prefill_reqs.append(seq_group_meta)
27662779
else:
27672780
decode_reqs.append(seq_group_meta)
2781+
if self.scheduler_config.enable_chunked_prefill and len(decode_reqs) != 0:
2782+
decode_reqs, real_decode_batch_size, decode_batch_size_padded = (
2783+
self._add_dummy_seq(decode_reqs, False, align_worker))
2784+
seq_group_metadata_list=[]
2785+
if len(prefill_reqs) != 0:
2786+
for req in prefill_reqs:
2787+
seq_group_metadata_list.append(req)
2788+
for req in decode_reqs:
2789+
seq_group_metadata_list.append(req)
2790+
batch_size_padded = len(seq_group_metadata_list)
27682791

27692792
# Prepare input tensors.
27702793
(
@@ -3571,6 +3594,145 @@ def warmup_scenario(self,
35713594
if not is_dummy_run:
35723595
gc.collect()
35733596

3597+
def warmup_scenario_mix(self,
3598+
batch_size,
3599+
seq_len,
3600+
ctx,
3601+
is_prompt,
3602+
kv_caches,
3603+
is_pt_profiler_run=False,
3604+
is_lora_profile_run=False,
3605+
temperature=0,
3606+
img_args=None,
3607+
num_iters=3,
3608+
align_worker=False,
3609+
is_dummy_run=False) -> None:
3610+
phase = 'mix'
3611+
use_graphs = is_dummy_run or self._use_graphs(batch_size, seq_len)
3612+
buckets = self.bucketing_manager.decode_buckets
3613+
num_candidates = len(buckets)
3614+
for idx, (decode_bs, _, decode_ctx) in enumerate(reversed(buckets)):
3615+
scenario_name = ("warmup_"
3616+
f"{phase}_"
3617+
f"prefill_bs{batch_size}_"
3618+
f"prefill_seq{seq_len}_"
3619+
f"prefill_ctx{ctx}_"
3620+
f"decode_bs{decode_bs}_"
3621+
f"decode_ctx{decode_ctx}_"
3622+
f"graphs{'T' if use_graphs else 'F'}")
3623+
3624+
self.log_warmup(f"Graph/{'mix'}/{'decode'}", idx, num_candidates, decode_bs, 1,
3625+
decode_ctx)
3626+
dummy_lora_requests: List[LoRARequest] = []
3627+
dummy_lora_requests_per_seq: List[LoRARequest] = []
3628+
if self.lora_config and is_lora_profile_run:
3629+
assert self.lora_manager is not None
3630+
with self.lora_manager.dummy_lora_cache():
3631+
for idx in range(self.lora_config.max_loras):
3632+
lora_id = idx + 1
3633+
dummy_lora_request = LoRARequest(
3634+
lora_name=f"warmup_{lora_id}",
3635+
lora_int_id=lora_id,
3636+
lora_local_path="/not/a/real/path",
3637+
)
3638+
self.lora_manager.add_dummy_lora(dummy_lora_request,
3639+
rank=LORA_WARMUP_RANK)
3640+
dummy_lora_requests.append(dummy_lora_request)
3641+
dummy_lora_requests_per_seq = [
3642+
dummy_lora_requests[idx % len(dummy_lora_requests)]
3643+
for idx in range(batch_size)
3644+
]
3645+
self.profiler.start('internal', scenario_name)
3646+
times = num_iters if use_graphs or is_pt_profiler_run else 1
3647+
seqs=[]
3648+
seqs_prefill = self.create_dummy_seq_group_metadata(
3649+
0,
3650+
seq_len + ctx * self.block_size,
3651+
True,
3652+
lora_request=dummy_lora_requests_per_seq[i]
3653+
if dummy_lora_requests_per_seq else None,
3654+
img_args=img_args,
3655+
temperature=temperature,
3656+
ctx=ctx)
3657+
3658+
seqs.append(seqs_prefill)
3659+
blocks = [decode_ctx // decode_bs for _ in range(decode_bs)]
3660+
blocks[0] += decode_ctx % decode_bs
3661+
for i, b in enumerate(blocks):
3662+
seqs_decode = self.create_dummy_seq_group_metadata(
3663+
i,
3664+
b * self.block_size - 1,
3665+
False,
3666+
lora_request=dummy_lora_requests_per_seq[i]
3667+
if dummy_lora_requests_per_seq else None,
3668+
temperature=temperature,
3669+
ctx=decode_ctx)
3670+
seqs.append(seqs_decode)
3671+
3672+
if not is_dummy_run:
3673+
torch.hpu.synchronize()
3674+
profiler = None
3675+
if is_pt_profiler_run and self.is_driver_worker:
3676+
profiler = setup_profiler()
3677+
profiler.start()
3678+
for time_index in range(times):
3679+
inputs = self.prepare_model_input_align_worker(
3680+
seqs, align_worker=align_worker)
3681+
# Chendi: Necessary fix for warmup with TP>1
3682+
if time_index == 0:
3683+
if self.is_driver_worker:
3684+
broadcast_tensor_dict(
3685+
{"input_tokens": inputs.input_tokens}, src=0)
3686+
else:
3687+
broadcast_tensor_dict(src=0)
3688+
if self._is_fla_model():
3689+
self.add_fla_dummy_data(inputs)
3690+
if is_prompt or self.is_single_step:
3691+
intermediate_tensors = None
3692+
if not get_pp_group().is_first_rank:
3693+
intermediate_tensors = \
3694+
self.model.make_empty_intermediate_tensors(
3695+
batch_size=batch_size,
3696+
context_size=seq_len if is_prompt else 1,
3697+
dtype=self.model_config.dtype,
3698+
device=self.device)
3699+
self.execute_model(inputs,
3700+
kv_caches,
3701+
intermediate_tensors=intermediate_tensors,
3702+
warmup_mode=True,
3703+
ctx_blocks=ctx,
3704+
is_dummy_run=is_dummy_run,
3705+
is_pt_profiler_run=is_pt_profiler_run)
3706+
else: # decode with multi-step
3707+
inputs = dataclasses.replace(inputs,
3708+
is_first_multi_step=True,
3709+
is_last_step=False)
3710+
self.execute_model(inputs,
3711+
kv_caches,
3712+
warmup_mode=True,
3713+
num_steps=2,
3714+
seqs=seqs,
3715+
ctx_blocks=ctx)
3716+
inputs = dataclasses.replace(inputs,
3717+
is_first_multi_step=False,
3718+
is_last_step=True)
3719+
self.execute_model(inputs,
3720+
kv_caches,
3721+
warmup_mode=True,
3722+
num_steps=2,
3723+
seqs=seqs,
3724+
ctx_blocks=ctx)
3725+
if not is_dummy_run:
3726+
torch.hpu.synchronize()
3727+
if profiler:
3728+
profiler.step()
3729+
if profiler:
3730+
profiler.stop()
3731+
self.profiler.end()
3732+
if not is_dummy_run:
3733+
gc.collect()
3734+
3735+
35743736
def remove_all_loras(self):
35753737
if not self.lora_manager:
35763738
raise RuntimeError("LoRA is not enabled.")
@@ -3665,6 +3827,30 @@ def warmup_graphs(self,
36653827
total_mem += used_mem
36663828
total_batch_seq += batch_seq
36673829

3830+
if self.scheduler_config.chunked_prefill_enabled and is_prompt:
3831+
for idx, (batch_size, query_len, ctx) in enumerate(reversed(buckets)):
3832+
# Graph memory usage is proportional to seq dimension in a batch
3833+
phase = f"Graph/{'mix'}/{'prompt'}"
3834+
seq_len = query_len + ctx * self.block_size
3835+
batch_seq = batch_size * seq_len
3836+
self.log_warmup(phase, idx, num_candidates, batch_size, query_len,
3837+
ctx)
3838+
with HabanaMemoryProfiler() as mem_prof:
3839+
self.warmup_scenario_mix(
3840+
batch_size,
3841+
query_len,
3842+
ctx,
3843+
is_prompt,
3844+
kv_caches,
3845+
temperature=1.0
3846+
if batch_size not in warmed_random_sampler_bs else 0,
3847+
)
3848+
warmed_random_sampler_bs.add(batch_size)
3849+
used_mem = align_workers(mem_prof.consumed_device_memory,
3850+
torch.distributed.ReduceOp.MAX)
3851+
total_mem += used_mem
3852+
total_batch_seq += batch_seq
3853+
36683854
if is_prompt and self.is_mm_run():
36693855
#For multimodal total_batch_seq and total_mem, we store it in the
36703856
#attribute for now.
@@ -4110,9 +4296,10 @@ def _phase(self, attn_metadata):
41104296
def _check_config(self, batch_size, seq_len, ctx, attn_metadata,
41114297
warmup_mode):
41124298
is_prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
4299+
is_chunked_prefill = self.vllm_config.scheduler_config.enable_chunked_prefill
41134300
cfg: Optional[tuple] = None
41144301
assert cfg is None, "Configs changed between 2D and 3D"
4115-
if is_prefix_caching:
4302+
if is_prefix_caching or is_chunked_prefill:
41164303
phase = self._phase(attn_metadata)
41174304
num_blocks = self._num_blocks(attn_metadata)
41184305
cfg = (batch_size, seq_len, num_blocks, phase)
@@ -4124,7 +4311,7 @@ def _check_config(self, batch_size, seq_len, ctx, attn_metadata,
41244311
if not seen and not warmup_mode:
41254312
logger.warning("Configuration: %s was not warmed-up!",
41264313
(phase.value, batch_size, seq_len,
4127-
num_blocks) if is_prefix_caching else
4314+
num_blocks) if is_prefix_caching or is_chunked_prefill else
41284315
(phase, batch_size, seq_len))
41294316

41304317
def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],

0 commit comments

Comments
 (0)