-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Adding Tensor_layout for Tensor parallelism for Autosharding #21792
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Adding Tensor_layout for Tensor parallelism for Autosharding #21792
Conversation
Summary of ChangesHello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request establishes foundational components for tensor parallelism within the JAX backend, crucial for Autosharding. It provides core collective communication primitives like Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces foundational components for tensor parallelism in Keras, specifically for the JAX backend. It adds all_reduce and all_gather collective operations, which are essential for distributed computations. Additionally, it provides a split_tensor_for_parallelism utility for sharding tensors across devices. The changes are well-tested, covering both even and uneven tensor splitting. My review includes a few suggestions to improve documentation accuracy and code simplicity, and to align with the repository's style guide regarding docstring examples.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…keras into tensor_parallel
|
Can you rebase to make the tests pass? |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #21792 +/- ##
==========================================
- Coverage 82.66% 82.32% -0.35%
==========================================
Files 577 583 +6
Lines 59453 60364 +911
Branches 9320 9514 +194
==========================================
+ Hits 49148 49694 +546
- Misses 7902 8222 +320
- Partials 2403 2448 +45
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| def test_all_reduce(self): | ||
| devices = jax.devices() | ||
| num_devices = len(devices) | ||
| input_data = np.ones((num_devices, 2), dtype="float32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_data needs to to sharded across the devices to make this test valid.
| def sum_fn(x): | ||
| return backend_dlib.all_reduce(x, op="sum", axis_name="batch") | ||
|
|
||
| result_sum = jax.pmap(sum_fn, axis_name="batch")(input_data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't the pmap be part of the all_reduce implementation?
Same for "mean".
| num_devices = len(devices) | ||
|
|
||
| input_data = np.arange(num_devices, dtype="float32").reshape( | ||
| num_devices, 1, 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_data needs to to sharded across the devices to make this test valid.
| if hasattr(layer, "_kernel") and layer._kernel is not None: | ||
| kernel_shape = layer._kernel.shape | ||
| if len(kernel_shape) == 2: | ||
| input_dim = kernel_shape[0] | ||
| output_dim = kernel_shape[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's a keras.layers.Dense layer, this should always work.
Do we actually need anything between lines 40-63?
| isinstance(layer, (layers.Embedding,)) | ||
| or "Embedding" in layer.__class__.__name__ | ||
| ): | ||
| if hasattr(layer, "weights"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weights is a property of Layer, it cannot be missing, remove this if.
| for weight in layer.weights: | ||
| if "embedding" in weight.name or "weight" in weight.name: | ||
| key_found = False | ||
| for attr_candidate in [ | ||
| "embeddings", | ||
| "position_embeddings", | ||
| "weight", | ||
| ]: | ||
| if getattr(layer, attr_candidate, None) is weight: | ||
| state_rules[f"{full_name}.{attr_candidate}"] = ( | ||
| split_rule(dim=1) | ||
| ) | ||
| key_found = True | ||
| break | ||
|
|
||
| if not key_found: | ||
| clean_name = weight.name.split("/")[-1] | ||
| state_rules[f"{full_name}.{clean_name}"] = split_rule( | ||
| dim=1 | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused about this. So if it's an embedding, the rule is split_rule(dim=1). Great.
Why do you need all the extra code? Why not do:
state_rules[f"{full_name}.embedding"] = split_rule(dim=1)And remove lines 145-164.
The {attr_candidate} or the {clean_name} is not reversible. You cannot find the weight back from the name because you removed some stuff.
Also, do you need to keep a reference to the weights in general? Like the kernel for dense?
| ) in self.tensor_parallel_config.state_rules.items(): | ||
| if re.search(p, norm_param_name) and hasattr(a, "dim"): | ||
| sharding_dim = a.dim | ||
| break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not safe, you're matching substrings.
For instance if p is "dense.kernel" and norm_param_name is "einsum_dense.kernel", it will match, even though you got the wrong layer.
You want an exact match and manipulating names or paths just makes it harder. You should just identify the variables themselves directly. So you want to key your state_rules by id(variable). For instance, in autoconfig, you just do state_rules[id(layer._kernel)] = split_rule(dim=0). This will tremendously simplify this code and the code in autoconfig.
Optimizers already have a way to map optimizer variables to model variables, you don't need to recreate a different way to do this mapping.
This change introduces core building blocks for tensor parallelism by adding two key components.
First, it adds crucial collective operations, all_reduce and all_gather, to the JAX backend. These allow multiple devices to synchronize data by summing tensors (like gradients) or gathering individual slices back into a full tensor. Second, it adds the high-level tensor sharding logic (split_tensor_for_parallelism), which uses ops.array_split to intelligently slice large tensors, even unevenly, for distribution across devices. New tests confirm this new parallel logic, including the uneven splitting, works as expected.
The tests on this PR will pass after the PR #21697 gets merged