-
Notifications
You must be signed in to change notification settings - Fork 347
Trying out logprobs and top logprobs for testing rather than logits. #745
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
assert_verbose_allclose( | ||
expected_output["logits"], | ||
actual_output["logits"], | ||
expected_logprobs, | ||
actual_logprobs, | ||
atol=logits_atol, | ||
rtol=logits_rtol, | ||
rtol=logits_rtol | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have to compare all logprobs
k = 5 | ||
exp_topk_vals, _ = torch.topk(expected_logprobs, k, dim=-1) | ||
act_topk_vals, _ = torch.topk(actual_logprobs, k, dim=-1) | ||
|
||
# Compare top-k logprobs with tolerance (ignoring order) | ||
max_diff = torch.max(torch.abs(exp_topk_vals - act_topk_vals)).item() | ||
print(f"Top-{k} logprobs max diff: {max_diff:.6f}") | ||
assert torch.all(torch.abs(exp_topk_vals - act_topk_vals) < (logits_atol + logits_rtol * torch.abs(exp_topk_vals))), ( | ||
f"Top-{k} logprobs are not all close (atol={logits_atol}). " | ||
f"Max diff: {max_diff:.6f}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make it a test util
1e-1, # 1e-1 | ||
1e-1, # 1e-2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After removing all logprobs comparison, we can try setting it lower.
sglang only has atol and sets it to 5e-2 (decode_tolerance)
verl sets (atol, rtol) = (1e-2, 1e-5)
, but it's mean of all logprobs not topk
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does not work with lower tolerance.
For gemma3, it passes when atol=1e-1 and rtol=1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested this out with fp32, it fails for most of the models where old logic for checking the logits is passing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are comparing values in log-space, the total tolerance here is actually relative tolerance.
actual_logprobs = torch.nn.functional.log_softmax(actual_output["logits"], dim=-1) | ||
expected_logprobs = torch.nn.functional.log_softmax(expected_output["logits"], dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make log_softmax()
and topk()
a util function, so we can call it in run_mini_model()
to avoid storing all logits
|
||
actual_logprobs = torch.nn.functional.log_softmax(actual_output["logits"], dim=-1) | ||
expected_logprobs = torch.nn.functional.log_softmax(expected_output["logits"], dim=-1) | ||
check_logprobs(actual_logprobs,expected_logprobs, atol=logits_atol,rtol=logits_rtol) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming we have top logprobs calculated and stored, we only need to call
assert_verbose_allclose(actual_top_logprobs, expected_top_logprobs, atol=logprob_atol, rtol=logprob_rtol)
.
Summary
Just testing out logprobs as mentioned in #742
It worked for the models where the test using logits was not working.
Also, tried to setup 1e-1 tolerance for qwen (previously 1) and it passed.
Testing Done
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence