|
11 | 11 |
|
12 | 12 | from lmdeploy.utils import get_logger |
13 | 13 |
|
| 14 | +from .utils import get_device_props |
| 15 | + |
14 | 16 | logger = get_logger('lmdeploy') |
15 | 17 |
|
16 | 18 | TRITON_VERSION = version.parse(triton.__version__) |
@@ -499,6 +501,19 @@ def _kernel_meta_sm9x(BLOCK_DMODEL: int, BLOCK_H: int): |
499 | 501 | return _kernel_meta_default(BLOCK_DMODEL, BLOCK_H) |
500 | 502 |
|
501 | 503 |
|
| 504 | +def _get_split_k(device_idx: int, head_grid: int, batch_size: int): |
| 505 | + """get split k.""" |
| 506 | + props = get_device_props(device_idx) |
| 507 | + num_sm = props['multi_processor_count'] |
| 508 | + # estimated occupancy 12.5% |
| 509 | + warps_per_sm = props['warps_per_sm'] // 8 |
| 510 | + |
| 511 | + SPLIT_K = triton.cdiv(num_sm * warps_per_sm // head_grid, triton.next_power_of_2(batch_size)) |
| 512 | + SPLIT_K = 1 << (SPLIT_K.bit_length() - 1) |
| 513 | + SPLIT_K = max(min(SPLIT_K, 64), 4) |
| 514 | + return SPLIT_K |
| 515 | + |
| 516 | + |
502 | 517 | def paged_attention_fwd( |
503 | 518 | q: Tensor, |
504 | 519 | k: Tensor, |
@@ -579,15 +594,18 @@ def _get_block_d(Lk): |
579 | 594 | is_decoding = q.shape[-3] == kv_seqlens.size(0) |
580 | 595 | assert is_decoding, 'we only support decoding paged attention.' |
581 | 596 |
|
582 | | - SPLIT_K = 4 |
583 | | - if quant_policy != 4: |
584 | | - acc = q.new_empty(batch, head, SPLIT_K, Lv + 2, dtype=torch.float32) |
585 | | - else: |
586 | | - acc = q.new_empty(batch, head, SPLIT_K, o.shape[-1] + 2, dtype=torch.float32) |
587 | 597 | BLOCK_DMODEL, BLOCK_DMODEL1, BLOCK_DV = _get_block_d(Lq) |
588 | 598 | p2_kv_group_num = triton.next_power_of_2(kv_group_num) |
589 | 599 | BLOCK_H = max(16, min(BLOCK, p2_kv_group_num)) |
590 | 600 | grid_1 = triton.cdiv(head, min(BLOCK_H, kv_group_num)) |
| 601 | + |
| 602 | + SPLIT_K = _get_split_k(q.device.index, grid_1, batch) |
| 603 | + |
| 604 | + if quant_policy != 4: |
| 605 | + acc = q.new_empty(batch, head, SPLIT_K, Lv + 2, dtype=torch.float32) |
| 606 | + else: |
| 607 | + acc = q.new_empty(batch, head, SPLIT_K, o.shape[-1] + 2, dtype=torch.float32) |
| 608 | + |
591 | 609 | grid = ( |
592 | 610 | grid_1, |
593 | 611 | SPLIT_K, |
|
0 commit comments