Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
if not _is_cuda:
cache = cache.to(dtype)

if dtype == torch.float32 or (
if (
(not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512])
and not (_is_cpu and _is_cpu_amx_available)
and not (_is_xpu)
Expand Down Expand Up @@ -273,11 +273,7 @@ def forward_cuda(
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if (
_is_cuda
and (self.head_size in [64, 128, 256, 512])
and self.dtype != torch.float32
):
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
apply_rope_with_cos_sin_cache_inplace(
positions=positions,
query=query,
Expand Down
6 changes: 6 additions & 0 deletions test/srt/cpu/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ def single_test(
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 32, 8),
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 16, 4),
(512, 128, 311, 10000, False, torch.bfloat16, "cpu", 3, 39, 4, 2),
(64, 64, 32, 8000, True, torch.float32, "cpu", 32, 32, 1, 1),
(256, 128, 4096, 10000, True, torch.float32, "cpu", 2, 512, 32, 8),
(512, 128, 311, 10000, True, torch.float32, "cpu", 3, 39, 4, 2),
(128, 128, 2048, 10000, False, torch.float32, "cpu", 2, 512, 32, 8),
(128, 128, 2048, 10000, False, torch.float32, "cpu", 2, 512, 16, 4),
(512, 128, 311, 10000, False, torch.float32, "cpu", 3, 39, 4, 2),
]

for (
Expand Down
2 changes: 1 addition & 1 deletion test/srt/rotary_embedding/test_mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class MRoPETestInfo(NamedTuple):
],
)
@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
@pytest.mark.parametrize("num_tokens", num_tokens_list)
def test_mrope(
model_name: str,
Expand Down
Loading