Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1898,17 +1898,33 @@ def create_dummy_seq_group_metadata(self,
seq_len,
is_prompt,
lora_request=None,
temperature=0):
temperature=0,
is_profile_run=False):
if self.is_pooler:
sampling_params = None
else:
sampling_params = SamplingParams(temperature=temperature)
num_blocks = math.ceil(seq_len / self.block_size)
seq_len = max(seq_len, 1)
context_len = 0
if is_prompt:
input_len = seq_len
output_len = 0
block_tables = None

if (not is_profile_run
and self.scheduler_config.chunked_prefill_enabled
and self.scheduler_config.prefill_chunk_size):
# if chunked prefill enabled prefill chunk size specified
# simplify the warmup by considering the chunk size
chunk_size = self.scheduler_config.prefill_chunk_size
if seq_len > chunk_size:
chunks = seq_len // chunk_size
chunk_remaining = seq_len % chunk_size
if chunk_remaining == 0:
chunks = chunks - 1
context_len = chunks * chunk_size
block_tables = {group_id: [_PAD_BLOCK_ID] * num_blocks}
else:
input_len = seq_len - 1
output_len = 1
Expand All @@ -1917,6 +1933,9 @@ def create_dummy_seq_group_metadata(self,
output_token_ids = [1] * output_len
prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821
seq_data = SequenceData(prompt_token_ids_array)
if is_prompt and context_len > 0:
# set the _num_computed_tokens for the context len
seq_data.update_num_computed_tokens(context_len)
seq_data.output_token_ids = output_token_ids
return SequenceGroupMetadata(request_id=str(group_id),
is_prompt=(output_len == 0),
Expand Down Expand Up @@ -2020,7 +2039,8 @@ def warmup_scenario(self,
is_prompt,
lora_request=dummy_lora_requests_per_seq[i]
if dummy_lora_requests_per_seq else None,
temperature=temperature) for i in range(batch_size)
temperature=temperature,
is_profile_run=is_profile_run) for i in range(batch_size)
]
else:
# FIXME: seq_len is actually number of blocks
Expand All @@ -2033,7 +2053,8 @@ def warmup_scenario(self,
is_prompt,
lora_request=dummy_lora_requests_per_seq[i]
if dummy_lora_requests_per_seq else None,
temperature=temperature) for i, b in enumerate(blocks)
temperature=temperature,
is_profile_run=is_profile_run) for i, b in enumerate(blocks)
]
if not is_dummy_run:
torch.hpu.synchronize()
Expand Down
Loading