Skip to content

Commit

Permalink
Change DeviceList::FromProto() to use Client::MakeDeviceList() to…
Browse files Browse the repository at this point in the history
… create device lists

This lets the IFRT implementation control how device lists are deserialized, effectively addressing the TODO for device list SerDes in a different way. This requires passing `Client*` to `DeviceList::FromProto()` and SerDes implementations that internally call `DeviceList::FromProto()`.

PiperOrigin-RevId: 726124799
  • Loading branch information
junwhanahn authored and Google-ML-Automation committed Feb 12, 2025
1 parent ac1e4f2 commit eabe200
Show file tree
Hide file tree
Showing 22 changed files with 80 additions and 140 deletions.
6 changes: 0 additions & 6 deletions xla/python/ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/container:node_hash_set",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
Expand Down Expand Up @@ -468,7 +467,6 @@ cc_library(
":ifrt",
":serdes",
"//xla:util",
"@com_google_absl//absl/status:statusor",
"@llvm-project//llvm:Support",
],
)
Expand Down Expand Up @@ -502,7 +500,6 @@ xla_cc_test(
":sharding_serdes",
"//xla/python/ifrt/ir:sharding_param",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/functional:bind_front",
"@com_google_googletest//:gtest_main",
],
)
Expand Down Expand Up @@ -530,7 +527,6 @@ xla_cc_test(
"//xla/pjrt:pjrt_layout",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/hash:hash_testing",
"@com_google_absl//absl/status:statusor",
"@com_google_googletest//:gtest_main",
"@llvm-project//llvm:Support",
],
Expand All @@ -551,7 +547,6 @@ xla_cc_test(
":ifrt",
"//xla/tsl/platform:env",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:platform_port",
Expand Down Expand Up @@ -754,7 +749,6 @@ xla_cc_test(
"//xla/tsl/platform:status_matchers",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"@com_google_absl//absl/functional:bind_front",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:cord",
"@com_google_googletest//:gtest_main",
Expand Down
7 changes: 4 additions & 3 deletions xla/python/ifrt/array_spec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "xla/pjrt/pjrt_layout.h"
#include "xla/python/ifrt/array_spec.pb.h"
#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/device_list.h"
#include "xla/python/ifrt/dtype.h"
#include "xla/python/ifrt/shape.h"
Expand All @@ -32,12 +33,12 @@ limitations under the License.
namespace xla {
namespace ifrt {

absl::StatusOr<ArraySpec> ArraySpec::FromProto(
DeviceList::LookupDeviceFunc lookup_device, const ArraySpecProto& proto) {
absl::StatusOr<ArraySpec> ArraySpec::FromProto(Client* client,
const ArraySpecProto& proto) {
TF_ASSIGN_OR_RETURN(auto dtype, DType::FromProto(proto.dtype()));
TF_ASSIGN_OR_RETURN(auto shape, Shape::FromProto(proto.shape()));
TF_ASSIGN_OR_RETURN(auto sharding,
Sharding::FromProto(lookup_device, proto.sharding()));
Sharding::FromProto(client, proto.sharding()));
std::shared_ptr<const xla::PjRtLayout> layout;
if (proto.has_layout()) {
TF_ASSIGN_OR_RETURN(layout, xla::PjRtLayout::Deserialize(proto.layout()));
Expand Down
7 changes: 4 additions & 3 deletions xla/python/ifrt/array_spec.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "xla/pjrt/pjrt_layout.h"
#include "xla/python/ifrt/array_spec.pb.h"
#include "xla/python/ifrt/device_list.h"
#include "xla/python/ifrt/dtype.h"
#include "xla/python/ifrt/shape.h"
#include "xla/python/ifrt/sharding.h"

namespace xla {
namespace ifrt {

class Client;

// Specification of an array that groups the static properties of an `Array`
// together. Typically used for describing expected or requested static
// properties of an input/output array of an operation.
Expand Down Expand Up @@ -73,8 +74,8 @@ struct ArraySpec {
}

// Constructs `ArraySpec` from `ArraySpecProto`.
static absl::StatusOr<ArraySpec> FromProto(
DeviceList::LookupDeviceFunc lookup_device, const ArraySpecProto& proto);
static absl::StatusOr<ArraySpec> FromProto(Client* client,
const ArraySpecProto& proto);

// Returns a `ArraySpecProto` representation.
absl::StatusOr<ArraySpecProto> ToProto() const;
Expand Down
7 changes: 1 addition & 6 deletions xla/python/ifrt/array_spec_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@ limitations under the License.

#include <gtest/gtest.h>
#include "absl/hash/hash_testing.h"
#include "absl/status/statusor.h"
#include "llvm/Support/Casting.h"
#include "xla/layout_util.h"
#include "xla/pjrt/pjrt_layout.h"
#include "xla/python/ifrt/array_spec.pb.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/device_test_util.h"
#include "xla/python/ifrt/dtype.h"
#include "xla/python/ifrt/memory.h"
Expand Down Expand Up @@ -64,12 +62,9 @@ TEST_P(ArraySpecTest, ToFromProto) {
/*shape=*/shape,
/*shard_shape=*/shard_shape)};

auto lookup_device_func = [&](DeviceId device_id) -> absl::StatusOr<Device*> {
return client()->LookupDevice(device_id);
};
TF_ASSERT_OK_AND_ASSIGN(const ArraySpecProto proto, spec.ToProto());
TF_ASSERT_OK_AND_ASSIGN(const ArraySpec array_spec_copy,
ArraySpec::FromProto(lookup_device_func, proto));
ArraySpec::FromProto(client(), proto));

EXPECT_EQ(array_spec_copy.dtype, dtype);
EXPECT_EQ(array_spec_copy.shape, shape);
Expand Down
16 changes: 7 additions & 9 deletions xla/python/ifrt/custom_call_program_serdes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,24 +82,22 @@ class CustomCallProgramSerDes
}
TF_ASSIGN_OR_RETURN(
tsl::RCReference<DeviceList> devices,
DeviceList::FromProto(deserialize_program_options->lookup_device,
DeviceList::FromProto(deserialize_program_options->client,
proto.devices()));
std::vector<ArraySpec> input_specs;
input_specs.reserve(proto.input_specs_size());
for (const ArraySpecProto& spec_proto : proto.input_specs()) {
TF_ASSIGN_OR_RETURN(
ArraySpec spec,
ArraySpec::FromProto(deserialize_program_options->lookup_device,
spec_proto));
TF_ASSIGN_OR_RETURN(ArraySpec spec,
ArraySpec::FromProto(
deserialize_program_options->client, spec_proto));
input_specs.push_back(std::move(spec));
}
std::vector<ArraySpec> output_specs;
output_specs.reserve(proto.output_specs_size());
for (const ArraySpecProto& spec_proto : proto.output_specs()) {
TF_ASSIGN_OR_RETURN(
ArraySpec spec,
ArraySpec::FromProto(deserialize_program_options->lookup_device,
spec_proto));
TF_ASSIGN_OR_RETURN(ArraySpec spec,
ArraySpec::FromProto(
deserialize_program_options->client, spec_proto));
output_specs.push_back(std::move(spec));
}

Expand Down
5 changes: 1 addition & 4 deletions xla/python/ifrt/custom_call_program_serdes_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@ limitations under the License.

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/functional/bind_front.h"
#include "absl/status/status.h"
#include "absl/strings/cord.h"
#include "llvm/Support/Casting.h"
#include "xla/python/ifrt/array_spec.h"
#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/custom_call_program.h"
#include "xla/python/ifrt/device_list.h"
#include "xla/python/ifrt/device_test_util.h"
Expand Down Expand Up @@ -87,8 +85,7 @@ TEST_P(CustomCallProgramSerDesTest, RoundTrip) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<CustomCallProgram> deserialized_program,
Deserialize<CustomCallProgram>(
serialized, std::make_unique<DeserializeProgramOptions>(
absl::bind_front(&Client::LookupDevice, client()))));
serialized, std::make_unique<DeserializeProgramOptions>(client())));

EXPECT_EQ(deserialized_program->type, "test type");
EXPECT_EQ(deserialized_program->name, "test name");
Expand Down
13 changes: 7 additions & 6 deletions xla/python/ifrt/device_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ limitations under the License.

#include "absl/base/call_once.h"
#include "absl/base/optimization.h"
#include "absl/container/inlined_vector.h"
#include "absl/hash/hash.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/device.pb.h"
#include "xla/tsl/concurrency/ref_count.h"
Expand All @@ -41,16 +43,15 @@ char DeviceList::ID = 0;
char BasicDeviceList::ID = 0;

absl::StatusOr<tsl::RCReference<DeviceList>> DeviceList::FromProto(
LookupDeviceFunc lookup_device, const DeviceListProto& proto) {
// TODO(hyeontaek): Define SerDes for `DeviceList` and use it to remove this
// layering inversion.
BasicDeviceList::Devices devices;
xla::ifrt::Client* client, const DeviceListProto& proto) {
absl::InlinedVector<Device*, 1> devices;
devices.reserve(proto.device_ids_size());
for (int device_id : proto.device_ids()) {
TF_ASSIGN_OR_RETURN(Device * device, lookup_device(DeviceId(device_id)));
TF_ASSIGN_OR_RETURN(Device* const device,
client->LookupDevice(DeviceId(device_id)));
devices.push_back(device);
}
return BasicDeviceList::Create(std::move(devices));
return client->MakeDeviceList(devices);
}

DeviceListProto DeviceList::ToProto() const {
Expand Down
22 changes: 3 additions & 19 deletions xla/python/ifrt/device_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ limitations under the License.

#include "absl/base/call_once.h"
#include "absl/container/inlined_vector.h"
#include "absl/functional/function_ref.h"
#include "absl/log/check.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
Expand All @@ -41,12 +40,6 @@ namespace ifrt {
class DeviceList : public tsl::ReferenceCounted<DeviceList>,
public llvm::RTTIExtends<DeviceList, llvm::RTTIRoot> {
public:
// Function that matches the semantics of `Client::LookupDevice()`.
// TODO(hyeontaek): Remove this type. In the future, a deserialization option
// will take `Client*` to allow constructing a complex `DeviceList` that is
// not just `BasicDeviceList`.
using LookupDeviceFunc = absl::FunctionRef<absl::StatusOr<Device*>(DeviceId)>;

// Not copyable or movable. `DeviceList` is a runtime object that may contain
// runtime-specific state that cannot be trivially copied or moved.
DeviceList(const DeviceList&) = delete;
Expand All @@ -55,10 +48,10 @@ class DeviceList : public tsl::ReferenceCounted<DeviceList>,
DeviceList& operator=(DeviceList&&) = delete;

// Constructs `DeviceList` from `DeviceListProto`. Devices are looked up using
// `lookup_device`. Device ids in the proto must be consistent with the
// devices returned by `lookup_device`.
// `client`. Device ids in the proto must be consistent with the devices
// returned by `client`.
static absl::StatusOr<tsl::RCReference<DeviceList>> FromProto(
LookupDeviceFunc lookup_device, const DeviceListProto& proto);
xla::ifrt::Client* client, const DeviceListProto& proto);

// Returns a `DeviceListProto` representation.
DeviceListProto ToProto() const;
Expand Down Expand Up @@ -137,15 +130,6 @@ class BasicDeviceList : public llvm::RTTIExtends<BasicDeviceList, DeviceList> {

~BasicDeviceList() override = default;

// Constructs `DeviceList` from `DeviceListProto`. Devices are looked up
// using `lookup_device`. Device ids in the proto must be consistent with
// the devices returned by `lookup_device`.
static absl::StatusOr<tsl::RCReference<DeviceList>> FromProto(
LookupDeviceFunc lookup_device, const DeviceListProto& proto);

// Returns a `DeviceListProto` representation.
DeviceListProto ToProto() const;

absl::Span<Device* const> devices() const override { return devices_; }

DeviceList* AddressableDeviceList() const override;
Expand Down
6 changes: 1 addition & 5 deletions xla/python/ifrt/device_list_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ limitations under the License.

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/device.pb.h"
Expand All @@ -44,11 +43,8 @@ class DeviceListTest : public test_util::DeviceTest {};
TEST_P(DeviceListTest, ToFromProto) {
auto device_list = GetDevices({0, 1});
DeviceListProto proto = device_list->ToProto();
auto lookup_device_func = [&](DeviceId device_id) -> absl::StatusOr<Device*> {
return client()->LookupDevice(device_id);
};
TF_ASSERT_OK_AND_ASSIGN(auto device_list_copy,
DeviceList::FromProto(lookup_device_func, proto));
DeviceList::FromProto(client(), proto));
EXPECT_EQ(*device_list_copy, *device_list);
}

Expand Down
8 changes: 3 additions & 5 deletions xla/python/ifrt/program_serdes.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ limitations under the License.
#define XLA_PYTHON_IFRT_PROGRAM_SERDES_H_

#include "llvm/Support/ExtensibleRTTI.h"
#include "xla/python/ifrt/device_list.h"
#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/serdes.h"

namespace xla {
Expand All @@ -26,13 +26,11 @@ namespace ifrt {
// Abstract options for deserializing an `Program`.
struct DeserializeProgramOptions
: llvm::RTTIExtends<DeserializeProgramOptions, DeserializeOptions> {
explicit DeserializeProgramOptions(DeviceList::LookupDeviceFunc lookup_device)
: lookup_device(lookup_device) {}
explicit DeserializeProgramOptions(Client* client) : client(client) {}

static char ID; // NOLINT

// Function that converts device ids to devices.
DeviceList::LookupDeviceFunc lookup_device;
Client* client;
};

} // namespace ifrt
Expand Down
8 changes: 4 additions & 4 deletions xla/python/ifrt/remap_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,21 +295,21 @@ absl::Status RemapPlan::Validate() const {
return absl::OkStatus();
}

absl::StatusOr<RemapPlan> RemapPlan::FromProto(
DeviceList::LookupDeviceFunc lookup_device, const RemapPlanProto& proto) {
absl::StatusOr<RemapPlan> RemapPlan::FromProto(Client* client,
const RemapPlanProto& proto) {
RemapPlan plan;

plan.input_specs.reserve(proto.input_specs_size());
for (const auto& input_spec_proto : proto.input_specs()) {
TF_ASSIGN_OR_RETURN(ArraySpec input_spec,
ArraySpec::FromProto(lookup_device, input_spec_proto));
ArraySpec::FromProto(client, input_spec_proto));
plan.input_specs.push_back(std::move(input_spec));
}

plan.output_specs.reserve(proto.output_specs_size());
for (const auto& output_spec_proto : proto.output_specs()) {
TF_ASSIGN_OR_RETURN(ArraySpec output_spec,
ArraySpec::FromProto(lookup_device, output_spec_proto));
ArraySpec::FromProto(client, output_spec_proto));
plan.output_specs.push_back(std::move(output_spec));
}

Expand Down
7 changes: 4 additions & 3 deletions xla/python/ifrt/remap_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/array_spec.h"
#include "xla/python/ifrt/device_list.h"
#include "xla/python/ifrt/remap_plan.pb.h"

namespace xla {
namespace ifrt {

class Client;

// Remap plan that describes how the shards from input `Array`s are mapped to
// the shards of output `Array`s.
//
Expand Down Expand Up @@ -95,8 +96,8 @@ struct RemapPlan {
// Constructs `RemapPlan` from `RemapPlanProto`. Devices are looked up
// using `lookup_device`. Device ids in the proto must be consistent with
// the devices returned by `lookup_device`.
static absl::StatusOr<RemapPlan> FromProto(
DeviceList::LookupDeviceFunc lookup_device, const RemapPlanProto& proto);
static absl::StatusOr<RemapPlan> FromProto(Client* client,
const RemapPlanProto& proto);

// Returns a `RemapPlanProto` representation.
absl::StatusOr<RemapPlanProto> ToProto() const;
Expand Down
6 changes: 2 additions & 4 deletions xla/python/ifrt/remap_plan_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,8 @@ TEST_P(RemapPlanTest, ToFromProto) {
/*to=*/{RemapPlan::Interval{0, 2, 1}, RemapPlan::Interval{2, 4, 1}}});

TF_ASSERT_OK_AND_ASSIGN(RemapPlanProto plan_proto, plan.ToProto());
TF_ASSERT_OK_AND_ASSIGN(
RemapPlan plan_copy,
RemapPlan::FromProto(absl::bind_front(&Client::LookupDevice, client()),
plan_proto));
TF_ASSERT_OK_AND_ASSIGN(RemapPlan plan_copy,
RemapPlan::FromProto(client(), plan_proto));

EXPECT_THAT(*plan_copy.mappings, ElementsAreArray(*plan.mappings));

Expand Down
5 changes: 2 additions & 3 deletions xla/python/ifrt/sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,10 @@ bool Sharding::operator==(const Sharding& other) const {
}

absl::StatusOr<std::unique_ptr<Sharding>> Sharding::FromProto(
DeviceList::LookupDeviceFunc lookup_device,
const ShardingProto& sharding_proto) {
Client* client, const ShardingProto& sharding_proto) {
return Deserialize<Sharding>(
sharding_proto.serialized_sharding(),
std::make_unique<DeserializeShardingOptions>(std::move(lookup_device)));
std::make_unique<DeserializeShardingOptions>(client));
}

absl::StatusOr<ShardingProto> Sharding::ToProto() const {
Expand Down
Loading

0 comments on commit eabe200

Please sign in to comment.