Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
a8d4fbc
Refactor memory cache allocators to remove unnecessary parameters and…
keyboardAnt Nov 10, 2025
62f0d26
Add RadixKey class for improved key handling in radix_cache
keyboardAnt Nov 10, 2025
78b0325
Remove unnecessary 'need_sort' parameter from TokenToKVPoolAllocator …
keyboardAnt Nov 10, 2025
2cb0793
Add new parameters to RadixCache and SWARadixCache for enhanced funct…
keyboardAnt Nov 11, 2025
ce0638c
Refactor memory management in allocators to unify page freeing methods
keyboardAnt Nov 11, 2025
bad9979
Enhance memory management by introducing back-compat alias for free_p…
keyboardAnt Nov 11, 2025
8fee621
Revert out-of-scope edits (Ascend, chunk_cache, model_runner, benches…
keyboardAnt Nov 11, 2025
311cfb0
Remove 'need_sort' parameter from TokenToKVPoolAllocator and related …
keyboardAnt Nov 11, 2025
7c8b1ea
Enhance memory management by adjusting page ID calculations across al…
keyboardAnt Nov 11, 2025
0a1f0b3
Refactor free_pages calls in BaseTokenToKVPoolAllocator to avoid attr…
keyboardAnt Nov 11, 2025
bc3f272
Enhance memory leak detection in SchedulerRuntimeCheckerMixin by refi…
keyboardAnt Nov 11, 2025
d3233ee
Enhance cache_unfinished_req method in RadixCache to support chunked …
keyboardAnt Nov 11, 2025
293bc8a
Refactor memory allocation methods in allocators to remove CPU tensor…
keyboardAnt Nov 11, 2025
724b07a
Refactor memory management methods to replace free_pages with free_pa…
keyboardAnt Nov 11, 2025
efb68f2
Enhance PagedTokenToKVPoolAllocator to improve page ID handling
keyboardAnt Nov 11, 2025
7343561
Enhance debugging capabilities in SchedulerRuntimeCheckerMixin and Ra…
keyboardAnt Nov 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion benchmark/hf3fs/bench_zerocopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
dtype=kv_cache_dtype,
device=device,
kvcache=token_to_kv_pool,
need_sort=False,
)

kv_cache = token_to_kv_pool_allocator.get_kvcache()
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1578,6 +1578,9 @@ def handle_batch_embedding_request(
self.handle_embedding_request(tokenized_req)

def _get_token_info(self):
# Make staged frees visible before measuring
if hasattr(self.token_to_kv_pool_allocator, "merge_and_sort_free"):
self.token_to_kv_pool_allocator.merge_and_sort_free()
available_size = self.token_to_kv_pool_allocator.available_size()
evictable_size = self.tree_cache.evictable_size()
num_used = self.max_total_num_tokens - (available_size + evictable_size)
Expand All @@ -1586,6 +1589,9 @@ def _get_token_info(self):

def _get_mamba_token_info(self):
is_radix_tree = isinstance(self.tree_cache, MambaRadixCache)
# Make staged frees visible before measuring
if hasattr(self.token_to_kv_pool_allocator, "merge_and_sort_free"):
self.token_to_kv_pool_allocator.merge_and_sort_free()
full_available_size = self.token_to_kv_pool_allocator.available_size()
full_evictable_size = (
self.tree_cache.full_evictable_size() if is_radix_tree else 0
Expand Down
146 changes: 139 additions & 7 deletions python/sglang/srt/managers/scheduler_runtime_checker_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,21 @@ def _check_mamba_memory(self: Scheduler):
def _check_radix_cache_memory(self: Scheduler):
_, _, available_size, evictable_size = self._get_token_info()
protected_size = self.tree_cache.protected_size()
memory_leak = (available_size + evictable_size) != (
# self.max_total_num_tokens
# if not self.enable_hierarchical_cache
# else self.max_total_num_tokens - protected_size
self.max_total_num_tokens
- protected_size
total_accounted = available_size + evictable_size + protected_size
diff = self.max_total_num_tokens - total_accounted
# Allow a small slack for in-flight/unaccounted tokens that are about to be
# reflected by the tree/allocator (e.g., post-insert duplicates or tails).
# This avoids transient false-positives during idle checks.
try:
page_size = int(getattr(self, "page_size", 1))
except Exception:
page_size = 1
tolerance = max(32, page_size)
memory_leak = not (0 <= diff <= tolerance)
token_msg = (
f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, "
f"{protected_size=}, diff={diff}, tolerance={tolerance}\n"
)
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
return memory_leak, token_msg

def _check_runtime_mem_leak(self: Scheduler):
Expand Down Expand Up @@ -149,6 +156,126 @@ def check_memory(self: Scheduler):
memory_leak, token_msg = self._check_radix_cache_memory()

if memory_leak:
# Extra diagnostics to help pinpoint mismatched accounting
try:
alloc = self.token_to_kv_pool_allocator
free_len = (
int(len(getattr(alloc, "free_pages", [])))
if getattr(alloc, "free_pages", None) is not None
else -1
)
release_len = (
int(len(getattr(alloc, "release_pages", [])))
if getattr(alloc, "release_pages", None) is not None
else -1
)
# Some trees expose total_size() for quick sanity
tree_total = None
try:
if hasattr(self.tree_cache, "total_size"):
total = self.tree_cache.total_size()
tree_total = total if isinstance(total, int) else str(total)
except Exception:
tree_total = "n/a"
print(
f"DEBUG {self.max_total_num_tokens=} "
f"free={free_len} "
f"release={release_len} "
f"evictable={self.tree_cache.evictable_size()} "
f"protected={self.tree_cache.protected_size()} "
f"avail={self.token_to_kv_pool_allocator.available_size()} "
f"tree_total={tree_total}"
)
except Exception:
pass
# Extra detailed breakdown including staged frees inside allocator free_group
try:
alloc = self.token_to_kv_pool_allocator
page_size_dbg = int(getattr(alloc, "page_size", 1))
is_open_free_group = bool(
hasattr(alloc, "is_not_in_free_group")
and (not alloc.is_not_in_free_group)
)
fg_list = getattr(alloc, "free_group", None)
staged_groups = int(len(fg_list)) if fg_list is not None else 0
staged_pages = 0
if isinstance(fg_list, list) and staged_groups > 0:
# Sum lengths of page-id tensors staged for grouped frees
staged_pages = int(
sum(int(len(t)) for t in fg_list if t is not None)
)
staged_tokens = staged_pages * page_size_dbg

avail_now = int(self.token_to_kv_pool_allocator.available_size())
evictable_now = int(self.tree_cache.evictable_size())
protected_now = int(self.tree_cache.protected_size())
total_accounted = avail_now + evictable_now + protected_now
diff_now = int(self.max_total_num_tokens - total_accounted)
diff_with_staged = int(
self.max_total_num_tokens - (total_accounted + staged_tokens)
)
reserved_decode = int(
getattr(self.server_args, "num_reserved_decode_tokens", 0)
)
running_nonempty = bool(
self.running_batch is not None and not self.running_batch.is_empty()
)
print(
"DEBUG+ breakdown: "
f"page_size={page_size_dbg} "
f"staged_groups={staged_groups} "
f"staged_pages={staged_pages} "
f"staged_tokens={staged_tokens} "
f"avail_now={avail_now} "
f"evictable_now={evictable_now} "
f"protected_now={protected_now} "
f"total_accounted={total_accounted} "
f"diff_now={diff_now} "
f"diff_with_staged={diff_with_staged} "
f"reserved_decode={reserved_decode} "
f"running_nonempty={running_nonempty}"
)
except Exception:
pass
# Decode-boundary slack estimate (pages that will be allocated next decode step)
try:
last_slack_pages = -1
run_slack_pages = -1
if getattr(self, "last_batch", None) is not None:
last_slack_pages = int(self.last_batch.new_page_count_next_decode())
if (
getattr(self, "running_batch", None) is not None
and not self.running_batch.is_empty()
):
run_slack_pages = int(
self.running_batch.new_page_count_next_decode()
)
slack_pages_total = max(
0, (0 if last_slack_pages < 0 else last_slack_pages)
) + max(0, (0 if run_slack_pages < 0 else run_slack_pages))
slack_tokens = slack_pages_total * page_size_dbg
diff_minus_slack = diff_now - slack_tokens
last_bs = (
self.last_batch.batch_size() if self.last_batch is not None else 0
)
run_bs = (
self.running_batch.batch_size()
if self.running_batch is not None
and not self.running_batch.is_empty()
else 0
)
print(
"DEBUG+ decode_slack: "
f"last_slack_pages={last_slack_pages} "
f"run_slack_pages={run_slack_pages} "
f"slack_pages_total={slack_pages_total} "
f"slack_tokens={slack_tokens} "
f"diff_minus_slack={diff_minus_slack} "
f"last_batch_size={last_bs} "
f"running_batch_size={run_bs}"
)
except Exception:
pass
msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
raise ValueError(msg)

Expand Down Expand Up @@ -217,6 +344,11 @@ def check_tree_cache(self: Scheduler):
self.tree_cache.sanity_check()

def self_check_during_idle(self: Scheduler):
# Skip idle checks if there is an in-flight running batch to avoid counting
# tokens that are not yet reflected in the radix tree or allocator lists.
if self.running_batch is not None and not self.running_batch.is_empty():
return

if self.disaggregation_mode == DisaggregationMode.PREFILL:
if len(self.disagg_prefill_inflight_queue) > 0:
return
Expand Down
Loading
Loading