Skip to content

Commit 46dd988

Browse files
authored
[https://nvbugs/5661877][fix] fix test regression in TestBatchedSampling::test_samples (#9215)
Signed-off-by: ixlmar <[email protected]>
1 parent 0f77fec commit 46dd988

File tree

3 files changed

+56
-66
lines changed

3 files changed

+56
-66
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,11 +278,14 @@ def _group_requests_by_strategy_key(
278278
group_dict: dict[tuple[GenericStrategyKeyType, bool], tuple[list[int], list[Strategy]]] = (
279279
defaultdict(lambda: ([], []))
280280
)
281+
281282
for req_index, req in enumerate(requests):
282283
strategy = _request_strategy(req, vocab_size=vocab_size)
283-
# In the overlap path, py_draft_logits is not updated yet,
284-
# so we use get_draft_token_length() for the checking.
285-
speculation_needs_probs = get_draft_token_length(req) > 0 and strategy is not GREEDY
284+
speculation_needs_probs = (
285+
# NB: This criterion needs to be consistent with the gating of rejection sampling in
286+
# process_draft_tokens.
287+
TorchSampler._speculation_could_use_rejection_sampling(req, strategy)
288+
)
286289
strategy_key = strategy_to_key(strategy, speculation_needs_probs)
287290
group_dict_entry = group_dict[(strategy_key, speculation_needs_probs)]
288291
group_dict_entry[0].append(req_index)
@@ -1026,6 +1029,17 @@ def _process_draft_tokens_rejection_sampling(
10261029

10271030
return num_accepted
10281031

1032+
@staticmethod
1033+
def _speculation_could_use_rejection_sampling(
1034+
request: LlmRequest, strategy: Optional[Strategy] = None
1035+
) -> bool:
1036+
if strategy is None:
1037+
strategy = _request_strategy(
1038+
request,
1039+
vocab_size=2**31, # vocab_size does not affect greediness
1040+
)
1041+
return get_draft_token_length(request) > 0 and strategy != GREEDY
1042+
10291043
def process_draft_tokens(
10301044
self,
10311045
request: LlmRequest,
@@ -1034,9 +1048,17 @@ def process_draft_tokens(
10341048
finish_reasons: FinishReasonsList,
10351049
resource_manager: Optional[ResourceManager] = None,
10361050
) -> int:
1037-
if (
1038-
_request_strategy(request, vocab_size=2**31) == GREEDY
1039-
or request.py_draft_logits is None
1051+
if not (
1052+
self._speculation_could_use_rejection_sampling(request)
1053+
# NB: '_speculation_could_use_rejection_sampling' is called in sample_async, which precludes
1054+
# inspection of .py_draft_logits, because it is not set yet when the overlap path
1055+
# is used.
1056+
#
1057+
# OTOH, some drafters (e.g. NGram) do not provide draft logits, precluding rejection
1058+
# sampling. The current solution accepts that .py_target_probs may sometimes be
1059+
# computed, even though .py_draft_logits may never be set and the target probs
1060+
# may ultimately not be required.
1061+
and request.py_draft_logits is not None
10401062
):
10411063
spec_tree_manager = self.get_spec_tree_manager(resource_manager)
10421064
if spec_tree_manager is not None:

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,6 @@ examples/test_multimodal.py::test_llm_multimodal_general[llava-onevision-qwen2-7
390390
examples/test_multimodal.py::test_llm_multimodal_general[llava-onevision-qwen2-7b-ov-hf-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5655832)
391391
examples/test_multimodal.py::test_llm_multimodal_general[Qwen2-VL-7B-Instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:4] SKIP (https://nvbugs/5655832)
392392
disaggregated/test_disaggregated.py::test_disaggregated_mixed[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5661926)
393-
unittest/_torch/sampler/test_torch_sampler.py::TestBatchedSampling SKIP (https://nvbugs/5661877)
394393
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True] SKIP (https://nvbugs/5568836)
395394
test_e2e.py::test_trtllm_multimodal_benchmark_serving SKIP (https://nvbugs/5647825)
396395
unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[MNNVL] SKIP (https://nvbugs/5664904)

tests/unittest/_torch/sampler/test_torch_sampler.py

Lines changed: 28 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -945,15 +945,11 @@ def seq_slot_assignment(
945945
def mock_requests(
946946
self,
947947
sampling_params_list: list[SamplingParams],
948-
with_draft_logits: bool,
949-
vocab_size: int,
950948
seq_slot_assignment: tuple[list[int], int],
951949
draft_lens: list[int],
952950
) -> ScheduledRequests:
953951
return self._build_mock_requests(
954952
sampling_params_list=sampling_params_list,
955-
with_draft_logits=with_draft_logits,
956-
vocab_size=vocab_size,
957953
seq_slot_assignment=seq_slot_assignment,
958954
draft_lens=draft_lens,
959955
)
@@ -962,8 +958,6 @@ def _build_mock_requests(
962958
self,
963959
sampling_params_list: list[SamplingParams],
964960
*,
965-
with_draft_logits: bool,
966-
vocab_size: int,
967961
seq_slot_assignment: tuple[list[int], int],
968962
draft_lens: list[int],
969963
) -> ScheduledRequests:
@@ -975,21 +969,9 @@ def __init__(
975969
self,
976970
sampling_params_list: list[SamplingParams],
977971
*,
978-
with_draft_logits: bool,
979972
draft_lens: list[int],
980973
):
981974
self._sampling_params_list = sampling_params_list
982-
self._with_draft_logits = with_draft_logits
983-
984-
def _attach_draft_logits(req: LlmRequest) -> LlmRequest:
985-
draft_len = len(req.py_draft_tokens)
986-
if draft_len and with_draft_logits:
987-
req.py_draft_logits = torch.testing.make_tensor( # type: ignore
988-
(draft_len, vocab_size),
989-
dtype=torch.float32,
990-
device="cuda",
991-
)
992-
return req
993975

994976
# NB:
995977
# - stop words are tested in test_write_finish_reasons
@@ -999,24 +981,22 @@ def _attach_draft_logits(req: LlmRequest) -> LlmRequest:
999981
# - py_return_log_probs is tested elsewhere
1000982
# - code paths gated by py_return_context_logits tested in test_select_generated_logits
1001983
self._gen_requests = [
1002-
_attach_draft_logits(
1003-
LlmRequest(
1004-
request_id=seq_slot,
1005-
max_new_tokens=(2 * draft_len), # not used by tested code
1006-
input_tokens=[12], # not used by tested code
1007-
sampling_config=SamplingConfig(sampling_params._get_sampling_config()),
1008-
seq_slot=seq_slot,
1009-
is_streaming=False, # not relevant for tested code
1010-
draft_tokens=( # 'len(.py_draft_tokens)' is inspected by get_draft_token_length
1011-
torch.testing.make_tensor(
1012-
(draft_len,),
1013-
dtype=torch.int32,
1014-
device="cpu",
1015-
).tolist()
1016-
if draft_len
1017-
else None
1018-
),
1019-
)
984+
LlmRequest(
985+
request_id=seq_slot,
986+
max_new_tokens=(2 * draft_len), # not used by tested code
987+
input_tokens=[12], # not used by tested code
988+
sampling_config=SamplingConfig(sampling_params._get_sampling_config()),
989+
seq_slot=seq_slot,
990+
is_streaming=False, # not relevant for tested code
991+
draft_tokens=( # 'len(.py_draft_tokens)' is inspected by get_draft_token_length
992+
torch.testing.make_tensor(
993+
(draft_len,),
994+
dtype=torch.int32,
995+
device="cpu",
996+
).tolist()
997+
if draft_len
998+
else None
999+
),
10201000
)
10211001
for sampling_params, seq_slot, draft_len in zip(
10221002
sampling_params_list, seq_slots, draft_lens
@@ -1040,9 +1020,7 @@ def all_requests(self) -> list[LlmRequest]:
10401020
with torch.inference_mode(True):
10411021
return cast(
10421022
ScheduledRequests,
1043-
ScheduledRequestsMock(
1044-
sampling_params_list, with_draft_logits=with_draft_logits, draft_lens=draft_lens
1045-
),
1023+
ScheduledRequestsMock(sampling_params_list, draft_lens=draft_lens),
10461024
)
10471025

10481026
@pytest.fixture(scope="function")
@@ -1184,20 +1162,17 @@ def test_backend_selection(
11841162
"max_draft_len",
11851163
"draft_lens",
11861164
"sampling_params_list",
1187-
"with_draft_logits",
11881165
"params_label",
11891166
"allow_zero_draft_len",
11901167
"vocab_size",
11911168
),
11921169
[
1193-
# NB: with_draft_logits=True and non-zero draft len ensures that
1194-
# LlmRequest.py_target_probs is set.
1170+
# NB: non-zero draft len ensures that LlmRequest.py_target_probs is set.
11951171
pytest.param(
11961172
use_flashinfer,
11971173
3,
11981174
[3] * len(sampling_params_list),
11991175
sampling_params_list,
1200-
True,
12011176
params_label,
12021177
False,
12031178
vocab_size,
@@ -1225,7 +1200,6 @@ def test_probs(
12251200
allow_zero_draft_len: bool, # used by fixtures
12261201
sampling_params_list: list[SamplingParams],
12271202
seq_slot_assignment: tuple[list[int], int],
1228-
with_draft_logits: bool,
12291203
):
12301204
"""Validate probabilities returned by sample_async.
12311205
@@ -1255,9 +1229,7 @@ def _uut_provider(is_warmup: bool) -> Generator[Callable[[], None], None, None]:
12551229
# requests.
12561230
uut_mock_requests = self._build_mock_requests(
12571231
sampling_params_list=sampling_params_list,
1258-
vocab_size=vocab_size,
12591232
seq_slot_assignment=seq_slot_assignment,
1260-
with_draft_logits=with_draft_logits,
12611233
draft_lens=draft_lens,
12621234
)
12631235
else:
@@ -1427,11 +1399,8 @@ def _compute_probs(
14271399
)
14281400
mock_requests_with_probs = self._build_mock_requests(
14291401
sampling_params_list=sampling_params_list,
1430-
vocab_size=vocab_size,
14311402
seq_slot_assignment=seq_slot_assignment,
1432-
# NB: with_draft_logits=True and non-zero draft len ensures that
1433-
# LlmRequest.py_target_probs is set.
1434-
with_draft_logits=True,
1403+
# NB: non-zero draft len ensures that LlmRequest.py_target_probs is set.
14351404
draft_lens=([draft_len_with_probs] * len(sampling_params_list)),
14361405
)
14371406
# zero-pad logits to draft_len_with_probs
@@ -1818,6 +1787,12 @@ def _validate_token_frequencies(
18181787

18191788
# Perform G-test (asymptotically approximated by Pearson's chi-square test) to
18201789
# check that sampled tokens are consistent with the expected probs.
1790+
#
1791+
# NB: Need to use FP64 to avoid negative test statistic values.
1792+
test_token_counts_ma = test_token_counts_ma.astype(np.float64)
1793+
test_expected_counts_ma = test_expected_counts_ma.astype(np.float64)
1794+
test_expected_counts_ma /= test_expected_counts_ma.sum(axis=-1, keepdims=True)
1795+
test_expected_counts_ma *= num_samples
18211796
test_result = power_divergence(
18221797
f_obs=test_token_counts_ma,
18231798
f_exp=test_expected_counts_ma,
@@ -1847,7 +1822,6 @@ def _validate_token_frequencies(
18471822
"use_flashinfer",
18481823
"max_draft_len",
18491824
"sampling_params_list",
1850-
"with_draft_logits",
18511825
"allow_zero_draft_len",
18521826
"bypass_sampling",
18531827
"vocab_size",
@@ -1857,7 +1831,6 @@ def _validate_token_frequencies(
18571831
use_flashinfer,
18581832
max_draft_len,
18591833
sampling_params_list,
1860-
with_draft_logits,
18611834
allow_zero_draft_len,
18621835
# Run full sampling test only for uniform batches, with/without probs, but skip
18631836
# sampling statistics when varying draft lens etc. to validate batch handling:
@@ -1868,22 +1841,20 @@ def _validate_token_frequencies(
18681841
id=(
18691842
f"{'FlashInfer' if use_flashinfer else 'Torch'}"
18701843
f"-draft_len={0 if allow_zero_draft_len else 1}..{max_draft_len}"
1871-
f"-return_probs={with_draft_logits}-{params_label}"
1844+
f"-{params_label}"
18721845
),
18731846
)
18741847
# https://stackoverflow.com/a/75421799, does not work with nested loops
18751848
for (
18761849
use_flashinfer,
18771850
is_mixed,
1878-
with_draft_logits,
18791851
max_draft_len,
18801852
allow_zero_draft_len,
18811853
_build_test_cases,
18821854
vocab_size,
18831855
) in product(
18841856
[False, True],
18851857
[False, True],
1886-
[True, False],
18871858
[0, 3],
18881859
[False, True],
18891860
[_build_test_cases],
@@ -1895,8 +1866,7 @@ def _validate_token_frequencies(
18951866
include_uniform=(not is_mixed),
18961867
include_mixed=is_mixed,
18971868
)
1898-
if (allow_zero_draft_len or max_draft_len > 0)
1899-
and (not with_draft_logits or max_draft_len > 0)
1869+
if allow_zero_draft_len or max_draft_len > 0
19001870
],
19011871
)
19021872
def test_samples(
@@ -1908,7 +1878,6 @@ def test_samples(
19081878
vocab_size: int,
19091879
sampling_params_list: list[SamplingParams],
19101880
seq_slot_assignment: tuple[list[int], int],
1911-
with_draft_logits: bool,
19121881
max_draft_len: int,
19131882
use_flashinfer: bool,
19141883
allow_zero_draft_len: bool, # used by fixtures
@@ -2038,7 +2007,7 @@ def _uut(res=res):
20382007
probs = probs[: (draft_len + 1)]
20392008

20402009
# check probs are returned only when needed
2041-
should_return_probs = draft_len and with_draft_logits
2010+
should_return_probs = bool(draft_len)
20422011
assert (
20432012
hasattr(req, "py_target_probs") and req.py_target_probs is not None
20442013
) == should_return_probs

0 commit comments

Comments
 (0)