@@ -112,7 +112,7 @@ def top_p_sampling_batch(
112112 top_p : float ,
113113 temperature : float ,
114114 generator : Optional [torch .Generator ] = None ,
115- ) -> tuple [torch .Tensor , torch .Tensor , Optional [ torch . Tensor ] ]:
115+ ) -> tuple [torch .Tensor , torch .Tensor ]:
116116 # NB: To be replaced by a more efficient implementation.
117117 return top_k_top_p_sampling_batch (
118118 logits ,
@@ -128,7 +128,7 @@ def temperature_sampling_batch(
128128 * ,
129129 temperature : float ,
130130 generator : Optional [torch .Generator ] = None ,
131- ) -> tuple [torch .Tensor , torch .Tensor , Optional [ torch . Tensor ] ]:
131+ ) -> tuple [torch .Tensor , torch .Tensor ]:
132132 # NB: To be replaced by a more efficient implementation.
133133 return top_k_top_p_sampling_batch (
134134 logits ,
@@ -146,20 +146,7 @@ def top_k_top_p_sampling_batch(
146146 top_p : float ,
147147 temperature : float ,
148148 generator : Optional [torch .Generator ] = None ,
149- ) -> tuple [torch .Tensor , torch .Tensor , Optional [torch .Tensor ]]:
150- """
151- Perform top-k and top-p sampling.
152-
153- Args:
154- logits: Input logits tensor [batch_size, vocab_size]
155- top_k: Top-k value
156- top_p: Top-p (nucleus sampling) value
157- temperature: Temperature for sampling
158- generator: Optional torch random generator
159-
160- Returns:
161- Tuple of (sampled_tokens, softmax_probs)
162- """
149+ ) -> tuple [torch .Tensor , torch .Tensor ]:
163150 logits_dim = logits .dim ()
164151 assert logits_dim == 2 , "logits should be 2D: [batch_size, vocab_size]"
165152 assert temperature > 0 , "non-greedy sampling requires valid temperature"
@@ -212,16 +199,6 @@ def greedy_search_sampling_batch(
212199 * ,
213200 return_probs : bool = True ,
214201) -> tuple [torch .Tensor , Optional [torch .Tensor ]]:
215- """
216- Perform greedy sampling.
217-
218- Args:
219- logits: Input logits tensor
220- return_probs: If True, return softmax probabilities
221-
222- Returns:
223- Tuple of (sampled_tokens, softmax_probs)
224- """
225202 next_tokens = torch .argmax (logits , dim = - 1 )
226203 softmax : Optional [torch .Tensor ] = None
227204 if return_probs :
0 commit comments