Skip to content

Commit f567f30

Browse files
committed
feat: 'disable_flashinfer_sampling' config option
Signed-off-by: ixlmar <[email protected]>
1 parent 6dd2fcd commit f567f30

File tree

5 files changed

+12
-10
lines changed

5 files changed

+12
-10
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,7 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
824824
max_batch_size: int,
825825
speculative_config: SpeculativeConfig,
826826
max_beam_width: int,
827-
disable_flash_infer_sampling: bool):
827+
disable_flashinfer_sampling: bool):
828828
max_num_sequences = max_batch_size * mapping.pp_size
829829
max_draft_len = (0 if speculative_config is None else
830830
speculative_config.max_draft_len)
@@ -837,7 +837,7 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
837837
max_total_draft_tokens=max_total_draft_tokens,
838838
max_num_sequences=max_num_sequences,
839839
max_beam_width=max_beam_width,
840-
disable_flash_infer_sampling=disable_flash_infer_sampling,
840+
disable_flashinfer_sampling=disable_flashinfer_sampling,
841841
)
842842

843843

@@ -853,15 +853,15 @@ def instantiate_sampler(
853853
speculative_config: SpeculativeConfig,
854854
decoding_config: trtllm.DecodingConfig,
855855
kv_cache_config: KvCacheConfig,
856-
disable_flash_infer_sampling: bool,
856+
disable_flashinfer_sampling: bool,
857857
):
858858
sampler_args = create_torch_sampler_args(
859859
mapping,
860860
max_seq_len=engine.max_seq_len,
861861
max_batch_size=max_batch_size,
862862
speculative_config=speculative_config,
863863
max_beam_width=max_beam_width,
864-
disable_flash_infer_sampling=disable_flash_infer_sampling,
864+
disable_flashinfer_sampling=disable_flashinfer_sampling,
865865
)
866866
decoding_mode = get_decoding_mode(decoding_config=decoding_config,
867867
max_beam_width=max_beam_width)

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def drafting_loop_wrapper(model):
516516
speculative_config=spec_config,
517517
decoding_config=decoding_config,
518518
kv_cache_config=kv_cache_config,
519-
disable_flash_infer_sampling=llm_args._disable_flash_infer_sampling,
519+
disable_flashinfer_sampling=llm_args._disable_flashinfer_sampling,
520520
)
521521
logger.info(f"Using Sampler: {type(sampler).__name__}")
522522

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ class Args:
615615
max_num_sequences: int
616616
max_beam_width: int
617617
max_total_draft_tokens: int
618-
disable_flash_infer_sampling: bool = False
618+
disable_flashinfer_sampling: bool = False
619619

620620
def __init__(self, args: Args):
621621
self.max_seq_len = args.max_seq_len
@@ -651,7 +651,7 @@ def __init__(self, args: Args):
651651
}
652652

653653
self._grouped_sampler_cls: Type[GroupedStrategySampler]
654-
if IS_FLASHINFER_AVAILABLE and not args.disable_flash_infer_sampling:
654+
if IS_FLASHINFER_AVAILABLE and not args.disable_flashinfer_sampling:
655655
from .sampling_utils_flashinfer import FlashInferGroupedStrategySampler
656656

657657
self._grouped_sampler_cls = FlashInferGroupedStrategySampler

tensorrt_llm/llmapi/llm_args.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2775,8 +2775,10 @@ class TorchLlmArgs(BaseLlmArgs):
27752775
# PrivateVars
27762776
_quant_config: Optional[QuantConfig] = PrivateAttr(default=None)
27772777

2778-
_disable_flash_infer_sampling: bool = PrivateAttr(default=True)
2779-
"""Unless this is set to False, FlashInfer.sampling is not used, even if available."""
2778+
disable_flashinfer_sampling: bool = Field(
2779+
default=True,
2780+
description="Disable the use of FlashInfer.sampling.",
2781+
status="prototype")
27802782

27812783
@property
27822784
def quant_config(self) -> QuantConfig:

tests/unittest/_torch/sampler/test_torch_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,7 @@ def _build_sampler(
10771077
max_beam_width=1, # currently the only supported value
10781078
max_num_sequences=num_seq_slots,
10791079
max_total_draft_tokens=max_draft_len,
1080-
disable_flash_infer_sampling=(not use_flashinfer),
1080+
disable_flashinfer_sampling=(not use_flashinfer),
10811081
)
10821082
)
10831083

0 commit comments

Comments
 (0)