Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion csrc/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1765,5 +1765,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer);

m.def("is_sm90_compiled", deep_ep::is_sm90_compiled);
m.attr("topk_idx_t") = py::cast(c10::CppTypeToScalarType<deep_ep::topk_idx_t>::value);
m.attr("topk_idx_t") =
py::reinterpret_borrow<py::object>((PyObject*)torch::getTHPDtype(c10::CppTypeToScalarType<deep_ep::topk_idx_t>::value));
}
6 changes: 3 additions & 3 deletions tests/test_low_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import torch
import torch.distributed as dist
from functools import partial
from typing import Literal
from typing import Literal, Set

import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back


def simulate_failure_and_skip(rank: int, api: Literal["dispatch", "combine", "clean"], expected_masked_ranks: set[int]):
def simulate_failure_and_skip(rank: int, api: Literal["dispatch", "combine", "clean"], expected_masked_ranks: Set[int]):
# Simulates rank failure when the rank first calls the corresponding communication API
failed_api_ranks = {
# API -> rank to fail (rank fails when it first calls the corresponding communication API)
Expand All @@ -29,7 +29,7 @@ def simulate_failure_and_skip(rank: int, api: Literal["dispatch", "combine", "cl


def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], buffer: deep_ep.Buffer, mask_status: torch.Tensor,
expected_masked_ranks: set[int]):
expected_masked_ranks: Set[int]):
buffer.low_latency_query_mask_buffer(mask_status)
assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks

Expand Down