Skip to content

Commit 7391a99

Browse files
SangChengCsangchengmeng
andauthored
fix grouped_topk tl.sort when numel=1 (#1101)
Co-authored-by: sangchengmeng <[email protected]>
1 parent 2d95b73 commit 7391a99

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

lightllm/common/fused_moe/grouped_topk.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,10 @@ def grouped_topk_kernel(
159159
axis=1,
160160
)
161161

162-
sorted_group_value = tl.sort(group_value, descending=True)
162+
if EXPERT_GROUP_NUM > 1:
163+
sorted_group_value = tl.sort(group_value, descending=True)
164+
else:
165+
sorted_group_value = group_value
163166
group_topk_value = tl.sum(tl.where(offs_group == group_topk_num - 1, sorted_group_value, 0.0))
164167
mask_group_scores = tl.where(
165168
((group_value >= group_topk_value)[:, None]) & ((offs_group_v < group_expert_num)[None, :]),

0 commit comments

Comments
 (0)