Skip to content

Commit

Permalink
PR #22588: Use cuda event and Rendezvous instead of nccl allreduce as…
Browse files Browse the repository at this point in the history
… a barrier

Imported from GitHub PR #22588

Copybara import of the project:

--
5acbea5 by TJ Xu <[email protected]>:

Use cuda event and Rendezvous instead of nccl allreduce as a barrier

--
3d78e81 by TJ Xu <[email protected]>:

Improve comment for the motivation

Merging this change closes #22588

COPYBARA_INTEGRATE_REVIEW=#22588 from Tixxx:tixxx/event_barrier 3d78e81
PiperOrigin-RevId: 726643328
  • Loading branch information
Tixxx authored and Google-ML-Automation committed Feb 13, 2025
1 parent 4b352b7 commit 405efd8
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 14 deletions.
60 changes: 48 additions & 12 deletions xla/backends/gpu/runtime/nccl_collective_permute_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,15 @@ absl::Status NcclCollectivePermuteStartThunk::Initialize(
if (p2p_memcpy_enabled_) {
TF_ASSIGN_OR_RETURN(const int64_t current_id,
GetCurrentId(params.collective_params, config_));
absl::MutexLock lock(&barrier_mutex_);
if (barrier_flags_.find(current_id) == barrier_flags_.end()) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<se::MemoryAllocation> alloc,
params.stream->parent()->HostMemoryAllocate(sizeof(uint8_t)));
barrier_flags_[current_id] = std::move(alloc);
{
absl::MutexLock lock(&barrier_mutex_);
if (receiver_barrier_events_.find(current_id) ==
receiver_barrier_events_.end()) {
TF_ASSIGN_OR_RETURN(auto receiver_event,
params.executor->CreateEvent());
receiver_barrier_events_.emplace(current_id, std::move(receiver_event));
}
}

TF_ASSIGN_OR_RETURN(
std::vector<DeviceBufferPair> device_buffers,
ConvertToDeviceBuffers(params.buffer_allocations, {buffers_},
Expand Down Expand Up @@ -214,6 +215,18 @@ absl::Status NcclCollectivePermuteStartThunk::Initialize(

return absl::OkStatus();
}
struct CallRendezvousKey {
RunId run_id;

template <typename H>
friend H AbslHashValue(H h, const CallRendezvousKey& key) {
return H::combine(std::move(h), key.run_id);
}
};

bool operator==(const CallRendezvousKey& a, const CallRendezvousKey& b) {
return a.run_id == b.run_id;
}

absl::Status NcclCollectivePermuteStartThunk::RunNcclCollective(
const ExecuteParams& params, se::Stream& stream,
Expand All @@ -238,11 +251,34 @@ absl::Status NcclCollectivePermuteStartThunk::RunNcclCollective(

TF_ASSIGN_OR_RETURN(GpuCollectives * collectives, GetGpuCollectives(params));
if (use_memcpy) {
se::DeviceMemoryBase sync_var_address =
se::DeviceMemoryBase(barrier_flags_[current_id]->opaque());
TF_RETURN_IF_ERROR(comm_handle.comm->AllReduce(
sync_var_address, sync_var_address, PrimitiveType::U8, 1,
ReductionKind::MIN, GpuCollectives::On(stream)));
std::optional<int64_t> source_id = source_target.source;
std::optional<int64_t> target_id = source_target.target;
// Due to the one-sided push mechanism of memcpy p2p, we need to make sure
// the buffer on the receiving side is ready before sender pushes the data.
// Receiving side will record an event and the sender will wait for the
// event before proceeding.
if (source_id) {
absl::MutexLock lock(&barrier_mutex_);
auto receiver_event = receiver_barrier_events_.find(current_id);
TF_RETURN_IF_ERROR(stream.RecordEvent(receiver_event->second.get()));
}
auto rendezvous_name = absl::StrFormat(
"rendezvous of collective-permute; run_id=%d; op id:%d",
params.collective_params->run_id.ToInt(), config_.config.op_id);
auto rendezvous_key = CallRendezvousKey{params.collective_params->run_id};

// Perform a rendezvous to make sure all receivers have their events
// recorded.
Rendezvous(rendezvous_name, rendezvous_key, device_count_,
/*warn_stuck_timeout=*/absl::Seconds(20),
/*terminate_timeout=*/absl::Seconds(40));

// For sending side, wait for the recorded event from the receiving side.
if (target_id) {
absl::MutexLock lock(&barrier_mutex_);
auto receiver_event = receiver_barrier_events_.find(*target_id);
TF_RETURN_IF_ERROR(stream.WaitFor(receiver_event->second.get()));
}
}

return ::xla::gpu::RunCollectivePermute(
Expand Down
4 changes: 2 additions & 2 deletions xla/backends/gpu/runtime/nccl_collective_permute_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk {
std::vector<Buffer> buffers_;
RecvPtrMap recv_ptr_map_;
absl::Mutex barrier_mutex_;
std::unordered_map<int64_t, std::unique_ptr<se::MemoryAllocation>>
barrier_flags_;
std::unordered_map<int64_t, std::unique_ptr<se::Event>>
receiver_barrier_events_;
bool p2p_memcpy_enabled_ = false;
int64_t device_count_;
};
Expand Down

0 comments on commit 405efd8

Please sign in to comment.