Skip to content

Commit 77a92be

Browse files
authored
fix: minor updates and fixs for unit_tests to match current code (#1083)
1 parent 687d37c commit 77a92be

File tree

5 files changed

+27
-15
lines changed

5 files changed

+27
-15
lines changed

unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,10 @@ def test_gen_decode_params_basic():
1313
b_kv_seq_len,
1414
b1_cu_kv_seq_len,
1515
position_ids,
16-
max_q_seq_len,
17-
max_kv_seq_len,
1816
) = gen_decode_params(b_seq_len)
1917

2018
true_b_q_seq_len = torch.ones_like(b_seq_len)
21-
b_q_seq_len, b1_cu_q_seq_len, b_kv_seq_len, b1_cu_kv_seq_len, position_ids, max_q_seq_len, max_kv_seq_len
2219

23-
assert max_q_seq_len == 1
24-
assert max_kv_seq_len == b_seq_len.max().item()
2520
assert torch.equal(b_q_seq_len, true_b_q_seq_len)
2621
assert torch.equal(b1_cu_q_seq_len, torch.nn.functional.pad(torch.cumsum(true_b_q_seq_len, dim=0), (1, 0), value=0))
2722
assert torch.equal(b_kv_seq_len, b_seq_len)

unit_tests/common/basemodel/triton_kernel/test_gen_prefill_params.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,8 @@ def test_gen_prefill_params_basic():
2020
b_kv_seq_len,
2121
b1_cu_kv_seq_len,
2222
position_ids,
23-
max_q_seq_len,
24-
max_kv_seq_len,
2523
) = gen_prefill_params(input_token_num, b_ready_cache_len, b_seq_len)
2624

27-
assert max_q_seq_len == true_b_q_seq_len.max().item()
28-
assert max_kv_seq_len == b_seq_len.max().item()
2925
assert torch.equal(b_q_seq_len, true_b_q_seq_len)
3026
assert torch.equal(b1_cu_q_seq_len, torch.nn.functional.pad(torch.cumsum(true_b_q_seq_len, dim=0), (1, 0), value=0))
3127
assert torch.equal(b_kv_seq_len, b_seq_len)

unit_tests/common/fused_moe/test_softmax_topk.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ def benchmark(M, N, K, renorm, runs):
1818
sgl_vals = torch.empty((M, K), dtype=torch.float32, device="cuda")
1919
sgl_ids = torch.empty((M, K), dtype=torch.int32, device="cuda")
2020
# Warm-up
21-
sgl_ops.topk_softmax(sgl_vals, sgl_ids, torch.empty_like(sgl_ids), gating)
21+
sgl_ops.topk_softmax(sgl_vals, sgl_ids, gating)
2222
torch.cuda.synchronize()
2323
start = torch.cuda.Event(True)
2424
end = torch.cuda.Event(True)
2525
start.record()
2626
for _ in range(runs):
27-
sgl_ops.topk_softmax(sgl_vals, sgl_ids, torch.empty_like(sgl_ids), gating)
27+
sgl_ops.topk_softmax(sgl_vals, sgl_ids, gating)
2828
if renorm:
2929
sgl_vals.div_(sgl_vals.sum(-1, keepdim=True).clamp_min(1e-8))
3030

unit_tests/server/core/objs/test_sampling_params.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def test_decode_node_initialization():
127127
}
128128
node.initialize(data)
129129
assert node.exists is True
130-
assert node.node_id_high == (12345678901234567890 >> 64) & 0xFFFFFFFFFFFFFFFF
131-
assert node.node_id_low == 12345678901234567890 & 0xFFFFFFFFFFFFFFFF
130+
assert node.node_id.node_id_high == (12345678901234567890 >> 64) & 0xFFFFFFFFFFFFFFFF
131+
assert node.node_id.node_id_low == 12345678901234567890 & 0xFFFFFFFFFFFFFFFF
132132
assert node.ip[0] == 192
133133
assert node.ip[1] == 168
134134
assert node.ip[2] == 1

unit_tests/server/core/objs/test_shm_req_manager.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,34 @@
1+
import os
12
import pytest
23
import time
3-
from lightllm.utils.envs_utils import set_env_start_args
4+
from easydict import EasyDict
5+
from lightllm.utils.envs_utils import set_env_start_args, get_env_start_args
46
from lightllm.server.core.objs.shm_req_manager import ShmReqManager
57

68

79
@pytest.fixture(scope="module", autouse=True)
810
def setup_env():
9-
set_env_start_args({"running_max_req_size": 10, "disable_chunked_prefill": True, "token_healing_mode": False})
11+
original = os.environ.get("LIGHTLLM_START_ARGS")
12+
set_env_start_args(
13+
EasyDict(
14+
running_max_req_size=10,
15+
disable_chunked_prefill=True,
16+
token_healing_mode=False,
17+
enable_flashinfer_prefill=False,
18+
enable_flashinfer_decode=False,
19+
)
20+
)
21+
# clear the lru_cache if used
22+
if hasattr(get_env_start_args, "cache_clear"):
23+
get_env_start_args.cache_clear()
24+
1025
yield
26+
if original is not None:
27+
os.environ["LIGHTLLM_START_ARGS"] = original
28+
else:
29+
os.environ.pop("LIGHTLLM_START_ARGS", None)
30+
if hasattr(get_env_start_args, "cache_clear"):
31+
get_env_start_args.cache_clear()
1132

1233

1334
@pytest.fixture(scope="module")

0 commit comments

Comments
 (0)