Skip to content

Commit 447e521

Browse files
committed
Better conform to TRT-LLM style guidelines and add an LLM API argument to explicitly enable the async worker for testing purposes
Signed-off-by: Dan Hansen <[email protected]>
1 parent 6e31c1c commit 447e521

File tree

6 files changed

+34
-24
lines changed

6 files changed

+34
-24
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,7 @@ def create_py_executor_instance(
815815
def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
816816
max_batch_size: int,
817817
speculative_config: SpeculativeConfig,
818-
max_beam_width: int, use_async_worker: bool):
818+
max_beam_width: int, enable_async_worker: bool):
819819
max_num_sequences = max_batch_size * mapping.pp_size
820820
max_draft_len = (0 if speculative_config is None else
821821
speculative_config.max_draft_len)
@@ -828,7 +828,7 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
828828
max_total_draft_tokens=max_total_draft_tokens,
829829
max_num_sequences=max_num_sequences,
830830
max_beam_width=max_beam_width,
831-
use_async_worker=use_async_worker,
831+
enable_async_worker=enable_async_worker,
832832
)
833833

834834

@@ -839,15 +839,16 @@ def instantiate_sampler(engine: PyTorchModelEngine,
839839
speculative_config: SpeculativeConfig,
840840
decoding_config: trtllm.DecodingConfig,
841841
kv_cache_config: KvCacheConfig):
842-
use_async_worker = confidential_compute_enabled()
842+
enable_async_worker = (confidential_compute_enabled() or
843+
pytorch_backend_config.sampler_enable_async_worker)
843844

844845
sampler_args = create_torch_sampler_args(
845846
mapping,
846847
max_seq_len=engine.max_seq_len,
847848
max_batch_size=max_batch_size,
848849
speculative_config=speculative_config,
849850
max_beam_width=max_beam_width,
850-
use_async_worker=use_async_worker)
851+
enable_async_worker=enable_async_worker)
851852
decoding_mode = get_decoding_mode(decoding_config=decoding_config,
852853
max_beam_width=max_beam_width)
853854
if mapping.cp_config.get('cp_type') == CpType.STAR:
@@ -874,7 +875,7 @@ def instantiate_sampler(engine: PyTorchModelEngine,
874875
max_beam_width=max_beam_width,
875876
decoding_config=decoding_config,
876877
kv_cache_config=kv_cache_config,
877-
use_async_worker=use_async_worker)
878+
enable_async_worker=enable_async_worker)
878879
if not engine.model.model_config.is_generation:
879880
# NOTE: choose sampler based on model type
880881
return EarlyStopSampler()

tensorrt_llm/_torch/pyexecutor/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class PyTorchConfig:
6969
The type of sampler to use. Options are TRTLLMSampler, TorchSampler or auto.
7070
Defaults to auto, which will use TorchSampler unless BeamSearch is requested.
7171
"""
72+
sampler_enable_async_worker: bool = False
7273

7374
kv_cache_dtype: str = "auto"
7475
mamba_ssm_cache_dtype: str = "auto"

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
LlmResponse, get_draft_token_length)
5353
from .model_engine import ModelEngine
5454
from .resource_manager import ResourceManager
55-
from .sampler import Sampler, SampleState, SampleStateTensors
55+
from .sampler import AsyncWorkerMixin, Sampler, SampleState, SampleStateTensors
5656
from .scheduler import RequestScheduler, ScheduledRequests
5757

5858
# Environment variable to specify iteration ranges for profiling start/stop.
@@ -479,7 +479,7 @@ def shutdown(self):
479479
if self.draft_model_engine is not None:
480480
del self.draft_model_engine
481481
# Stop the sampler's async worker, if it was used
482-
if hasattr(self.sampler, 'async_worker_stop'):
482+
if isinstance(self.sampler, AsyncWorkerMixin):
483483
self.sampler.async_worker_stop()
484484

485485
def can_enqueue_requests(self) -> bool:

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from .kv_cache_connector import KvCacheConnectorManager
3939
from .model_engine import PyTorchModelEngine
4040
from .py_executor import PyExecutor
41+
from .sampler import AsyncWorkerMixin
4142

4243

4344
class _ExecutorCreationStage(enum.Enum):
@@ -181,7 +182,7 @@ def update_sampler_max_seq_len(max_seq_len, sampler):
181182

182183

183184
def maybe_start_sampler_async_worker(sampler):
184-
if hasattr(sampler, 'async_worker_start') and sampler.use_async_worker:
185+
if isinstance(sampler, AsyncWorkerMixin) and sampler.enable_async_worker:
185186
sampler.async_worker_start()
186187

187188

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from abc import ABC, abstractmethod
1818
from collections import defaultdict
1919
from collections.abc import Iterable
20-
from concurrent.futures import Future, ThreadPoolExecutor
20+
from concurrent import futures
2121
from dataclasses import dataclass
2222
from itertools import repeat
2323
from typing import Any, Callable, List, Optional, TypeVar, cast
@@ -82,7 +82,7 @@
8282

8383
@dataclass(kw_only=True)
8484
class SampleStateTensors:
85-
new_tokens: torch.Tensor | Future[torch.Tensor]
85+
new_tokens: torch.Tensor | futures.Future[torch.Tensor]
8686
log_probs: torch.Tensor | None = None
8787

8888
def values(self):
@@ -574,12 +574,12 @@ class AsyncWorkerMixin:
574574
def _async_worker_active(self) -> bool:
575575
return self._async_worker is not None
576576

577-
def _async_worker_init(self, use_async_worker: bool):
578-
self.use_async_worker = use_async_worker
577+
def _async_worker_init(self, enable_async_worker: bool):
578+
self.enable_async_worker = enable_async_worker
579579
self._async_worker = None
580580

581581
def async_worker_start(self):
582-
assert self.use_async_worker
582+
assert self.enable_async_worker
583583
assert not self._async_worker_active()
584584

585585
def _async_worker_initializer(device_id):
@@ -590,7 +590,7 @@ def _async_worker_initializer(device_id):
590590
# blocking copies from gating subsequent async work
591591
torch.cuda.set_stream(torch.cuda.Stream())
592592

593-
self._async_worker = ThreadPoolExecutor(
593+
self._async_worker = futures.ThreadPoolExecutor(
594594
max_workers=1,
595595
initializer=_async_worker_initializer,
596596
initargs=(torch.cuda.current_device(),),
@@ -653,7 +653,7 @@ class Args:
653653
max_num_sequences: int
654654
max_beam_width: int
655655
max_total_draft_tokens: int
656-
use_async_worker: Optional[bool] = False
656+
enable_async_worker: Optional[bool] = False
657657

658658
def __init__(self, args: Args):
659659
self.max_seq_len = args.max_seq_len
@@ -674,7 +674,7 @@ def __init__(self, args: Args):
674674
self._global_seed = 42
675675
self._generator = None
676676

677-
self._async_worker_init(args.use_async_worker)
677+
self._async_worker_init(args.enable_async_worker)
678678

679679
def get_generator(self, device: torch.device) -> torch.Generator:
680680
"""Get a deterministic generator for the specified device.
@@ -755,9 +755,8 @@ def handle_logprobs(
755755
if request.py_return_log_probs:
756756
if self._async_worker_active():
757757
# These should be futures if we used the async worker
758-
assert isinstance(request.py_topk_logprobs_values, Future) and isinstance(
759-
request.py_topk_logprobs_vals, Future
760-
)
758+
assert isinstance(request.py_topk_logprobs_values, futures.Future)
759+
assert isinstance(request.py_topk_logprobs_vals, futures.Future)
761760
topk_log_probs_vals = request.py_topk_logprobs_vals.result()
762761
topk_log_probs_indices = request.py_topk_logprobs_indices.result()
763762
else:
@@ -1079,7 +1078,7 @@ def update_requests(
10791078
assert state.host is not None
10801079

10811080
if self._async_worker_active():
1082-
assert isinstance(state.host.new_tokens, Future)
1081+
assert isinstance(state.host.new_tokens, futures.Future)
10831082
new_tokens = state.host.new_tokens.result()
10841083
else:
10851084
new_tokens = state.host.new_tokens
@@ -1686,7 +1685,9 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors):
16861685
class SampleStateTRTLLM(SampleState):
16871686
finalize_events: dict[str, CudaEvent] | None = None
16881687
"""`Optional` to accommodate `_forward_step_inter_pp` which creates a `SampleState` without `finalize_events`"""
1689-
host: Optional[SampleStateTensorsHostTRTLLM | Future[SampleStateTensorsHostTRTLLM]] = None
1688+
host: Optional[SampleStateTensorsHostTRTLLM | futures.Future[SampleStateTensorsHostTRTLLM]] = (
1689+
None
1690+
)
16901691

16911692

16921693
class TRTLLMSampler(Sampler, AsyncWorkerMixin):
@@ -1709,7 +1710,7 @@ def __init__(
17091710
max_beam_width: int,
17101711
decoding_config: Optional[DecodingConfig] = None,
17111712
kv_cache_config: Optional[KvCacheConfig] = None,
1712-
use_async_worker: Optional[bool] = False,
1713+
enable_async_worker: Optional[bool] = False,
17131714
):
17141715
vocab_size = model.config.vocab_size
17151716
num_hidden_layers = model.config.num_hidden_layers
@@ -1760,7 +1761,7 @@ def __init__(
17601761
self._initialize_store()
17611762
self._instantiate_algorithms()
17621763

1763-
self._async_worker_init(use_async_worker)
1764+
self._async_worker_init(enable_async_worker)
17641765

17651766
def _initialize_store(self):
17661767
torch_stream = torch.cuda.current_stream().cuda_stream
@@ -1983,7 +1984,7 @@ def update_requests(
19831984

19841985
if self._async_worker_active():
19851986
# Wait for and "unpack" the host tensors
1986-
assert isinstance(state.host, Future)
1987+
assert isinstance(state.host, futures.Future)
19871988
state.host = state.host.result()
19881989

19891990
beam_width = self.beam_width(state.scheduled_requests.all_requests())

tensorrt_llm/llmapi/llm_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2449,6 +2449,11 @@ class TorchLlmArgs(BaseLlmArgs):
24492449
"The type of sampler to use. Options are TRTLLMSampler, TorchSampler or auto. Defaults to auto, which will use TorchSampler unless BeamSearch is requested.",
24502450
status="beta")
24512451

2452+
sampler_enable_async_worker: bool = Field(
2453+
default=False,
2454+
description="Enable the async worker in the sampler for D->H copies",
2455+
status="beta")
2456+
24522457
enable_iter_perf_stats: bool = Field(
24532458
default=False,
24542459
description="Enable iteration performance statistics.",
@@ -2822,6 +2827,7 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
28222827
use_low_precision_moe_combine=self.moe_config.
28232828
use_low_precision_moe_combine,
28242829
sampler_type=self.sampler_type,
2830+
sampler_enable_async_worker=self.sampler_enable_async_worker,
28252831
kv_cache_dtype=self.kv_cache_config.dtype,
28262832
mamba_ssm_cache_dtype=self.kv_cache_config.mamba_ssm_cache_dtype,
28272833
enable_iter_perf_stats=self.enable_iter_perf_stats,

0 commit comments

Comments
 (0)