-
Notifications
You must be signed in to change notification settings - Fork 360
sharding validator Improvements #1810
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally this looks great - however we generally want to assert that params are fully sharded (not just sharded by at least one axis, but sharded by all of them)
|
||
# If the parameter is not sharded on all of the target axes, it's considered "problematic." | ||
if not is_sharded_on_all_target_axis: | ||
unsharded_params_total_size += p_leaf.size # Add to total unsharded parameter size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note this is different logic than what existed originally --
Before we would sum up the local shard sizes and compare to the ideal fully sharded size.
Here we are summing up all non-fully params and comparing that to the number of params
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! This is a great feature to have!
Also, could you attach an example of logging with assertion failure of insufficient sharding in the PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
sharding validator improvement
Description
This change significantly enhances the sharding validation utility by providing detailed, actionable error messages. The previous version only indicates that some matrices were not fully sharded, but it failed to specify which matrices were unsharded by which axis. This forced developers into a tedious manual search to find the unsharded matrices. This improvement is able to show a more readable error message with large matrices that has not been fully sharded by which axes.
Specifically, for each replicated tensor, the error now shows:
Which Matrix: The full name of the parameter (e.g., layers.10.attention.query.kernel).
Its Current (Problematic) Sharding Status: The parameter's current PartitionSpec is displayed (e.g., PartitionSpec(None, 'expert')), making it clear why it's considered replicated on the main data-parallel axes.
Which Axes are not sharded: A list of available mesh axes that could have been used for sharding (e.g., could be sharded on: ['fsdp', 'tensor']), giving the developer an immediate hint for a valid solution.
FIXES: b/367055330
Tests
This change was validated using a comprehensive unit test suite, TestAssertParamsSufficientlySharded, which ensures the new validation logic is both correct and robust. The tests cover a wide variety of sharding configurations to confirm that the assertions pass and fail as expected.
Test Coverage
The test suite validates the following key scenarios:
Correctly Passing Scenarios:
Fully Sharded: A tensor that is fully sharded across all available primary mesh axes passes the assertion.
Sufficiently Sharded: A tensor sharded on at least one valid target axis (e.g., fsdp) passes, even if its other dimensions are replicated.
Complex Mesh: A tensor correctly sharded on a valid axis within a complex, multi-dimensional mesh also passes.
Correctly Failing Scenarios:
Completely Unsharded: A fully replicated parameter with no sharding specification correctly triggers an AssertionError.
Incorrectly Sharded: A tensor that is sharded, but not along any of the required target axes (e.g., sharded on 'sequence' but not on 'fsdp'), correctly fails.
Mixed Sharding: A model containing both correctly sharded and fully replicated parameters fails when the total size of replicated parameters exceeds the specified tolerance.
Each failing test case implicitly verifies that the raised AssertionError contains the new, detailed diagnostic message, including the names of problematic tensors and the sharding axes they could be mapped to. The tests also confirm that the companion utility, get_formatted_sharding_annotations, runs without error in all configurations.
Checklist
Before submitting this PR, please make sure (put X in square brackets):