Skip to content

Commit 2c4be50

Browse files
shihaobaiClaude Code
authored andcommitted
add moe_align_fused (#1054)
1 parent 03633c1 commit 2c4be50

File tree

3 files changed

+193
-3
lines changed

3 files changed

+193
-3
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,91 @@ def moe_align1(
219219
)
220220

221221

222+
@triton.jit
223+
def moe_align_fused_kernel(
224+
topk_ids_ptr, # [token_num, topk]
225+
topk_weights_ptr, # [token_num, topk]
226+
expert_to_token_index_ptr, # [expert_num, token_num * topk]
227+
expert_to_weight_ptr, # [expert_num, token_num * topk]
228+
expert_token_num_ptr, # [expert_num]
229+
token_num,
230+
topk_num: tl.constexpr,
231+
BLOCK_SIZE: tl.constexpr,
232+
):
233+
token_block = tl.program_id(0)
234+
offs = token_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
235+
mask = offs < token_num * topk_num
236+
237+
expert_ids = tl.load(topk_ids_ptr + offs, mask=mask, other=0)
238+
weights = tl.load(topk_weights_ptr + offs, mask=mask, other=0.0)
239+
240+
# 用 atomic_add 给 expert 分配写位置
241+
write_pos = tl.atomic_add(expert_token_num_ptr + expert_ids, 1, mask=mask)
242+
243+
# 按 token 顺序写 index 和 weight
244+
tl.store(
245+
expert_to_token_index_ptr + expert_ids * (token_num * topk_num) + write_pos,
246+
offs,
247+
mask=mask,
248+
)
249+
tl.store(
250+
expert_to_weight_ptr + expert_ids * (token_num * topk_num) + write_pos,
251+
weights,
252+
mask=mask,
253+
)
254+
255+
256+
def _get_moe_align_fused_static_key(
257+
topk_weights: torch.Tensor,
258+
) -> dict:
259+
topk_num = topk_weights.shape[1]
260+
return {
261+
"topk_num": topk_num,
262+
}
263+
264+
265+
def _get_moe_align_fused_configs():
266+
return [
267+
{
268+
"BLOCK_SIZE": bt,
269+
"num_warps": nw,
270+
}
271+
for nw in [1, 2, 4, 8]
272+
for bt in [128, 256, 512, 1024, 2048]
273+
]
274+
275+
276+
@autotune(
277+
kernel_name="moe_align_fused:v1",
278+
configs_gen_func=_get_moe_align_fused_configs,
279+
static_key_func=_get_moe_align_fused_static_key,
280+
run_key_func=lambda topk_ids: topk_ids.shape[0],
281+
mutates_args=["expert_to_token_index", "expert_to_weight", "expert_token_num"],
282+
)
283+
def moe_align_fused(
284+
expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config: Optional[dict] = None
285+
):
286+
token_num, topk_num = topk_ids.shape
287+
if run_config is None:
288+
run_config = {}
289+
BLOCK_SIZE = run_config.get("BLOCK_SIZE", 256)
290+
num_warps = run_config.get("num_warps", 4)
291+
292+
grid = (triton.cdiv(token_num * topk_num, BLOCK_SIZE),)
293+
moe_align_fused_kernel[grid](
294+
topk_ids,
295+
topk_weights,
296+
expert_to_token_index,
297+
expert_to_weight,
298+
expert_token_num,
299+
token_num,
300+
topk_num,
301+
BLOCK_SIZE=BLOCK_SIZE,
302+
num_warps=num_warps,
303+
)
304+
return expert_to_token_index, expert_to_weight, expert_token_num
305+
306+
222307
@triton.jit
223308
def moe_align2_kernel(
224309
experts_token_num_ptr, # [expert_num,]
@@ -719,9 +804,14 @@ def fused_experts_impl(
719804

720805
expert_to_tokens = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.int32, device="cuda")
721806
expert_to_weights = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.float32, device="cuda")
722-
moe_align(topk_ids=curr_topk_ids, out=expert_to_tokens)
723-
expert_to_token_num = torch.empty((E,), dtype=torch.int32, device="cuda")
724-
moe_align1(expert_to_tokens, curr_topk_weights, expert_to_weights, expert_to_token_num, topk=topk_num)
807+
expert_to_token_num = torch.zeros((E,), dtype=torch.int32, device="cuda")
808+
moe_align_fused(
809+
expert_to_token_index=expert_to_tokens,
810+
expert_to_weight=expert_to_weights,
811+
expert_token_num=expert_to_token_num,
812+
topk_ids=curr_topk_ids,
813+
topk_weights=curr_topk_weights,
814+
)
725815

726816
reused_mblock_infos = grouped_matmul(
727817
curr_topk_ids.numel(),
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
{
2+
"1": {
3+
"BLOCK_SIZE": 256,
4+
"num_warps": 8
5+
},
6+
"100": {
7+
"BLOCK_SIZE": 128,
8+
"num_warps": 4
9+
},
10+
"1024": {
11+
"BLOCK_SIZE": 256,
12+
"num_warps": 8
13+
},
14+
"128": {
15+
"BLOCK_SIZE": 128,
16+
"num_warps": 4
17+
},
18+
"16": {
19+
"BLOCK_SIZE": 128,
20+
"num_warps": 4
21+
},
22+
"16384": {
23+
"BLOCK_SIZE": 128,
24+
"num_warps": 4
25+
},
26+
"2048": {
27+
"BLOCK_SIZE": 128,
28+
"num_warps": 8
29+
},
30+
"256": {
31+
"BLOCK_SIZE": 128,
32+
"num_warps": 4
33+
},
34+
"32": {
35+
"BLOCK_SIZE": 128,
36+
"num_warps": 4
37+
},
38+
"4096": {
39+
"BLOCK_SIZE": 128,
40+
"num_warps": 4
41+
},
42+
"64": {
43+
"BLOCK_SIZE": 128,
44+
"num_warps": 4
45+
},
46+
"8": {
47+
"BLOCK_SIZE": 256,
48+
"num_warps": 8
49+
}
50+
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
{
2+
"1": {
3+
"BLOCK_SIZE": 256,
4+
"num_warps": 8
5+
},
6+
"100": {
7+
"BLOCK_SIZE": 128,
8+
"num_warps": 4
9+
},
10+
"1024": {
11+
"BLOCK_SIZE": 256,
12+
"num_warps": 4
13+
},
14+
"128": {
15+
"BLOCK_SIZE": 256,
16+
"num_warps": 8
17+
},
18+
"16": {
19+
"BLOCK_SIZE": 128,
20+
"num_warps": 4
21+
},
22+
"16384": {
23+
"BLOCK_SIZE": 256,
24+
"num_warps": 8
25+
},
26+
"2048": {
27+
"BLOCK_SIZE": 256,
28+
"num_warps": 8
29+
},
30+
"256": {
31+
"BLOCK_SIZE": 256,
32+
"num_warps": 8
33+
},
34+
"32": {
35+
"BLOCK_SIZE": 128,
36+
"num_warps": 4
37+
},
38+
"4096": {
39+
"BLOCK_SIZE": 128,
40+
"num_warps": 4
41+
},
42+
"64": {
43+
"BLOCK_SIZE": 128,
44+
"num_warps": 4
45+
},
46+
"8": {
47+
"BLOCK_SIZE": 256,
48+
"num_warps": 8
49+
}
50+
}

0 commit comments

Comments
 (0)