Skip to content

Commit

Permalink
[pjrt] Removed PjRtDevice overloads of `PjRtClient::CreateBuffersForA…
Browse files Browse the repository at this point in the history
…syncHostToDevice`

I also pulled in the `Shape`->`ShapeSpec` conversion code into the default
implementation, since it was duplicated in a few clients.

PiperOrigin-RevId: 726646980
  • Loading branch information
superbobry authored and Google-ML-Automation committed Feb 13, 2025
1 parent 5322b44 commit d73250c
Show file tree
Hide file tree
Showing 17 changed files with 155 additions and 232 deletions.
14 changes: 14 additions & 0 deletions xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,20 @@ cc_library(
],
)

cc_library(
name = "pjrt_client_utils",
srcs = ["pjrt_client_utils.cc"],
hdrs = ["pjrt_client_utils.h"],
visibility = internal_visibility(["//xla:friends"]),
deps = [
":pjrt_client",
"//xla:shape_util",
"//xla:util",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
],
)

cc_library(
name = "pjrt_client_test_common",
testonly = 1,
Expand Down
1 change: 1 addition & 0 deletions xla/pjrt/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ cc_library(
"//xla/pjrt:layout_mode",
"//xla/pjrt:mlir_to_hlo",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_client_utils",
"//xla/pjrt:pjrt_common",
"//xla/pjrt:pjrt_compiler",
"//xla/pjrt:pjrt_executable",
Expand Down
49 changes: 6 additions & 43 deletions xla/pjrt/cpu/cpu_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ limitations under the License.
#include "xla/pjrt/layout_mode.h"
#include "xla/pjrt/mlir_to_hlo.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_client_utils.h"
#include "xla/pjrt/pjrt_common.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/pjrt_executable.h"
Expand Down Expand Up @@ -947,41 +948,14 @@ TfrtCpuClient::CreateUninitializedBuffer(const Shape& shape,
tensorflow::down_cast<TfrtCpuDevice*>(device), this);
}

absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
TfrtCpuClient::CreateBuffersForAsyncHostToDevice(absl::Span<const Shape> shapes,
PjRtDevice* device) {
auto* tfrt_device = tensorflow::down_cast<TfrtCpuDevice*>(device);
return TfrtCpuAsyncHostToDeviceTransferManager::Create(shapes, tfrt_device,
this);
}

absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
TfrtCpuClient::CreateBuffersForAsyncHostToDevice(
absl::Span<const Shape> shapes, PjRtMemorySpace* memory_space) {
CHECK_EQ(memory_space->devices().size(), 1);
return CreateBuffersForAsyncHostToDevice(shapes, memory_space->devices()[0]);
}

static absl::StatusOr<std::vector<xla::Shape>> ConvertShapeSpecToShapes(
absl::Span<const PjRtClient::ShapeSpec> shape_specs,
std::optional<absl::Span<const std::optional<Layout>>> device_layouts) {
if (device_layouts.has_value() &&
device_layouts->size() != shape_specs.size()) {
return InvalidArgument(
"Number of layouts %d does not match the number of shapes %d",
device_layouts->size(), shape_specs.size());
}
std::vector<xla::Shape> device_shapes;
device_shapes.reserve(shape_specs.size());
for (size_t i = 0; i < shape_specs.size(); ++i) {
auto& shape_spec = shape_specs[i];
Shape& device_shape = device_shapes.emplace_back(
ShapeUtil::MakeShape(shape_spec.element_type, shape_spec.dims));
if (device_layouts.has_value() && (*device_layouts)[i].has_value()) {
*device_shape.mutable_layout() = *(*device_layouts)[i];
}
}
return device_shapes;
auto* tfrt_device =
tensorflow::down_cast<TfrtCpuDevice*>(memory_space->devices().front());
return TfrtCpuAsyncHostToDeviceTransferManager::Create(shapes, tfrt_device,
this);
}

absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
Expand All @@ -990,22 +964,11 @@ TfrtCpuClient::CreateBuffersForAsyncHostToDevice(
std::optional<absl::Span<const std::optional<Layout>>> device_layouts,
PjRtMemorySpace* memory_space) {
TF_ASSIGN_OR_RETURN(std::vector<xla::Shape> device_shapes,
ConvertShapeSpecToShapes(shape_specs, device_layouts));
ConvertShapeSpecsToShapes(shape_specs, device_layouts));
return CreateBuffersForAsyncHostToDevice(absl::MakeSpan(device_shapes),
memory_space);
}

absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
TfrtCpuClient::CreateBuffersForAsyncHostToDevice(
absl::Span<const PjRtClient::ShapeSpec> shape_specs,
std::optional<absl::Span<const std::optional<Layout>>> device_layouts,
PjRtDevice* device) {
TF_ASSIGN_OR_RETURN(std::vector<xla::Shape> device_shapes,
ConvertShapeSpecToShapes(shape_specs, device_layouts));
return CreateBuffersForAsyncHostToDevice(absl::MakeSpan(device_shapes),
device);
}

absl::StatusOr<std::unique_ptr<PjRtBuffer>> TfrtCpuClient::BufferFromHostBuffer(
const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
std::optional<absl::Span<int64_t const>> byte_strides,
Expand Down
10 changes: 0 additions & 10 deletions xla/pjrt/cpu/cpu_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,6 @@ class TfrtCpuClient final : public PjRtClient {
absl::StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtMemorySpace* device) override;

absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
CreateBuffersForAsyncHostToDevice(absl::Span<const Shape> shapes,
PjRtDevice* device) override;

absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
CreateBuffersForAsyncHostToDevice(absl::Span<const Shape> shapes,
PjRtMemorySpace* memory_space) override;
Expand All @@ -158,12 +154,6 @@ class TfrtCpuClient final : public PjRtClient {
std::optional<absl::Span<const std::optional<Layout>>> device_layouts,
PjRtMemorySpace* memory_space) override;

absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
CreateBuffersForAsyncHostToDevice(
absl::Span<const PjRtClient::ShapeSpec> shape_specs,
std::optional<absl::Span<const std::optional<Layout>>> device_layouts,
PjRtDevice* device) override;

absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
std::optional<absl::Span<int64_t const>> byte_strides,
Expand Down
45 changes: 21 additions & 24 deletions xla/pjrt/cpu/cpu_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ TEST(TfrtCpuClientTest, AsyncTransferRawData) {
xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2});
TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager,
client->CreateBuffersForAsyncHostToDevice(
{shape}, client->addressable_devices()[0]));
{shape}, client->memory_spaces()[0]));
auto buffer = transfer_manager->RetrieveBuffer(0);
auto ready_future = buffer->GetReadyFuture();
EXPECT_THAT(ready_future.IsReady(), IsFalse());
Expand All @@ -245,8 +245,8 @@ TEST(TfrtCpuClientTest, AsyncTransferWithSpecs) {
PjRtClient::ShapeSpec shape_spec{U32, {3, 2}};
TF_ASSERT_OK_AND_ASSIGN(
auto transfer_manager,
client->CreateBuffersForAsyncHostToDevice(
{shape_spec}, std::nullopt, client->addressable_devices()[0]));
client->CreateBuffersForAsyncHostToDevice({shape_spec}, std::nullopt,
client->memory_spaces()[0]));
auto buffer = transfer_manager->RetrieveBuffer(0);
auto ready_future = buffer->GetReadyFuture();
EXPECT_THAT(ready_future.IsReady(), IsFalse());
Expand All @@ -266,7 +266,7 @@ TEST(TfrtCpuClientTest, AsyncTransferLiteral) {
xla::Shape shape = xla::ShapeUtil::MakeShape(F32, {128, 256});
TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager,
client->CreateBuffersForAsyncHostToDevice(
{shape}, client->addressable_devices()[0]));
{shape}, client->memory_spaces()[0]));
auto buffer = transfer_manager->RetrieveBuffer(0);
auto ready_future = buffer->GetReadyFuture();
EXPECT_THAT(ready_future.IsReady(), IsFalse());
Expand All @@ -282,7 +282,7 @@ TEST(TfrtCpuClientTest, AsyncTransferLiteralInt4) {
xla::Shape shape = xla::ShapeUtil::MakeShape(S4, {128, 256});
TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager,
client->CreateBuffersForAsyncHostToDevice(
{shape}, client->addressable_devices()[0]));
{shape}, client->memory_spaces()[0]));
auto buffer = transfer_manager->RetrieveBuffer(0);
auto ready_future = buffer->GetReadyFuture();
EXPECT_THAT(ready_future.IsReady(), IsFalse());
Expand Down Expand Up @@ -310,7 +310,7 @@ TEST(TfrtCpuClientTest, AsyncTransferCallsOnDone) {
xla::Shape shape = ShapeUtil::MakeShape(F32, {3, 2});
TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager,
client->CreateBuffersForAsyncHostToDevice(
{shape}, client->addressable_devices()[0]));
{shape}, client->memory_spaces()[0]));
auto buffer = transfer_manager->RetrieveBuffer(0);
auto ready_future = buffer->GetReadyFuture();
EXPECT_THAT(ready_future.IsReady(), IsFalse());
Expand All @@ -328,7 +328,7 @@ TEST(TfrtCpuClientTest, AsyncTransferNeverTransferred) {
xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2});
TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager,
client->CreateBuffersForAsyncHostToDevice(
{shape}, client->addressable_devices()[0]));
{shape}, client->memory_spaces()[0]));
auto buffer = transfer_manager->RetrieveBuffer(0);
transfer_manager.reset();
EXPECT_THAT(
Expand All @@ -343,11 +343,11 @@ TEST(TfrtCpuClientTest, AsyncTransferBufferCount) {
xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2});
TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager,
client->CreateBuffersForAsyncHostToDevice(
{shape}, client->addressable_devices()[0]));
{shape}, client->memory_spaces()[0]));
EXPECT_EQ(transfer_manager->buffer_count(), 1);
TF_ASSERT_OK_AND_ASSIGN(
transfer_manager, client->CreateBuffersForAsyncHostToDevice(
{shape, shape}, client->addressable_devices()[0]));
TF_ASSERT_OK_AND_ASSIGN(transfer_manager,
client->CreateBuffersForAsyncHostToDevice(
{shape, shape}, client->memory_spaces()[0]));
EXPECT_EQ(transfer_manager->buffer_count(), 2);
}

Expand All @@ -356,17 +356,17 @@ TEST(TfrtCpuClientTest, AsyncTransferBufferSize) {
xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2});
TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager,
client->CreateBuffersForAsyncHostToDevice(
{shape}, client->addressable_devices()[0]));
{shape}, client->memory_spaces()[0]));
EXPECT_EQ(transfer_manager->buffer_size(0), 3 * 2 * 4);
}

TEST(TfrtCpuClientTest, AsyncTransferDevice) {
TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(CpuClientOptions()));
xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2});
auto* device = client->addressable_devices()[0];
TF_ASSERT_OK_AND_ASSIGN(
auto transfer_manager,
client->CreateBuffersForAsyncHostToDevice({shape}, device));
TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager,
client->CreateBuffersForAsyncHostToDevice(
{shape}, *device->default_memory_space()));
EXPECT_EQ(transfer_manager->device(), device);
}

Expand All @@ -375,7 +375,7 @@ TEST(TfrtCpuClientTest, AsyncTransferSetBufferError) {
xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2});
TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager,
client->CreateBuffersForAsyncHostToDevice(
{shape}, client->addressable_devices()[0]));
{shape}, client->memory_spaces()[0]));
auto buffer = transfer_manager->RetrieveBuffer(0);
transfer_manager->SetBufferError(0, Internal("foobar"));
EXPECT_THAT(
Expand All @@ -399,7 +399,7 @@ TEST(TfrtCpuClientTest, AsyncTransferRawDataToSubBuffer) {
xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2});
TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager,
client->CreateBuffersForAsyncHostToDevice(
{shape}, client->addressable_devices()[0]));
{shape}, client->memory_spaces()[0]));
auto buffer = transfer_manager->RetrieveBuffer(0);
auto ready_future = buffer->GetReadyFuture();
EXPECT_THAT(ready_future.IsReady(), IsFalse());
Expand Down Expand Up @@ -477,12 +477,9 @@ ENTRY Identity() -> f32[2, 2] {
ASSERT_TRUE(!fingerprint.empty());

Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
TF_ASSERT_OK_AND_ASSIGN(
auto* memory_space,
client->addressable_devices()[0]->default_memory_space());
TF_ASSERT_OK_AND_ASSIGN(
auto transfer_manager,
client->CreateBuffersForAsyncHostToDevice({shape}, memory_space));
TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager,
client->CreateBuffersForAsyncHostToDevice(
{shape}, client->memory_spaces()[0]));
auto buffer = transfer_manager->RetrieveBuffer(0);
transfer_manager->SetBufferError(0, Internal("foobar"));

Expand Down Expand Up @@ -666,7 +663,7 @@ TEST(TfrtCpuClientTest, CopyRawToHost) {
xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2});
TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager,
client->CreateBuffersForAsyncHostToDevice(
{shape}, client->addressable_devices()[0]));
{shape}, client->memory_spaces()[0]));
auto buffer = transfer_manager->RetrieveBuffer(0);
auto ready_future = buffer->GetReadyFuture();
EXPECT_THAT(ready_future.IsReady(), IsFalse());
Expand Down
40 changes: 0 additions & 40 deletions xla/pjrt/gpu/se_gpu_pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -596,32 +596,6 @@ absl::string_view StreamExecutorGpuClient::platform_version() const {
#endif // TENSORFLOW_USE_ROCM && defined(TF_ROCM_VERSION)
}

absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice(
absl::Span<const PjRtClient::ShapeSpec> shape_specs,
std::optional<absl::Span<const std::optional<Layout>>> device_layouts,
PjRtDevice* device) {
auto* stream_executor_device =
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device);
return xla::AsyncHostToDeviceTransferManager::Create(
shape_specs, std::move(device_layouts), stream_executor_device, this,
/*memory_space=*/nullptr);
}

absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice(
absl::Span<const Shape> shapes, PjRtDevice* device) {
absl::InlinedVector<PjRtClient::ShapeSpec, 4> shape_specs;
shape_specs.reserve(shapes.size());
for (const auto& shape : shapes) {
shape_specs.emplace_back(PjRtClient::ShapeSpec{
shape.element_type(),
DimensionVector(shape.dimensions().begin(), shape.dimensions().end())});
}
return CreateBuffersForAsyncHostToDevice(
shape_specs, /*device_layouts=*/std::nullopt, device);
}

absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice(
absl::Span<const PjRtClient::ShapeSpec> shape_specs,
Expand All @@ -636,20 +610,6 @@ StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice(
memory_space);
}

absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice(
absl::Span<const Shape> shapes, PjRtMemorySpace* memory_space) {
absl::InlinedVector<PjRtClient::ShapeSpec, 4> shape_specs;
shape_specs.reserve(shapes.size());
for (const auto& shape : shapes) {
shape_specs.emplace_back(PjRtClient::ShapeSpec{
shape.element_type(),
DimensionVector(shape.dimensions().begin(), shape.dimensions().end())});
}
return CreateBuffersForAsyncHostToDevice(
shape_specs, /*device_layouts=*/std::nullopt, memory_space);
}

absl::StatusOr<xla::DeviceAssignment>
StreamExecutorGpuClient::GetDefaultDeviceAssignment(int num_replicas,
int num_partitions) const {
Expand Down
14 changes: 1 addition & 13 deletions xla/pjrt/gpu/se_gpu_pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,26 +203,14 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient {
int num_replicas, int num_partitions) const override;

absl::string_view platform_version() const override;
absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
CreateBuffersForAsyncHostToDevice(
absl::Span<const PjRtClient::ShapeSpec> shape_specs,
std::optional<absl::Span<const std::optional<Layout>>> device_layouts,
PjRtDevice* device) override;

absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
CreateBuffersForAsyncHostToDevice(absl::Span<const Shape> shapes,
PjRtDevice* device) override;

using PjRtStreamExecutorClient::CreateBuffersForAsyncHostToDevice;
absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
CreateBuffersForAsyncHostToDevice(
absl::Span<const PjRtClient::ShapeSpec> shape_specs,
std::optional<absl::Span<const std::optional<Layout>>> device_layouts,
PjRtMemorySpace* memory_space) override;

absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
CreateBuffersForAsyncHostToDevice(absl::Span<const Shape> shapes,
PjRtMemorySpace* memory_space) override;

PjRtFuture<> CopyRawSubBufferToHost(PjRtBuffer* buffer, PjRtFuture<void*> dst,
int64_t offset,
int64_t transfer_size) override;
Expand Down
Loading

0 comments on commit d73250c

Please sign in to comment.