diff --git a/xla/backends/gpu/runtime/nccl_collective_permute_thunk.cc b/xla/backends/gpu/runtime/nccl_collective_permute_thunk.cc index b777adcb43907..24685afa7ce73 100644 --- a/xla/backends/gpu/runtime/nccl_collective_permute_thunk.cc +++ b/xla/backends/gpu/runtime/nccl_collective_permute_thunk.cc @@ -177,14 +177,15 @@ absl::Status NcclCollectivePermuteStartThunk::Initialize( if (p2p_memcpy_enabled_) { TF_ASSIGN_OR_RETURN(const int64_t current_id, GetCurrentId(params.collective_params, config_)); - absl::MutexLock lock(&barrier_mutex_); - if (barrier_flags_.find(current_id) == barrier_flags_.end()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr alloc, - params.stream->parent()->HostMemoryAllocate(sizeof(uint8_t))); - barrier_flags_[current_id] = std::move(alloc); + { + absl::MutexLock lock(&barrier_mutex_); + if (receiver_barrier_events_.find(current_id) == + receiver_barrier_events_.end()) { + TF_ASSIGN_OR_RETURN(auto receiver_event, + params.executor->CreateEvent()); + receiver_barrier_events_.emplace(current_id, std::move(receiver_event)); + } } - TF_ASSIGN_OR_RETURN( std::vector device_buffers, ConvertToDeviceBuffers(params.buffer_allocations, {buffers_}, @@ -214,6 +215,18 @@ absl::Status NcclCollectivePermuteStartThunk::Initialize( return absl::OkStatus(); } +struct CallRendezvousKey { + RunId run_id; + + template + friend H AbslHashValue(H h, const CallRendezvousKey& key) { + return H::combine(std::move(h), key.run_id); + } +}; + +bool operator==(const CallRendezvousKey& a, const CallRendezvousKey& b) { + return a.run_id == b.run_id; +} absl::Status NcclCollectivePermuteStartThunk::RunNcclCollective( const ExecuteParams& params, se::Stream& stream, @@ -238,11 +251,34 @@ absl::Status NcclCollectivePermuteStartThunk::RunNcclCollective( TF_ASSIGN_OR_RETURN(GpuCollectives * collectives, GetGpuCollectives(params)); if (use_memcpy) { - se::DeviceMemoryBase sync_var_address = - se::DeviceMemoryBase(barrier_flags_[current_id]->opaque()); - TF_RETURN_IF_ERROR(comm_handle.comm->AllReduce( - sync_var_address, sync_var_address, PrimitiveType::U8, 1, - ReductionKind::MIN, GpuCollectives::On(stream))); + std::optional source_id = source_target.source; + std::optional target_id = source_target.target; + // Due to the one-sided push mechanism of memcpy p2p, we need to make sure + // the buffer on the receiving side is ready before sender pushes the data. + // Receiving side will record an event and the sender will wait for the + // event before proceeding. + if (source_id) { + absl::MutexLock lock(&barrier_mutex_); + auto receiver_event = receiver_barrier_events_.find(current_id); + TF_RETURN_IF_ERROR(stream.RecordEvent(receiver_event->second.get())); + } + auto rendezvous_name = absl::StrFormat( + "rendezvous of collective-permute; run_id=%d; op id:%d", + params.collective_params->run_id.ToInt(), config_.config.op_id); + auto rendezvous_key = CallRendezvousKey{params.collective_params->run_id}; + + // Perform a rendezvous to make sure all receivers have their events + // recorded. + Rendezvous(rendezvous_name, rendezvous_key, device_count_, + /*warn_stuck_timeout=*/absl::Seconds(20), + /*terminate_timeout=*/absl::Seconds(40)); + + // For sending side, wait for the recorded event from the receiving side. + if (target_id) { + absl::MutexLock lock(&barrier_mutex_); + auto receiver_event = receiver_barrier_events_.find(*target_id); + TF_RETURN_IF_ERROR(stream.WaitFor(receiver_event->second.get())); + } } return ::xla::gpu::RunCollectivePermute( diff --git a/xla/backends/gpu/runtime/nccl_collective_permute_thunk.h b/xla/backends/gpu/runtime/nccl_collective_permute_thunk.h index f91840fe1d76f..e6e1b6c8fbf78 100644 --- a/xla/backends/gpu/runtime/nccl_collective_permute_thunk.h +++ b/xla/backends/gpu/runtime/nccl_collective_permute_thunk.h @@ -121,8 +121,8 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk { std::vector buffers_; RecvPtrMap recv_ptr_map_; absl::Mutex barrier_mutex_; - std::unordered_map> - barrier_flags_; + std::unordered_map> + receiver_barrier_events_; bool p2p_memcpy_enabled_ = false; int64_t device_count_; };