diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/xla/pjrt/gpu/se_gpu_pjrt_client.cc index dcd01c8d331a9..caba83542002c 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -155,6 +155,11 @@ limitations under the License. namespace xla { +template +static bool IsMemorySpaceKind(const PjRtMemorySpace* memory_space) { + return memory_space->kind_id() == MemorySpaceKind::kKindId; +} + absl::Status RunCallbackOnStream( se::Stream* stream, AsyncWorkRunner* async_work_runner, absl::AnyInvocable callback, @@ -1344,6 +1349,19 @@ absl::StatusOr StreamExecutorGpuClient::GetDefaultLayout( return topology_->GetDefaultLayout(element_type, dims); } +absl::StatusOr StreamExecutorGpuClient::GetCopyDestinationShape( + const xla::Shape& shape, PjRtMemorySpace* src_memory_space, + PjRtMemorySpace* dst_memory_space) { + if (this != dst_memory_space->client() || + IsMemorySpaceKind(src_memory_space) != + IsMemorySpaceKind(dst_memory_space)) { + return CommonPjRtClient::GetCopyDestinationShape(shape, src_memory_space, + dst_memory_space); + } + return MakeDefaultShapeForMemorySpace( + dst_memory_space, shape, shape.has_layout() ? &shape.layout() : nullptr); +} + absl::StatusOr> StreamExecutorGpuClient::CompileAndLoad(MaybeOwningMlirModule module, CompileOptions options) { diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client.h b/xla/pjrt/gpu/se_gpu_pjrt_client.h index 12ff74374e56e..b79937cbb0a10 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -182,6 +182,10 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { absl::StatusOr GetDefaultLayout( PrimitiveType element_type, absl::Span dims) override; + absl::StatusOr GetCopyDestinationShape( + const xla::Shape& shape, PjRtMemorySpace* src_memory_space, + PjRtMemorySpace* dst_memory_space) override; + absl::StatusOr> LoadSerialized( absl::string_view serialized, std::optional options, const LoadOptions& load_options); diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index a4706e077ffff..df9132e12fe7d 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -66,6 +66,7 @@ limitations under the License. #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/test.h" #include "xla/layout.h" +#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/pjrt/device_event.h" @@ -1520,16 +1521,33 @@ TEST(StreamExecutorGpuClientTest, GetTopologyDescriptionWithGlobalDevicesTest) { TEST(PjRtCpuClientTest, CopyToMemorySpace) { TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(DefaultOptions())); + xla::Shape shape = xla::ShapeUtil::MakeShape(S32, {128, 256}); + TF_ASSERT_OK_AND_ASSIGN(auto literal, xla::MakeFakeLiteral(shape)); for (auto* memory_space : client->memory_spaces()) { - xla::Shape shape = xla::ShapeUtil::MakeShape(S32, {128, 256}); - TF_ASSERT_OK_AND_ASSIGN(auto literal, xla::MakeFakeLiteral(shape)); TF_ASSERT_OK_AND_ASSIGN( auto buffer, client->BufferFromHostLiteral(literal, memory_space)); TF_ASSERT_OK_AND_ASSIGN(buffer, buffer->CopyToMemorySpace(buffer->memory_space())); TF_ASSERT_OK_AND_ASSIGN(auto received_literal, buffer->ToLiteral().Await()); - EXPECT_THAT(received_literal->data(), - ElementsAreArray(literal.data())); + EXPECT_EQ(*received_literal, literal); + } +} + +TEST(PjRtCpuClientTest, CopyToMemorySpaceWithCustomLayout) { + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(DefaultOptions())); + xla::Shape shape = xla::ShapeUtil::MakeShape(S32, {128, 256}); + TF_ASSERT_OK_AND_ASSIGN(auto literal, xla::MakeFakeLiteral(shape)); + Layout device_layout = LayoutUtil::MakeAscendingLayout(2); + for (auto* memory_space : client->memory_spaces()) { + TF_ASSERT_OK_AND_ASSIGN( + auto buffer, + client->BufferFromHostLiteral(literal, memory_space, &device_layout)); + TF_ASSERT_OK_AND_ASSIGN(buffer, + buffer->CopyToMemorySpace(buffer->memory_space())); + EXPECT_EQ(buffer->layout()->xla_layout(), device_layout); + TF_ASSERT_OK_AND_ASSIGN(auto received_literal, buffer->ToLiteral().Await()); + EXPECT_EQ(*received_literal, literal); } }