Skip to content

sharding validator Improvements #1810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2025
Merged

sharding validator Improvements #1810

merged 1 commit into from
Jun 13, 2025

Conversation

wei879-100
Copy link
Collaborator

@wei879-100 wei879-100 commented Jun 6, 2025

Description

This change significantly enhances the sharding validation utility by providing detailed, actionable error messages. The previous version only indicates that some matrices were not fully sharded, but it failed to specify which matrices were unsharded by which axis. This forced developers into a tedious manual search to find the unsharded matrices. This improvement is able to show a more readable error message with large matrices that has not been fully sharded by which axes.

Specifically, for each replicated tensor, the error now shows:

Which Matrix: The full name of the parameter (e.g., layers.10.attention.query.kernel).
Its Current (Problematic) Sharding Status: The parameter's current PartitionSpec is displayed (e.g., PartitionSpec(None, 'expert')), making it clear why it's considered replicated on the main data-parallel axes.
Which Axes are not sharded: A list of available mesh axes that could have been used for sharding (e.g., could be sharded on: ['fsdp', 'tensor']), giving the developer an immediate hint for a valid solution.

FIXES: b/367055330

Tests

This change was validated using a comprehensive unit test suite, TestAssertParamsSufficientlySharded, which ensures the new validation logic is both correct and robust. The tests cover a wide variety of sharding configurations to confirm that the assertions pass and fail as expected.

Test Coverage
The test suite validates the following key scenarios:
Correctly Passing Scenarios:
Fully Sharded: A tensor that is fully sharded across all available primary mesh axes passes the assertion.
Sufficiently Sharded: A tensor sharded on at least one valid target axis (e.g., fsdp) passes, even if its other dimensions are replicated.
Complex Mesh: A tensor correctly sharded on a valid axis within a complex, multi-dimensional mesh also passes.
Correctly Failing Scenarios:
Completely Unsharded: A fully replicated parameter with no sharding specification correctly triggers an AssertionError.
Incorrectly Sharded: A tensor that is sharded, but not along any of the required target axes (e.g., sharded on 'sequence' but not on 'fsdp'), correctly fails.
Mixed Sharding: A model containing both correctly sharded and fully replicated parameters fails when the total size of replicated parameters exceeds the specified tolerance.

Each failing test case implicitly verifies that the raised AssertionError contains the new, detailed diagnostic message, including the names of problematic tensors and the sharding axes they could be mapped to. The tests also confirm that the companion utility, get_formatted_sharding_annotations, runs without error in all configurations.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally this looks great - however we generally want to assert that params are fully sharded (not just sharded by at least one axis, but sharded by all of them)


# If the parameter is not sharded on all of the target axes, it's considered "problematic."
if not is_sharded_on_all_target_axis:
unsharded_params_total_size += p_leaf.size # Add to total unsharded parameter size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note this is different logic than what existed originally --

Before we would sum up the local shard sizes and compare to the ideal fully sharded size.

Here we are summing up all non-fully params and comparing that to the number of params

@gobbleturk gobbleturk assigned RissyRan and unassigned gobbleturk Jun 11, 2025
@wei879-100 wei879-100 removed their assignment Jun 11, 2025
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! This is a great feature to have!

Also, could you attach an example of logging with assertion failure of insufficient sharding in the PR?

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

sharding validator improvement
@copybara-service copybara-service bot merged commit 244a071 into main Jun 13, 2025
18 checks passed
@copybara-service copybara-service bot deleted the weifan-shard branch June 13, 2025 21:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants