@@ -219,6 +219,91 @@ def moe_align1(
219
219
)
220
220
221
221
222
+ @triton .jit
223
+ def moe_align_fused_kernel (
224
+ topk_ids_ptr , # [token_num, topk]
225
+ topk_weights_ptr , # [token_num, topk]
226
+ expert_to_token_index_ptr , # [expert_num, token_num * topk]
227
+ expert_to_weight_ptr , # [expert_num, token_num * topk]
228
+ expert_token_num_ptr , # [expert_num]
229
+ token_num ,
230
+ topk_num : tl .constexpr ,
231
+ BLOCK_SIZE : tl .constexpr ,
232
+ ):
233
+ token_block = tl .program_id (0 )
234
+ offs = token_block * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
235
+ mask = offs < token_num * topk_num
236
+
237
+ expert_ids = tl .load (topk_ids_ptr + offs , mask = mask , other = 0 )
238
+ weights = tl .load (topk_weights_ptr + offs , mask = mask , other = 0.0 )
239
+
240
+ # 用 atomic_add 给 expert 分配写位置
241
+ write_pos = tl .atomic_add (expert_token_num_ptr + expert_ids , 1 , mask = mask )
242
+
243
+ # 按 token 顺序写 index 和 weight
244
+ tl .store (
245
+ expert_to_token_index_ptr + expert_ids * (token_num * topk_num ) + write_pos ,
246
+ offs ,
247
+ mask = mask ,
248
+ )
249
+ tl .store (
250
+ expert_to_weight_ptr + expert_ids * (token_num * topk_num ) + write_pos ,
251
+ weights ,
252
+ mask = mask ,
253
+ )
254
+
255
+
256
+ def _get_moe_align_fused_static_key (
257
+ topk_weights : torch .Tensor ,
258
+ ) -> dict :
259
+ topk_num = topk_weights .shape [1 ]
260
+ return {
261
+ "topk_num" : topk_num ,
262
+ }
263
+
264
+
265
+ def _get_moe_align_fused_configs ():
266
+ return [
267
+ {
268
+ "BLOCK_SIZE" : bt ,
269
+ "num_warps" : nw ,
270
+ }
271
+ for nw in [1 , 2 , 4 , 8 ]
272
+ for bt in [128 , 256 , 512 , 1024 , 2048 ]
273
+ ]
274
+
275
+
276
+ @autotune (
277
+ kernel_name = "moe_align_fused:v1" ,
278
+ configs_gen_func = _get_moe_align_fused_configs ,
279
+ static_key_func = _get_moe_align_fused_static_key ,
280
+ run_key_func = lambda topk_ids : topk_ids .shape [0 ],
281
+ mutates_args = ["expert_to_token_index" , "expert_to_weight" , "expert_token_num" ],
282
+ )
283
+ def moe_align_fused (
284
+ expert_to_token_index , expert_to_weight , expert_token_num , topk_ids , topk_weights , run_config : Optional [dict ] = None
285
+ ):
286
+ token_num , topk_num = topk_ids .shape
287
+ if run_config is None :
288
+ run_config = {}
289
+ BLOCK_SIZE = run_config .get ("BLOCK_SIZE" , 256 )
290
+ num_warps = run_config .get ("num_warps" , 4 )
291
+
292
+ grid = (triton .cdiv (token_num * topk_num , BLOCK_SIZE ),)
293
+ moe_align_fused_kernel [grid ](
294
+ topk_ids ,
295
+ topk_weights ,
296
+ expert_to_token_index ,
297
+ expert_to_weight ,
298
+ expert_token_num ,
299
+ token_num ,
300
+ topk_num ,
301
+ BLOCK_SIZE = BLOCK_SIZE ,
302
+ num_warps = num_warps ,
303
+ )
304
+ return expert_to_token_index , expert_to_weight , expert_token_num
305
+
306
+
222
307
@triton .jit
223
308
def moe_align2_kernel (
224
309
experts_token_num_ptr , # [expert_num,]
@@ -719,9 +804,14 @@ def fused_experts_impl(
719
804
720
805
expert_to_tokens = torch .empty ((E , topk_num * tokens_in_chunk ), dtype = torch .int32 , device = "cuda" )
721
806
expert_to_weights = torch .empty ((E , topk_num * tokens_in_chunk ), dtype = torch .float32 , device = "cuda" )
722
- moe_align (topk_ids = curr_topk_ids , out = expert_to_tokens )
723
- expert_to_token_num = torch .empty ((E ,), dtype = torch .int32 , device = "cuda" )
724
- moe_align1 (expert_to_tokens , curr_topk_weights , expert_to_weights , expert_to_token_num , topk = topk_num )
807
+ expert_to_token_num = torch .zeros ((E ,), dtype = torch .int32 , device = "cuda" )
808
+ moe_align_fused (
809
+ expert_to_token_index = expert_to_tokens ,
810
+ expert_to_weight = expert_to_weights ,
811
+ expert_token_num = expert_to_token_num ,
812
+ topk_ids = curr_topk_ids ,
813
+ topk_weights = curr_topk_weights ,
814
+ )
725
815
726
816
reused_mblock_infos = grouped_matmul (
727
817
curr_topk_ids .numel (),
0 commit comments