Skip to content

Commit 84507b7

Browse files
authored
optimize longcontext decoding (#3510)
* optimize longcontext decoding * add min split k * more warps
1 parent 6bddb1c commit 84507b7

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

lmdeploy/pytorch/kernels/cuda/pagedattention.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
from lmdeploy.utils import get_logger
1313

14+
from .utils import get_device_props
15+
1416
logger = get_logger('lmdeploy')
1517

1618
TRITON_VERSION = version.parse(triton.__version__)
@@ -499,6 +501,19 @@ def _kernel_meta_sm9x(BLOCK_DMODEL: int, BLOCK_H: int):
499501
return _kernel_meta_default(BLOCK_DMODEL, BLOCK_H)
500502

501503

504+
def _get_split_k(device_idx: int, head_grid: int, batch_size: int):
505+
"""get split k."""
506+
props = get_device_props(device_idx)
507+
num_sm = props['multi_processor_count']
508+
# estimated occupancy 12.5%
509+
warps_per_sm = props['warps_per_sm'] // 8
510+
511+
SPLIT_K = triton.cdiv(num_sm * warps_per_sm // head_grid, triton.next_power_of_2(batch_size))
512+
SPLIT_K = 1 << (SPLIT_K.bit_length() - 1)
513+
SPLIT_K = max(min(SPLIT_K, 64), 4)
514+
return SPLIT_K
515+
516+
502517
def paged_attention_fwd(
503518
q: Tensor,
504519
k: Tensor,
@@ -579,15 +594,18 @@ def _get_block_d(Lk):
579594
is_decoding = q.shape[-3] == kv_seqlens.size(0)
580595
assert is_decoding, 'we only support decoding paged attention.'
581596

582-
SPLIT_K = 4
583-
if quant_policy != 4:
584-
acc = q.new_empty(batch, head, SPLIT_K, Lv + 2, dtype=torch.float32)
585-
else:
586-
acc = q.new_empty(batch, head, SPLIT_K, o.shape[-1] + 2, dtype=torch.float32)
587597
BLOCK_DMODEL, BLOCK_DMODEL1, BLOCK_DV = _get_block_d(Lq)
588598
p2_kv_group_num = triton.next_power_of_2(kv_group_num)
589599
BLOCK_H = max(16, min(BLOCK, p2_kv_group_num))
590600
grid_1 = triton.cdiv(head, min(BLOCK_H, kv_group_num))
601+
602+
SPLIT_K = _get_split_k(q.device.index, grid_1, batch)
603+
604+
if quant_policy != 4:
605+
acc = q.new_empty(batch, head, SPLIT_K, Lv + 2, dtype=torch.float32)
606+
else:
607+
acc = q.new_empty(batch, head, SPLIT_K, o.shape[-1] + 2, dtype=torch.float32)
608+
591609
grid = (
592610
grid_1,
593611
SPLIT_K,

0 commit comments

Comments
 (0)