diff --git a/xla/pjrt/BUILD b/xla/pjrt/BUILD index 36bb2e215ec4d..947e0ee21677c 100644 --- a/xla/pjrt/BUILD +++ b/xla/pjrt/BUILD @@ -100,10 +100,12 @@ cc_library( "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:event", "//xla/tsl/concurrency:async_value", + "//xla/tsl/concurrency:ref_count", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", @@ -132,6 +134,7 @@ xla_cc_test( "//xla/hlo/testlib:test", "//xla/service:cpu_plugin", "//xla/stream_executor:device_memory_allocator", + "//xla/tsl/concurrency:ref_count", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -543,6 +546,7 @@ cc_library( "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:stream", + "//xla/tsl/concurrency:ref_count", "//xla/tsl/framework:allocator", "//xla/tsl/platform:env", "@com_google_absl//absl/algorithm:container", diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 3399a5e1f51c5..71d584b24af26 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -208,7 +208,7 @@ class AsyncHostToDeviceTransferManager buffer_sizes_.reserve(buffer_ptrs_.size()); for (const auto& ptr : buffer_ptrs_) { DCHECK_EQ(ptr->device_memory().size(), 1); - buffer_sizes_.push_back(ptr->device_memory()[0].size()); + buffer_sizes_.push_back(ptr->device_memory()[0]->mem().size()); } last_transfer_started_.resize(buffer_ptrs_.size(), false); } @@ -375,7 +375,7 @@ class AsyncHostToDeviceTransferManager buffer_index); } DCHECK_EQ(buffer_ptrs_[buffer_index]->device_memory().size(), 1); - auto& buffer_memory = buffer_ptrs_[buffer_index]->device_memory()[0]; + auto& buffer_memory = buffer_ptrs_[buffer_index]->device_memory()[0]->mem(); se::DeviceMemoryBase sub_buffer; CHECK_LE(offset, buffer_memory.size()); CHECK_LE(transfer_size, buffer_memory.size() - offset); @@ -676,7 +676,7 @@ PjRtFuture<> StreamExecutorGpuClient::CopyRawSubBufferToHost( return; } - auto& device_memory = device_buffer->device_memory()[0]; + auto& device_memory = device_buffer->device_memory()[0]->mem(); if (offset < 0 || offset > device_memory.size() || device_memory.size() - offset < transfer_size) { promise.Set( diff --git a/xla/pjrt/pjrt_stream_executor_client.cc b/xla/pjrt/pjrt_stream_executor_client.cc index a6043252ac78c..f87e918163f6a 100644 --- a/xla/pjrt/pjrt_stream_executor_client.cc +++ b/xla/pjrt/pjrt_stream_executor_client.cc @@ -136,6 +136,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream.h" +#include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/framework/allocator.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -609,10 +610,11 @@ void PjRtStreamExecutorBuffer::ScopedHold::ConvertUsageHold( SetState(kConverted); } -void PjRtStreamExecutorBuffer::ScopedHold::ConfirmDonation() { +void PjRtStreamExecutorBuffer::ScopedHold::ConfirmDonation( + bool unsafe_release) { CHECK(ok()); CHECK_EQ(type_, kDonation); - parent_->ConfirmDonation(buffer().get()); + parent_->ConfirmDonation(buffer().get(), unsafe_release); SetState(kDonated); } @@ -679,7 +681,7 @@ class ScopedHoldAsExternalReference : public PjRtBuffer::ExternalReference { : external_reference_(std::move(hold)) { CHECK(external_reference_.type() == PjRtStreamExecutorBuffer::ScopedHold::kExternalReference); - data_ptr_ = external_reference_->device_memory().front().opaque(); + data_ptr_ = external_reference_->device_memory().front()->opaque(); } ~ScopedHoldAsExternalReference() override = default; @@ -713,7 +715,7 @@ class TrackedDeviceBufferExternalReference explicit TrackedDeviceBufferExternalReference( std::shared_ptr tracked_device_buffer) : tracked_device_buffer_(std::move(tracked_device_buffer)) { - data_ptr_ = tracked_device_buffer_->device_memory()[0].opaque(); + data_ptr_ = tracked_device_buffer_->device_memory()[0]->opaque(); } ~TrackedDeviceBufferExternalReference() override = default; @@ -756,7 +758,7 @@ PjRtStreamExecutorBuffer::DonateWithControlDependency(PjRtFuture<> dependency) { } // Copy all the data in the existing tracked_buffer. - absl::InlinedVector buffers( + absl::InlinedVector, 4> buffers( tracked_buffer->device_memory().begin(), tracked_buffer->device_memory().end()); auto original_definition_events = tracked_buffer->definition_events(); @@ -773,9 +775,7 @@ PjRtStreamExecutorBuffer::DonateWithControlDependency(PjRtFuture<> dependency) { original_definition_events.end()); auto new_device_buffer = std::make_shared( - tracked_buffer->allocator(), device(), std::move(buffers), - std::move(definition_events), - /*on_delete_callback=*/nullptr); + device(), std::move(buffers), std::move(definition_events)); // Make the new buffer which is identical to the old, except for the new // definition event. @@ -800,7 +800,7 @@ PjRtStreamExecutorBuffer::DonateWithControlDependency(PjRtFuture<> dependency) { local_device->ReturnStreamToPool(std::move(stream)); }); - tracked_buffer.ConfirmDonation(); + tracked_buffer.ConfirmDonation(false); return new_buffer; } @@ -958,7 +958,8 @@ PjRtStreamExecutorClient::BufferFromHostBufferInternal( // memory that has already been allocated, and a possible Event // allocation. - se::DeviceMemoryBase device_memory = device_buffer->device_memory()[0]; + se::DeviceMemoryBase device_memory = + device_buffer->device_memory()[0]->mem(); // If applicable on the backend, stage the transfer via host memory // allocated via the host_memory_allocator. On GPU, this is pinned @@ -1082,12 +1083,9 @@ PjRtStreamExecutorClient::CreateErrorBuffer(absl::Status error, definition_event->SetDefinedStatus(error); // Create an empty buffer. - auto* se_client = tensorflow::down_cast(this); - absl::Span buffers; auto dummy_device_buffer = std::make_shared( - se_client->allocator(), device, buffers, - absl::MakeSpan(&definition_event, 1), - /*on_delete_callback=*/nullptr); + device, absl::Span>(), + absl::MakeSpan(&definition_event, 1)); auto py_buffer = std::make_unique( shape, std::move(dummy_device_buffer), this, device, @@ -1243,7 +1241,9 @@ PjRtStreamExecutorClient::CreateViewOfDeviceBuffer( CHECK_EQ(memory_space->devices().size(), 1); auto* device = memory_space->devices().front(); - se::DeviceMemoryBase buffer(device_ptr, ShapeUtil::ByteSizeOf(shape)); + auto buffer = RawSEDeviceMemory::CreateForeign( + se::DeviceMemoryBase(device_ptr, ShapeUtil::ByteSizeOf(shape)), + std::move(on_delete_callback)); TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, tensorflow::down_cast(device) @@ -1268,9 +1268,9 @@ PjRtStreamExecutorClient::CreateViewOfDeviceBuffer( definition_stream); auto device_buffer = std::make_shared( - /*allocator=*/nullptr, device, - std::initializer_list{buffer}, definition_events, - std::move(on_delete_callback)); + device, + std::initializer_list>{buffer}, + definition_events); return std::unique_ptr(std::make_unique( shape, std::move(device_buffer), this, device, device->default_memory_space().value_or(nullptr))); @@ -1597,7 +1597,7 @@ void PjRtStreamExecutorBuffer::ConvertUsageHold( } void PjRtStreamExecutorBuffer::ConfirmDonation( - TrackedDeviceBuffer* device_buffer) { + TrackedDeviceBuffer* device_buffer, bool unsafe_release) { { absl::MutexLock lock(&mu_); CHECK_EQ(holds_[ScopedHold::kUsage], 0); @@ -1609,7 +1609,7 @@ void PjRtStreamExecutorBuffer::ConfirmDonation( device_buffer->LockUseAndTransferUsageEvents(); // Give up ownership of the device memory so we don't free it when the last // reference to device_buffer_ goes away. - device_buffer->ReleaseDeviceMemory(); + device_buffer->ReleaseDeviceMemory(unsafe_release); // Make *this invalid so it can't be used again. Any threads blocking in // Release or GetBufferWithHold will see an invalid buffer and return. device_buffer_.reset(); @@ -1758,7 +1758,7 @@ absl::StatusOr PjRtStreamExecutorBuffer::GetOnDeviceSizeInBytes() return InvalidArgument( "GetOnDeviceSizeInBytes called on tuple-shaped buffer"); } - return device_buffer_->device_memory()[0].size(); + return device_buffer_->device_memory()[0]->mem().size(); } PjRtFuture<> PjRtStreamExecutorBuffer::CopyRawToHost(void* dst, int64_t offset, @@ -3064,7 +3064,7 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper( // Even though there was an error we need to call ConfirmDonation, which // renders b invalid, since the computation has been enqueued and b has // been donated. - b.ConfirmDonation(); + b.ConfirmDonation(true); } } return event_or.status(); @@ -3085,7 +3085,7 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper( stream, &buffers_to_release); } else { CHECK(b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation); - b.ConfirmDonation(); + b.ConfirmDonation(true); } } diff --git a/xla/pjrt/pjrt_stream_executor_client.h b/xla/pjrt/pjrt_stream_executor_client.h index 2654372331cb3..86e98de8edafa 100644 --- a/xla/pjrt/pjrt_stream_executor_client.h +++ b/xla/pjrt/pjrt_stream_executor_client.h @@ -621,7 +621,8 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { // Confirms that the buffer was successfully donated to an execution. // Only valid for holds of type kDonation. Causes the buffer to become // invalid. - void ConfirmDonation(); + // TODO(parkers): Only allow safe releases. + void ConfirmDonation(bool unsafe_release); // Adds the held device buffers in order to 'iterator'. Used to add the // buffers to an ExecutionInput. We require but do not verify that @@ -819,7 +820,7 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { // Drops a donation hold and makes *this invalid for further use. Does a // sanity check that buffer==device_buffer_. Called after device_buffer_ was // successfully donated to an execution. - void ConfirmDonation(TrackedDeviceBuffer* device_buffer); + void ConfirmDonation(TrackedDeviceBuffer* device_buffer, bool unsafe_release); // Drops a hold without taking any other action. Does a sanity check that // buffer==device_buffer_ or device_buffer_==nullptr. diff --git a/xla/pjrt/tracked_device_buffer.cc b/xla/pjrt/tracked_device_buffer.cc index ca551c4aa81b8..0b406519e0caa 100644 --- a/xla/pjrt/tracked_device_buffer.cc +++ b/xla/pjrt/tracked_device_buffer.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/functional/any_invocable.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" @@ -41,6 +42,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/event.h" +#include "xla/tsl/concurrency/ref_count.h" #include "tsl/platform/logging.h" #include "tsl/profiler/lib/connected_traceme.h" #include "tsl/profiler/lib/context_types.h" @@ -184,6 +186,61 @@ void BufferSequencingEvent::ExecuteFutureTasks() { thread_pool_->Schedule(std::move(call_all_task_callbacks)); } +class AllocatedRawSEDeviceMemory : public RawSEDeviceMemory { + public: + AllocatedRawSEDeviceMemory(se::DeviceMemoryBase value, int device_ordinal, + se::DeviceMemoryAllocator* allocator) + : RawSEDeviceMemory(value), + allocator_(allocator), + device_ordinal_(device_ordinal) {} + + ~AllocatedRawSEDeviceMemory() override { + if (allocator_) { + absl::Status status = allocator_->Deallocate(device_ordinal_, mem()); + if (!status.ok()) { + LOG(ERROR) << "Buffer deallocation failed: " << status; + } + } + } + + void UnsafeReleaseMemory() override { allocator_ = nullptr; } + + private: + se::DeviceMemoryAllocator* allocator_; + int device_ordinal_; +}; + +tsl::RCReference RawSEDeviceMemory::Create( + se::DeviceMemoryBase value, PjRtDevice* device, + se::DeviceMemoryAllocator* allocator) { + return tsl::MakeRef( + value, device->local_device_id().value(), allocator); +} + +class ForeignRawSEDeviceMemory : public RawSEDeviceMemory { + public: + ForeignRawSEDeviceMemory(se::DeviceMemoryBase value, + absl::AnyInvocable on_delete_callback) + : RawSEDeviceMemory(value), + on_delete_callback_(std::move(on_delete_callback)) {} + + ~ForeignRawSEDeviceMemory() override { std::move(on_delete_callback_)(); } + + void UnsafeReleaseMemory() override { + LOG(FATAL) << "ForeignRawSEDeviceMemory cannot be donated."; + } + + private: + absl::AnyInvocable on_delete_callback_; +}; + +tsl::RCReference RawSEDeviceMemory::CreateForeign( + se::DeviceMemoryBase value, + absl::AnyInvocable on_delete_callback) { + return tsl::MakeRef(value, + std::move(on_delete_callback)); +} + /* static */ std::shared_ptr TrackedDeviceBuffer::FromScopedShapedBuffer( ScopedShapedBuffer* shaped_buffer, @@ -191,21 +248,21 @@ TrackedDeviceBuffer::FromScopedShapedBuffer( PjRtDevice* device) { ShapeTree::iterator iterator = shaped_buffer->buffers().begin(); - std::vector buffers; + std::vector> buffers; buffers.reserve(1); ShapeUtil::ForEachSubshape( shaped_buffer->on_device_shape(), [&](const Shape&, const ShapeIndex&) { CHECK(iterator != shaped_buffer->buffers().end()); - buffers.push_back(iterator->second); + buffers.push_back(RawSEDeviceMemory::Create( + iterator->second, device, shaped_buffer->memory_allocator())); iterator->second = se::DeviceMemoryBase(); ++iterator; }); CHECK(iterator == shaped_buffer->buffers().end()); return std::make_shared( - shaped_buffer->memory_allocator(), device, - absl::Span(buffers), definition_events, - /*on_delete_callback=*/nullptr); + device, absl::Span>(buffers), + definition_events); } ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer( @@ -215,9 +272,9 @@ ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer( device_->local_hardware_id().value()); ShapeTree::iterator iterator = shaped_buffer.buffers().begin(); - for (const se::DeviceMemoryBase& buf : device_memory_) { + for (const tsl::RCReference& buf : device_memory_) { CHECK(iterator != shaped_buffer.buffers().end()); - iterator->second = buf; + iterator->second = buf->mem(); ++iterator; } CHECK(iterator == shaped_buffer.buffers().end()); @@ -230,10 +287,10 @@ ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer( void TrackedDeviceBuffer::AddToInputAsImmutable( ShapeTree::iterator* iterator, const ShapeTree::iterator& end) const { - for (const se::DeviceMemoryBase& buf : device_memory_) { + for (const tsl::RCReference& buf : device_memory_) { CHECK(*iterator != end); // Set buffers to be case (1) in the comment on ExecutionInput. - (*iterator)->second = MaybeOwningDeviceMemory(buf); + (*iterator)->second = MaybeOwningDeviceMemory(buf->mem()); ++(*iterator); } } @@ -243,42 +300,35 @@ void TrackedDeviceBuffer::AddToInputAsDonated( const ShapeTree::iterator& end, ExecutionInput* execution_input, se::DeviceMemoryAllocator* allocator) const { - for (const se::DeviceMemoryBase& buf : device_memory_) { + for (const tsl::RCReference& buf : device_memory_) { CHECK(*iterator != end); // Set buffers to be case (2) in the comment on ExecutionInput. (*iterator)->second = MaybeOwningDeviceMemory(se::OwningDeviceMemory( - buf, device_->local_device_id().value(), allocator)); + buf->mem(), device_->local_device_id().value(), allocator)); execution_input->SetUnownedIndex((*iterator)->first); ++(*iterator); } } TrackedDeviceBuffer::TrackedDeviceBuffer( - se::DeviceMemoryAllocator* allocator, PjRtDevice* device, - absl::Span device_memory, - absl::Span> definition_events, - absl::AnyInvocable on_delete_callback) - : allocator_(allocator), - device_(device), + PjRtDevice* device, + absl::Span const> device_memory, + absl::Span> definition_events) + : device_(device), device_memory_(device_memory.begin(), device_memory.end()), definition_events_(std::make_move_iterator(definition_events.begin()), std::make_move_iterator(definition_events.end())), - in_use_(true), - on_delete_callback_(std::move(on_delete_callback)) {} - -TrackedDeviceBuffer::~TrackedDeviceBuffer() { - if (allocator_) { - for (const se::DeviceMemoryBase& buffer : device_memory_) { - absl::Status status = - allocator_->Deallocate(device_->local_device_id().value(), buffer); - if (!status.ok()) { - LOG(ERROR) << "Buffer deallocation failed: " << status; - } + in_use_(true) {} + +TrackedDeviceBuffer::~TrackedDeviceBuffer() = default; + +void TrackedDeviceBuffer::ReleaseDeviceMemory(bool unsafe_release) { + if (unsafe_release) { + for (auto& mem : device_memory_) { + mem->UnsafeReleaseMemory(); } } - if (on_delete_callback_) { - std::move(on_delete_callback_)(); - } + device_memory_.clear(); } void TrackedDeviceBuffer::AddUsageEvent( diff --git a/xla/pjrt/tracked_device_buffer.h b/xla/pjrt/tracked_device_buffer.h index 4dbf7881014e7..4e020b42e96bb 100644 --- a/xla/pjrt/tracked_device_buffer.h +++ b/xla/pjrt/tracked_device_buffer.h @@ -195,6 +195,32 @@ class BufferSequencingEvent { tsl::AsyncValueRef defined_status_ ABSL_GUARDED_BY(mu_); }; +// TODO(parkers): Implement PjRtRawBuffer API. +class RawSEDeviceMemory : public tsl::ReferenceCounted { + public: + explicit RawSEDeviceMemory(se::DeviceMemoryBase value) : value_(value) {} + + virtual ~RawSEDeviceMemory() = default; + + const se::DeviceMemoryBase& mem() const { return value_; } + + void* opaque() const { return value_.opaque(); } + + // TODO(parkers): Donate this ref-counted object instead of the underlying + // buffer. + virtual void UnsafeReleaseMemory() = 0; + + static tsl::RCReference Create( + se::DeviceMemoryBase value, PjRtDevice* device, + se::DeviceMemoryAllocator* allocator); + static tsl::RCReference CreateForeign( + se::DeviceMemoryBase value, + absl::AnyInvocable on_delete_callback); + + private: + se::DeviceMemoryBase value_; +}; + // Class that represents a tuple of device buffers. Like a ScopedShapedBuffer it // owns all of the device memory in the tuple. It also tracks the definition and // usage of the memory on streams, to allow for synchronized usage and deletion @@ -247,11 +273,11 @@ class TrackedDeviceBuffer { ExecutionInput* execution_input, se::DeviceMemoryAllocator* allocator) const; - se::DeviceMemoryAllocator* allocator() const { return allocator_; } - absl::InlinedVector& device_memory() { + absl::InlinedVector, 1>& device_memory() { return device_memory_; } - const absl::InlinedVector& device_memory() const { + const absl::InlinedVector, 1>& + device_memory() const { return device_memory_; } absl::Span> definition_events() @@ -264,7 +290,7 @@ class TrackedDeviceBuffer { // Relinquishes ownership of the buffer's device memory, e.g., after the // buffer is passed to a computation that aliases its inputs to outputs. - void ReleaseDeviceMemory() { device_memory_.clear(); } + void ReleaseDeviceMemory(bool unsafe_release); // Indicates that the buffer has been used on a stream. // @@ -287,21 +313,18 @@ class TrackedDeviceBuffer { StreamAndEventContainer LockUseAndTransferUsageEvents(); TrackedDeviceBuffer() : in_use_(true) {} - TrackedDeviceBuffer(se::DeviceMemoryAllocator* allocator, PjRtDevice* device, - absl::Span device_memory, - absl::Span> - definition_events, - absl::AnyInvocable on_delete_callback); + TrackedDeviceBuffer( + PjRtDevice* device, + absl::Span const> device_memory, + absl::Span> + definition_events); ~TrackedDeviceBuffer(); private: - // Are the buffers in device_memory_ owned? If so, which allocator and device? - // May be nullptr, indicating the buffers are not owned. - se::DeviceMemoryAllocator* allocator_; PjRtDevice* device_; // Each host-side buffer may have several buffers on-device. - absl::InlinedVector device_memory_; + absl::InlinedVector, 1> device_memory_; // Events that are triggered when the content of one or more buffers is ready // during multistream execution. May be nullptr, which is used in the diff --git a/xla/pjrt/tracked_device_buffer_test.cc b/xla/pjrt/tracked_device_buffer_test.cc index b8d4b61a75dd4..dd5d883b0921c 100644 --- a/xla/pjrt/tracked_device_buffer_test.cc +++ b/xla/pjrt/tracked_device_buffer_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/tsl/concurrency/ref_count.h" #include "xla/util.h" #include "tsl/platform/statusor.h" @@ -81,7 +82,7 @@ class TestDevice : public PjRtDevice { absl::StatusOr> MakeArray( const Shape& shape, LocalClient* client, PjRtDevice* device) { - std::vector device_buffers; + std::vector> device_buffers; TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( client->backend().transfer_manager()->HostShapeToDeviceShape(shape), [&](const Shape& subshape, const ShapeIndex&) -> absl::Status { @@ -91,12 +92,14 @@ absl::StatusOr> MakeArray( /*device_ordinal=*/0, client->backend().transfer_manager()->GetByteSizeRequirement( subshape))); - device_buffers.push_back(device_memory.Release()); + device_buffers.push_back( + RawSEDeviceMemory::Create(device_memory.Release(), device, + client->backend().memory_allocator())); return absl::OkStatus(); })); return std::make_shared( - client->backend().memory_allocator(), device, device_buffers, - absl::Span>(), nullptr); + device, device_buffers, + absl::Span>()); } TEST(TrackedDeviceBufferTest, AsShapedBuffer) { @@ -114,8 +117,8 @@ TEST(TrackedDeviceBufferTest, AsShapedBuffer) { ASSERT_EQ(b_buffer->device_memory().size(), 1); ASSERT_EQ(c_buffer->device_memory().size(), 1); std::vector expected_buffer_sequence = { - a_buffer->device_memory()[0], b_buffer->device_memory()[0], - c_buffer->device_memory()[0]}; + a_buffer->device_memory()[0]->mem(), b_buffer->device_memory()[0]->mem(), + c_buffer->device_memory()[0]->mem()}; ShapedBuffer shaped_a = a_buffer->AsShapedBuffer( client->backend().transfer_manager()->HostShapeToDeviceShape(a_shape)); ShapedBuffer shaped_b = b_buffer->AsShapedBuffer(