Skip to content

Conversation

kahfizulkifli
Copy link

@kahfizulkifli kahfizulkifli commented Jun 17, 2025

In transformers-neuronx, it is important to ensure the distributed machine learning pipeline has the same meaning (semantic equivalence) as the original single-device pipeline. However, we found that one of the distributed operators used in these pipelines (all-to-all) don’t fulfill the intended behavior defined in the XLA documentation, which causes the outputs of these operators to be incorrect. The bug is shown when running the model with the GQA of shard-over-batch feature across multiple devices.

In the all-to-all semantics defined in XLA, it transposes the sharding from one dimension to another dimension. We found that in transformers-neuronx, the intended sharded dimension, when concatenated across multiple devices along the same dimension, doesn't result in the same tensor as the original version.

Steps to reproduce the bug:

  1. Setup transformers-neuronx library on AWS Neuron trainium machine
  2. Download the Llama-3 pretrained model from Huggingface https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/tree/main
  3. Edit the config.json in the Llama-3 folder to use 1 layer ("num_attention_heads": 4, "num_hidden_layers": 1, "num_key_value_heads": 1)
  4. Run the test script provided in two modes: baseline and distributed with collectives layout BSH
# Baseline mode
python llama_driver.py run <model_path_folder> --tp_degree=1 --gqa=shard-over-batch --debug > tp_1.txt
# Distributed mode
python llama_driver.py run <model_path_folder> --tp_degree=2 --gqa=shard-over-batch --debug > tp_2.txt
  1. Check that the results of both scripts are different to one another

Here are some generated outputs for the baseline and distributed.

split_dim=0 and concat_dim=1 case

Baseline:
dot.53
torch.Size([256, 4096])
tensor([-9.7422e-02,  2.6171e-01,  2.9402e-01, -1.4613e-01,  4.4988e-01,
        -5.7943e-01, -5.1829e-02,  8.0281e-02,  1.1778e+00,  8.7789e-02,
        -8.3667e-01, -1.5258e+00,  1.0470e+00,  5.0003e-01, -2.6673e-01,
         8.2140e-01,  1.2367e-01, -2.5464e-01, -4.1388e-01,  6.9485e-01,
         1.1211e-01, -9.0016e-01, -7.1786e-01,  1.0526e+00, -6.7589e-01,
        -1.9913e-01, -3.2404e-01,  1.4165e+00,  4.8524e-02, -1.0322e+00,

Distributed:

transpose.59
torch.Size([256, 4096])
tensor([-9.7422e-02,  1.3232e-01,  6.9574e-01, -1.3009e-01,  3.3184e-01,
         5.1237e-01, -7.2935e-02,  3.9345e-01,  7.3009e-02,  1.0174e-01,
         6.9550e-01,  5.0049e-01,  6.3394e-03,  5.2025e-02,  6.7904e-02,
        -1.4282e-01, -9.7422e-02,  1.3232e-01,  6.9574e-01, -1.3009e-01,
         5.3184e-01,  2.1237e-01, -7.2935e-02,  3.9345e-01,  7.3009e-02,
         1.0174e-01,  5.9550e-01,  5.0049e-01,  6.3394e-03,  5.2025e-02,
         7.7904e-02, -1.4282e-01, -9.7422e-02,  1.3232e-01,  6.9574e-01,

In this all-to-all layout transformation, we propose to change the reshape-transpose sequences in the all-to-all operation for different split-dim and concat-dim by adding a splitting of axis along the split_dim by the tp_degree.

In the split_dim=0, concat_dim=1 case, after we do the all-to-all operation, we split it to (tp_degree, a // tp_degree, b) to get the correct groupings after communicating from different cores. Then, we switch dimensions 0 and 1 between each other and reshape it so that the sharding dimension is at the first dimension.

In the split_dim=1, concat_dim=0 case, before we do the all-to-all operation, we split it to (a, tp_degree, b // tp_degree) to get the correct groupings before communicating from different cores. Then, we switch between dimensions 0 and 1 between each other and reshape it so that the sharding dimension is at the first dimension.

If we run the script again with the updated all_to_all, the results become equal to each other

Baseline:

dot.53
torch.Size([256, 4096])
tensor([-9.7422e-02,  2.6171e-01,  2.9402e-01, -1.4613e-01,  4.4988e-01,
        -5.7943e-01, -5.1829e-02,  8.0281e-02,  1.1778e+00,  8.7789e-02,
        -8.3667e-01, -1.5258e+00,  1.0470e+00,  5.0003e-01, -2.6673e-01,
         8.2140e-01,  1.2367e-01, -2.5464e-01, -4.1388e-01,  6.9485e-01,
         1.1211e-01, -9.0016e-01, -7.1786e-01,  1.0526e+00, -6.7589e-01,
        -1.9913e-01, -3.2404e-01,  1.4165e+00,  4.8524e-02, -1.0322e+00,

Distributed:

transpose.59
torch.Size([256, 4096])
tensor([-9.7422e-02,  1.3232e-01,  6.9574e-01, -1.3009e-01,  3.3184e-01,
         2.1237e-01, -7.2935e-02,  3.9345e-01,  7.3009e-02,  1.0174e-01,
         6.9550e-01,  5.0049e-01,  6.3394e-03,  5.2025e-02,  6.7904e-02,
        -1.4282e-01, -9.7422e-02,  1.3232e-01,  6.9574e-01, -1.3009e-01,
         3.3184e-01,  2.1237e-01, -7.2935e-02,  3.9345e-01,  7.3009e-02,
         1.0174e-01,  5.9550e-01,  5.0049e-01,  6.3394e-03,  5.2025e-02,
         8.7904e-02, -1.4282e-01, -9.7422e-02,  1.3232e-01,  6.9574e-01,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant