Skip to content

Commit

Permalink
[XLA:GPU] Add support for multiple updates per replica in RaggedAllTo…
Browse files Browse the repository at this point in the history
…All.

A recent proposal suggested to extend the API of ra2a to allow to send multiple updates in one op. Before we would need to emit multiple chained ra2a to achieve the same effect.

PiperOrigin-RevId: 726001087
  • Loading branch information
olegshyshkov authored and Google-ML-Automation committed Feb 12, 2025
1 parent 105e240 commit 5268b93
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 126 deletions.
120 changes: 70 additions & 50 deletions xla/backends/gpu/runtime/nccl_ragged_all_to_all_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ NcclRaggedAllToAllConfig GetNcclRaggedAllToAllConfig(
const HloRaggedAllToAllInstruction* instr) {
NcclRaggedAllToAllConfig config;
config.config = GetNcclCollectiveConfig(instr, std::nullopt);
config.num_ragged_rows = instr->operand(2)->shape().dimensions(0);

const Shape& input_size_shape = instr->operand(2)->shape();
config.num_total_updates = input_size_shape.dimensions(0);
config.ragged_row_element_size =
ShapeUtil::ElementsIn(instr->shape()) / instr->shape().dimensions(0);
return config;
Expand Down Expand Up @@ -166,7 +168,7 @@ absl::Status NcclRaggedAllToAllStartThunk::Initialize(
for (int64_t i = 0; i < kNumRaggedMetadataOperands; ++i) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<se::MemoryAllocation> alloc,
params.executor->HostMemoryAllocate(
config_.num_ragged_rows * sizeof(int64_t)));
config_.num_total_updates * sizeof(int64_t)));
allocs.push_back(std::move(alloc));
}
host_buffer_allocs_.emplace(params.executor, std::move(allocs));
Expand All @@ -175,7 +177,7 @@ absl::Status NcclRaggedAllToAllStartThunk::Initialize(
if (!device_buffer_allocs_.contains(params.executor)) {
se::DeviceMemoryHandle output_offsets_device_buffer{
params.executor,
params.executor->Allocate(config_.num_ragged_rows * sizeof(int64_t))};
params.executor->Allocate(config_.num_total_updates * sizeof(int64_t))};

if (output_offsets_device_buffer.memory().is_null()) {
return absl::InternalError("Failed to allocate output offsets buffer.");
Expand Down Expand Up @@ -275,14 +277,15 @@ absl::Status NcclRaggedAllToAllStartThunk::RunNcclCollective(
receive_pointer_maps_[stream.parent()]->opaque());
}
return xla::gpu::RunMemCpyRaggedAllToAll(
collectives, config_.ragged_row_element_size, device_buffers, stream,
comm_handle.comm, ragged_metadata_allocs, send_pointer,
receive_pointer_map);
collectives, config_.ragged_row_element_size, config_.num_total_updates,
device_buffers, stream, comm_handle.comm, ragged_metadata_allocs,
send_pointer, receive_pointer_map);
}

return xla::gpu::RunRaggedAllToAll(
collectives, config_.ragged_row_element_size, device_buffers, stream,
comm_handle.comm, ragged_metadata_allocs, output_offsets_device_buffer);
collectives, config_.ragged_row_element_size, config_.num_total_updates,
device_buffers, stream, comm_handle.comm, ragged_metadata_allocs,
output_offsets_device_buffer);
}

AsyncStreamKind NcclRaggedAllToAllStartThunk::GetAsyncStreamKind() const {
Expand All @@ -292,21 +295,26 @@ AsyncStreamKind NcclRaggedAllToAllStartThunk::GetAsyncStreamKind() const {
// Runs AllToAll on a buffer that contains ragged tensor metadata.
absl::Status RunAllToAllOnIndexBuffer(
GpuCollectives* collectives, const se::DeviceMemoryBase& source_buffer,
int64_t num_updates_per_replica,
const se::DeviceMemoryBase& destination_buffer, PrimitiveType element_type,
se::Stream& stream, Communicator* comm) {
TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm->NumRanks());

TF_RETURN_IF_ERROR(collectives->GroupStart());
for (int peer = 0; peer < num_ranks; ++peer) {
int64_t offset = peer * num_updates_per_replica;
se::DeviceMemoryBase send_slice = collectives->Slice(
source_buffer, element_type, /*offset=*/peer, /*count=*/1);
se::DeviceMemoryBase recv_slice = collectives->Slice(
destination_buffer, element_type, /*offset=*/peer, /*count=*/1);
source_buffer, element_type, offset, /*count=*/num_updates_per_replica);
se::DeviceMemoryBase recv_slice =
collectives->Slice(destination_buffer, element_type, offset,
/*count=*/num_updates_per_replica);

TF_RETURN_IF_ERROR(comm->Send(send_slice, element_type, /*count=*/1,
TF_RETURN_IF_ERROR(comm->Send(send_slice, element_type,
/*count=*/num_updates_per_replica,
RankId(peer), GpuCollectives::On(stream)));

TF_RETURN_IF_ERROR(comm->Recv(recv_slice, element_type, /*count=*/1,
TF_RETURN_IF_ERROR(comm->Recv(recv_slice, element_type,
/*count=*/num_updates_per_replica,
RankId(peer), GpuCollectives::On(stream)));
}

Expand All @@ -316,6 +324,7 @@ absl::Status RunAllToAllOnIndexBuffer(

absl::Status RunRaggedAllToAll(
GpuCollectives* collectives, int64_t ragged_row_element_size,
int64_t num_total_updates,
const std::vector<DeviceBufferPair>& original_buffers, se::Stream& stream,
Communicator* comm, const std::vector<int64_t*>& ragged_metadata_allocs,
const se::DeviceMemoryBase& output_offsets_device_buffer) {
Expand All @@ -329,6 +338,8 @@ absl::Status RunRaggedAllToAll(

std::vector<DeviceBufferPair> buffers = original_buffers;

int64_t num_updates_per_replica = num_total_updates / num_ranks;

// `output_offsets` of the RaggedAllToAll instruction are sharded in a way,
// that `output_offset[i]` is an offset in the i-th peer output buffer. To
// make it work for NCCL model with send/recv, we need to know offsets in the
Expand All @@ -337,8 +348,8 @@ absl::Status RunRaggedAllToAll(
DeviceBufferPair& output_offsets_buffer_pair = buffers[4];
TF_RETURN_IF_ERROR(RunAllToAllOnIndexBuffer(
collectives, output_offsets_buffer_pair.source_buffer,
output_offsets_device_buffer, output_offsets_buffer_pair.element_type,
stream, comm));
num_updates_per_replica, output_offsets_device_buffer,
output_offsets_buffer_pair.element_type, stream, comm));
output_offsets_buffer_pair.source_buffer = output_offsets_device_buffer;

TF_ASSIGN_OR_RETURN(
Expand All @@ -353,24 +364,27 @@ absl::Status RunRaggedAllToAll(
TF_RETURN_IF_ERROR(collectives->GroupStart());

const DeviceBufferPair& data_buffer = buffers[0];
for (int peer = 0; peer < num_ranks; ++peer) {
se::DeviceMemoryBase send_slice =
collectives->Slice(data_buffer.source_buffer, data_buffer.element_type,
input_offsets[peer] * ragged_row_element_size,
send_sizes[peer] * ragged_row_element_size);

se::DeviceMemoryBase recv_slice = collectives->Slice(
data_buffer.destination_buffer, data_buffer.element_type,
output_offsets[peer] * ragged_row_element_size,
recv_sizes[peer] * ragged_row_element_size);

TF_RETURN_IF_ERROR(comm->Send(send_slice, data_buffer.element_type,
send_sizes[peer] * ragged_row_element_size,
RankId(peer), GpuCollectives::On(stream)));

TF_RETURN_IF_ERROR(comm->Recv(recv_slice, data_buffer.element_type,
recv_sizes[peer] * ragged_row_element_size,
RankId(peer), GpuCollectives::On(stream)));
for (int64_t i = 0; i < num_updates_per_replica; ++i) {
for (int peer = 0; peer < num_ranks; ++peer) {
int64_t idx = peer * num_updates_per_replica + i;
se::DeviceMemoryBase send_slice = collectives->Slice(
data_buffer.source_buffer, data_buffer.element_type,
input_offsets[idx] * ragged_row_element_size,
send_sizes[idx] * ragged_row_element_size);

se::DeviceMemoryBase recv_slice = collectives->Slice(
data_buffer.destination_buffer, data_buffer.element_type,
output_offsets[idx] * ragged_row_element_size,
recv_sizes[idx] * ragged_row_element_size);

TF_RETURN_IF_ERROR(comm->Send(send_slice, data_buffer.element_type,
send_sizes[idx] * ragged_row_element_size,
RankId(peer), GpuCollectives::On(stream)));

TF_RETURN_IF_ERROR(comm->Recv(recv_slice, data_buffer.element_type,
recv_sizes[idx] * ragged_row_element_size,
RankId(peer), GpuCollectives::On(stream)));
}
}

return collectives->GroupEnd();
Expand All @@ -380,9 +394,10 @@ absl::Status RunRaggedAllToAll(
// NcclCommunicator implementation.
absl::Status RunMemCpyRaggedAllToAll(
GpuCollectives* collectives, int64_t ragged_row_element_size,
const std::vector<DeviceBufferPair>& buffers, se::Stream& stream,
Communicator* comm, const std::vector<int64_t*>& ragged_metadata_allocs,
uint64_t* send_pointer, uint64_t receive_pointer_map[]) {
int64_t num_total_updates, const std::vector<DeviceBufferPair>& buffers,
se::Stream& stream, Communicator* comm,
const std::vector<int64_t*>& ragged_metadata_allocs, uint64_t* send_pointer,
uint64_t receive_pointer_map[]) {
int device_ordinal = stream.parent()->device_ordinal();
VLOG(3) << "Performing mem-copy-ragged-all-to-all from device ordinal: "
<< device_ordinal;
Expand All @@ -408,25 +423,30 @@ absl::Status RunMemCpyRaggedAllToAll(
std::vector<IntegerOperandData> ragged_metadata,
LoadRaggedTensorMetadata(stream, buffers, ragged_metadata_allocs));

int64_t num_updates_per_replica = num_total_updates / num_ranks;

const IntegerOperandData& input_offsets = ragged_metadata[0];
const IntegerOperandData& send_sizes = ragged_metadata[1];
const IntegerOperandData& output_offsets = ragged_metadata[2];

// Transfer a slice of data to each peer's output buffer.
for (int peer = 0; peer < num_ranks; ++peer) {
se::DeviceMemoryBase send_slice =
collectives->Slice(data_buffer.source_buffer, data_buffer.element_type,
input_offsets[peer] * ragged_row_element_size,
send_sizes[peer] * ragged_row_element_size);
se::DeviceMemoryBase base_dst_addr =
se::DeviceMemoryBase(reinterpret_cast<void*>(receive_pointer_map[peer]),
data_buffer.destination_buffer.size());
se::DeviceMemoryBase dst_slice =
collectives->Slice(base_dst_addr, data_buffer.element_type,
output_offsets[peer] * ragged_row_element_size,
send_sizes[peer] * ragged_row_element_size);
TF_RETURN_IF_ERROR(
stream.MemcpyD2D(&dst_slice, send_slice, send_slice.size()));
for (int64_t i = 0; i < num_updates_per_replica; ++i) {
for (int peer = 0; peer < num_ranks; ++peer) {
int64_t idx = peer * num_updates_per_replica + i;
se::DeviceMemoryBase send_slice = collectives->Slice(
data_buffer.source_buffer, data_buffer.element_type,
input_offsets[idx] * ragged_row_element_size,
send_sizes[idx] * ragged_row_element_size);
se::DeviceMemoryBase base_dst_addr = se::DeviceMemoryBase(
reinterpret_cast<void*>(receive_pointer_map[peer]),
data_buffer.destination_buffer.size());
se::DeviceMemoryBase dst_slice =
collectives->Slice(base_dst_addr, data_buffer.element_type,
output_offsets[idx] * ragged_row_element_size,
send_sizes[idx] * ragged_row_element_size);
TF_RETURN_IF_ERROR(
stream.MemcpyD2D(&dst_slice, send_slice, send_slice.size()));
}
}

return absl::OkStatus();
Expand Down
14 changes: 8 additions & 6 deletions xla/backends/gpu/runtime/nccl_ragged_all_to_all_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ namespace gpu {

struct NcclRaggedAllToAllConfig {
NcclCollectiveConfig config;
int64_t num_ragged_rows = 1;
int64_t num_total_updates = 1;
int64_t ragged_row_element_size = 1;
};

Expand Down Expand Up @@ -112,15 +112,17 @@ class NcclRaggedAllToAllStartThunk : public NcclCollectiveThunk {

absl::Status RunRaggedAllToAll(
GpuCollectives* collectives, int64_t ragged_row_element_size,
const std::vector<DeviceBufferPair>& buffers, se::Stream& stream,
Communicator* comm, const std::vector<int64_t*>& ragged_metadata_allocs,
int64_t num_total_updates, const std::vector<DeviceBufferPair>& buffers,
se::Stream& stream, Communicator* comm,
const std::vector<int64_t*>& ragged_metadata_allocs,
const se::DeviceMemoryBase& output_offsets_device_buffer);

absl::Status RunMemCpyRaggedAllToAll(
GpuCollectives* collectives, int64_t ragged_row_element_size,
const std::vector<DeviceBufferPair>& buffers, se::Stream& stream,
Communicator* comm, const std::vector<int64_t*>& ragged_metadata_allocs,
uint64_t* send_pointer, uint64_t receive_pointer_map[]);
int64_t num_total_updates, const std::vector<DeviceBufferPair>& buffers,
se::Stream& stream, Communicator* comm,
const std::vector<int64_t*>& ragged_metadata_allocs, uint64_t* send_pointer,
uint64_t receive_pointer_map[]);

} // namespace gpu
} // namespace xla
Expand Down
Loading

0 comments on commit 5268b93

Please sign in to comment.