diff --git a/xla/backends/cpu/nanort/BUILD b/xla/backends/cpu/nanort/BUILD index efdfe24112134..8adb1898e6d5a 100644 --- a/xla/backends/cpu/nanort/BUILD +++ b/xla/backends/cpu/nanort/BUILD @@ -134,6 +134,7 @@ cc_library( "//xla/pjrt:utils", "//xla/python/ifrt", "//xla/python/ifrt:attribute_map", + "//xla/python/ifrt:basic_device_list", "//xla/python/ifrt/hlo:hlo_program", "//xla/python/pjrt_ifrt:pjrt_dtype", "//xla/python/pjrt_ifrt:xla_ifrt", diff --git a/xla/backends/cpu/nanort/ifrt_client.cc b/xla/backends/cpu/nanort/ifrt_client.cc index ecf82dfde78ff..ef844b2b8dbd7 100644 --- a/xla/backends/cpu/nanort/ifrt_client.cc +++ b/xla/backends/cpu/nanort/ifrt_client.cc @@ -58,6 +58,7 @@ limitations under the License. #include "xla/pjrt/utils.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" diff --git a/xla/python/compile_only_ifrt/BUILD b/xla/python/compile_only_ifrt/BUILD index 824ca6e921f1d..23a265a9ce84c 100644 --- a/xla/python/compile_only_ifrt/BUILD +++ b/xla/python/compile_only_ifrt/BUILD @@ -18,6 +18,7 @@ cc_library( "//xla/pjrt:pjrt_layout", "//xla/python/ifrt", "//xla/python/ifrt:attribute_map", + "//xla/python/ifrt:basic_device_list", "//xla/python/pjrt_ifrt", "//xla/python/pjrt_ifrt:pjrt_attribute_map_util", "//xla/python/pjrt_ifrt:pjrt_dtype", diff --git a/xla/python/compile_only_ifrt/client.h b/xla/python/compile_only_ifrt/client.h index ac4c33d659a83..06ec9abfac605 100644 --- a/xla/python/compile_only_ifrt/client.h +++ b/xla/python/compile_only_ifrt/client.h @@ -36,6 +36,7 @@ limitations under the License. #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" diff --git a/xla/python/ifrt/BUILD b/xla/python/ifrt/BUILD index d4b5496128d7d..de194f25dc4b1 100644 --- a/xla/python/ifrt/BUILD +++ b/xla/python/ifrt/BUILD @@ -100,7 +100,6 @@ cc_library( "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", @@ -304,6 +303,7 @@ cc_library( ":internal", ]), deps = [ + ":basic_device_list", ":ifrt", ":mock", ":test_util", @@ -445,6 +445,7 @@ cc_library( ]), deps = [ ":attribute_map", + ":basic_device_list", ":ifrt", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -629,6 +630,29 @@ xla_cc_test( ], ) +cc_library( + name = "basic_device_list", + srcs = ["basic_device_list.cc"], + hdrs = ["basic_device_list.h"], + compatible_with = get_compatible_with_portable(), + visibility = internal_visibility([ + ":friends", + ":internal", + ]), + deps = [ + ":device_proto_cc", + ":ifrt", + "//xla/tsl/concurrency:ref_count", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + ], +) + tf_proto_library( name = "dtype_proto", srcs = ["dtype.proto"], diff --git a/xla/python/ifrt/basic_device_list.cc b/xla/python/ifrt/basic_device_list.cc new file mode 100644 index 0000000000000..54a6e25a7e5df --- /dev/null +++ b/xla/python/ifrt/basic_device_list.cc @@ -0,0 +1,104 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/ifrt/basic_device_list.h" + +#include +#include +#include +#include +#include + +#include "absl/base/call_once.h" +#include "absl/base/optimization.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device.pb.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/tsl/concurrency/ref_count.h" + +namespace xla { +namespace ifrt { + +char BasicDeviceList::ID = 0; + +tsl::RCReference BasicDeviceList::Create(Devices devices) { + return tsl::MakeRef(std::move(devices)); +} + +tsl::RCReference BasicDeviceList::Create( + absl::Span devices) { + return Create(Devices(devices.begin(), devices.end())); +} + +tsl::RCReference BasicDeviceList::Create( + std::initializer_list devices) { + return Create(Devices(devices.begin(), devices.end())); +} + +BasicDeviceList::BasicDeviceList(Devices devices) + : devices_(std::move(devices)), hash_(kUnsetHash) {} + +DeviceList* BasicDeviceList::AddressableDeviceList() const { + absl::call_once(addressable_device_list_cache_.once_flag, [this] { + Devices addressable_devices; + for (Device* device : devices_) { + if (device->IsAddressable()) { + addressable_devices.push_back(device); + } + } + const bool already_fully_addressable = + addressable_devices.size() == devices_.size(); + if (already_fully_addressable) { + // `device_list_holder` is intentionally unset. We skip storing a + // reference-counted copy in the holder to avoid creating a self cycle. + addressable_device_list_cache_.device_list = + const_cast(this); + } else { + addressable_device_list_cache_.device_list_holder = + BasicDeviceList::Create(std::move(addressable_devices)); + addressable_device_list_cache_.device_list = + addressable_device_list_cache_.device_list_holder.get(); + } + }); + return addressable_device_list_cache_.device_list; +} + +uint64_t BasicDeviceList::hash() const { + uint64_t hash = hash_.load(std::memory_order_relaxed); + if (ABSL_PREDICT_FALSE(hash == kUnsetHash)) { + hash = absl::HashOf(devices()); + if (ABSL_PREDICT_FALSE(hash == kUnsetHash)) { + ++hash; + } + hash_.store(hash, std::memory_order_relaxed); + } + return hash; +} + +std::string BasicDeviceList::ToString() const { + return absl::StrCat("BasicDeviceList([", + absl::StrJoin(devices_, ",", + [](std::string* out, Device* device) { + absl::StrAppend(out, device->ToString()); + }), + "])"); +} + +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt/basic_device_list.h b/xla/python/ifrt/basic_device_list.h new file mode 100644 index 0000000000000..65d335e11ff9b --- /dev/null +++ b/xla/python/ifrt/basic_device_list.h @@ -0,0 +1,111 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_BASIC_DEVICE_LIST_H_ +#define XLA_PYTHON_IFRT_BASIC_DEVICE_LIST_H_ + +#include +#include +#include +#include + +#include "absl/base/call_once.h" +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device.pb.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/tsl/concurrency/ref_count.h" + +namespace xla { +namespace ifrt { + +// Simple implementation of `DeviceList` that contains a list of devices without +// creating any runtime object in the IFRT implementation. +// +// This is a transitory type that will be replaced with (1) a non-IFRT container +// defined by the user code (e.g., `std::vector`) or (2) IFRT +// implementation's own `DeviceList` constructed from its `xla::ifrt::Client` +// API implementation. +// +// Note for IFRT API users: This class is primarily intended for IFRT +// implementations. Please use `Client::MakeDeviceList()` instead. +class BasicDeviceList : public llvm::RTTIExtends { + public: + // Number of devices to inline in `Devices`. + static constexpr int kInlineDeviceSize = 1; + + // TODO(hyeontaek): Consider using variant> for + // better performance. + using Devices = absl::InlinedVector; + + // Constructor with a pre-populated `devices`. + static tsl::RCReference Create(Devices devices); + static tsl::RCReference Create(absl::Span devices); + static tsl::RCReference Create( + std::initializer_list devices); + + ~BasicDeviceList() override = default; + + absl::Span devices() const override { return devices_; } + + DeviceList* AddressableDeviceList() const override; + + bool operator==(const DeviceList& other) const override { + if (this == &other) { + return true; + } + const auto* other_basic_device_list = + llvm::dyn_cast(&other); + if (other_basic_device_list == nullptr) { + return false; + } + return devices_ == other_basic_device_list->devices_; + } + + uint64_t hash() const override; + + static char ID; // NOLINT + + private: + explicit BasicDeviceList(Devices devices); + + template + friend tsl::RCReference tsl::MakeRef(Args&&... args); + + std::string ToString() const override; + + Devices devices_; + + // Addressable device list is dynamically computed and cached. + struct AddressableDeviceListCache { + absl::once_flag once_flag; + DeviceList* device_list = nullptr; + tsl::RCReference device_list_holder; + }; + mutable AddressableDeviceListCache addressable_device_list_cache_; + + // Cached hash. 0 indicates the hash needs to be computed and cached. + // May be written multiple times with the same non-zero value. + static constexpr uint64_t kUnsetHash = 0; + mutable std::atomic hash_; +}; + +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_BASIC_DEVICE_LIST_H_ diff --git a/xla/python/ifrt/device_list.cc b/xla/python/ifrt/device_list.cc index 75efeea9b01a0..20e494df51454 100644 --- a/xla/python/ifrt/device_list.cc +++ b/xla/python/ifrt/device_list.cc @@ -15,20 +15,10 @@ limitations under the License. #include "xla/python/ifrt/device_list.h" -#include -#include -#include -#include -#include #include -#include "absl/base/call_once.h" -#include "absl/base/optimization.h" #include "absl/container/inlined_vector.h" -#include "absl/hash/hash.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" #include "absl/types/span.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/device.h" @@ -40,7 +30,6 @@ namespace xla { namespace ifrt { char DeviceList::ID = 0; -char BasicDeviceList::ID = 0; absl::StatusOr> DeviceList::FromProto( xla::ifrt::Client* client, const DeviceListProto& proto) { @@ -63,69 +52,6 @@ DeviceListProto DeviceList::ToProto() const { return proto; } -tsl::RCReference BasicDeviceList::Create(Devices devices) { - return tsl::MakeRef(std::move(devices)); -} - -tsl::RCReference BasicDeviceList::Create( - absl::Span devices) { - return Create(Devices(devices.begin(), devices.end())); -} - -tsl::RCReference BasicDeviceList::Create( - std::initializer_list devices) { - return Create(Devices(devices.begin(), devices.end())); -} - -BasicDeviceList::BasicDeviceList(Devices devices) - : devices_(std::move(devices)), hash_(kUnsetHash) {} - -DeviceList* BasicDeviceList::AddressableDeviceList() const { - absl::call_once(addressable_device_list_cache_.once_flag, [this] { - Devices addressable_devices; - for (Device* device : devices_) { - if (device->IsAddressable()) { - addressable_devices.push_back(device); - } - } - const bool already_fully_addressable = - addressable_devices.size() == devices_.size(); - if (already_fully_addressable) { - // `device_list_holder` is intentionally unset. We skip storing a - // reference-counted copy in the holder to avoid creating a self cycle. - addressable_device_list_cache_.device_list = - const_cast(this); - } else { - addressable_device_list_cache_.device_list_holder = - BasicDeviceList::Create(std::move(addressable_devices)); - addressable_device_list_cache_.device_list = - addressable_device_list_cache_.device_list_holder.get(); - } - }); - return addressable_device_list_cache_.device_list; -} - -uint64_t BasicDeviceList::hash() const { - uint64_t hash = hash_.load(std::memory_order_relaxed); - if (ABSL_PREDICT_FALSE(hash == kUnsetHash)) { - hash = absl::HashOf(devices()); - if (ABSL_PREDICT_FALSE(hash == kUnsetHash)) { - ++hash; - } - hash_.store(hash, std::memory_order_relaxed); - } - return hash; -} - -std::string BasicDeviceList::ToString() const { - return absl::StrCat("BasicDeviceList([", - absl::StrJoin(devices_, ",", - [](std::string* out, Device* device) { - absl::StrAppend(out, device->ToString()); - }), - "])"); -} - std::vector GetDeviceIds( const tsl::RCReference& device_list) { std::vector ids; diff --git a/xla/python/ifrt/device_list.h b/xla/python/ifrt/device_list.h index 37280fcc649b1..44a3dfd1ce8a4 100644 --- a/xla/python/ifrt/device_list.h +++ b/xla/python/ifrt/device_list.h @@ -16,18 +16,13 @@ limitations under the License. #ifndef XLA_PYTHON_IFRT_DEVICE_LIST_H_ #define XLA_PYTHON_IFRT_DEVICE_LIST_H_ -#include #include -#include #include #include -#include "absl/base/call_once.h" -#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "llvm/Support/Casting.h" #include "llvm/Support/ExtensibleRTTI.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device.pb.h" @@ -106,74 +101,6 @@ class DeviceList : public tsl::ReferenceCounted, virtual std::string ToString() const = 0; }; -// Simple implementation of `DeviceList` that contains a list of devices without -// creating any runtime object in the IFRT implementation. -// -// This is a transitory type that will be replaced with (1) a non-IFRT container -// defined by the user code (e.g., `std::vector`) or (2) IFRT -// implementation's own `DeviceList` constructed from its `xla::ifrt::Client` -// API implementation. -class BasicDeviceList : public llvm::RTTIExtends { - public: - // Number of devices to inline in `Devices`. - static constexpr int kInlineDeviceSize = 1; - - // TODO(hyeontaek): Consider using variant> for - // better performance. - using Devices = absl::InlinedVector; - - // Constructor with a pre-populated `devices`. - static tsl::RCReference Create(Devices devices); - static tsl::RCReference Create(absl::Span devices); - static tsl::RCReference Create( - std::initializer_list devices); - - ~BasicDeviceList() override = default; - - absl::Span devices() const override { return devices_; } - - DeviceList* AddressableDeviceList() const override; - - bool operator==(const DeviceList& other) const override { - if (this == &other) { - return true; - } - const auto* other_basic_device_list = - llvm::dyn_cast(&other); - if (other_basic_device_list == nullptr) { - return false; - } - return devices_ == other_basic_device_list->devices_; - } - - uint64_t hash() const override; - - static char ID; // NOLINT - - private: - explicit BasicDeviceList(Devices devices); - - template - friend tsl::RCReference tsl::MakeRef(Args&&... args); - - std::string ToString() const override; - - Devices devices_; - - // Addressable device list is dynamically computed and cached. - struct AddressableDeviceListCache { - absl::once_flag once_flag; - DeviceList* device_list = nullptr; - tsl::RCReference device_list_holder; - }; - mutable AddressableDeviceListCache addressable_device_list_cache_; - - // Cached hash. 0 indicates the hash needs to be computed and cached. - // May be written multiple times with the same non-zero value. - static constexpr uint64_t kUnsetHash = 0; - mutable std::atomic hash_; -}; - // Returns the id of each device in `device_list`. std::vector GetDeviceIds( const tsl::RCReference& device_list); diff --git a/xla/python/ifrt/device_test_util.cc b/xla/python/ifrt/device_test_util.cc index ca5ec5389cb24..990805902a57b 100644 --- a/xla/python/ifrt/device_test_util.cc +++ b/xla/python/ifrt/device_test_util.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/memory.h" diff --git a/xla/python/ifrt/mock.h b/xla/python/ifrt/mock.h index f397cb5bb3afa..416d523c6388c 100644 --- a/xla/python/ifrt/mock.h +++ b/xla/python/ifrt/mock.h @@ -36,6 +36,7 @@ limitations under the License. #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" diff --git a/xla/python/ifrt/support/BUILD b/xla/python/ifrt/support/BUILD index 24dd54da9c065..065768d2ac37b 100644 --- a/xla/python/ifrt/support/BUILD +++ b/xla/python/ifrt/support/BUILD @@ -62,6 +62,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/ir:tile_assignment", "//xla/python/ifrt", + "//xla/python/ifrt:basic_device_list", "//xla/python/ifrt:mock", "//xla/python/ifrt:test_util", "//xla/python/ifrt/ir:sharding_param", diff --git a/xla/python/ifrt/support/sharding_conversions_test.cc b/xla/python/ifrt/support/sharding_conversions_test.cc index 9c6f6b077a03c..a7443bc06c15d 100644 --- a/xla/python/ifrt/support/sharding_conversions_test.cc +++ b/xla/python/ifrt/support/sharding_conversions_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/tile_assignment.h" +#include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/index_domain.h" diff --git a/xla/python/ifrt_proxy/client/BUILD b/xla/python/ifrt_proxy/client/BUILD index e6ce0ca9c3a4a..202d642208d00 100644 --- a/xla/python/ifrt_proxy/client/BUILD +++ b/xla/python/ifrt_proxy/client/BUILD @@ -159,6 +159,7 @@ cc_library( "//xla/pjrt:pjrt_device_description", "//xla/python/ifrt", "//xla/python/ifrt:attribute_map", + "//xla/python/ifrt:basic_device_list", "//xla/python/ifrt_proxy/common:common_serdes", "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", "//xla/python/ifrt_proxy/common:types", @@ -273,6 +274,7 @@ ifrt_proxy_cc_test( ":rpc_helper", ":version", "//xla/python/ifrt", + "//xla/python/ifrt:basic_device_list", "//xla/python/ifrt:mock", "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", "//xla/python/ifrt_proxy/common:types", @@ -547,6 +549,7 @@ ifrt_proxy_cc_test( "//xla:shape_util", "//xla/pjrt:pjrt_layout", "//xla/python/ifrt", + "//xla/python/ifrt:basic_device_list", "//xla/python/ifrt:mock", "//xla/python/ifrt:sharding_serdes", "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", diff --git a/xla/python/ifrt_proxy/client/array_test.cc b/xla/python/ifrt_proxy/client/array_test.cc index 78c979e97cb72..9129fd8a8a971 100644 --- a/xla/python/ifrt_proxy/client/array_test.cc +++ b/xla/python/ifrt_proxy/client/array_test.cc @@ -24,6 +24,7 @@ #include "absl/time/time.h" #include "absl/types/span.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/dtype.h" diff --git a/xla/python/ifrt_proxy/client/client.cc b/xla/python/ifrt_proxy/client/client.cc index 5df0c527ba621..eaa22a2188d44 100644 --- a/xla/python/ifrt_proxy/client/client.cc +++ b/xla/python/ifrt_proxy/client/client.cc @@ -35,6 +35,7 @@ #include "xla/pjrt/pjrt_device_description.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" diff --git a/xla/python/ifrt_proxy/client/executable_test.cc b/xla/python/ifrt_proxy/client/executable_test.cc index 697bfd969e3c1..bcab02ce60f71 100644 --- a/xla/python/ifrt_proxy/client/executable_test.cc +++ b/xla/python/ifrt_proxy/client/executable_test.cc @@ -27,6 +27,7 @@ #include "xla/layout_util.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/dtype.h" diff --git a/xla/python/ifrt_proxy/server/BUILD b/xla/python/ifrt_proxy/server/BUILD index 668f040679376..06ffc85d106cd 100644 --- a/xla/python/ifrt_proxy/server/BUILD +++ b/xla/python/ifrt_proxy/server/BUILD @@ -126,6 +126,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/pjrt:pjrt_layout", "//xla/python/ifrt", + "//xla/python/ifrt:basic_device_list", "//xla/python/ifrt:program_serdes", "//xla/python/ifrt:serdes", "//xla/python/ifrt:sharding_serdes", @@ -179,6 +180,7 @@ ifrt_proxy_cc_test( "//xla/pjrt:pjrt_layout", "//xla/python/ifrt", "//xla/python/ifrt:attribute_map", + "//xla/python/ifrt:basic_device_list", "//xla/python/ifrt:mock", "//xla/python/ifrt:program_serdes", "//xla/python/ifrt:serdes", diff --git a/xla/python/ifrt_proxy/server/ifrt_backend.cc b/xla/python/ifrt_proxy/server/ifrt_backend.cc index 3270fcca90620..5bb9e9fef955b 100644 --- a/xla/python/ifrt_proxy/server/ifrt_backend.cc +++ b/xla/python/ifrt_proxy/server/ifrt_backend.cc @@ -47,6 +47,7 @@ #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" diff --git a/xla/python/ifrt_proxy/server/ifrt_backend_test.cc b/xla/python/ifrt_proxy/server/ifrt_backend_test.cc index 05155de5565d6..1ebc82a488f86 100644 --- a/xla/python/ifrt_proxy/server/ifrt_backend_test.cc +++ b/xla/python/ifrt_proxy/server/ifrt_backend_test.cc @@ -46,6 +46,7 @@ #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" diff --git a/xla/python/pjrt_ifrt/BUILD b/xla/python/pjrt_ifrt/BUILD index 752877e8d800f..49b526b48cfdf 100644 --- a/xla/python/pjrt_ifrt/BUILD +++ b/xla/python/pjrt_ifrt/BUILD @@ -164,6 +164,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/ir:tile_assignment", "//xla/python/ifrt", + "//xla/python/ifrt:basic_device_list", "//xla/python/ifrt:device_test_util", "//xla/python/ifrt:tuple_impl_test_lib", "//xla/tsl/concurrency:ref_count", @@ -229,6 +230,7 @@ cc_library( "//xla/pjrt/distributed:topology_util", "//xla/python/ifrt", "//xla/python/ifrt:attribute_map", + "//xla/python/ifrt:basic_device_list", "//xla/python/ifrt/hlo:hlo_program", "//xla/service:computation_placer_hdr", "//xla/service:hlo_proto_cc", diff --git a/xla/python/pjrt_ifrt/pjrt_array.cc b/xla/python/pjrt_ifrt/pjrt_array.cc index b8ba14293830c..a1bc18f9a1e35 100644 --- a/xla/python/pjrt_ifrt/pjrt_array.cc +++ b/xla/python/pjrt_ifrt/pjrt_array.cc @@ -37,6 +37,7 @@ limitations under the License. #include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/utils.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/dtype.h" diff --git a/xla/python/pjrt_ifrt/pjrt_client.cc b/xla/python/pjrt_ifrt/pjrt_client.cc index 745265daf2d56..464cca7ac0834 100644 --- a/xla/python/pjrt_ifrt/pjrt_client.cc +++ b/xla/python/pjrt_ifrt/pjrt_client.cc @@ -57,6 +57,7 @@ limitations under the License. #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" diff --git a/xla/python/pjrt_ifrt/pjrt_executable.cc b/xla/python/pjrt_ifrt/pjrt_executable.cc index 3b9ccbdfebee2..2b136123e93ae 100644 --- a/xla/python/pjrt_ifrt/pjrt_executable.cc +++ b/xla/python/pjrt_ifrt/pjrt_executable.cc @@ -40,6 +40,7 @@ limitations under the License. #include "xla/pjrt/pjrt_future.h" #include "xla/primitive_util.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/dtype.h" diff --git a/xla/python/pjrt_ifrt/xla_sharding_test.cc b/xla/python/pjrt_ifrt/xla_sharding_test.cc index 0c5a8f735bcfc..6c8e495d4f00c 100644 --- a/xla/python/pjrt_ifrt/xla_sharding_test.cc +++ b/xla/python/pjrt_ifrt/xla_sharding_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/tile_assignment.h" +#include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/device_test_util.h" #include "xla/python/ifrt/index.h"