Skip to content

feat: infer device_ids and normalize tile assignment #9514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
152 changes: 113 additions & 39 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,34 @@ class XLAShardingTest : public AtenXlaTensorTestBase {
}
};

TEST_F(XLAShardingTest, NormalizeTileAssignment) {
// Test with an empty tile assignment
std::vector<int64_t> empty_tile_assignment = {};
auto normalized =
ShardingUtil::NormalizeTileAssignment(empty_tile_assignment);
EXPECT_TRUE(normalized.empty());

// Test with positive values
std::vector<int64_t> positive_tile_assignment = {3, 1, 4, 2};
normalized = ShardingUtil::NormalizeTileAssignment(positive_tile_assignment);
EXPECT_EQ(normalized, std::vector<int64_t>({2, 0, 3, 1}));

// Test with all identical values
std::vector<int64_t> identical_tile_assignment = {5, 5, 5, 5};
normalized = ShardingUtil::NormalizeTileAssignment(identical_tile_assignment);
EXPECT_EQ(normalized, std::vector<int64_t>({0, 0, 0, 0}));

// Test with negative values
std::vector<int64_t> negative_tile_assignment = {-3, -1, -4, -2};
EXPECT_THROW(ShardingUtil::NormalizeTileAssignment(negative_tile_assignment),
std::runtime_error);

// Test with mixed positive and negative values
std::vector<int64_t> mixed_tile_assignment = {3, -1, 4, 2};
EXPECT_THROW(ShardingUtil::NormalizeTileAssignment(mixed_tile_assignment),
std::runtime_error);
}

TEST_F(XLAShardingTest, GetShardShape) {
auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
Expand All @@ -50,15 +78,19 @@ TEST_F(XLAShardingTest, GetShardShape) {
{0, 1},
{2, 3},
});
auto sharding = xla::HloSharding::Tile(mesh).ToProto();
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3};
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);

auto shard_shape = ShardingUtil::GetShardShape(sharding_spec);
// For tiled sharding, each dimension should be halved
EXPECT_EQ(shard_shape, std::vector<int64_t>({4, 4}));

sharding_spec->sharding = xla::HloSharding::Replicate().ToProto();
xla_sharding = xla::HloSharding::Replicate().ToProto();
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
sharding_spec->sharding = sharding;
shard_shape = ShardingUtil::GetShardShape(sharding_spec);
// For replicated sharding, each dimension should be preserved
EXPECT_EQ(shard_shape, std::vector<int64_t>({8, 7}));
Expand All @@ -74,7 +106,9 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
{0, 1},
{2, 3},
});
auto sharding = xla::HloSharding::Tile(mesh).ToProto();
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3};
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
auto shard_shape = ShardingUtil::GetShardShape(sharding_spec);
Expand Down Expand Up @@ -103,7 +137,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
EXPECT_EQ(slice.step(), 1);
}
}
sharding = xla::HloSharding::Replicate().ToProto();
xla_sharding = xla::HloSharding::Replicate().ToProto();
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
sharding_spec->sharding = sharding;
shard_shape = ShardingUtil::GetShardShape(sharding_spec);
replica_and_indices = ShardingUtil::GetShardReplicaAndIndicesForDevices(
Expand All @@ -121,16 +156,18 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
TEST_F(XLAShardingTest, ShardTensor) {
std::vector<std::string> devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3",
"TPU:4", "TPU:5", "TPU:6", "TPU:7"};
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7};

// 1D tiled
at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
xla::OpSharding sharding =
xla::OpSharding xla_sharding =
xla::HloSharding::Tile1D(
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()),
devices.size())
.ToProto();
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
Expand All @@ -148,7 +185,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
{0, 1, 2, 3},
{4, 5, 6, 7},
});
sharding = xla::HloSharding::Tile(mesh).ToProto();
xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
Expand All @@ -160,15 +198,19 @@ TEST_F(XLAShardingTest, ShardTensor) {
// 3D tiled, the first dim is replicated and the last halved. The last shard
// size should be smaller in dim=1 because it's not evenly divisible.
xla::Array3D<int64_t> cube({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}});
sharding_spec->sharding = xla::HloSharding::Tile(cube).ToProto();
xla_sharding = xla::HloSharding::Tile(cube).ToProto();
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
sharding_spec->sharding = sharding;
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 8);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({8, 2, 2}));
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({8, 1, 2}));

// Replicated, all shards should be identical.
sharding_spec->sharding = xla::HloSharding::Replicate().ToProto();
xla_sharding = xla::HloSharding::Replicate().ToProto();
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
sharding_spec->sharding = sharding;
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 8);
Expand All @@ -182,7 +224,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
xla::Array4D<int64_t> tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}});
sharding = xla::HloSharding::Tile(tesseract).ToProto();
xla_sharding = xla::HloSharding::Tile(tesseract).ToProto();
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
Expand All @@ -206,7 +249,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
xla::Array<int64_t> hypercube(std::vector<int64_t>{1, 1, 2, 2, 2});
hypercube.FillIota(0);
sharding = xla::HloSharding::Tile(hypercube).ToProto();
xla_sharding = xla::HloSharding::Tile(hypercube).ToProto();
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
Expand Down Expand Up @@ -234,7 +278,9 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
{4, 5, 0, 1},
{6, 7, 2, 3},
});
auto sharding = xla::HloSharding::Tile(mesh).ToProto();
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
std::vector<int64_t> denormalized_tile_assignment = {4, 5, 0, 1, 6, 7, 2, 3};
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
// For devices at the start of the mesh, all shards should have the same
Expand All @@ -251,7 +297,10 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
{0, 1, 4, 5},
{2, 3, 6, 7},
});
sharding_spec->sharding = xla::HloSharding::Tile(mesh).ToProto();
xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
denormalized_tile_assignment = {0, 1, 4, 5, 2, 3, 6, 7};
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
sharding_spec->sharding = sharding;
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 4);
Expand All @@ -278,7 +327,9 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) {
{{7}},
});

auto sharding = xla::HloSharding::Tile(mesh).ToProto();
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7};
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding, global_shape, /*minibatch=*/true);
auto shards = ShardingUtil::ShardTensor(minibatch_tensor, sharding_spec,
Expand All @@ -292,17 +343,21 @@ TEST_F(XLAShardingTest, EqualShardingSpecs) {
auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
XLATensor::ShardingSpec tiled_2d(xla::HloSharding::Tile({
{0, 1, 2, 3},
{4, 5, 6, 7},
})
.ToProto(),
tensor_shape);
XLATensor::ShardingSpec tiled_3d(
xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto(),
tensor_shape);
XLATensor::ShardingSpec replicated(xla::HloSharding::Replicate().ToProto(),
tensor_shape);
auto xla_sharding = xla::HloSharding::Tile({
{0, 1, 2, 3},
{4, 5, 6, 7},
})
.ToProto();
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7};
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
XLATensor::ShardingSpec tiled_2d(sharding, tensor_shape);
xla_sharding =
xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto();
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
XLATensor::ShardingSpec tiled_3d(sharding, tensor_shape);
xla_sharding = xla::HloSharding::Replicate().ToProto();
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
XLATensor::ShardingSpec replicated(sharding, tensor_shape);
EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_2d));
EXPECT_FALSE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_3d));
EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(replicated, replicated));
Expand All @@ -323,12 +378,17 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
std::vector<std::string> devices(3);
std::fill_n(devices.begin(), devices.size(),
bridge::GetDefaultDevice()->toString());
auto replicate_xla_sharding = xla::HloSharding::Replicate().ToProto();
auto unknown_xla_sharding = xla::HloSharding::Unknown().ToProto();
torch_xla::OpSharding replicate_sharding(replicate_xla_sharding,
std::nullopt);
torch_xla::OpSharding unknown_sharding(unknown_xla_sharding, std::nullopt);
std::vector<XLATensor::ShardingSpecPtr> shardings = {
nullptr,
std::make_shared<XLATensor::ShardingSpec>(
xla::HloSharding::Replicate().ToProto(), tensor_shape),
std::make_shared<XLATensor::ShardingSpec>(
xla::HloSharding::Unknown().ToProto(), tensor_shape)};
std::make_shared<XLATensor::ShardingSpec>(replicate_sharding,
tensor_shape),
std::make_shared<XLATensor::ShardingSpec>(unknown_sharding,
tensor_shape)};
std::vector<torch::lazy::BackendDataPtr> tensors_data =
CreateTensorsData(tensors, shardings, devices);

Expand Down Expand Up @@ -387,13 +447,29 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
auto y = xla::Add(x, xla::ConstantR0<float>(&b, 3));
xla::XlaComputation xla_computation =
GetValueOrThrow(b.Build(/*remove_dynamic_dimensions=*/false));

std::vector<XLATensorPtr> tensors{XLATensor::Create(
torch_xla::runtime::GetComputationClientOrDie()->CreateDataPlaceholder(
bridge::GetDefaultDevice()->toString(), std::move(shape)))};
std::vector<std::vector<int64_t>> denormalized_tile_assignments;
for (auto tensor : tensors) {
auto sharding_spec = tensor->sharding_spec();
if (sharding_spec) {
denormalized_tile_assignments.push_back(
sharding_spec->sharding.GetDenormalizedTileAssignment());
}
}

std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
instances.push_back({std::move(xla_computation),
bridge::GetDefaultDevice()->toString(),
{bridge::GetDefaultDevice()->toString()},
&shape,
/*should_wrap_parameter=*/false,
/*is_sharded=*/true});
instances.push_back(
{std::move(xla_computation),
bridge::GetDefaultDevice()->toString(),
{bridge::GetDefaultDevice()->toString()},
&shape,
/*should_wrap_parameter=*/false,
/*is_sharded=*/true,
/*allow_spmd_sharding_propagation_to_output=*/true,
/*denormalized_tile_assignments=*/denormalized_tile_assignments});

std::vector<
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
Expand All @@ -404,9 +480,6 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
"add", std::move(computations[0]->move_computation()));

// Prepare output sharding propagation, expect a sharded output placeholder.
std::vector<XLATensorPtr> tensors{XLATensor::Create(
torch_xla::runtime::GetComputationClientOrDie()->CreateDataPlaceholder(
bridge::GetDefaultDevice()->toString(), std::move(shape)))};
std::vector<torch::lazy::BackendDataPtr> data_placeholders;
std::vector<XLATensor::ShardingSpecPtr> sharding_specs;
ShardingUtil::PrepareOutputShardingPropagation(
Expand All @@ -417,11 +490,12 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
if (n_devices > 1) {
// Tiled sharding requires multiple devices.
EXPECT_TRUE(xla::protobuf_util::HaveSameSerialization(
tiled, sharding_specs[0]->sharding));
tiled, sharding_specs[0]->sharding.GetXlaOpSharding()));
} else {
// Sincle device execution defaults to replication sharding.
EXPECT_TRUE(xla::protobuf_util::HaveSameSerialization(
xla::HloSharding::Replicate().ToProto(), sharding_specs[0]->sharding));
xla::HloSharding::Replicate().ToProto(),
sharding_specs[0]->sharding.GetXlaOpSharding()));
}

// Check if the placeholder is on a SPMD device (sharded) with no real values.
Expand Down
2 changes: 2 additions & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ function run_xla_op_tests3 {
run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py" "$@" --skip-gradient-checkpointing
run_test "$_TEST_DIR/test_gradient_accumulation.py"
run_save_tensor_hlo run_test "$_TEST_DIR/spmd/test_spmd_lowering_context.py"
run_test_multi_devices "$_TEST_DIR/spmd/test_submesh_zero_indexed.py"
run_test_multi_devices "$_TEST_DIR/spmd/test_submesh_non_zero_indexed.py"
run_test "$_TEST_DIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
run_test "$_TEST_DIR/test_input_output_aliases.py"
run_test_without_functionalization "$_TEST_DIR/test_input_output_aliases.py"
Expand Down
Loading