Skip to content

Disk cache and cpu Cache feature #997

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 64 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
2b5c260
add multi_level_kv_cache
Jul 29, 2025
8672506
fix
Jul 29, 2025
d2a4c53
add shmdict
Jul 29, 2025
bb230a6
add shm dict
Jul 30, 2025
c4caedb
add shm dict
Jul 30, 2025
1851371
fix cpu cache client
Jul 30, 2025
506167d
fix
Jul 30, 2025
44bec0c
fix
Jul 30, 2025
557c536
disk and cpu cache enable mix
Jul 30, 2025
fa5ec1f
add hash utils
hiworldwzj Jul 30, 2025
414645a
fix
hiworldwzj Jul 30, 2025
6752565
add start args for cpu cache and disk cache
Jul 31, 2025
738cfea
add CpuCacheMatch List
Jul 31, 2025
a3dffb5
add multi_level_cache manager
Jul 31, 2025
b617a83
fix
Jul 31, 2025
257f57c
fix
Jul 31, 2025
d6d3851
add calcu_cpu_cache_page_num
Jul 31, 2025
906996e
improve radix cache
Aug 1, 2025
70d762a
improve pd p impl
Aug 1, 2025
a896b3a
add multi_level_cache_manager.py
Aug 1, 2025
565916c
fix
hiworldwzj Aug 2, 2025
7033c0b
add kv cache offload kernel
hiworldwzj Aug 2, 2025
11e67e9
add to do
hiworldwzj Aug 2, 2025
acb6ade
add kv_cache_utils.py
hiworldwzj Aug 3, 2025
3830612
add register_shm_ptr_to_pin
Aug 4, 2025
7d852a7
fix
Aug 4, 2025
1cf9d74
fix
Aug 4, 2025
287ce25
fix
Aug 4, 2025
de03e94
fix
Aug 4, 2025
516e9bc
fix
Aug 4, 2025
79b06b6
fix
Aug 4, 2025
ab46882
fix
Aug 4, 2025
4f9269f
fix
Aug 4, 2025
3363fd3
fix
Aug 4, 2025
8d54b6c
fix
Aug 4, 2025
fc9cbff
fix
hiworldwzj Aug 4, 2025
50b2e9c
fix
hiworldwzj Aug 4, 2025
7218378
fix
hiworldwzj Aug 4, 2025
542b4c2
fix
hiworldwzj Aug 4, 2025
728d447
fix
hiworldwzj Aug 4, 2025
989ca56
fix
hiworldwzj Aug 4, 2025
7683430
fix radix cache insert
hiworldwzj Aug 4, 2025
accd36c
fix
hiworldwzj Aug 4, 2025
4a18b3d
fix
Aug 11, 2025
5437ee7
fix
Aug 11, 2025
4cc6d08
fix
Aug 11, 2025
b011825
add multi_level_kv_cache start
Aug 11, 2025
088345b
add multi_level_kv_cache start
Aug 11, 2025
c3d5e61
rename
Aug 11, 2025
151333f
rename multi level kv cache
Aug 11, 2025
2711e8b
fix
Aug 11, 2025
c9a1838
fix
Aug 11, 2025
937af93
fix
Aug 11, 2025
ff7fa9c
add draft test.py
Aug 11, 2025
df1d500
fix
hiworldwzj Aug 11, 2025
6ef4dcd
fix first version
hiworldwzj Aug 11, 2025
31d57af
add cpu_prompt_cache_len
Aug 12, 2025
9d56ce9
add cpu_prompt_cache_len
Aug 12, 2025
cae4432
fix
Aug 12, 2025
5cbec55
fix
Aug 12, 2025
f4cdbed
fix
Aug 12, 2025
053c922
Fix
Aug 12, 2025
b5ca416
fix
Aug 12, 2025
2d1cc47
fix
Aug 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 228 additions & 0 deletions lightllm/common/basemodel/triton_kernel/kv_cache_offload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
import torch

import triton
import triton.language as tl


@triton.jit
def _offload_gpu_kv_to_cpu(
token_indexes_ptr,
gpu_kv_cache_ptr,
gpu_stride0,
gpu_stride1,
gpu_stride2,
gpu_stride3,
cpu_kv_cache_ptr,
cpu_stride0,
cpu_stride1,
cpu_stride2,
cpu_stride3,
cpu_stride4,
page_indexes_ptr,
page_readies_ptr,
layer_num,
head_all_dim,
BLOCK_HEAD_ALL_DIM: tl.constexpr,
TOKEN_BLOCK: tl.constexpr,
):
block_index = tl.program_id(0)
cpu_page_index = tl.load(page_indexes_ptr + block_index).to(tl.int64)
if cpu_page_index == -1:
return

ready_state = tl.load(page_readies_ptr + block_index)
if ready_state:
return

token_range = block_index * TOKEN_BLOCK + tl.arange(0, TOKEN_BLOCK)
token_indexes = tl.load(token_indexes_ptr + token_range).to(tl.int64)
head_all_dim_range = tl.arange(0, BLOCK_HEAD_ALL_DIM)

gpu_stride0 = tl.cast(gpu_stride0, dtype=tl.int64)

for layer_index in range(layer_num):
gpu_ptr = (
gpu_kv_cache_ptr
+ layer_index * gpu_stride0
+ token_indexes[:, None] * gpu_stride1
+ head_all_dim_range[None, :]
)
gpu_data = tl.load(gpu_ptr, mask=(head_all_dim_range[None, :] < head_all_dim), other=0.0)
cpu_ptr = (
cpu_kv_cache_ptr
+ cpu_page_index * cpu_stride0
+ layer_index * cpu_stride1
+ tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_stride2
+ head_all_dim_range[None, :]
)
tl.store(
cpu_ptr,
gpu_data,
mask=(head_all_dim_range[None, :] < head_all_dim),
)
return



@torch.no_grad()
def offload_gpu_kv_to_cpu(
token_indexes: torch.Tensor,
gpu_kv_cache: torch.Tensor,
cpu_kv_cache: torch.Tensor,
page_indexes: torch.Tensor,
page_readies: torch.Tensor,
):
"""
this function is used to offload GPU KV cache to CPU KV cache.
Args:
token_indexes: (token_num,)
gpu_kv_cache: (layer_num, token_num, head_num, head_dim)
cpu_kv_cache: (all_page_num, layer_num, token_block_size, head_num, head_dim)
page_indexes: (page_num,)
page_readies: (page_num,)
"""
token_block_size = cpu_kv_cache.shape[2]
token_num = page_indexes.shape[0] * token_block_size
assert token_indexes.shape[0] >= token_num
assert page_indexes.shape == page_readies.shape
page_num = page_indexes.shape[0]
head_all_dim = gpu_kv_cache.shape[-1] * gpu_kv_cache.shape[-2]
BLOCK_HEAD_ALL_DIM = triton.next_power_of_2(gpu_kv_cache.shape[-1] * gpu_kv_cache.shape[-2])

grid = (page_num,)
num_warps = 4

_offload_gpu_kv_to_cpu[grid](
token_indexes_ptr=token_indexes,
gpu_kv_cache_ptr=gpu_kv_cache,
gpu_stride0=gpu_kv_cache.stride(0),
gpu_stride1=gpu_kv_cache.stride(1),
gpu_stride2=gpu_kv_cache.stride(2),
gpu_stride3=gpu_kv_cache.stride(3),
cpu_kv_cache_ptr=cpu_kv_cache,
cpu_stride0=cpu_kv_cache.stride(0),
cpu_stride1=cpu_kv_cache.stride(1),
cpu_stride2=cpu_kv_cache.stride(2),
cpu_stride3=cpu_kv_cache.stride(3),
cpu_stride4=cpu_kv_cache.stride(4),
page_indexes_ptr=page_indexes,
page_readies_ptr=page_readies,
layer_num=gpu_kv_cache.shape[0],
head_all_dim=head_all_dim,
BLOCK_HEAD_ALL_DIM=BLOCK_HEAD_ALL_DIM,
TOKEN_BLOCK=token_block_size,
num_warps=num_warps,
num_stages=1,
)
return


@triton.jit
def _load_cpu_cache_to_gpu(
token_indexes_ptr,
gpu_kv_cache_ptr,
gpu_stride0,
gpu_stride1,
gpu_stride2,
gpu_stride3,
cpu_kv_cache_ptr,
cpu_stride0,
cpu_stride1,
cpu_stride2,
cpu_stride3,
cpu_stride4,
page_indexes_ptr,
layer_num,
head_all_dim,
all_move_token_num,
BLOCK_HEAD_ALL_DIM: tl.constexpr,
TOKEN_BLOCK: tl.constexpr,
):
block_index = tl.program_id(0)
cpu_page_index = tl.load(page_indexes_ptr + block_index).to(tl.int64)
if cpu_page_index == -1:
return

gpu_stride0 = tl.cast(gpu_stride0, dtype=tl.int64)
padded_size = TOKEN_BLOCK * tl.num_programs(0) - all_move_token_num
head_all_dim_range = tl.arange(0, BLOCK_HEAD_ALL_DIM)
token_range = block_index * TOKEN_BLOCK + tl.arange(0, TOKEN_BLOCK)
token_range = token_range - padded_size

token_mask = token_range >= 0
head_dim_mask = head_all_dim_range < head_all_dim

token_indexes = tl.load(token_indexes_ptr + token_range, mask=token_mask, other=0).to(tl.int64)

cpu_page_index = tl.load(page_indexes_ptr + block_index)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The cpu_page_index is loaded twice. The load at line 157 is redundant as it was already loaded at line 140. This can be removed for minor performance improvement and code clarity.

for layer_index in range(layer_num):
cpu_ptr = (
cpu_kv_cache_ptr
+ cpu_page_index * cpu_stride0
+ layer_index * cpu_stride1
+ tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_stride2
+ head_all_dim_range[None, :]
)
cpu_data = tl.load(cpu_ptr, mask=head_dim_mask[None, :], other=0.0)

gpu_ptr = (
gpu_kv_cache_ptr
+ layer_index * gpu_stride0
+ token_indexes[:, None] * gpu_stride1
+ head_all_dim_range[None, :]
)
tl.store(
gpu_ptr,
cpu_data,
mask=token_mask[:, None] & head_dim_mask[None, :],
)
return


@torch.no_grad()
def load_cpu_kv_to_gpu(
mem_indexes: torch.Tensor,
gpu_kv_cache: torch.Tensor,
cpu_kv_cache: torch.Tensor,
page_indexes: torch.Tensor,
):
"""
this function is used to offload GPU KV cache to CPU KV cache.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring is incorrect. It states that the function offloads from GPU to CPU, but it should be loading from CPU to GPU.

Suggested change
this function is used to offload GPU KV cache to CPU KV cache.
this function is used to load CPU KV cache to GPU KV cache.

Args:
mem_indexes: (token_num,)
gpu_kv_cache: (layer_num, token_num, head_num, head_dim)
cpu_kv_cache: (page_num, layer_num, token_block_size, head_num, head_dim)
page_indexes: (page_num,)
"""
token_block_size = cpu_kv_cache.shape[2]
token_num = page_indexes.shape[0] * token_block_size
assert mem_indexes.shape[0] >= token_num
page_num = page_indexes.shape[0]
BLOCK_HEAD_ALL_DIM = triton.next_power_of_2(gpu_kv_cache.shape[-1] * gpu_kv_cache.shape[-2])

grid = (page_num,)
num_warps = 1

_offload_gpu_kv_to_cpu[grid](
token_indexes_ptr=mem_indexes,
gpu_kv_cache_ptr=gpu_kv_cache,
gpu_stride0=gpu_kv_cache.stride(0),
gpu_stride1=gpu_kv_cache.stride(1),
gpu_stride2=gpu_kv_cache.stride(2),
gpu_stride3=gpu_kv_cache.stride(3),
cpu_kv_cache_ptr=cpu_kv_cache,
cpu_stride0=cpu_kv_cache.stride(0),
cpu_stride1=cpu_kv_cache.stride(1),
cpu_stride2=cpu_kv_cache.stride(2),
cpu_stride3=cpu_kv_cache.stride(3),
cpu_stride4=cpu_kv_cache.stride(4),
page_indexes_ptr=page_indexes,
layer_num=gpu_kv_cache.shape[0],
head_all_dim=gpu_kv_cache.shape[-1] * gpu_kv_cache.shape[-2],
all_move_token_num=len(mem_indexes),
BLOCK_HEAD_ALL_DIM=BLOCK_HEAD_ALL_DIM,
TOKEN_BLOCK=token_block_size,
num_warps=num_warps,
num_stages=1,
)
Comment on lines +206 to +227

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This function incorrectly calls the _offload_gpu_kv_to_cpu kernel instead of _load_cpu_cache_to_gpu. This is a critical error that will cause data to be moved in the wrong direction. The arguments passed are also incorrect for the intended operation, and the chuncked_size parameter is missing.

    _load_cpu_cache_to_gpu[grid](
        token_indexes_ptr=mem_indexes,
        gpu_kv_cache_ptr=gpu_kv_cache,
        gpu_stride0=gpu_kv_cache.stride(0),
        gpu_stride1=gpu_kv_cache.stride(1),
        gpu_stride2=gpu_kv_cache.stride(2),
        gpu_stride3=gpu_kv_cache.stride(3),
        cpu_kv_cache_ptr=cpu_kv_cache,
        cpu_stride0=cpu_kv_cache.stride(0),
        cpu_stride1=cpu_kv_cache.stride(1),
        cpu_stride2=cpu_kv_cache.stride(2),
        cpu_stride3=cpu_kv_cache.stride(3),
        cpu_stride4=cpu_kv_cache.stride(4),
        page_indexes_ptr=page_indexes,
        layer_num=gpu_kv_cache.shape[0],
        head_all_dim=gpu_kv_cache.shape[-1] * gpu_kv_cache.shape[-2],
        all_move_token_num=len(mem_indexes),
        chuncked_size=token_num,
        BLOCK_HEAD_ALL_DIM=BLOCK_HEAD_ALL_DIM,
        TOKEN_BLOCK=token_block_size,
        num_warps=num_warps,
        num_stages=1,
    )

return
21 changes: 21 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,4 +477,25 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=0.03,
help="""The interval of the schedule time, default is 30ms.""",
)
parser.add_argument(
"--enable_cpu_cache",
action="store_true",
help="""enable cpu cache to store kv cache.""",
)
parser.add_argument(
"--cpu_cache_storage_size",
type=float,
default=2,
help="""The capacity of cpu cache. GB used.""",
)
parser.add_argument(
"--cpu_cache_token_page_size",
type=int,
default=256,
help="""The token page size of cpu cache""",
)
parser.add_argument("--enable_disk_cache", action="store_true", help="""enable disk cache to store kv cache.""")
parser.add_argument(
"--disk_cache_storage_size", type=float, default=10, help="""The capacity of disk cache. GB used."""
)
return parser
16 changes: 4 additions & 12 deletions lightllm/server/api_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.responses import Response, StreamingResponse, JSONResponse
from lightllm.server.core.objs.sampling_params import SamplingParams
from lightllm.server.core.objs import StartArgs
from .multimodal_params import MultimodalParams
from .httpserver.manager import HttpServerManager
from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster
Expand Down Expand Up @@ -71,7 +72,7 @@ class G_Objs:
httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster] = None
shared_token_load: TokenLoad = None

def set_args(self, args):
def set_args(self, args: StartArgs):
self.args = args
from .api_lightllm import lightllm_generate, lightllm_generate_stream
from .api_tgi import tgi_generate_impl, tgi_generate_stream_impl
Expand All @@ -86,22 +87,13 @@ def set_args(self, args):
if args.run_mode == "pd_master":
self.metric_client = MetricClient(args.metric_port)
self.httpserver_manager = HttpServerManagerForPDMaster(
args,
metric_port=args.metric_port,
args=args,
)
else:
init_tokenizer(args) # for openai api
SamplingParams.load_generation_cfg(args.model_dir)
self.metric_client = MetricClient(args.metric_port)
self.httpserver_manager = HttpServerManager(
args,
router_port=args.router_port,
cache_port=args.cache_port,
detokenization_pub_port=args.detokenization_pub_port,
visual_port=args.visual_port,
enable_multimodal=args.enable_multimodal,
metric_port=args.metric_port,
)
self.httpserver_manager = HttpServerManager(args=args)
dp_size_in_node = max(1, args.dp // args.nnodes) # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", dp_size_in_node)

Expand Down
Loading
Loading