Skip to content

Disk cache and cpu Cache feature #997

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 64 commits into
base: main
Choose a base branch
from
Open

Disk cache and cpu Cache feature #997

wants to merge 64 commits into from

Conversation

hiworldwzj
Copy link
Collaborator

No description provided.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant new feature for CPU and disk-based KV caching. The implementation is extensive, touching many parts of the system from low-level Triton kernels to high-level server logic. While the overall direction is good, there are several critical issues that need to be addressed, including incorrect kernel calls, logical errors in loops, and improper shared memory allocation. There are also numerous typos (e.g., 'chuncked' instead of 'chunked') that should be corrected for code clarity and consistency.


# 创建共享内存
key = args.cpu_kv_cache_shm_id # 共享内存的键
size = 1024 # 共享内存大小

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The shared memory size is hardcoded to 1024 bytes, which is far too small for the KV cache. The actual required size should be calculated using calcu_cpu_cache_meta().calcu_size() and used here.

Suggested change
size = 1024 # 共享内存大小
size = calcu_cpu_cache_meta().calcu_size() # 共享内存大小

self.multi_level_cache_manager.cpu_cache_handle_queue.append(trans_task)
else:
true_finished_reqs.append(req)
return true_finished_reqs

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The return true_finished_reqs statement is inside the for loop, which will cause the function to exit after processing only the first finished request. This is incorrect and will lead to other finished requests not being processed for CPU cache offloading. The return statement should be moved outside the loop.

Suggested change
return true_finished_reqs
return true_finished_reqs

Comment on lines 186 to 197
while cur_link_item is not None:
if cur_link_item.key == key:
# remove item
pre_item = cur_link_item.get_pre_item()
pre_item.next_index = cur_link_item.next_index
if cur_link_item.next_index != -1:
next_item = cur_link_item.get_next_item()
next_item.pre_index = pre_item.self_index

self.link_items.add_item_to_tail(index=cur_link_item.self_index)
else:
cur_link_item = cur_link_item.get_next_item()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The remove method in ShmDict has a bug. After finding and removing an item, the while loop continues without updating cur_link_item, leading to an infinite loop. Assuming keys are unique, the loop should terminate after removal.

Suggested change
while cur_link_item is not None:
if cur_link_item.key == key:
# remove item
pre_item = cur_link_item.get_pre_item()
pre_item.next_index = cur_link_item.next_index
if cur_link_item.next_index != -1:
next_item = cur_link_item.get_next_item()
next_item.pre_index = pre_item.self_index
self.link_items.add_item_to_tail(index=cur_link_item.self_index)
else:
cur_link_item = cur_link_item.get_next_item()
while cur_link_item is not None:
if cur_link_item.key == key:
# remove item
pre_item = cur_link_item.get_pre_item()
pre_item.next_index = cur_link_item.next_index
if cur_link_item.next_index != -1:
next_item = cur_link_item.get_next_item()
next_item.pre_index = pre_item.self_index
self.link_items.add_item_to_tail(index=cur_link_item.self_index)
return
else:
cur_link_item = cur_link_item.get_next_item()

Comment on lines +206 to +227
_offload_gpu_kv_to_cpu[grid](
token_indexes_ptr=mem_indexes,
gpu_kv_cache_ptr=gpu_kv_cache,
gpu_stride0=gpu_kv_cache.stride(0),
gpu_stride1=gpu_kv_cache.stride(1),
gpu_stride2=gpu_kv_cache.stride(2),
gpu_stride3=gpu_kv_cache.stride(3),
cpu_kv_cache_ptr=cpu_kv_cache,
cpu_stride0=cpu_kv_cache.stride(0),
cpu_stride1=cpu_kv_cache.stride(1),
cpu_stride2=cpu_kv_cache.stride(2),
cpu_stride3=cpu_kv_cache.stride(3),
cpu_stride4=cpu_kv_cache.stride(4),
page_indexes_ptr=page_indexes,
layer_num=gpu_kv_cache.shape[0],
head_all_dim=gpu_kv_cache.shape[-1] * gpu_kv_cache.shape[-2],
all_move_token_num=len(mem_indexes),
BLOCK_HEAD_ALL_DIM=BLOCK_HEAD_ALL_DIM,
TOKEN_BLOCK=token_block_size,
num_warps=num_warps,
num_stages=1,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This function incorrectly calls the _offload_gpu_kv_to_cpu kernel instead of _load_cpu_cache_to_gpu. This is a critical error that will cause data to be moved in the wrong direction. The arguments passed are also incorrect for the intended operation, and the chuncked_size parameter is missing.

    _load_cpu_cache_to_gpu[grid](
        token_indexes_ptr=mem_indexes,
        gpu_kv_cache_ptr=gpu_kv_cache,
        gpu_stride0=gpu_kv_cache.stride(0),
        gpu_stride1=gpu_kv_cache.stride(1),
        gpu_stride2=gpu_kv_cache.stride(2),
        gpu_stride3=gpu_kv_cache.stride(3),
        cpu_kv_cache_ptr=cpu_kv_cache,
        cpu_stride0=cpu_kv_cache.stride(0),
        cpu_stride1=cpu_kv_cache.stride(1),
        cpu_stride2=cpu_kv_cache.stride(2),
        cpu_stride3=cpu_kv_cache.stride(3),
        cpu_stride4=cpu_kv_cache.stride(4),
        page_indexes_ptr=page_indexes,
        layer_num=gpu_kv_cache.shape[0],
        head_all_dim=gpu_kv_cache.shape[-1] * gpu_kv_cache.shape[-2],
        all_move_token_num=len(mem_indexes),
        chuncked_size=token_num,
        BLOCK_HEAD_ALL_DIM=BLOCK_HEAD_ALL_DIM,
        TOKEN_BLOCK=token_block_size,
        num_warps=num_warps,
        num_stages=1,
    )

req: Req = req
finded_page_indexes = []
for token_chuncked_hash_value in req.token_hash_list.get_all():
page_index = self.cpu_cache_client.query_one_page(token_chuncked_hash_value)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The method query_one_page returns a tuple (page_index, is_ready), but the result is assigned to a single variable page_index. This will cause incorrect behavior as page_index will be a tuple. The tuple should be unpacked. Also, the variable token_chuncked_hash_value has a typo.

Suggested change
page_index = self.cpu_cache_client.query_one_page(token_chuncked_hash_value)
page_index, _ = self.cpu_cache_client.query_one_page(token_chunked_hash_value)

) -> Optional["TransTask"]:
with torch.cuda.stream(cpu_kv_cache_stream):
all_token_hash_list = req.shm_req.token_hash_list.get_all()
block_size = req.cur_kv_len // self.args.cpu_cache_token_chuncked_size

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in an attribute name. cpu_cache_token_chuncked_size should be cpu_cache_token_chunked_size.

Suggested change
block_size = req.cur_kv_len // self.args.cpu_cache_token_chuncked_size
block_size = req.cur_kv_len // self.args.cpu_cache_token_chunked_size


def fill_cpu_cache_to_reqs(self, reqs: List[InferReq]):
idle_token_num = g_infer_context.get_can_alloc_token_num()
token_chuncked_size = self.args.cpu_cache_token_chuncked_size

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in a variable name. token_chuncked_size should be token_chunked_size.

Suggested change
token_chuncked_size = self.args.cpu_cache_token_chuncked_size
token_chunked_size = self.args.cpu_cache_token_chunked_size

for req in reqs:
if req.shm_req.group_req_id == req.shm_req.request_id:
page_list = req.shm_req.cpu_cache_match_page_indexes.get_all()
match_tokens = len(page_list) * token_chuncked_size

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in a variable name. token_chuncked_size should be token_chunked_size.

Suggested change
match_tokens = len(page_list) * token_chuncked_size
match_tokens = len(page_list) * token_chunked_size

if cpu_page_index == -1:
return

first_block_start_index = chuncked_size * tl.num_programs(0) - all_move_token_num

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the variable name chuncked_size. It should be chunked_size.

Suggested change
first_block_start_index = chuncked_size * tl.num_programs(0) - all_move_token_num
first_block_start_index = chunked_size * tl.num_programs(0) - all_move_token_num

layer_num,
head_all_dim,
all_move_token_num,
chuncked_size,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the parameter name chuncked_size. It should be chunked_size for consistency and correctness.

Suggested change
chuncked_size,
chunked_size,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant