Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions fastdeploy/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@

class RequestStatus(Enum):
WAITING = 0
RUNNING = 1
PREEMPTED = 2
FINISHED = 3
ABORT = 4
RUNNING_PREFILL = 1
RUNNING_DECODE = 2
PREEMPTED = 3
FINISHED = 4
ABORT = 5


class RequestType(Enum):
Expand Down
174 changes: 120 additions & 54 deletions fastdeploy/engine/sched/resource_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,18 +216,12 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l
self.bos_client = None
self.async_preprocess_pool = ThreadPoolExecutor(max_workers=4)

self.init_reserve_output_block_num = (
envs.FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL
) # int
self.decay_output_block_num = (
envs.FD_RESERVE_DECAY_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL
) # float
self.min_reserve_output_block_num = (
envs.FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL
) # int
self.current_reserve_output_block_num = self.init_reserve_output_block_num
self.current_reserve_output_block_num_float = self.init_reserve_output_block_num
self.can_relax_prefill_strategy = True
self.init_new_token_ratio = envs.FD_INIT_NEW_TOKEN_RATIO
self.min_new_token_ratio = envs.FD_MIN_NEW_TOKEN_RATIO
self.new_token_ratio_decay = envs.FD_NEW_TOKEN_RATIO_DECAY
self.clip_max_new_tokens = envs.FD_CLIP_MAX_NEW_TOKENS
self.retract_decode_steps = envs.FD_RETRACT_DECODE_STEPS
self.new_token_ratio = self.init_new_token_ratio
# Scheduler-side requests that have not been moved into resource manager waiting queue yet.
self.scheduler_unhandled_request_num = 0

Expand Down Expand Up @@ -313,6 +307,15 @@ def _can_preempt(self):
return True
return False

def _can_preempt_with_decode_task(self):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 _can_preempt 方法已成为死代码

_trigger_preempt 已从 self._can_preempt() 切换到 self._can_preempt_with_decode_task(),经全仓搜索确认 _can_preempt 不再有任何调用方。建议删除该方法,避免后续维护者误用不带 decode 状态过滤的旧版本。

"""
A request is preemptable if it does NOT use extend tables AND is in decode status.
"""
for req in self.running:
if not req.use_extend_tables and req.status == RequestStatus.RUNNING_DECODE:
return True
return False

def preempted_all(self):
with self.lock:
preempted_reqs = []
Expand Down Expand Up @@ -350,14 +353,38 @@ def wait_worker_inflight_requests_finish(self, timeout=60):
def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs):
"""
If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out.
Only requests that is in decode status can be preempted.
"""
can_schedule = False
while self._can_preempt():
if not self.cache_manager.can_allocate_gpu_blocks(num_new_blocks):
preempted_req = self.running.pop()
if preempted_req.use_extend_tables:
self.running.insert(0, preempted_req)
continue
while self._can_preempt_with_decode_task():
if self.cache_manager.can_allocate_gpu_blocks(num_new_blocks):
# The request can be scheduled.
can_schedule = True
break
else:
# Scan from back to front to find the last preemptable request
preempted_req = None
i = len(self.running) - 1
while i >= 0:
candidate = self.running[i]
# Skip requests that are not in decode status
if candidate.status != RequestStatus.RUNNING_DECODE:
i -= 1
continue
# Skip requests using extend tables
if candidate.use_extend_tables:
i -= 1
continue
# Found a valid preempt target
preempted_req = candidate
break

if preempted_req is None:
# No preemptable request found (all have no output tokens or use extend tables)
return False

# Remove the preempted request from the running list
self.running.pop(i)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
if self.config.scheduler_config.splitwise_role == "decode":
Expand Down Expand Up @@ -389,33 +416,78 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re
llm_logger.debug(
f"preempt {preempted_req.request_id} in idx {preempted_req.idx} with generated ids {preempted_req.output_token_ids}"
)

llm_logger.debug(self.info())
self._info_each_block()
self._recompute_new_token_ratio_on_preemption()

if preempted_req == request:
# No more request to preempt.
can_schedule = False
break
else:
# The request can be scheduled.
can_schedule = True
break
self.current_reserve_output_block_num = self.init_reserve_output_block_num
self.current_reserve_output_block_num_float = self.init_reserve_output_block_num
self.can_relax_prefill_strategy = False

return can_schedule

def _recompute_new_token_ratio_on_preemption(self):
"""Recompute new_token_ratio based on actual decode progress of running requests.

Aligned with SGLang's retract_decode logic: estimate the ratio as the actual
fraction of max_tokens already decoded plus a small lookahead, rather than
naively resetting to the initial value. This avoids over-reserving when most
requests are near completion, and under-reserving when they've just started.

Formula:
ratio = (total_decoded + RETRACT_DECODE_STEPS * num_running) / (total_max_new_tokens + 1)
capped at init_new_token_ratio so preemption never makes the ratio more
aggressive than the initial setting.
"""
if not self.running:
self.new_token_ratio = self.init_new_token_ratio
return
total_decoded_tokens = sum(len(req.output_token_ids) for req in self.running)
total_max_new_tokens = 0
for req in self.running:
max_tokens = req.sampling_params.max_tokens
if max_tokens is None:
max_tokens = self.config.model_config.max_model_len - req.prompt_token_ids_len
total_max_new_tokens += max_tokens
num_running_decode = sum([1 if req.num_total_tokens > req.need_prefill_tokens else 0 for req in self.running])
new_ratio = (total_decoded_tokens + self.retract_decode_steps * num_running_decode) / (
total_max_new_tokens + 1
)
self.new_token_ratio = min(new_ratio, self.init_new_token_ratio)

def _get_running_request_reserve_blocks(self, request: Request) -> int:
"""Estimate KV-cache blocks to reserve for a running request's future decode tokens.

Aligned with SGLang's per-request budget estimation:
reserved_tokens = min(max_tokens - already_generated, CLIP_MAX_NEW_TOKENS) * new_token_ratio
then ceil-divided by block_size. The ratio decays each scheduling step so that
the reservation gradually relaxes; on preemption it resets to the initial value.
"""
max_tokens = request.sampling_params.max_tokens
if max_tokens is None:
max_tokens = self.config.model_config.max_model_len - request.prompt_token_ids_len
remaining_tokens = max_tokens - len(request.output_token_ids)
clipped_remaining = min(remaining_tokens, self.clip_max_new_tokens)
reserved_tokens = max(int(clipped_remaining * self.new_token_ratio), 0)
block_size = self.config.cache_config.block_size
return (reserved_tokens + block_size - 1) // block_size

def _get_can_schedule_prefill_threshold_block(self, num_chunk_new_block):
if self.can_relax_prefill_strategy:
can_schedule_block_num_threshold = num_chunk_new_block
else:
can_schedule_block_num_threshold = (
num_chunk_new_block + len(self.running) * self.current_reserve_output_block_num
"""Compute the minimum free blocks required to admit a new prefill request.

The threshold includes: (1) blocks needed for the prefill itself, and
(2) blocks reserved for all running decode requests' future output tokens,
estimated per-request via _get_running_request_reserve_blocks. This prevents
new prefills from starving ongoing decodes of KV-cache capacity.
"""
reserve_blocks = sum(self._get_running_request_reserve_blocks(req) for req in self.running)
can_schedule_block_num_threshold = num_chunk_new_block + reserve_blocks
if self.config.speculative_config.method is not None:
can_schedule_block_num_threshold = min(
can_schedule_block_num_threshold + 1, self.config.cache_config.max_block_num_per_seq
)
if self.config.speculative_config.method is not None:
can_schedule_block_num_threshold = min(
can_schedule_block_num_threshold + 1, self.config.cache_config.max_block_num_per_seq
)
return can_schedule_block_num_threshold

def _update_mm_hashes(self, request):
Expand Down Expand Up @@ -770,6 +842,7 @@ def get_enough_request(request, scheduled_reqs):
error_reqs: list[tuple[str, str]] = []
token_budget = self.config.scheduler_config.max_num_batched_tokens
need_abort_requests = [] # users trigger abortion
chunk_prefill_in_running_not_satisfied = False

# First, schedule the RUNNING requests.
req_index = 0
Expand Down Expand Up @@ -906,22 +979,17 @@ def _allocate_decode_and_extend():
req_index += 1
continue
num_new_block = self.get_new_block_nums(request, num_new_tokens)
can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block(num_new_block)
# Allocate blocks to prefill
if self.cache_manager.can_allocate_gpu_blocks(num_new_block):
request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id)
)
# Prepare prefill task
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
else: # Not enough blocks to allocate, trigger preemption
can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs)
if not can_schedule:
break
if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold):
request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id)
)
# Prepare prefill task
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
else: # Not enough blocks to allocate
chunk_prefill_in_running_not_satisfied = True
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 chunk prefill 分配失败后阻塞全部 waiting 请求——行为变更较大,请确认是否符合预期

之前此处的逻辑是:chunk prefill 分配不足时触发 _trigger_preempt 尝试腾出空间。现在改为直接 break 并设置 chunk_prefill_in_running_not_satisfied = True,这会导致本轮调度所有 waiting 请求都被跳过(包括那些可能只需少量 block 的小请求)。

这在高负载场景下可能导致 waiting 队列饥饿。请确认:

  1. 是否考虑过仅跳过「需要大量 block 的 waiting 请求」而允许小请求通过?
  2. 在 running 队列中有多个 chunk prefill 请求时,排在前面的请求失败是否应允许后面的请求继续尝试(当前是直接 break 整个循环)?

break # For chunk prefill request, if not satisfy condition for prefill, just break
token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens
if (
Expand All @@ -939,7 +1007,7 @@ def _allocate_decode_and_extend():
self.running.remove(request)

# Second, schedule the WAITING requests.
if not preempted_reqs:
if (not preempted_reqs) and (not chunk_prefill_in_running_not_satisfied):
skip_requests: list[Request] = []
while self.waiting and token_budget > 0:
if (
Expand Down Expand Up @@ -1025,7 +1093,7 @@ def _allocate_decode_and_extend():
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.num_computed_tokens
)
request.status = RequestStatus.RUNNING
request.status = RequestStatus.RUNNING_PREFILL
if self.config.scheduler_config.splitwise_role == "mixed":
allocated_position = self.get_available_position()
request.idx = allocated_position
Expand Down Expand Up @@ -1094,7 +1162,7 @@ def _allocate_decode_and_extend():
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.num_computed_tokens
)
request.status = RequestStatus.RUNNING
request.status = RequestStatus.RUNNING_PREFILL
else:
if self.config.cache_config.enable_prefix_caching:
self._free_blocks(request)
Expand All @@ -1108,14 +1176,10 @@ def _allocate_decode_and_extend():

if scheduled_reqs:
llm_logger.debug(f"schedued_reqs: {scheduled_reqs}")
self.current_reserve_output_block_num_float -= self.decay_output_block_num
self.current_reserve_output_block_num = max(
int(self.current_reserve_output_block_num_float),
self.min_reserve_output_block_num,
0,
self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_decay,
self.min_new_token_ratio,
)
if self.current_reserve_output_block_num == 0:
self.can_relax_prefill_strategy = True

self._log_console_scheduler_metrics(scheduled_reqs)

Expand Down Expand Up @@ -1334,6 +1398,7 @@ def pre_recycle_resource(self, request_id: str):
def add_request_in_p(self, requests: list[Request]):
with self.lock:
for request in requests:
request.status = RequestStatus.RUNNING_PREFILL
self.running.append(request)

def preallocate_resource_in_p(self, request: Request):
Expand Down Expand Up @@ -1467,6 +1532,7 @@ def add_prefilled_request(self, request_output: RequestOutput):
):
request.draft_token_ids = copy.deepcopy(request_output.outputs.draft_token_ids)
request.need_prefill_tokens = len(request.prompt_token_ids) + 1
request.status = RequestStatus.RUNNING_DECODE

request_output.metrics.decode_recv_req_time = request.metrics.decode_recv_req_time
request_output.metrics.decode_preallocate_req_time = request.metrics.decode_preallocate_req_time
Expand Down
6 changes: 6 additions & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,12 @@ def _validate_split_kv_size(value: int) -> int:
# Whether to enable low latency in mixed scenario
"FD_XPU_ENABLE_MIXED_EP_MODE": lambda: bool(int(os.getenv("FD_XPU_ENABLE_MIXED_EP_MODE", "0"))),
# Reserve output blocks for decoding requests when schedule new prefill requests
"FD_INIT_NEW_TOKEN_RATIO": lambda: float(os.getenv("FD_INIT_NEW_TOKEN_RATIO", "0.7")),
"FD_MIN_NEW_TOKEN_RATIO": lambda: float(os.getenv("FD_MIN_NEW_TOKEN_RATIO", "0.1")),
"FD_NEW_TOKEN_RATIO_DECAY": lambda: float(os.getenv("FD_NEW_TOKEN_RATIO_DECAY", "0.001")),
"FD_CLIP_MAX_NEW_TOKENS": lambda: int(os.getenv("FD_CLIP_MAX_NEW_TOKENS", "4096")),
"FD_RETRACT_DECODE_STEPS": lambda: int(os.getenv("FD_RETRACT_DECODE_STEPS", "1024")),
# Legacy reserve block env vars (kept for backwards compatibility, no longer used)
"FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL": lambda: int(
os.getenv("FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL", "16")
),
Expand Down
3 changes: 3 additions & 0 deletions fastdeploy/output/token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
Request,
RequestMetrics,
RequestOutput,
RequestStatus,
SpeculateMetrics,
)
from fastdeploy.inter_communicator import ZmqIpcServer
Expand Down Expand Up @@ -891,6 +892,8 @@ def _process_batch_output(self):
continue

self.total_step += 1
if task.status == RequestStatus.RUNNING_PREFILL:
task.status = RequestStatus.RUNNING_DECODE
current_time = time.time()
trace_carrier = None
if self.tokens_counter[task_id] == 0:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ einops
setproctitle
aistudio_sdk
p2pstore
mooncake-transfer-engine>=0.3.10.post1
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 mooncake-transfer-engine 依赖与本 PR 的调度优化主题无关

本 PR 标题为 Scheduler Optimization,但此处新增了 mooncake-transfer-engine 依赖。混合不相关变更会增加 review 难度和回滚风险。建议将此依赖变更拆到独立 PR 中提交。

py-cpuinfo
flashinfer-python-paddle @ https://xly-devops.bj.bcebos.com/flashinfer/flashinfer_python_paddle-0.4.1.2-py3-none-any.whl
flash_mask @ https://xly-devops.bj.bcebos.com/flashmask/flash_mask-4.0.0%2Bg4c84f74-py3-none-any.whl
Expand Down
4 changes: 2 additions & 2 deletions tests/engine/test_resource_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@ def test_preempted_all_with_normal_requests(self):
req1 = Mock(spec=Request)
req1.request_id = "req1"
req1.use_extend_tables = False
req1.status = RequestStatus.RUNNING
req1.status = RequestStatus.RUNNING_DECODE
req1.block_tables = [1, 2, 3]
req1.num_cached_blocks = 0
req1.idx = 0

req2 = Mock(spec=Request)
req2.request_id = "req2"
req2.use_extend_tables = False
req2.status = RequestStatus.RUNNING
req2.status = RequestStatus.RUNNING_DECODE
req2.block_tables = [4, 5]
req2.num_cached_blocks = 0
req2.idx = 1
Expand Down
27 changes: 2 additions & 25 deletions tests/v1/test_resource_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def test_schedule_decode_and_waiting_prefill(self):

decode_request = _make_request(request_id="req-decode", prompt_token_ids=[1, 2])
decode_request.idx = 0
decode_request.status = RequestStatus.RUNNING
decode_request.status = RequestStatus.RUNNING_DECODE
decode_request.num_computed_tokens = 2
decode_request.output_token_ids = [99]
decode_request.block_tables = [1]
Expand All @@ -665,30 +665,7 @@ def test_schedule_decode_and_waiting_prefill(self):
self.assertGreaterEqual(len(scheduled_reqs), 2)
self.assertEqual(error_reqs, [])
self.assertIn(decode_request.request_id, manager.using_extend_tables_req_id)
self.assertEqual(waiting_request.status, RequestStatus.RUNNING)

def test_trigger_preempt_records_tasks(self):
manager = _build_manager()
_register_manager_cleanup(self, manager)
manager.cache_manager = MagicMock()
manager.cache_manager.num_gpu_blocks = 8
manager.cache_manager.gpu_free_block_list = list(range(8))
manager.cache_manager.can_allocate_gpu_blocks.side_effect = [False, True]
manager._free_blocks = MagicMock()
preempted_req = _make_request(request_id="req-preempted")
preempted_req.idx = 0
preempted_req.use_extend_tables = False
request = _make_request(request_id="req-target")
request.idx = 1
manager.running = [request, preempted_req]

preempted_reqs = []
scheduled_reqs = []
can_schedule = manager._trigger_preempt(request, 2, preempted_reqs, scheduled_reqs)
self.assertTrue(can_schedule)
self.assertIn(preempted_req.request_id, manager.to_be_rescheduled_request_id_set)
self.assertEqual(preempted_reqs[0], preempted_req)
self.assertEqual(scheduled_reqs[0].request_id, preempted_req.request_id)
self.assertEqual(waiting_request.status, RequestStatus.RUNNING_PREFILL)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 新的抢占逻辑缺少单元测试覆盖

test_trigger_preempt_records_tasks 被删除,但新逻辑(仅抢占 RUNNING_DECODE 请求、从后向前扫描、_recompute_new_token_ratio_on_preemption)没有对应的新测试。建议补充以下场景的测试:

  1. running 队列中混合 RUNNING_PREFILL 和 RUNNING_DECODE 请求时,仅 RUNNING_DECODE 被抢占
  2. 所有 running 请求都是 RUNNING_PREFILL 时,_trigger_preempt 返回 False
  3. _recompute_new_token_ratio_on_preemption 在不同 decode 进度下的 ratio 计算正确性


def test_available_position_and_real_bsz(self):
manager = _build_manager()
Expand Down
Loading