Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2083,8 +2083,13 @@ def _chunk_new_tokens_to_schedule(
# If prefill_chunk_size is specified, chunk with the specific size
# The prefill_chunk_size must be dividable by the block size
assert scheduler_config.prefill_chunk_size % block_size == 0
remaining_token_budget = min(
remaining_token_budget, scheduler_config.prefill_chunk_size)
if remaining_token_budget >= scheduler_config.prefill_chunk_size:
remaining_token_budget = scheduler_config.prefill_chunk_size
else:
# If we sequence has to be chunked, we make sure the context
# blocks are multiple of prefill_chunk_size
if num_new_tokens > remaining_token_budget:
remaining_token_budget = 0
num_new_tokens = min(num_new_tokens, remaining_token_budget)

return num_new_tokens
126 changes: 101 additions & 25 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,16 +1023,20 @@ def _use_graphs(self,
batch_size,
seq_len,
is_prompt,
is_profile_run=False):
if is_prompt and batch_size * seq_len > self.max_seq_len_to_capture:
is_profile_run=False,
context_blocks=0):
if is_prompt and ((batch_size * seq_len +
context_blocks * self.block_size) >
self.max_seq_len_to_capture):
return False
if self.enforce_eager:
return False
if is_profile_run:
return False
if self.skip_warmup:
return True
return (batch_size, seq_len, is_prompt) in self.graphed_buckets
return (batch_size, seq_len, is_prompt, context_blocks
) in self.graphed_buckets

def _is_valid_bucket(self, bucket):
return bucket[0] * bucket[1] <= self.max_num_batched_tokens
Expand Down Expand Up @@ -1898,17 +1902,25 @@ def create_dummy_seq_group_metadata(self,
seq_len,
is_prompt,
lora_request=None,
temperature=0):
temperature=0,
context_blocks=0):
if self.is_pooler:
sampling_params = None
else:
sampling_params = SamplingParams(temperature=temperature)
num_blocks = math.ceil(seq_len / self.block_size)
seq_len = max(seq_len, 1)
context_len = 0
if is_prompt:
input_len = seq_len
output_len = 0
block_tables = None
if context_blocks > 0:
context_len = context_blocks * self.block_size
# input length and num blocks needs to consider the context
input_len = seq_len + context_len
num_blocks = math.ceil(input_len / self.block_size)
block_tables = {group_id: [_PAD_BLOCK_ID] * num_blocks}
else:
input_len = seq_len - 1
output_len = 1
Expand All @@ -1917,6 +1929,9 @@ def create_dummy_seq_group_metadata(self,
output_token_ids = [1] * output_len
prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821
seq_data = SequenceData(prompt_token_ids_array)
if is_prompt and context_len > 0:
# set the _num_computed_tokens for the context len
seq_data.update_num_computed_tokens(context_len)
seq_data.output_token_ids = output_token_ids
return SequenceGroupMetadata(request_id=str(group_id),
is_prompt=(output_len == 0),
Expand Down Expand Up @@ -1979,9 +1994,11 @@ def warmup_scenario(self,
temperature=0,
num_iters=3,
align_worker=False,
is_dummy_run=False) -> None:
is_dummy_run=False,
context_blocks=0) -> None:
use_graphs = (is_dummy_run) or self._use_graphs(
batch_size, seq_len, is_prompt, is_profile_run=is_profile_run)
batch_size, seq_len, is_prompt, is_profile_run=is_profile_run,
context_blocks=context_blocks)
scenario_name = ("warmup_"
f"{'prompt' if is_prompt else 'decode'}_"
f"bs{batch_size}_"
Expand Down Expand Up @@ -2020,7 +2037,8 @@ def warmup_scenario(self,
is_prompt,
lora_request=dummy_lora_requests_per_seq[i]
if dummy_lora_requests_per_seq else None,
temperature=temperature) for i in range(batch_size)
temperature=temperature,
context_blocks=context_blocks) for i in range(batch_size)
]
else:
# FIXME: seq_len is actually number of blocks
Expand Down Expand Up @@ -2145,23 +2163,49 @@ def list_loras(self) -> Set[int]:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_adapters()

def log_warmup(self, phase, i, max_i, batch_size, seq_len):
def log_warmup(self, phase, i, max_i, batch_size, seq_len,
context_blocks=0):
free_mem = format_bytes(
HabanaMemoryProfiler.current_free_device_memory())
dim = "num_blocks"
context_blocks_info = ""
if "Prompt" in phase:
dim = "seq_len"
if context_blocks > 0:
context_blocks_info = f"num_blocks:{context_blocks} "
msg = (f"[Warmup][{phase}][{i+1}/{max_i}] "
f"batch_size:{batch_size} "
f"{dim}:{seq_len} "
f"{context_blocks_info}"
f"free_mem:{free_mem}")
logger.info(msg)

def warmup_all_buckets(self, buckets, is_prompt, kv_caches):
for i, (batch_size, seq_len) in enumerate(reversed(buckets)):
self.log_warmup('Prompt' if is_prompt else 'Decode', i,
len(buckets), batch_size, seq_len)
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
if (is_prompt and self.scheduler_config.chunked_prefill_enabled
and self.scheduler_config.prefill_chunk_size
and batch_size <= (
self.scheduler_config.max_num_batched_tokens
// self.scheduler_config.prefill_chunk_size
)):
chunk_size = self.scheduler_config.prefill_chunk_size
max_chunks = self.max_model_len // chunk_size
if self.max_model_len % chunk_size > 0:
max_chunks += 1
chunk_blocks = chunk_size // self.block_size
for chunk in range(max_chunks):
context_blocks = chunk * chunk_blocks
self.log_warmup('Prompt' if is_prompt else 'Decode', i,
len(buckets), batch_size, seq_len,
context_blocks)
self.warmup_scenario(batch_size, seq_len, is_prompt,
kv_caches,
context_blocks=context_blocks)
gc.collect()
else:
self.log_warmup('Prompt' if is_prompt else 'Decode', i,
len(buckets), batch_size, seq_len)
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)

def warmup_graphs(self,
strategy,
Expand Down Expand Up @@ -2195,18 +2239,48 @@ def warmup_graphs(self,
batch_seq > self.max_seq_len_to_capture:
captured_all = False
continue
graphed_bucket = (batch_size, seq_len, is_prompt)
if graphed_bucket in self.graphed_buckets:
continue
self.graphed_buckets.add(graphed_bucket)
self.log_warmup(phase, idx, num_candidates, batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof:
self.warmup_scenario(batch_size,
seq_len,
is_prompt,
kv_caches,
temperature=1.0 if batch_size
not in warmed_random_sampler_bs else 0)
if (is_prompt and self.scheduler_config.chunked_prefill_enabled
and self.scheduler_config.prefill_chunk_size
and batch_size <= (
self.scheduler_config.max_num_batched_tokens
// self.scheduler_config.prefill_chunk_size
)):
chunk_size = self.scheduler_config.prefill_chunk_size
max_chunks = self.max_model_len // chunk_size
if self.max_model_len % chunk_size > 0:
max_chunks += 1
chunk_blocks = chunk_size // self.block_size
for chunk in range(max_chunks):
context_blocks = chunk * chunk_blocks
graphed_bucket = (batch_size, seq_len, is_prompt,
context_blocks)
if graphed_bucket in self.graphed_buckets:
continue
self.graphed_buckets.add(graphed_bucket)
self.log_warmup(phase, idx, num_candidates, batch_size,
seq_len, context_blocks=context_blocks)
with HabanaMemoryProfiler() as mem_prof:
self.warmup_scenario(
batch_size,
seq_len,
is_prompt,
kv_caches,
temperature=1.0 if batch_size
not in warmed_random_sampler_bs else 0,
context_blocks=context_blocks)
else:
graphed_bucket = (batch_size, seq_len, is_prompt, 0)
if graphed_bucket in self.graphed_buckets:
continue
self.graphed_buckets.add(graphed_bucket)
self.log_warmup(phase, idx, num_candidates, batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof:
self.warmup_scenario(batch_size,
seq_len,
is_prompt,
kv_caches,
temperature=1.0 if batch_size
not in warmed_random_sampler_bs else 0)
warmed_random_sampler_bs.add(batch_size)
used_mem = align_workers(mem_prof.consumed_device_memory,
torch.distributed.ReduceOp.MAX)
Expand Down Expand Up @@ -2236,7 +2310,7 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
is_prompt = phase == 'prompt'
graphs = graph == 't'
if graphs:
self.graphed_buckets.add((int(bs), int(seq_len), is_prompt))
self.graphed_buckets.add((int(bs), int(seq_len), is_prompt, 0))
self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches,
True)
raise AssertionError("Finished profiling")
Expand Down Expand Up @@ -2834,10 +2908,12 @@ def execute_model(
assert is_prompt is not None
batch_size = input_tokens.size(0)
seq_len = self._seq_len(attn_metadata)
num_blocks = self._num_blocks(attn_metadata)
use_graphs = self._use_graphs(batch_size,
seq_len,
is_prompt,
is_profile_run=profile_run_mode)
is_profile_run=profile_run_mode,
context_blocks=num_blocks)
self._check_config(batch_size, seq_len, attn_metadata, warmup_mode)

lora_mask: torch.Tensor = None
Expand Down
Loading