Skip to content

Commit e5f39ec

Browse files
authored
[TRTLLM-9488][feat] add 'disable_flashinfer_sampling' config option (#9454)
Signed-off-by: ixlmar <[email protected]>
1 parent 930cdad commit e5f39ec

File tree

6 files changed

+18
-10
lines changed

6 files changed

+18
-10
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,7 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
828828
max_batch_size: int,
829829
speculative_config: SpeculativeConfig,
830830
max_beam_width: int,
831-
disable_flash_infer_sampling: bool):
831+
disable_flashinfer_sampling: bool):
832832
max_num_sequences = max_batch_size * mapping.pp_size
833833
max_draft_len = (0 if speculative_config is None else
834834
speculative_config.max_draft_len)
@@ -841,7 +841,7 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
841841
max_total_draft_tokens=max_total_draft_tokens,
842842
max_num_sequences=max_num_sequences,
843843
max_beam_width=max_beam_width,
844-
disable_flash_infer_sampling=disable_flash_infer_sampling,
844+
disable_flashinfer_sampling=disable_flashinfer_sampling,
845845
)
846846

847847

@@ -857,15 +857,15 @@ def instantiate_sampler(
857857
speculative_config: SpeculativeConfig,
858858
decoding_config: trtllm.DecodingConfig,
859859
kv_cache_config: KvCacheConfig,
860-
disable_flash_infer_sampling: bool,
860+
disable_flashinfer_sampling: bool,
861861
):
862862
sampler_args = create_torch_sampler_args(
863863
mapping,
864864
max_seq_len=engine.max_seq_len,
865865
max_batch_size=max_batch_size,
866866
speculative_config=speculative_config,
867867
max_beam_width=max_beam_width,
868-
disable_flash_infer_sampling=disable_flash_infer_sampling,
868+
disable_flashinfer_sampling=disable_flashinfer_sampling,
869869
)
870870
decoding_mode = get_decoding_mode(decoding_config=decoding_config,
871871
max_beam_width=max_beam_width)

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ def drafting_loop_wrapper(model):
533533
speculative_config=spec_config,
534534
decoding_config=decoding_config,
535535
kv_cache_config=kv_cache_config,
536-
disable_flash_infer_sampling=llm_args._disable_flash_infer_sampling,
536+
disable_flashinfer_sampling=llm_args.disable_flashinfer_sampling,
537537
)
538538
logger.info(f"Using Sampler: {type(sampler).__name__}")
539539

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,7 @@ class Args:
616616
max_num_sequences: int
617617
max_beam_width: int
618618
max_total_draft_tokens: int
619-
disable_flash_infer_sampling: bool = False
619+
disable_flashinfer_sampling: bool = False
620620

621621
def __init__(self, args: Args):
622622
self.max_seq_len = args.max_seq_len
@@ -652,7 +652,7 @@ def __init__(self, args: Args):
652652
}
653653

654654
self._grouped_sampler_cls: Type[GroupedStrategySampler]
655-
if IS_FLASHINFER_AVAILABLE and not args.disable_flash_infer_sampling:
655+
if IS_FLASHINFER_AVAILABLE and not args.disable_flashinfer_sampling:
656656
from .sampling_utils_flashinfer import FlashInferGroupedStrategySampler
657657

658658
self._grouped_sampler_cls = FlashInferGroupedStrategySampler

tensorrt_llm/llmapi/llm_args.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2707,8 +2707,12 @@ class TorchLlmArgs(BaseLlmArgs):
27072707
# PrivateVars
27082708
_quant_config: Optional[QuantConfig] = PrivateAttr(default=None)
27092709

2710-
_disable_flash_infer_sampling: bool = PrivateAttr(default=True)
2711-
"""Unless this is set to False, FlashInfer.sampling is not used, even if available."""
2710+
disable_flashinfer_sampling: bool = Field(
2711+
default=True,
2712+
description=
2713+
"Disable the use of FlashInfer.sampling. This option is likely to be removed in the future.",
2714+
status="prototype",
2715+
)
27122716

27132717
@property
27142718
def quant_config(self) -> QuantConfig:

tests/unittest/_torch/sampler/test_torch_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,7 @@ def _build_sampler(
10771077
max_beam_width=1, # currently the only supported value
10781078
max_num_sequences=num_seq_slots,
10791079
max_total_draft_tokens=max_draft_len,
1080-
disable_flash_infer_sampling=(not use_flashinfer),
1080+
disable_flashinfer_sampling=(not use_flashinfer),
10811081
)
10821082
)
10831083

tests/unittest/api_stability/references/llm.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ methods:
107107
annotation: bool
108108
default: False
109109
status: beta
110+
disable_flashinfer_sampling:
111+
annotation: bool
112+
default: False
113+
status: prototype
110114
moe_config:
111115
annotation: tensorrt_llm.llmapi.llm_args.MoeConfig
112116
status: beta

0 commit comments

Comments
 (0)