Skip to content

Commit 8e9f05e

Browse files
authored
Update marlin moe kernel interface (#13322)
1 parent bc08352 commit 8e9f05e

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

sgl-kernel/python/sgl_kernel/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
silu_and_mul,
3535
)
3636
from sgl_kernel.expert_specialization import es_fp8_blockwise_scaled_grouped_mm
37-
from sgl_kernel.fused_moe import fused_marlin_moe
37+
from sgl_kernel.fused_moe import fused_marlin_moe, moe_wna16_marlin_gemm
3838
from sgl_kernel.gemm import (
3939
awq_dequantize,
4040
bmm_fp8,

sgl-kernel/python/sgl_kernel/fused_moe.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,60 @@ def get_scalar_type(num_bits: int, has_zp: bool):
1515
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
1616

1717

18+
def moe_wna16_marlin_gemm(
19+
a: torch.Tensor,
20+
c_or_none: Optional[torch.Tensor],
21+
b_q_weight: torch.Tensor,
22+
b_scales: torch.Tensor,
23+
b_zeros_or_none: Optional[torch.Tensor],
24+
g_idx_or_none: Optional[torch.Tensor],
25+
perm_or_none: Optional[torch.Tensor],
26+
workspace: torch.Tensor,
27+
sorted_token_ids: torch.Tensor,
28+
expert_ids: torch.Tensor,
29+
num_tokens_post_padded: torch.Tensor,
30+
topk_weights: torch.Tensor,
31+
moe_block_size: int,
32+
top_k: int,
33+
mul_topk_weights: bool,
34+
is_ep: bool,
35+
b_q_type_id: int,
36+
size_m: int,
37+
size_n: int,
38+
size_k: int,
39+
is_k_full: bool,
40+
use_atomic_add: bool,
41+
use_fp32_reduce: bool,
42+
is_zp_float: bool,
43+
):
44+
return torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default(
45+
a,
46+
c_or_none,
47+
b_q_weight,
48+
b_scales,
49+
b_zeros_or_none,
50+
g_idx_or_none,
51+
perm_or_none,
52+
workspace,
53+
sorted_token_ids,
54+
expert_ids,
55+
num_tokens_post_padded,
56+
topk_weights,
57+
moe_block_size=moe_block_size,
58+
top_k=top_k,
59+
mul_topk_weights=mul_topk_weights,
60+
is_ep=is_ep,
61+
b_q_type_id=b_q_type_id,
62+
size_m=size_m,
63+
size_n=size_n,
64+
size_k=size_k,
65+
is_k_full=is_k_full,
66+
use_atomic_add=use_atomic_add,
67+
use_fp32_reduce=use_fp32_reduce,
68+
is_zp_float=is_zp_float,
69+
)
70+
71+
1872
def fused_marlin_moe(
1973
hidden_states: torch.Tensor,
2074
w1: torch.Tensor,

0 commit comments

Comments
 (0)