Skip to content

XLAShardedTensor.to_local() support #9505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/neuron/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ function run_xla_op_tests3 {
#run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py"
run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py"
run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_to_local.py"
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
#run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py"
Expand Down
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ function run_xla_op_tests3 {
run_test_multi_devices_without_func "$_TEST_DIR/spmd/test_dtensor_integration3.py"
run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py"
run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_to_local.py"
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py"
Expand Down
76 changes: 76 additions & 0 deletions test/spmd/test_xla_dtensor_to_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import sys
import unittest
import torch
import numpy as np

from torch.distributed.tensor import DeviceMesh
from torch.distributed._tensor import DTensor
from torch.distributed.tensor.placement_types import Replicate, Shard
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
from torch_xla.distributed.spmd.xla_sharded_tensor import XLAShardedTensor
import test_xla_sharding_base


class DTensorXLAFromLocalConversionTest(test_xla_sharding_base.XlaShardingTest):
"""
Test suite for the automatic conversion of regular tensors to XLAShardedTensor
in DTensor.from_local() when using XLA device mesh.
"""

@classmethod
def setUpClass(cls):
super().setUpClass()
Comment on lines +22 to +24
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this method is necessary if there's no additional setup logic


def test_to_local(self):
from torch.distributed.tensor import distribute_tensor
world_size = xr.global_runtime_device_count()
mesh = DeviceMesh("xla", list(range(world_size)))

big_tensor = torch.randn(100000, 88)
sharded_tensor = XLAShardedTensor(big_tensor, mesh, [Shard(0)])

local_tensor = sharded_tensor.to_local()

# Verify the shapes are the same
self.assertEqual(local_tensor.shape, big_tensor.shape)

# Check the value of the tensor
torch.testing.assert_close(local_tensor, big_tensor, check_device=False)

def test_to_local_requires_grad(self):
"""Test that gradients flow correctly through to_local()."""
# Create a tensor with requires_grad=True
world_size = xr.global_runtime_device_count()
mesh = DeviceMesh("xla", list(range(world_size)))

tensor = torch.randn(100_000, 88, requires_grad=True)

# Create XLAShardedTensor
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)])

# Verify requires_grad is set
self.assertTrue(sharded_tensor.requires_grad)

res = sharded_tensor.sum()
res.backward()

# Verify grad are calculated
self.assertTrue(sharded_tensor.grad is not None)

# Call to local function
local_tensor = sharded_tensor.to_local()

# Verify requires_grad is preserved
self.assertTrue(local_tensor.requires_grad)

# All gradients should be 1.0 since we did a sum()
self.assertTrue(torch.allclose(local_tensor.grad, torch.ones_like(tensor)))

print("Gradient flow test successful")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove



if __name__ == "__main__":
result = unittest.main(exit=False)
sys.exit(0 if result.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
run_test "$_TEST_DIR/spmd/test_fsdp_v2.py"
run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py"
run_test "$_TEST_DIR/spmd/test_xla_dtensor_to_local.py"
run_test "$_TEST_DIR/test_gradient_accumulation.py"
XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v
run_test "$_TEST_DIR/test_autocast.py"
Expand Down
25 changes: 24 additions & 1 deletion torch_xla/distributed/spmd/xla_sharded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __new__(cls,
dtype=elem.dtype,
layout=elem.layout,
device=elem.device,
requires_grad=kwargs.get("requires_grad", False))
requires_grad=kwargs.get("requires_grad", elem.requires_grad))
r.global_tensor = elem.detach() if r.requires_grad else elem

# Initialize mesh, partition, and spec information
Expand Down Expand Up @@ -150,6 +150,29 @@ def load_local_shards_(self, shards: List[XLAShard]):
# Invalidate cached spec since the global_tensor data has changed
self.invalidate_spec_cache()

def to_local(self):
"""
Returns the local representation of the XLAShardedTensor.

This method returns the global tensor representation, which contains
the combined data across all devices. The returned tensor is on the
same device as the original XLAShardedTensor. The returned tensor
will have the same requires_grad value as the XLAShardedTensor.
If the original tensor has gradients, those will be preserved.

Returns:
torch.Tensor: The global tensor representation with appropriate requires_grad setting.
"""

# Create a new tensor with the same values of global_tensor
result = self.global_tensor.clone()
# Since global tensor is detached, add requires_grad and grad values back to the local tensor
if self.requires_grad:
result.requires_grad = self.requires_grad
result.grad = self.grad

return result

@property
def sharding_spec(self):
return torch_xla._XLAC._get_xla_sharding_spec(self.global_tensor)
Expand Down
Loading