Skip to content

Commit 326305d

Browse files
committed
[TRTLLM-6756][chore] Update test cases for beam search sampling after merge
- Modified `model_kwargs` to include `sampler_type` for improved test configuration. - Adjusted `llm_cuda_graph` fixture to remove unnecessary `sampler_type` parameter. - Enhanced clarity in `test_torch_sampler.py` by adding comments regarding the `is_context_init_state` attribute. Signed-off-by: Stefan Niebler <[email protected]>
1 parent 67593dd commit 326305d

File tree

3 files changed

+20
-13
lines changed

3 files changed

+20
-13
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,12 +1093,16 @@ def _process_draft_tokens_tree(
10931093
# Take the longest accepted path as the next new token.
10941094
num_accepted_draft_tokens = 0
10951095
for idx in eagle_paths[longest_match_path_idx][:longest_accepted_len]:
1096-
add_token(request, new_tokens_list, beam_idx=self.DEFAULT_BEAM_IDX, step=cast(int, idx.item()))
1096+
add_token(
1097+
request, new_tokens_list, beam_idx=DEFAULT_BEAM_IDX, step=cast(int, idx.item())
1098+
)
10971099
num_accepted_draft_tokens += 1
1098-
if self.finish_if_reason(request,
1099-
finish_reasons,
1100-
step=num_accepted_draft_tokens,
1101-
beam_idx=DEFAULT_BEAM_IDX,):
1100+
if self.finish_if_reason(
1101+
request,
1102+
finish_reasons,
1103+
step=num_accepted_draft_tokens,
1104+
beam_idx=DEFAULT_BEAM_IDX,
1105+
):
11021106
break
11031107

11041108
assert num_accepted_draft_tokens <= longest_accepted_len
@@ -1108,7 +1112,6 @@ def _process_draft_tokens_tree(
11081112

11091113
return num_accepted_draft_tokens - 1
11101114

1111-
11121115
def setup_sampler_step(self, requests: ScheduledRequests):
11131116
"""Setup the sampler step for the requests
11141117

tests/unittest/_torch/sampler/test_beam_search.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import pathlib as _pl
16+
from typing import Any
1617

1718
import pytest
1819
import torch
@@ -49,7 +50,7 @@ def sampler_type(request):
4950

5051

5152
@pytest.fixture(scope="module")
52-
def model_kwargs(fixed_params) -> dict[str, Any]:
53+
def model_kwargs(fixed_params, sampler_type) -> dict[str, Any]:
5354
assert fixed_params[
5455
"max_beam_width"] == 2, "This test only works for a beam width of 2"
5556
return dict(
@@ -58,6 +59,7 @@ def model_kwargs(fixed_params) -> dict[str, Any]:
5859
weight_loader=DummyWeightLoader(),
5960
config_loader=DummyConfigLoader(),
6061
),
62+
sampler_type=sampler_type,
6163
)
6264

6365

@@ -72,19 +74,18 @@ def _build_llm(fixed_params, input_prompts, model_kwargs):
7274
max_beam_width=fixed_params["max_beam_width"],
7375
disable_overlap_scheduler=True,
7476
cuda_graph_config=None,
75-
sampler_type=sampler_type,
7677
)
7778

7879

7980
@pytest.fixture(scope="module")
8081
def llm(fixed_params, input_prompts, model_kwargs):
81-
return _build_llm(fixed_params, input_prompts, model_kwargs)
82+
llm = _build_llm(fixed_params, input_prompts, model_kwargs)
8283
yield llm
8384
llm.shutdown()
8485

8586

8687
@pytest.fixture(scope="module")
87-
def llm_cuda_graph(fixed_params, input_prompts, sampler_type, model_kwargs):
88+
def llm_cuda_graph(fixed_params, input_prompts, model_kwargs):
8889
llm = LLM(
8990
**model_kwargs,
9091
kv_cache_config=KvCacheConfig(max_tokens=10000),
@@ -96,7 +97,6 @@ def llm_cuda_graph(fixed_params, input_prompts, sampler_type, model_kwargs):
9697
disable_overlap_scheduler=False,
9798
cuda_graph_config=CudaGraphConfig(batch_sizes=[1, 2, 4, 8],
9899
enable_padding=True),
99-
sampler_type=sampler_type,
100100
)
101101
yield llm
102102
llm.shutdown()
@@ -327,7 +327,9 @@ def test_beam_search_e2e_cuda_graph_and_overlap(
327327
sampling_params)
328328

329329

330+
###########################################################################
330331
# Unit tests
332+
###########################################################################
331333
class GeneralTestParams:
332334
# Test Parameters for the update_beam_history and finish_beams tests
333335
beam_width = 3

tests/unittest/_torch/sampler/test_torch_sampler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,11 @@ class TestStrategySelection:
8383

8484
class MockLlmRequest:
8585
sampling_config: SamplingConfig
86-
is_context_init_state: bool # Not used in this test
86+
is_context_init_state: bool # Torch sampler accesses this, but it does not affect this test
8787

88-
def get_beam_width_by_iter(self, for_next_iteration: bool) -> int:
88+
def get_beam_width_by_iter(
89+
self, for_next_iteration: bool
90+
) -> int: # Torch sampler accesses this, but it does not affect this test
8991
return self.sampling_config.beam_width
9092

9193
def _check_params(self, params: SamplingParams):

0 commit comments

Comments
 (0)