-
Notifications
You must be signed in to change notification settings - Fork 508
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Split
BasicDeviceList
into its own BUILD target and make it visible…
… only to IFRT implementations After this CL, IFRT users will no longer have visibility to `BasicDeviceList`. This ensures that IFRT users use `Client::MakeDeviceList()` to create a device list instead of directly calling `BasicDeviceList::Create()`. PiperOrigin-RevId: 726540745
- Loading branch information
1 parent
1b07f67
commit ee4c408
Showing
25 changed files
with
264 additions
and
148 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
/* Copyright 2025 The OpenXLA Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include "xla/python/ifrt/basic_device_list.h" | ||
|
||
#include <atomic> | ||
#include <cstdint> | ||
#include <initializer_list> | ||
#include <string> | ||
#include <utility> | ||
|
||
#include "absl/base/call_once.h" | ||
#include "absl/base/optimization.h" | ||
#include "absl/hash/hash.h" | ||
#include "absl/strings/str_cat.h" | ||
#include "absl/strings/str_join.h" | ||
#include "absl/types/span.h" | ||
#include "xla/python/ifrt/device.h" | ||
#include "xla/python/ifrt/device.pb.h" | ||
#include "xla/python/ifrt/device_list.h" | ||
#include "xla/tsl/concurrency/ref_count.h" | ||
|
||
namespace xla { | ||
namespace ifrt { | ||
|
||
char BasicDeviceList::ID = 0; | ||
|
||
tsl::RCReference<DeviceList> BasicDeviceList::Create(Devices devices) { | ||
return tsl::MakeRef<BasicDeviceList>(std::move(devices)); | ||
} | ||
|
||
tsl::RCReference<DeviceList> BasicDeviceList::Create( | ||
absl::Span<Device* const> devices) { | ||
return Create(Devices(devices.begin(), devices.end())); | ||
} | ||
|
||
tsl::RCReference<DeviceList> BasicDeviceList::Create( | ||
std::initializer_list<Device*> devices) { | ||
return Create(Devices(devices.begin(), devices.end())); | ||
} | ||
|
||
BasicDeviceList::BasicDeviceList(Devices devices) | ||
: devices_(std::move(devices)), hash_(kUnsetHash) {} | ||
|
||
DeviceList* BasicDeviceList::AddressableDeviceList() const { | ||
absl::call_once(addressable_device_list_cache_.once_flag, [this] { | ||
Devices addressable_devices; | ||
for (Device* device : devices_) { | ||
if (device->IsAddressable()) { | ||
addressable_devices.push_back(device); | ||
} | ||
} | ||
const bool already_fully_addressable = | ||
addressable_devices.size() == devices_.size(); | ||
if (already_fully_addressable) { | ||
// `device_list_holder` is intentionally unset. We skip storing a | ||
// reference-counted copy in the holder to avoid creating a self cycle. | ||
addressable_device_list_cache_.device_list = | ||
const_cast<BasicDeviceList*>(this); | ||
} else { | ||
addressable_device_list_cache_.device_list_holder = | ||
BasicDeviceList::Create(std::move(addressable_devices)); | ||
addressable_device_list_cache_.device_list = | ||
addressable_device_list_cache_.device_list_holder.get(); | ||
} | ||
}); | ||
return addressable_device_list_cache_.device_list; | ||
} | ||
|
||
uint64_t BasicDeviceList::hash() const { | ||
uint64_t hash = hash_.load(std::memory_order_relaxed); | ||
if (ABSL_PREDICT_FALSE(hash == kUnsetHash)) { | ||
hash = absl::HashOf(devices()); | ||
if (ABSL_PREDICT_FALSE(hash == kUnsetHash)) { | ||
++hash; | ||
} | ||
hash_.store(hash, std::memory_order_relaxed); | ||
} | ||
return hash; | ||
} | ||
|
||
std::string BasicDeviceList::ToString() const { | ||
return absl::StrCat("BasicDeviceList([", | ||
absl::StrJoin(devices_, ",", | ||
[](std::string* out, Device* device) { | ||
absl::StrAppend(out, device->ToString()); | ||
}), | ||
"])"); | ||
} | ||
|
||
} // namespace ifrt | ||
} // namespace xla |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
/* Copyright 2025 The OpenXLA Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#ifndef XLA_PYTHON_IFRT_BASIC_DEVICE_LIST_H_ | ||
#define XLA_PYTHON_IFRT_BASIC_DEVICE_LIST_H_ | ||
|
||
#include <atomic> | ||
#include <cstdint> | ||
#include <initializer_list> | ||
#include <string> | ||
|
||
#include "absl/base/call_once.h" | ||
#include "absl/container/inlined_vector.h" | ||
#include "absl/types/span.h" | ||
#include "llvm/Support/Casting.h" | ||
#include "llvm/Support/ExtensibleRTTI.h" | ||
#include "xla/python/ifrt/device.h" | ||
#include "xla/python/ifrt/device.pb.h" | ||
#include "xla/python/ifrt/device_list.h" | ||
#include "xla/tsl/concurrency/ref_count.h" | ||
|
||
namespace xla { | ||
namespace ifrt { | ||
|
||
// Simple implementation of `DeviceList` that contains a list of devices without | ||
// creating any runtime object in the IFRT implementation. | ||
// | ||
// This is a transitory type that will be replaced with (1) a non-IFRT container | ||
// defined by the user code (e.g., `std::vector<Device*>`) or (2) IFRT | ||
// implementation's own `DeviceList` constructed from its `xla::ifrt::Client` | ||
// API implementation. | ||
// | ||
// Note for IFRT API users: This class is primarily intended for IFRT | ||
// implementations. Please use `Client::MakeDeviceList()` instead. | ||
class BasicDeviceList : public llvm::RTTIExtends<BasicDeviceList, DeviceList> { | ||
public: | ||
// Number of devices to inline in `Devices`. | ||
static constexpr int kInlineDeviceSize = 1; | ||
|
||
// TODO(hyeontaek): Consider using variant<Device*, std::vector<Device*>> for | ||
// better performance. | ||
using Devices = absl::InlinedVector<Device*, kInlineDeviceSize>; | ||
|
||
// Constructor with a pre-populated `devices`. | ||
static tsl::RCReference<DeviceList> Create(Devices devices); | ||
static tsl::RCReference<DeviceList> Create(absl::Span<Device* const> devices); | ||
static tsl::RCReference<DeviceList> Create( | ||
std::initializer_list<Device*> devices); | ||
|
||
~BasicDeviceList() override = default; | ||
|
||
absl::Span<Device* const> devices() const override { return devices_; } | ||
|
||
DeviceList* AddressableDeviceList() const override; | ||
|
||
bool operator==(const DeviceList& other) const override { | ||
if (this == &other) { | ||
return true; | ||
} | ||
const auto* other_basic_device_list = | ||
llvm::dyn_cast<BasicDeviceList>(&other); | ||
if (other_basic_device_list == nullptr) { | ||
return false; | ||
} | ||
return devices_ == other_basic_device_list->devices_; | ||
} | ||
|
||
uint64_t hash() const override; | ||
|
||
static char ID; // NOLINT | ||
|
||
private: | ||
explicit BasicDeviceList(Devices devices); | ||
|
||
template <typename T, typename... Args> | ||
friend tsl::RCReference<T> tsl::MakeRef(Args&&... args); | ||
|
||
std::string ToString() const override; | ||
|
||
Devices devices_; | ||
|
||
// Addressable device list is dynamically computed and cached. | ||
struct AddressableDeviceListCache { | ||
absl::once_flag once_flag; | ||
DeviceList* device_list = nullptr; | ||
tsl::RCReference<DeviceList> device_list_holder; | ||
}; | ||
mutable AddressableDeviceListCache addressable_device_list_cache_; | ||
|
||
// Cached hash. 0 indicates the hash needs to be computed and cached. | ||
// May be written multiple times with the same non-zero value. | ||
static constexpr uint64_t kUnsetHash = 0; | ||
mutable std::atomic<uint64_t> hash_; | ||
}; | ||
|
||
} // namespace ifrt | ||
} // namespace xla | ||
|
||
#endif // XLA_PYTHON_IFRT_BASIC_DEVICE_LIST_H_ |
Oops, something went wrong.