2424from typing import Generic , Literal , Optional , TypeAlias , TypeVar , cast
2525
2626import torch
27+ import torch .nn .functional as F
2728
2829from tensorrt_llm .sampling_params import SamplingParams
2930
@@ -95,7 +96,7 @@ def top_k_sampling_batch(
9596 top_k : int ,
9697 temperature : float ,
9798 generator : Optional [torch .Generator ] = None ,
98- ) -> tuple [torch .Tensor , torch .Tensor ]:
99+ ) -> tuple [torch .Tensor , torch .Tensor , torch . Tensor ]:
99100 # NB: To be replaced by a more efficient implementation.
100101 return top_k_top_p_sampling_batch (
101102 logits ,
@@ -112,7 +113,7 @@ def top_p_sampling_batch(
112113 top_p : float ,
113114 temperature : float ,
114115 generator : Optional [torch .Generator ] = None ,
115- ) -> tuple [torch .Tensor , torch .Tensor ]:
116+ ) -> tuple [torch .Tensor , torch .Tensor , torch . Tensor ]:
116117 # NB: To be replaced by a more efficient implementation.
117118 return top_k_top_p_sampling_batch (
118119 logits ,
@@ -128,7 +129,7 @@ def temperature_sampling_batch(
128129 * ,
129130 temperature : float ,
130131 generator : Optional [torch .Generator ] = None ,
131- ) -> tuple [torch .Tensor , torch .Tensor ]:
132+ ) -> tuple [torch .Tensor , torch .Tensor , torch . Tensor ]:
132133 # NB: To be replaced by a more efficient implementation.
133134 return top_k_top_p_sampling_batch (
134135 logits ,
@@ -146,7 +147,7 @@ def top_k_top_p_sampling_batch(
146147 top_p : float ,
147148 temperature : float ,
148149 generator : Optional [torch .Generator ] = None ,
149- ) -> tuple [torch .Tensor , torch .Tensor ]:
150+ ) -> tuple [torch .Tensor , torch .Tensor , torch . Tensor ]:
150151 logits_dim = logits .dim ()
151152 assert logits_dim == 2 , "logits should be 2D: [batch_size, vocab_size]"
152153 assert temperature > 0 , "non-greedy sampling requires valid temperature"
@@ -189,21 +190,26 @@ def top_k_top_p_sampling_batch(
189190 # compute probability distribution
190191 softmax = torch .softmax (logits , dim = - 1 )
191192
193+ # compute log probabilities
194+ logprobs = F .log_softmax (logits , dim = - 1 )
195+
192196 # sample from the distribution and generate result of [batch_size, 1]
193197 next_tokens = torch .multinomial (softmax , num_samples = 1 , generator = generator ).squeeze (- 1 )
194- return next_tokens , softmax
198+ return next_tokens , softmax , logprobs
195199
196200
197201def greedy_search_sampling_batch (
198202 logits ,
199203 * ,
200204 return_probs : bool = True ,
201- ) -> tuple [torch .Tensor , Optional [torch .Tensor ]]:
205+ ) -> tuple [torch .Tensor , Optional [torch .Tensor ], Optional [ torch . Tensor ] ]:
202206 next_tokens = torch .argmax (logits , dim = - 1 )
203207 softmax : Optional [torch .Tensor ] = None
208+ logprobs : Optional [torch .Tensor ] = None
204209 if return_probs :
205210 softmax = torch .softmax (logits , dim = - 1 )
206- return next_tokens , softmax
211+ logprobs = F .log_softmax (logits , dim = - 1 )
212+ return next_tokens , softmax , logprobs
207213
208214
209215def get_rejected_indices (
@@ -248,71 +254,6 @@ def sample_rejected(
248254 return cast (int , new_token .item ())
249255
250256
251- def process_logits (
252- strategy : Strategy ,
253- logits : torch .Tensor ,
254- ) -> torch .Tensor :
255- """
256- Process logits according to the specified strategy (temperature, top-k, top-p)
257- without sampling. Returns processed logits ready for log_softmax.
258-
259- Args:
260- strategy: Sampling strategy tuple (strategy_name, *params)
261- logits: Input logits tensor [batch_size, vocab_size]
262-
263- Returns:
264- Processed logits tensor [batch_size, vocab_size]
265- """
266- logits = logits .clone ()
267- match strategy :
268- case ("top_k" , top_k , temperature ):
269- logits = logits / max (temperature , 1e-5 )
270- batch_size , vocab_size = logits .size ()
271- if top_k < vocab_size :
272- values , _ = torch .topk (logits , top_k , dim = - 1 )
273- min_values = values [:, - 1 ].unsqueeze (- 1 ).expand (batch_size , vocab_size )
274- logits = torch .where (
275- logits < min_values , torch .full_like (logits , float ("-inf" )), logits
276- )
277- case ("top_p" , top_p , temperature ):
278- logits = logits / max (temperature , 1e-5 )
279- if top_p < 1 :
280- sorted_logits , sorted_indices = torch .sort (logits , descending = True , dim = - 1 )
281- cumulative_probs = torch .cumsum (torch .softmax (sorted_logits , dim = - 1 ), dim = - 1 )
282- sorted_indices_to_remove = cumulative_probs > top_p
283- sorted_indices_to_remove [:, 1 :] = sorted_indices_to_remove [:, :- 1 ].clone ()
284- sorted_indices_to_remove [:, 0 ] = 0
285- indices_to_remove = sorted_indices_to_remove .scatter (
286- 1 , sorted_indices , sorted_indices_to_remove
287- )
288- logits = logits .masked_fill (indices_to_remove , float ("-inf" ))
289- case ("top_k_top_p" , top_k , top_p , temperature ):
290- logits = logits / max (temperature , 1e-5 )
291- batch_size , vocab_size = logits .size ()
292- if top_k < vocab_size :
293- values , _ = torch .topk (logits , top_k , dim = - 1 )
294- min_values = values [:, - 1 ].unsqueeze (- 1 ).expand (batch_size , vocab_size )
295- logits = torch .where (
296- logits < min_values , torch .full_like (logits , float ("-inf" )), logits
297- )
298- if top_p < 1 :
299- sorted_logits , sorted_indices = torch .sort (logits , descending = True , dim = - 1 )
300- cumulative_probs = torch .cumsum (torch .softmax (sorted_logits , dim = - 1 ), dim = - 1 )
301- sorted_indices_to_remove = cumulative_probs > top_p
302- sorted_indices_to_remove [:, 1 :] = sorted_indices_to_remove [:, :- 1 ].clone ()
303- sorted_indices_to_remove [:, 0 ] = 0
304- indices_to_remove = sorted_indices_to_remove .scatter (
305- 1 , sorted_indices , sorted_indices_to_remove
306- )
307- logits = logits .masked_fill (indices_to_remove , float ("-inf" ))
308- case ("temperature" , temperature ):
309- logits = logits / max (temperature , 1e-5 )
310- case ("greedy" , None ):
311- # No processing needed for greedy
312- pass
313- return logits
314-
315-
316257def sample (
317258 strategy : Strategy ,
318259 logits : torch .Tensor ,
@@ -327,43 +268,45 @@ def sample(
327268 strategy: Sampling strategy tuple (strategy_name, *params)
328269 logits: Input logits tensor
329270 generator: Optional random generator
330- return_probs: If True, return softmax probabilities
271+ return_probs: If True, return softmax probabilities and log probabilities
331272
332273 Returns:
333- Tuple of (sampled_tokens, softmax_probs)
274+ Tuple of (sampled_tokens, softmax_probs, logprobs )
334275 """
335276 match strategy :
336277 case ("top_k" , top_k , temperature ):
337- tokens , softmax = top_k_sampling_batch (
278+ tokens , softmax , logprobs = top_k_sampling_batch (
338279 logits ,
339280 top_k = top_k ,
340281 temperature = temperature ,
341282 generator = generator ,
342283 )
343284 case ("top_p" , top_p , temperature ):
344- tokens , softmax = top_p_sampling_batch (
285+ tokens , softmax , logprobs = top_p_sampling_batch (
345286 logits ,
346287 top_p = top_p ,
347288 generator = generator ,
348289 temperature = temperature ,
349290 )
350291 case ("top_k_top_p" , top_k , top_p , temperature ):
351- tokens , softmax = top_k_top_p_sampling_batch (
292+ tokens , softmax , logprobs = top_k_top_p_sampling_batch (
352293 logits ,
353294 top_k = top_k ,
354295 top_p = top_p ,
355296 temperature = temperature ,
356297 generator = generator ,
357298 )
358299 case ("temperature" , temperature ):
359- tokens , softmax = temperature_sampling_batch (
300+ tokens , softmax , logprobs = temperature_sampling_batch (
360301 logits ,
361302 temperature = temperature ,
362303 generator = generator ,
363304 )
364305 case ("greedy" , None ):
365- tokens , softmax = greedy_search_sampling_batch (logits , return_probs = return_probs )
366- return tokens , softmax
306+ tokens , softmax , logprobs = greedy_search_sampling_batch (
307+ logits , return_probs = return_probs
308+ )
309+ return tokens , softmax , logprobs
367310
368311
369312GenericStrategyKeyType = TypeVar ("GenericStrategyKeyType" )
@@ -415,12 +358,13 @@ def sample_grouped_strategies(
415358
416359 assert all (strategy == group_key for strategy in strategies ), "group must be consistent"
417360
418- return sample (
361+ tokens , probs , _ = sample (
419362 group_key ,
420363 logits ,
421364 generator = generator ,
422365 return_probs = return_probs ,
423366 )
367+ return tokens , probs
424368
425369
426370class _AcceptSyncCompute :
0 commit comments