@@ -612,49 +612,49 @@ void invokeMoeGate_V2(int* f2n, // [e*n] -> n
612612 softmax,
613613 norm_topk,
614614 routed_scale);
615- };
616615
617- auto fail = [&] {
618- std::cerr << __FILE__ << " (" << __LINE__ << " ): unsupported moe config: expert_num=" << experts
619- << " , top_k=" << experts_per_token << " , softmax=" << softmax << " , norm_topk=" << norm_topk << " \n " ;
620- std::abort ();
616+ return true ;
621617 };
622618
623619 if (!softmax && norm_topk) {
624620 // norm top-k is part of softmax impl
625- fail () ;
621+ TM_CHECK ( 0 ) << softmax << " " << norm_topk ;
626622 }
627623
628- if (experts <= 8 ) {
629- if (experts_per_token <= 2 ) {
630- invoke (_Int<8 >, _Int<2 >, _Int<8 >, _Int<4 >);
631- }
632- else {
633- invoke (_Int<8 >, _Int<8 >, _Int<8 >, _Int<4 >);
634- }
635- }
636- else if (experts <= 64 ) {
637- if (experts_per_token <= 4 ) {
638- invoke (_Int<64 >, _Int<4 >, _Int<16 >, _Int<4 >);
639- }
640- else if (experts_per_token <= 8 ) {
641- invoke (_Int<64 >, _Int<8 >, _Int<16 >, _Int<4 >);
624+ auto dispatch = [&] {
625+ if (experts <= 8 ) {
626+ if (experts_per_token <= 2 ) {
627+ return invoke (_Int<8 >, _Int<2 >, _Int<8 >, _Int<4 >);
628+ }
629+ else {
630+ return invoke (_Int<8 >, _Int<8 >, _Int<8 >, _Int<4 >);
631+ }
642632 }
643- else {
644- fail ();
633+ else if (experts <= 64 ) {
634+ if (experts_per_token <= 4 ) {
635+ return invoke (_Int<64 >, _Int<4 >, _Int<16 >, _Int<4 >);
636+ }
637+ else if (experts_per_token <= 8 ) {
638+ return invoke (_Int<64 >, _Int<8 >, _Int<16 >, _Int<4 >);
639+ }
645640 }
646- }
647- else if (experts <= 160 ) {
648- if (experts_per_token <= 8 ) {
649- invoke (_Int< 160 >, _Int< 8 >, _Int< 10 >, _Int< 2 >);
641+ else if (experts <= 128 ) {
642+ if (experts_per_token <= 8 ) {
643+ return invoke (_Int< 128 >, _Int< 8 >, _Int< 16 >, _Int< 4 >);
644+ }
650645 }
651- else {
652- fail ();
646+ else if (experts <= 160 ) {
647+ if (experts_per_token <= 8 ) {
648+ return invoke (_Int<160 >, _Int<8 >, _Int<10 >, _Int<2 >);
649+ }
653650 }
654- }
655- else {
656- fail ();
657- }
651+ return false ;
652+ };
653+
654+ auto success = dispatch ();
655+
656+ TM_CHECK (success) << " unsupported moe config: expert_num=" << experts << " , top_k=" << experts_per_token
657+ << " , softmax=" << softmax << " , norm_topk=" << norm_topk;
658658
659659 {
660660 constexpr int threads = (1 << base_log_tile) / kMoeGateVecSize ;
0 commit comments