Skip to content

Commit 9901f76

Browse files
authored
optimize moe gate for 128 experts (#3500)
1 parent b4854b1 commit 9901f76

File tree

1 file changed

+32
-32
lines changed

1 file changed

+32
-32
lines changed

src/turbomind/kernels/gemm/moe_utils_v2.cu

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -612,49 +612,49 @@ void invokeMoeGate_V2(int* f2n, // [e*n] -> n
612612
softmax,
613613
norm_topk,
614614
routed_scale);
615-
};
616615

617-
auto fail = [&] {
618-
std::cerr << __FILE__ << "(" << __LINE__ << "): unsupported moe config: expert_num=" << experts
619-
<< ", top_k=" << experts_per_token << ", softmax=" << softmax << ", norm_topk=" << norm_topk << "\n";
620-
std::abort();
616+
return true;
621617
};
622618

623619
if (!softmax && norm_topk) {
624620
// norm top-k is part of softmax impl
625-
fail();
621+
TM_CHECK(0) << softmax << " " << norm_topk;
626622
}
627623

628-
if (experts <= 8) {
629-
if (experts_per_token <= 2) {
630-
invoke(_Int<8>, _Int<2>, _Int<8>, _Int<4>);
631-
}
632-
else {
633-
invoke(_Int<8>, _Int<8>, _Int<8>, _Int<4>);
634-
}
635-
}
636-
else if (experts <= 64) {
637-
if (experts_per_token <= 4) {
638-
invoke(_Int<64>, _Int<4>, _Int<16>, _Int<4>);
639-
}
640-
else if (experts_per_token <= 8) {
641-
invoke(_Int<64>, _Int<8>, _Int<16>, _Int<4>);
624+
auto dispatch = [&] {
625+
if (experts <= 8) {
626+
if (experts_per_token <= 2) {
627+
return invoke(_Int<8>, _Int<2>, _Int<8>, _Int<4>);
628+
}
629+
else {
630+
return invoke(_Int<8>, _Int<8>, _Int<8>, _Int<4>);
631+
}
642632
}
643-
else {
644-
fail();
633+
else if (experts <= 64) {
634+
if (experts_per_token <= 4) {
635+
return invoke(_Int<64>, _Int<4>, _Int<16>, _Int<4>);
636+
}
637+
else if (experts_per_token <= 8) {
638+
return invoke(_Int<64>, _Int<8>, _Int<16>, _Int<4>);
639+
}
645640
}
646-
}
647-
else if (experts <= 160) {
648-
if (experts_per_token <= 8) {
649-
invoke(_Int<160>, _Int<8>, _Int<10>, _Int<2>);
641+
else if (experts <= 128) {
642+
if (experts_per_token <= 8) {
643+
return invoke(_Int<128>, _Int<8>, _Int<16>, _Int<4>);
644+
}
650645
}
651-
else {
652-
fail();
646+
else if (experts <= 160) {
647+
if (experts_per_token <= 8) {
648+
return invoke(_Int<160>, _Int<8>, _Int<10>, _Int<2>);
649+
}
653650
}
654-
}
655-
else {
656-
fail();
657-
}
651+
return false;
652+
};
653+
654+
auto success = dispatch();
655+
656+
TM_CHECK(success) << "unsupported moe config: expert_num=" << experts << ", top_k=" << experts_per_token
657+
<< ", softmax=" << softmax << ", norm_topk=" << norm_topk;
658658

659659
{
660660
constexpr int threads = (1 << base_log_tile) / kMoeGateVecSize;

0 commit comments

Comments
 (0)