-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Added top n log probs #2262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Added top n log probs #2262
Changes from all commits
2cbc985
724f80b
d6286fd
14247a5
bc4981c
d22a9b7
a8e4719
4d3ebf1
4a7e96e
af4cb7a
144e017
6a0935b
ffe1d77
300c130
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1749,7 +1749,7 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T | |
|
|
||
| def calculate_log_probs( | ||
| self, logits: Tensor, new_tokens: Tensor, only_last_token_logits: Optional[bool] = False | ||
| ) -> List[List[float]]: | ||
| ) -> Tuple[List[List[float]], Tensor]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a breaking change, will this break any downstream code or is this private?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good point that I missed. Subscribing to this thread; I am also very interested in the answer to this, since I intended to optimize the @wdykas I think you are the main author of this method. Do you know how much (if any) downstream dependence there should be on it?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is only used by internal code. So shouldnt be a problem. |
||
| """Calculate log probs for all active requests and return them. | ||
|
|
||
| TODO: @wdykas support top-n log probs. | ||
|
|
@@ -1762,14 +1762,16 @@ def calculate_log_probs( | |
| Returns: | ||
| List of lists where each inner list contains log probs for a request in the | ||
| same order as the active requests (from paused_request_count to total_request_count). | ||
| log_probs (Tensor): Used to compute top n logprobs later if required. | ||
| """ | ||
|
|
||
| # Calculate log_probs (sequence_length x vocab_size) | ||
| log_probs = F.log_softmax(logits.squeeze(0).float(), dim=-1) | ||
|
|
||
| if only_last_token_logits or self.is_decode_only(): | ||
| seq_idx = torch.arange(len(new_tokens), dtype=torch.int32, device=logits.device) | ||
| selected_log_probs = log_probs[seq_idx, new_tokens] | ||
| return [[lp] for lp in selected_log_probs.flatten().tolist()] | ||
| return [[lp] for lp in selected_log_probs.flatten().tolist()], log_probs | ||
|
|
||
| # Get the selected token ids for all tokens. | ||
| # We shift the active token window left by one to remove the first prompt token for | ||
|
|
@@ -1812,7 +1814,7 @@ def calculate_log_probs( | |
| ) | ||
|
|
||
| # Convert each log prob tensor into a list | ||
| return [lp.tolist() for lp in selected_log_probs_list] | ||
| return [lp.tolist() for lp in selected_log_probs_list], log_probs | ||
|
|
||
| def get_kvcache_utilization_stats(self) -> dict: | ||
| """Compute KV cache buffer utilization stats for the current step. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.