Skip to content

Commit

Permalink
Move users of BasicDeviceList::Create() to Client::MakeDeviceList()
Browse files Browse the repository at this point in the history
IFRT is moving towards runtime-controlled device list creation. This CL moves most of explicit device list creation from `BasicDeviceList::Create()` to `Client::MakeDeviceList()`. Once the migration is done, `BasicDeviceList::Create()` will be reserved only for IFRT implementations and all IFRT users will be expected to use `Client::MakeDeviceList::Create()` to create device lists.

PiperOrigin-RevId: 725692548
  • Loading branch information
junwhanahn authored and Google-ML-Automation committed Feb 11, 2025
1 parent a00b023 commit 52d5cca
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions xla/python/pjrt_ifrt/basic_string_array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ limitations under the License.
#include "xla/pjrt/pjrt_future.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/device_list.h"
#include "xla/python/ifrt/dtype.h"
#include "xla/python/ifrt/future.h"
#include "xla/python/ifrt/memory.h"
Expand Down Expand Up @@ -320,7 +319,7 @@ TEST(MakeArrayFromHostBufferTest, FailureCases) {
// MakeArrayFromHostBuffer should check and fail if the sharding is not a
// SingleDeviceSharding.
std::shared_ptr<const Sharding> opaque_sharding =
OpaqueSharding::Create(BasicDeviceList::Create({device}), MemoryKind());
OpaqueSharding::Create(client->MakeDeviceList({device}), MemoryKind());
EXPECT_THAT(client->MakeArrayFromHostBuffer(
data, DType(DType::kString), shape,
/*byte_strides=*/std::nullopt, opaque_sharding,
Expand Down Expand Up @@ -400,7 +399,7 @@ absl::StatusOr<tsl::RCReference<Array>> MakeShardedStringTestArray(
}

std::shared_ptr<const Sharding> sharding = ConcreteEvenSharding::Create(
BasicDeviceList::Create({devices[0], devices[1]}), MemoryKind(),
client->MakeDeviceList({devices[0], devices[1]}), MemoryKind(),
Shape({2, 1}), Shape({1}), is_fully_replicated);

std::vector<tsl::RCReference<Array>> arrays;
Expand Down Expand Up @@ -441,7 +440,7 @@ TEST(AssembleArrayFromSingleDeviceArraysTest, FailsWithNonStringArrays) {
auto devices = client->addressable_devices();
ASSERT_GE(devices.size(), 2);
std::shared_ptr<const Sharding> opaque_sharding = OpaqueSharding::Create(
BasicDeviceList::Create({devices[0], devices[1]}), MemoryKind());
client->MakeDeviceList({devices[0], devices[1]}), MemoryKind());

std::vector<tsl::RCReference<Array>> arrays(2);
TF_ASSERT_OK_AND_ASSIGN(
Expand All @@ -462,7 +461,7 @@ TEST(AssembleArrayFromSingleDeviceArraysTest,
auto devices = client->addressable_devices();
ASSERT_GE(devices.size(), 2);
std::shared_ptr<const Sharding> opaque_sharding = OpaqueSharding::Create(
BasicDeviceList::Create({devices[0], devices[1]}), MemoryKind());
client->MakeDeviceList({devices[0], devices[1]}), MemoryKind());

std::vector<tsl::RCReference<Array>> arrays(2);
const std::vector<std::string> per_shard_contents({"abc", "def"});
Expand All @@ -485,7 +484,7 @@ TEST(AssembleArrayFromSingleDeviceArraysTest,
auto devices = client->addressable_devices();
ASSERT_GE(devices.size(), 2);
std::shared_ptr<const Sharding> opaque_sharding = OpaqueSharding::Create(
BasicDeviceList::Create({devices[0], devices[1]}), MemoryKind());
client->MakeDeviceList({devices[0], devices[1]}), MemoryKind());

// Make two non-ready single device sharded arrays.
std::vector<tsl::RCReference<Array>> arrays;
Expand Down Expand Up @@ -535,7 +534,7 @@ TEST(AssembleArrayFromSingleDeviceArraysTest,
auto devices = client->addressable_devices();
ASSERT_GE(devices.size(), 2);
std::shared_ptr<const Sharding> opaque_sharding = OpaqueSharding::Create(
BasicDeviceList::Create({devices[0], devices[1]}), MemoryKind());
client->MakeDeviceList({devices[0], devices[1]}), MemoryKind());

// Make two non-ready single device sharded arrays.
std::vector<tsl::RCReference<Array>> arrays;
Expand Down Expand Up @@ -665,7 +664,7 @@ TEST(CopyTest, SuccessSingleDeviceShardedArray) {
TF_ASSERT_OK_AND_ASSIGN(
auto new_arrays,
client->CopyArrays(absl::MakeSpan(arrays),
BasicDeviceList::Create({devices[1]}), MemoryKind(),
client->MakeDeviceList({devices[1]}), MemoryKind(),
ArrayCopySemantics::kAlwaysCopy));

auto new_basic_string_array =
Expand All @@ -691,7 +690,7 @@ TEST(CopyTest, SuccessMultiDeviceShardedArray) {
TF_ASSERT_OK_AND_ASSIGN(
auto new_arrays,
client->CopyArrays(absl::MakeSpan(arrays),
BasicDeviceList::Create({devices[2], devices[3]}),
client->MakeDeviceList({devices[2], devices[3]}),
MemoryKind(), ArrayCopySemantics::kAlwaysCopy));

auto new_basic_string_array =
Expand All @@ -718,7 +717,7 @@ TEST(CopyTest, FailsAfterDeletion) {
arrays[0]->Delete();

EXPECT_THAT(client->CopyArrays(absl::MakeSpan(arrays),
BasicDeviceList::Create({devices[1]}),
client->MakeDeviceList({devices[1]}),
MemoryKind(), ArrayCopySemantics::kAlwaysCopy),
StatusIs(absl::StatusCode::kFailedPrecondition));
}
Expand All @@ -737,7 +736,7 @@ TEST(CopyTest, FailsWithDifferentNumbersDevices) {

EXPECT_THAT(
client->CopyArrays(absl::MakeSpan(arrays),
BasicDeviceList::Create({devices[0], devices[1]}),
client->MakeDeviceList({devices[0], devices[1]}),
MemoryKind(), ArrayCopySemantics::kAlwaysCopy),
StatusIs(absl::StatusCode::kInvalidArgument));
}
Expand All @@ -758,7 +757,7 @@ TEST(CopyTest, NonReadySourceArraySuccessfullyBecomesReadyAfterCopy) {
auto promise = std::move(ret.second);

TF_ASSERT_OK(client->CopyArrays(
absl::MakeSpan(arrays), BasicDeviceList::Create({devices[1]}),
absl::MakeSpan(arrays), client->MakeDeviceList({devices[1]}),
MemoryKind(), ArrayCopySemantics::kAlwaysCopy));

absl::Notification done_readying_single_device_arrays;
Expand Down Expand Up @@ -796,7 +795,7 @@ TEST(CopyTest, NonReadySourceArrayFailsToBecomeReadyAfterCopy) {
auto promise = std::move(ret.second);

TF_ASSERT_OK(client->CopyArrays(
absl::MakeSpan(arrays), BasicDeviceList::Create({devices[1]}),
absl::MakeSpan(arrays), client->MakeDeviceList({devices[1]}),
MemoryKind(), ArrayCopySemantics::kAlwaysCopy));

absl::Notification done_readying_single_device_arrays;
Expand Down

0 comments on commit 52d5cca

Please sign in to comment.