Skip to content
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
41d97d8
page head kernel
huangtingwei9988 Oct 10, 2025
f9805fb
fix
huangtingwei9988 Oct 10, 2025
9f49f81
fix
huangtingwei9988 Oct 10, 2025
fb85c0f
fix
huangtingwei9988 Oct 10, 2025
26d6a9e
fix
huangtingwei9988 Oct 10, 2025
642f228
fix
huangtingwei9988 Oct 10, 2025
6bc369c
fix
huangtingwei9988 Oct 10, 2025
8a028a7
fix
huangtingwei9988 Oct 10, 2025
2c4139c
fix
huangtingwei9988 Oct 10, 2025
964692c
add test case
huangtingwei9988 Oct 11, 2025
6aef21d
fix
huangtingwei9988 Oct 13, 2025
84e4d9d
fix
huangtingwei9988 Oct 13, 2025
403d360
fix
huangtingwei9988 Oct 13, 2025
5988f1c
fix
huangtingwei9988 Oct 13, 2025
b2ce9bc
fix
huangtingwei9988 Oct 13, 2025
1ca427a
add transfer_kv_per_layer_lf_phf
huangtingwei9988 Oct 14, 2025
ee97dcc
fix test
huangtingwei9988 Oct 14, 2025
8a87ff4
fix lint
huangtingwei9988 Oct 14, 2025
7863788
fix bug
huangtingwei9988 Oct 14, 2025
7b4670d
fix bug
huangtingwei9988 Oct 14, 2025
2c0eb75
Merge branch 'main' into page_head_io_kernel
huangtingwei9988 Oct 14, 2025
184f77f
Merge branch 'main' into page_head_io_kernel
huangtingwei9988 Oct 15, 2025
91c3ce3
fix typo
huangtingwei9988 Oct 15, 2025
2b52407
Merge branch 'page_head_io_kernel' of github.com:antgroup/sglang into…
huangtingwei9988 Oct 15, 2025
1084b2a
support memory_pool_host page head layout
huangtingwei9988 Oct 15, 2025
8f3d9dd
mooncake store support page head layout
huangtingwei9988 Oct 15, 2025
540aa2d
Merge branch 'main' into support_page_head_layout
huangtingwei9988 Oct 20, 2025
0b91373
Merge branch 'main' into support_page_head_layout
huangtingwei9988 Oct 22, 2025
b1a80e1
Merge branch 'main' into support_page_head_layout
huangtingwei9988 Oct 26, 2025
9596c03
merge main
huangtingwei9988 Oct 26, 2025
2b901d4
Merge branch 'main' into support_page_head_layout
huangtingwei9988 Oct 26, 2025
1f6193a
Merge branch 'main' into support_page_head_layout
huangtingwei9988 Oct 26, 2025
f1c0223
Merge branch 'main' into support_page_head_layout
huangtingwei9988 Oct 28, 2025
27173eb
Merge branch 'main' into support_page_head_layout
huangtingwei9988 Oct 30, 2025
edb00f3
Merge branch 'main' into support_page_head_layout
huangtingwei9988 Nov 3, 2025
f66e68b
Merge branch 'main' into support_page_head_layout
huangtingwei9988 Nov 3, 2025
de0de62
Merge branch 'main' into support_page_head_layout
huangtingwei9988 Nov 4, 2025
08adb1a
Merge branch 'main' into support_page_head_layout
huangtingwei9988 Nov 5, 2025
5861707
Merge branch 'main' into support_page_head_layout
huangtingwei9988 Nov 6, 2025
0148cee
Merge branch 'main' into support_page_head_layout
huangtingwei9988 Nov 10, 2025
3da613f
Merge branch 'main' into support_page_head_layout
huangtingwei9988 Nov 11, 2025
5a21d6f
resolve conflicts
huangtingwei9988 Nov 14, 2025
7e86008
Merge branch 'main' into support_page_head_layout
huangtingwei9988 Nov 15, 2025
416b7f9
Merge branch 'main' into support_page_head_layout
huangtingwei9988 Nov 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions python/sglang/srt/mem_cache/memory_pool_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
transfer_kv_all_layer,
transfer_kv_all_layer_direct_lf_pf,
transfer_kv_all_layer_lf_pf,
transfer_kv_all_layer_lf_ph,
transfer_kv_all_layer_mla,
transfer_kv_all_layer_mla_lf_pf,
transfer_kv_direct,
Expand All @@ -25,6 +26,7 @@
transfer_kv_per_layer_mla,
transfer_kv_per_layer_mla_pf_lf,
transfer_kv_per_layer_pf_lf,
transfer_kv_per_layer_ph_lf,
)
if _is_npu:
from sgl_kernel_npu.kvcacheio import TransferDirection, transfer_kv_dim_exchange
Expand Down Expand Up @@ -238,6 +240,15 @@ def init_kv_buffer(self):
self.head_num,
self.head_dim,
)
elif self.layout == "page_head":
dims = (
2,
self.page_num,
self.head_num,
self.page_size,
self.layer_num,
self.head_dim,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
Expand Down Expand Up @@ -292,6 +303,20 @@ def load_to_device_per_layer(
item_size=self.token_stride_size,
src_layout_dim=self.layout_dim,
)
elif self.layout == "page_head":
transfer_kv_per_layer_ph_lf(
src_k=self.k_buffer,
dst_k=device_pool.k_buffer[layer_id],
src_v=self.v_buffer,
dst_v=device_pool.v_buffer[layer_id],
src_indices=host_indices,
dst_indices=device_indices,
layer_id=layer_id,
item_size=self.token_stride_size,
src_layout_dim=self.layout_dim,
page_size=self.page_size,
head_num=self.head_num,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
elif io_backend == "direct":
Expand Down Expand Up @@ -366,6 +391,20 @@ def backup_from_device_all_layer(
dst_layout_dim=self.layout_dim,
num_layers=self.layer_num,
)
elif self.layout == "page_head":
transfer_kv_all_layer_lf_ph(
src_k_layers=device_pool.k_data_ptrs,
dst_k=self.k_buffer,
src_v_layers=device_pool.v_data_ptrs,
dst_v=self.v_buffer,
src_indices=device_indices,
dst_indices=host_indices,
item_size=self.token_stride_size,
dst_layout_dim=self.layout_dim,
num_layers=self.layer_num,
page_size=self.page_size,
head_num=self.head_num,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
elif io_backend == "direct":
Expand Down Expand Up @@ -409,7 +448,7 @@ def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
data_page = self.kv_buffer[:, :, index : index + self.page_size, :, :]
elif self.layout == "page_first":
data_page = self.kv_buffer[:, index : index + self.page_size, :, :, :]
elif self.layout == "page_first_direct":
elif self.layout in ["page_first_direct", "page_head"]:
real_index = index // self.page_size
data_page = self.kv_buffer[:, real_index : real_index + 1, :, :, :, :]
else:
Expand Down Expand Up @@ -450,6 +489,13 @@ def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
2, 1, self.layer_num, self.page_size, self.head_num, self.head_dim
)
)
elif self.layout == "page_head":
real_index = index // self.page_size
self.kv_buffer[:, real_index : real_index + 1, :, :, :, :] = (
data_page.reshape(
2, 1, self.head_num, self.page_size, self.layer_num, self.head_dim
)
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")

Expand Down Expand Up @@ -490,7 +536,7 @@ def get_page_buffer_meta(self, indices):
self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
)
element_size_list = [element_size] * len(ptr_list)
elif self.layout in ["page_first", "page_first_direct"]:
elif self.layout in ["page_first", "page_first_direct", "page_head"]:
for index in range(0, len(indices), self.page_size):
k_ptr = (
kv_buffer_data_ptr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def register_mem_pool_host(self, mem_pool_host: HostKVCache):
assert self.mem_pool_host.layout in [
"page_first",
"page_first_direct",
"page_head",
], "mooncake store storage backend only support page first or page first direct layout"
buffer = self.mem_pool_host.kv_buffer
try:
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3054,6 +3054,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
"page_first",
"page_first_direct",
"page_first_kv_split",
"page_head",
],
default=ServerArgs.hicache_mem_layout,
help="The layout of host memory pool for hierarchical cache.",
Expand Down
Loading