Incorrect all-to-all layout transformation implementation #108
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
In transformers-neuronx, it is important to ensure the distributed machine learning pipeline has the same meaning (semantic equivalence) as the original single-device pipeline. However, we found that one of the distributed operators used in these pipelines (all-to-all) don’t fulfill the intended behavior defined in the XLA documentation, which causes the outputs of these operators to be incorrect. The bug is shown when running the model with the GQA of shard-over-batch feature across multiple devices.
In the all-to-all semantics defined in XLA, it transposes the sharding from one dimension to another dimension. We found that in transformers-neuronx, the intended sharded dimension, when concatenated across multiple devices along the same dimension, doesn't result in the same tensor as the original version.
Steps to reproduce the bug:
Here are some generated outputs for the baseline and distributed.
split_dim=0 and concat_dim=1 case
Distributed:
In this all-to-all layout transformation, we propose to change the reshape-transpose sequences in the all-to-all operation for different split-dim and concat-dim by adding a splitting of axis along the split_dim by the tp_degree.
In the split_dim=0, concat_dim=1 case, after we do the all-to-all operation, we split it to (tp_degree, a // tp_degree, b) to get the correct groupings after communicating from different cores. Then, we switch dimensions 0 and 1 between each other and reshape it so that the sharding dimension is at the first dimension.
In the split_dim=1, concat_dim=0 case, before we do the all-to-all operation, we split it to (a, tp_degree, b // tp_degree) to get the correct groupings before communicating from different cores. Then, we switch between dimensions 0 and 1 between each other and reshape it so that the sharding dimension is at the first dimension.
If we run the script again with the updated all_to_all, the results become equal to each other
Baseline:
Distributed: