@@ -15,6 +15,60 @@ def get_scalar_type(num_bits: int, has_zp: bool):
1515 return scalar_types .uint4b8 if num_bits == 4 else scalar_types .uint8b128
1616
1717
18+ def moe_wna16_marlin_gemm (
19+ a : torch .Tensor ,
20+ c_or_none : Optional [torch .Tensor ],
21+ b_q_weight : torch .Tensor ,
22+ b_scales : torch .Tensor ,
23+ b_zeros_or_none : Optional [torch .Tensor ],
24+ g_idx_or_none : Optional [torch .Tensor ],
25+ perm_or_none : Optional [torch .Tensor ],
26+ workspace : torch .Tensor ,
27+ sorted_token_ids : torch .Tensor ,
28+ expert_ids : torch .Tensor ,
29+ num_tokens_post_padded : torch .Tensor ,
30+ topk_weights : torch .Tensor ,
31+ moe_block_size : int ,
32+ top_k : int ,
33+ mul_topk_weights : bool ,
34+ is_ep : bool ,
35+ b_q_type_id : int ,
36+ size_m : int ,
37+ size_n : int ,
38+ size_k : int ,
39+ is_k_full : bool ,
40+ use_atomic_add : bool ,
41+ use_fp32_reduce : bool ,
42+ is_zp_float : bool ,
43+ ):
44+ return torch .ops .sgl_kernel .moe_wna16_marlin_gemm .default (
45+ a ,
46+ c_or_none ,
47+ b_q_weight ,
48+ b_scales ,
49+ b_zeros_or_none ,
50+ g_idx_or_none ,
51+ perm_or_none ,
52+ workspace ,
53+ sorted_token_ids ,
54+ expert_ids ,
55+ num_tokens_post_padded ,
56+ topk_weights ,
57+ moe_block_size = moe_block_size ,
58+ top_k = top_k ,
59+ mul_topk_weights = mul_topk_weights ,
60+ is_ep = is_ep ,
61+ b_q_type_id = b_q_type_id ,
62+ size_m = size_m ,
63+ size_n = size_n ,
64+ size_k = size_k ,
65+ is_k_full = is_k_full ,
66+ use_atomic_add = use_atomic_add ,
67+ use_fp32_reduce = use_fp32_reduce ,
68+ is_zp_float = is_zp_float ,
69+ )
70+
71+
1872def fused_marlin_moe (
1973 hidden_states : torch .Tensor ,
2074 w1 : torch .Tensor ,
0 commit comments