Skip to content

Trying out logprobs and top logprobs for testing rather than logits. #745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

Manan17
Copy link
Contributor

@Manan17 Manan17 commented Jun 5, 2025

Summary

Just testing out logprobs as mentioned in #742
It worked for the models where the test using logits was not working.
Also, tried to setup 1e-1 tolerance for qwen (previously 1) and it passed.

Testing Done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Comment on lines 1201 to 1206
assert_verbose_allclose(
expected_output["logits"],
actual_output["logits"],
expected_logprobs,
actual_logprobs,
atol=logits_atol,
rtol=logits_rtol,
rtol=logits_rtol
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have to compare all logprobs

Comment on lines 1207 to 1217
k = 5
exp_topk_vals, _ = torch.topk(expected_logprobs, k, dim=-1)
act_topk_vals, _ = torch.topk(actual_logprobs, k, dim=-1)

# Compare top-k logprobs with tolerance (ignoring order)
max_diff = torch.max(torch.abs(exp_topk_vals - act_topk_vals)).item()
print(f"Top-{k} logprobs max diff: {max_diff:.6f}")
assert torch.all(torch.abs(exp_topk_vals - act_topk_vals) < (logits_atol + logits_rtol * torch.abs(exp_topk_vals))), (
f"Top-{k} logprobs are not all close (atol={logits_atol}). "
f"Max diff: {max_diff:.6f}"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it a test util

Comment on lines +899 to 900
1e-1, # 1e-1
1e-1, # 1e-2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After removing all logprobs comparison, we can try setting it lower.
sglang only has atol and sets it to 5e-2 (decode_tolerance)
verl sets (atol, rtol) = (1e-2, 1e-5), but it's mean of all logprobs not topk

Copy link
Contributor Author

@Manan17 Manan17 Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does not work with lower tolerance.
For gemma3, it passes when atol=1e-1 and rtol=1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this out with fp32, it fails for most of the models where old logic for checking the logits is passing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are comparing values in log-space, the total tolerance here is actually relative tolerance.

Comment on lines +1193 to +1194
actual_logprobs = torch.nn.functional.log_softmax(actual_output["logits"], dim=-1)
expected_logprobs = torch.nn.functional.log_softmax(expected_output["logits"], dim=-1)
Copy link
Collaborator

@Tcc0403 Tcc0403 Jun 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make log_softmax() and topk() a util function, so we can call it in run_mini_model() to avoid storing all logits


actual_logprobs = torch.nn.functional.log_softmax(actual_output["logits"], dim=-1)
expected_logprobs = torch.nn.functional.log_softmax(expected_output["logits"], dim=-1)
check_logprobs(actual_logprobs,expected_logprobs, atol=logits_atol,rtol=logits_rtol)
Copy link
Collaborator

@Tcc0403 Tcc0403 Jun 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming we have top logprobs calculated and stored, we only need to call
assert_verbose_allclose(actual_top_logprobs, expected_top_logprobs, atol=logprob_atol, rtol=logprob_rtol).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants