diff --git a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py index d0b93c2bd19..0f49bcb9aef 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py @@ -235,11 +235,11 @@ def _sample( logits_shape = logits.shape logits = logits.view(-1, logits_shape[-1]) # sampling_batch expects 2D logits if isinstance(sampling_params.top_k, int) and sampling_params.top_k > 1: - idx_next, probs = top_k_sampling_batch( + idx_next, probs, _ = top_k_sampling_batch( logits, top_k=sampling_params.top_k, temperature=1.0 ) else: - idx_next, probs = greedy_search_sampling_batch(logits) + idx_next, probs, _ = greedy_search_sampling_batch(logits) idx_next = idx_next.view(logits_shape[:-1]) return idx_next, probs diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 01d3f35f876..b65123cd786 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -438,6 +438,8 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest): """LlmRequest wraps `bindings.internal.batch_manager.LlmRequest` but detour some features to Python implementation""" + _logprob_params = None + def __init__( self, *args, @@ -797,6 +799,8 @@ def executor_request_to_llm_request( py_multimodal_data=getattr(executor_request, "py_multimodal_data", None), kv_cache_retention_config=executor_request.kv_cache_retention_config) + llm_request._logprob_params = getattr(executor_request, "_logprob_params", + None) if child_req_ids: for child_id in child_req_ids: llm_request.create_child_request(child_id) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 5757f8efbc7..2f586794b86 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -974,7 +974,7 @@ def _process_draft_tokens_rejection_sampling( else _request_strategy(request, vocab_size=2**31) ) generator = self.get_generator(request.py_draft_logits.device) - _, draft_probs = sample( + _, draft_probs, _ = sample( draft_sampling_strategy, request.py_draft_logits, generator=generator, @@ -1776,7 +1776,16 @@ def _process_requests( # Handle top-k logprobs. This is done outside the sampling loop, # because the returned logprobs are specified to not reflect temperature scaling, # top-k/top-p masking, etc. + if return_log_probs: + logprobs_mode = None + for req in requests: + if req.py_num_logprobs: + logprob_params = getattr(req, "_logprob_params", None) + if logprob_params and hasattr(logprob_params, "logprobs_mode"): + logprobs_mode = logprob_params.logprobs_mode + break + assert logits_cuda.dim() == 2, "logits should be 2D" logprobs_req_indices = [ req_id for req_id, req in enumerate(requests) if req.py_num_logprobs @@ -1785,10 +1794,34 @@ def _process_requests( logprobs_logit_indices_cuda = logprobs_logit_indices.to( device=logits_cuda.device, non_blocking=True ) - logprobs_cuda = F.log_softmax( - logits_cuda[logprobs_logit_indices_cuda].to(dtype=torch.float32, non_blocking=True), - dim=-1, - ) + + # Compute logprobs based on mode + if logprobs_mode == "processed_logprobs": + # Process logits with the same transformations as sampling (temperature, top-k, top-p) + # but without actually sampling + logprobs_list = [] + for req_id in logprobs_req_indices: + req = requests[req_id] + strategy = _request_strategy(req, vocab_size=logits_cuda.size(1)) + req_logits_indices = logits_cuda_indexer[req_id] + req_logits = logits_cuda[req_logits_indices].to( + dtype=torch.float32, non_blocking=True + ) + # Use sample() to get processed logprobs (after temperature, top-k, top-p applied) + _, _, req_logprobs = sample(strategy, req_logits, return_probs=True) + logprobs_list.append(req_logprobs) + # Concatenate all logprobs + logprobs_cuda = torch.cat(logprobs_list, dim=0) + else: + # For raw_logprobs and other modes, use raw logits (before sampling modifications) + raw_logits_for_logprobs = raw_logits_cuda[:sum_steps] + logprobs_cuda = F.log_softmax( + raw_logits_for_logprobs[logprobs_logit_indices_cuda].to( + dtype=torch.float32, non_blocking=True + ), + dim=-1, + ) + topk_vals_cuda, topk_indices_cuda = torch.topk( logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1 ) diff --git a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py index 35e64afe4c2..148286c09e2 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py @@ -24,6 +24,7 @@ from typing import Generic, Literal, Optional, TypeAlias, TypeVar, cast import torch +import torch.nn.functional as F from tensorrt_llm.sampling_params import SamplingParams @@ -95,7 +96,7 @@ def top_k_sampling_batch( top_k: int, temperature: float, generator: Optional[torch.Generator] = None, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # NB: To be replaced by a more efficient implementation. return top_k_top_p_sampling_batch( logits, @@ -112,7 +113,7 @@ def top_p_sampling_batch( top_p: float, temperature: float, generator: Optional[torch.Generator] = None, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # NB: To be replaced by a more efficient implementation. return top_k_top_p_sampling_batch( logits, @@ -128,7 +129,7 @@ def temperature_sampling_batch( *, temperature: float, generator: Optional[torch.Generator] = None, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # NB: To be replaced by a more efficient implementation. return top_k_top_p_sampling_batch( logits, @@ -146,7 +147,7 @@ def top_k_top_p_sampling_batch( top_p: float, temperature: float, generator: Optional[torch.Generator] = None, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: logits_dim = logits.dim() assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]" assert temperature > 0, "non-greedy sampling requires valid temperature" @@ -189,21 +190,26 @@ def top_k_top_p_sampling_batch( # compute probability distribution softmax = torch.softmax(logits, dim=-1) + # compute log probabilities + logprobs = F.log_softmax(logits, dim=-1) + # sample from the distribution and generate result of [batch_size, 1] next_tokens = torch.multinomial(softmax, num_samples=1, generator=generator).squeeze(-1) - return next_tokens, softmax + return next_tokens, softmax, logprobs def greedy_search_sampling_batch( logits, *, return_probs: bool = True, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: next_tokens = torch.argmax(logits, dim=-1) softmax: Optional[torch.Tensor] = None + logprobs: Optional[torch.Tensor] = None if return_probs: softmax = torch.softmax(logits, dim=-1) - return next_tokens, softmax + logprobs = F.log_softmax(logits, dim=-1) + return next_tokens, softmax, logprobs def get_rejected_indices( @@ -254,24 +260,36 @@ def sample( *, generator: Optional[torch.Generator] = None, return_probs: bool = True, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Sample from logits using the specified strategy. + + Args: + strategy: Sampling strategy tuple (strategy_name, *params) + logits: Input logits tensor + generator: Optional random generator + return_probs: If True, return softmax probabilities and log probabilities + + Returns: + Tuple of (sampled_tokens, softmax_probs, logprobs) + """ match strategy: case ("top_k", top_k, temperature): - tokens, softmax = top_k_sampling_batch( + tokens, softmax, logprobs = top_k_sampling_batch( logits, top_k=top_k, temperature=temperature, generator=generator, ) case ("top_p", top_p, temperature): - tokens, softmax = top_p_sampling_batch( + tokens, softmax, logprobs = top_p_sampling_batch( logits, top_p=top_p, generator=generator, temperature=temperature, ) case ("top_k_top_p", top_k, top_p, temperature): - tokens, softmax = top_k_top_p_sampling_batch( + tokens, softmax, logprobs = top_k_top_p_sampling_batch( logits, top_k=top_k, top_p=top_p, @@ -279,14 +297,16 @@ def sample( generator=generator, ) case ("temperature", temperature): - tokens, softmax = temperature_sampling_batch( + tokens, softmax, logprobs = temperature_sampling_batch( logits, temperature=temperature, generator=generator, ) case ("greedy", None): - tokens, softmax = greedy_search_sampling_batch(logits, return_probs=return_probs) - return tokens, softmax + tokens, softmax, logprobs = greedy_search_sampling_batch( + logits, return_probs=return_probs + ) + return tokens, softmax, logprobs GenericStrategyKeyType = TypeVar("GenericStrategyKeyType") @@ -338,12 +358,13 @@ def sample_grouped_strategies( assert all(strategy == group_key for strategy in strategies), "group must be consistent" - return sample( + tokens, probs, _ = sample( group_key, logits, generator=generator, return_probs=return_probs, ) + return tokens, probs class _AcceptSyncCompute: diff --git a/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py b/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py index f8ce56a1672..734d320e079 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py +++ b/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py @@ -137,7 +137,7 @@ def _sample_greedy_with_probs( group_logit_indices: Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: probs = self._prepare_probs_with_temperature(logits, group_logit_indices, None) - new_tokens, _ = greedy_search_sampling_batch(probs, return_probs=False) + new_tokens, _, _ = greedy_search_sampling_batch(probs, return_probs=False) return new_tokens, probs @classmethod @@ -370,7 +370,8 @@ def sample( ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if group_logit_indices is not None: logits = logits[group_logit_indices] - return greedy_search_sampling_batch(logits, return_probs=False) + tokens, probs, _ = greedy_search_sampling_batch(logits, return_probs=False) + return tokens, probs class TopKTopPSampleOnly(StrategyImplSampleOnly): def __init__(self, top_k: torch.Tensor, top_p: torch.Tensor, temperature: torch.Tensor): diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index 65423e3f8ed..6a2d2d35e8f 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -569,6 +569,11 @@ def _deduce_max_tokens(request: GenerationRequest, if self._is_pytorch_backend and request.scheduling_params is not None: executor_request.py_scheduling_params = request.scheduling_params + if self._is_pytorch_backend: + logprob_params = self._get_logprob_params(request) + if logprob_params is not None: + executor_request._logprob_params = logprob_params + if request.arrival_time is not None: executor_request.py_arrival_time = request.arrival_time diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index e7ab9192ad1..a204238780f 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -234,7 +234,8 @@ def _get_logprob_params( or self.postproc_config.num_postprocess_workers > 0, drop_generation_logits=( not request.sampling_params._need_return_generation_logits) - or self.postproc_config.num_postprocess_workers > 0) + or self.postproc_config.num_postprocess_workers > 0, + logprobs_mode=request.sampling_params.logprobs_mode) return logprob_params diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index d47743cf8f0..f92240e45c0 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -1016,6 +1016,13 @@ def compute_logprobs( - Generation logprobs (from generation_logits, TRT backend): used when backend doesn't compute them in sampler (e.g., TRT). - Generation logprobs (PyTorch backend): not used; computed in sampler, not here. + Args: + k_prompt_logprobs: Number of top logprobs to return for prompt tokens + k_logprobs: Number of top logprobs to return for generated tokens + context_logits: Logits for context/prompt tokens + generation_logits: Logits for generated tokens + output_token_ids: Token IDs of generated outputs + Returns: LogProbsResult, a NamedTuple containing: - prompt: Optional[List[Dict[token_id, Logprob]]] logprobs for prompt tokens. @@ -1034,6 +1041,7 @@ def _topk_logprobs(logits: torch.Tensor, top_k: int, logits = logits[:len(tokens)] logprobs = F.log_softmax(logits.to("cuda", dtype=torch.float32), dim=-1) + topk_vals, topk_indices = torch.topk(logprobs, k=top_k, dim=-1) results: TokenLogprobs = [] diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index b7ad63821ad..9ba52e464f5 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -2,7 +2,7 @@ import os from abc import ABC, abstractmethod from dataclasses import dataclass, field, fields -from typing import List, NamedTuple, Optional, Tuple, Union +from typing import List, Literal, NamedTuple, Optional, Tuple, Union import torch from pydantic import BaseModel @@ -10,6 +10,11 @@ from tensorrt_llm.bindings import executor as tllme from tensorrt_llm.logger import logger +# Logprobs mode: +# - "processed_logprobs": return log-softmax of greedy sampled logits +# TODO: add "return_raw_context_logits" and "return_raw_generation_logits" later +LogprobsMode = Literal["processed_logprobs"] + @dataclass(slots=True, kw_only=True) class GuidedDecodingParams: @@ -44,6 +49,8 @@ class LogprobParams(NamedTuple): drop_context_logits: bool = False # Drop the geneation_logits once the logprobs are computed drop_generation_logits: bool = False + # Logprobs mode: controls whether to return logprobs before or after sampling modifications + logprobs_mode: LogprobsMode = "processed_logprobs" class LogitsProcessor(ABC): @@ -174,6 +181,9 @@ class SamplingParams: logprobs (int, optional): Number of log probabilities to return per output token. Defaults to None. prompt_logprobs (int, optional): Number of log probabilities to return per prompt token. Defaults to None. + logprobs_mode (Literal['processed_logprobs']): Controls return logprobs after sampling modifications. Defaults to "processed_logprobs". + Options: + - "processed_logprobs": Return log-softmax of processed logits return_context_logits (bool): Controls if Result should contain the context logits. Defaults to False. return_generation_logits (bool): Controls if Result should contain the generation logits. Defaults to False. exclude_input_from_output (bool): Controls if output tokens in Result should include the input tokens. Defaults to True. @@ -250,6 +260,9 @@ class SamplingParams: return_perf_metrics: bool = False additional_model_outputs: Optional[List[str]] = None + # Logprobs mode: controls whether to return logprobs before or after sampling modifications + logprobs_mode: LogprobsMode = "processed_logprobs" + # Used in logprobs calculation in TRT flow to drop logits early if user did not explicitly request them. # Can be deprecated after migration to PyTorch backend. _context_logits_auto_enabled: bool = False diff --git a/tests/unittest/api_stability/references/sampling_params.yaml b/tests/unittest/api_stability/references/sampling_params.yaml index d6b3e6156e3..948aee0b654 100644 --- a/tests/unittest/api_stability/references/sampling_params.yaml +++ b/tests/unittest/api_stability/references/sampling_params.yaml @@ -15,5 +15,8 @@ methods: prompt_ignore_length: annotation: Optional[int] default: null + logprobs_mode: + annotation: Literal['processed_logprobs'] + default: processed_logprobs return_annotation: None properties: {} diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 1bdd2dfbeb5..df0d4c7b184 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -917,6 +917,99 @@ def test_llm_return_logprobs_streaming(prompt_logprobs, logprobs, backend="pytorch") +@skip_ray +@pytest.mark.parametrize("temperature", [None, 0.8, 1.0]) +@pytest.mark.parametrize("top_k", [None, 10, 0]) +@pytest.mark.parametrize("top_p", [None, 0.5, 1.0]) +# temperature: 0.0 is greedy sampling and will be covered by below test +# top_k: 0 means all logits +# top_p: 1 means no top-p filtering +def test_llm_logprobs_modes_basic(temperature, top_k, top_p): + """ + Test processed_logprobs mode works correctly in PyTorch backend. + Validates that: + - processed_logprobs returns non-positive values (log probabilities) + """ + llm = LLM( + llama_model_path, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7), + ) + + prompts = ["The future of AI is"] + sampling_params = SamplingParams( + max_tokens=5, + logprobs=3, + temperature=temperature, + top_k=top_k, + top_p=top_p, + logprobs_mode="processed_logprobs", + seed=42, + return_context_logits=True, + return_generation_logits=True, + ) + + outputs = list(llm.generate(prompts, sampling_params)) + assert len(outputs) == 1 + + output = outputs[0] + assert len(output.outputs) == 1 + logprobs_list = output.outputs[0].logprobs + + assert logprobs_list is not None + assert len(logprobs_list) > 0 + + # Collect all logprob values + all_values = [] + for token_logprobs in logprobs_list: + for logprob_obj in token_logprobs.values(): + all_values.append(logprob_obj.logprob) + + # Validate that processed_logprobs returns non-positive values (log probabilities) + for val in all_values: + assert val <= 0.0, f"processed_logprobs should have non-positive values, got {val}" + + del llm + + +@skip_ray +def test_llm_logprobs_mode_backward_compatibility(): + """ + Test that default behavior without specifying logprobs_mode. + """ + llm = LLM( + llama_model_path, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7), + ) + + prompt = ["once upon a time"] + + # Explicit processed_logprobs + explicit_params = SamplingParams( + max_tokens=10, + logprobs=2, + logprobs_mode="processed_logprobs", + seed=123, + ) + explicit_outputs = list(llm.generate(prompt, explicit_params)) + + # Default (should be processed_logprobs) + default_params = SamplingParams( + max_tokens=10, + logprobs=2, + seed=123, + ) + default_outputs = list(llm.generate(prompt, default_params)) + + # Should produce same tokens + explicit_tokens = explicit_outputs[0].outputs[0].token_ids + default_tokens = default_outputs[0].outputs[0].token_ids + + assert explicit_tokens == default_tokens, ( + "Default should match explicit processed_logprobs") + + del llm + + class TestLlmError: def test_max_num_token_check(self):