From 5268b93b546139d2fb863637601c582cb8382c3a Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Wed, 12 Feb 2025 04:29:53 -0800 Subject: [PATCH] [XLA:GPU] Add support for multiple updates per replica in RaggedAllToAll. A recent proposal suggested to extend the API of ra2a to allow to send multiple updates in one op. Before we would need to emit multiple chained ra2a to achieve the same effect. PiperOrigin-RevId: 726001087 --- .../runtime/nccl_ragged_all_to_all_thunk.cc | 120 ++++++---- .../runtime/nccl_ragged_all_to_all_thunk.h | 14 +- xla/tests/collective_ops_e2e_test.cc | 216 ++++++++++++------ 3 files changed, 224 insertions(+), 126 deletions(-) diff --git a/xla/backends/gpu/runtime/nccl_ragged_all_to_all_thunk.cc b/xla/backends/gpu/runtime/nccl_ragged_all_to_all_thunk.cc index 85a12564b7eee..85e61dadcb7d7 100644 --- a/xla/backends/gpu/runtime/nccl_ragged_all_to_all_thunk.cc +++ b/xla/backends/gpu/runtime/nccl_ragged_all_to_all_thunk.cc @@ -60,7 +60,9 @@ NcclRaggedAllToAllConfig GetNcclRaggedAllToAllConfig( const HloRaggedAllToAllInstruction* instr) { NcclRaggedAllToAllConfig config; config.config = GetNcclCollectiveConfig(instr, std::nullopt); - config.num_ragged_rows = instr->operand(2)->shape().dimensions(0); + + const Shape& input_size_shape = instr->operand(2)->shape(); + config.num_total_updates = input_size_shape.dimensions(0); config.ragged_row_element_size = ShapeUtil::ElementsIn(instr->shape()) / instr->shape().dimensions(0); return config; @@ -166,7 +168,7 @@ absl::Status NcclRaggedAllToAllStartThunk::Initialize( for (int64_t i = 0; i < kNumRaggedMetadataOperands; ++i) { TF_ASSIGN_OR_RETURN(std::unique_ptr alloc, params.executor->HostMemoryAllocate( - config_.num_ragged_rows * sizeof(int64_t))); + config_.num_total_updates * sizeof(int64_t))); allocs.push_back(std::move(alloc)); } host_buffer_allocs_.emplace(params.executor, std::move(allocs)); @@ -175,7 +177,7 @@ absl::Status NcclRaggedAllToAllStartThunk::Initialize( if (!device_buffer_allocs_.contains(params.executor)) { se::DeviceMemoryHandle output_offsets_device_buffer{ params.executor, - params.executor->Allocate(config_.num_ragged_rows * sizeof(int64_t))}; + params.executor->Allocate(config_.num_total_updates * sizeof(int64_t))}; if (output_offsets_device_buffer.memory().is_null()) { return absl::InternalError("Failed to allocate output offsets buffer."); @@ -275,14 +277,15 @@ absl::Status NcclRaggedAllToAllStartThunk::RunNcclCollective( receive_pointer_maps_[stream.parent()]->opaque()); } return xla::gpu::RunMemCpyRaggedAllToAll( - collectives, config_.ragged_row_element_size, device_buffers, stream, - comm_handle.comm, ragged_metadata_allocs, send_pointer, - receive_pointer_map); + collectives, config_.ragged_row_element_size, config_.num_total_updates, + device_buffers, stream, comm_handle.comm, ragged_metadata_allocs, + send_pointer, receive_pointer_map); } return xla::gpu::RunRaggedAllToAll( - collectives, config_.ragged_row_element_size, device_buffers, stream, - comm_handle.comm, ragged_metadata_allocs, output_offsets_device_buffer); + collectives, config_.ragged_row_element_size, config_.num_total_updates, + device_buffers, stream, comm_handle.comm, ragged_metadata_allocs, + output_offsets_device_buffer); } AsyncStreamKind NcclRaggedAllToAllStartThunk::GetAsyncStreamKind() const { @@ -292,21 +295,26 @@ AsyncStreamKind NcclRaggedAllToAllStartThunk::GetAsyncStreamKind() const { // Runs AllToAll on a buffer that contains ragged tensor metadata. absl::Status RunAllToAllOnIndexBuffer( GpuCollectives* collectives, const se::DeviceMemoryBase& source_buffer, + int64_t num_updates_per_replica, const se::DeviceMemoryBase& destination_buffer, PrimitiveType element_type, se::Stream& stream, Communicator* comm) { TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm->NumRanks()); TF_RETURN_IF_ERROR(collectives->GroupStart()); for (int peer = 0; peer < num_ranks; ++peer) { + int64_t offset = peer * num_updates_per_replica; se::DeviceMemoryBase send_slice = collectives->Slice( - source_buffer, element_type, /*offset=*/peer, /*count=*/1); - se::DeviceMemoryBase recv_slice = collectives->Slice( - destination_buffer, element_type, /*offset=*/peer, /*count=*/1); + source_buffer, element_type, offset, /*count=*/num_updates_per_replica); + se::DeviceMemoryBase recv_slice = + collectives->Slice(destination_buffer, element_type, offset, + /*count=*/num_updates_per_replica); - TF_RETURN_IF_ERROR(comm->Send(send_slice, element_type, /*count=*/1, + TF_RETURN_IF_ERROR(comm->Send(send_slice, element_type, + /*count=*/num_updates_per_replica, RankId(peer), GpuCollectives::On(stream))); - TF_RETURN_IF_ERROR(comm->Recv(recv_slice, element_type, /*count=*/1, + TF_RETURN_IF_ERROR(comm->Recv(recv_slice, element_type, + /*count=*/num_updates_per_replica, RankId(peer), GpuCollectives::On(stream))); } @@ -316,6 +324,7 @@ absl::Status RunAllToAllOnIndexBuffer( absl::Status RunRaggedAllToAll( GpuCollectives* collectives, int64_t ragged_row_element_size, + int64_t num_total_updates, const std::vector& original_buffers, se::Stream& stream, Communicator* comm, const std::vector& ragged_metadata_allocs, const se::DeviceMemoryBase& output_offsets_device_buffer) { @@ -329,6 +338,8 @@ absl::Status RunRaggedAllToAll( std::vector buffers = original_buffers; + int64_t num_updates_per_replica = num_total_updates / num_ranks; + // `output_offsets` of the RaggedAllToAll instruction are sharded in a way, // that `output_offset[i]` is an offset in the i-th peer output buffer. To // make it work for NCCL model with send/recv, we need to know offsets in the @@ -337,8 +348,8 @@ absl::Status RunRaggedAllToAll( DeviceBufferPair& output_offsets_buffer_pair = buffers[4]; TF_RETURN_IF_ERROR(RunAllToAllOnIndexBuffer( collectives, output_offsets_buffer_pair.source_buffer, - output_offsets_device_buffer, output_offsets_buffer_pair.element_type, - stream, comm)); + num_updates_per_replica, output_offsets_device_buffer, + output_offsets_buffer_pair.element_type, stream, comm)); output_offsets_buffer_pair.source_buffer = output_offsets_device_buffer; TF_ASSIGN_OR_RETURN( @@ -353,24 +364,27 @@ absl::Status RunRaggedAllToAll( TF_RETURN_IF_ERROR(collectives->GroupStart()); const DeviceBufferPair& data_buffer = buffers[0]; - for (int peer = 0; peer < num_ranks; ++peer) { - se::DeviceMemoryBase send_slice = - collectives->Slice(data_buffer.source_buffer, data_buffer.element_type, - input_offsets[peer] * ragged_row_element_size, - send_sizes[peer] * ragged_row_element_size); - - se::DeviceMemoryBase recv_slice = collectives->Slice( - data_buffer.destination_buffer, data_buffer.element_type, - output_offsets[peer] * ragged_row_element_size, - recv_sizes[peer] * ragged_row_element_size); - - TF_RETURN_IF_ERROR(comm->Send(send_slice, data_buffer.element_type, - send_sizes[peer] * ragged_row_element_size, - RankId(peer), GpuCollectives::On(stream))); - - TF_RETURN_IF_ERROR(comm->Recv(recv_slice, data_buffer.element_type, - recv_sizes[peer] * ragged_row_element_size, - RankId(peer), GpuCollectives::On(stream))); + for (int64_t i = 0; i < num_updates_per_replica; ++i) { + for (int peer = 0; peer < num_ranks; ++peer) { + int64_t idx = peer * num_updates_per_replica + i; + se::DeviceMemoryBase send_slice = collectives->Slice( + data_buffer.source_buffer, data_buffer.element_type, + input_offsets[idx] * ragged_row_element_size, + send_sizes[idx] * ragged_row_element_size); + + se::DeviceMemoryBase recv_slice = collectives->Slice( + data_buffer.destination_buffer, data_buffer.element_type, + output_offsets[idx] * ragged_row_element_size, + recv_sizes[idx] * ragged_row_element_size); + + TF_RETURN_IF_ERROR(comm->Send(send_slice, data_buffer.element_type, + send_sizes[idx] * ragged_row_element_size, + RankId(peer), GpuCollectives::On(stream))); + + TF_RETURN_IF_ERROR(comm->Recv(recv_slice, data_buffer.element_type, + recv_sizes[idx] * ragged_row_element_size, + RankId(peer), GpuCollectives::On(stream))); + } } return collectives->GroupEnd(); @@ -380,9 +394,10 @@ absl::Status RunRaggedAllToAll( // NcclCommunicator implementation. absl::Status RunMemCpyRaggedAllToAll( GpuCollectives* collectives, int64_t ragged_row_element_size, - const std::vector& buffers, se::Stream& stream, - Communicator* comm, const std::vector& ragged_metadata_allocs, - uint64_t* send_pointer, uint64_t receive_pointer_map[]) { + int64_t num_total_updates, const std::vector& buffers, + se::Stream& stream, Communicator* comm, + const std::vector& ragged_metadata_allocs, uint64_t* send_pointer, + uint64_t receive_pointer_map[]) { int device_ordinal = stream.parent()->device_ordinal(); VLOG(3) << "Performing mem-copy-ragged-all-to-all from device ordinal: " << device_ordinal; @@ -408,25 +423,30 @@ absl::Status RunMemCpyRaggedAllToAll( std::vector ragged_metadata, LoadRaggedTensorMetadata(stream, buffers, ragged_metadata_allocs)); + int64_t num_updates_per_replica = num_total_updates / num_ranks; + const IntegerOperandData& input_offsets = ragged_metadata[0]; const IntegerOperandData& send_sizes = ragged_metadata[1]; const IntegerOperandData& output_offsets = ragged_metadata[2]; // Transfer a slice of data to each peer's output buffer. - for (int peer = 0; peer < num_ranks; ++peer) { - se::DeviceMemoryBase send_slice = - collectives->Slice(data_buffer.source_buffer, data_buffer.element_type, - input_offsets[peer] * ragged_row_element_size, - send_sizes[peer] * ragged_row_element_size); - se::DeviceMemoryBase base_dst_addr = - se::DeviceMemoryBase(reinterpret_cast(receive_pointer_map[peer]), - data_buffer.destination_buffer.size()); - se::DeviceMemoryBase dst_slice = - collectives->Slice(base_dst_addr, data_buffer.element_type, - output_offsets[peer] * ragged_row_element_size, - send_sizes[peer] * ragged_row_element_size); - TF_RETURN_IF_ERROR( - stream.MemcpyD2D(&dst_slice, send_slice, send_slice.size())); + for (int64_t i = 0; i < num_updates_per_replica; ++i) { + for (int peer = 0; peer < num_ranks; ++peer) { + int64_t idx = peer * num_updates_per_replica + i; + se::DeviceMemoryBase send_slice = collectives->Slice( + data_buffer.source_buffer, data_buffer.element_type, + input_offsets[idx] * ragged_row_element_size, + send_sizes[idx] * ragged_row_element_size); + se::DeviceMemoryBase base_dst_addr = se::DeviceMemoryBase( + reinterpret_cast(receive_pointer_map[peer]), + data_buffer.destination_buffer.size()); + se::DeviceMemoryBase dst_slice = + collectives->Slice(base_dst_addr, data_buffer.element_type, + output_offsets[idx] * ragged_row_element_size, + send_sizes[idx] * ragged_row_element_size); + TF_RETURN_IF_ERROR( + stream.MemcpyD2D(&dst_slice, send_slice, send_slice.size())); + } } return absl::OkStatus(); diff --git a/xla/backends/gpu/runtime/nccl_ragged_all_to_all_thunk.h b/xla/backends/gpu/runtime/nccl_ragged_all_to_all_thunk.h index 337d0110cc746..984101ebe11d3 100644 --- a/xla/backends/gpu/runtime/nccl_ragged_all_to_all_thunk.h +++ b/xla/backends/gpu/runtime/nccl_ragged_all_to_all_thunk.h @@ -41,7 +41,7 @@ namespace gpu { struct NcclRaggedAllToAllConfig { NcclCollectiveConfig config; - int64_t num_ragged_rows = 1; + int64_t num_total_updates = 1; int64_t ragged_row_element_size = 1; }; @@ -112,15 +112,17 @@ class NcclRaggedAllToAllStartThunk : public NcclCollectiveThunk { absl::Status RunRaggedAllToAll( GpuCollectives* collectives, int64_t ragged_row_element_size, - const std::vector& buffers, se::Stream& stream, - Communicator* comm, const std::vector& ragged_metadata_allocs, + int64_t num_total_updates, const std::vector& buffers, + se::Stream& stream, Communicator* comm, + const std::vector& ragged_metadata_allocs, const se::DeviceMemoryBase& output_offsets_device_buffer); absl::Status RunMemCpyRaggedAllToAll( GpuCollectives* collectives, int64_t ragged_row_element_size, - const std::vector& buffers, se::Stream& stream, - Communicator* comm, const std::vector& ragged_metadata_allocs, - uint64_t* send_pointer, uint64_t receive_pointer_map[]); + int64_t num_total_updates, const std::vector& buffers, + se::Stream& stream, Communicator* comm, + const std::vector& ragged_metadata_allocs, uint64_t* send_pointer, + uint64_t receive_pointer_map[]); } // namespace gpu } // namespace xla diff --git a/xla/tests/collective_ops_e2e_test.cc b/xla/tests/collective_ops_e2e_test.cc index 09726651ffeb5..a4ead70013c2b 100644 --- a/xla/tests/collective_ops_e2e_test.cc +++ b/xla/tests/collective_ops_e2e_test.cc @@ -1883,13 +1883,15 @@ class RaggedAllToAllTest : public AsyncMemcpyCollectiveOps { // the test data we need to know the sizes of all ragged rows for each // replica. // - // `input_sizes` is a 2D array of shape [num_replicas, num_replicas]. - // `input_sizes[i, j]` is the number of elements in the j-th ragged row of the - // i-th replica input. + // `input_sizes` is an array of shape [num_replicas, num_replicas, + // num_updates_per_replica]. For concenivence, `input_sizes` can be a 2D + // array, in that case `num_updates_per_replica` is assumed to be 1. template - void CreateRandomTestData(HloModule* module, - const Array& input_sizes) { + void CreateRandomTestData(HloModule* module, Array input_sizes) { CHECK(inputs_.empty()); + if (input_sizes.num_dimensions() == 2) { + input_sizes.Reshape({input_sizes.dim(0), input_sizes.dim(1), 1}); + } auto ragged_all_to_all = FindInstruction(module, HloOpcode::kRaggedAllToAll); EXPECT_THAT(ragged_all_to_all, NotNull()); @@ -1899,75 +1901,35 @@ class RaggedAllToAllTest : public AsyncMemcpyCollectiveOps { ragged_all_to_all->shape().dimensions().begin(), ragged_all_to_all->shape().dimensions().end()}; - int64_t num_replicas = input_sizes.dim(0); + Array output_sizes = input_sizes; + output_sizes.TransposeDimensions({1, 0, 2}); + Array input_offsets = CalculateOffsetsFromSizes(input_sizes); + Array output_offsets = CalculateOffsetsFromSizes(output_sizes); + output_offsets.TransposeDimensions({1, 0, 2}); + + int64_t num_replicas = input_sizes.dim(0); std::vector> input_data(num_replicas, Array(ragged_tensor_sizes)); std::vector> output_data(num_replicas, Array(ragged_tensor_sizes)); - - Array output_sizes = input_sizes; - output_sizes.TransposeDimensions({1, 0}); - - // Computes ragged tensor offsets based on the sizes of the ragged rows. - auto get_offsets = [&](const Array& sizes) { - Array offsets(sizes.dimensions()); - for (int i = 0; i < num_replicas; ++i) { - for (int j = 1; j < num_replicas; ++j) { - offsets(i, j) = offsets(i, j - 1) + sizes(i, j - 1); - } - } - return offsets; - }; - - Array input_offsets = get_offsets(input_sizes); - Array output_offsets = get_offsets(output_sizes); - output_offsets.TransposeDimensions({1, 0}); - - std::vector chunk_sizes{ragged_tensor_sizes.begin(), - ragged_tensor_sizes.end()}; - - // Fill the input and output tensors with random data. An all-to-all is - // effective a transpose. We generate a chunk of random data for each pair - // of replicas and write the chunk starting from the (i, j) offset of the - // input tensor and starting from the (j, i) offset of the output tensor. - std::vector start_indices(ragged_tensor_sizes.size()); - for (int i = 0; i < num_replicas; ++i) { - for (int j = 0; j < num_replicas; ++j) { - chunk_sizes[0] = input_sizes(i, j); - - Array chunk_data(chunk_sizes); - chunk_data.FillRandomUniform(1, 127, /*seed=*/i * num_replicas + j); - - start_indices[0] = input_offsets(i, j); - input_data[i].UpdateSlice(chunk_data, start_indices); - - start_indices[0] = output_offsets(i, j); - output_data[j].UpdateSlice(chunk_data, start_indices); - } - } - - auto get_row = [&](int64_t row_id, const Array& data) { - Array row = - data.Slice({row_id, 0}, {row_id + 1, num_replicas}); - row.Reshape({num_replicas}); - return row; - }; + FillWithRandomData(input_data, output_data, input_offsets, output_offsets, + input_sizes); // Create literals from array data. for (int replica_id = 0; replica_id < num_replicas; ++replica_id) { inputs_.push_back(LiteralUtil::CreateFromArray(input_data[replica_id])); - input_offsets_.push_back( - LiteralUtil::CreateFromArray(get_row(replica_id, input_offsets))); - input_sizes_.push_back( - LiteralUtil::CreateFromArray(get_row(replica_id, input_sizes))); + input_offsets_.push_back(LiteralUtil::CreateFromArray( + GetReplicaSlice(replica_id, input_offsets))); + input_sizes_.push_back(LiteralUtil::CreateFromArray( + GetReplicaSlice(replica_id, input_sizes))); expected_outputs_.push_back( LiteralUtil::CreateFromArray(output_data[replica_id])); - output_offsets_.push_back( - LiteralUtil::CreateFromArray(get_row(replica_id, output_offsets))); - output_sizes_.push_back( - LiteralUtil::CreateFromArray(get_row(replica_id, output_sizes))); + output_offsets_.push_back(LiteralUtil::CreateFromArray( + GetReplicaSlice(replica_id, output_offsets))); + output_sizes_.push_back(LiteralUtil::CreateFromArray( + GetReplicaSlice(replica_id, output_sizes))); } // The ragged-all-to-all accepts an output tensor as a parameter to allow @@ -1987,6 +1949,74 @@ class RaggedAllToAllTest : public AsyncMemcpyCollectiveOps { return input_literal_ptrs; } + // Computes ragged tensor offsets based on the sizes of the ragged rows. + template + Array CalculateOffsetsFromSizes(const Array& sizes) { + int64_t num_replicas = sizes.dim(0); + int64_t num_updates_per_replica = sizes.dim(2); + Array offsets(sizes.dimensions()); + for (int i = 0; i < num_replicas; ++i) { + int64_t cur_offset = 0; + for (int j = 0; j < num_replicas; ++j) { + for (int k = 0; k < num_updates_per_replica; ++k) { + offsets(i, j, k) = cur_offset; + cur_offset += sizes(i, j, k); + } + } + } + return offsets; + } + + // Fill the input and output tensors with random data. An all-to-all is + // effectively a transpose. We generate a chunk of random data for each update + // of each pair of replicas and write the chunk starting from the (i, j, k) + // offset of the input tensor and starting from the (j, i, k) offset of the + // output tensor. + template + void FillWithRandomData(std::vector>& input_data, + std::vector>& output_data, + const Array& input_offsets, + const Array& output_offsets, + const Array& input_sizes) { + int64_t num_replicas = input_sizes.dim(0); + int64_t num_updates_per_replica = input_sizes.dim(2); + std::vector start_indices(input_data[0].num_dimensions()); + std::vector chunk_sizes{input_data[0].dimensions().begin(), + input_data[0].dimensions().end()}; + + for (int i = 0; i < num_replicas; ++i) { + for (int j = 0; j < num_replicas; ++j) { + for (int k = 0; k < num_updates_per_replica; ++k) { + chunk_sizes[0] = input_sizes(i, j, k); + + Array chunk_data(chunk_sizes); + chunk_data.FillRandomUniform( + 1, 127, + /*seed=*/(i * num_replicas + j) * num_updates_per_replica + k); + + start_indices[0] = input_offsets(i, j, k); + input_data[i].UpdateSlice(chunk_data, start_indices); + + start_indices[0] = output_offsets(i, j, k); + output_data[j].UpdateSlice(chunk_data, start_indices); + } + } + } + } + + // Returns a slice of input data that corresponds to the given replica. + template + Array GetReplicaSlice(int64_t replica_id, + const Array& data) { + int64_t num_replicas = data.dim(0); + int64_t num_updates_per_replica = data.dim(2); + Array replica_slice = + data.Slice({replica_id, 0, 0}, + {replica_id + 1, num_replicas, num_updates_per_replica}); + replica_slice.Reshape({num_replicas * num_updates_per_replica}); + return replica_slice; + } + // Literates for the input and output data, offset, and size parameters of the // ragged-all-to-all. Each vector contains one literal per replica. std::vector inputs_; @@ -2044,6 +2074,50 @@ XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_2GPUs) { EXPECT_TRUE(LiteralTestUtil::Equal(expected_outputs_[1], results[1])); } +XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_2GPUs_MultipleUpdates) { + absl::string_view kModuleReplicatedStr = R"( + HloModule module, num_partitions=1 + + ENTRY entry { + input = f32[8] parameter(0) + output = f32[8] parameter(1) + input_offsets = s32[4] parameter(2) + send_sizes = s32[4] parameter(3) + output_offsets = s32[4] parameter(4) + recv_sizes = s32[4] parameter(5) + ROOT ra2a = f32[8] ragged-all-to-all(input, output, input_offsets, + send_sizes, output_offsets, recv_sizes), replica_groups={{0,1}} + })"; + + const int64_t kNumReplicas = 2; + const int64_t kNumPartitions = 1; + if (test_runner().device_count() < kNumReplicas * kNumPartitions) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions + << " devices (" << test_runner().device_count() + << " available)"; + } + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas * kNumPartitions); + + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config)); + + CreateRandomTestData( + module.get(), /*input_sizes=*/{/*replica_0=*/{{1, 2}, {2, 1}}, + /*replica_1=*/{{3, 1}, {1, 1}}}); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + HloTestBase::ExecuteReplicated(std::move(module), GetInputLiteralPtrs(), + /*num_replicas=*/kNumReplicas, + /*run_hlo_passes=*/true, + /*device_assignment=*/nullptr)); + ASSERT_EQ(results.size(), kNumReplicas); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_outputs_[0], results[0])); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_outputs_[1], results[1])); +} + XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_2GPUs_MultiDimData) { absl::string_view kModuleReplicatedStr = R"( HloModule module, num_partitions=1 @@ -2159,19 +2233,20 @@ XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_8GPUs) { HloModule module, num_partitions=1 ENTRY entry { - input = f32[128, 5, 32] parameter(0) - output = f32[128, 5, 32] parameter(1) - input_offsets = s32[8] parameter(2) - send_sizes = s32[8] parameter(3) - output_offsets = s32[8] parameter(4) - recv_sizes = s32[8] parameter(5) - ROOT ra2a = f32[128, 5, 32] ragged-all-to-all(input, output, + input = f32[512, 5, 32] parameter(0) + output = f32[512, 5, 32] parameter(1) + input_offsets = s32[32] parameter(2) + send_sizes = s32[32] parameter(3) + output_offsets = s32[32] parameter(4) + recv_sizes = s32[32] parameter(5) + ROOT ra2a = f32[512, 5, 32] ragged-all-to-all(input, output, input_offsets, send_sizes, output_offsets, recv_sizes), replica_groups={{0,1,2,3,4,5,6,7}} })"; const int64_t kNumReplicas = 8; const int64_t kNumPartitions = 1; + const int64_t kNumUpdatesPerReplica = 4; if (test_runner().device_count() < kNumReplicas * kNumPartitions) { GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions << " devices (" << test_runner().device_count() @@ -2184,7 +2259,8 @@ XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_8GPUs) { TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config)); - Array input_sizes({kNumReplicas, kNumReplicas}); + Array input_sizes( + {kNumReplicas, kNumReplicas, kNumUpdatesPerReplica}); input_sizes.FillRandomUniform(0, 10); CreateRandomTestData(module.get(), input_sizes);