From 96a9745ed6c3cf5394429c1dffacb88102eee471 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Thu, 13 Feb 2025 07:25:23 -0800 Subject: [PATCH] Split `CompileOnlyIfRtClient` into its own directory PiperOrigin-RevId: 726476996 --- xla/python/BUILD | 1 + xla/python/compile_only_ifrt/BUILD | 34 +++ xla/python/compile_only_ifrt/client.cc | 25 ++ xla/python/compile_only_ifrt/client.h | 353 +++++++++++++++++++++++++ xla/python/py_compile_only_client.cc | 326 +---------------------- 5 files changed, 414 insertions(+), 325 deletions(-) create mode 100644 xla/python/compile_only_ifrt/BUILD create mode 100644 xla/python/compile_only_ifrt/client.cc create mode 100644 xla/python/compile_only_ifrt/client.h diff --git a/xla/python/BUILD b/xla/python/BUILD index 148504d42189d..3d547fdbcb3ab 100644 --- a/xla/python/BUILD +++ b/xla/python/BUILD @@ -392,6 +392,7 @@ cc_library( "//xla/pjrt:transpose", "//xla/pjrt/distributed", "//xla/pjrt/distributed:client", + "//xla/python/compile_only_ifrt:client", "//xla/python/ifrt", "//xla/python/ifrt:attribute_map", "//xla/python/ifrt:custom_call_program", diff --git a/xla/python/compile_only_ifrt/BUILD b/xla/python/compile_only_ifrt/BUILD new file mode 100644 index 0000000000000..824ca6e921f1d --- /dev/null +++ b/xla/python/compile_only_ifrt/BUILD @@ -0,0 +1,34 @@ +load("//xla/tsl:tsl.bzl", "internal_visibility") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//xla/python:__subpackages__", + ]), +) + +cc_library( + name = "client", + srcs = ["client.cc"], + hdrs = ["client.h"], + deps = [ + "//xla:shape_util", + "//xla/pjrt:host_memory_spaces", + "//xla/pjrt:pjrt_device_description", + "//xla/pjrt:pjrt_layout", + "//xla/python/ifrt", + "//xla/python/ifrt:attribute_map", + "//xla/python/pjrt_ifrt", + "//xla/python/pjrt_ifrt:pjrt_attribute_map_util", + "//xla/python/pjrt_ifrt:pjrt_dtype", + "//xla/service:computation_placer_hdr", + "//xla/tsl/concurrency:ref_count", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + ], +) diff --git a/xla/python/compile_only_ifrt/client.cc b/xla/python/compile_only_ifrt/client.cc new file mode 100644 index 0000000000000..f837702ca2d5d --- /dev/null +++ b/xla/python/compile_only_ifrt/client.cc @@ -0,0 +1,25 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/compile_only_ifrt/client.h" + +namespace xla { + +[[maybe_unused]] char CompileOnlyMemory::ID = 0; +[[maybe_unused]] char CompileOnlyDevice::ID = 0; +[[maybe_unused]] char CompileOnlyIfrtCompiler::ID = 0; +[[maybe_unused]] char CompileOnlyIfRtClient::ID = 0; + +} // namespace xla diff --git a/xla/python/compile_only_ifrt/client.h b/xla/python/compile_only_ifrt/client.h new file mode 100644 index 0000000000000..ac4c33d659a83 --- /dev/null +++ b/xla/python/compile_only_ifrt/client.h @@ -0,0 +1,353 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_COMPILE_ONLY_IFRT_CLIENT_H_ +#define XLA_PYTHON_COMPILE_ONLY_IFRT_CLIENT_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/pjrt/host_memory_spaces.h" +#include "xla/pjrt/pjrt_device_description.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/topology.h" +#include "xla/python/ifrt/value.h" +#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" +#include "xla/service/computation_placer.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { + +class CompileOnlyMemory + : public llvm::RTTIExtends { + public: + explicit CompileOnlyMemory( + int id, const PjRtMemorySpaceDescription* memory_description, + ifrt::Device* device) + : id_(id), + kind_(memory_description->kind()), + debug_string_(absl::StrFormat("CompileOnlyMemory(id=%d, kind=%s)", id, + memory_description->kind())), + device_(device) {} + + ifrt::MemoryId Id() const override { return ifrt::MemoryId(id_); } + + const ifrt::MemoryKind& Kind() const override { return kind_; } + + absl::string_view ToString() const override { return debug_string_; } + absl::string_view DebugString() const override { return debug_string_; } + + absl::Span Devices() const override { + return absl::Span{&device_, 1}; + } + + static char ID; // NOLINT + + private: + int id_; + ifrt::MemoryKind kind_; + std::string debug_string_; + ifrt::Device* device_; +}; + +class CompileOnlyDevice + : public llvm::RTTIExtends { + public: + explicit CompileOnlyDevice(const PjRtDeviceDescription* description) + : description_(std::move(description)), + attributes_(ifrt::FromPjRtAttributeMap(description_->Attributes())) {} + + const PjRtDeviceDescription& description() const { return *description_; } + + ifrt::Client* client() const override { return nullptr; } + bool IsAddressable() const override { return false; } + ifrt::DeviceId Id() const override { + return ifrt::DeviceId(description_->id()); + } + + int ProcessIndex() const override { return description_->process_index(); } + + absl::string_view Kind() const override { + return description_->device_kind(); + } + + absl::string_view ToString() const override { + return description_->ToString(); + } + + absl::string_view DebugString() const override { + return description_->DebugString(); + } + + absl::Span Memories() const override { + return unowned_memories_; + } + absl::StatusOr DefaultMemory() const override { + if (default_memory_) { + return default_memory_; + } + return Unimplemented("DefaultMemory is not supported"); + } + + const ifrt::AttributeMap& Attributes() const override { return attributes_; } + + void AttachMemory(std::unique_ptr memory) { + unowned_memories_.push_back(memory.get()); + owned_memories_.push_back(std::move(memory)); + } + + void SetDefaultMemory(ifrt::Memory* memory) { default_memory_ = memory; } + + static char ID; // NOLINT + + private: + const PjRtDeviceDescription* description_; + ifrt::AttributeMap attributes_; + ifrt::Memory* default_memory_ = nullptr; + std::vector unowned_memories_; + std::vector> owned_memories_; +}; + +class CompileOnlyIfrtCompiler final + : public llvm::RTTIExtends { + public: + absl::StatusOr> Compile( + std::unique_ptr program, + std::unique_ptr options) override { + return Unimplemented("Compile not implemented."); + } + + absl::StatusOr> Compile( + std::unique_ptr program, const ifrt::Topology& topology, + std::unique_ptr options) override { + return Unimplemented("Compile not implemented."); + } + + absl::StatusOr> + DeserializeLoadedExecutable( + absl::string_view serialized, + std::unique_ptr options) override { + return Unimplemented("DeserializeLoadedExecutable not implemented."); + } + + static char ID; // NOLINT +}; + +class CompileOnlyIfRtClient final + : public llvm::RTTIExtends { + public: + explicit CompileOnlyIfRtClient(std::shared_ptr topology) + : topology_(std::move(topology)), + descriptions_(topology_->DeviceDescriptions()), + attributes_(ifrt::AttributeMap::Map()) { + int offset = 0; + for (auto& description : descriptions_) { + owned_devices_.push_back( + std::make_unique(description.get())); + auto* device = owned_devices_.back().get(); + devices_.push_back(device); + if (description->process_index() == process_index()) { + auto default_memory = description->default_memory_space(); + for (auto* memory_description : description->memory_spaces()) { + auto memory = std::make_unique( + offset, memory_description, device); + if (default_memory.ok() && memory_description == *default_memory) { + device->SetDefaultMemory(memory.get()); + } + device->AttachMemory(std::move(memory)); + ++offset; + } + } + } + } + + absl::StatusOr> MakeArrayFromHostBuffer( + const void* data, ifrt::DType dtype, ifrt::Shape shape, + std::optional> byte_strides, + std::shared_ptr sharding, + HostBufferSemantics semantics, + std::function on_done_with_host_buffer) override { + return Unimplemented( + "MakeArrayFromHostBuffer not available with compile-only client."); + } + + absl::StatusOr> + AssembleArrayFromSingleDeviceArrays( + ifrt::Shape shape, std::shared_ptr sharding, + absl::Span> arrays, + ifrt::ArrayCopySemantics semantics) override { + return Unimplemented( + "AssembleArrayFromSingleDeviceArrays not available with compile-only " + "client."); + } + absl::StatusOr> + AssembleArrayFromSingleDeviceArrays( + ifrt::Shape shape, std::shared_ptr sharding, + absl::Span> arrays, + ifrt::ArrayCopySemantics array_copy_semantics, + ifrt::SingleDeviceShardSemantics single_device_shard_semantics) override { + return Unimplemented( + "AssembleArrayFromSingleDeviceArrays not available with compile-only " + "client."); + } + absl::StatusOr> + AssembleArrayFromSingleDeviceArrays( + ifrt::DType dtype, ifrt::Shape shape, + std::shared_ptr sharding, + absl::Span> arrays, + ifrt::ArrayCopySemantics array_copy_semantics, + ifrt::SingleDeviceShardSemantics single_device_shard_semantics) override { + return Unimplemented( + "AssembleArrayFromSingleDeviceArrays not available with compile-only " + "client."); + } + + absl::StatusOr>> CopyArrays( + absl::Span> arrays, + std::optional> devices, + std::optional memory_kind, + ifrt::ArrayCopySemantics semantics) override { + return Unimplemented("CopyArrays not available with compile-only client."); + } + + absl::StatusOr>> RemapArrays( + const ifrt::RemapPlan& plan, + absl::Span> arrays, + ifrt::ArrayCopySemantics semantics) override { + return Unimplemented("RemapArrays not available with compile-only client."); + } + + ifrt::Future<> GetReadyFuture( + absl::Span> values) override { + return ifrt::Future<>(Unimplemented( + "GetReadyFuture not available with compile-only client.")); + } + + absl::StatusOr> MakeTuple( + absl::Span> values) override { + return Unimplemented("MakeTuple not available with compile-only client."); + } + + absl::string_view runtime_type() const override { + return "compile_only_runtime"; + } + + absl::string_view platform_name() const override { + return topology_->platform_name(); + } + absl::string_view platform_version() const override { + return topology_->platform_version(); + } + ifrt::PlatformId platform_id() const override { + return topology_->platform_id(); + } + const ifrt::AttributeMap& Attributes() const override { return attributes_; } + + int device_count() const override { return devices().size(); } + int addressable_device_count() const override { return 0; } + absl::Span devices() const override { return devices_; } + absl::Span addressable_devices() const override { + return {}; + } + int process_index() const override { return 0; } + absl::Span GetAllDevices() const override { + return devices_; + } + absl::StatusOr GetDefaultDeviceAssignment( + int num_replicas, int num_partitions) const override { + return Unimplemented( + "GetDefaultDeviceAssignment not available with compile-only client."); + } + absl::StatusOr LookupDevice( + ifrt::DeviceId device_id) const override { + return Unimplemented( + "LookupDevice not available with compile-only client."); + } + + absl::StatusOr LookupAddressableDevice( + int local_hardware_id) const override { + return Unimplemented( + "LookupAddressableDevice not available with compile-only client."); + } + + tsl::RCReference MakeDeviceList( + absl::Span devices) const override { + return ifrt::BasicDeviceList::Create(devices); + } + + ifrt::Compiler* GetDefaultCompiler() override { return &default_compiler_; } + + static char ID; // NOLINT + + const ifrt::PjRtTopology& topology() const { return *topology_; } + + absl::StatusOr> GetTopologyForDevices( + const tsl::RCReference& devices) const override { + return topology_; + } + + absl::StatusOr> GetDefaultLayout( + ifrt::DType dtype, absl::Span dims, ifrt::Device* device, + ifrt::MemoryKind memory_kind) const override { + if (memory_kind == ifrt::MemoryKind(UnpinnedHostMemorySpace::kKind)) { + return std::make_shared( + LayoutUtil::MakeDescendingLayout(dims.size())); + } + TF_ASSIGN_OR_RETURN(PrimitiveType element_type, ToPrimitiveType(dtype)); + TF_ASSIGN_OR_RETURN(xla::Layout layout, + topology_->GetDefaultLayout(element_type, dims)); + return std::make_shared(std::move(layout)); + } + + private: + CompileOnlyIfrtCompiler default_compiler_; + std::shared_ptr topology_; + std::vector> descriptions_; + ifrt::AttributeMap attributes_; + std::vector> owned_devices_; + std::vector devices_; +}; + +} // namespace xla + +#endif // XLA_PYTHON_COMPILE_ONLY_IFRT_CLIENT_H_ diff --git a/xla/python/py_compile_only_client.cc b/xla/python/py_compile_only_client.cc index 99ddce8986bee..7423d2f38bf73 100644 --- a/xla/python/py_compile_only_client.cc +++ b/xla/python/py_compile_only_client.cc @@ -15,21 +15,14 @@ limitations under the License. #include "xla/python/py_compile_only_client.h" -#include -#include #include -#include -#include #include #include -#include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" -#include "absl/types/span.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/ExtensibleRTTI.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" @@ -37,41 +30,17 @@ limitations under the License. #include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "xla/layout.h" -#include "xla/layout_util.h" -#include "xla/pjrt/host_memory_spaces.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_compiler.h" -#include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_executable.h" -#include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/status_casters.h" -#include "xla/python/ifrt/array.h" -#include "xla/python/ifrt/attribute_map.h" -#include "xla/python/ifrt/client.h" -#include "xla/python/ifrt/compiler.h" -#include "xla/python/ifrt/device.h" -#include "xla/python/ifrt/device_list.h" -#include "xla/python/ifrt/dtype.h" +#include "xla/python/compile_only_ifrt/client.h" #include "xla/python/ifrt/executable.h" -#include "xla/python/ifrt/future.h" -#include "xla/python/ifrt/memory.h" -#include "xla/python/ifrt/program.h" -#include "xla/python/ifrt/remap_plan.h" -#include "xla/python/ifrt/shape.h" -#include "xla/python/ifrt/sharding.h" -#include "xla/python/ifrt/topology.h" -#include "xla/python/ifrt/tuple.h" -#include "xla/python/ifrt/value.h" #include "xla/python/nb_class_ptr.h" -#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" -#include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" #include "xla/python/py_client.h" -#include "xla/service/computation_placer.h" -#include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" @@ -85,299 +54,6 @@ namespace xla { namespace { -class CompileOnlyMemory - : public llvm::RTTIExtends { - public: - explicit CompileOnlyMemory( - int id, const PjRtMemorySpaceDescription* memory_description, - ifrt::Device* device) - : id_(id), - kind_(memory_description->kind()), - debug_string_(absl::StrFormat("CompileOnlyMemory(id=%d, kind=%s)", id, - memory_description->kind())), - device_(device) {} - - ifrt::MemoryId Id() const override { return ifrt::MemoryId(id_); } - - const ifrt::MemoryKind& Kind() const override { return kind_; } - - absl::string_view ToString() const override { return debug_string_; } - absl::string_view DebugString() const override { return debug_string_; } - - absl::Span Devices() const override { - return absl::Span{&device_, 1}; - } - - static char ID; // NOLINT - - private: - int id_; - ifrt::MemoryKind kind_; - std::string debug_string_; - ifrt::Device* device_; -}; - -[[maybe_unused]] char CompileOnlyMemory::ID = 0; - -class CompileOnlyDevice - : public llvm::RTTIExtends { - public: - explicit CompileOnlyDevice(const PjRtDeviceDescription* description) - : description_(std::move(description)), - attributes_(ifrt::FromPjRtAttributeMap(description_->Attributes())) {} - - const PjRtDeviceDescription& description() const { return *description_; } - - ifrt::Client* client() const override { return nullptr; } - bool IsAddressable() const override { return false; } - ifrt::DeviceId Id() const override { - return ifrt::DeviceId(description_->id()); - } - - int ProcessIndex() const override { return description_->process_index(); } - - absl::string_view Kind() const override { - return description_->device_kind(); - } - - absl::string_view ToString() const override { - return description_->ToString(); - } - - absl::string_view DebugString() const override { - return description_->DebugString(); - } - - absl::Span Memories() const override { - return unowned_memories_; - } - absl::StatusOr DefaultMemory() const override { - if (default_memory_) { - return default_memory_; - } - return Unimplemented("DefaultMemory is not supported"); - } - - const ifrt::AttributeMap& Attributes() const override { return attributes_; } - - void AttachMemory(std::unique_ptr memory) { - unowned_memories_.push_back(memory.get()); - owned_memories_.push_back(std::move(memory)); - } - - void SetDefaultMemory(ifrt::Memory* memory) { default_memory_ = memory; } - - private: - const PjRtDeviceDescription* description_; - ifrt::AttributeMap attributes_; - ifrt::Memory* default_memory_ = nullptr; - std::vector unowned_memories_; - std::vector> owned_memories_; -}; - -class InvalidIfrtCompiler final - : public llvm::RTTIExtends { - public: - absl::StatusOr> Compile( - std::unique_ptr program, - std::unique_ptr options) override { - return Unimplemented("Compile not implemented."); - } - - absl::StatusOr> Compile( - std::unique_ptr program, const ifrt::Topology& topology, - std::unique_ptr options) override { - return Unimplemented("Compile not implemented."); - } - - absl::StatusOr> - DeserializeLoadedExecutable( - absl::string_view serialized, - std::unique_ptr options) override { - return Unimplemented("DeserializeLoadedExecutable not implemented."); - } - - static char ID; // NOLINT -}; -[[maybe_unused]] char InvalidIfrtCompiler::ID = 0; - -class CompileOnlyIfRtClient final - : public llvm::RTTIExtends { - public: - explicit CompileOnlyIfRtClient(std::shared_ptr topology) - : topology_(std::move(topology)), - descriptions_(topology_->DeviceDescriptions()), - attributes_(ifrt::AttributeMap::Map()) { - int offset = 0; - for (auto& description : descriptions_) { - owned_devices_.push_back( - std::make_unique(description.get())); - auto* device = owned_devices_.back().get(); - devices_.push_back(device); - if (description->process_index() == process_index()) { - auto default_memory = description->default_memory_space(); - for (auto* memory_description : description->memory_spaces()) { - auto memory = std::make_unique( - offset, memory_description, device); - if (default_memory.ok() && memory_description == *default_memory) { - device->SetDefaultMemory(memory.get()); - } - device->AttachMemory(std::move(memory)); - ++offset; - } - } - } - } - - absl::StatusOr> MakeArrayFromHostBuffer( - const void* data, ifrt::DType dtype, ifrt::Shape shape, - std::optional> byte_strides, - std::shared_ptr sharding, - HostBufferSemantics semantics, - std::function on_done_with_host_buffer) override { - return Unimplemented( - "MakeArrayFromHostBuffer not available with compile-only client."); - } - - absl::StatusOr> - AssembleArrayFromSingleDeviceArrays( - ifrt::Shape shape, std::shared_ptr sharding, - absl::Span> arrays, - ifrt::ArrayCopySemantics semantics) override { - return Unimplemented( - "AssembleArrayFromSingleDeviceArrays not available with compile-only " - "client."); - } - absl::StatusOr> - AssembleArrayFromSingleDeviceArrays( - ifrt::Shape shape, std::shared_ptr sharding, - absl::Span> arrays, - ifrt::ArrayCopySemantics array_copy_semantics, - ifrt::SingleDeviceShardSemantics single_device_shard_semantics) override { - return Unimplemented( - "AssembleArrayFromSingleDeviceArrays not available with compile-only " - "client."); - } - absl::StatusOr> - AssembleArrayFromSingleDeviceArrays( - ifrt::DType dtype, ifrt::Shape shape, - std::shared_ptr sharding, - absl::Span> arrays, - ifrt::ArrayCopySemantics array_copy_semantics, - ifrt::SingleDeviceShardSemantics single_device_shard_semantics) override { - return Unimplemented( - "AssembleArrayFromSingleDeviceArrays not available with compile-only " - "client."); - } - - absl::StatusOr>> CopyArrays( - absl::Span> arrays, - std::optional> devices, - std::optional memory_kind, - ifrt::ArrayCopySemantics semantics) override { - return Unimplemented("CopyArrays not available with compile-only client."); - } - - absl::StatusOr>> RemapArrays( - const ifrt::RemapPlan& plan, - absl::Span> arrays, - ifrt::ArrayCopySemantics semantics) override { - return Unimplemented("RemapArrays not available with compile-only client."); - } - - ifrt::Future<> GetReadyFuture( - absl::Span> values) override { - return ifrt::Future<>(Unimplemented( - "GetReadyFuture not available with compile-only client.")); - } - - absl::StatusOr> MakeTuple( - absl::Span> values) override { - return Unimplemented("MakeTuple not available with compile-only client."); - } - - absl::string_view runtime_type() const override { - return "compile_only_runtime"; - } - - absl::string_view platform_name() const override { - return topology_->platform_name(); - } - absl::string_view platform_version() const override { - return topology_->platform_version(); - } - ifrt::PlatformId platform_id() const override { - return topology_->platform_id(); - } - const ifrt::AttributeMap& Attributes() const override { return attributes_; } - - int device_count() const override { return devices().size(); } - int addressable_device_count() const override { return 0; } - absl::Span devices() const override { return devices_; } - absl::Span addressable_devices() const override { - return {}; - } - int process_index() const override { return 0; } - absl::Span GetAllDevices() const override { - return devices_; - } - absl::StatusOr GetDefaultDeviceAssignment( - int num_replicas, int num_partitions) const override { - return Unimplemented( - "GetDefaultDeviceAssignment not available with compile-only client."); - } - absl::StatusOr LookupDevice( - ifrt::DeviceId device_id) const override { - return Unimplemented( - "LookupDevice not available with compile-only client."); - } - - absl::StatusOr LookupAddressableDevice( - int local_hardware_id) const override { - return Unimplemented( - "LookupAddressableDevice not available with compile-only client."); - } - - tsl::RCReference MakeDeviceList( - absl::Span devices) const override { - return ifrt::BasicDeviceList::Create(devices); - } - - ifrt::Compiler* GetDefaultCompiler() override { return &default_compiler_; } - - static char ID; // NOLINT - - const ifrt::PjRtTopology& topology() const { return *topology_; } - - absl::StatusOr> GetTopologyForDevices( - const tsl::RCReference& devices) const override { - return topology_; - } - - absl::StatusOr> GetDefaultLayout( - ifrt::DType dtype, absl::Span dims, ifrt::Device* device, - ifrt::MemoryKind memory_kind) const override { - if (memory_kind == ifrt::MemoryKind(UnpinnedHostMemorySpace::kKind)) { - return std::make_shared( - LayoutUtil::MakeDescendingLayout(dims.size())); - } - TF_ASSIGN_OR_RETURN(PrimitiveType element_type, ToPrimitiveType(dtype)); - TF_ASSIGN_OR_RETURN(xla::Layout layout, - topology_->GetDefaultLayout(element_type, dims)); - return std::make_shared(std::move(layout)); - } - - private: - InvalidIfrtCompiler default_compiler_; - std::shared_ptr topology_; - std::vector> descriptions_; - ifrt::AttributeMap attributes_; - std::vector> owned_devices_; - std::vector devices_; -}; - -[[maybe_unused]] char CompileOnlyIfRtClient::ID = 0; - class CompileOnlyPyClient : public PyClient { public: using PyClient::PyClient;