Skip to content

Commit

Permalink
Move users of BasicDeviceList::Create() to Client::MakeDeviceList()
Browse files Browse the repository at this point in the history
IFRT is moving towards runtime-controlled device list creation. This CL moves most of explicit device list creation from `BasicDeviceLIst::Create()` to `Client::MakeDeviceList()`. Once the migration is done, `BasicDeviceList::Create()` will be reserved only for IFRT implementations and all IFRT users will be expected to use `Client::MakeDeviceList::Create()` to create device lists.

PiperOrigin-RevId: 725785696
  • Loading branch information
junwhanahn authored and Google-ML-Automation committed Feb 11, 2025
1 parent b81efed commit e9c0445
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 38 deletions.
2 changes: 1 addition & 1 deletion xla/python/pjit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable,
xla::ifrt::Client* const ifrt_client =
executable.ifrt_loaded_executable()->client();
tsl::RCReference<xla::ifrt::DeviceList> ifrt_devices =
xla::ifrt::BasicDeviceList::Create({addressable_devices[0]});
ifrt_client->MakeDeviceList({addressable_devices[0]});
for (auto& [key, group] : copy_groups) {
TF_ASSIGN_OR_RETURN(
auto copied_ifrt_arrays,
Expand Down
17 changes: 8 additions & 9 deletions xla/python/pmap_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,19 +154,18 @@ absl::StatusOr<ShardArgResult> ShardArg(
return xla::InvalidArgument("Array has been deleted.");
}
if (result.ifrt_array->sharding().devices()->devices() != devices) {
xla::ifrt::BasicDeviceList::Devices ifrt_devices;
absl::InlinedVector<xla::ifrt::Device*, 1> ifrt_devices;
ifrt_devices.reserve(devices.size());
ifrt_devices.insert(ifrt_devices.end(), devices.begin(),
devices.end());
// pmap does not support memory_kind for now.
auto* ifrt_client = result.ifrt_array->client();
TF_ASSIGN_OR_RETURN(
auto copied_ifrt_arrays,
ifrt_client->CopyArrays(
absl::MakeSpan(&result.ifrt_array, 1),
xla::ifrt::BasicDeviceList::Create(std::move(ifrt_devices)),
xla::ifrt::MemoryKind(),
xla::ifrt::ArrayCopySemantics::kReuseInput));
TF_ASSIGN_OR_RETURN(auto copied_ifrt_arrays,
ifrt_client->CopyArrays(
absl::MakeSpan(&result.ifrt_array, 1),
ifrt_client->MakeDeviceList(ifrt_devices),
xla::ifrt::MemoryKind(),
xla::ifrt::ArrayCopySemantics::kReuseInput));
result.ifrt_array = std::move(copied_ifrt_arrays.front());
}
return result;
Expand All @@ -188,7 +187,7 @@ absl::StatusOr<ShardArgResult> ShardArg(

std::vector<tsl::RCReference<xla::ifrt::Array>> per_device_arrays;
per_device_arrays.reserve(n_devices);
xla::ifrt::BasicDeviceList::Devices devices;
absl::InlinedVector<xla::ifrt::Device*, 1> devices;
devices.reserve(n_devices);
// TODO(hyeontaek): The created array will never be disassembled. We should
// omit collecting shapes and make the OpaqueSharding non-disassemblable?
Expand Down
20 changes: 12 additions & 8 deletions xla/python/py_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ limitations under the License.
#include "absl/base/casts.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
Expand Down Expand Up @@ -153,7 +154,7 @@ tsl::RCReference<ifrt::Array> CreateIfRtArrayFromSingleDeviceShardedPyArrays(
}
std::vector<tsl::RCReference<ifrt::Array>> ifrt_arrays;
ifrt_arrays.reserve(py_arrays.size());
ifrt::BasicDeviceList::Devices devices;
absl::InlinedVector<ifrt::Device*, 1> devices;
devices.reserve(py_arrays.size());
absl::flat_hash_set<ifrt::Device*> device_set;
device_set.reserve(py_arrays.size());
Expand Down Expand Up @@ -197,8 +198,9 @@ tsl::RCReference<ifrt::Array> CreateIfRtArrayFromSingleDeviceShardedPyArrays(
.c_str());
}
}
ifrt::Client* client = ifrt_arrays.front()->client();
tsl::RCReference<ifrt::DeviceList> device_list =
ifrt::BasicDeviceList::Create(std::move(devices));
client->MakeDeviceList(devices);
if (device_set.size() != device_list->size()) {
throw nb::value_error(
absl::StrFormat(
Expand All @@ -207,7 +209,6 @@ tsl::RCReference<ifrt::Array> CreateIfRtArrayFromSingleDeviceShardedPyArrays(
*device_list)
.c_str());
}
ifrt::Client* client = ifrt_arrays.front()->client();

auto ifrt_dtype = DtypeToIfRtDType(dtype);
if (!ifrt_dtype.ok()) {
Expand Down Expand Up @@ -696,7 +697,7 @@ absl::Status PyArray::set_arrays(nb::object obj) {
py_arrays().clear();
std::vector<tsl::RCReference<ifrt::Array>> ifrt_arrays;
ifrt_arrays.reserve(list.size());
ifrt::BasicDeviceList::Devices devices;
absl::InlinedVector<ifrt::Device*, 1> devices;
devices.reserve(list.size());
std::vector<ifrt::Shape> shapes;
shapes.reserve(list.size());
Expand Down Expand Up @@ -1248,7 +1249,7 @@ absl::StatusOr<PyArray> PyArray::BatchedDevicePut(
nb::list owning_pylist;
std::vector<tsl::RCReference<ifrt::Array>> ifrt_arrays;

xla::ifrt::BasicDeviceList::Devices devices;
absl::InlinedVector<ifrt::Device*, 1> devices;
devices.reserve(n_devices);
std::vector<xla::ifrt::Shape> shapes;
shapes.reserve(n_devices);
Expand Down Expand Up @@ -2015,16 +2016,19 @@ absl::Status PyArray::RegisterTypes(nb::module_& m) {
absl::Span<const std::vector<const PyDevice*>> dst_device_lists,
absl::Span<const nb::object> shardings,
absl::Span<const ifrt::ArrayCopySemantics> array_copy_semantics) {
if (arrays.empty()) {
return std::vector<PyArray>();
}
auto* client = arrays[0].ifrt_array()->client();
std::vector<tsl::RCReference<ifrt::DeviceList>> device_lists;
device_lists.reserve(dst_device_lists.size());
for (const auto& dst_devices : dst_device_lists) {
ifrt::BasicDeviceList::Devices devices;
absl::InlinedVector<ifrt::Device*, 1> devices;
devices.reserve(dst_devices.size());
for (auto& d : dst_devices) {
devices.push_back(d->device());
}
device_lists.push_back(
ifrt::BasicDeviceList::Create(std::move(devices)));
device_lists.push_back(client->MakeDeviceList(devices));
}
return xla::ValueOrThrow(PyArray::BatchedCopyToDeviceWithSharding(
arrays, device_lists, shardings, array_copy_semantics));
Expand Down
11 changes: 6 additions & 5 deletions xla/python/py_device_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/container/inlined_vector.h"
#include "absl/hash/hash.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -58,7 +59,7 @@ PyDeviceList::PyDeviceList(nb::tuple py_device_assignment)
device_list_ = xla::ifrt::BasicDeviceList::Create({});
return;
}
xla::ifrt::BasicDeviceList::Devices devices;
absl::InlinedVector<xla::ifrt::Device*, 1> devices;
devices.reserve(py_device_assignment.size());
for (nb::handle obj : py_device_assignment) {
if (!nb::isinstance<xla::PyDevice>(obj.ptr())) {
Expand All @@ -75,7 +76,7 @@ PyDeviceList::PyDeviceList(nb::tuple py_device_assignment)
}
devices.push_back(py_device->device());
}
device_list_ = xla::ifrt::BasicDeviceList::Create(std::move(devices));
device_list_ = py_client_->ifrt_client()->MakeDeviceList(devices);
}

PyDeviceList::~PyDeviceList() {
Expand Down Expand Up @@ -303,7 +304,7 @@ bool PyDeviceList::IsFullyAddressable() {
if (!self->addressable_device_list_.has_value()) {
switch (self->device_list_.index()) {
case 0: {
xla::ifrt::BasicDeviceList::Devices addressable_devices;
absl::InlinedVector<xla::ifrt::Device*, 1> addressable_devices;
const int process_index =
self->py_client_ ? self->py_client_->process_index() : 0;
for (xla::ifrt::Device* device :
Expand All @@ -313,8 +314,8 @@ bool PyDeviceList::IsFullyAddressable() {
}
}
self->addressable_device_list_ = xla::make_nb_class<PyDeviceList>(
self->py_client_,
xla::ifrt::BasicDeviceList::Create(std::move(addressable_devices)));
self->py_client_, self->py_client_->ifrt_client()->MakeDeviceList(
addressable_devices));
break;
}
case 1: {
Expand Down
18 changes: 9 additions & 9 deletions xla/python/py_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
Expand Down Expand Up @@ -152,7 +153,7 @@ struct ShardedBufferAdapter<ExecuteShardedArg> {
// shape information is unused.
std::vector<tsl::RCReference<ifrt::Array>> ifrt_arrays;
ifrt_arrays.reserve(arg_vector.size());
ifrt::BasicDeviceList::Devices devices;
absl::InlinedVector<ifrt::Device*, 1> devices;
devices.reserve(arg_vector.size());
for (auto& arr : arg_vector) {
CHECK_EQ(arr.ifrt_array()->sharding().devices()->size(), 1)
Expand All @@ -165,14 +166,13 @@ struct ShardedBufferAdapter<ExecuteShardedArg> {
// Use a dummy shape.
// TODO(hyeontaek): Find a way to compute a correct shape.
// TODO(yashkatariya): Plumb sharding or memory_kind here.
auto ifrt_array =
ifrt_arrays.front()->client()->AssembleArrayFromSingleDeviceArrays(
ifrt_arrays.front()->shape(),
ifrt::OpaqueSharding::Create(
ifrt::BasicDeviceList::Create(std::move(devices)),
ifrt::MemoryKind()),
absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput,
ifrt::SingleDeviceShardSemantics::kAddressableShards);
ifrt::Client* client = ifrt_arrays.front()->client();
auto ifrt_array = client->AssembleArrayFromSingleDeviceArrays(
ifrt_arrays.front()->shape(),
ifrt::OpaqueSharding::Create(client->MakeDeviceList(devices),
ifrt::MemoryKind()),
absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput,
ifrt::SingleDeviceShardSemantics::kAddressableShards);
TF_CHECK_OK(ifrt_array.status());
return *ifrt_array;
}
Expand Down
12 changes: 10 additions & 2 deletions xla/python/py_program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "absl/strings/string_view.h"
Expand All @@ -39,6 +41,7 @@ limitations under the License.
#include "xla/python/ifrt/array_spec.h"
#include "xla/python/ifrt/compiler.h"
#include "xla/python/ifrt/custom_call_program.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/device_list.h"
#include "xla/python/ifrt/hlo/hlo_program.h"
#include "xla/python/ifrt/host_callback.h"
Expand Down Expand Up @@ -73,12 +76,17 @@ absl::StatusOr<tsl::RCReference<ifrt::DeviceList>> GetDeviceList(
return nb::cast<const jax::PyDeviceList*>(devices)->ifrt_device_list();
} else {
auto py_devices = nb::cast<std::vector<nb_class_ptr<PyDevice>>>(devices);
ifrt::BasicDeviceList::Devices ifrt_devices;
if (py_devices.empty()) {
return absl::InvalidArgumentError(
"Colocated Python program requires at least one device");
}
absl::InlinedVector<ifrt::Device*, 1> ifrt_devices;
ifrt_devices.reserve(py_devices.size());
for (const nb_class_ptr<PyDevice>& py_device : py_devices) {
ifrt_devices.push_back(py_device->device());
}
return ifrt::BasicDeviceList::Create(std::move(ifrt_devices));
return py_devices.front()->client()->ifrt_client()->MakeDeviceList(
ifrt_devices);
}
}

Expand Down
4 changes: 2 additions & 2 deletions xla/python/py_values.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,12 +427,12 @@ absl::StatusOr<DevicePutResultFn> HandlePyArray(
} else {
return [ifrt_array = tsl::FormRef(ifrt_array), to_device, to_memory_kind,
owning_pybuffer = py_array.weak_type()]() mutable
-> absl::StatusOr<DevicePutResult> {
-> absl::StatusOr<DevicePutResult> {
auto* ifrt_client = ifrt_array->client();
TF_ASSIGN_OR_RETURN(
auto copied_ifrt_arrays,
ifrt_client->CopyArrays(absl::MakeSpan(&ifrt_array, 1),
ifrt::BasicDeviceList::Create({to_device}),
ifrt_client->MakeDeviceList({to_device}),
to_memory_kind,
ifrt::ArrayCopySemantics::kReuseInput));
return DevicePutResult(std::move(copied_ifrt_arrays[0]),
Expand Down
4 changes: 2 additions & 2 deletions xla/python/xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ NB_MODULE(xla_extension, m) {
"get_topology_for_devices requires >= 1 devices.");
}
auto client = py_devices[0]->client();
ifrt::BasicDeviceList::Devices ifrt_devices;
absl::InlinedVector<ifrt::Device*, 1> ifrt_devices;
ifrt_devices.reserve(py_devices.size());
for (const auto& py_device : py_devices) {
if (py_device->client().get() != client.get()) {
Expand All @@ -461,7 +461,7 @@ NB_MODULE(xla_extension, m) {
ifrt_devices.push_back(py_device->device());
}
tsl::RCReference<ifrt::DeviceList> device_list =
ifrt::BasicDeviceList::Create(std::move(ifrt_devices));
client->ifrt_client()->MakeDeviceList(ifrt_devices);
return xla::ValueOrThrow(
client->ifrt_client()->GetTopologyForDevices(device_list));
});
Expand Down

0 comments on commit e9c0445

Please sign in to comment.