@@ -273,6 +273,7 @@ def offload_gpu_kv_to_cpu(
273273def _load_cpu_cache_to_gpu (
274274 gpu_mem_indexes_ptr ,
275275 copy_token_num ,
276+ copy_block_num ,
276277 cpu_mem_indexes_ptr ,
277278 cpu_page_indexes_ptr ,
278279 gpu_kv_cache_ptr ,
@@ -299,74 +300,76 @@ def _load_cpu_cache_to_gpu(
299300 BLOCK_HEAD_DIM : tl .constexpr ,
300301 TOKEN_BLOCK : tl .constexpr ,
301302):
302- block_index = tl .program_id (0 )
303- token_range = block_index * TOKEN_BLOCK + tl .arange (0 , TOKEN_BLOCK )
304- token_mask = token_range < copy_token_num
305- gpu_mem_indexes = tl .load (gpu_mem_indexes_ptr + token_range , mask = token_mask ).to (tl .int64 )
306- cpu_mem_indexes = tl .load (cpu_mem_indexes_ptr + token_range , mask = token_mask ).to (tl .int64 )
307- cpu_page_indexes = tl .load (cpu_page_indexes_ptr + token_range , mask = token_mask ).to (tl .int64 )
308-
309- head_dim_range = tl .arange (0 , BLOCK_HEAD_DIM )
310- head_dim_mask = head_dim_range < head_dim
311-
312- for layer_index in range (layer_num ):
313- move_mask = token_mask [:, None ] & head_dim_mask [None , :]
314-
315- for k_head_index in range (cpu_k_head_num ):
316- gpu_k_head_index = k_head_index + gpu_k_start_head_index
317- cpu_k_head_index = k_head_index + cpu_k_start_head_index
318-
319- cpu_ptr = (
320- cpu_kv_cache_ptr
321- + cpu_page_indexes [:, None ] * cpu_stride0
322- + layer_index .to (tl .int64 ) * cpu_stride1
323- + cpu_mem_indexes [:, None ] * cpu_stride2
324- + cpu_k_head_index * cpu_stride3
325- + head_dim_range [None , :]
326- )
327- cpu_data = tl .load (cpu_ptr , mask = move_mask , other = 0.0 )
328-
329- gpu_ptr = (
330- gpu_kv_cache_ptr
331- + layer_index .to (tl .int64 ) * gpu_stride0
332- + gpu_mem_indexes [:, None ] * gpu_stride1
333- + gpu_k_head_index * gpu_stride2
334- + head_dim_range [None , :]
335- )
303+ block_index_start = tl .program_id (0 )
304+ split_block_num = tl .num_programs (0 )
305+ for block_index in range (block_index_start , copy_block_num , split_block_num ):
306+ token_range = block_index * TOKEN_BLOCK + tl .arange (0 , TOKEN_BLOCK )
307+ token_mask = token_range < copy_token_num
308+ gpu_mem_indexes = tl .load (gpu_mem_indexes_ptr + token_range , mask = token_mask ).to (tl .int64 )
309+ cpu_mem_indexes = tl .load (cpu_mem_indexes_ptr + token_range , mask = token_mask ).to (tl .int64 )
310+ cpu_page_indexes = tl .load (cpu_page_indexes_ptr + token_range , mask = token_mask ).to (tl .int64 )
336311
337- tl .store (
338- gpu_ptr ,
339- cpu_data ,
340- mask = move_mask ,
341- )
312+ head_dim_range = tl .arange (0 , BLOCK_HEAD_DIM )
313+ head_dim_mask = head_dim_range < head_dim
342314
343- for v_head_index in range (cpu_v_head_num ):
344- gpu_v_head_index = v_head_index + gpu_v_start_head_index
345- cpu_v_head_index = v_head_index + cpu_v_start_head_index
346-
347- cpu_ptr = (
348- cpu_kv_cache_ptr
349- + cpu_page_indexes [:, None ] * cpu_stride0
350- + layer_index .to (tl .int64 ) * cpu_stride1
351- + cpu_mem_indexes [:, None ] * cpu_stride2
352- + cpu_v_head_index * cpu_stride3
353- + head_dim_range [None , :]
354- )
355- cpu_data = tl .load (cpu_ptr , mask = move_mask , other = 0.0 )
356-
357- gpu_ptr = (
358- gpu_kv_cache_ptr
359- + layer_index .to (tl .int64 ) * gpu_stride0
360- + gpu_mem_indexes [:, None ] * gpu_stride1
361- + gpu_v_head_index * gpu_stride2
362- + head_dim_range [None , :]
363- )
315+ for layer_index in range (layer_num ):
316+ move_mask = token_mask [:, None ] & head_dim_mask [None , :]
364317
365- tl .store (
366- gpu_ptr ,
367- cpu_data ,
368- mask = move_mask ,
369- )
318+ for k_head_index in range (cpu_k_head_num ):
319+ gpu_k_head_index = k_head_index + gpu_k_start_head_index
320+ cpu_k_head_index = k_head_index + cpu_k_start_head_index
321+
322+ cpu_ptr = (
323+ cpu_kv_cache_ptr
324+ + cpu_page_indexes [:, None ] * cpu_stride0
325+ + layer_index .to (tl .int64 ) * cpu_stride1
326+ + cpu_mem_indexes [:, None ] * cpu_stride2
327+ + cpu_k_head_index * cpu_stride3
328+ + head_dim_range [None , :]
329+ )
330+ cpu_data = tl .load (cpu_ptr , mask = move_mask , other = 0.0 )
331+
332+ gpu_ptr = (
333+ gpu_kv_cache_ptr
334+ + layer_index .to (tl .int64 ) * gpu_stride0
335+ + gpu_mem_indexes [:, None ] * gpu_stride1
336+ + gpu_k_head_index * gpu_stride2
337+ + head_dim_range [None , :]
338+ )
339+
340+ tl .store (
341+ gpu_ptr ,
342+ cpu_data ,
343+ mask = move_mask ,
344+ )
345+
346+ for v_head_index in range (cpu_v_head_num ):
347+ gpu_v_head_index = v_head_index + gpu_v_start_head_index
348+ cpu_v_head_index = v_head_index + cpu_v_start_head_index
349+
350+ cpu_ptr = (
351+ cpu_kv_cache_ptr
352+ + cpu_page_indexes [:, None ] * cpu_stride0
353+ + layer_index .to (tl .int64 ) * cpu_stride1
354+ + cpu_mem_indexes [:, None ] * cpu_stride2
355+ + cpu_v_head_index * cpu_stride3
356+ + head_dim_range [None , :]
357+ )
358+ cpu_data = tl .load (cpu_ptr , mask = move_mask , other = 0.0 )
359+
360+ gpu_ptr = (
361+ gpu_kv_cache_ptr
362+ + layer_index .to (tl .int64 ) * gpu_stride0
363+ + gpu_mem_indexes [:, None ] * gpu_stride1
364+ + gpu_v_head_index * gpu_stride2
365+ + head_dim_range [None , :]
366+ )
367+
368+ tl .store (
369+ gpu_ptr ,
370+ cpu_data ,
371+ mask = move_mask ,
372+ )
370373 return
371374
372375
@@ -378,6 +381,7 @@ def load_cpu_kv_to_gpu(
378381 page_indexes : torch .Tensor ,
379382 tp_index : int ,
380383 tp_world_size : int ,
384+ grid_num : int ,
381385 _cache_data = {},
382386):
383387 """
@@ -489,12 +493,13 @@ def load_cpu_kv_to_gpu(
489493
490494 TOKEN_BLOCK = 128
491495
492- grid = (triton . cdiv ( move_token_num , TOKEN_BLOCK ) ,)
496+ grid = (grid_num ,)
493497 num_warps = 4
494498
495499 _load_cpu_cache_to_gpu [grid ](
496500 gpu_mem_indexes_ptr = gpu_mem_indexes ,
497501 copy_token_num = move_token_num ,
502+ copy_block_num = triton .cdiv (move_token_num , TOKEN_BLOCK ),
498503 cpu_mem_indexes_ptr = cpu_mem_indexes ,
499504 cpu_page_indexes_ptr = cpu_page_indexes ,
500505 gpu_kv_cache_ptr = gpu_kv_cache ,
0 commit comments