@@ -680,7 +680,9 @@ def _update_metadata_chunked_prefill(self,
680680 attn_metadata .num_prefills )
681681 attn_bias = None
682682 if envs .VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT :
683- assert batch_size == 1 , "Chunked prefill with dynamic block_list only supports batch_size=1"
683+ assert batch_size == 1 , (
684+ "Chunked prefill with dynamic block_list only supports bs=1"
685+ )
684686 for i in range (batch_size ):
685687 single_attn_bias = self ._set_attn_bias_chunked (
686688 int (seq_len ), context_lens_t [i ], query_lens_t [i ], device ,
@@ -1864,9 +1866,9 @@ def _prepare_prompt(
18641866
18651867 computed_block_nums = seq_group_metadata .computed_block_nums
18661868 if (self .scheduler_config is not None
1869+ and self .scheduler_config is not None
18671870 and self .scheduler_config .chunked_prefill_enabled
1868- and not (computed_block_nums is None
1869- or computed_block_nums == [])):
1871+ and self .cache_config .enable_prefix_caching ):
18701872 raise RuntimeError (
18711873 "chunked prefill cannot be used with prefix caching "
18721874 "now." )
@@ -1896,7 +1898,10 @@ def _prepare_prompt(
18961898 # Prefill has chunked before.
18971899 block_table = seq_group_metadata .block_tables [seq_id ]
18981900 if envs .VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT :
1899- assert context_len % self .block_size == 0 , "context len must be multiple of block size in dynamic chunked prefill mode"
1901+ assert context_len % self .block_size == 0 , (
1902+ "context len must be multiple of block size in "
1903+ "dynamic chunked prefill mode"
1904+ )
19001905 prefix_blocks = context_len // self .block_size
19011906 prefix_block_tables .append (block_table [:prefix_blocks ])
19021907 else :
@@ -2030,11 +2035,16 @@ def _prepare_prompt(
20302035 for _ in range (batch_size_padding ))
20312036
20322037 real_num_seqs = len (query_lens )
2033- if self .scheduler_config .chunked_prefill_enabled and envs .VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT :
2034- assert target_query_len <= self .max_num_batched_tokens , f"{ target_query_len = } exceeds { self .max_num_batched_tokens = } for chunked prefill"
2038+ bs = len (seq_group_metadata_list )
2039+ if (self .scheduler_config .chunked_prefill_enabled
2040+ and envs .VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT ):
2041+ assert target_query_len <= self .max_num_batched_tokens , (
2042+ f"{ target_query_len = } exceeds "
2043+ f"{ self .max_num_batched_tokens = } for chunked prefill"
2044+ )
2045+
20352046 max_prompt_len = self .max_num_batched_tokens
20362047 else :
2037- bs = len (seq_group_metadata_list )
20382048 if bs > 1 and self .use_merged_prefill :
20392049 bs = 1
20402050 max_prompt_len = max (
@@ -2069,7 +2079,9 @@ def _prepare_prompt(
20692079 if self .vllm_config .cache_config .enable_prefix_caching :
20702080 assert self .scheduler_config .max_num_prefill_seqs == 1
20712081 assert bs == 1 , (
2072- "Prefix caching or chunked prefill with multiple sequences is not supported yet." )
2082+ "Prefix caching or chunked prefill with multiple sequences "
2083+ "is not supported yet."
2084+ )
20732085 # prefix caching or chunked prefill
20742086
20752087 max_num_block = max (len (bt ) for bt in prefix_block_tables )
@@ -2079,9 +2091,10 @@ def _prepare_prompt(
20792091 ([_PAD_BLOCK_ID ] * (max_num_block - len (bt )))
20802092 for bt in prefix_block_tables ))
20812093
2082- if self .scheduler_config .chunked_prefill_enabled and not envs .VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT :
2083- if max_prompt_len < max_num_block * self .block_size :
2084- max_prompt_len = max_num_block * self .block_size
2094+ if (self .scheduler_config .chunked_prefill_enabled
2095+ and not envs .VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT
2096+ and max_prompt_len < max_num_block * self .block_size ):
2097+ max_prompt_len = max_num_block * self .block_size
20852098 pad_len = len (prefix_block_list )
20862099 prefix_block_list = pad_list (prefix_block_list , pad_len ,
20872100 _PAD_BLOCK_ID )
@@ -2765,6 +2778,16 @@ def prepare_input_tensors(
27652778 prefill_reqs .append (seq_group_meta )
27662779 else :
27672780 decode_reqs .append (seq_group_meta )
2781+ if self .scheduler_config .enable_chunked_prefill and len (decode_reqs ) != 0 :
2782+ decode_reqs , real_decode_batch_size , decode_batch_size_padded = (
2783+ self ._add_dummy_seq (decode_reqs , False , align_worker ))
2784+ seq_group_metadata_list = []
2785+ if len (prefill_reqs ) != 0 :
2786+ for req in prefill_reqs :
2787+ seq_group_metadata_list .append (req )
2788+ for req in decode_reqs :
2789+ seq_group_metadata_list .append (req )
2790+ batch_size_padded = len (seq_group_metadata_list )
27682791
27692792 # Prepare input tensors.
27702793 (
@@ -3571,6 +3594,145 @@ def warmup_scenario(self,
35713594 if not is_dummy_run :
35723595 gc .collect ()
35733596
3597+ def warmup_scenario_mix (self ,
3598+ batch_size ,
3599+ seq_len ,
3600+ ctx ,
3601+ is_prompt ,
3602+ kv_caches ,
3603+ is_pt_profiler_run = False ,
3604+ is_lora_profile_run = False ,
3605+ temperature = 0 ,
3606+ img_args = None ,
3607+ num_iters = 3 ,
3608+ align_worker = False ,
3609+ is_dummy_run = False ) -> None :
3610+ phase = 'mix'
3611+ use_graphs = is_dummy_run or self ._use_graphs (batch_size , seq_len )
3612+ buckets = self .bucketing_manager .decode_buckets
3613+ num_candidates = len (buckets )
3614+ for idx , (decode_bs , _ , decode_ctx ) in enumerate (reversed (buckets )):
3615+ scenario_name = ("warmup_"
3616+ f"{ phase } _"
3617+ f"prefill_bs{ batch_size } _"
3618+ f"prefill_seq{ seq_len } _"
3619+ f"prefill_ctx{ ctx } _"
3620+ f"decode_bs{ decode_bs } _"
3621+ f"decode_ctx{ decode_ctx } _"
3622+ f"graphs{ 'T' if use_graphs else 'F' } " )
3623+
3624+ self .log_warmup (f"Graph/{ 'mix' } /{ 'decode' } " , idx , num_candidates , decode_bs , 1 ,
3625+ decode_ctx )
3626+ dummy_lora_requests : List [LoRARequest ] = []
3627+ dummy_lora_requests_per_seq : List [LoRARequest ] = []
3628+ if self .lora_config and is_lora_profile_run :
3629+ assert self .lora_manager is not None
3630+ with self .lora_manager .dummy_lora_cache ():
3631+ for idx in range (self .lora_config .max_loras ):
3632+ lora_id = idx + 1
3633+ dummy_lora_request = LoRARequest (
3634+ lora_name = f"warmup_{ lora_id } " ,
3635+ lora_int_id = lora_id ,
3636+ lora_local_path = "/not/a/real/path" ,
3637+ )
3638+ self .lora_manager .add_dummy_lora (dummy_lora_request ,
3639+ rank = LORA_WARMUP_RANK )
3640+ dummy_lora_requests .append (dummy_lora_request )
3641+ dummy_lora_requests_per_seq = [
3642+ dummy_lora_requests [idx % len (dummy_lora_requests )]
3643+ for idx in range (batch_size )
3644+ ]
3645+ self .profiler .start ('internal' , scenario_name )
3646+ times = num_iters if use_graphs or is_pt_profiler_run else 1
3647+ seqs = []
3648+ seqs_prefill = self .create_dummy_seq_group_metadata (
3649+ 0 ,
3650+ seq_len + ctx * self .block_size ,
3651+ True ,
3652+ lora_request = dummy_lora_requests_per_seq [i ]
3653+ if dummy_lora_requests_per_seq else None ,
3654+ img_args = img_args ,
3655+ temperature = temperature ,
3656+ ctx = ctx )
3657+
3658+ seqs .append (seqs_prefill )
3659+ blocks = [decode_ctx // decode_bs for _ in range (decode_bs )]
3660+ blocks [0 ] += decode_ctx % decode_bs
3661+ for i , b in enumerate (blocks ):
3662+ seqs_decode = self .create_dummy_seq_group_metadata (
3663+ i ,
3664+ b * self .block_size - 1 ,
3665+ False ,
3666+ lora_request = dummy_lora_requests_per_seq [i ]
3667+ if dummy_lora_requests_per_seq else None ,
3668+ temperature = temperature ,
3669+ ctx = decode_ctx )
3670+ seqs .append (seqs_decode )
3671+
3672+ if not is_dummy_run :
3673+ torch .hpu .synchronize ()
3674+ profiler = None
3675+ if is_pt_profiler_run and self .is_driver_worker :
3676+ profiler = setup_profiler ()
3677+ profiler .start ()
3678+ for time_index in range (times ):
3679+ inputs = self .prepare_model_input_align_worker (
3680+ seqs , align_worker = align_worker )
3681+ # Chendi: Necessary fix for warmup with TP>1
3682+ if time_index == 0 :
3683+ if self .is_driver_worker :
3684+ broadcast_tensor_dict (
3685+ {"input_tokens" : inputs .input_tokens }, src = 0 )
3686+ else :
3687+ broadcast_tensor_dict (src = 0 )
3688+ if self ._is_fla_model ():
3689+ self .add_fla_dummy_data (inputs )
3690+ if is_prompt or self .is_single_step :
3691+ intermediate_tensors = None
3692+ if not get_pp_group ().is_first_rank :
3693+ intermediate_tensors = \
3694+ self .model .make_empty_intermediate_tensors (
3695+ batch_size = batch_size ,
3696+ context_size = seq_len if is_prompt else 1 ,
3697+ dtype = self .model_config .dtype ,
3698+ device = self .device )
3699+ self .execute_model (inputs ,
3700+ kv_caches ,
3701+ intermediate_tensors = intermediate_tensors ,
3702+ warmup_mode = True ,
3703+ ctx_blocks = ctx ,
3704+ is_dummy_run = is_dummy_run ,
3705+ is_pt_profiler_run = is_pt_profiler_run )
3706+ else : # decode with multi-step
3707+ inputs = dataclasses .replace (inputs ,
3708+ is_first_multi_step = True ,
3709+ is_last_step = False )
3710+ self .execute_model (inputs ,
3711+ kv_caches ,
3712+ warmup_mode = True ,
3713+ num_steps = 2 ,
3714+ seqs = seqs ,
3715+ ctx_blocks = ctx )
3716+ inputs = dataclasses .replace (inputs ,
3717+ is_first_multi_step = False ,
3718+ is_last_step = True )
3719+ self .execute_model (inputs ,
3720+ kv_caches ,
3721+ warmup_mode = True ,
3722+ num_steps = 2 ,
3723+ seqs = seqs ,
3724+ ctx_blocks = ctx )
3725+ if not is_dummy_run :
3726+ torch .hpu .synchronize ()
3727+ if profiler :
3728+ profiler .step ()
3729+ if profiler :
3730+ profiler .stop ()
3731+ self .profiler .end ()
3732+ if not is_dummy_run :
3733+ gc .collect ()
3734+
3735+
35743736 def remove_all_loras (self ):
35753737 if not self .lora_manager :
35763738 raise RuntimeError ("LoRA is not enabled." )
@@ -3665,6 +3827,30 @@ def warmup_graphs(self,
36653827 total_mem += used_mem
36663828 total_batch_seq += batch_seq
36673829
3830+ if self .scheduler_config .chunked_prefill_enabled and is_prompt :
3831+ for idx , (batch_size , query_len , ctx ) in enumerate (reversed (buckets )):
3832+ # Graph memory usage is proportional to seq dimension in a batch
3833+ phase = f"Graph/{ 'mix' } /{ 'prompt' } "
3834+ seq_len = query_len + ctx * self .block_size
3835+ batch_seq = batch_size * seq_len
3836+ self .log_warmup (phase , idx , num_candidates , batch_size , query_len ,
3837+ ctx )
3838+ with HabanaMemoryProfiler () as mem_prof :
3839+ self .warmup_scenario_mix (
3840+ batch_size ,
3841+ query_len ,
3842+ ctx ,
3843+ is_prompt ,
3844+ kv_caches ,
3845+ temperature = 1.0
3846+ if batch_size not in warmed_random_sampler_bs else 0 ,
3847+ )
3848+ warmed_random_sampler_bs .add (batch_size )
3849+ used_mem = align_workers (mem_prof .consumed_device_memory ,
3850+ torch .distributed .ReduceOp .MAX )
3851+ total_mem += used_mem
3852+ total_batch_seq += batch_seq
3853+
36683854 if is_prompt and self .is_mm_run ():
36693855 #For multimodal total_batch_seq and total_mem, we store it in the
36703856 #attribute for now.
@@ -4110,9 +4296,10 @@ def _phase(self, attn_metadata):
41104296 def _check_config (self , batch_size , seq_len , ctx , attn_metadata ,
41114297 warmup_mode ):
41124298 is_prefix_caching = self .vllm_config .cache_config .enable_prefix_caching
4299+ is_chunked_prefill = self .vllm_config .scheduler_config .enable_chunked_prefill
41134300 cfg : Optional [tuple ] = None
41144301 assert cfg is None , "Configs changed between 2D and 3D"
4115- if is_prefix_caching :
4302+ if is_prefix_caching or is_chunked_prefill :
41164303 phase = self ._phase (attn_metadata )
41174304 num_blocks = self ._num_blocks (attn_metadata )
41184305 cfg = (batch_size , seq_len , num_blocks , phase )
@@ -4124,7 +4311,7 @@ def _check_config(self, batch_size, seq_len, ctx, attn_metadata,
41244311 if not seen and not warmup_mode :
41254312 logger .warning ("Configuration: %s was not warmed-up!" ,
41264313 (phase .value , batch_size , seq_len ,
4127- num_blocks ) if is_prefix_caching else
4314+ num_blocks ) if is_prefix_caching or is_chunked_prefill else
41284315 (phase , batch_size , seq_len ))
41294316
41304317 def create_lora_mask (self , input_tokens : torch .Tensor , lora_ids : List [int ],
0 commit comments