Skip to content

Commit c9bade2

Browse files
Use function sample instead of process logits and change based on review comment.
Signed-off-by: Wangshanshan <[email protected]>
1 parent 30e7b79 commit c9bade2

File tree

4 files changed

+44
-216
lines changed

4 files changed

+44
-216
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,8 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
438438
"""LlmRequest wraps `bindings.internal.batch_manager.LlmRequest`
439439
but detour some features to Python implementation"""
440440

441+
_logprob_params = None
442+
441443
def __init__(
442444
self,
443445
*args,
@@ -797,8 +799,8 @@ def executor_request_to_llm_request(
797799
py_multimodal_data=getattr(executor_request, "py_multimodal_data",
798800
None),
799801
kv_cache_retention_config=executor_request.kv_cache_retention_config)
800-
if hasattr(executor_request, "_logprob_params"):
801-
llm_request._logprob_params = executor_request._logprob_params
802+
llm_request._logprob_params = getattr(executor_request, "_logprob_params",
803+
None)
802804
if child_req_ids:
803805
for child_id in child_req_ids:
804806
llm_request.create_child_request(child_id)

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@
6868
Strategy,
6969
UtilsSamplingParams,
7070
get_rejected_indices,
71-
process_logits,
7271
resolve_sampling_strategy,
7372
sample,
7473
sample_rejected,
@@ -975,7 +974,7 @@ def _process_draft_tokens_rejection_sampling(
975974
else _request_strategy(request, vocab_size=2**31)
976975
)
977976
generator = self.get_generator(request.py_draft_logits.device)
978-
_, draft_probs = sample(
977+
_, draft_probs, _ = sample(
979978
draft_sampling_strategy,
980979
request.py_draft_logits,
981980
generator=generator,
@@ -1800,21 +1799,19 @@ def _process_requests(
18001799
if logprobs_mode == "processed_logprobs":
18011800
# Process logits with the same transformations as sampling (temperature, top-k, top-p)
18021801
# but without actually sampling
1803-
processed_logits_list = []
1802+
logprobs_list = []
18041803
for req_id in logprobs_req_indices:
18051804
req = requests[req_id]
18061805
strategy = _request_strategy(req, vocab_size=logits_cuda.size(1))
18071806
req_logits_indices = logits_cuda_indexer[req_id]
18081807
req_logits = logits_cuda[req_logits_indices].to(
18091808
dtype=torch.float32, non_blocking=True
18101809
)
1811-
# Apply the same processing as sampling would apply
1812-
processed_req_logits = process_logits(strategy, req_logits)
1813-
processed_logits_list.append(processed_req_logits)
1814-
# Concatenate all processed logits
1815-
processed_logits_cuda = torch.cat(processed_logits_list, dim=0)
1816-
# Apply log_softmax to get log probabilities
1817-
logprobs_cuda = F.log_softmax(processed_logits_cuda, dim=-1)
1810+
# Use sample() to get processed logprobs (after temperature, top-k, top-p applied)
1811+
_, _, req_logprobs = sample(strategy, req_logits, return_probs=True)
1812+
logprobs_list.append(req_logprobs)
1813+
# Concatenate all logprobs
1814+
logprobs_cuda = torch.cat(logprobs_list, dim=0)
18181815
else:
18191816
# For raw_logprobs and other modes, use raw logits (before sampling modifications)
18201817
raw_logits_for_logprobs = raw_logits_cuda[:sum_steps]

tensorrt_llm/_torch/pyexecutor/sampling_utils.py

Lines changed: 25 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import Generic, Literal, Optional, TypeAlias, TypeVar, cast
2525

2626
import torch
27+
import torch.nn.functional as F
2728

2829
from tensorrt_llm.sampling_params import SamplingParams
2930

@@ -95,7 +96,7 @@ def top_k_sampling_batch(
9596
top_k: int,
9697
temperature: float,
9798
generator: Optional[torch.Generator] = None,
98-
) -> tuple[torch.Tensor, torch.Tensor]:
99+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
99100
# NB: To be replaced by a more efficient implementation.
100101
return top_k_top_p_sampling_batch(
101102
logits,
@@ -112,7 +113,7 @@ def top_p_sampling_batch(
112113
top_p: float,
113114
temperature: float,
114115
generator: Optional[torch.Generator] = None,
115-
) -> tuple[torch.Tensor, torch.Tensor]:
116+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
116117
# NB: To be replaced by a more efficient implementation.
117118
return top_k_top_p_sampling_batch(
118119
logits,
@@ -128,7 +129,7 @@ def temperature_sampling_batch(
128129
*,
129130
temperature: float,
130131
generator: Optional[torch.Generator] = None,
131-
) -> tuple[torch.Tensor, torch.Tensor]:
132+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
132133
# NB: To be replaced by a more efficient implementation.
133134
return top_k_top_p_sampling_batch(
134135
logits,
@@ -146,7 +147,7 @@ def top_k_top_p_sampling_batch(
146147
top_p: float,
147148
temperature: float,
148149
generator: Optional[torch.Generator] = None,
149-
) -> tuple[torch.Tensor, torch.Tensor]:
150+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
150151
logits_dim = logits.dim()
151152
assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]"
152153
assert temperature > 0, "non-greedy sampling requires valid temperature"
@@ -189,21 +190,26 @@ def top_k_top_p_sampling_batch(
189190
# compute probability distribution
190191
softmax = torch.softmax(logits, dim=-1)
191192

193+
# compute log probabilities
194+
logprobs = F.log_softmax(logits, dim=-1)
195+
192196
# sample from the distribution and generate result of [batch_size, 1]
193197
next_tokens = torch.multinomial(softmax, num_samples=1, generator=generator).squeeze(-1)
194-
return next_tokens, softmax
198+
return next_tokens, softmax, logprobs
195199

196200

197201
def greedy_search_sampling_batch(
198202
logits,
199203
*,
200204
return_probs: bool = True,
201-
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
205+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
202206
next_tokens = torch.argmax(logits, dim=-1)
203207
softmax: Optional[torch.Tensor] = None
208+
logprobs: Optional[torch.Tensor] = None
204209
if return_probs:
205210
softmax = torch.softmax(logits, dim=-1)
206-
return next_tokens, softmax
211+
logprobs = F.log_softmax(logits, dim=-1)
212+
return next_tokens, softmax, logprobs
207213

208214

209215
def get_rejected_indices(
@@ -248,71 +254,6 @@ def sample_rejected(
248254
return cast(int, new_token.item())
249255

250256

251-
def process_logits(
252-
strategy: Strategy,
253-
logits: torch.Tensor,
254-
) -> torch.Tensor:
255-
"""
256-
Process logits according to the specified strategy (temperature, top-k, top-p)
257-
without sampling. Returns processed logits ready for log_softmax.
258-
259-
Args:
260-
strategy: Sampling strategy tuple (strategy_name, *params)
261-
logits: Input logits tensor [batch_size, vocab_size]
262-
263-
Returns:
264-
Processed logits tensor [batch_size, vocab_size]
265-
"""
266-
logits = logits.clone()
267-
match strategy:
268-
case ("top_k", top_k, temperature):
269-
logits = logits / max(temperature, 1e-5)
270-
batch_size, vocab_size = logits.size()
271-
if top_k < vocab_size:
272-
values, _ = torch.topk(logits, top_k, dim=-1)
273-
min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size)
274-
logits = torch.where(
275-
logits < min_values, torch.full_like(logits, float("-inf")), logits
276-
)
277-
case ("top_p", top_p, temperature):
278-
logits = logits / max(temperature, 1e-5)
279-
if top_p < 1:
280-
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
281-
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
282-
sorted_indices_to_remove = cumulative_probs > top_p
283-
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
284-
sorted_indices_to_remove[:, 0] = 0
285-
indices_to_remove = sorted_indices_to_remove.scatter(
286-
1, sorted_indices, sorted_indices_to_remove
287-
)
288-
logits = logits.masked_fill(indices_to_remove, float("-inf"))
289-
case ("top_k_top_p", top_k, top_p, temperature):
290-
logits = logits / max(temperature, 1e-5)
291-
batch_size, vocab_size = logits.size()
292-
if top_k < vocab_size:
293-
values, _ = torch.topk(logits, top_k, dim=-1)
294-
min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size)
295-
logits = torch.where(
296-
logits < min_values, torch.full_like(logits, float("-inf")), logits
297-
)
298-
if top_p < 1:
299-
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
300-
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
301-
sorted_indices_to_remove = cumulative_probs > top_p
302-
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
303-
sorted_indices_to_remove[:, 0] = 0
304-
indices_to_remove = sorted_indices_to_remove.scatter(
305-
1, sorted_indices, sorted_indices_to_remove
306-
)
307-
logits = logits.masked_fill(indices_to_remove, float("-inf"))
308-
case ("temperature", temperature):
309-
logits = logits / max(temperature, 1e-5)
310-
case ("greedy", None):
311-
# No processing needed for greedy
312-
pass
313-
return logits
314-
315-
316257
def sample(
317258
strategy: Strategy,
318259
logits: torch.Tensor,
@@ -327,43 +268,45 @@ def sample(
327268
strategy: Sampling strategy tuple (strategy_name, *params)
328269
logits: Input logits tensor
329270
generator: Optional random generator
330-
return_probs: If True, return softmax probabilities
271+
return_probs: If True, return softmax probabilities and log probabilities
331272
332273
Returns:
333-
Tuple of (sampled_tokens, softmax_probs)
274+
Tuple of (sampled_tokens, softmax_probs, logprobs)
334275
"""
335276
match strategy:
336277
case ("top_k", top_k, temperature):
337-
tokens, softmax = top_k_sampling_batch(
278+
tokens, softmax, logprobs = top_k_sampling_batch(
338279
logits,
339280
top_k=top_k,
340281
temperature=temperature,
341282
generator=generator,
342283
)
343284
case ("top_p", top_p, temperature):
344-
tokens, softmax = top_p_sampling_batch(
285+
tokens, softmax, logprobs = top_p_sampling_batch(
345286
logits,
346287
top_p=top_p,
347288
generator=generator,
348289
temperature=temperature,
349290
)
350291
case ("top_k_top_p", top_k, top_p, temperature):
351-
tokens, softmax = top_k_top_p_sampling_batch(
292+
tokens, softmax, logprobs = top_k_top_p_sampling_batch(
352293
logits,
353294
top_k=top_k,
354295
top_p=top_p,
355296
temperature=temperature,
356297
generator=generator,
357298
)
358299
case ("temperature", temperature):
359-
tokens, softmax = temperature_sampling_batch(
300+
tokens, softmax, logprobs = temperature_sampling_batch(
360301
logits,
361302
temperature=temperature,
362303
generator=generator,
363304
)
364305
case ("greedy", None):
365-
tokens, softmax = greedy_search_sampling_batch(logits, return_probs=return_probs)
366-
return tokens, softmax
306+
tokens, softmax, logprobs = greedy_search_sampling_batch(
307+
logits, return_probs=return_probs
308+
)
309+
return tokens, softmax, logprobs
367310

368311

369312
GenericStrategyKeyType = TypeVar("GenericStrategyKeyType")
@@ -415,12 +358,13 @@ def sample_grouped_strategies(
415358

416359
assert all(strategy == group_key for strategy in strategies), "group must be consistent"
417360

418-
return sample(
361+
tokens, probs, _ = sample(
419362
group_key,
420363
logits,
421364
generator=generator,
422365
return_probs=return_probs,
423366
)
367+
return tokens, probs
424368

425369

426370
class _AcceptSyncCompute:

0 commit comments

Comments
 (0)