From d73250c2f9c665aadda8c5f572b06c45a2732eb4 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 13 Feb 2025 14:52:07 -0800 Subject: [PATCH] [pjrt] Removed PjRtDevice overloads of `PjRtClient::CreateBuffersForAsyncHostToDevice` I also pulled in the `Shape`->`ShapeSpec` conversion code into the default implementation, since it was duplicated in a few clients. PiperOrigin-RevId: 726646980 --- xla/pjrt/BUILD | 14 ++++++ xla/pjrt/cpu/BUILD | 1 + xla/pjrt/cpu/cpu_client.cc | 49 +++----------------- xla/pjrt/cpu/cpu_client.h | 10 ---- xla/pjrt/cpu/cpu_client_test.cc | 45 +++++++++--------- xla/pjrt/gpu/se_gpu_pjrt_client.cc | 40 ---------------- xla/pjrt/gpu/se_gpu_pjrt_client.h | 14 +----- xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 16 ++++--- xla/pjrt/pjrt_c_api_client.cc | 38 ---------------- xla/pjrt/pjrt_c_api_client.h | 14 ------ xla/pjrt/pjrt_c_api_client_test.cc | 2 +- xla/pjrt/pjrt_client.h | 32 ++++--------- xla/pjrt/pjrt_client_utils.cc | 53 ++++++++++++++++++++++ xla/pjrt/pjrt_client_utils.h | 35 ++++++++++++++ xla/pjrt/pjrt_stream_executor_client.h | 12 ----- xla/pjrt/tf_pjrt_client.h | 6 --- xla/python/transfer/streaming_ifrt_test.cc | 6 ++- 17 files changed, 155 insertions(+), 232 deletions(-) create mode 100644 xla/pjrt/pjrt_client_utils.cc create mode 100644 xla/pjrt/pjrt_client_utils.h diff --git a/xla/pjrt/BUILD b/xla/pjrt/BUILD index cc78016955ffa..36bb2e215ec4d 100644 --- a/xla/pjrt/BUILD +++ b/xla/pjrt/BUILD @@ -240,6 +240,20 @@ cc_library( ], ) +cc_library( + name = "pjrt_client_utils", + srcs = ["pjrt_client_utils.cc"], + hdrs = ["pjrt_client_utils.h"], + visibility = internal_visibility(["//xla:friends"]), + deps = [ + ":pjrt_client", + "//xla:shape_util", + "//xla:util", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "pjrt_client_test_common", testonly = 1, diff --git a/xla/pjrt/cpu/BUILD b/xla/pjrt/cpu/BUILD index 4e0f2b980daea..6ec0f339f7da4 100644 --- a/xla/pjrt/cpu/BUILD +++ b/xla/pjrt/cpu/BUILD @@ -167,6 +167,7 @@ cc_library( "//xla/pjrt:layout_mode", "//xla/pjrt:mlir_to_hlo", "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_client_utils", "//xla/pjrt:pjrt_common", "//xla/pjrt:pjrt_compiler", "//xla/pjrt:pjrt_executable", diff --git a/xla/pjrt/cpu/cpu_client.cc b/xla/pjrt/cpu/cpu_client.cc index 8f17c2442e506..6c4209ea6518a 100644 --- a/xla/pjrt/cpu/cpu_client.cc +++ b/xla/pjrt/cpu/cpu_client.cc @@ -75,6 +75,7 @@ limitations under the License. #include "xla/pjrt/layout_mode.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_client_utils.h" #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_executable.h" @@ -947,41 +948,14 @@ TfrtCpuClient::CreateUninitializedBuffer(const Shape& shape, tensorflow::down_cast(device), this); } -absl::StatusOr> -TfrtCpuClient::CreateBuffersForAsyncHostToDevice(absl::Span shapes, - PjRtDevice* device) { - auto* tfrt_device = tensorflow::down_cast(device); - return TfrtCpuAsyncHostToDeviceTransferManager::Create(shapes, tfrt_device, - this); -} - absl::StatusOr> TfrtCpuClient::CreateBuffersForAsyncHostToDevice( absl::Span shapes, PjRtMemorySpace* memory_space) { CHECK_EQ(memory_space->devices().size(), 1); - return CreateBuffersForAsyncHostToDevice(shapes, memory_space->devices()[0]); -} - -static absl::StatusOr> ConvertShapeSpecToShapes( - absl::Span shape_specs, - std::optional>> device_layouts) { - if (device_layouts.has_value() && - device_layouts->size() != shape_specs.size()) { - return InvalidArgument( - "Number of layouts %d does not match the number of shapes %d", - device_layouts->size(), shape_specs.size()); - } - std::vector device_shapes; - device_shapes.reserve(shape_specs.size()); - for (size_t i = 0; i < shape_specs.size(); ++i) { - auto& shape_spec = shape_specs[i]; - Shape& device_shape = device_shapes.emplace_back( - ShapeUtil::MakeShape(shape_spec.element_type, shape_spec.dims)); - if (device_layouts.has_value() && (*device_layouts)[i].has_value()) { - *device_shape.mutable_layout() = *(*device_layouts)[i]; - } - } - return device_shapes; + auto* tfrt_device = + tensorflow::down_cast(memory_space->devices().front()); + return TfrtCpuAsyncHostToDeviceTransferManager::Create(shapes, tfrt_device, + this); } absl::StatusOr> @@ -990,22 +964,11 @@ TfrtCpuClient::CreateBuffersForAsyncHostToDevice( std::optional>> device_layouts, PjRtMemorySpace* memory_space) { TF_ASSIGN_OR_RETURN(std::vector device_shapes, - ConvertShapeSpecToShapes(shape_specs, device_layouts)); + ConvertShapeSpecsToShapes(shape_specs, device_layouts)); return CreateBuffersForAsyncHostToDevice(absl::MakeSpan(device_shapes), memory_space); } -absl::StatusOr> -TfrtCpuClient::CreateBuffersForAsyncHostToDevice( - absl::Span shape_specs, - std::optional>> device_layouts, - PjRtDevice* device) { - TF_ASSIGN_OR_RETURN(std::vector device_shapes, - ConvertShapeSpecToShapes(shape_specs, device_layouts)); - return CreateBuffersForAsyncHostToDevice(absl::MakeSpan(device_shapes), - device); -} - absl::StatusOr> TfrtCpuClient::BufferFromHostBuffer( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, diff --git a/xla/pjrt/cpu/cpu_client.h b/xla/pjrt/cpu/cpu_client.h index ed43819c9672f..a56986be79ed4 100644 --- a/xla/pjrt/cpu/cpu_client.h +++ b/xla/pjrt/cpu/cpu_client.h @@ -144,10 +144,6 @@ class TfrtCpuClient final : public PjRtClient { absl::StatusOr> CreateUninitializedBuffer( const Shape& shape, PjRtMemorySpace* device) override; - absl::StatusOr> - CreateBuffersForAsyncHostToDevice(absl::Span shapes, - PjRtDevice* device) override; - absl::StatusOr> CreateBuffersForAsyncHostToDevice(absl::Span shapes, PjRtMemorySpace* memory_space) override; @@ -158,12 +154,6 @@ class TfrtCpuClient final : public PjRtClient { std::optional>> device_layouts, PjRtMemorySpace* memory_space) override; - absl::StatusOr> - CreateBuffersForAsyncHostToDevice( - absl::Span shape_specs, - std::optional>> device_layouts, - PjRtDevice* device) override; - absl::StatusOr> BufferFromHostBuffer( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, diff --git a/xla/pjrt/cpu/cpu_client_test.cc b/xla/pjrt/cpu/cpu_client_test.cc index aea5b47ff2207..48e74c4d090d7 100644 --- a/xla/pjrt/cpu/cpu_client_test.cc +++ b/xla/pjrt/cpu/cpu_client_test.cc @@ -225,7 +225,7 @@ TEST(TfrtCpuClientTest, AsyncTransferRawData) { xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2}); TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( - {shape}, client->addressable_devices()[0])); + {shape}, client->memory_spaces()[0])); auto buffer = transfer_manager->RetrieveBuffer(0); auto ready_future = buffer->GetReadyFuture(); EXPECT_THAT(ready_future.IsReady(), IsFalse()); @@ -245,8 +245,8 @@ TEST(TfrtCpuClientTest, AsyncTransferWithSpecs) { PjRtClient::ShapeSpec shape_spec{U32, {3, 2}}; TF_ASSERT_OK_AND_ASSIGN( auto transfer_manager, - client->CreateBuffersForAsyncHostToDevice( - {shape_spec}, std::nullopt, client->addressable_devices()[0])); + client->CreateBuffersForAsyncHostToDevice({shape_spec}, std::nullopt, + client->memory_spaces()[0])); auto buffer = transfer_manager->RetrieveBuffer(0); auto ready_future = buffer->GetReadyFuture(); EXPECT_THAT(ready_future.IsReady(), IsFalse()); @@ -266,7 +266,7 @@ TEST(TfrtCpuClientTest, AsyncTransferLiteral) { xla::Shape shape = xla::ShapeUtil::MakeShape(F32, {128, 256}); TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( - {shape}, client->addressable_devices()[0])); + {shape}, client->memory_spaces()[0])); auto buffer = transfer_manager->RetrieveBuffer(0); auto ready_future = buffer->GetReadyFuture(); EXPECT_THAT(ready_future.IsReady(), IsFalse()); @@ -282,7 +282,7 @@ TEST(TfrtCpuClientTest, AsyncTransferLiteralInt4) { xla::Shape shape = xla::ShapeUtil::MakeShape(S4, {128, 256}); TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( - {shape}, client->addressable_devices()[0])); + {shape}, client->memory_spaces()[0])); auto buffer = transfer_manager->RetrieveBuffer(0); auto ready_future = buffer->GetReadyFuture(); EXPECT_THAT(ready_future.IsReady(), IsFalse()); @@ -310,7 +310,7 @@ TEST(TfrtCpuClientTest, AsyncTransferCallsOnDone) { xla::Shape shape = ShapeUtil::MakeShape(F32, {3, 2}); TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( - {shape}, client->addressable_devices()[0])); + {shape}, client->memory_spaces()[0])); auto buffer = transfer_manager->RetrieveBuffer(0); auto ready_future = buffer->GetReadyFuture(); EXPECT_THAT(ready_future.IsReady(), IsFalse()); @@ -328,7 +328,7 @@ TEST(TfrtCpuClientTest, AsyncTransferNeverTransferred) { xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2}); TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( - {shape}, client->addressable_devices()[0])); + {shape}, client->memory_spaces()[0])); auto buffer = transfer_manager->RetrieveBuffer(0); transfer_manager.reset(); EXPECT_THAT( @@ -343,11 +343,11 @@ TEST(TfrtCpuClientTest, AsyncTransferBufferCount) { xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2}); TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( - {shape}, client->addressable_devices()[0])); + {shape}, client->memory_spaces()[0])); EXPECT_EQ(transfer_manager->buffer_count(), 1); - TF_ASSERT_OK_AND_ASSIGN( - transfer_manager, client->CreateBuffersForAsyncHostToDevice( - {shape, shape}, client->addressable_devices()[0])); + TF_ASSERT_OK_AND_ASSIGN(transfer_manager, + client->CreateBuffersForAsyncHostToDevice( + {shape, shape}, client->memory_spaces()[0])); EXPECT_EQ(transfer_manager->buffer_count(), 2); } @@ -356,7 +356,7 @@ TEST(TfrtCpuClientTest, AsyncTransferBufferSize) { xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2}); TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( - {shape}, client->addressable_devices()[0])); + {shape}, client->memory_spaces()[0])); EXPECT_EQ(transfer_manager->buffer_size(0), 3 * 2 * 4); } @@ -364,9 +364,9 @@ TEST(TfrtCpuClientTest, AsyncTransferDevice) { TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(CpuClientOptions())); xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2}); auto* device = client->addressable_devices()[0]; - TF_ASSERT_OK_AND_ASSIGN( - auto transfer_manager, - client->CreateBuffersForAsyncHostToDevice({shape}, device)); + TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, + client->CreateBuffersForAsyncHostToDevice( + {shape}, *device->default_memory_space())); EXPECT_EQ(transfer_manager->device(), device); } @@ -375,7 +375,7 @@ TEST(TfrtCpuClientTest, AsyncTransferSetBufferError) { xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2}); TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( - {shape}, client->addressable_devices()[0])); + {shape}, client->memory_spaces()[0])); auto buffer = transfer_manager->RetrieveBuffer(0); transfer_manager->SetBufferError(0, Internal("foobar")); EXPECT_THAT( @@ -399,7 +399,7 @@ TEST(TfrtCpuClientTest, AsyncTransferRawDataToSubBuffer) { xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2}); TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( - {shape}, client->addressable_devices()[0])); + {shape}, client->memory_spaces()[0])); auto buffer = transfer_manager->RetrieveBuffer(0); auto ready_future = buffer->GetReadyFuture(); EXPECT_THAT(ready_future.IsReady(), IsFalse()); @@ -477,12 +477,9 @@ ENTRY Identity() -> f32[2, 2] { ASSERT_TRUE(!fingerprint.empty()); Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); - TF_ASSERT_OK_AND_ASSIGN( - auto* memory_space, - client->addressable_devices()[0]->default_memory_space()); - TF_ASSERT_OK_AND_ASSIGN( - auto transfer_manager, - client->CreateBuffersForAsyncHostToDevice({shape}, memory_space)); + TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, + client->CreateBuffersForAsyncHostToDevice( + {shape}, client->memory_spaces()[0])); auto buffer = transfer_manager->RetrieveBuffer(0); transfer_manager->SetBufferError(0, Internal("foobar")); @@ -666,7 +663,7 @@ TEST(TfrtCpuClientTest, CopyRawToHost) { xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2}); TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( - {shape}, client->addressable_devices()[0])); + {shape}, client->memory_spaces()[0])); auto buffer = transfer_manager->RetrieveBuffer(0); auto ready_future = buffer->GetReadyFuture(); EXPECT_THAT(ready_future.IsReady(), IsFalse()); diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 7375673e2f3df..3399a5e1f51c5 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -596,32 +596,6 @@ absl::string_view StreamExecutorGpuClient::platform_version() const { #endif // TENSORFLOW_USE_ROCM && defined(TF_ROCM_VERSION) } -absl::StatusOr> -StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( - absl::Span shape_specs, - std::optional>> device_layouts, - PjRtDevice* device) { - auto* stream_executor_device = - tensorflow::down_cast(device); - return xla::AsyncHostToDeviceTransferManager::Create( - shape_specs, std::move(device_layouts), stream_executor_device, this, - /*memory_space=*/nullptr); -} - -absl::StatusOr> -StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( - absl::Span shapes, PjRtDevice* device) { - absl::InlinedVector shape_specs; - shape_specs.reserve(shapes.size()); - for (const auto& shape : shapes) { - shape_specs.emplace_back(PjRtClient::ShapeSpec{ - shape.element_type(), - DimensionVector(shape.dimensions().begin(), shape.dimensions().end())}); - } - return CreateBuffersForAsyncHostToDevice( - shape_specs, /*device_layouts=*/std::nullopt, device); -} - absl::StatusOr> StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( absl::Span shape_specs, @@ -636,20 +610,6 @@ StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( memory_space); } -absl::StatusOr> -StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( - absl::Span shapes, PjRtMemorySpace* memory_space) { - absl::InlinedVector shape_specs; - shape_specs.reserve(shapes.size()); - for (const auto& shape : shapes) { - shape_specs.emplace_back(PjRtClient::ShapeSpec{ - shape.element_type(), - DimensionVector(shape.dimensions().begin(), shape.dimensions().end())}); - } - return CreateBuffersForAsyncHostToDevice( - shape_specs, /*device_layouts=*/std::nullopt, memory_space); -} - absl::StatusOr StreamExecutorGpuClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) const { diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client.h b/xla/pjrt/gpu/se_gpu_pjrt_client.h index 137202eb1c8ae..8f200a9a484ca 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -203,26 +203,14 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { int num_replicas, int num_partitions) const override; absl::string_view platform_version() const override; - absl::StatusOr> - CreateBuffersForAsyncHostToDevice( - absl::Span shape_specs, - std::optional>> device_layouts, - PjRtDevice* device) override; - - absl::StatusOr> - CreateBuffersForAsyncHostToDevice(absl::Span shapes, - PjRtDevice* device) override; + using PjRtStreamExecutorClient::CreateBuffersForAsyncHostToDevice; absl::StatusOr> CreateBuffersForAsyncHostToDevice( absl::Span shape_specs, std::optional>> device_layouts, PjRtMemorySpace* memory_space) override; - absl::StatusOr> - CreateBuffersForAsyncHostToDevice(absl::Span shapes, - PjRtMemorySpace* memory_space) override; - PjRtFuture<> CopyRawSubBufferToHost(PjRtBuffer* buffer, PjRtFuture dst, int64_t offset, int64_t transfer_size) override; diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index 4653fa3d265e0..b47997c64e1e6 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -501,11 +501,12 @@ TEST(StreamExecutorGpuClientTest, ToLiteralAsync) { GetStreamExecutorGpuClient(GpuClientOptions())); ASSERT_GE(client->addressable_devices().size(), 1); + auto* d = client->addressable_devices()[0]; auto src_literal = LiteralUtil::CreateR1({41.0f, 42.0f, 43.0f, 44.0f}); TF_ASSERT_OK_AND_ASSIGN( auto transfer_manager, - client->CreateBuffersForAsyncHostToDevice( - {src_literal.shape()}, client->addressable_devices()[0])); + client->CreateBuffersForAsyncHostToDevice({src_literal.shape()}, + *d->default_memory_space())); auto buffer = transfer_manager->RetrieveBuffer(0); absl::Mutex mu; @@ -586,11 +587,12 @@ TEST(StreamExecutorGpuClientTest, ToLiteralAsyncBeforeBufferReady) { GetStreamExecutorGpuClient(GpuClientOptions())); ASSERT_GE(client->addressable_devices().size(), 1); + auto* d = client->addressable_devices()[0]; auto src_literal = LiteralUtil::CreateR1({41.0f, 42.0f, 43.0f, 44.0f}); TF_ASSERT_OK_AND_ASSIGN( auto transfer_manager, - client->CreateBuffersForAsyncHostToDevice( - {src_literal.shape()}, client->addressable_devices()[0])); + client->CreateBuffersForAsyncHostToDevice({src_literal.shape()}, + *d->default_memory_space())); auto buffer = transfer_manager->RetrieveBuffer(0); absl::Mutex mu; @@ -626,6 +628,7 @@ TEST(StreamExecutorGpuClientTest, FromHostAsync) { GetStreamExecutorGpuClient(GpuClientOptions())); ASSERT_GE(client->addressable_devices().size(), 1); + auto* d = client->addressable_devices()[0]; std::vector src_literals; std::vector src_shapes; for (int i = 0; i < 4; ++i) { @@ -636,7 +639,7 @@ TEST(StreamExecutorGpuClientTest, FromHostAsync) { } TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( - src_shapes, client->addressable_devices()[0])); + src_shapes, *d->default_memory_space())); std::vector> buffers; for (int i = 0; i < src_shapes.size(); ++i) { buffers.emplace_back(transfer_manager->RetrieveBuffer(i)); @@ -949,7 +952,8 @@ TEST(StreamExecutorGpuClientTest, AsyncCopyToDevice) { auto src_literal = LiteralUtil::CreateR1({41.0f, 42.0f, 43.0f, 44.0f}); TF_ASSERT_OK_AND_ASSIGN( auto transfer_manager, - client->CreateBuffersForAsyncHostToDevice({src_literal.shape()}, d0)); + client->CreateBuffersForAsyncHostToDevice({src_literal.shape()}, + *d0->default_memory_space())); auto src_buffer = transfer_manager->RetrieveBuffer(0); // CopyToMemorySpace won't be enqueued until src_buffer is available. auto local_recv_buffer = diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc index be05c632655bf..fa4f036b80ec2 100644 --- a/xla/pjrt/pjrt_c_api_client.cc +++ b/xla/pjrt/pjrt_c_api_client.cc @@ -872,44 +872,6 @@ PjRtCApiClient::CreateBuffersForAsyncHostToDevice( this, args.transfer_manager); } -absl::StatusOr> -PjRtCApiClient::CreateBuffersForAsyncHostToDevice( - absl::Span shape_specs, - std::optional>> device_layouts, - PjRtDevice* device) { - TF_ASSIGN_OR_RETURN(auto memory_space, device->default_memory_space()); - return CreateBuffersForAsyncHostToDevice(shape_specs, device_layouts, - memory_space); -} - -absl::StatusOr> -PjRtCApiClient::CreateBuffersForAsyncHostToDevice( - absl::Span shapes, PjRtDevice* device) { - absl::InlinedVector shape_specs; - shape_specs.reserve(shapes.size()); - for (const auto& shape : shapes) { - shape_specs.emplace_back(PjRtClient::ShapeSpec{ - shape.element_type(), - DimensionVector(shape.dimensions().begin(), shape.dimensions().end())}); - } - return CreateBuffersForAsyncHostToDevice( - shape_specs, /*device_layouts=*/std::nullopt, device); -} - -absl::StatusOr> -PjRtCApiClient::CreateBuffersForAsyncHostToDevice( - absl::Span shapes, PjRtMemorySpace* memory_space) { - absl::InlinedVector shape_specs; - shape_specs.reserve(shapes.size()); - for (const auto& shape : shapes) { - shape_specs.emplace_back(PjRtClient::ShapeSpec{ - shape.element_type(), - DimensionVector(shape.dimensions().begin(), shape.dimensions().end())}); - } - return CreateBuffersForAsyncHostToDevice( - shape_specs, /*device_layouts=*/std::nullopt, memory_space); -} - const PJRT_Api* PjRtCApiClient::pjrt_c_api() const { return c_api_; } // --------------------------------- Device Descriptions ----------------------- diff --git a/xla/pjrt/pjrt_c_api_client.h b/xla/pjrt/pjrt_c_api_client.h index 36cf10bd6990d..355d36a4310a7 100644 --- a/xla/pjrt/pjrt_c_api_client.h +++ b/xla/pjrt/pjrt_c_api_client.h @@ -331,26 +331,12 @@ class PjRtCApiClient : public PjRtClient { absl::StatusOr GetTopologyDescription() const override; - absl::StatusOr> - CreateBuffersForAsyncHostToDevice( - absl::Span shape_specs, - std::optional>> device_layouts, - PjRtDevice* device) override; - absl::StatusOr> CreateBuffersForAsyncHostToDevice( absl::Span shape_specs, std::optional>> device_layouts, PjRtMemorySpace* memory_space) override; - absl::StatusOr> - CreateBuffersForAsyncHostToDevice(absl::Span shapes, - PjRtDevice* device) override; - - absl::StatusOr> - CreateBuffersForAsyncHostToDevice(absl::Span shapes, - PjRtMemorySpace* memory_space) override; - absl::StatusOr> BufferFromHostBuffer( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, diff --git a/xla/pjrt/pjrt_c_api_client_test.cc b/xla/pjrt/pjrt_c_api_client_test.cc index 46f70975db705..514d45d3dbba7 100644 --- a/xla/pjrt/pjrt_c_api_client_test.cc +++ b/xla/pjrt/pjrt_c_api_client_test.cc @@ -181,7 +181,7 @@ TEST(PjRtCApiClientTest, CreateBuffersForAsyncHostToDeviceWithShape) { /*minor_to_major=*/{1, 0, 2}); std::vector host_shapes = {host_shape}; auto status_or_transfer_manager = client->CreateBuffersForAsyncHostToDevice( - absl::MakeSpan(host_shapes), client->addressable_devices()[0]); + absl::MakeSpan(host_shapes), client->memory_spaces()[0]); EXPECT_TRUE(status_or_transfer_manager.ok()) << status_or_transfer_manager.status(); } diff --git a/xla/pjrt/pjrt_client.h b/xla/pjrt/pjrt_client.h index 659c403e25c52..2e714bbe925ff 100644 --- a/xla/pjrt/pjrt_client.h +++ b/xla/pjrt/pjrt_client.h @@ -764,18 +764,6 @@ class PjRtClient { // `device_layouts` itself is not specified, then all buffers will use the // default device layout. virtual absl::StatusOr> - CreateBuffersForAsyncHostToDevice( - absl::Span shape_specs, - std::optional>> device_layouts, - PjRtDevice* device) { - return absl::UnimplementedError(absl::StrCat( - "CreateBuffersForAsyncHostToDevice with ShapeSpec and Layout is " - "not implemented on platform: ", - platform_name())); - } - - // Variant of CreateBuffersForAsyncHostToDevice with PjRtMemorySpace. - virtual absl::StatusOr> CreateBuffersForAsyncHostToDevice( absl::Span shape_specs, std::optional>> device_layouts, @@ -789,19 +777,17 @@ class PjRtClient { // Returns a manager for async transfers into a set of buffers with on-host // shapes 'shapes'. virtual absl::StatusOr> - CreateBuffersForAsyncHostToDevice(absl::Span shapes, - PjRtDevice* device) { - return Unimplemented( - "CreateBuffersForAsyncHostToDevice with on host is not implemented."); - } - - // Variant of CreateBuffersForAsyncHostToDevice with PjRtMemorySpace. - virtual absl::StatusOr> CreateBuffersForAsyncHostToDevice(absl::Span shapes, PjRtMemorySpace* memory_space) { - return Unimplemented( - "CreateBuffersForAsyncHostToDevice with PjRtMemorySpace is not " - "implemented."); + absl::InlinedVector shape_specs; + shape_specs.reserve(shapes.size()); + for (const auto& shape : shapes) { + shape_specs.emplace_back(ShapeSpec{ + shape.element_type(), DimensionVector(shape.dimensions().begin(), + shape.dimensions().end())}); + } + return CreateBuffersForAsyncHostToDevice( + shape_specs, /*device_layouts=*/std::nullopt, memory_space); } // Describes the semantics the caller to BufferFromHostBuffer expects from the diff --git a/xla/pjrt/pjrt_client_utils.cc b/xla/pjrt/pjrt_client_utils.cc new file mode 100644 index 0000000000000..63851475d4eb1 --- /dev/null +++ b/xla/pjrt/pjrt_client_utils.cc @@ -0,0 +1,53 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/pjrt/pjrt_client_utils.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/layout.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/shape_util.h" +#include "xla/util.h" + +namespace xla { + +absl::StatusOr> ConvertShapeSpecsToShapes( + absl::Span shape_specs, + std::optional>> device_layouts) { + if (device_layouts.has_value() && + device_layouts->size() != shape_specs.size()) { + return InvalidArgument( + "Number of layouts %d does not match the number of shapes %d", + device_layouts->size(), shape_specs.size()); + } + std::vector device_shapes; + device_shapes.reserve(shape_specs.size()); + for (size_t i = 0; i < shape_specs.size(); ++i) { + auto& shape_spec = shape_specs[i]; + Shape& device_shape = device_shapes.emplace_back( + ShapeUtil::MakeShape(shape_spec.element_type, shape_spec.dims)); + if (device_layouts.has_value() && (*device_layouts)[i].has_value()) { + *device_shape.mutable_layout() = *(*device_layouts)[i]; + } + } + return device_shapes; +} + +} // namespace xla diff --git a/xla/pjrt/pjrt_client_utils.h b/xla/pjrt/pjrt_client_utils.h new file mode 100644 index 0000000000000..fe9f27949931b --- /dev/null +++ b/xla/pjrt/pjrt_client_utils.h @@ -0,0 +1,35 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_PJRT_CLIENT_UTILS_H_ +#define XLA_PJRT_PJRT_CLIENT_UTILS_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/layout.h" +#include "xla/pjrt/pjrt_client.h" + +namespace xla { + +absl::StatusOr> ConvertShapeSpecsToShapes( + absl::Span shape_specs, + std::optional>> device_layouts); + +} // namespace xla + +#endif // XLA_PJRT_PJRT_CLIENT_UTILS_H_ diff --git a/xla/pjrt/pjrt_stream_executor_client.h b/xla/pjrt/pjrt_stream_executor_client.h index 6f8969795e1fd..2654372331cb3 100644 --- a/xla/pjrt/pjrt_stream_executor_client.h +++ b/xla/pjrt/pjrt_stream_executor_client.h @@ -345,18 +345,6 @@ class PjRtStreamExecutorClient : public PjRtClient { absl::StatusOr> CreateErrorBuffer( absl::Status error, const Shape& shape, PjRtMemorySpace* memory) override; - absl::StatusOr> - CreateBuffersForAsyncHostToDevice(absl::Span shapes, - PjRtDevice* device) override { - return Unimplemented("Async transfer to buffers not implemented"); - }; - - absl::StatusOr> - CreateBuffersForAsyncHostToDevice(absl::Span shapes, - PjRtMemorySpace* memory_space) override { - return Unimplemented("Async transfer to buffers not implemented"); - }; - absl::StatusOr> BufferFromHostBuffer( const void* data, PrimitiveType type, absl::Span dims, std::optional> byte_strides, diff --git a/xla/pjrt/tf_pjrt_client.h b/xla/pjrt/tf_pjrt_client.h index bdc9cf98532b9..3cd2f0b9b36a7 100644 --- a/xla/pjrt/tf_pjrt_client.h +++ b/xla/pjrt/tf_pjrt_client.h @@ -276,12 +276,6 @@ class TfPjRtClient : public PjRtClient { "CreateUninitializedBuffer not supported for TfPjRtClient."); } absl::StatusOr> - CreateBuffersForAsyncHostToDevice(absl::Span shapes, - PjRtDevice* device) override { - return Unimplemented( - "AsyncHostToDeviceTransferManager not supported for Tf."); - } - absl::StatusOr> CreateBuffersForAsyncHostToDevice(absl::Span shapes, PjRtMemorySpace* memory_space) override { return Unimplemented( diff --git a/xla/python/transfer/streaming_ifrt_test.cc b/xla/python/transfer/streaming_ifrt_test.cc index f9d9b6e855eeb..b2e40e2c3b3ca 100644 --- a/xla/python/transfer/streaming_ifrt_test.cc +++ b/xla/python/transfer/streaming_ifrt_test.cc @@ -74,9 +74,11 @@ absl::StatusOr SetupTransferDestList( xla::ifrt::PjRtClient* ifrt_client, size_t xfer_size) { auto* pjrt_client = ifrt_client->pjrt_client(); // CHECK_EQ(pjrt_client->platform_id(), xla::TpuId()); + TF_ASSIGN_OR_RETURN(auto* pjrt_memory_space, + device->pjrt_device()->default_memory_space()); TF_ASSIGN_OR_RETURN(auto atm_owned, - pjrt_client->CreateBuffersForAsyncHostToDevice( - {shape}, device->pjrt_device())); + pjrt_client->CreateBuffersForAsyncHostToDevice( + {shape}, pjrt_memory_space)); auto atm = std::shared_ptr( std::move(atm_owned)); SingleBufferCopyPlan results;