Skip to content

Commit

Permalink
Replace the remaining use of BasicDeviceList in the base IFRT
Browse files Browse the repository at this point in the history
After this CL, we will be able to separate `BasicDeviceList` into a separate build target. This will let us control its visibility so that IFRT users (not IFRT implementations) stop calling `BasicDeviceList`.

PiperOrigin-RevId: 726184429
  • Loading branch information
junwhanahn authored and Google-ML-Automation committed Feb 12, 2025
1 parent d091218 commit ae9d50a
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 26 deletions.
13 changes: 9 additions & 4 deletions xla/python/ifrt/remap_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
Expand Down Expand Up @@ -179,7 +180,8 @@ absl::Status RemapPlan::Validate() const {
}

const int num_outputs = output_specs.size();
std::vector<BasicDeviceList::Devices> out_assigned_devices_list(num_outputs);
std::vector<absl::InlinedVector<Device*, 1>> out_assigned_devices_list(
num_outputs);
for (int i = 0; i < num_outputs; ++i) {
out_assigned_devices_list[i].resize(
/*n=*/output_specs[i].sharding->devices()->size(),
Expand Down Expand Up @@ -235,7 +237,7 @@ absl::Status RemapPlan::Validate() const {
std::vector<bool>& in_used_buffers = in_used_buffers_list[mapping.in_array];
absl::Span<Device* const> in_devices =
input_specs[mapping.in_array].sharding->devices()->devices();
BasicDeviceList::Devices& out_assigned_devices =
absl::InlinedVector<Device*, 1>& out_assigned_devices =
out_assigned_devices_list[mapping.out_array];
const int64_t in_shards_count = in_used_buffers.size();
const int64_t out_shards_count = out_assigned_devices.size();
Expand Down Expand Up @@ -287,9 +289,12 @@ absl::Status RemapPlan::Validate() const {
output_specs[i].sharding->devices()->devices()) {
return InvalidArgument(
"Output array %d devices and sharding devices do not match: "
"Expected %v, but got %v",
"Expected %v, but got [%s]",
i, *output_specs[i].sharding->devices(),
*BasicDeviceList::Create(std::move(out_assigned_devices_list[i])));
absl::StrJoin(out_assigned_devices_list[i], ", ",
[](std::string* s, Device* d) {
absl::StrAppend(s, d->ToString());
}));
}
}
return absl::OkStatus();
Expand Down
7 changes: 7 additions & 0 deletions xla/python/ifrt/sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ExtensibleRTTI.h"
#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/device_list.h"
#include "xla/python/ifrt/index.h"
Expand Down Expand Up @@ -210,6 +211,12 @@ std::unique_ptr<SingleDeviceSharding> SingleDeviceSharding::Create(
new SingleDeviceSharding(device, memory_kind));
}

SingleDeviceSharding::SingleDeviceSharding(Device* device,
MemoryKind memory_kind)
: llvm::RTTIExtends<SingleDeviceSharding, Sharding>(
device->client()->MakeDeviceList({device}), memory_kind,
/*is_fully_replicated=*/true) {}

absl::StatusOr<Shape> SingleDeviceSharding::GetShardShape(
const Shape& shape) const {
return shape;
Expand Down
5 changes: 1 addition & 4 deletions xla/python/ifrt/sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,7 @@ class SingleDeviceSharding final
static char ID; // NOLINT

private:
explicit SingleDeviceSharding(Device* device, MemoryKind memory_kind)
: llvm::RTTIExtends<SingleDeviceSharding, Sharding>(
BasicDeviceList::Create({device}), memory_kind,
/*is_fully_replicated=*/true) {}
explicit SingleDeviceSharding(Device* device, MemoryKind memory_kind);

void Hash(absl::HashState state) const override;
};
Expand Down
25 changes: 13 additions & 12 deletions xla/python/ifrt/sharding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,7 @@ TEST_P(OpaqueShardingTest, CreateWithBadDeviceList) {
EXPECT_DEATH(
OpaqueSharding::Create(tsl::RCReference<DeviceList>(), MemoryKind()), "");

EXPECT_DEATH(
OpaqueSharding::Create(BasicDeviceList::Create({}), MemoryKind()), "");
EXPECT_DEATH(OpaqueSharding::Create(GetDevices({}), MemoryKind()), "");
}

TEST_P(OpaqueShardingTest, IsFullyReplicated) {
Expand Down Expand Up @@ -326,8 +325,8 @@ TEST_P(ConcreteShardingTest, CreateWithBadDeviceList) {
MemoryKind(), Shape({}), {Shape({})}),
"");

EXPECT_DEATH(ConcreteSharding::Create(BasicDeviceList::Create({}),
MemoryKind(), Shape({}), {Shape({})}),
EXPECT_DEATH(ConcreteSharding::Create(GetDevices({}), MemoryKind(), Shape({}),
{Shape({})}),
"");
}

Expand Down Expand Up @@ -520,8 +519,10 @@ TEST_P(ConcreteShardingTest, DisassembleDynamicShape) {
DynamicShape shard_dynamic_shape3,
DynamicShape::Create(Shape({7}), BoundedDynamicShapeTag({true})));
std::vector<DynamicShape> shard_dynamic_shapes{
std::move(shard_dynamic_shape0), std::move(shard_dynamic_shape1),
std::move(shard_dynamic_shape2), std::move(shard_dynamic_shape3),
std::move(shard_dynamic_shape0),
std::move(shard_dynamic_shape1),
std::move(shard_dynamic_shape2),
std::move(shard_dynamic_shape3),
};
auto sharding = ConcreteSharding::Create(device_list, MemoryKind(),
dynamic_shape, shard_dynamic_shapes);
Expand Down Expand Up @@ -610,8 +611,8 @@ TEST_P(ConcreteEvenShardingTest, CreateWithBadDeviceList) {
/*is_fully_replicated=*/true),
"");

EXPECT_DEATH(ConcreteEvenSharding::Create(BasicDeviceList::Create({}),
MemoryKind(), Shape({}), Shape({}),
EXPECT_DEATH(ConcreteEvenSharding::Create(GetDevices({}), MemoryKind(),
Shape({}), Shape({}),
/*is_fully_replicated=*/true),
"");
}
Expand Down Expand Up @@ -816,10 +817,10 @@ TEST_P(ShardingParamShardingTest, CreateWithBadDeviceList) {
.value(),
"");

EXPECT_DEATH(ShardingParamSharding::Create(param, BasicDeviceList::Create({}),
MemoryKind())
.value(),
"");
EXPECT_DEATH(
ShardingParamSharding::Create(param, GetDevices({}), MemoryKind())
.value(),
"");
}

TEST_P(ShardingParamShardingTest, CreateFailsWhenDeviceCountNotMatch) {
Expand Down
8 changes: 8 additions & 0 deletions xla/python/ifrt_proxy/client/array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/device_list.h"
#include "xla/python/ifrt/dtype.h"
#include "xla/python/ifrt/future.h"
#include "xla/python/ifrt/memory.h"
Expand Down Expand Up @@ -130,7 +132,13 @@ TEST_F(ArrayTest, FullyReplicatedShard) {
.WillOnce(MockClientSessionReturnResponse(response));

MockClient client;
ON_CALL(client, MakeDeviceList(_))
.WillByDefault([](absl::Span<xla::ifrt::Device* const> devices) {
return xla::ifrt::BasicDeviceList::Create(devices);
});

MockDevice mock_device;
ON_CALL(mock_device, client()).WillByDefault(Return(&client));

auto sharding = xla::ifrt::SingleDeviceSharding::Create(
&mock_device, xla::ifrt::MemoryKind());
Expand Down
9 changes: 7 additions & 2 deletions xla/python/ifrt_proxy/client/executable_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,15 @@ TEST_F(LoadedExecutableTest, Metadata) {
// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS
#if defined(PLATFORM_GOOGLE)
TEST_F(LoadedExecutableTest, Execute) {
MockClient client;
ON_CALL(client, MakeDeviceList(_))
.WillByDefault([](absl::Span<xla::ifrt::Device* const> devices) {
return xla::ifrt::BasicDeviceList::Create(devices);
});

MockDevice device;
ON_CALL(device, client()).WillByDefault(Return(&client));
ON_CALL(device, Id()).WillByDefault(Return(DeviceId(1)));

MockClient client;
ON_CALL(client, LookupDevice(DeviceId(1))).WillByDefault(Return(&device));

LoadedExecutable executable(
Expand Down
13 changes: 9 additions & 4 deletions xla/python/ifrt_proxy/server/ifrt_backend_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ class IfrtBackendHandlerTest : public IfrtBackendTest {
std::vector<xla::ifrt::Device*> raw_device_ptrs;
for (int i = 0; i < 2; ++i) {
auto mock_device = std::make_unique<xla::ifrt::MockDevice>();
ON_CALL(*mock_device, client()).WillByDefault(Return(mock_client.get()));
ON_CALL(*mock_device, Id()).WillByDefault(Return(DeviceId(i)));
ON_CALL(*mock_device, IsAddressable()).WillByDefault(Return(true));
raw_device_ptrs.push_back(mock_device.get());
Expand All @@ -270,6 +271,10 @@ class IfrtBackendHandlerTest : public IfrtBackendTest {
}
return mock_devices_[id.value()].get();
}));
ON_CALL(*mock_client, MakeDeviceList(_))
.WillByDefault([](absl::Span<xla::ifrt::Device* const> devices) {
return xla::ifrt::BasicDeviceList::Create(devices);
});

// Remembering a raw pointer to the mock client here is OK, since most tests
// anyway have to make the basic and tacit assumption that the backend will
Expand Down Expand Up @@ -1348,8 +1353,8 @@ TEST_P(IfrtBackendHandlerTest, LoadedExecutableMetadata) {
// TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS
#if defined(PLATFORM_GOOGLE)
TEST_P(IfrtBackendHandlerTest, LoadedExecutableExecute) {
MockDevice device;
ON_CALL(device, Id()).WillByDefault(Return(DeviceId(0)));
TF_ASSERT_OK_AND_ASSIGN(xla::ifrt::Device* const device,
mock_client_->LookupDevice(DeviceId(0)));

MockLoadedExecutable* executable;
uint64_t handle;
Expand All @@ -1365,7 +1370,7 @@ TEST_P(IfrtBackendHandlerTest, LoadedExecutableExecute) {
constexpr int kNumOutputs = 2;

Shape shape({2, 2});
auto sharding = SingleDeviceSharding::Create(&device, MemoryKind());
auto sharding = SingleDeviceSharding::Create(device, MemoryKind());

auto make_array = [&]() {
auto array = tsl::MakeRef<MockArray>();
Expand Down Expand Up @@ -1422,7 +1427,7 @@ TEST_P(IfrtBackendHandlerTest, LoadedExecutableExecute) {
)pb"))));
TF_ASSERT_OK_AND_ASSIGN(
auto sharding_proto,
SingleDeviceSharding::Create(&device, MemoryKind())->ToProto());
SingleDeviceSharding::Create(device, MemoryKind())->ToProto());
for (const auto& output :
response->loaded_executable_execute_response().outputs()) {
EXPECT_THAT(output.sharding(), EquivToProto(sharding_proto));
Expand Down

0 comments on commit ae9d50a

Please sign in to comment.