Skip to content

Commit

Permalink
[IFRT] Add new Client::AssembleArrayFromSingleDeviceArrays API that t…
Browse files Browse the repository at this point in the history
…akes a `dtype` argument.

This is necessary to support IFRT/PJRT arrays with no buffers, as a step toward MPMD/pipeline parallelism.

PiperOrigin-RevId: 725908366
  • Loading branch information
emilyfertig authored and Google-ML-Automation committed Feb 12, 2025
1 parent e07a682 commit b92e86c
Show file tree
Hide file tree
Showing 14 changed files with 120 additions and 11 deletions.
15 changes: 15 additions & 0 deletions xla/backends/cpu/nanort/ifrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,21 @@ NanoIfrtClient::AssembleArrayFromSingleDeviceArrays(
array_copy_semantics);
}

absl::StatusOr<tsl::RCReference<ifrt::Array>>
NanoIfrtClient::AssembleArrayFromSingleDeviceArrays(
ifrt::DType dtype, ifrt::Shape shape,
absl::Nonnull<std::shared_ptr<const ifrt::Sharding>> sharding,
absl::Span<tsl::RCReference<ifrt::Array>> arrays,
ifrt::ArrayCopySemantics array_copy_semantics,
ifrt::SingleDeviceShardSemantics single_device_shard_semantics) {
// NanoRT devices always have at least one buffer, so we can use the buffer
// dtype.
TF_RET_CHECK(!arrays.empty());
TF_RET_CHECK(dtype == arrays.front()->dtype());
return AssembleArrayFromSingleDeviceArrays(shape, sharding, arrays,
array_copy_semantics);
}

absl::StatusOr<std::vector<tsl::RCReference<ifrt::Array>>>
NanoIfrtClient::CopyArrays(
absl::Span<tsl::RCReference<ifrt::Array>> arrays,
Expand Down
7 changes: 7 additions & 0 deletions xla/backends/cpu/nanort/ifrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ class NanoIfrtClient : public llvm::RTTIExtends<NanoIfrtClient, ifrt::Client> {
absl::Span<tsl::RCReference<ifrt::Array>> arrays,
ifrt::ArrayCopySemantics array_copy_semantics,
ifrt::SingleDeviceShardSemantics single_device_shard_semantics) override;
absl::StatusOr<tsl::RCReference<ifrt::Array>>
AssembleArrayFromSingleDeviceArrays(
ifrt::DType dtype, ifrt::Shape shape,
absl::Nonnull<std::shared_ptr<const ifrt::Sharding>> sharding,
absl::Span<tsl::RCReference<ifrt::Array>> arrays,
ifrt::ArrayCopySemantics array_copy_semantics,
ifrt::SingleDeviceShardSemantics single_device_shard_semantics) override;

absl::StatusOr<std::vector<tsl::RCReference<ifrt::Array>>> CopyArrays(
absl::Span<tsl::RCReference<ifrt::Array>> arrays,
Expand Down
8 changes: 4 additions & 4 deletions xla/python/ifrt/array_impl_test_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ TEST(ArrayImplTest, AssembleArray) {
TF_ASSERT_OK_AND_ASSIGN(
auto assembled_array,
client->AssembleArrayFromSingleDeviceArrays(
assembled_shape, assembled_sharding, absl::MakeSpan(arrays),
dtype, assembled_shape, assembled_sharding, absl::MakeSpan(arrays),
ArrayCopySemantics::kAlwaysCopy,
SingleDeviceShardSemantics::kAddressableShards));

Expand Down Expand Up @@ -482,7 +482,7 @@ TEST(ArrayImplTest, AssembleAndDisassembleSingleDeviceArray) {

TF_ASSERT_OK_AND_ASSIGN(auto assembled_array,
client->AssembleArrayFromSingleDeviceArrays(
shape, sharding, absl::MakeSpan(arrays),
dtype, shape, sharding, absl::MakeSpan(arrays),
ArrayCopySemantics::kAlwaysCopy,
SingleDeviceShardSemantics::kAddressableShards));

Expand Down Expand Up @@ -565,7 +565,7 @@ TEST(ArrayImplTest, CopyToDifferentDevice) {
TF_ASSERT_OK_AND_ASSIGN(
arrays.emplace_back(),
client->AssembleArrayFromSingleDeviceArrays(
shape, sharding, absl::MakeSpan(shards),
dtype, shape, sharding, absl::MakeSpan(shards),
ArrayCopySemantics::kAlwaysCopy,
SingleDeviceShardSemantics::kAddressableShards));
}
Expand All @@ -575,7 +575,7 @@ TEST(ArrayImplTest, CopyToDifferentDevice) {
TF_ASSERT_OK_AND_ASSIGN(
arrays.emplace_back(),
client->AssembleArrayFromSingleDeviceArrays(
shape, sharding, absl::MakeSpan(shards),
dtype, shape, sharding, absl::MakeSpan(shards),
ArrayCopySemantics::kAlwaysCopy,
SingleDeviceShardSemantics::kAddressableShards));
}
Expand Down
9 changes: 8 additions & 1 deletion xla/python/ifrt/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class Client : public llvm::RTTIExtends<Client, llvm::RTTIRoot> {

// Builds a larger array out of individual per-device shards.
// TODO(hyeontaek): Replace this API with the version that takes
// `SingleDeviceShardSemantics`.
// `SingleDeviceShardSemantics` and `dtype`.
virtual absl::StatusOr<tsl::RCReference<Array>>
AssembleArrayFromSingleDeviceArrays(
Shape shape, absl::Nonnull<std::shared_ptr<const Sharding>> sharding,
Expand All @@ -128,6 +128,13 @@ class Client : public llvm::RTTIExtends<Client, llvm::RTTIRoot> {
absl::Span<tsl::RCReference<Array>> arrays,
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) = 0;
virtual absl::StatusOr<tsl::RCReference<Array>>
AssembleArrayFromSingleDeviceArrays(
DType dtype, Shape shape,
absl::Nonnull<std::shared_ptr<const Sharding>> sharding,
absl::Span<tsl::RCReference<Array>> arrays,
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) = 0;

// Copies the arrays to a new set of devices.
//
Expand Down
11 changes: 11 additions & 0 deletions xla/python/ifrt/mock.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,17 @@ MockClient::MockClient(std::unique_ptr<xla::ifrt::Client> delegated)
std::move(shape), std::move(sharding), arrays,
array_copy_semantics, single_device_shard_semantics);
});
ON_CALL(*this, AssembleArrayFromSingleDeviceArrays(_, _, _, _, _, _))
.WillByDefault(
[this](DType dtype, Shape shape,
absl::Nonnull<std::shared_ptr<const Sharding>> sharding,
absl::Span<tsl::RCReference<Array>> arrays,
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) {
return delegated_->AssembleArrayFromSingleDeviceArrays(
std::move(dtype), std::move(shape), std::move(sharding), arrays,
array_copy_semantics, single_device_shard_semantics);
});
ON_CALL(*this, CopyArrays)
.WillByDefault([this](absl::Span<tsl::RCReference<Array>> arrays,
std::optional<tsl::RCReference<DeviceList>> devices,
Expand Down
8 changes: 8 additions & 0 deletions xla/python/ifrt/mock.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,14 @@ class MockClient : public llvm::RTTIExtends<MockClient, Client> {
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics),
(final));
MOCK_METHOD(absl::StatusOr<tsl::RCReference<Array>>,
AssembleArrayFromSingleDeviceArrays,
(DType dtype, Shape shape,
absl::Nonnull<std::shared_ptr<const Sharding>> sharding,
absl::Span<tsl::RCReference<Array>> arrays,
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics),
(final));
MOCK_METHOD(absl::StatusOr<std::vector<tsl::RCReference<Array>>>, CopyArrays,
(absl::Span<tsl::RCReference<Array>> arrays,
std::optional<tsl::RCReference<DeviceList>> devices,
Expand Down
11 changes: 11 additions & 0 deletions xla/python/ifrt_proxy/client/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,17 @@ Client::AssembleArrayFromSingleDeviceArrays(
array_copy_semantics, single_device_shard_semantics);
}

absl::StatusOr<tsl::RCReference<xla::ifrt::Array>>
Client::AssembleArrayFromSingleDeviceArrays(
DType dtype, Shape shape, std::shared_ptr<const Sharding> sharding,
absl::Span<tsl::RCReference<xla::ifrt::Array>> arrays,
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) {
return Array::AssembleArrayFromSingleDeviceArrays(
this, rpc_helper_, std::move(shape), sharding, arrays,
array_copy_semantics, single_device_shard_semantics);
}

absl::StatusOr<std::vector<tsl::RCReference<xla::ifrt::Array>>>
Client::CopyArrays(
absl::Span<tsl::RCReference<xla::ifrt::Array>> arrays,
Expand Down
6 changes: 6 additions & 0 deletions xla/python/ifrt_proxy/client/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ class Client final : public llvm::RTTIExtends<Client, xla::ifrt::Client> {
absl::Span<tsl::RCReference<xla::ifrt::Array>> arrays,
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) override;
absl::StatusOr<tsl::RCReference<xla::ifrt::Array>>
AssembleArrayFromSingleDeviceArrays(
DType dtype, Shape shape, std::shared_ptr<const Sharding> sharding,
absl::Span<tsl::RCReference<xla::ifrt::Array>> arrays,
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) override;

absl::StatusOr<std::vector<tsl::RCReference<Array>>> CopyArrays(
absl::Span<tsl::RCReference<Array>> arrays,
Expand Down
1 change: 1 addition & 0 deletions xla/python/ifrt_proxy/common/ifrt_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ message AssembleArrayFromSingleDeviceArraysRequest {
proto.ArrayCopySemantics copy_semantics = 4;
optional proto.SingleDeviceShardSemantics single_device_shard_semantics = 5;
fixed64 result_handle = 6;
optional DTypeProto dtype = 7;
}
message AssembleArrayFromSingleDeviceArraysResponse {
fixed64 array_handle = 1;
Expand Down
4 changes: 4 additions & 0 deletions xla/python/ifrt_proxy/common/versions.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ enum {
// related to LoadedExecutable.
kClientHandlesExecutableOptimization,

// kAssembleArrayFromSingleDeviceArraysWithDType adds a DType argument to
// AssembleArrayFromSingleDeviceArrays to support non-addressable arrays.
kAssembleArrayFromSingleDeviceArraysWithDType,

// kSentiel is used to derive kCurrent below. Keep this as the last value of
// the enum.
kSentiel,
Expand Down
17 changes: 15 additions & 2 deletions xla/python/pjrt_ifrt/pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,20 @@ PjRtClient::AssembleArrayFromSingleDeviceArrays(
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) {
DCHECK(this);
DCHECK(!arrays.empty());
DType dtype = arrays[0]->dtype();
return AssembleArrayFromSingleDeviceArrays(
dtype, std::move(shape), std::move(sharding), arrays,
array_copy_semantics, single_device_shard_semantics);
}

absl::StatusOr<tsl::RCReference<Array>>
PjRtClient::AssembleArrayFromSingleDeviceArrays(
DType dtype, Shape shape, std::shared_ptr<const Sharding> sharding,
absl::Span<tsl::RCReference<Array>> arrays,
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) {
DCHECK(this);
if (llvm::isa<const SingleDeviceSharding>(sharding.get())) {
// Assemble with SingleDeviceSharding is No-op.
if (arrays.size() != 1) {
Expand Down Expand Up @@ -1016,14 +1030,13 @@ PjRtClient::AssembleArrayFromSingleDeviceArrays(
"single-shard arrays: %d vs. %d",
sharding->devices()->AddressableDeviceList()->size(), arrays.size());
}
if (arrays[0]->dtype().kind() == DType::kString) {
if (dtype.kind() == DType::kString) {
return AssembleStringArrayFromSingleDeviceStringArrays(
shape, sharding, arrays, array_copy_semantics,
single_device_shard_semantics);
}
PjRtArray::PjRtBuffers buffers;
buffers.reserve(arrays.size());
DType dtype = arrays[0]->dtype();
for (int i = 0; i < arrays.size(); ++i) {
if (!llvm::isa<PjRtCompatibleArray>(arrays[i].get())) {
return InvalidArgument(
Expand Down
5 changes: 5 additions & 0 deletions xla/python/pjrt_ifrt/pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,11 @@ class PjRtClient final
absl::Span<tsl::RCReference<Array>> arrays,
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) override;
absl::StatusOr<tsl::RCReference<Array>> AssembleArrayFromSingleDeviceArrays(
DType dtype, Shape shape, std::shared_ptr<const Sharding> sharding,
absl::Span<tsl::RCReference<Array>> arrays,
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) override;

absl::StatusOr<std::vector<tsl::RCReference<Array>>> CopyArrays(
absl::Span<tsl::RCReference<Array>> arrays,
Expand Down
18 changes: 14 additions & 4 deletions xla/python/py_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,20 @@ tsl::RCReference<ifrt::Array> CreateIfRtArrayFromSingleDeviceShardedPyArrays(
// TODO(hyeontaek): Return a absl::Status.
throw nb::value_error(ifrt_sharding.status().ToString().c_str());
}
auto ifrt_array = client->AssembleArrayFromSingleDeviceArrays(
ifrt::Shape(shape), *std::move(ifrt_sharding),
absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput,
ifrt::SingleDeviceShardSemantics::kAddressableShards);
absl::StatusOr<tsl::RCReference<ifrt::Array>> ifrt_array;
// TODO(emilyaf): Always call the version that takes `dtype` once tokens are
// handled correctly.
if (ifrt_arrays.empty()) {
ifrt_array = client->AssembleArrayFromSingleDeviceArrays(
ifrt_dtype.value(), ifrt::Shape(shape), *std::move(ifrt_sharding),
absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput,
ifrt::SingleDeviceShardSemantics::kAddressableShards);
} else {
ifrt_array = client->AssembleArrayFromSingleDeviceArrays(
ifrt::Shape(shape), *std::move(ifrt_sharding),
absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput,
ifrt::SingleDeviceShardSemantics::kAddressableShards);
}
if (!ifrt_array.ok()) {
// TODO(hyeontaek): Return a absl::Status.
throw nb::value_error(ifrt_array.status().ToString().c_str());
Expand Down
11 changes: 11 additions & 0 deletions xla/python/py_compile_only_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,17 @@ class CompileOnlyIfRtClient final
"AssembleArrayFromSingleDeviceArrays not available with compile-only "
"client.");
}
absl::StatusOr<tsl::RCReference<ifrt::Array>>
AssembleArrayFromSingleDeviceArrays(
ifrt::DType dtype, ifrt::Shape shape,
std::shared_ptr<const ifrt::Sharding> sharding,
absl::Span<tsl::RCReference<ifrt::Array>> arrays,
ifrt::ArrayCopySemantics array_copy_semantics,
ifrt::SingleDeviceShardSemantics single_device_shard_semantics) override {
return Unimplemented(
"AssembleArrayFromSingleDeviceArrays not available with compile-only "
"client.");
}

absl::StatusOr<std::vector<tsl::RCReference<ifrt::Array>>> CopyArrays(
absl::Span<tsl::RCReference<ifrt::Array>> arrays,
Expand Down

0 comments on commit b92e86c

Please sign in to comment.