Skip to content

Commit 6b58763

Browse files
author
wangzaijun
committed
fix
1 parent 33cd95d commit 6b58763

File tree

2 files changed

+73
-66
lines changed

2 files changed

+73
-66
lines changed

lightllm/common/basemodel/triton_kernel/kv_cache_offload.py

Lines changed: 71 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ def offload_gpu_kv_to_cpu(
273273
def _load_cpu_cache_to_gpu(
274274
gpu_mem_indexes_ptr,
275275
copy_token_num,
276+
copy_block_num,
276277
cpu_mem_indexes_ptr,
277278
cpu_page_indexes_ptr,
278279
gpu_kv_cache_ptr,
@@ -299,74 +300,76 @@ def _load_cpu_cache_to_gpu(
299300
BLOCK_HEAD_DIM: tl.constexpr,
300301
TOKEN_BLOCK: tl.constexpr,
301302
):
302-
block_index = tl.program_id(0)
303-
token_range = block_index * TOKEN_BLOCK + tl.arange(0, TOKEN_BLOCK)
304-
token_mask = token_range < copy_token_num
305-
gpu_mem_indexes = tl.load(gpu_mem_indexes_ptr + token_range, mask=token_mask).to(tl.int64)
306-
cpu_mem_indexes = tl.load(cpu_mem_indexes_ptr + token_range, mask=token_mask).to(tl.int64)
307-
cpu_page_indexes = tl.load(cpu_page_indexes_ptr + token_range, mask=token_mask).to(tl.int64)
308-
309-
head_dim_range = tl.arange(0, BLOCK_HEAD_DIM)
310-
head_dim_mask = head_dim_range < head_dim
311-
312-
for layer_index in range(layer_num):
313-
move_mask = token_mask[:, None] & head_dim_mask[None, :]
314-
315-
for k_head_index in range(cpu_k_head_num):
316-
gpu_k_head_index = k_head_index + gpu_k_start_head_index
317-
cpu_k_head_index = k_head_index + cpu_k_start_head_index
318-
319-
cpu_ptr = (
320-
cpu_kv_cache_ptr
321-
+ cpu_page_indexes[:, None] * cpu_stride0
322-
+ layer_index.to(tl.int64) * cpu_stride1
323-
+ cpu_mem_indexes[:, None] * cpu_stride2
324-
+ cpu_k_head_index * cpu_stride3
325-
+ head_dim_range[None, :]
326-
)
327-
cpu_data = tl.load(cpu_ptr, mask=move_mask, other=0.0)
328-
329-
gpu_ptr = (
330-
gpu_kv_cache_ptr
331-
+ layer_index.to(tl.int64) * gpu_stride0
332-
+ gpu_mem_indexes[:, None] * gpu_stride1
333-
+ gpu_k_head_index * gpu_stride2
334-
+ head_dim_range[None, :]
335-
)
303+
block_index_start = tl.program_id(0)
304+
split_block_num = tl.num_programs(0)
305+
for block_index in range(block_index_start, copy_block_num, split_block_num):
306+
token_range = block_index * TOKEN_BLOCK + tl.arange(0, TOKEN_BLOCK)
307+
token_mask = token_range < copy_token_num
308+
gpu_mem_indexes = tl.load(gpu_mem_indexes_ptr + token_range, mask=token_mask).to(tl.int64)
309+
cpu_mem_indexes = tl.load(cpu_mem_indexes_ptr + token_range, mask=token_mask).to(tl.int64)
310+
cpu_page_indexes = tl.load(cpu_page_indexes_ptr + token_range, mask=token_mask).to(tl.int64)
336311

337-
tl.store(
338-
gpu_ptr,
339-
cpu_data,
340-
mask=move_mask,
341-
)
312+
head_dim_range = tl.arange(0, BLOCK_HEAD_DIM)
313+
head_dim_mask = head_dim_range < head_dim
342314

343-
for v_head_index in range(cpu_v_head_num):
344-
gpu_v_head_index = v_head_index + gpu_v_start_head_index
345-
cpu_v_head_index = v_head_index + cpu_v_start_head_index
346-
347-
cpu_ptr = (
348-
cpu_kv_cache_ptr
349-
+ cpu_page_indexes[:, None] * cpu_stride0
350-
+ layer_index.to(tl.int64) * cpu_stride1
351-
+ cpu_mem_indexes[:, None] * cpu_stride2
352-
+ cpu_v_head_index * cpu_stride3
353-
+ head_dim_range[None, :]
354-
)
355-
cpu_data = tl.load(cpu_ptr, mask=move_mask, other=0.0)
356-
357-
gpu_ptr = (
358-
gpu_kv_cache_ptr
359-
+ layer_index.to(tl.int64) * gpu_stride0
360-
+ gpu_mem_indexes[:, None] * gpu_stride1
361-
+ gpu_v_head_index * gpu_stride2
362-
+ head_dim_range[None, :]
363-
)
315+
for layer_index in range(layer_num):
316+
move_mask = token_mask[:, None] & head_dim_mask[None, :]
364317

365-
tl.store(
366-
gpu_ptr,
367-
cpu_data,
368-
mask=move_mask,
369-
)
318+
for k_head_index in range(cpu_k_head_num):
319+
gpu_k_head_index = k_head_index + gpu_k_start_head_index
320+
cpu_k_head_index = k_head_index + cpu_k_start_head_index
321+
322+
cpu_ptr = (
323+
cpu_kv_cache_ptr
324+
+ cpu_page_indexes[:, None] * cpu_stride0
325+
+ layer_index.to(tl.int64) * cpu_stride1
326+
+ cpu_mem_indexes[:, None] * cpu_stride2
327+
+ cpu_k_head_index * cpu_stride3
328+
+ head_dim_range[None, :]
329+
)
330+
cpu_data = tl.load(cpu_ptr, mask=move_mask, other=0.0)
331+
332+
gpu_ptr = (
333+
gpu_kv_cache_ptr
334+
+ layer_index.to(tl.int64) * gpu_stride0
335+
+ gpu_mem_indexes[:, None] * gpu_stride1
336+
+ gpu_k_head_index * gpu_stride2
337+
+ head_dim_range[None, :]
338+
)
339+
340+
tl.store(
341+
gpu_ptr,
342+
cpu_data,
343+
mask=move_mask,
344+
)
345+
346+
for v_head_index in range(cpu_v_head_num):
347+
gpu_v_head_index = v_head_index + gpu_v_start_head_index
348+
cpu_v_head_index = v_head_index + cpu_v_start_head_index
349+
350+
cpu_ptr = (
351+
cpu_kv_cache_ptr
352+
+ cpu_page_indexes[:, None] * cpu_stride0
353+
+ layer_index.to(tl.int64) * cpu_stride1
354+
+ cpu_mem_indexes[:, None] * cpu_stride2
355+
+ cpu_v_head_index * cpu_stride3
356+
+ head_dim_range[None, :]
357+
)
358+
cpu_data = tl.load(cpu_ptr, mask=move_mask, other=0.0)
359+
360+
gpu_ptr = (
361+
gpu_kv_cache_ptr
362+
+ layer_index.to(tl.int64) * gpu_stride0
363+
+ gpu_mem_indexes[:, None] * gpu_stride1
364+
+ gpu_v_head_index * gpu_stride2
365+
+ head_dim_range[None, :]
366+
)
367+
368+
tl.store(
369+
gpu_ptr,
370+
cpu_data,
371+
mask=move_mask,
372+
)
370373
return
371374

372375

@@ -378,6 +381,7 @@ def load_cpu_kv_to_gpu(
378381
page_indexes: torch.Tensor,
379382
tp_index: int,
380383
tp_world_size: int,
384+
grid_num: int,
381385
_cache_data={},
382386
):
383387
"""
@@ -489,12 +493,13 @@ def load_cpu_kv_to_gpu(
489493

490494
TOKEN_BLOCK = 128
491495

492-
grid = (triton.cdiv(move_token_num, TOKEN_BLOCK),)
496+
grid = (grid_num,)
493497
num_warps = 4
494498

495499
_load_cpu_cache_to_gpu[grid](
496500
gpu_mem_indexes_ptr=gpu_mem_indexes,
497501
copy_token_num=move_token_num,
502+
copy_block_num=triton.cdiv(move_token_num, TOKEN_BLOCK),
498503
cpu_mem_indexes_ptr=cpu_mem_indexes,
499504
cpu_page_indexes_ptr=cpu_page_indexes,
500505
gpu_kv_cache_ptr=gpu_kv_cache,

lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]):
6262

6363
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(need_size=need_token_num)
6464

65+
# g_infer_context.get_overlap_stream().synchronize()
6566
# 将 cpu page 的内容拷贝到 gpu 页面中
6667
load_cpu_kv_to_gpu(
6768
gpu_mem_indexes=mem_indexes.cuda(non_blocking=True),
@@ -70,6 +71,7 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]):
7071
page_indexes=torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda(non_blocking=True),
7172
tp_index=self.backend.rank_in_dp,
7273
tp_world_size=self.backend.dp_world_size,
74+
grid_num=1 if self.args.enable_fa3 else 16, # TODO 更有效的分配策略。
7375
)
7476

7577
torch.cuda.current_stream().synchronize()

0 commit comments

Comments
 (0)