-
Notifications
You must be signed in to change notification settings - Fork 557
XLAShardedTensor.to_local() support #9505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Hoomaaan
wants to merge
7
commits into
pytorch:master
Choose a base branch
from
Hoomaaan:toLocal_wspec
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
57dbdd8
Implement XLAShardedTensor._spec and test
aws-cph 566959e
Removed auto wrapping sharding propagation, added cached spec invalid…
aws-cph 909b01c
Removing lazy import
aws-cph 5107724
Added test for catching thrown error in spec
aws-cph 655128e
Test for Routing XLA device handling through distribute_tensor to ens…
Hoomaaan 933a964
[XLA] Implement XLAShardedTensor.to_local()
Hoomaaan 812a69a
run git_fix for yapf
Hoomaaan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import sys | ||
import unittest | ||
import torch | ||
import numpy as np | ||
|
||
from torch.distributed.tensor import DeviceMesh | ||
from torch.distributed._tensor import DTensor | ||
from torch.distributed.tensor.placement_types import Replicate, Shard | ||
import torch_xla | ||
import torch_xla.runtime as xr | ||
import torch_xla.core.xla_model as xm | ||
from torch_xla.distributed.spmd.xla_sharded_tensor import XLAShardedTensor | ||
import test_xla_sharding_base | ||
|
||
|
||
class DTensorXLAFromLocalConversionTest(test_xla_sharding_base.XlaShardingTest): | ||
""" | ||
Test suite for the automatic conversion of regular tensors to XLAShardedTensor | ||
in DTensor.from_local() when using XLA device mesh. | ||
""" | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
super().setUpClass() | ||
|
||
def test_to_local(self): | ||
from torch.distributed.tensor import distribute_tensor | ||
world_size = xr.global_runtime_device_count() | ||
mesh = DeviceMesh("xla", list(range(world_size))) | ||
|
||
big_tensor = torch.randn(100000, 88) | ||
sharded_tensor = XLAShardedTensor(big_tensor, mesh, [Shard(0)]) | ||
|
||
local_tensor = sharded_tensor.to_local() | ||
|
||
# Verify the shapes are the same | ||
self.assertEqual(local_tensor.shape, big_tensor.shape) | ||
|
||
# Check the value of the tensor | ||
torch.testing.assert_close(local_tensor, big_tensor, check_device=False) | ||
|
||
def test_to_local_requires_grad(self): | ||
"""Test that gradients flow correctly through to_local().""" | ||
# Create a tensor with requires_grad=True | ||
world_size = xr.global_runtime_device_count() | ||
mesh = DeviceMesh("xla", list(range(world_size))) | ||
|
||
tensor = torch.randn(100_000, 88, requires_grad=True) | ||
|
||
# Create XLAShardedTensor | ||
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)]) | ||
|
||
# Verify requires_grad is set | ||
self.assertTrue(sharded_tensor.requires_grad) | ||
|
||
res = sharded_tensor.sum() | ||
res.backward() | ||
|
||
# Verify grad are calculated | ||
self.assertTrue(sharded_tensor.grad is not None) | ||
|
||
# Call to local function | ||
local_tensor = sharded_tensor.to_local() | ||
|
||
# Verify requires_grad is preserved | ||
self.assertTrue(local_tensor.requires_grad) | ||
|
||
# All gradients should be 1.0 since we did a sum() | ||
self.assertTrue(torch.allclose(local_tensor.grad, torch.ones_like(tensor))) | ||
|
||
print("Gradient flow test successful") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove |
||
|
||
|
||
if __name__ == "__main__": | ||
result = unittest.main(exit=False) | ||
sys.exit(0 if result.result.wasSuccessful() else 1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this method is necessary if there's no additional setup logic