Skip to content

Commit

Permalink
Split BasicDeviceList into its own BUILD target and make it visible…
Browse files Browse the repository at this point in the history
… only to IFRT implementations

After this CL, IFRT users will no longer have visibility to `BasicDeviceList`. This ensures that IFRT users use `Client::MakeDeviceList()` to create a device list instead of directly calling `BasicDeviceList::Create()`.

PiperOrigin-RevId: 726540745
  • Loading branch information
junwhanahn authored and Google-ML-Automation committed Feb 13, 2025
1 parent 1b07f67 commit ee4c408
Show file tree
Hide file tree
Showing 25 changed files with 264 additions and 148 deletions.
1 change: 1 addition & 0 deletions xla/backends/cpu/nanort/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ cc_library(
"//xla/pjrt:utils",
"//xla/python/ifrt",
"//xla/python/ifrt:attribute_map",
"//xla/python/ifrt:basic_device_list",
"//xla/python/ifrt/hlo:hlo_program",
"//xla/python/pjrt_ifrt:pjrt_dtype",
"//xla/python/pjrt_ifrt:xla_ifrt",
Expand Down
1 change: 1 addition & 0 deletions xla/backends/cpu/nanort/ifrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ limitations under the License.
#include "xla/pjrt/utils.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/attribute_map.h"
#include "xla/python/ifrt/basic_device_list.h"
#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/compiler.h"
#include "xla/python/ifrt/device.h"
Expand Down
1 change: 1 addition & 0 deletions xla/python/compile_only_ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ cc_library(
"//xla/pjrt:pjrt_layout",
"//xla/python/ifrt",
"//xla/python/ifrt:attribute_map",
"//xla/python/ifrt:basic_device_list",
"//xla/python/pjrt_ifrt",
"//xla/python/pjrt_ifrt:pjrt_attribute_map_util",
"//xla/python/pjrt_ifrt:pjrt_dtype",
Expand Down
1 change: 1 addition & 0 deletions xla/python/compile_only_ifrt/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "xla/pjrt/pjrt_layout.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/attribute_map.h"
#include "xla/python/ifrt/basic_device_list.h"
#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/compiler.h"
#include "xla/python/ifrt/device.h"
Expand Down
26 changes: 25 additions & 1 deletion xla/python/ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ cc_library(
"//xla/tsl/platform:logging",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down Expand Up @@ -304,6 +303,7 @@ cc_library(
":internal",
]),
deps = [
":basic_device_list",
":ifrt",
":mock",
":test_util",
Expand Down Expand Up @@ -445,6 +445,7 @@ cc_library(
]),
deps = [
":attribute_map",
":basic_device_list",
":ifrt",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
Expand Down Expand Up @@ -629,6 +630,29 @@ xla_cc_test(
],
)

cc_library(
name = "basic_device_list",
srcs = ["basic_device_list.cc"],
hdrs = ["basic_device_list.h"],
compatible_with = get_compatible_with_portable(),
visibility = internal_visibility([
":friends",
":internal",
]),
deps = [
":device_proto_cc",
":ifrt",
"//xla/tsl/concurrency:ref_count",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
],
)

tf_proto_library(
name = "dtype_proto",
srcs = ["dtype.proto"],
Expand Down
104 changes: 104 additions & 0 deletions xla/python/ifrt/basic_device_list.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/python/ifrt/basic_device_list.h"

#include <atomic>
#include <cstdint>
#include <initializer_list>
#include <string>
#include <utility>

#include "absl/base/call_once.h"
#include "absl/base/optimization.h"
#include "absl/hash/hash.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/device.pb.h"
#include "xla/python/ifrt/device_list.h"
#include "xla/tsl/concurrency/ref_count.h"

namespace xla {
namespace ifrt {

char BasicDeviceList::ID = 0;

tsl::RCReference<DeviceList> BasicDeviceList::Create(Devices devices) {
return tsl::MakeRef<BasicDeviceList>(std::move(devices));
}

tsl::RCReference<DeviceList> BasicDeviceList::Create(
absl::Span<Device* const> devices) {
return Create(Devices(devices.begin(), devices.end()));
}

tsl::RCReference<DeviceList> BasicDeviceList::Create(
std::initializer_list<Device*> devices) {
return Create(Devices(devices.begin(), devices.end()));
}

BasicDeviceList::BasicDeviceList(Devices devices)
: devices_(std::move(devices)), hash_(kUnsetHash) {}

DeviceList* BasicDeviceList::AddressableDeviceList() const {
absl::call_once(addressable_device_list_cache_.once_flag, [this] {
Devices addressable_devices;
for (Device* device : devices_) {
if (device->IsAddressable()) {
addressable_devices.push_back(device);
}
}
const bool already_fully_addressable =
addressable_devices.size() == devices_.size();
if (already_fully_addressable) {
// `device_list_holder` is intentionally unset. We skip storing a
// reference-counted copy in the holder to avoid creating a self cycle.
addressable_device_list_cache_.device_list =
const_cast<BasicDeviceList*>(this);
} else {
addressable_device_list_cache_.device_list_holder =
BasicDeviceList::Create(std::move(addressable_devices));
addressable_device_list_cache_.device_list =
addressable_device_list_cache_.device_list_holder.get();
}
});
return addressable_device_list_cache_.device_list;
}

uint64_t BasicDeviceList::hash() const {
uint64_t hash = hash_.load(std::memory_order_relaxed);
if (ABSL_PREDICT_FALSE(hash == kUnsetHash)) {
hash = absl::HashOf(devices());
if (ABSL_PREDICT_FALSE(hash == kUnsetHash)) {
++hash;
}
hash_.store(hash, std::memory_order_relaxed);
}
return hash;
}

std::string BasicDeviceList::ToString() const {
return absl::StrCat("BasicDeviceList([",
absl::StrJoin(devices_, ",",
[](std::string* out, Device* device) {
absl::StrAppend(out, device->ToString());
}),
"])");
}

} // namespace ifrt
} // namespace xla
111 changes: 111 additions & 0 deletions xla/python/ifrt/basic_device_list.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_PYTHON_IFRT_BASIC_DEVICE_LIST_H_
#define XLA_PYTHON_IFRT_BASIC_DEVICE_LIST_H_

#include <atomic>
#include <cstdint>
#include <initializer_list>
#include <string>

#include "absl/base/call_once.h"
#include "absl/container/inlined_vector.h"
#include "absl/types/span.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ExtensibleRTTI.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/device.pb.h"
#include "xla/python/ifrt/device_list.h"
#include "xla/tsl/concurrency/ref_count.h"

namespace xla {
namespace ifrt {

// Simple implementation of `DeviceList` that contains a list of devices without
// creating any runtime object in the IFRT implementation.
//
// This is a transitory type that will be replaced with (1) a non-IFRT container
// defined by the user code (e.g., `std::vector<Device*>`) or (2) IFRT
// implementation's own `DeviceList` constructed from its `xla::ifrt::Client`
// API implementation.
//
// Note for IFRT API users: This class is primarily intended for IFRT
// implementations. Please use `Client::MakeDeviceList()` instead.
class BasicDeviceList : public llvm::RTTIExtends<BasicDeviceList, DeviceList> {
public:
// Number of devices to inline in `Devices`.
static constexpr int kInlineDeviceSize = 1;

// TODO(hyeontaek): Consider using variant<Device*, std::vector<Device*>> for
// better performance.
using Devices = absl::InlinedVector<Device*, kInlineDeviceSize>;

// Constructor with a pre-populated `devices`.
static tsl::RCReference<DeviceList> Create(Devices devices);
static tsl::RCReference<DeviceList> Create(absl::Span<Device* const> devices);
static tsl::RCReference<DeviceList> Create(
std::initializer_list<Device*> devices);

~BasicDeviceList() override = default;

absl::Span<Device* const> devices() const override { return devices_; }

DeviceList* AddressableDeviceList() const override;

bool operator==(const DeviceList& other) const override {
if (this == &other) {
return true;
}
const auto* other_basic_device_list =
llvm::dyn_cast<BasicDeviceList>(&other);
if (other_basic_device_list == nullptr) {
return false;
}
return devices_ == other_basic_device_list->devices_;
}

uint64_t hash() const override;

static char ID; // NOLINT

private:
explicit BasicDeviceList(Devices devices);

template <typename T, typename... Args>
friend tsl::RCReference<T> tsl::MakeRef(Args&&... args);

std::string ToString() const override;

Devices devices_;

// Addressable device list is dynamically computed and cached.
struct AddressableDeviceListCache {
absl::once_flag once_flag;
DeviceList* device_list = nullptr;
tsl::RCReference<DeviceList> device_list_holder;
};
mutable AddressableDeviceListCache addressable_device_list_cache_;

// Cached hash. 0 indicates the hash needs to be computed and cached.
// May be written multiple times with the same non-zero value.
static constexpr uint64_t kUnsetHash = 0;
mutable std::atomic<uint64_t> hash_;
};

} // namespace ifrt
} // namespace xla

#endif // XLA_PYTHON_IFRT_BASIC_DEVICE_LIST_H_
Loading

0 comments on commit ee4c408

Please sign in to comment.