Skip to content

Commit

Permalink
Move users of BasicDeviceList::Create() to Client::MakeDeviceList()
Browse files Browse the repository at this point in the history
IFRT is moving towards runtime-controlled device list creation. This CL moves most of explicit device list creation from `BasicDeviceLIst::Create()` to `Client::MakeDeviceList()`. Once the migration is done, `BasicDeviceList::Create()` will be reserved only for IFRT implementations and all IFRT users will be expected to use `Client::MakeDeviceList::Create()` to create device lists.

PiperOrigin-RevId: 725330452
  • Loading branch information
junwhanahn authored and Google-ML-Automation committed Feb 10, 2025
1 parent 841b1bf commit ec483e3
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 20 deletions.
3 changes: 3 additions & 0 deletions xla/python/ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ cc_library(
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down Expand Up @@ -321,6 +322,7 @@ cc_library(
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/status",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
Expand Down Expand Up @@ -617,6 +619,7 @@ cc_library(
"//xla/tsl/platform:status_matchers",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
Expand Down
26 changes: 13 additions & 13 deletions xla/python/ifrt/array_impl_test_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/status/status.h"
#include "absl/synchronization/notification.h"
#include "absl/time/clock.h"
Expand Down Expand Up @@ -311,7 +312,7 @@ TEST(ArrayImplTest, MakeArrayFromHostBufferReplicated) {
std::iota(data->begin(), data->end(), 0);
absl::Span<Device* const> devices = client->addressable_devices();
std::shared_ptr<const Sharding> sharding = ConcreteEvenSharding::Create(
BasicDeviceList::Create(devices), MemoryKind(), shape,
client->MakeDeviceList(devices), MemoryKind(), shape,
/*shard_shape=*/shape, /*is_fully_replicated=*/true);

TF_ASSERT_OK_AND_ASSIGN(
Expand Down Expand Up @@ -375,9 +376,8 @@ TEST(ArrayImplTest, AssembleArray) {
std::vector<tsl::RCReference<Array>> arrays({array0, array1});
Shape assembled_shape({4, 3});
std::shared_ptr<const Sharding> assembled_sharding = OpaqueSharding::Create(
BasicDeviceList::Create(
{array0->sharding().devices()->devices().front(),
array1->sharding().devices()->devices().front()}),
client->MakeDeviceList({array0->sharding().devices()->devices().front(),
array1->sharding().devices()->devices().front()}),
MemoryKind());
TF_ASSERT_OK_AND_ASSIGN(
auto assembled_array,
Expand Down Expand Up @@ -423,9 +423,9 @@ TEST(ArrayImplTest, AssembleAndDisassembleArray) {
Shape assembled_shape({4, 3});
ShardingParam sharding_param(
/*dim_shards=*/{2, 1}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 1}});
auto ifrt_device_list = BasicDeviceList::Create(
{array0->sharding().devices()->devices().front(),
array1->sharding().devices()->devices().front()});
auto ifrt_device_list =
client->MakeDeviceList({array0->sharding().devices()->devices().front(),
array1->sharding().devices()->devices().front()});
TF_ASSERT_OK_AND_ASSIGN(
std::shared_ptr<const Sharding> sharding_param_sharding,
ShardingParamSharding::Create(std::move(sharding_param), ifrt_device_list,
Expand Down Expand Up @@ -537,7 +537,7 @@ TEST(ArrayImplTest, CopyToSameDevices) {
TEST(ArrayImplTest, CopyToDifferentDevice) {
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());
tsl::RCReference<DeviceList> devices =
BasicDeviceList::Create(client->addressable_devices());
client->MakeDeviceList(client->addressable_devices());

DType dtype(DType::kF32);
Shape shape({2, 3});
Expand Down Expand Up @@ -580,22 +580,22 @@ TEST(ArrayImplTest, CopyToDifferentDevice) {
SingleDeviceShardSemantics::kAddressableShards));
}

BasicDeviceList::Devices new_devices;
absl::InlinedVector<xla::ifrt::Device*, 1> new_devices;
for (auto it = devices->devices().rbegin(); it != devices->devices().rend();
++it) {
new_devices.push_back(*it);
}
TF_ASSERT_OK_AND_ASSIGN(
auto new_arrays,
client->CopyArrays(absl::MakeSpan(arrays),
BasicDeviceList::Create(new_devices), MemoryKind(),
client->MakeDeviceList(new_devices), MemoryKind(),
ArrayCopySemantics::kAlwaysCopy));

for (int i = 0; i < arrays.size(); ++i) {
TF_ASSERT_OK_AND_ASSIGN(
auto expected_sharding,
arrays[i]->sharding().WithDeviceAssignment(
BasicDeviceList::Create(new_devices), MemoryKind()));
client->MakeDeviceList(new_devices), MemoryKind()));
EXPECT_EQ(new_arrays[i]->sharding(), *expected_sharding);

TF_ASSERT_OK_AND_ASSIGN(
Expand Down Expand Up @@ -637,7 +637,7 @@ TEST(ArrayImplTest, CopyMixedSourceDevices) {
Device* new_device = client->addressable_devices().at(1);
EXPECT_THAT(client
->CopyArrays(absl::MakeSpan(arrays),
BasicDeviceList::Create({new_device}),
client->MakeDeviceList({new_device}),
MemoryKind(), ArrayCopySemantics::kAlwaysCopy)
.status(),
StatusIs(absl::StatusCode::kInvalidArgument));
Expand Down Expand Up @@ -671,7 +671,7 @@ TEST(ArrayImplTest, CopyMixedSourceMemoryKind) {
Device* new_device = client->addressable_devices().at(1);
EXPECT_THAT(client
->CopyArrays(absl::MakeSpan(arrays),
BasicDeviceList::Create({new_device}),
client->MakeDeviceList({new_device}),
MemoryKind(), ArrayCopySemantics::kAlwaysCopy)
.status(),
StatusIs(absl::StatusCode::kInvalidArgument));
Expand Down
2 changes: 1 addition & 1 deletion xla/python/ifrt/ir/tests/executable_impl_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ IfrtIrExecutableImplTestBase::PickDevices(int count) {
absl::Span<Device* const> devices = client_->devices();
TF_RET_CHECK(count <= devices.size())
<< "Requested " << count << " devices. Only have " << devices.size();
return BasicDeviceList::Create(devices.first(count));
return client_->MakeDeviceList(devices.first(count));
}

} // namespace test_util
Expand Down
5 changes: 3 additions & 2 deletions xla/python/ifrt/remap_impl_test_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/container/inlined_vector.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -110,7 +111,7 @@ absl::StatusOr<tsl::RCReference<Array>> CreateArray(

std::vector<tsl::RCReference<Array>> shards;
shards.reserve(base_values.size());
BasicDeviceList::Devices devices;
absl::InlinedVector<xla::ifrt::Device*, 1> devices;
devices.reserve(device_indices.size());

for (int i = 0; i < base_values.size(); ++i) {
Expand All @@ -132,7 +133,7 @@ absl::StatusOr<tsl::RCReference<Array>> CreateArray(
}

std::shared_ptr<const Sharding> assembled_sharding =
ConcreteEvenSharding::Create(BasicDeviceList::Create(std::move(devices)),
ConcreteEvenSharding::Create(client->MakeDeviceList(devices),
MemoryKind(),
/*shape=*/shape,
/*shard_shape=*/std::move(shard_shape));
Expand Down
9 changes: 5 additions & 4 deletions xla/python/ifrt/test_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <utility>

#include "absl/base/thread_annotations.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -91,7 +92,7 @@ void SetTestFilterIfNotUserSpecified(absl::string_view custom_filter) {

absl::StatusOr<tsl::RCReference<DeviceList>> GetDevices(
Client* client, absl::Span<const int> device_indices) {
BasicDeviceList::Devices devices;
absl::InlinedVector<xla::ifrt::Device*, 1> devices;
devices.reserve(device_indices.size());
const absl::Span<Device* const> client_devices = client->devices();
for (int device_index : device_indices) {
Expand All @@ -101,12 +102,12 @@ absl::StatusOr<tsl::RCReference<DeviceList>> GetDevices(
}
devices.push_back(client_devices[device_index]);
}
return BasicDeviceList::Create(std::move(devices));
return client->MakeDeviceList(devices);
}

absl::StatusOr<tsl::RCReference<DeviceList>> GetAddressableDevices(
Client* client, absl::Span<const int> device_indices) {
BasicDeviceList::Devices devices;
absl::InlinedVector<xla::ifrt::Device*, 1> devices;
devices.reserve(device_indices.size());
const absl::Span<Device* const> client_devices =
client->addressable_devices();
Expand All @@ -117,7 +118,7 @@ absl::StatusOr<tsl::RCReference<DeviceList>> GetAddressableDevices(
}
devices.push_back(client_devices[device_index]);
}
return BasicDeviceList::Create(std::move(devices));
return client->MakeDeviceList(std::move(devices));
}

} // namespace test_util
Expand Down

0 comments on commit ec483e3

Please sign in to comment.