From 7196adf838eb9f9073c73867448816b3f7f94036 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 22:22:01 +0000 Subject: [PATCH 01/30] update to use separate scheduler Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler.py | 144 +----- vllm/v1/core/sched/scheduler_disagg.py | 630 +++++++++++++++++++++++++ vllm/v1/engine/core.py | 4 +- 3 files changed, 634 insertions(+), 144 deletions(-) create mode 100644 vllm/v1/core/sched/scheduler_disagg.py diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 17b89b68f0bb..a32b6f583af9 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -2,21 +2,15 @@ from __future__ import annotations -import itertools import time from collections import defaultdict, deque from collections.abc import Iterable from typing import Optional, Union -from vllm import envs from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) -from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.sampling_params import KVTransferParams from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager @@ -71,14 +65,6 @@ def __init__( self.kv_events_config is not None and self.kv_events_config.enable_kv_cache_events) - # Create KVConnector for the Scheduler. Note that each Worker - # will have a corresponding KVConnector with Role=WORKER. - # KV Connector pushes/pull of remote KVs for P/D and offloading. - self.connector = None - if self.vllm_config.kv_transfer_config is not None: - self.connector = KVConnectorFactory.create_connector_v1( - config=self.vllm_config, role=KVConnectorRole.SCHEDULER) - self.kv_event_publisher = EventPublisherFactory.create( self.kv_events_config) @@ -99,9 +85,6 @@ def __init__( # This is flushed at the end of each scheduling step. self.finished_req_ids: set[str] = set() - # Requests in states for tracking KV transfers for P/D disagg - self.finished_recving_kv_req_ids: set[str] = set() - # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. # Request id -> deque of CachedRequestData @@ -314,27 +297,6 @@ def schedule(self) -> SchedulerOutput: request = self.waiting[0] - # Skip request if the remote KV recv is still waiting - # for the requests to arrive. - if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: - if request.request_id in self.finished_recving_kv_req_ids: - assert self.kv_cache_manager.enable_caching - # Now that the KVs have been recved, we can cache - # them and set num_computed_tokens. - self.kv_cache_manager.cache_blocks( - request, - num_tokens=0, - num_computed_tokens=(len(request.all_token_ids) - - 1)) - self.finished_recving_kv_req_ids.remove( - request.request_id) - request.status = RequestStatus.WAITING - self.kv_cache_manager.free(request) - else: - self.waiting.popleft() - skipped_waiting_requests.appendleft(request) - continue - # Skip request if the structured output request is still waiting # for FSM compilation. if request.status == RequestStatus.WAITING_FOR_FSM: @@ -362,47 +324,6 @@ def schedule(self) -> SchedulerOutput: self.kv_cache_manager.get_computed_blocks( request) - # Get externally-cached tokens if using a KVConnector. - num_external_tokens = ( - 0 if self.connector is None else - self.connector.get_num_new_matched_tokens( - request, num_computed_tokens)) - - # Total computed tokens (local + external). - num_computed_tokens += num_external_tokens - - if request.do_remote_prefill and num_external_tokens > 0: - # Allocate slots for the external tokens, but skip - # caching until after the KV transfer is done. - new_blocks = self.kv_cache_manager.allocate_slots( - request, - num_external_tokens, - computed_blocks, - skip_cache_blocks=True) - if new_blocks is None: - # Requests cannot be scheduled - break - - self.waiting.popleft() - skipped_waiting_requests.appendleft(request) - request.status = RequestStatus.WAITING_FOR_REMOTE_KVS - - # KVConnector: update internal state after allocation. - # This information is used to determine if a load is - # needed for this request. - if self.connector is not None: - self.connector.update_state_after_alloc( - request, - [ - b.block_id for b in itertools.chain( - computed_blocks, new_blocks) - ], - num_external_tokens, - ) - # We should only trigger a KV transfer once per request. - request.do_remote_prefill = False - continue - # Number of tokens to be scheduled. # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed reqs, @@ -430,7 +351,7 @@ def schedule(self) -> SchedulerOutput: new_blocks = self.kv_cache_manager.allocate_slots( request, - num_new_tokens + num_external_tokens, + num_new_tokens, computed_blocks, num_lookahead_tokens=self.num_lookahead_tokens, ) @@ -438,19 +359,6 @@ def schedule(self) -> SchedulerOutput: # The request cannot be scheduled. break - # KVConnector: update internal state after allocation. - # This information is used to determine if a load is - # needed for this request. - if self.connector is not None: - self.connector.update_state_after_alloc( - request, - [ - b.block_id for b in itertools.chain( - computed_blocks, new_blocks) - ], - num_external_tokens, - ) - self.waiting.popleft() if request.use_structured_output: structured_output_request_ids[ @@ -558,14 +466,6 @@ def schedule(self) -> SchedulerOutput: grammar_bitmask=grammar_bitmask, ) - # NOTE(Kuntai): this function is designed for multiple purposes: - # 1. Plan the KV cache store - # 2. Wrap up all the KV cache load / save ops into an opaque object - # 3. Clear the internal states of the connector - if self.connector is not None: - meta = self.connector.build_connector_meta(scheduler_output) - scheduler_output.kv_connector_metadata = meta - events = self.kv_cache_manager.take_events() if events: batch = KVEventBatch(ts=time.time(), events=events) @@ -806,32 +706,6 @@ def update_from_output( # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids: - # Stop request after the first token if doing a remote_decode. - # NOTE(rob): req is not freed (or preempted) in the EngineCore - # until the xfer is done to ensure we do not free the KV blocks. - kv_transfer_params = None - # TODO(rob): edge case where we get a stop for stop_strings - # inside AsyncLLM. - if request.do_remote_decode and not stopped: - request.status = RequestStatus.FINISHED_REMOTE_DECODE - self._free_request(request, skip_free_blocks=True) - stopped = True - - # TODO(rob): do this on a per-Connector basis. - remote_blocks = [ - block.block_id for block in - self.kv_cache_manager.get_computed_blocks(request)[0] - ] - - engine_id = self.vllm_config.kv_transfer_config.engine_id - kv_transfer_params = KVTransferParams( - do_remote_prefill=True, - remote_block_ids=remote_blocks, - remote_engine_id=engine_id, - remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST, - remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT, - ) - # Add EngineCoreOutput for this Request. outputs.append( EngineCoreOutput( @@ -842,7 +716,6 @@ def update_from_output( new_prompt_logprobs_tensors=prompt_logprobs_tensors, stop_reason=request.stop_reason, events=request.take_events(), - kv_transfer_params=kv_transfer_params, )) else: @@ -852,14 +725,6 @@ def update_from_output( if not stopped: new_running.append(request) - # P/D: update recv and send status from last step. - for req_id in (model_runner_output.finished_recving or []): - logger.debug("Finished recving KV transfer for request %s", req_id) - self.finished_recving_kv_req_ids.add(req_id) - for req_id in (model_runner_output.finished_sending or []): - logger.debug("Finished sending KV transfer for request %s", req_id) - self._free_blocks(self.requests[req_id]) - # Return the cached request data to the queue so they can # be reused. Note: we cannot add stopped requests to this # since they are already freed above! @@ -923,13 +788,6 @@ def _free_request(self, self.encoder_cache_manager.free(request) self._cached_reqs_data.pop(request.request_id, None) self.finished_req_ids.add(request.request_id) - - if not skip_free_blocks: - self._free_blocks(request) - - def _free_blocks(self, request: Request): - assert request.is_finished() - assert request.request_id not in self._cached_reqs_data self.kv_cache_manager.free(request) self.kv_cache_manager.free_block_hashes(request) del self.requests[request.request_id] diff --git a/vllm/v1/core/sched/scheduler_disagg.py b/vllm/v1/core/sched/scheduler_disagg.py new file mode 100644 index 000000000000..ecceeeac01ce --- /dev/null +++ b/vllm/v1/core/sched/scheduler_disagg.py @@ -0,0 +1,630 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import itertools +import time +from collections import deque + +from vllm import envs +from vllm.distributed.kv_events import KVEventBatch +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole +from vllm.logger import init_logger +from vllm.sampling_params import KVTransferParams +from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.core.sched.utils import check_stop +from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, + EngineCoreOutputs) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request, RequestStatus + +logger = init_logger(__name__) + + +class DisaggregatedScheduler(Scheduler): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # NOTE(rob): there is no reason to believe these are not + # supported. However, I would like to test them first + # before enabling them with P/D. + if self.use_eagle or self.vllm_config.speculative_config: + raise NotImplementedError( + "Speculative Decoding is not yet supported with " + "KV Disaggregation.") + if self.lora_config: + raise NotImplementedError( + "LoRA is not yet supported with KV Disaggregation.") + + # Create KVConnector for the Scheduler. + if self.vllm_config.kv_transfer_config is not None: + raise ValueError("Using DisaggregatedScheduler but found unset " + "kv_transfer_config.") + self.connector = KVConnectorFactory.create_connector_v1( + config=self.vllm_config, role=KVConnectorRole.SCHEDULER) + + # Requests in states for tracking KV transfers. + self.finished_recving_kv_req_ids: set[str] = set() + + def schedule(self) -> SchedulerOutput: + # NOTE(woosuk) on the scheduling algorithm: + # There's no "decoding phase" nor "prefill phase" in the scheduler. + # Each request just has the num_computed_tokens and + # num_tokens_with_spec. num_tokens_with_spec = + # len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids). + # At each step, the scheduler tries to assign tokens to the requests + # so that each request's num_computed_tokens can catch up its + # num_tokens_with_spec. This is general enough to cover + # chunked prefills, prefix caching, speculative decoding, + # and the "jump decoding" optimization in the future. + + scheduled_new_reqs: list[Request] = [] + scheduled_resumed_reqs: list[Request] = [] + scheduled_running_reqs: list[Request] = [] + preempted_reqs: list[Request] = [] + + # NOTE: structured_output_request_ids maps + # a request's (request that uses structured output) + # request_id to the running request index. + # This will helps us determine to slice the grammar bitmask + # and only applies valid mask for requests that + # uses structured decoding. + structured_output_request_ids: dict[str, int] = {} + + req_to_new_block_ids: dict[str, list[int]] = {} + num_scheduled_tokens: dict[str, int] = {} + token_budget = self.max_num_scheduled_tokens + # Encoder-related. + scheduled_encoder_inputs: dict[str, list[int]] = {} + encoder_budget = self.max_num_encoder_input_tokens + + # For logging. + scheduled_timestamp = time.monotonic() + + # First, schedule the RUNNING requests. + req_index = 0 + while req_index < len(self.running) and token_budget > 0: + request = self.running[req_index] + + num_new_tokens = (request.num_tokens_with_spec - + request.num_computed_tokens) + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = min(num_new_tokens, token_budget) + + # Make sure the input position does not exceed the max model len. + # This is necessary when using spec decoding. + num_new_tokens = min( + num_new_tokens, + self.max_model_len - request.num_computed_tokens) + + # Schedule encoder inputs. + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget) = self._try_schedule_encoder_inputs( + request, request.num_computed_tokens, num_new_tokens, + encoder_budget) + + if num_new_tokens == 0: + # The request cannot be scheduled because one of the following + # reasons: + # 1. No new tokens to schedule. This may happen when PP>1 and + # we have already scheduled all prompt tokens but they are + # not finished yet. + # 2. The encoder budget is exhausted. + # 3. The encoder cache is exhausted. + # NOTE(woosuk): Here, by doing `continue` instead of `break`, + # we do not strictly follow the FCFS scheduling policy and + # allow the lower-priority requests to be scheduled. + req_index += 1 + continue + + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, num_new_tokens) + if new_blocks is None: + # The request cannot be scheduled. + # Preempt the lowest-priority request. + preempted_req = self.running.pop() + self.kv_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + if self.log_stats: + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, scheduled_timestamp) + + self.waiting.appendleft(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. + can_schedule = False + break + else: + # The request can be scheduled. + can_schedule = True + break + if not can_schedule: + break + assert new_blocks is not None + + # Schedule the request. + scheduled_running_reqs.append(request) + if request.use_structured_output: + # PERF: in case of chunked prefill, + # request might not include any new tokens. + # Therefore, we might introduce some additional + # cycle to fill in the bitmask, which could be a big no-op. + structured_output_request_ids[request.request_id] = req_index + req_to_new_block_ids[request.request_id] = [ + b.block_id for b in new_blocks + ] + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + req_index += 1 + + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget + + # Use a temporary deque to collect requests that need to be skipped + # and put back at the head of the waiting queue later (e.g. for FSM + # or KVCacheSending). + skipped_waiting_requests: deque[Request] = deque() + + # Next, schedule the WAITING requests. + if not preempted_reqs: + while self.waiting and token_budget > 0: + if len(self.running) == self.max_num_running_reqs: + break + + request = self.waiting[0] + + # Skip request if the remote KV recv is still waiting + # for the requests to arrive. + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + if request.request_id in self.finished_recving_kv_req_ids: + assert self.kv_cache_manager.enable_caching + # Now that the KVs have been recved, we can cache + # them and set num_computed_tokens. + self.kv_cache_manager.cache_blocks( + request, + num_tokens=0, + num_computed_tokens=(len(request.all_token_ids) - + 1)) + self.finished_recving_kv_req_ids.remove( + request.request_id) + request.status = RequestStatus.WAITING + self.kv_cache_manager.free(request) + else: + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + + # Skip request if the structured output request is still waiting + # for FSM compilation. + if request.status == RequestStatus.WAITING_FOR_FSM: + structured_output_req = request.structured_output_request + if structured_output_req and structured_output_req.grammar: + request.status = RequestStatus.WAITING + else: + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + + # Get already-cached tokens. + computed_blocks, num_computed_tokens = \ + self.kv_cache_manager.get_computed_blocks( + request) + + # Get externally-cached tokens if using a KVConnector. + num_external_tokens = ( + 0 if self.connector is None else + self.connector.get_num_new_matched_tokens( + request, num_computed_tokens)) + + # Total computed tokens (local + external). + num_computed_tokens += num_external_tokens + + if request.do_remote_prefill and num_external_tokens > 0: + # Allocate slots for the external tokens, but skip + # caching until after the KV transfer is done. + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_external_tokens, + computed_blocks, + skip_cache_blocks=True) + if new_blocks is None: + # Requests cannot be scheduled + break + + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + + # KVConnector: update internal state after allocation. + # This information is used to determine if a load is + # needed for this request. + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + [ + b.block_id for b in itertools.chain( + computed_blocks, new_blocks) + ], + num_external_tokens, + ) + # We should only trigger a KV transfer once per request. + request.do_remote_prefill = False + continue + + # Number of tokens to be scheduled. + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed reqs, + # which have output tokens. + num_new_tokens = request.num_tokens - num_computed_tokens + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + + # Schedule encoder inputs. + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget) = self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, + encoder_budget) + if num_new_tokens == 0: + # The request cannot be scheduled. + break + else: + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens + num_external_tokens, + computed_blocks, + ) + if new_blocks is None: + # The request cannot be scheduled. + break + + # KVConnector: update internal state after allocation. + # This information is used to determine if a load is + # needed for this request. + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + [ + b.block_id for b in itertools.chain( + computed_blocks, new_blocks) + ], + num_external_tokens, + ) + + self.waiting.popleft() + if request.use_structured_output: + structured_output_request_ids[ + request.request_id] = req_index + req_index += 1 + self.running.append(request) + if self.log_stats: + request.record_event(EngineCoreEventType.SCHEDULED, + scheduled_timestamp) + if request.status == RequestStatus.WAITING: + scheduled_new_reqs.append(request) + elif request.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(request) + else: + raise RuntimeError( + f"Invalid request status: {request.status}") + + req_to_new_block_ids[request.request_id] = [ + b.block_id for b in computed_blocks + new_blocks + ] + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + request.status = RequestStatus.RUNNING + request.num_computed_tokens = num_computed_tokens + + # Encoder-related. + if not request.do_remote_prefill and encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget + + # Put back any skipped requests at the head of the waiting queue + if skipped_waiting_requests: + self.waiting.extendleft(skipped_waiting_requests) + + # Check if the scheduling constraints are satisfied. + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert token_budget >= 0 + assert len(self.running) <= self.max_num_running_reqs + # Since some requests in the RUNNING queue may not be scheduled in + # this step, the total number of scheduled requests can be smaller than + # len(self.running). + assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + + len(scheduled_running_reqs) <= len(self.running)) + + # Get the longest common prefix among all requests in the running queue. + # This can be potentially used for cascade attention. + num_common_prefix_blocks = 0 + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request, len(self.running))) + + grammar_bitmask = self.structured_output_manager.grammar_bitmask( + self.requests, + structured_output_request_ids, + scheduled_spec_decode_tokens={}) + # Construct the scheduler output. + new_reqs_data = [ + NewRequestData.from_request(req, + req_to_new_block_ids[req.request_id]) + for req in scheduled_new_reqs + ] + resumed_reqs_data = [ + self._make_cached_request_data( + request=req, + num_scheduled_tokens=num_scheduled_tokens[req.request_id], + num_scheduled_spec_tokens=0, + new_block_ids=req_to_new_block_ids[req.request_id], + resumed_from_preemption=True, + ) for req in scheduled_resumed_reqs + ] + running_reqs_data = [ + self._make_cached_request_data( + request=req, + num_scheduled_tokens=num_scheduled_tokens[req.request_id], + num_scheduled_spec_tokens=0, + new_block_ids=req_to_new_block_ids[req.request_id], + resumed_from_preemption=False, + ) for req in scheduled_running_reqs + ] + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=resumed_reqs_data + running_reqs_data, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between + # the previous and the current steps. + finished_req_ids=self.finished_req_ids, + free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), + structured_output_request_ids=structured_output_request_ids, + grammar_bitmask=grammar_bitmask, + ) + + # NOTE(Kuntai): this function is designed for multiple purposes: + # 1. Plan the KV cache store + # 2. Wrap up all the KV cache load / save ops into an opaque object + # 3. Clear the internal states of the connector + meta = self.connector.build_connector_meta(scheduler_output) + scheduler_output.kv_connector_metadata = meta + + events = self.kv_cache_manager.take_events() + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) + + # Advance the number of computed tokens for the request AFTER + # the request is scheduled. + # 1. The scheduler_output of the current step has to include the + # original number of scheduled tokens to determine input IDs. + # 2. Advance the number of computed tokens here allowing us to + # schedule the prefill request again immediately in the next + # scheduling step. + # 3. If some tokens (e.g. spec tokens) are rejected later, the number of + # computed tokens will be adjusted in update_from_output. + for req_id, num_scheduled_token in num_scheduled_tokens.items(): + if req := self.requests.get(req_id): + req.num_computed_tokens += num_scheduled_token + + self.finished_req_ids = set() + return scheduler_output + + def update_from_output( + self, + scheduler_output: SchedulerOutput, + model_runner_output: ModelRunnerOutput, + ) -> EngineCoreOutputs: + sampled_token_ids = model_runner_output.sampled_token_ids + spec_token_ids = model_runner_output.spec_token_ids + logprobs = model_runner_output.logprobs + prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + + new_running: list[Request] = [] + outputs: list[EngineCoreOutput] = [] + + # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below + # loop can be a performance bottleneck. We should do our best to avoid + # expensive operations inside the loop. + for request in self.running: + req_id = request.request_id + num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) + if num_tokens_scheduled == 0: + # The request was not scheduled in this step. + new_running.append(request) + continue + + req_index = model_runner_output.req_id_to_index[req_id] + generated_token_ids = sampled_token_ids[req_index] + + cached_encoder_input_ids = ( + self.encoder_cache_manager.get_cached_input_ids(request)) + # OPTIMIZATION: Avoid list(set) if the set is empty. + if cached_encoder_input_ids: + for input_id in list(cached_encoder_input_ids): + mm_positions = request.mm_positions[input_id] + start_pos = mm_positions.offset + num_tokens = mm_positions.length + if start_pos + num_tokens <= request.num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + self.encoder_cache_manager.free_encoder_input( + request, input_id) + + stopped = False + new_logprobs = None + new_token_ids = generated_token_ids + + # Append generated tokens and check for stop. Note that if + # a request is still being prefilled, we expect the model runner + # to return empty token ids for the request. + for num_new, output_token_id in enumerate(new_token_ids, 1): + request.append_output_token_ids(output_token_id) + + # Check for stop and update request state. + # This must be called before we make the EngineCoreOutput. + stopped = check_stop(request, self.max_model_len) + if stopped: + self._free_request(request) + del new_token_ids[num_new:] # Trim new tokens if needed. + break + + # Extract sample logprobs if needed. + if request.sampling_params.logprobs is not None and logprobs: + # NOTE: once we support N tokens per step (spec decode), + # the outer lists can be of length > 1. + new_logprobs = logprobs.slice(req_index, req_index + 1) + + if new_token_ids and request.use_structured_output: + # NOTE: structured_output_request + # should not be None if use_structured_output, we have + # check above, so safe to ignore type warning + request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] + req_id, new_token_ids) + + # Add newly generated spec token ids to the request. + if spec_token_ids is not None: + if request.use_structured_output: + metadata = request.structured_output_request + assert metadata is not None and metadata.grammar is not None + # Needs to happen after new_token_ids are accepted. + request.spec_token_ids = metadata.grammar.validate_tokens( + spec_token_ids[req_index]) + else: + request.spec_token_ids = spec_token_ids[req_index] + + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + if new_token_ids: + # Stop request after the first token if doing a remote_decode. + # NOTE(rob): req is not freed (or preempted) in the EngineCore + # until the xfer is done to ensure we do not free the KV blocks. + kv_transfer_params = None + # TODO(rob): edge case where we get a stop for stop_strings + # inside AsyncLLM. + if request.do_remote_decode and not stopped: + request.status = RequestStatus.FINISHED_REMOTE_DECODE + self._free_request(request, skip_free_blocks=True) + stopped = True + + # TODO(rob): do this on a per-Connector basis. + remote_blocks = [ + block.block_id for block in + self.kv_cache_manager.get_computed_blocks(request)[0] + ] + + engine_id = self.vllm_config.kv_transfer_config.engine_id + kv_transfer_params = KVTransferParams( + do_remote_prefill=True, + remote_block_ids=remote_blocks, + remote_engine_id=engine_id, + remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST, + remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT, + ) + + # Add EngineCoreOutput for this Request. + outputs.append( + EngineCoreOutput( + request_id=req_id, + new_token_ids=new_token_ids, + finish_reason=request.get_finished_reason(), + new_logprobs=new_logprobs, + new_prompt_logprobs_tensors=prompt_logprobs_tensors, + stop_reason=request.stop_reason, + events=request.take_events(), + kv_transfer_params=kv_transfer_params, + )) + + else: + # Invariant: EngineCore returns no partial prefill outputs. + assert not prompt_logprobs_tensors + + if not stopped: + new_running.append(request) + + # P/D: update recv and send status from last step. + for req_id in (model_runner_output.finished_recving or []): + logger.debug("Finished recving KV transfer for request %s", req_id) + self.finished_recving_kv_req_ids.add(req_id) + for req_id in (model_runner_output.finished_sending or []): + logger.debug("Finished sending KV transfer for request %s", req_id) + self._free_blocks(self.requests[req_id]) + + # Return the cached request data to the queue so they can + # be reused. Note: we cannot add stopped requests to this + # since they are already freed above! + for req_data in scheduler_output.scheduled_cached_reqs: + # NOTE(rob): since we free stopped reqs above, adding stopped reqs + # to _cached_reqs_data will cause a memory leak. + if req_data.req_id not in self.finished_req_ids: + self._cached_reqs_data[req_data.req_id].append(req_data) + + self.running = new_running + engine_core_outputs = EngineCoreOutputs( + outputs=outputs, + scheduler_stats=self.make_stats(), + ) + if self.include_finished_set: + #TODO currently sending duplicates here, improve this + engine_core_outputs.finished_requests = ( + scheduler_output.finished_req_ids | self.finished_req_ids) + + return engine_core_outputs + + def _free_request(self, + request: Request, + skip_free_blocks: bool = False) -> None: + assert request.is_finished() + self.encoder_cache_manager.free(request) + self._cached_reqs_data.pop(request.request_id, None) + self.finished_req_ids.add(request.request_id) + + if not skip_free_blocks: + self._free_blocks(request) + + def _free_blocks(self, request: Request): + assert request.is_finished() + assert request.request_id not in self._cached_reqs_data + self.kv_cache_manager.free(request) + self.kv_cache_manager.free_block_hashes(request) + del self.requests[request.request_id] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e772615b7861..d13cef4a3abf 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -28,6 +28,8 @@ from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler +from vllm.v1.core.sched.scheduler_disagg import ( # noqa: E501 + DisaggregatedScheduler as V1DisaggregatedScheduler) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.mm_input_cache import MirroredProcessingCache @@ -86,7 +88,7 @@ def __init__(self, # This warning can be removed once the V1 Scheduler interface is # finalized and we can maintain support for scheduler classes that # implement it - if Scheduler is not V1Scheduler: + if Scheduler not in [V1Scheduler, V1DisaggregatedScheduler]: logger.warning( "Using configured V1 scheduler class %s. " "This scheduler interface is not public and " From 493cda1136ffcbb83bd2e3786e8dd6c9a1935e83 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 22:23:56 +0000 Subject: [PATCH 02/30] update to use separate scheduler Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler_disagg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/scheduler_disagg.py b/vllm/v1/core/sched/scheduler_disagg.py index ecceeeac01ce..6eae9f347a17 100644 --- a/vllm/v1/core/sched/scheduler_disagg.py +++ b/vllm/v1/core/sched/scheduler_disagg.py @@ -343,7 +343,7 @@ def schedule(self) -> SchedulerOutput: request.num_computed_tokens = num_computed_tokens # Encoder-related. - if not request.do_remote_prefill and encoder_inputs_to_schedule: + if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( encoder_inputs_to_schedule) # Allocate the encoder cache. From 24279188ff6468dcc79a93be0768363cbab8b399 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 22:24:18 +0000 Subject: [PATCH 03/30] update to use separate scheduler Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler_disagg.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/core/sched/scheduler_disagg.py b/vllm/v1/core/sched/scheduler_disagg.py index 6eae9f347a17..57fc98393275 100644 --- a/vllm/v1/core/sched/scheduler_disagg.py +++ b/vllm/v1/core/sched/scheduler_disagg.py @@ -443,8 +443,7 @@ def schedule(self) -> SchedulerOutput: # 3. If some tokens (e.g. spec tokens) are rejected later, the number of # computed tokens will be adjusted in update_from_output. for req_id, num_scheduled_token in num_scheduled_tokens.items(): - if req := self.requests.get(req_id): - req.num_computed_tokens += num_scheduled_token + self.requests[req_id].num_computed_tokens += num_scheduled_token self.finished_req_ids = set() return scheduler_output From 50cecf675502732a0799c919de78bd64f58c397a Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 22:25:34 +0000 Subject: [PATCH 04/30] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler.py | 8 +++----- vllm/v1/core/sched/scheduler_disagg.py | 3 +-- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index a32b6f583af9..46db747e92b0 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -781,15 +781,13 @@ def finish_requests( request.status = finished_status self._free_request(request) - def _free_request(self, - request: Request, - skip_free_blocks: bool = False) -> None: + def _free_request(self, request: Request) -> None: assert request.is_finished() + self.kv_cache_manager.free(request) + self.kv_cache_manager.free_block_hashes(request) self.encoder_cache_manager.free(request) self._cached_reqs_data.pop(request.request_id, None) self.finished_req_ids.add(request.request_id) - self.kv_cache_manager.free(request) - self.kv_cache_manager.free_block_hashes(request) del self.requests[request.request_id] def get_num_unfinished_requests(self) -> int: diff --git a/vllm/v1/core/sched/scheduler_disagg.py b/vllm/v1/core/sched/scheduler_disagg.py index 57fc98393275..80cea05e4418 100644 --- a/vllm/v1/core/sched/scheduler_disagg.py +++ b/vllm/v1/core/sched/scheduler_disagg.py @@ -590,8 +590,7 @@ def update_from_output( self._free_blocks(self.requests[req_id]) # Return the cached request data to the queue so they can - # be reused. Note: we cannot add stopped requests to this - # since they are already freed above! + # be reused. for req_data in scheduler_output.scheduled_cached_reqs: # NOTE(rob): since we free stopped reqs above, adding stopped reqs # to _cached_reqs_data will cause a memory leak. From ba3e759197a7536293f04179cef21702995ee9c2 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 22:35:51 +0000 Subject: [PATCH 05/30] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/worker/gpu_model_runner.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3a8dae04ee0a..d0e8d62eba2c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1064,10 +1064,9 @@ def maybe_get_finished() -> tuple[set[str], set[str]]: return output # Prepare the decoder inputs. - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - attn_metadata, logits_indices, spec_decode_metadata = ( self._prepare_inputs(scheduler_output)) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): # Use piecewise CUDA graphs. @@ -1141,7 +1140,7 @@ def maybe_get_finished() -> tuple[set[str], set[str]]: num_tokens=num_input_tokens): maybe_setup_kv_connector() - model_output = self.model( + output = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, @@ -1152,9 +1151,9 @@ def maybe_get_finished() -> tuple[set[str], set[str]]: finished_sending, finished_recving = maybe_get_finished() if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = model_output + hidden_states, aux_hidden_states = output else: - hidden_states = model_output + hidden_states = output if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. From 096de020900e3afc5cac23294fdd4de945d73537 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 22:37:12 +0000 Subject: [PATCH 06/30] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/engine/output_processor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 1d98f15ebde3..0e901fc8327d 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -305,22 +305,22 @@ def process_outputs( 1) Compute stats for logging 2) Detokenize 3) Create and handle RequestOutput objects: - * If there is a queue (for usage with AsyncLLM), + * If there is a queue (for usage with AsyncLLM), put the RequestOutput objects into the queue for handling by the per-request generate() tasks. - * If there is no queue (for usage with LLMEngine), + * If there is no queue (for usage with LLMEngine), return a list of RequestOutput objects. ****************** NOTE FOR DEVELOPERS ****************** vLLM V1 minimizes the number of python loops over the full - batch to ensure system overheads are minimized. This is the + batch to ensure system overheads are minimized. This is the only function that should loop over EngineCoreOutputs. If you need to touch every element of the batch, do it from within the loop below. - + ********************************************************** """ From 0fefd4a9a0e47e2a4be735539579fff7d1eea818 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 22:39:10 +0000 Subject: [PATCH 07/30] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler.py | 2 +- vllm/v1/core/sched/scheduler_disagg.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 46db747e92b0..f6d6f517b758 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -387,7 +387,7 @@ def schedule(self) -> SchedulerOutput: request.num_computed_tokens = num_computed_tokens # Encoder-related. - if not request.do_remote_prefill and encoder_inputs_to_schedule: + if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( encoder_inputs_to_schedule) # Allocate the encoder cache. diff --git a/vllm/v1/core/sched/scheduler_disagg.py b/vllm/v1/core/sched/scheduler_disagg.py index 80cea05e4418..2a19a4c116b1 100644 --- a/vllm/v1/core/sched/scheduler_disagg.py +++ b/vllm/v1/core/sched/scheduler_disagg.py @@ -171,7 +171,7 @@ def schedule(self) -> SchedulerOutput: req_index += 1 # Encoder-related. - if encoder_inputs_to_schedule: + if not request.do_remote_prefill and encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( encoder_inputs_to_schedule) # Allocate the encoder cache. From eba69825a347623b2d349eb164c16f2ee0ff6e88 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 22:39:59 +0000 Subject: [PATCH 08/30] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/kv_cache_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 12c55be00375..5f647366dcdf 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -178,7 +178,7 @@ def allocate_slots( prefix caching. num_lookahead_tokens: The number of speculative tokens to allocate. This is used by spec decode proposers with kv-cache such - as eagle. + as eagle. skip_cache_blocks: Whether to skip caching the blocks. This is used by P/D when allocating blocks used in a KV transfer which will complete in a future step. From d97cbf99851904f0b68e72b8e2c2e3f262d7d181 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 22:40:17 +0000 Subject: [PATCH 09/30] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/kv_cache_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 5f647366dcdf..c4b83c89398d 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -307,7 +307,7 @@ def cache_blocks( num_cached_blocks = self.num_cached_block.get(request.request_id, len(new_computed_blocks)) - # Speculated tokens might be rejected in the future, so we do + # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( From 590c213982bcb0a7c18fc351f5ee37fa65d087b6 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 22:42:12 +0000 Subject: [PATCH 10/30] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/kv_cache_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index c4b83c89398d..ef05f895a714 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -306,7 +306,6 @@ def cache_blocks( # for a running request. num_cached_blocks = self.num_cached_block.get(request.request_id, len(new_computed_blocks)) - # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. From be0407a7ea751785719421d0d20c53ba96079c8b Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 22:46:57 +0000 Subject: [PATCH 11/30] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/engine/output_processor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 0e901fc8327d..30557f0f02dc 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -146,7 +146,7 @@ def make_request_output( new_token_ids: list[int], finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], - kv_transfer_params: KVTransferParams, + kv_transfer_params: Optional[KVTransferParams] = None, ) -> Optional[RequestOutput]: finished = finish_reason is not None @@ -176,7 +176,7 @@ def _new_request_output( request_id: str, outputs: list[CompletionOutput], finished: bool, - kv_transfer_params: KVTransferParams, + kv_transfer_params: Optional[KVTransferParams] = None, ) -> RequestOutput: if self.output_kind == RequestOutputKind.DELTA: @@ -320,7 +320,7 @@ def process_outputs( If you need to touch every element of the batch, do it from within the loop below. - + ********************************************************** """ From c905a48f1c90ef2504177575ba5ffe4a098d9034 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 22:52:25 +0000 Subject: [PATCH 12/30] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/distributed/kv_transfer/kv_connector/v1/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index a4a735890dab..ca9e19156719 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -22,6 +22,7 @@ import enum from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import TYPE_CHECKING import torch @@ -46,6 +47,7 @@ class KVConnectorRole(enum.Enum): WORKER = 1 +@dataclass class KVConnectorMetadata: pass From 0b2cc61839a8d8d67b9b9edf25a7247eb585e3ce Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 22:56:41 +0000 Subject: [PATCH 13/30] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f6d6f517b758..08bd33d6bd3a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -787,8 +787,8 @@ def _free_request(self, request: Request) -> None: self.kv_cache_manager.free_block_hashes(request) self.encoder_cache_manager.free(request) self._cached_reqs_data.pop(request.request_id, None) - self.finished_req_ids.add(request.request_id) del self.requests[request.request_id] + self.finished_req_ids.add(request.request_id) def get_num_unfinished_requests(self) -> int: return len(self.waiting) + len(self.running) From 23c3a6fd969e5ffe7f0556366c4df4068591c8e0 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 22:57:36 +0000 Subject: [PATCH 14/30] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 08bd33d6bd3a..f65c2bcf29d9 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -725,9 +725,7 @@ def update_from_output( if not stopped: new_running.append(request) - # Return the cached request data to the queue so they can - # be reused. Note: we cannot add stopped requests to this - # since they are already freed above! + # Return the cached request data to the queue so they can be reused. for req_data in scheduler_output.scheduled_cached_reqs: # NOTE(rob): since we free stopped reqs above, adding stopped reqs # to _cached_reqs_data will cause a memory leak. From 000715da7d802fb0e931f95d9b875abeb7952339 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 22:59:07 +0000 Subject: [PATCH 15/30] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler.py | 3 +-- vllm/v1/core/sched/scheduler_disagg.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f65c2bcf29d9..be11519da555 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -481,8 +481,7 @@ def schedule(self) -> SchedulerOutput: # 3. If some tokens (e.g. spec tokens) are rejected later, the number of # computed tokens will be adjusted in update_from_output. for req_id, num_scheduled_token in num_scheduled_tokens.items(): - if req := self.requests.get(req_id): - req.num_computed_tokens += num_scheduled_token + self.requests[req_id].num_computed_tokens += num_scheduled_token self.finished_req_ids = set() return scheduler_output diff --git a/vllm/v1/core/sched/scheduler_disagg.py b/vllm/v1/core/sched/scheduler_disagg.py index 2a19a4c116b1..e93c7a5b00f5 100644 --- a/vllm/v1/core/sched/scheduler_disagg.py +++ b/vllm/v1/core/sched/scheduler_disagg.py @@ -443,7 +443,8 @@ def schedule(self) -> SchedulerOutput: # 3. If some tokens (e.g. spec tokens) are rejected later, the number of # computed tokens will be adjusted in update_from_output. for req_id, num_scheduled_token in num_scheduled_tokens.items(): - self.requests[req_id].num_computed_tokens += num_scheduled_token + if req := self.requests.get(req_id): + req.num_computed_tokens += num_scheduled_token self.finished_req_ids = set() return scheduler_output From 37ae9ad4d112c496c1db336b4af629a06a181eba Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 22:59:24 +0000 Subject: [PATCH 16/30] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index be11519da555..d7a8bd972ee4 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -714,8 +714,7 @@ def update_from_output( new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, stop_reason=request.stop_reason, - events=request.take_events(), - )) + events=request.take_events())) else: # Invariant: EngineCore returns no partial prefill outputs. From 59280ea405d605839f565eae8ec1f8446fe50c23 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 23:00:13 +0000 Subject: [PATCH 17/30] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/engine/output_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 30557f0f02dc..3f6f1f685e4c 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -320,7 +320,7 @@ def process_outputs( If you need to touch every element of the batch, do it from within the loop below. - + ********************************************************** """ From 70c766f669f6a1e0b30542b58e8501cd183b33c4 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 23:07:07 +0000 Subject: [PATCH 18/30] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler.py | 2 +- vllm/v1/core/sched/scheduler_disagg.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d7a8bd972ee4..3177eb913f2a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -326,7 +326,7 @@ def schedule(self) -> SchedulerOutput: # Number of tokens to be scheduled. # We use `request.num_tokens` instead of - # `request.num_prompt_tokens` to consider the resumed reqs, + # `request.num_prompt_tokens` to consider the resumed requests, # which have output tokens. num_new_tokens = request.num_tokens - num_computed_tokens if (0 < self.scheduler_config.long_prefill_token_threshold < diff --git a/vllm/v1/core/sched/scheduler_disagg.py b/vllm/v1/core/sched/scheduler_disagg.py index e93c7a5b00f5..c285a1e22756 100644 --- a/vllm/v1/core/sched/scheduler_disagg.py +++ b/vllm/v1/core/sched/scheduler_disagg.py @@ -272,7 +272,7 @@ def schedule(self) -> SchedulerOutput: # Number of tokens to be scheduled. # We use `request.num_tokens` instead of - # `request.num_prompt_tokens` to consider the resumed reqs, + # `request.num_prompt_tokens` to consider the resumed request, # which have output tokens. num_new_tokens = request.num_tokens - num_computed_tokens if (0 < self.scheduler_config.long_prefill_token_threshold < From 089b1d70110d6a4897872e9ed2d42a7ecf2d76bd Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 23:11:02 +0000 Subject: [PATCH 19/30] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/kv_cache_manager.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index ef05f895a714..9ccb7d15b56f 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -271,13 +271,9 @@ def allocate_slots( if not self.enable_caching: return new_blocks + # For disaggregated, avoid caching until KVs are recved. if skip_cache_blocks: - # NOTE(rob): this assert is valid because we only call - # skip_cache_blocks=True on the first time of WAITING - # during a P/D setup. assert request.request_id not in self.num_cached_block - # NOTE(rob): this is necessary so we don't double - # cache a block after is has finished recving. self.num_cached_block[request.request_id] = len( new_computed_blocks) return new_blocks From 17e90857a1003b4a9202a93e519e31ab5e576af9 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 23:32:42 +0000 Subject: [PATCH 20/30] updated Signed-off-by: rshaw@neuralmagic.com --- .../v1/kv_connector/{ => integration}/run_accuracy_test.sh | 0 tests/v1/kv_connector/{ => integration}/test_accuracy.py | 0 .../v1/kv_connector/{ => integration}/toy_proxy_server.py | 0 tests/v1/kv_connector/{ => unit}/__init__.py | 0 .../unit}/test_multi_connector.py | 0 tests/v1/kv_connector/{ => unit}/test_nixl_connector.py | 0 .../{ => unit}/test_remote_decode_lifecycle.py | 0 .../{ => unit}/test_remote_prefill_lifecycle.py | 0 tests/v1/kv_connector/{ => unit}/utils.py | 0 vllm/engine/arg_utils.py | 7 ++++++- vllm/v1/core/sched/scheduler_disagg.py | 4 ++-- 11 files changed, 8 insertions(+), 3 deletions(-) rename tests/v1/kv_connector/{ => integration}/run_accuracy_test.sh (100%) rename tests/v1/kv_connector/{ => integration}/test_accuracy.py (100%) rename tests/v1/kv_connector/{ => integration}/toy_proxy_server.py (100%) rename tests/v1/kv_connector/{ => unit}/__init__.py (100%) rename tests/v1/{kv_transfer => kv_connector/unit}/test_multi_connector.py (100%) rename tests/v1/kv_connector/{ => unit}/test_nixl_connector.py (100%) rename tests/v1/kv_connector/{ => unit}/test_remote_decode_lifecycle.py (100%) rename tests/v1/kv_connector/{ => unit}/test_remote_prefill_lifecycle.py (100%) rename tests/v1/kv_connector/{ => unit}/utils.py (100%) diff --git a/tests/v1/kv_connector/run_accuracy_test.sh b/tests/v1/kv_connector/integration/run_accuracy_test.sh similarity index 100% rename from tests/v1/kv_connector/run_accuracy_test.sh rename to tests/v1/kv_connector/integration/run_accuracy_test.sh diff --git a/tests/v1/kv_connector/test_accuracy.py b/tests/v1/kv_connector/integration/test_accuracy.py similarity index 100% rename from tests/v1/kv_connector/test_accuracy.py rename to tests/v1/kv_connector/integration/test_accuracy.py diff --git a/tests/v1/kv_connector/toy_proxy_server.py b/tests/v1/kv_connector/integration/toy_proxy_server.py similarity index 100% rename from tests/v1/kv_connector/toy_proxy_server.py rename to tests/v1/kv_connector/integration/toy_proxy_server.py diff --git a/tests/v1/kv_connector/__init__.py b/tests/v1/kv_connector/unit/__init__.py similarity index 100% rename from tests/v1/kv_connector/__init__.py rename to tests/v1/kv_connector/unit/__init__.py diff --git a/tests/v1/kv_transfer/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py similarity index 100% rename from tests/v1/kv_transfer/test_multi_connector.py rename to tests/v1/kv_connector/unit/test_multi_connector.py diff --git a/tests/v1/kv_connector/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py similarity index 100% rename from tests/v1/kv_connector/test_nixl_connector.py rename to tests/v1/kv_connector/unit/test_nixl_connector.py diff --git a/tests/v1/kv_connector/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py similarity index 100% rename from tests/v1/kv_connector/test_remote_decode_lifecycle.py rename to tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py diff --git a/tests/v1/kv_connector/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py similarity index 100% rename from tests/v1/kv_connector/test_remote_prefill_lifecycle.py rename to tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py diff --git a/tests/v1/kv_connector/utils.py b/tests/v1/kv_connector/unit/utils.py similarity index 100% rename from tests/v1/kv_connector/utils.py rename to tests/v1/kv_connector/unit/utils.py diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index aefba620e189..57b14101b9ea 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1430,7 +1430,12 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None: # V1 should use the new scheduler by default. # Swap it only if this arg is set to the original V0 default if self.scheduler_cls == EngineArgs.scheduler_cls: - self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler" + if self.kv_transfer_config: + self.scheduler_cls = ( + "vllm.v1.core.sched.scheduler_disagg.DisaggregatedScheduler" + ) + else: + self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler" # When no user override, set the default values based on the usage # context. diff --git a/vllm/v1/core/sched/scheduler_disagg.py b/vllm/v1/core/sched/scheduler_disagg.py index c285a1e22756..35d1c2041b5d 100644 --- a/vllm/v1/core/sched/scheduler_disagg.py +++ b/vllm/v1/core/sched/scheduler_disagg.py @@ -41,8 +41,8 @@ def __init__(self, *args, **kwargs): "LoRA is not yet supported with KV Disaggregation.") # Create KVConnector for the Scheduler. - if self.vllm_config.kv_transfer_config is not None: - raise ValueError("Using DisaggregatedScheduler but found unset " + if self.vllm_config.kv_transfer_config is None: + raise ValueError("Using Disaggregated Scheduler but found unset " "kv_transfer_config.") self.connector = KVConnectorFactory.create_connector_v1( config=self.vllm_config, role=KVConnectorRole.SCHEDULER) From e5547593c5f7d67e02da311fe7fb42e568fc260c Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 23:47:27 +0000 Subject: [PATCH 21/30] updated Signed-off-by: rshaw@neuralmagic.com --- .buildkite/test-pipeline.yaml | 2 +- .../{integration => nixl_integration}/run_accuracy_test.sh | 4 ++-- .../{integration => nixl_integration}/test_accuracy.py | 0 .../{integration => nixl_integration}/toy_proxy_server.py | 2 -- 4 files changed, 3 insertions(+), 5 deletions(-) rename tests/v1/kv_connector/{integration => nixl_integration}/run_accuracy_test.sh (94%) rename tests/v1/kv_connector/{integration => nixl_integration}/test_accuracy.py (100%) rename tests/v1/kv_connector/{integration => nixl_integration}/toy_proxy_server.py (99%) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index ddadb9477623..ac4c3e5f8eb3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -209,7 +209,7 @@ steps: - pytest -v -s v1/worker - pytest -v -s v1/structured_output - pytest -v -s v1/spec_decode - - pytest -v -s v1/kv_transfer + - pytest -v -s v1/kv_connector/unit - pytest -v -s v1/test_serial_utils.py - pytest -v -s v1/test_stats.py - pytest -v -s v1/test_utils.py diff --git a/tests/v1/kv_connector/integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh similarity index 94% rename from tests/v1/kv_connector/integration/run_accuracy_test.sh rename to tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index 802cd41db645..c44a8f9011bd 100755 --- a/tests/v1/kv_connector/integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -88,7 +88,7 @@ for PORT in "${DECODE_PORTS[@]}"; do done # Build the command for the proxy server with all the hosts and ports -PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/toy_proxy_server.py --port 8192" +PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port 8192" # Add all prefill hosts and ports PROXY_CMD+=" --prefiller-hosts ${PREFILL_HOSTS[@]}" @@ -106,4 +106,4 @@ $PROXY_CMD & sleep 5 # Run lm eval. -python -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/test_accuracy.py +python -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_accuracy.py diff --git a/tests/v1/kv_connector/integration/test_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_accuracy.py similarity index 100% rename from tests/v1/kv_connector/integration/test_accuracy.py rename to tests/v1/kv_connector/nixl_integration/test_accuracy.py diff --git a/tests/v1/kv_connector/integration/toy_proxy_server.py b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py similarity index 99% rename from tests/v1/kv_connector/integration/toy_proxy_server.py rename to tests/v1/kv_connector/nixl_integration/toy_proxy_server.py index cb9a20189424..85c2d88a6ae2 100644 --- a/tests/v1/kv_connector/integration/toy_proxy_server.py +++ b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py @@ -213,8 +213,6 @@ async def handle_completions(request: Request): # Get the next decode client in round-robin fashion decode_client_info = get_next_client(request.app, 'decode') - print(f"Using {prefill_client_info} {decode_client_info}") - # Stream response from decode service async def generate_stream(): async for chunk in stream_service_response( From 1c24c663f68a4fa9f3ddc36025f0bf33069a51d0 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 23:49:35 +0000 Subject: [PATCH 22/30] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/config.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 5ca70f2f67b6..f60547bdd4e9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3402,8 +3402,6 @@ class KVTransferConfig(BaseModel): kv_connector: Optional[str] = None # Engine ID for the KV transfers. - # Note(tms): sticking this here so the engine_id is consistent between - # scheduler-side and worker-side of the KVConnector engine_id: str = str(uuid.uuid4()) # The device used by kv connector to buffer the KV cache. From 8ac138e6e2bfb8eb97c6add32151e659669327fd Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 5 May 2025 23:49:59 +0000 Subject: [PATCH 23/30] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/core/sched/scheduler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 3177eb913f2a..936f5d0e3158 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -715,7 +715,6 @@ def update_from_output( new_prompt_logprobs_tensors=prompt_logprobs_tensors, stop_reason=request.stop_reason, events=request.take_events())) - else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors From f8239cf320f7eafee7ec7fb8be3f7ee78c429d47 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 6 May 2025 00:20:46 +0000 Subject: [PATCH 24/30] updated Signed-off-by: rshaw@neuralmagic.com --- .../kv_connector/unit/test_multi_connector.py | 239 ------------------ tests/v1/kv_connector/unit/utils.py | 8 +- 2 files changed, 4 insertions(+), 243 deletions(-) delete mode 100644 tests/v1/kv_connector/unit/test_multi_connector.py diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py deleted file mode 100644 index ed26ba0f0d33..000000000000 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ /dev/null @@ -1,239 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import filecmp -import shutil -import tempfile -from collections import defaultdict -from pathlib import Path - -from vllm import LLM, SamplingParams -from vllm.config import KVTransferConfig, VllmConfig -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) -from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa - SharedStorageConnector) - -MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" - -PROMPT_CONTEXT = "Hi " * 100 -PROMPTS = [ - PROMPT_CONTEXT + "Hello, my name is", - PROMPT_CONTEXT + "The capital of France is", -] - -SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20) - - -class TestSharedStorageConnector(SharedStorageConnector): - - def __init__(self, config: VllmConfig, role): - self.name = config.kv_transfer_config.kv_connector_extra_config["name"] - self._connector = SharedStorageConnector(config, role) - self.call_record: dict[str, int] = defaultdict(int) - # Use a unique temp file per connector - self._event_file = tempfile.gettempdir( - ) + f"/connector_{self.name}_events.log" - # Start with an empty file - with open(self._event_file, "w") as _: - pass - - def __getattribute__(self, name): - if name in ("_connector", "call_record", "name", "_event_file", - "__class__", "__dict__", "__getattribute__", - "__init__"): # avoid recursion - return object.__getattribute__(self, name) - if not hasattr(self._connector, name): - return object.__getattribute__(self, name) - attr = getattr(self._connector, name) - - if callable(attr): - - def wrapper(*args, **kwargs): - self.call_record[name] += 1 - # Log the event as a line to the file - try: - with open(self._event_file, "a") as f: - f.write(name + "\n") - except Exception as e: - print(f"[ERROR] Could not log event {name} " - f"for {self.name}: {e}") - return attr(*args, **kwargs) - - return wrapper - return attr - - -KVConnectorFactory.register_connector("TestSharedStorageConnector", - TestSharedStorageConnector.__module__, - TestSharedStorageConnector.__name__) - - -# Helper function to compare directories recursively -def _compare_directories(dir1: Path, dir2: Path) -> bool: - """Compares two directories recursively for identical content.""" - dcmp = filecmp.dircmp(dir1, dir2) - if dcmp.left_only or dcmp.right_only or dcmp.diff_files: - print(f"Differences found between {dir1} and {dir2}:") - print(f" Left only: {dcmp.left_only}") - print(f" Right only: {dcmp.right_only}") - print(f" Different files: {dcmp.diff_files}") - return False - for sub_dir in dcmp.common_dirs: - if not _compare_directories(dir1 / sub_dir, dir2 / sub_dir): - return False - return True - - -def test_multi_shared_storage_connector_consistency(): - """ - Tests that MultiConnector with two SharedStorageConnectors saves - identical KV cache data to separate storage locations. - """ - storage_1_path = Path("storage_1/") - storage_2_path = Path("storage_2/") - shutil.rmtree(storage_1_path, ignore_errors=True) - shutil.rmtree(storage_2_path, ignore_errors=True) - storage_1_path.mkdir() - storage_2_path.mkdir() - - # Configure MultiConnector with two SharedStorageConnectors - kv_transfer_config = KVTransferConfig( - kv_connector="MultiConnector", - kv_role="kv_both", - kv_connector_extra_config={ - "connectors": [{ - "kv_connector": "TestSharedStorageConnector", - "kv_role": "kv_both", - "kv_connector_extra_config": { - "shared_storage_path": str(storage_1_path), - "name": "storage1", - } - }, { - "kv_connector": "TestSharedStorageConnector", - "kv_role": "kv_both", - "kv_connector_extra_config": { - "shared_storage_path": str(storage_2_path), - "name": "storage2", - } - }] - }, - ) - - llm = LLM( - model=MODEL_NAME, - enforce_eager=True, - gpu_memory_utilization=0.5, - kv_transfer_config=kv_transfer_config, - ) - # Run generation - this should trigger saving KV cache - _ = llm.generate(PROMPTS, SAMPLING_PARAMS) - - # --- Verification --- - - # Check that both storage directories were populated - local_subdirs = list(storage_1_path.iterdir()) - external_subdirs = list(storage_2_path.iterdir()) - - assert len( - local_subdirs - ) > 0, f"Local storage path {storage_1_path} is empty after generation." - assert len(external_subdirs) > 0, ( - f"External storage path {storage_2_path} is empty after generation.") - assert len(local_subdirs) == len(external_subdirs), ( - f"Mismatch in number of cache entries: " - f"Local={len(local_subdirs)}, External={len(external_subdirs)}") - - # The subdirectories should correspond to the prompt hashes - # Since prompts are the same, the hash directories should be the same name - local_subdir_names = sorted([d.name for d in local_subdirs]) - external_subdir_names = sorted([d.name for d in external_subdirs]) - assert local_subdir_names == external_subdir_names, ( - "Cache directory names do not match between local and external storage" - ) - - # Compare the contents of each corresponding cache directory - for subdir_name in local_subdir_names: - print(f"Comparing contents of cache directory: {subdir_name}") - assert _compare_directories(storage_1_path / subdir_name, - storage_2_path / subdir_name), \ - (f"Contents differ for cache directory '{subdir_name}' between " - f"{storage_1_path} and {storage_2_path}") - - events = get_connector_events() - # get_num_new_matched_tokens will be called on each connector in turn. - # neither of them have hits so update_state_after_alloc won't be called. - assert events["storage1"][:3] == [ - 'get_num_new_matched_tokens', 'build_connector_meta', - 'bind_connector_metadata' - ] - assert events["storage2"][:3] == [ - 'get_num_new_matched_tokens', 'build_connector_meta', - 'bind_connector_metadata' - ] - - # Reset prefix cache or else we'll just get the tokens back from there. - llm.reset_prefix_cache() - - # Run generation again - this should trigger loading from the first - # connector. - _ = llm.generate(PROMPTS, SAMPLING_PARAMS) - - events = get_connector_events() - # get_num_new_matched_tokens will return new tokens from the first - # connector so update_state_after_alloc will be called once blocks - # are allocated for the first connector. - # get_num_new_matched_tokens *won't* be called on the second connector - # in this case. - assert events["storage1"][:4] == [ - 'get_num_new_matched_tokens', 'update_state_after_alloc', - 'build_connector_meta', 'bind_connector_metadata' - ] - assert events["storage2"][:2] == [ - 'build_connector_meta', 'bind_connector_metadata' - ] - - # Delete storage1 connector state - shutil.rmtree(storage_1_path) - - # Reset prefix cache or else we'll just get the tokens back from there. - llm.reset_prefix_cache() - - # Run generation again - this should trigger loading from the first - # connector. - _ = llm.generate(PROMPTS, SAMPLING_PARAMS) - - events = get_connector_events() - # get_num_new_matched_tokens will be called for the first connector but it - # won't have a hit so update_state_after_alloc won't be called. - # get_num_new_matched_tokens will also be called on the second connector, - # but it should have a hit so update_state_after_alloc will be called. - assert events["storage1"][:3] == [ - 'get_num_new_matched_tokens', 'build_connector_meta', - 'bind_connector_metadata' - ] - assert events["storage2"][:4] == [ - 'get_num_new_matched_tokens', 'update_state_after_alloc', - 'build_connector_meta', 'bind_connector_metadata' - ] - - # Clean up - shutil.rmtree(storage_1_path) - shutil.rmtree(storage_2_path) - - -def get_connector_events() -> dict[str, list[str]]: - # Read in connector events and reset the files. - import glob - event_files = glob.glob(tempfile.gettempdir() + "/connector_*_events.log") - connector_events = {} - for fname in event_files: - name = fname.split("connector_")[1].split("_events.log")[0] - try: - with open(fname, "r+") as f: - connector_events[name] = [ - line.strip() for line in f if line.strip() - ] - f.truncate(0) - except Exception as e: - print(f"[ERROR] Could not read connector events for {name}: {e}") - - return connector_events diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index c5527bc0ee55..a681b1ad5f29 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -6,7 +6,7 @@ from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, ModelConfig, SchedulerConfig, VllmConfig) from vllm.sampling_params import KVTransferParams, SamplingParams -from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.core.sched.scheduler_disagg import DisaggregatedScheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec) from vllm.v1.outputs import ModelRunnerOutput @@ -16,7 +16,7 @@ EOS_TOKEN_ID = 50256 -def assert_scheduler_empty(scheduler: Scheduler): +def assert_scheduler_empty(scheduler: DisaggregatedScheduler): """Confirm the scheduler is "empty" - i.e. no leaks.""" # Scheduler Metadata. assert len(scheduler.requests) == 0 @@ -88,7 +88,7 @@ def create_vllm_config( def create_scheduler( vllm_config: VllmConfig, num_blocks: int = 10000, -) -> Scheduler: +) -> DisaggregatedScheduler: """Initialize Scheduler For Testing.""" block_size = vllm_config.cache_config.block_size kv_cache_config = KVCacheConfig( @@ -101,7 +101,7 @@ def create_scheduler( ], ) vllm_config.cache_config.num_gpu_blocks = num_blocks - return Scheduler( + return DisaggregatedScheduler( vllm_config=vllm_config, kv_cache_config=kv_cache_config, log_stats=True, From 4a391084c1f39c8a0c575b8289676f1ec1dcca07 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 6 May 2025 00:28:04 +0000 Subject: [PATCH 25/30] updated Signed-off-by: rshaw@neuralmagic.com --- .../kv_connector/unit/test_multi_connector.py | 239 ++++++++++++++++++ 1 file changed, 239 insertions(+) create mode 100644 tests/v1/kv_connector/unit/test_multi_connector.py diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py new file mode 100644 index 000000000000..ed26ba0f0d33 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -0,0 +1,239 @@ +# SPDX-License-Identifier: Apache-2.0 +import filecmp +import shutil +import tempfile +from collections import defaultdict +from pathlib import Path + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa + SharedStorageConnector) + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" + +PROMPT_CONTEXT = "Hi " * 100 +PROMPTS = [ + PROMPT_CONTEXT + "Hello, my name is", + PROMPT_CONTEXT + "The capital of France is", +] + +SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20) + + +class TestSharedStorageConnector(SharedStorageConnector): + + def __init__(self, config: VllmConfig, role): + self.name = config.kv_transfer_config.kv_connector_extra_config["name"] + self._connector = SharedStorageConnector(config, role) + self.call_record: dict[str, int] = defaultdict(int) + # Use a unique temp file per connector + self._event_file = tempfile.gettempdir( + ) + f"/connector_{self.name}_events.log" + # Start with an empty file + with open(self._event_file, "w") as _: + pass + + def __getattribute__(self, name): + if name in ("_connector", "call_record", "name", "_event_file", + "__class__", "__dict__", "__getattribute__", + "__init__"): # avoid recursion + return object.__getattribute__(self, name) + if not hasattr(self._connector, name): + return object.__getattribute__(self, name) + attr = getattr(self._connector, name) + + if callable(attr): + + def wrapper(*args, **kwargs): + self.call_record[name] += 1 + # Log the event as a line to the file + try: + with open(self._event_file, "a") as f: + f.write(name + "\n") + except Exception as e: + print(f"[ERROR] Could not log event {name} " + f"for {self.name}: {e}") + return attr(*args, **kwargs) + + return wrapper + return attr + + +KVConnectorFactory.register_connector("TestSharedStorageConnector", + TestSharedStorageConnector.__module__, + TestSharedStorageConnector.__name__) + + +# Helper function to compare directories recursively +def _compare_directories(dir1: Path, dir2: Path) -> bool: + """Compares two directories recursively for identical content.""" + dcmp = filecmp.dircmp(dir1, dir2) + if dcmp.left_only or dcmp.right_only or dcmp.diff_files: + print(f"Differences found between {dir1} and {dir2}:") + print(f" Left only: {dcmp.left_only}") + print(f" Right only: {dcmp.right_only}") + print(f" Different files: {dcmp.diff_files}") + return False + for sub_dir in dcmp.common_dirs: + if not _compare_directories(dir1 / sub_dir, dir2 / sub_dir): + return False + return True + + +def test_multi_shared_storage_connector_consistency(): + """ + Tests that MultiConnector with two SharedStorageConnectors saves + identical KV cache data to separate storage locations. + """ + storage_1_path = Path("storage_1/") + storage_2_path = Path("storage_2/") + shutil.rmtree(storage_1_path, ignore_errors=True) + shutil.rmtree(storage_2_path, ignore_errors=True) + storage_1_path.mkdir() + storage_2_path.mkdir() + + # Configure MultiConnector with two SharedStorageConnectors + kv_transfer_config = KVTransferConfig( + kv_connector="MultiConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "connectors": [{ + "kv_connector": "TestSharedStorageConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "shared_storage_path": str(storage_1_path), + "name": "storage1", + } + }, { + "kv_connector": "TestSharedStorageConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "shared_storage_path": str(storage_2_path), + "name": "storage2", + } + }] + }, + ) + + llm = LLM( + model=MODEL_NAME, + enforce_eager=True, + gpu_memory_utilization=0.5, + kv_transfer_config=kv_transfer_config, + ) + # Run generation - this should trigger saving KV cache + _ = llm.generate(PROMPTS, SAMPLING_PARAMS) + + # --- Verification --- + + # Check that both storage directories were populated + local_subdirs = list(storage_1_path.iterdir()) + external_subdirs = list(storage_2_path.iterdir()) + + assert len( + local_subdirs + ) > 0, f"Local storage path {storage_1_path} is empty after generation." + assert len(external_subdirs) > 0, ( + f"External storage path {storage_2_path} is empty after generation.") + assert len(local_subdirs) == len(external_subdirs), ( + f"Mismatch in number of cache entries: " + f"Local={len(local_subdirs)}, External={len(external_subdirs)}") + + # The subdirectories should correspond to the prompt hashes + # Since prompts are the same, the hash directories should be the same name + local_subdir_names = sorted([d.name for d in local_subdirs]) + external_subdir_names = sorted([d.name for d in external_subdirs]) + assert local_subdir_names == external_subdir_names, ( + "Cache directory names do not match between local and external storage" + ) + + # Compare the contents of each corresponding cache directory + for subdir_name in local_subdir_names: + print(f"Comparing contents of cache directory: {subdir_name}") + assert _compare_directories(storage_1_path / subdir_name, + storage_2_path / subdir_name), \ + (f"Contents differ for cache directory '{subdir_name}' between " + f"{storage_1_path} and {storage_2_path}") + + events = get_connector_events() + # get_num_new_matched_tokens will be called on each connector in turn. + # neither of them have hits so update_state_after_alloc won't be called. + assert events["storage1"][:3] == [ + 'get_num_new_matched_tokens', 'build_connector_meta', + 'bind_connector_metadata' + ] + assert events["storage2"][:3] == [ + 'get_num_new_matched_tokens', 'build_connector_meta', + 'bind_connector_metadata' + ] + + # Reset prefix cache or else we'll just get the tokens back from there. + llm.reset_prefix_cache() + + # Run generation again - this should trigger loading from the first + # connector. + _ = llm.generate(PROMPTS, SAMPLING_PARAMS) + + events = get_connector_events() + # get_num_new_matched_tokens will return new tokens from the first + # connector so update_state_after_alloc will be called once blocks + # are allocated for the first connector. + # get_num_new_matched_tokens *won't* be called on the second connector + # in this case. + assert events["storage1"][:4] == [ + 'get_num_new_matched_tokens', 'update_state_after_alloc', + 'build_connector_meta', 'bind_connector_metadata' + ] + assert events["storage2"][:2] == [ + 'build_connector_meta', 'bind_connector_metadata' + ] + + # Delete storage1 connector state + shutil.rmtree(storage_1_path) + + # Reset prefix cache or else we'll just get the tokens back from there. + llm.reset_prefix_cache() + + # Run generation again - this should trigger loading from the first + # connector. + _ = llm.generate(PROMPTS, SAMPLING_PARAMS) + + events = get_connector_events() + # get_num_new_matched_tokens will be called for the first connector but it + # won't have a hit so update_state_after_alloc won't be called. + # get_num_new_matched_tokens will also be called on the second connector, + # but it should have a hit so update_state_after_alloc will be called. + assert events["storage1"][:3] == [ + 'get_num_new_matched_tokens', 'build_connector_meta', + 'bind_connector_metadata' + ] + assert events["storage2"][:4] == [ + 'get_num_new_matched_tokens', 'update_state_after_alloc', + 'build_connector_meta', 'bind_connector_metadata' + ] + + # Clean up + shutil.rmtree(storage_1_path) + shutil.rmtree(storage_2_path) + + +def get_connector_events() -> dict[str, list[str]]: + # Read in connector events and reset the files. + import glob + event_files = glob.glob(tempfile.gettempdir() + "/connector_*_events.log") + connector_events = {} + for fname in event_files: + name = fname.split("connector_")[1].split("_events.log")[0] + try: + with open(fname, "r+") as f: + connector_events[name] = [ + line.strip() for line in f if line.strip() + ] + f.truncate(0) + except Exception as e: + print(f"[ERROR] Could not read connector events for {name}: {e}") + + return connector_events From 6f328a2b20bc47c693fa1f9079193601458c0b8d Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 6 May 2025 00:36:38 +0000 Subject: [PATCH 26/30] remove multi-connector Signed-off-by: rshaw@neuralmagic.com --- .../kv_connector/unit/test_multi_connector.py | 239 ------------------ .../kv_transfer/kv_connector/factory.py | 5 - .../kv_connector/v1/multi_connector.py | 110 -------- 3 files changed, 354 deletions(-) delete mode 100644 tests/v1/kv_connector/unit/test_multi_connector.py delete mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py deleted file mode 100644 index ed26ba0f0d33..000000000000 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ /dev/null @@ -1,239 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import filecmp -import shutil -import tempfile -from collections import defaultdict -from pathlib import Path - -from vllm import LLM, SamplingParams -from vllm.config import KVTransferConfig, VllmConfig -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) -from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa - SharedStorageConnector) - -MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" - -PROMPT_CONTEXT = "Hi " * 100 -PROMPTS = [ - PROMPT_CONTEXT + "Hello, my name is", - PROMPT_CONTEXT + "The capital of France is", -] - -SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20) - - -class TestSharedStorageConnector(SharedStorageConnector): - - def __init__(self, config: VllmConfig, role): - self.name = config.kv_transfer_config.kv_connector_extra_config["name"] - self._connector = SharedStorageConnector(config, role) - self.call_record: dict[str, int] = defaultdict(int) - # Use a unique temp file per connector - self._event_file = tempfile.gettempdir( - ) + f"/connector_{self.name}_events.log" - # Start with an empty file - with open(self._event_file, "w") as _: - pass - - def __getattribute__(self, name): - if name in ("_connector", "call_record", "name", "_event_file", - "__class__", "__dict__", "__getattribute__", - "__init__"): # avoid recursion - return object.__getattribute__(self, name) - if not hasattr(self._connector, name): - return object.__getattribute__(self, name) - attr = getattr(self._connector, name) - - if callable(attr): - - def wrapper(*args, **kwargs): - self.call_record[name] += 1 - # Log the event as a line to the file - try: - with open(self._event_file, "a") as f: - f.write(name + "\n") - except Exception as e: - print(f"[ERROR] Could not log event {name} " - f"for {self.name}: {e}") - return attr(*args, **kwargs) - - return wrapper - return attr - - -KVConnectorFactory.register_connector("TestSharedStorageConnector", - TestSharedStorageConnector.__module__, - TestSharedStorageConnector.__name__) - - -# Helper function to compare directories recursively -def _compare_directories(dir1: Path, dir2: Path) -> bool: - """Compares two directories recursively for identical content.""" - dcmp = filecmp.dircmp(dir1, dir2) - if dcmp.left_only or dcmp.right_only or dcmp.diff_files: - print(f"Differences found between {dir1} and {dir2}:") - print(f" Left only: {dcmp.left_only}") - print(f" Right only: {dcmp.right_only}") - print(f" Different files: {dcmp.diff_files}") - return False - for sub_dir in dcmp.common_dirs: - if not _compare_directories(dir1 / sub_dir, dir2 / sub_dir): - return False - return True - - -def test_multi_shared_storage_connector_consistency(): - """ - Tests that MultiConnector with two SharedStorageConnectors saves - identical KV cache data to separate storage locations. - """ - storage_1_path = Path("storage_1/") - storage_2_path = Path("storage_2/") - shutil.rmtree(storage_1_path, ignore_errors=True) - shutil.rmtree(storage_2_path, ignore_errors=True) - storage_1_path.mkdir() - storage_2_path.mkdir() - - # Configure MultiConnector with two SharedStorageConnectors - kv_transfer_config = KVTransferConfig( - kv_connector="MultiConnector", - kv_role="kv_both", - kv_connector_extra_config={ - "connectors": [{ - "kv_connector": "TestSharedStorageConnector", - "kv_role": "kv_both", - "kv_connector_extra_config": { - "shared_storage_path": str(storage_1_path), - "name": "storage1", - } - }, { - "kv_connector": "TestSharedStorageConnector", - "kv_role": "kv_both", - "kv_connector_extra_config": { - "shared_storage_path": str(storage_2_path), - "name": "storage2", - } - }] - }, - ) - - llm = LLM( - model=MODEL_NAME, - enforce_eager=True, - gpu_memory_utilization=0.5, - kv_transfer_config=kv_transfer_config, - ) - # Run generation - this should trigger saving KV cache - _ = llm.generate(PROMPTS, SAMPLING_PARAMS) - - # --- Verification --- - - # Check that both storage directories were populated - local_subdirs = list(storage_1_path.iterdir()) - external_subdirs = list(storage_2_path.iterdir()) - - assert len( - local_subdirs - ) > 0, f"Local storage path {storage_1_path} is empty after generation." - assert len(external_subdirs) > 0, ( - f"External storage path {storage_2_path} is empty after generation.") - assert len(local_subdirs) == len(external_subdirs), ( - f"Mismatch in number of cache entries: " - f"Local={len(local_subdirs)}, External={len(external_subdirs)}") - - # The subdirectories should correspond to the prompt hashes - # Since prompts are the same, the hash directories should be the same name - local_subdir_names = sorted([d.name for d in local_subdirs]) - external_subdir_names = sorted([d.name for d in external_subdirs]) - assert local_subdir_names == external_subdir_names, ( - "Cache directory names do not match between local and external storage" - ) - - # Compare the contents of each corresponding cache directory - for subdir_name in local_subdir_names: - print(f"Comparing contents of cache directory: {subdir_name}") - assert _compare_directories(storage_1_path / subdir_name, - storage_2_path / subdir_name), \ - (f"Contents differ for cache directory '{subdir_name}' between " - f"{storage_1_path} and {storage_2_path}") - - events = get_connector_events() - # get_num_new_matched_tokens will be called on each connector in turn. - # neither of them have hits so update_state_after_alloc won't be called. - assert events["storage1"][:3] == [ - 'get_num_new_matched_tokens', 'build_connector_meta', - 'bind_connector_metadata' - ] - assert events["storage2"][:3] == [ - 'get_num_new_matched_tokens', 'build_connector_meta', - 'bind_connector_metadata' - ] - - # Reset prefix cache or else we'll just get the tokens back from there. - llm.reset_prefix_cache() - - # Run generation again - this should trigger loading from the first - # connector. - _ = llm.generate(PROMPTS, SAMPLING_PARAMS) - - events = get_connector_events() - # get_num_new_matched_tokens will return new tokens from the first - # connector so update_state_after_alloc will be called once blocks - # are allocated for the first connector. - # get_num_new_matched_tokens *won't* be called on the second connector - # in this case. - assert events["storage1"][:4] == [ - 'get_num_new_matched_tokens', 'update_state_after_alloc', - 'build_connector_meta', 'bind_connector_metadata' - ] - assert events["storage2"][:2] == [ - 'build_connector_meta', 'bind_connector_metadata' - ] - - # Delete storage1 connector state - shutil.rmtree(storage_1_path) - - # Reset prefix cache or else we'll just get the tokens back from there. - llm.reset_prefix_cache() - - # Run generation again - this should trigger loading from the first - # connector. - _ = llm.generate(PROMPTS, SAMPLING_PARAMS) - - events = get_connector_events() - # get_num_new_matched_tokens will be called for the first connector but it - # won't have a hit so update_state_after_alloc won't be called. - # get_num_new_matched_tokens will also be called on the second connector, - # but it should have a hit so update_state_after_alloc will be called. - assert events["storage1"][:3] == [ - 'get_num_new_matched_tokens', 'build_connector_meta', - 'bind_connector_metadata' - ] - assert events["storage2"][:4] == [ - 'get_num_new_matched_tokens', 'update_state_after_alloc', - 'build_connector_meta', 'bind_connector_metadata' - ] - - # Clean up - shutil.rmtree(storage_1_path) - shutil.rmtree(storage_2_path) - - -def get_connector_events() -> dict[str, list[str]]: - # Read in connector events and reset the files. - import glob - event_files = glob.glob(tempfile.gettempdir() + "/connector_*_events.log") - connector_events = {} - for fname in event_files: - name = fname.split("connector_")[1].split("_events.log")[0] - try: - with open(fname, "r+") as f: - connector_events[name] = [ - line.strip() for line in f if line.strip() - ] - f.truncate(0) - except Exception as e: - print(f"[ERROR] Could not read connector events for {name}: {e}") - - return connector_events diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 9d342115ccff..54cb1871db3c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -110,8 +110,3 @@ def create_connector_v1( "NixlConnector", "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector", "NixlConnector") - -KVConnectorFactory.register_connector( - "MultiConnector", - "vllm.distributed.kv_transfer.kv_connector.v1.multi_connector", - "MultiConnector") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py deleted file mode 100644 index e8857d6e3677..000000000000 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ /dev/null @@ -1,110 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import copy -from typing import TYPE_CHECKING - -import torch - -from vllm.config import KVTransferConfig, VllmConfig -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) -from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) -from vllm.logger import init_logger -from vllm.v1.core.sched.output import SchedulerOutput - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionMetadata - from vllm.forward_context import ForwardContext - from vllm.v1.request import Request - -logger = init_logger(__name__) - - -class MultiKVConnectorMetadata(tuple[KVConnectorMetadata, ...], - KVConnectorMetadata): - pass - - -class MultiConnector(KVConnectorBase_V1): - - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) - self._connectors = [] - ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "connectors") - assert ktcs is not None - for ktc in ktcs: - temp_config = copy.copy(vllm_config) - temp_config.kv_transfer_config = KVTransferConfig(**ktc) - self._connectors.append( - KVConnectorFactory.create_connector_v1(temp_config, role)) - - # A mapping from request id to the connector that is assigned to it. - self._requests_to_connector: dict[str, KVConnectorBase_V1] = {} - - # We must override the base class method here because we need to bind - # the metadata to each connector in the order of the connectors in the - # MultiKVConnectorMetadata. - def bind_connector_metadata( - self, connector_metadata: KVConnectorMetadata) -> None: - assert isinstance(connector_metadata, MultiKVConnectorMetadata) - for c, cm in zip(self._connectors, connector_metadata): - c.bind_connector_metadata(cm) - - def clear_connector_metadata(self) -> None: - for c in self._connectors: - c.clear_connector_metadata() - - # ============================== - # Worker-side methods - # ============================== - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: - for c in self._connectors: - c.start_load_kv(forward_context, **kwargs) - - def wait_for_layer_load(self, layer_name: str) -> None: - for c in self._connectors: - c.wait_for_layer_load(layer_name) - - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: - for c in self._connectors: - c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs) - - def wait_for_save(self): - for c in self._connectors: - c.wait_for_save() - - # ============================== - # Scheduler-side methods - # ============================== - def get_num_new_matched_tokens( - self, - request: "Request", - num_computed_tokens: int, - ) -> int: - for c in self._connectors: - toks = c.get_num_new_matched_tokens(request, num_computed_tokens) - # The first connector that has new matched tokens will be assigned - # to this request. - if toks > 0: - self._requests_to_connector[request.request_id] = c - return toks - return 0 - - def update_state_after_alloc(self, request: "Request", - block_ids: list[int], - num_external_tokens: int): - # If the request is not assigned to any connector, we do nothing. - if request.request_id not in self._requests_to_connector: - return - # We assume that the request is assigned to only one connector. - c = self._requests_to_connector.pop(request.request_id) - c.update_state_after_alloc(request, block_ids, num_external_tokens) - - def build_connector_meta( - self, - scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata: - return MultiKVConnectorMetadata( - c.build_connector_meta(scheduler_output) for c in self._connectors) From da291cee137278d82c58f9aa016b3deb740e387f Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 6 May 2025 00:38:55 +0000 Subject: [PATCH 27/30] remove multi-connector Signed-off-by: rshaw@neuralmagic.com --- vllm/engine/arg_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 57b14101b9ea..bc10b67f0cea 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1432,8 +1432,7 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None: if self.scheduler_cls == EngineArgs.scheduler_cls: if self.kv_transfer_config: self.scheduler_cls = ( - "vllm.v1.core.sched.scheduler_disagg.DisaggregatedScheduler" - ) + "vllm.v1.core.sched.scheduler_disagg.DisaggregatedScheduler") else: self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler" From b00157b5858b5aaa325c0e7bf58d2ee68f516b5d Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 6 May 2025 00:52:47 +0000 Subject: [PATCH 28/30] cleanup Signed-off-by: rshaw@neuralmagic.com --- vllm/forward_context.py | 21 +++++++++++++++++++++ vllm/v1/core/sched/scheduler_disagg.py | 2 -- vllm/v1/outputs.py | 18 ++++++++++-------- 3 files changed, 31 insertions(+), 10 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index c24ba0f45f9e..c75d8f088c5b 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -11,6 +11,10 @@ import vllm.envs as envs from vllm.config import VllmConfig +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.logger import init_logger if TYPE_CHECKING: @@ -97,6 +101,16 @@ def set_forward_context(attn_metadata: Any, attn_metadata=attn_metadata, dp_metadata=dp_metadata) + # KVConnector: trigger (possibly async) load before forward. + # Each attn layer will block until the reading is complete. + trigger_kv_transfer = (attn_metadata is not None + and has_kv_transfer_group() + and is_v1_kv_transfer_group()) + if trigger_kv_transfer: + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + kv_connector.start_load_kv(_forward_context) + try: yield finally: @@ -133,4 +147,11 @@ def set_forward_context(attn_metadata: Any, "(batchsize, count, median_time(ms)): %s"), forward_stats) + # KVConnector: each attn layer triggers (possibly async) save. + # Ensure all those operations complete before forward() is done. + if trigger_kv_transfer: + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + kv_connector.wait_for_save() + _forward_context = prev_context diff --git a/vllm/v1/core/sched/scheduler_disagg.py b/vllm/v1/core/sched/scheduler_disagg.py index 35d1c2041b5d..44427a9f38e6 100644 --- a/vllm/v1/core/sched/scheduler_disagg.py +++ b/vllm/v1/core/sched/scheduler_disagg.py @@ -540,8 +540,6 @@ def update_from_output( # NOTE(rob): req is not freed (or preempted) in the EngineCore # until the xfer is done to ensure we do not free the KV blocks. kv_transfer_params = None - # TODO(rob): edge case where we get a stop for stop_strings - # inside AsyncLLM. if request.do_remote_decode and not stopped: request.status = RequestStatus.FINISHED_REMOTE_DECODE self._free_request(request, skip_free_blocks=True) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index e8ce0df5ed8d..f4a240bc7b06 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -105,11 +105,13 @@ class ModelRunnerOutput: finished_recving: Optional[set[str]] = None -EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], - req_id_to_index={}, - sampled_token_ids=[], - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - finished_sending=None, - finished_recving=None) +EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=[], + req_id_to_index={}, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + finished_sending=None, + finished_recving=None, +) From 5af868e6c20de33e29573fcefad6788a8aad6067 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 6 May 2025 01:10:11 +0000 Subject: [PATCH 29/30] cherry pick nixl Signed-off-by: rshaw@neuralmagic.com --- .../nixl_integration/run_accuracy_test.sh | 195 ++++++++++++------ .../nixl_integration/test_accuracy.py | 30 ++- .../nixl_integration/toy_proxy_server.py | 6 + .../unit/test_remote_prefill_lifecycle.py | 34 +++ .../kv_connector/v1/nixl_connector.py | 62 +++--- vllm/sampling_params.py | 1 - vllm/v1/core/sched/scheduler_disagg.py | 15 +- 7 files changed, 237 insertions(+), 106 deletions(-) diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index c44a8f9011bd..17eac2629687 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -1,9 +1,11 @@ #!/bin/bash - set -xe -# Model to run. -MODEL_NAME=Qwen/Qwen3-0.6B +# Models to run +MODELS=( +# "Qwen/Qwen3-0.6B" + "deepseek-ai/deepseek-vl2-tiny" +) # Number of prefill and decode instances to create NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1 @@ -24,86 +26,147 @@ wait_for_server() { done" && return 0 || return 1 } -# Arrays to store all hosts and ports -PREFILL_HOSTS=() -PREFILL_PORTS=() -DECODE_HOSTS=() -DECODE_PORTS=() +# Function to clean up previous instances +cleanup_instances() { + echo "Cleaning up any running vLLM instances..." + pkill -f "vllm serve" || true + sleep 2 +} -# Start prefill instances -for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do - # Calculate GPU ID - we'll distribute across available GPUs - GPU_ID=$((i % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l))) - # Calculate port number (base port + instance number) - PORT=$((8100 + i)) - # Calculate side channel port - SIDE_CHANNEL_PORT=$((5559 + i)) +# Handle to get model-specific arguments for deepseek +get_model_args() { + local model_name=$1 + local extra_args="" - echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT" + if [[ "$model_name" == "deepseek-ai/deepseek-vl2-tiny" ]]; then + extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code" + fi - CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $MODEL_NAME \ - --port $PORT \ - --enforce-eager \ - --disable-log-requests \ - --gpu-memory-utilization 0.2 \ - --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' & + echo "$extra_args" +} - # Store host and port for proxy configuration - PREFILL_HOSTS+=("localhost") - PREFILL_PORTS+=($PORT) -done -# Start decode instances -for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do - # Calculate GPU ID - we'll distribute across available GPUs, starting from after prefill GPUs - GPU_ID=$(((i + NUM_PREFILL_INSTANCES) % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l))) - # Calculate port number (base port + instance number) - PORT=$((8200 + i)) - # Calculate side channel port - SIDE_CHANNEL_PORT=$((5659 + i)) +# Function to run tests for a specific model +run_tests_for_model() { + local model_name=$1 + echo "================================" + echo "Testing model: $model_name" + echo "================================" + + # Get model-specific arguments + local model_args=$(get_model_args "$model_name") + + # Arrays to store all hosts and ports + PREFILL_HOSTS=() + PREFILL_PORTS=() + DECODE_HOSTS=() + DECODE_PORTS=() - echo "Starting decode instance $i on GPU $GPU_ID, port $PORT" + # Start prefill instances + for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do + # Calculate GPU ID - we'll distribute across available GPUs + GPU_ID=$((i % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l))) + # Calculate port number (base port + instance number) + PORT=$((8100 + i)) + # Calculate side channel port + SIDE_CHANNEL_PORT=$((5559 + i)) - CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $MODEL_NAME \ + echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT" + + # Build the command with or without model-specific args + BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ --port $PORT \ --enforce-eager \ --disable-log-requests \ --gpu-memory-utilization 0.2 \ - --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' & + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + + if [ -n "$model_args" ]; then + FULL_CMD="$BASE_CMD $model_args" + else + FULL_CMD="$BASE_CMD" + fi + + eval "$FULL_CMD &" + + # Store host and port for proxy configuration + PREFILL_HOSTS+=("localhost") + PREFILL_PORTS+=($PORT) + done + + # Start decode instances + for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do + # Calculate GPU ID - we'll distribute across available GPUs, starting from after prefill GPUs + GPU_ID=$(((i + NUM_PREFILL_INSTANCES) % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l))) + # Calculate port number (base port + instance number) + PORT=$((8200 + i)) + # Calculate side channel port + SIDE_CHANNEL_PORT=$((5659 + i)) + + echo "Starting decode instance $i on GPU $GPU_ID, port $PORT" + + # Build the command with or without model-specific args + BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ + --port $PORT \ + --enforce-eager \ + --disable-log-requests \ + --gpu-memory-utilization 0.2 \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" - # Store host and port for proxy configuration - DECODE_HOSTS+=("localhost") - DECODE_PORTS+=($PORT) -done + if [ -n "$model_args" ]; then + FULL_CMD="$BASE_CMD $model_args" + else + FULL_CMD="$BASE_CMD" + fi -# Wait for all instances to start -for PORT in "${PREFILL_PORTS[@]}"; do - echo "Waiting for prefill instance on port $PORT to start..." - wait_for_server $PORT -done + eval "$FULL_CMD &" -for PORT in "${DECODE_PORTS[@]}"; do - echo "Waiting for decode instance on port $PORT to start..." - wait_for_server $PORT -done + # Store host and port for proxy configuration + DECODE_HOSTS+=("localhost") + DECODE_PORTS+=($PORT) + done + + # Wait for all instances to start + for PORT in "${PREFILL_PORTS[@]}"; do + echo "Waiting for prefill instance on port $PORT to start..." + wait_for_server $PORT + done -# Build the command for the proxy server with all the hosts and ports -PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port 8192" + for PORT in "${DECODE_PORTS[@]}"; do + echo "Waiting for decode instance on port $PORT to start..." + wait_for_server $PORT + done -# Add all prefill hosts and ports -PROXY_CMD+=" --prefiller-hosts ${PREFILL_HOSTS[@]}" -PROXY_CMD+=" --prefiller-ports ${PREFILL_PORTS[@]}" + # Build the command for the proxy server with all the hosts and ports + PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port 8192" -# Add all decode hosts and ports -PROXY_CMD+=" --decoder-hosts ${DECODE_HOSTS[@]}" -PROXY_CMD+=" --decoder-ports ${DECODE_PORTS[@]}" + # Add all prefill hosts and ports + PROXY_CMD+=" --prefiller-hosts ${PREFILL_HOSTS[@]}" + PROXY_CMD+=" --prefiller-ports ${PREFILL_PORTS[@]}" -# Start the proxy server -echo "Starting proxy server with command: $PROXY_CMD" -$PROXY_CMD & + # Add all decode hosts and ports + PROXY_CMD+=" --decoder-hosts ${DECODE_HOSTS[@]}" + PROXY_CMD+=" --decoder-ports ${DECODE_PORTS[@]}" -# Wait for the proxy to start -sleep 5 + # Start the proxy server + echo "Starting proxy server with command: $PROXY_CMD" + $PROXY_CMD & + + # Wait for the proxy to start + sleep 5 + + # Run lm eval for this model + echo "Running tests for $model_name" + TEST_MODEL=$model_name python -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_accuracy.py + + # Clean up before running next model + cleanup_instances + sleep 3 +} + +# Run tests for each model +for model in "${MODELS[@]}"; do + run_tests_for_model "$model" +done -# Run lm eval. -python -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_accuracy.py +echo "All tests completed!" diff --git a/tests/v1/kv_connector/nixl_integration/test_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_accuracy.py index c9cfc863277d..b1c03abbb396 100644 --- a/tests/v1/kv_connector/nixl_integration/test_accuracy.py +++ b/tests/v1/kv_connector/nixl_integration/test_accuracy.py @@ -1,17 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 +import os + import lm_eval import openai BASE_URL = "http://localhost:8192/v1" -MODEL_NAME = "Qwen/Qwen3-0.6B" NUM_CONCURRENT = 100 TASK = "gsm8k" FILTER = "exact_match,strict-match" RTOL = 0.03 -EXPECTED_VALUE = 0.41 + +# Model-specific expected values +EXPECTED_VALUES = { + "Qwen/Qwen3-0.6B": 0.41, + "deepseek-ai/deepseek-vl2-tiny": 0.20, +} SIMPLE_PROMPT = "The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means", # noqa: E501 +# Get model name from environment variable +MODEL_NAME = os.environ.get("TEST_MODEL", "Qwen/Qwen3-0.6B") + def run_simple_prompt(): client = openai.OpenAI(api_key="EMPTY", base_url=BASE_URL) @@ -19,14 +28,13 @@ def run_simple_prompt(): prompt=SIMPLE_PROMPT) print("-" * 50) - print("Completion results:") + print(f"Completion results for {MODEL_NAME}:") print(completion) print("-" * 50) def test_accuracy(): """Run the end to end accuracy test.""" - run_simple_prompt() model_args = (f"model={MODEL_NAME}," @@ -40,6 +48,14 @@ def test_accuracy(): ) measured_value = results["results"][TASK][FILTER] - assert (measured_value - RTOL < EXPECTED_VALUE - and measured_value + RTOL > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + expected_value = EXPECTED_VALUES.get(MODEL_NAME) + + if expected_value is None: + print(f"Warning: No expected value found for {MODEL_NAME}. " + "Skipping accuracy check.") + print(f"Measured value: {measured_value}") + return + + assert (measured_value - RTOL < expected_value + and measured_value + RTOL > expected_value + ), f"Expected: {expected_value} | Measured: {measured_value}" diff --git a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py index 85c2d88a6ae2..ca5eca0c6231 100644 --- a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py +++ b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py @@ -10,6 +10,10 @@ from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse +from vllm.logger import init_logger + +logger = init_logger(__name__) + @asynccontextmanager async def lifespan(app: FastAPI): @@ -213,6 +217,8 @@ async def handle_completions(request: Request): # Get the next decode client in round-robin fashion decode_client_info = get_next_client(request.app, 'decode') + logger.debug("Using %s %s", prefill_client_info, decode_client_info) + # Stream response from decode service async def generate_stream(): async for chunk in stream_service_response( diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index b9deeda18e95..e6d254443e9d 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -272,3 +272,37 @@ def test_no_spurious_prefix_caching(): for block in remote_blocks: assert block.ref_cnt == 1 assert block._block_hash is None + + +def test_short_prompt_lifecycle(): + """Test lifecycle of a Remote Decode request with short prompt.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # Not enough tokens for full block. + NUM_TOKENS = vllm_config.cache_config.block_size // 2 + request = create_request(request_id=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True) + + scheduler.add_request(request) + + # STEP (1): Prefill. + # (1a): schedule() + scheduler_output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + + # (1b): execute_model() + model_runner_output = create_model_runner_output(reqs=[request]) + + # (1c): update_from_output() + # Since tokens < block_size, there will be no kv xfer. + # So this should be cleaned up immediately. + _ = scheduler.update_from_output(scheduler_output, model_runner_output) + + # Confirm we do not have any memory leaks after req lifecycle. + # We need one more call to schedule() to clear data for persistent batch. + _ = scheduler.schedule() + assert_scheduler_empty(scheduler) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 3b3dd91f0e61..8a15955bd4fc 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -331,20 +331,32 @@ def _nixl_handshake(self, host: str, port: int): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" - first_layer_name = next(iter(kv_caches)) - first_kv_cache = kv_caches[first_layer_name] + first_layer_name, first_kv_cache = next(iter(kv_caches.items())) + kv_elem_size = first_kv_cache.element_size() + + # TODO(tms): Find a more robust way to detect and handle MLA + use_mla = len(first_kv_cache.shape) == 3 + if use_mla: + # MLA case. + self.num_blocks = first_kv_cache.shape[0] + block_rank = 2 # [block_size, latent_dim] + block_shape = first_kv_cache.shape[-block_rank:] + else: + # [2 (k and v), num_blocks, ...] + self.num_blocks = first_kv_cache.shape[1] + block_rank = 3 # [block_size, kv_heads, head_dim] + block_shape = first_kv_cache.shape[-block_rank:] - # [2 (k and v), num_blocks, ...] - # TODO(tms): num_blocks will be in a different spot for MLA. - num_blocks = first_kv_cache.shape[1] - kv_elem_size = first_kv_cache[0].element_size() # TODO(tms): self.block_len needs to be per-layer for sliding window, # hybrid attn, etc - self.block_len = kv_elem_size * math.prod(first_kv_cache.shape[-3:]) - - logger.debug("Per layer kv cache size: %s", first_kv_cache[0].shape) - self.num_blocks = num_blocks - self.dst_num_blocks[self.engine_id] = num_blocks + self.block_len = kv_elem_size * math.prod(block_shape) + + logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla, + first_kv_cache.shape) + logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks, + block_shape) + logger.debug("Per layer kv cache size: %s", first_kv_cache.shape) + self.dst_num_blocks[self.engine_id] = self.num_blocks self.kv_caches = kv_caches kv_caches_base_addr = [] caches_data = [] @@ -355,10 +367,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # are non-contiguous (it's not locally guaranteed that they will be) # Disadvantage is that the encoded NixlAgentMetadata is now larger # (roughly 8KB vs 5KB). - for layer_name in kv_caches: - for cache in kv_caches[layer_name]: + for cache_or_caches in kv_caches.values(): + # Normalize to always be a list of caches + cache_list = [cache_or_caches] if use_mla else cache_or_caches + for cache in cache_list: base_addr = cache.data_ptr() - region_len = num_blocks * self.block_len + region_len = self.num_blocks * self.block_len caches_data.append((base_addr, region_len, self.rank, "")) kv_caches_base_addr.append(base_addr) self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr @@ -438,10 +452,10 @@ def get_finished(self) -> tuple[set[str], set[str]]: In TP>1 setup, each rank exchanges KVs with its counterpart ranks independently. get_finished() runs in a worker creates the done_sending and done_recving sets that are sent to the - scheduler via ModelRunnerOutput by Rank 0. To avoid race - ensure trnxs are done before adding to finished, Ranks 1 to - N-1 communicate to Rank 0 once their transaction is done. - Rank 0 only returns finished once all ranks are complete. + scheduler via ModelRunnerOutput by Rank 0. To ensure trnxs + are done before adding to finished, Ranks 1 to N-1 communicate + to Rank 0 once their transaction is done + Rank 0 returns + finished sets to Scheduler only once all ranks are done. """ done_sending = self._get_new_notifs() done_recving = self._pop_done_transfers(self._recving_transfers) @@ -579,18 +593,9 @@ def _read_blocks( # saturate IB with heterogeneous TP sizes. We should remove the staging # blocks until we are ready. - # NOTE(rob): we could potentially do the rearranging during the load_kv! - - # Note(tms): The remote_block_ids only contain full computed blocks, - # while the local_block_ids are all blocks allocated for this request, - # so truncate the local_block_ids to account for this. - del local_block_ids[len(remote_block_ids):] + assert len(local_block_ids) > 0 assert len(local_block_ids) == len(remote_block_ids) - # NOTE(rob): this can cause the remote blocks to not be freed? - if len(local_block_ids) == 0: - return - # Get side handles. local_xfer_side_handle = self.src_xfer_side_handle remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] @@ -621,7 +626,6 @@ def _read_blocks( def _get_block_descs_ids(self, engine_id: str, block_ids: list[int]) -> list[int]: """Get the descs ids for a set of block ids.""" - # TODO(rob): should we precompute this? # range(1) for MLA, range(2) otherwise. region_ids = range(self.num_regions) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index a27037c871de..0238cdfd4615 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -33,7 +33,6 @@ class KVTransferParams( omit_defaults=True, # type: ignore[call-arg] # required for @cached_property. dict=True): - # TODO(rob): we can handle xPyD and direct KV block Xfer remote_engine_id: Optional[str] = None remote_block_ids: Optional[list[int]] = None remote_host: Optional[str] = None diff --git a/vllm/v1/core/sched/scheduler_disagg.py b/vllm/v1/core/sched/scheduler_disagg.py index 44427a9f38e6..2eecb98436e3 100644 --- a/vllm/v1/core/sched/scheduler_disagg.py +++ b/vllm/v1/core/sched/scheduler_disagg.py @@ -462,6 +462,7 @@ def update_from_output( new_running: list[Request] = [] outputs: list[EngineCoreOutput] = [] + send_kv_no_op: list[str] = [] # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below # loop can be a performance bottleneck. We should do our best to avoid @@ -548,8 +549,13 @@ def update_from_output( # TODO(rob): do this on a per-Connector basis. remote_blocks = [ block.block_id for block in - self.kv_cache_manager.get_computed_blocks(request)[0] + self.kv_cache_manager.req_to_blocks[request.request_id] + if block._block_hash is not None ] + # If prompt < block_size, then there will be no KV xfer. + # Free these requests so we don't have a mem leak. + if len(remote_blocks) == 0: + send_kv_no_op.append(request.request_id) engine_id = self.vllm_config.kv_transfer_config.engine_id kv_transfer_params = KVTransferParams( @@ -581,12 +587,15 @@ def update_from_output( new_running.append(request) # P/D: update recv and send status from last step. - for req_id in (model_runner_output.finished_recving or []): + for req_id in (model_runner_output.finished_recving or ()): logger.debug("Finished recving KV transfer for request %s", req_id) self.finished_recving_kv_req_ids.add(req_id) - for req_id in (model_runner_output.finished_sending or []): + for req_id in (model_runner_output.finished_sending or ()): logger.debug("Finished sending KV transfer for request %s", req_id) self._free_blocks(self.requests[req_id]) + for req_id in send_kv_no_op: + logger.debug("No op sending KV transfer for request %s", req_id) + self._free_blocks(self.requests[req_id]) # Return the cached request data to the queue so they can # be reused. From e22d44b1d3e8378da30dd463a2dc6f75bf52f8be Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 6 May 2025 01:13:18 +0000 Subject: [PATCH 30/30] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/forward_context.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index c75d8f088c5b..f6e33d35e3bb 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -101,16 +101,6 @@ def set_forward_context(attn_metadata: Any, attn_metadata=attn_metadata, dp_metadata=dp_metadata) - # KVConnector: trigger (possibly async) load before forward. - # Each attn layer will block until the reading is complete. - trigger_kv_transfer = (attn_metadata is not None - and has_kv_transfer_group() - and is_v1_kv_transfer_group()) - if trigger_kv_transfer: - kv_connector = get_kv_transfer_group() - assert isinstance(kv_connector, KVConnectorBase_V1) - kv_connector.start_load_kv(_forward_context) - try: yield finally: @@ -147,11 +137,4 @@ def set_forward_context(attn_metadata: Any, "(batchsize, count, median_time(ms)): %s"), forward_stats) - # KVConnector: each attn layer triggers (possibly async) save. - # Ensure all those operations complete before forward() is done. - if trigger_kv_transfer: - kv_connector = get_kv_transfer_group() - assert isinstance(kv_connector, KVConnectorBase_V1) - kv_connector.wait_for_save() - _forward_context = prev_context