Skip to content

Commit 32f7db7

Browse files
authored
v100 triton kernel fix (#1040)
1 parent 0f5d0cc commit 32f7db7

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

lightllm/common/basemodel/triton_kernel/gather_token_id.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def _fwd_kernel_scatter(
1616
num_size,
1717
HAS_OUT_IS_NONE: tl.constexpr,
1818
BLOCK: tl.constexpr,
19+
OLD_VERSION_TRITON: tl.constexpr,
1920
):
2021
block_index = tl.program_id(0)
2122
block_range = block_index * BLOCK + tl.arange(0, BLOCK)
@@ -27,6 +28,8 @@ def _fwd_kernel_scatter(
2728

2829
if not HAS_OUT_IS_NONE:
2930
cur_has_out = tl.load(b_has_out + block_range, mask=block_mask, other=False)
31+
if OLD_VERSION_TRITON:
32+
cur_has_out = cur_has_out != 0
3033
tl.store(
3134
req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index,
3235
cur_next_token_id,
@@ -76,6 +79,7 @@ def scatter_token(
7679
num_size=batch_size,
7780
HAS_OUT_IS_NONE=b_has_out is None,
7881
BLOCK=BLOCK,
82+
OLD_VERSION_TRITON=triton.__version__ < "3.2.0",
7983
num_warps=num_warps,
8084
num_stages=1,
8185
)

lightllm/common/basemodel/triton_kernel/gen_sampling_params.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def _token_id_counter_update_kernel(
125125
batch_size,
126126
HAS_MASK: tl.constexpr,
127127
BLOCK: tl.constexpr,
128+
OLD_VERSION_TRITON: tl.constexpr,
128129
):
129130

130131
block_start_index = tl.program_id(0) * BLOCK
@@ -136,6 +137,8 @@ def _token_id_counter_update_kernel(
136137

137138
if HAS_MASK:
138139
mask = tl.load(mask_ptr + offs, mask=loc_mask, other=False)
140+
if OLD_VERSION_TRITON:
141+
mask = mask != 0
139142
tl.atomic_add(
140143
req_to_out_token_id_counter_ptr + req_idx * counter_stride_m + token_ids * counter_stride_n,
141144
1,
@@ -170,6 +173,7 @@ def update_req_to_token_id_counter(
170173
batch_size=batch_size,
171174
HAS_MASK=has_mask,
172175
BLOCK=BLOCK,
176+
OLD_VERSION_TRITON=triton.__version__ < "3.2.0",
173177
num_warps=1,
174178
)
175179
return

0 commit comments

Comments
 (0)