Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_utils import \
MODEL_CLASS_VISION_ENCODER_MAPPING
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
from tensorrt_llm._utils import (confidential_compute_enabled,
str_dtype_to_binding, torch_dtype_to_str)
from tensorrt_llm.bindings.executor import DecodingMode
from tensorrt_llm.llmapi.llm_args import (CacheTransceiverConfig,
EagleDecodingConfig, KvCacheConfig,
Expand Down Expand Up @@ -824,7 +825,8 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
max_batch_size: int,
speculative_config: SpeculativeConfig,
max_beam_width: int,
disable_flash_infer_sampling: bool):
disable_flash_infer_sampling: bool,
enable_async_worker: bool):
max_num_sequences = max_batch_size * mapping.pp_size
max_draft_len = (0 if speculative_config is None else
speculative_config.max_draft_len)
Expand All @@ -838,6 +840,7 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
max_num_sequences=max_num_sequences,
max_beam_width=max_beam_width,
disable_flash_infer_sampling=disable_flash_infer_sampling,
enable_async_worker=enable_async_worker,
)


Expand All @@ -855,13 +858,17 @@ def instantiate_sampler(
kv_cache_config: KvCacheConfig,
disable_flash_infer_sampling: bool,
):
enable_async_worker = (confidential_compute_enabled()
or llm_args.enable_sampler_async_worker)

sampler_args = create_torch_sampler_args(
mapping,
max_seq_len=engine.max_seq_len,
max_batch_size=max_batch_size,
speculative_config=speculative_config,
max_beam_width=max_beam_width,
disable_flash_infer_sampling=disable_flash_infer_sampling,
enable_async_worker=enable_async_worker,
)
decoding_mode = get_decoding_mode(decoding_config=decoding_config,
max_beam_width=max_beam_width)
Expand All @@ -888,7 +895,8 @@ def instantiate_sampler(
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
decoding_config=decoding_config,
kv_cache_config=kv_cache_config)
kv_cache_config=kv_cache_config,
enable_async_worker=enable_async_worker)
if not engine.model.model_config.is_generation:
# NOTE: choose sampler based on model type
return EarlyStopSampler()
Expand Down
13 changes: 11 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
LlmResponse, get_draft_token_length)
from .model_engine import ModelEngine
from .resource_manager import ResourceManager
from .sampler import Sampler, SampleState, SampleStateTensors
from .sampler import (AsyncWorkerMixin, Sampler, SamplerEvent, SampleState,
SampleStateTensors)
from .scheduler import RequestScheduler, ScheduledRequests

# Environment variable to specify iteration ranges for profiling start/stop.
Expand Down Expand Up @@ -359,6 +360,10 @@ def start_worker(self):
target=self._event_loop_wrapper, daemon=True)
self.worker_thread.start()
self.worker_started = True
# Start the sampler's async worker, if it is enabled
if (isinstance(self.sampler, AsyncWorkerMixin)
and self.sampler.async_worker_enabled()):
self.sampler.async_worker_start()

def _set_global_steady_clock_offset(self):
assert self.global_rank >= 0, "rank should be >= 0"
Expand Down Expand Up @@ -451,6 +456,10 @@ def shutdown(self):
keys = list(self.virtual_memory_pools.keys())
for key in keys:
del self.virtual_memory_pools[key]
# Stop the sampler's async worker, if it was used
if (isinstance(self.sampler, AsyncWorkerMixin)
and self.sampler.async_worker_enabled()):
self.sampler.async_worker_stop()

def can_enqueue_requests(self) -> bool:
"""
Expand Down Expand Up @@ -1603,7 +1612,7 @@ def _forward_step_inter_pp(self, scheduled_batch) -> SampleState:
self._update_request_states(scheduled_batch)
return self.sampler.SampleState(
scheduled_requests=scheduled_batch,
sampler_event=sampler_event,
sampler_event=SamplerEvent(cuda_event=sampler_event),
)

def _validate_request(self, request: LlmRequest):
Expand Down
155 changes: 134 additions & 21 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Iterable
from concurrent import futures
from dataclasses import dataclass
from functools import cached_property
from itertools import repeat
Expand Down Expand Up @@ -92,14 +93,25 @@ def values(self):
return vars(self).values()


@dataclass(kw_only=True)
class SamplerEvent:
cuda_event: torch.cuda.Event
worker_futures: Optional[list[futures.Future[Any]]] = None

def synchronize(self):
if self.worker_futures:
futures.wait(self.worker_futures)
self.cuda_event.synchronize()


@dataclass(kw_only=True)
class SampleState:
scheduled_requests: ScheduledRequests

device: Optional[SampleStateTensors] = None
host: Optional[SampleStateTensors] = None

sampler_event: Optional[torch.cuda.Event] = None
sampler_event: Optional[SamplerEvent] = None


class Sampler(ABC):
Expand Down Expand Up @@ -592,7 +604,106 @@ class SampleStateTorch(SampleState):
host: SampleStateTensorsHostTorch


class TorchSampler(Sampler):
class AsyncWorkerMixin:
"""
Mixin that adds the ability to fork off operations to run on a worker
thread (particularly D2H copies). If the async worker isn't active,
operations will seamlessly run on the main thread
"""

MAX_WORKERS = 1

def _async_worker_active(self) -> bool:
return hasattr(self, "_async_worker") and self._async_worker is not None

def _async_worker_init(self, enable_async_worker: bool):
self._enable_async_worker = enable_async_worker
self._async_worker = None
self._async_worker_futures: list[futures.Future[any]] = []

def async_worker_enabled(self):
return hasattr(self, "_enable_async_worker") and self._enable_async_worker

def async_worker_start(self):
assert self.async_worker_enabled()
if not self._async_worker_active():

def _async_worker_initializer(device_id):
# The current device is set per thread, so we need to set it
# again here
torch.cuda.set_device(device_id)
# Submit the host copies in a separate stream to prevent the
# blocking copies from gating subsequent async work
torch.cuda.set_stream(torch.cuda.Stream())

self._async_worker = futures.ThreadPoolExecutor(
max_workers=self.MAX_WORKERS,
initializer=_async_worker_initializer,
initargs=(torch.cuda.current_device(),),
)

def async_worker_stop(self):
assert self.async_worker_enabled()
if self._async_worker_active():
self._async_worker.shutdown(wait=True)
self._async_worker = None

@torch.inference_mode()
def _async_worker_run(self, ready: torch.cuda.Event, func, /, *args, **kwargs):
# Make sure the async work takes place after all prior operations on
# the primary stream. synchronize() is intentionally chosen instead of
# wait() here; otherwise, blocking copies will stall subsequent CUDA
# API calls on the main thread
ready.synchronize()

# Do the work
result = func(*args, **kwargs)

# Work submitted to the async worker is expected to block at the end,
# consistent with the semantics of futures; make sure that we wait for
# everything to complete
torch.cuda.current_stream().synchronize()

return result

def _async_worker_submit(self, func, /, *args, **kwargs):
if self._async_worker_active():
# Record an event on the main thread/stream that we will
# synchronize with on the worker thread/stream
ready = torch.cuda.Event()
ready.record()

# Submit the async work
result = self._async_worker.submit(self._async_worker_run, ready, func, *args, **kwargs)

# Save the future, so that we can await it later
self._async_worker_futures.append(result)

return result
else:
# If the async worker is not in use, just execute the function
return func(*args, **kwargs)

def _copy_to_host(self, src: torch.Tensor) -> torch.Tensor:
dest = torch.empty_like(src, device="cpu", pin_memory=True)
self._async_worker_submit(dest.copy_, src, non_blocking=True)
return dest

def _sampler_event_get(self) -> SamplerEvent:
cuda_event = torch.cuda.Event()
cuda_event.record()

# Transfer ownership to worker_futures and re-initialize
if self._async_worker_active():
worker_futures = self._async_worker_futures
self._async_worker_futures = []
else:
worker_futures = None

return SamplerEvent(cuda_event=cuda_event, worker_futures=worker_futures)


class TorchSampler(Sampler, AsyncWorkerMixin):
SampleState = SampleStateTorch

@override
Expand All @@ -616,6 +727,7 @@ class Args:
max_beam_width: int
max_total_draft_tokens: int
disable_flash_infer_sampling: bool = False
enable_async_worker: bool = False

def __init__(self, args: Args):
self.max_seq_len = args.max_seq_len
Expand Down Expand Up @@ -662,6 +774,8 @@ def __init__(self, args: Args):
self._global_seed = 42
self._generator = None

self._async_worker_init(args.enable_async_worker)

def get_generator(self, device: torch.device) -> torch.Generator:
"""Get a deterministic generator for the specified device.

Expand Down Expand Up @@ -950,7 +1064,7 @@ def _tree_sampling_batch(
new_draft_tokens_cuda.transpose(0, 1).to(torch.int, non_blocking=True).unsqueeze(dim=-1)
)

new_draft_tokens_host = int_new_draft_tokens.to("cpu", non_blocking=True)
new_draft_tokens_host = self._copy_to_host(int_new_draft_tokens)

return new_draft_tokens_host

Expand Down Expand Up @@ -1169,10 +1283,9 @@ def sample_async(
self._write_finish_reasons(
requests, finish_reasons=finish_reasons, seq_slots=seq_slots, new_tokens=new_tokens
)
finish_reasons_host = finish_reasons.to(device="cpu", non_blocking=True)
finish_reasons_host = self._copy_to_host(finish_reasons)

sampler_event = torch.cuda.Event()
sampler_event.record()
sampler_event = self._sampler_event_get()
return SampleStateTorch(
scheduled_requests=scheduled_requests,
device=SampleStateTensors(new_tokens=new_tokens),
Expand Down Expand Up @@ -1432,7 +1545,7 @@ def _dims_canonically_ordered(t: torch.Tensor) -> bool:
new_tokens_cuda.view(-1, *new_tokens_cuda.shape[2:])[:, beam, ...].scatter_(
0, batch_dest_indices_1d_cuda, batch_next_tokens_cuda_int
)
new_tokens_host = new_tokens_cuda.to("cpu", non_blocking=True)
new_tokens_host = self._copy_to_host(new_tokens_cuda)

return new_tokens_host

Expand Down Expand Up @@ -1793,10 +1906,8 @@ def _process_requests(
logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1
)
# Use a single D2H copy to reduce overheads
topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=True)
topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=True)
topk_vals.copy_(topk_vals_cuda, non_blocking=True)
topk_indices.copy_(topk_indices_cuda, non_blocking=True)
topk_vals = self._copy_to_host(topk_vals_cuda)
topk_indices = self._copy_to_host(topk_indices_cuda)
current_offset = 0
for req_id, steps in zip(
logprobs_req_indices, req_num_steps[logprobs_req_indices].tolist()
Expand Down Expand Up @@ -1875,7 +1986,7 @@ class SampleStateTRTLLM(SampleState):
host: Optional[SampleStateTensorsHostTRTLLM] = None


class TRTLLMSampler(Sampler):
class TRTLLMSampler(Sampler, AsyncWorkerMixin):
MAX_DECODING_TOKENS = 1 # It must be 1 when not in speculative decoding
SampleState = SampleStateTRTLLM

Expand All @@ -1895,6 +2006,7 @@ def __init__(
max_beam_width: int,
decoding_config: Optional[DecodingConfig] = None,
kv_cache_config: Optional[KvCacheConfig] = None,
enable_async_worker: bool = False,
):
vocab_size = model.config.vocab_size
num_hidden_layers = model.config.num_hidden_layers
Expand Down Expand Up @@ -1945,6 +2057,8 @@ def __init__(
self._initialize_store()
self._instantiate_algorithms()

self._async_worker_init(enable_async_worker)

def _initialize_store(self):
torch_stream = torch.cuda.current_stream().cuda_stream
cuda_stream = CudaStream(torch_stream)
Expand Down Expand Up @@ -2102,17 +2216,17 @@ def sample_async(
finalize_events[request.request_id] = self._finalize_request(request, False)
elif request.streaming:
finalize_events[request.request_id] = self._finalize_request(request, True)
gathered_ids = self.store["decoder_state"].gathered_ids.to("cpu", non_blocking=True)
new_output_tokens = self.store["decoder_state"].all_new_tokens.to("cpu", non_blocking=True)
finished_sum = self.store["decoder_state"].finished_sum.to("cpu", non_blocking=True)
finish_reasons = self.store["decoder_state"].finish_reasons.to("cpu", non_blocking=True)
sequence_lengths = self.store["decoder_state"].sequence_lengths.to("cpu", non_blocking=True)
gathered_ids = self._copy_to_host(self.store["decoder_state"].gathered_ids)
new_output_tokens = self._copy_to_host(self.store["decoder_state"].all_new_tokens)
finished_sum = self._copy_to_host(self.store["decoder_state"].finished_sum)
finish_reasons = self._copy_to_host(self.store["decoder_state"].finish_reasons)
sequence_lengths = self._copy_to_host(self.store["decoder_state"].sequence_lengths)

log_probs = None
cum_log_probs = None
if any(request.py_return_log_probs for request in scheduled_requests.all_requests()):
log_probs = self.store["decoder_state"].log_probs.to("cpu", non_blocking=True)
cum_log_probs = self.store["decoder_state"].cum_log_probs.to("cpu", non_blocking=True)
log_probs = self._copy_to_host(self.store["decoder_state"].log_probs)
cum_log_probs = self._copy_to_host(self.store["decoder_state"].cum_log_probs)

device = SampleStateTensors(new_tokens=self.store["decoder_state"].all_new_tokens)

Expand All @@ -2126,8 +2240,7 @@ def sample_async(
gathered_ids=gathered_ids,
)

sampler_event = torch.cuda.Event()
sampler_event.record()
sampler_event = self._sampler_event_get()

self.micro_batch_idx = (self.micro_batch_idx + 1) % self.num_micro_batches

Expand Down
Loading