feat: infer device_ids and normalize tile assignment #9514
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is the PR for complete support of sub-meshing and includes all changes related to abstraction/wrapper class RFC, normalize tile_assignment RFC and inferring device_ids for correct device_assignment for pjrt_computation_client RFC
This has all the changes related to above RFC's, since we need all of them in place together for the tests to work/pass, however, the commits related to abstraction can be ignored for this PR (we have a separate PR for that)
Once reviewed, and with the previous PR merged, i can create a new PR which will only include the commits related to inferring devices and normalizing tile assignment changes.
Explanation of major changes -
denormalized_tile_assignment
save/useWe utilize the lowering_context initialized during compile call of
xla_graph_executor
to savedenormalized_tile_assignment
containing the device_ids used by the user when creating the Mesh object. Thisdenormalized_tile_assignment
variable is then passed toScheduleSyncTensorsGraph
andpjrt_computation_client::Compile
functions (along with other functions as seen in the PR) so that the execution of the graph can happen on the correct set of devices (and not the "normalized device ids")argument_handles
changesAlong with this we need to make changes to
argument_handles
2d vector -> which holds the device buffers for all the arguments/tensors used in that particular graph/execution. To explain further - consider the case of 8 addressable devices, and submesh with 4, we define 2 tensors and do some computation (matmul) using those tensorsmark_sharding
apithe argument_handles should look like below in such case -
i.e. both the arguments are sharded and present only on 4 devices (sub-mesh) and not all 8 devices
mark-sharding
api and one is unshardedthe argument handles should look like this when it is populated for the 1st time for all tensors/arguments -
for submesh of 0,1,2,3
for submesh of 4,5,6,7
now, we need to make sure that the argument_handles that we pass to the
Execute
call of pjrt client has valid buffer pointers i.e. only of the submesh currently used by the process, hence the argument_handles should be changed to -Thus, these 2 use-cases required us to make a solution that is generic for both global mesh and local submeshing cases. Hence, we are changing the logic of how
argument_handles
are populated and how we need to change them for the current sub-mesh being used for computations/execution.New Test files added -
test_submesh_zero_indexed.py
Tests submeshes that start from device 0, using
[0,1]
for 2-device submeshes.test_submesh_non_zero_indexed.py
Tests submeshes that start from non-zero device indices, using
[2,3]
for 2-device submeshes.These tests are divided into 3 categories -
Basic Pattern tests
These tests validate fundamental tensor operations with different sharding and synchronization patterns:
Pattern 1:
shard both tensors → compute → cpu() → sync()
test_pattern1_2dev()
,test_pattern1_2dev_direct_device()
test_pattern1_2dev()
,test_pattern1_2dev_direct_device()
Pattern 2:
shard both tensors → compute → sync() → cpu()
test_pattern2_2dev()
,test_pattern2_2dev_direct_device()
test_pattern2_2dev_direct_device()
Pattern 3:
shard one tensor → compute → cpu() → sync()
test_pattern3_2dev()
,test_pattern3_2dev_direct_device()
test_pattern3_2dev_direct_device()
Pattern 4:
shard one tensor → compute → sync() → cpu()
test_pattern4_2dev_direct_device()
test_pattern4_2dev_direct_device()
Pattern 5:
modify tensor → shard one tensor → compute → sync() → cpu()
test_pattern5_2dev_direct_device()
test_pattern5_2dev_direct_device()
Pattern 6:
modify tensor → shard both tensors → compute → sync() → cpu()
test_pattern6_2dev_direct_device()
test_pattern6_2dev_direct_device()
Pattern 7:
modify tensor → shard one tensor → compute → cpu() → sync()
test_pattern7_2dev_direct_device()
test_pattern7_2dev_direct_device()
Single Tensor Tests
Simple addition operations with single tensors to validate basic sharding:
Zero-indexed:
test_single_tensor_addition_2dev()
test_single_tensor_addition_2dev_direct_device()
Non-zero-indexed:
test_single_tensor_addition_2dev_direct_device()
Advanced Direct Device Tests
Sophisticated tests that showcase complex scenarios with direct device creation:
Complex Operations Test
test_complex_operations_direct_device()
torch.matmul(xt1 + xt2, xt3)
with 3 sharded tensorsIn-Place Operations Test
test_inplace_operations_direct_device()
xt1 *= 2.0; xt1 += xt2
We are validating 2 Tensor Creation Approaches
1. CPU-to-Device Transfer (Traditional)
2. Direct Device Creation (Modern)