File tree Expand file tree Collapse file tree 2 files changed +8
-0
lines changed
lightllm/common/basemodel/triton_kernel Expand file tree Collapse file tree 2 files changed +8
-0
lines changed Original file line number Diff line number Diff line change @@ -16,6 +16,7 @@ def _fwd_kernel_scatter(
16
16
num_size ,
17
17
HAS_OUT_IS_NONE : tl .constexpr ,
18
18
BLOCK : tl .constexpr ,
19
+ OLD_VERSION_TRITON : tl .constexpr ,
19
20
):
20
21
block_index = tl .program_id (0 )
21
22
block_range = block_index * BLOCK + tl .arange (0 , BLOCK )
@@ -27,6 +28,8 @@ def _fwd_kernel_scatter(
27
28
28
29
if not HAS_OUT_IS_NONE :
29
30
cur_has_out = tl .load (b_has_out + block_range , mask = block_mask , other = False )
31
+ if OLD_VERSION_TRITON :
32
+ cur_has_out = cur_has_out != 0
30
33
tl .store (
31
34
req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index ,
32
35
cur_next_token_id ,
@@ -76,6 +79,7 @@ def scatter_token(
76
79
num_size = batch_size ,
77
80
HAS_OUT_IS_NONE = b_has_out is None ,
78
81
BLOCK = BLOCK ,
82
+ OLD_VERSION_TRITON = triton .__version__ < "3.2.0" ,
79
83
num_warps = num_warps ,
80
84
num_stages = 1 ,
81
85
)
Original file line number Diff line number Diff line change @@ -125,6 +125,7 @@ def _token_id_counter_update_kernel(
125
125
batch_size ,
126
126
HAS_MASK : tl .constexpr ,
127
127
BLOCK : tl .constexpr ,
128
+ OLD_VERSION_TRITON : tl .constexpr ,
128
129
):
129
130
130
131
block_start_index = tl .program_id (0 ) * BLOCK
@@ -136,6 +137,8 @@ def _token_id_counter_update_kernel(
136
137
137
138
if HAS_MASK :
138
139
mask = tl .load (mask_ptr + offs , mask = loc_mask , other = False )
140
+ if OLD_VERSION_TRITON :
141
+ mask = mask != 0
139
142
tl .atomic_add (
140
143
req_to_out_token_id_counter_ptr + req_idx * counter_stride_m + token_ids * counter_stride_n ,
141
144
1 ,
@@ -170,6 +173,7 @@ def update_req_to_token_id_counter(
170
173
batch_size = batch_size ,
171
174
HAS_MASK = has_mask ,
172
175
BLOCK = BLOCK ,
176
+ OLD_VERSION_TRITON = triton .__version__ < "3.2.0" ,
173
177
num_warps = 1 ,
174
178
)
175
179
return
You can’t perform that action at this time.
0 commit comments