Skip to content

feat: infer device_ids and normalize tile assignment #9514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from

Conversation

kvshbg-aws
Copy link

@kvshbg-aws kvshbg-aws commented Jul 26, 2025

This is the PR for complete support of sub-meshing and includes all changes related to abstraction/wrapper class RFC, normalize tile_assignment RFC and inferring device_ids for correct device_assignment for pjrt_computation_client RFC

This has all the changes related to above RFC's, since we need all of them in place together for the tests to work/pass, however, the commits related to abstraction can be ignored for this PR (we have a separate PR for that)

Once reviewed, and with the previous PR merged, i can create a new PR which will only include the commits related to inferring devices and normalizing tile assignment changes.


Explanation of major changes -

denormalized_tile_assignment save/use

We utilize the lowering_context initialized during compile call of xla_graph_executor to save denormalized_tile_assignment containing the device_ids used by the user when creating the Mesh object. This denormalized_tile_assignment variable is then passed to ScheduleSyncTensorsGraph and pjrt_computation_client::Compile functions (along with other functions as seen in the PR) so that the execution of the graph can happen on the correct set of devices (and not the "normalized device ids")

argument_handles changes

Along with this we need to make changes to argument_handles 2d vector -> which holds the device buffers for all the arguments/tensors used in that particular graph/execution. To explain further - consider the case of 8 addressable devices, and submesh with 4, we define 2 tensors and do some computation (matmul) using those tensors

  1. both tensors are sharded using mark_sharding api

the argument_handles should look like below in such case -

=== argument_handles (2D Matrix) ===
Dimensions: 4 devices (rows) x 2 arguments (cols)
  Device\Arg         Arg[0]         Arg[1]
------------------------------------------
     Dev[0]|0x5643accd9d10 0x5643ac667590 
     Dev[1]|0x5643accda280 0x5643ac667400 
     Dev[2]|0x5643accda7f0 0x5643accc1300 
     Dev[3]|0x5643accdad60 0x5643accc67f0 
=== End argument_handles after ===

i.e. both the arguments are sharded and present only on 4 devices (sub-mesh) and not all 8 devices

  1. one tensor is sharded using mark-sharding api and one is unsharded

the argument handles should look like this when it is populated for the 1st time for all tensors/arguments -
for submesh of 0,1,2,3

=== argument_handles before (2D Matrix) ===
Dimensions: 8 devices (rows) x 2 arguments (cols)
  Device\Arg         Arg[0]         Arg[1]
------------------------------------------
     Dev[0]|0x5643accd9d10 0x5643ac667590 
     Dev[1]|0x5643accda280 0x5643ac667400 
     Dev[2]|0x5643accda7f0 0x5643accc1300 
     Dev[3]|0x5643accdad60 0x5643accc67f0 
     Dev[4]|0x5643accdb2d0        nullptr 
     Dev[5]|0x5643accdba50        nullptr 
     Dev[6]|0x5643accdc1d0        nullptr 
     Dev[7]|0x5643accdc950        nullptr 
=== End argument_handles before ===

for submesh of 4,5,6,7

=== argument_handles before (2D Matrix) ===
Dimensions: 8 devices (rows) x 2 arguments (cols)
  Device\Arg         Arg[0]         Arg[1]
------------------------------------------
     Dev[0]|0x5643accd9d10    nullptr
     Dev[1]|0x5643accda280     nullptr
     Dev[2]|0x5643accda7f0      nullptr
     Dev[3]|0x5643accdad60    nullptr
     Dev[4]|0x5643accdb2d0    0x5643accc67f0
     Dev[5]|0x5643accdba50     0x5643ac667400
     Dev[6]|0x5643accdc1d0     0x5643accc1300
     Dev[7]|0x5643accdc950    0x5643ac667590
=== End argument_handles before ===

now, we need to make sure that the argument_handles that we pass to the Execute call of pjrt client has valid buffer pointers i.e. only of the submesh currently used by the process, hence the argument_handles should be changed to -

=== argument_handles after (2D Matrix) ===
Dimensions: 4 devices (rows) x 2 arguments (cols)
  Device\Arg         Arg[0]         Arg[1]
------------------------------------------
     Dev[x]|0x5643accd9d10 0x5643ac667590 
     Dev[x]|0x5643accda280 0x5643ac667400 
     Dev[x]|0x5643accda7f0 0x5643accc1300 
     Dev[x]|0x5643accdad60 0x5643accc67f0 
=== End argument_handles after ===

Thus, these 2 use-cases required us to make a solution that is generic for both global mesh and local submeshing cases. Hence, we are changing the logic of how argument_handles are populated and how we need to change them for the current sub-mesh being used for computations/execution.


New Test files added -

  1. test_submesh_zero_indexed.py
    Tests submeshes that start from device 0, using [0,1] for 2-device submeshes.

  2. test_submesh_non_zero_indexed.py
    Tests submeshes that start from non-zero device indices, using [2,3] for 2-device submeshes.

These tests are divided into 3 categories -

Basic Pattern tests

These tests validate fundamental tensor operations with different sharding and synchronization patterns:

Pattern 1: shard both tensors → compute → cpu() → sync()
  • Zero-indexed: test_pattern1_2dev(), test_pattern1_2dev_direct_device()
  • Non-zero-indexed: test_pattern1_2dev(), test_pattern1_2dev_direct_device()
Pattern 2: shard both tensors → compute → sync() → cpu()
  • Zero-indexed: test_pattern2_2dev(), test_pattern2_2dev_direct_device()
  • Non-zero-indexed: test_pattern2_2dev_direct_device()
Pattern 3: shard one tensor → compute → cpu() → sync()
  • Zero-indexed: test_pattern3_2dev(), test_pattern3_2dev_direct_device()
  • Non-zero-indexed: test_pattern3_2dev_direct_device()
Pattern 4: shard one tensor → compute → sync() → cpu()
  • Zero-indexed: test_pattern4_2dev_direct_device()
  • Non-zero-indexed: test_pattern4_2dev_direct_device()
Pattern 5: modify tensor → shard one tensor → compute → sync() → cpu()
  • Zero-indexed: test_pattern5_2dev_direct_device()
  • Non-zero-indexed: test_pattern5_2dev_direct_device()
Pattern 6: modify tensor → shard both tensors → compute → sync() → cpu()
  • Zero-indexed: test_pattern6_2dev_direct_device()
  • Non-zero-indexed: test_pattern6_2dev_direct_device()
Pattern 7: modify tensor → shard one tensor → compute → cpu() → sync()
  • Zero-indexed: test_pattern7_2dev_direct_device()
  • Non-zero-indexed: test_pattern7_2dev_direct_device()

Single Tensor Tests

Simple addition operations with single tensors to validate basic sharding:

  • Zero-indexed:

    • test_single_tensor_addition_2dev()
    • test_single_tensor_addition_2dev_direct_device()
  • Non-zero-indexed:

    • test_single_tensor_addition_2dev_direct_device()

Advanced Direct Device Tests

Sophisticated tests that showcase complex scenarios with direct device creation:

Complex Operations Test
  • Test: test_complex_operations_direct_device()
  • Operation: torch.matmul(xt1 + xt2, xt3) with 3 sharded tensors
  • Purpose: Validates multi-tensor operations with complex mathematical expressions
In-Place Operations Test
  • Test: test_inplace_operations_direct_device()
  • Operations: xt1 *= 2.0; xt1 += xt2
  • Purpose: Ensures in-place operations work correctly with sharded tensors

We are validating 2 Tensor Creation Approaches

1. CPU-to-Device Transfer (Traditional)

t1 = torch.randn(4, 4, device='cpu')
xt1 = t1.to(torch_xla.device())
expected = torch.matmul(t1, t2)  # Expected calculated on CPU

2. Direct Device Creation (Modern)

xt1 = torch.randn(4, 4, device=torch_xla.device())
expected = torch.matmul(xt1.cpu(), xt2.cpu())  # Expected from device tensors

@kvshbg-aws kvshbg-aws force-pushed the kvshbg-aws/local-spmd-normalize-infer branch from a974e2c to 21ad6c4 Compare July 27, 2025 00:23
@kvshbg-aws kvshbg-aws changed the title Kvshbg aws/local spmd normalize infer feat: infer device_ids and normalize tile assignment Jul 28, 2025
@kvshbg-aws kvshbg-aws marked this pull request as ready for review August 5, 2025 19:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant