diff --git a/xla/backends/gpu/runtime/nccl_ragged_all_to_all_thunk.cc b/xla/backends/gpu/runtime/nccl_ragged_all_to_all_thunk.cc index 85a12564b7eee..85e61dadcb7d7 100644 --- a/xla/backends/gpu/runtime/nccl_ragged_all_to_all_thunk.cc +++ b/xla/backends/gpu/runtime/nccl_ragged_all_to_all_thunk.cc @@ -60,7 +60,9 @@ NcclRaggedAllToAllConfig GetNcclRaggedAllToAllConfig( const HloRaggedAllToAllInstruction* instr) { NcclRaggedAllToAllConfig config; config.config = GetNcclCollectiveConfig(instr, std::nullopt); - config.num_ragged_rows = instr->operand(2)->shape().dimensions(0); + + const Shape& input_size_shape = instr->operand(2)->shape(); + config.num_total_updates = input_size_shape.dimensions(0); config.ragged_row_element_size = ShapeUtil::ElementsIn(instr->shape()) / instr->shape().dimensions(0); return config; @@ -166,7 +168,7 @@ absl::Status NcclRaggedAllToAllStartThunk::Initialize( for (int64_t i = 0; i < kNumRaggedMetadataOperands; ++i) { TF_ASSIGN_OR_RETURN(std::unique_ptr alloc, params.executor->HostMemoryAllocate( - config_.num_ragged_rows * sizeof(int64_t))); + config_.num_total_updates * sizeof(int64_t))); allocs.push_back(std::move(alloc)); } host_buffer_allocs_.emplace(params.executor, std::move(allocs)); @@ -175,7 +177,7 @@ absl::Status NcclRaggedAllToAllStartThunk::Initialize( if (!device_buffer_allocs_.contains(params.executor)) { se::DeviceMemoryHandle output_offsets_device_buffer{ params.executor, - params.executor->Allocate(config_.num_ragged_rows * sizeof(int64_t))}; + params.executor->Allocate(config_.num_total_updates * sizeof(int64_t))}; if (output_offsets_device_buffer.memory().is_null()) { return absl::InternalError("Failed to allocate output offsets buffer."); @@ -275,14 +277,15 @@ absl::Status NcclRaggedAllToAllStartThunk::RunNcclCollective( receive_pointer_maps_[stream.parent()]->opaque()); } return xla::gpu::RunMemCpyRaggedAllToAll( - collectives, config_.ragged_row_element_size, device_buffers, stream, - comm_handle.comm, ragged_metadata_allocs, send_pointer, - receive_pointer_map); + collectives, config_.ragged_row_element_size, config_.num_total_updates, + device_buffers, stream, comm_handle.comm, ragged_metadata_allocs, + send_pointer, receive_pointer_map); } return xla::gpu::RunRaggedAllToAll( - collectives, config_.ragged_row_element_size, device_buffers, stream, - comm_handle.comm, ragged_metadata_allocs, output_offsets_device_buffer); + collectives, config_.ragged_row_element_size, config_.num_total_updates, + device_buffers, stream, comm_handle.comm, ragged_metadata_allocs, + output_offsets_device_buffer); } AsyncStreamKind NcclRaggedAllToAllStartThunk::GetAsyncStreamKind() const { @@ -292,21 +295,26 @@ AsyncStreamKind NcclRaggedAllToAllStartThunk::GetAsyncStreamKind() const { // Runs AllToAll on a buffer that contains ragged tensor metadata. absl::Status RunAllToAllOnIndexBuffer( GpuCollectives* collectives, const se::DeviceMemoryBase& source_buffer, + int64_t num_updates_per_replica, const se::DeviceMemoryBase& destination_buffer, PrimitiveType element_type, se::Stream& stream, Communicator* comm) { TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm->NumRanks()); TF_RETURN_IF_ERROR(collectives->GroupStart()); for (int peer = 0; peer < num_ranks; ++peer) { + int64_t offset = peer * num_updates_per_replica; se::DeviceMemoryBase send_slice = collectives->Slice( - source_buffer, element_type, /*offset=*/peer, /*count=*/1); - se::DeviceMemoryBase recv_slice = collectives->Slice( - destination_buffer, element_type, /*offset=*/peer, /*count=*/1); + source_buffer, element_type, offset, /*count=*/num_updates_per_replica); + se::DeviceMemoryBase recv_slice = + collectives->Slice(destination_buffer, element_type, offset, + /*count=*/num_updates_per_replica); - TF_RETURN_IF_ERROR(comm->Send(send_slice, element_type, /*count=*/1, + TF_RETURN_IF_ERROR(comm->Send(send_slice, element_type, + /*count=*/num_updates_per_replica, RankId(peer), GpuCollectives::On(stream))); - TF_RETURN_IF_ERROR(comm->Recv(recv_slice, element_type, /*count=*/1, + TF_RETURN_IF_ERROR(comm->Recv(recv_slice, element_type, + /*count=*/num_updates_per_replica, RankId(peer), GpuCollectives::On(stream))); } @@ -316,6 +324,7 @@ absl::Status RunAllToAllOnIndexBuffer( absl::Status RunRaggedAllToAll( GpuCollectives* collectives, int64_t ragged_row_element_size, + int64_t num_total_updates, const std::vector& original_buffers, se::Stream& stream, Communicator* comm, const std::vector& ragged_metadata_allocs, const se::DeviceMemoryBase& output_offsets_device_buffer) { @@ -329,6 +338,8 @@ absl::Status RunRaggedAllToAll( std::vector buffers = original_buffers; + int64_t num_updates_per_replica = num_total_updates / num_ranks; + // `output_offsets` of the RaggedAllToAll instruction are sharded in a way, // that `output_offset[i]` is an offset in the i-th peer output buffer. To // make it work for NCCL model with send/recv, we need to know offsets in the @@ -337,8 +348,8 @@ absl::Status RunRaggedAllToAll( DeviceBufferPair& output_offsets_buffer_pair = buffers[4]; TF_RETURN_IF_ERROR(RunAllToAllOnIndexBuffer( collectives, output_offsets_buffer_pair.source_buffer, - output_offsets_device_buffer, output_offsets_buffer_pair.element_type, - stream, comm)); + num_updates_per_replica, output_offsets_device_buffer, + output_offsets_buffer_pair.element_type, stream, comm)); output_offsets_buffer_pair.source_buffer = output_offsets_device_buffer; TF_ASSIGN_OR_RETURN( @@ -353,24 +364,27 @@ absl::Status RunRaggedAllToAll( TF_RETURN_IF_ERROR(collectives->GroupStart()); const DeviceBufferPair& data_buffer = buffers[0]; - for (int peer = 0; peer < num_ranks; ++peer) { - se::DeviceMemoryBase send_slice = - collectives->Slice(data_buffer.source_buffer, data_buffer.element_type, - input_offsets[peer] * ragged_row_element_size, - send_sizes[peer] * ragged_row_element_size); - - se::DeviceMemoryBase recv_slice = collectives->Slice( - data_buffer.destination_buffer, data_buffer.element_type, - output_offsets[peer] * ragged_row_element_size, - recv_sizes[peer] * ragged_row_element_size); - - TF_RETURN_IF_ERROR(comm->Send(send_slice, data_buffer.element_type, - send_sizes[peer] * ragged_row_element_size, - RankId(peer), GpuCollectives::On(stream))); - - TF_RETURN_IF_ERROR(comm->Recv(recv_slice, data_buffer.element_type, - recv_sizes[peer] * ragged_row_element_size, - RankId(peer), GpuCollectives::On(stream))); + for (int64_t i = 0; i < num_updates_per_replica; ++i) { + for (int peer = 0; peer < num_ranks; ++peer) { + int64_t idx = peer * num_updates_per_replica + i; + se::DeviceMemoryBase send_slice = collectives->Slice( + data_buffer.source_buffer, data_buffer.element_type, + input_offsets[idx] * ragged_row_element_size, + send_sizes[idx] * ragged_row_element_size); + + se::DeviceMemoryBase recv_slice = collectives->Slice( + data_buffer.destination_buffer, data_buffer.element_type, + output_offsets[idx] * ragged_row_element_size, + recv_sizes[idx] * ragged_row_element_size); + + TF_RETURN_IF_ERROR(comm->Send(send_slice, data_buffer.element_type, + send_sizes[idx] * ragged_row_element_size, + RankId(peer), GpuCollectives::On(stream))); + + TF_RETURN_IF_ERROR(comm->Recv(recv_slice, data_buffer.element_type, + recv_sizes[idx] * ragged_row_element_size, + RankId(peer), GpuCollectives::On(stream))); + } } return collectives->GroupEnd(); @@ -380,9 +394,10 @@ absl::Status RunRaggedAllToAll( // NcclCommunicator implementation. absl::Status RunMemCpyRaggedAllToAll( GpuCollectives* collectives, int64_t ragged_row_element_size, - const std::vector& buffers, se::Stream& stream, - Communicator* comm, const std::vector& ragged_metadata_allocs, - uint64_t* send_pointer, uint64_t receive_pointer_map[]) { + int64_t num_total_updates, const std::vector& buffers, + se::Stream& stream, Communicator* comm, + const std::vector& ragged_metadata_allocs, uint64_t* send_pointer, + uint64_t receive_pointer_map[]) { int device_ordinal = stream.parent()->device_ordinal(); VLOG(3) << "Performing mem-copy-ragged-all-to-all from device ordinal: " << device_ordinal; @@ -408,25 +423,30 @@ absl::Status RunMemCpyRaggedAllToAll( std::vector ragged_metadata, LoadRaggedTensorMetadata(stream, buffers, ragged_metadata_allocs)); + int64_t num_updates_per_replica = num_total_updates / num_ranks; + const IntegerOperandData& input_offsets = ragged_metadata[0]; const IntegerOperandData& send_sizes = ragged_metadata[1]; const IntegerOperandData& output_offsets = ragged_metadata[2]; // Transfer a slice of data to each peer's output buffer. - for (int peer = 0; peer < num_ranks; ++peer) { - se::DeviceMemoryBase send_slice = - collectives->Slice(data_buffer.source_buffer, data_buffer.element_type, - input_offsets[peer] * ragged_row_element_size, - send_sizes[peer] * ragged_row_element_size); - se::DeviceMemoryBase base_dst_addr = - se::DeviceMemoryBase(reinterpret_cast(receive_pointer_map[peer]), - data_buffer.destination_buffer.size()); - se::DeviceMemoryBase dst_slice = - collectives->Slice(base_dst_addr, data_buffer.element_type, - output_offsets[peer] * ragged_row_element_size, - send_sizes[peer] * ragged_row_element_size); - TF_RETURN_IF_ERROR( - stream.MemcpyD2D(&dst_slice, send_slice, send_slice.size())); + for (int64_t i = 0; i < num_updates_per_replica; ++i) { + for (int peer = 0; peer < num_ranks; ++peer) { + int64_t idx = peer * num_updates_per_replica + i; + se::DeviceMemoryBase send_slice = collectives->Slice( + data_buffer.source_buffer, data_buffer.element_type, + input_offsets[idx] * ragged_row_element_size, + send_sizes[idx] * ragged_row_element_size); + se::DeviceMemoryBase base_dst_addr = se::DeviceMemoryBase( + reinterpret_cast(receive_pointer_map[peer]), + data_buffer.destination_buffer.size()); + se::DeviceMemoryBase dst_slice = + collectives->Slice(base_dst_addr, data_buffer.element_type, + output_offsets[idx] * ragged_row_element_size, + send_sizes[idx] * ragged_row_element_size); + TF_RETURN_IF_ERROR( + stream.MemcpyD2D(&dst_slice, send_slice, send_slice.size())); + } } return absl::OkStatus(); diff --git a/xla/backends/gpu/runtime/nccl_ragged_all_to_all_thunk.h b/xla/backends/gpu/runtime/nccl_ragged_all_to_all_thunk.h index 337d0110cc746..984101ebe11d3 100644 --- a/xla/backends/gpu/runtime/nccl_ragged_all_to_all_thunk.h +++ b/xla/backends/gpu/runtime/nccl_ragged_all_to_all_thunk.h @@ -41,7 +41,7 @@ namespace gpu { struct NcclRaggedAllToAllConfig { NcclCollectiveConfig config; - int64_t num_ragged_rows = 1; + int64_t num_total_updates = 1; int64_t ragged_row_element_size = 1; }; @@ -112,15 +112,17 @@ class NcclRaggedAllToAllStartThunk : public NcclCollectiveThunk { absl::Status RunRaggedAllToAll( GpuCollectives* collectives, int64_t ragged_row_element_size, - const std::vector& buffers, se::Stream& stream, - Communicator* comm, const std::vector& ragged_metadata_allocs, + int64_t num_total_updates, const std::vector& buffers, + se::Stream& stream, Communicator* comm, + const std::vector& ragged_metadata_allocs, const se::DeviceMemoryBase& output_offsets_device_buffer); absl::Status RunMemCpyRaggedAllToAll( GpuCollectives* collectives, int64_t ragged_row_element_size, - const std::vector& buffers, se::Stream& stream, - Communicator* comm, const std::vector& ragged_metadata_allocs, - uint64_t* send_pointer, uint64_t receive_pointer_map[]); + int64_t num_total_updates, const std::vector& buffers, + se::Stream& stream, Communicator* comm, + const std::vector& ragged_metadata_allocs, uint64_t* send_pointer, + uint64_t receive_pointer_map[]); } // namespace gpu } // namespace xla diff --git a/xla/tests/collective_ops_e2e_test.cc b/xla/tests/collective_ops_e2e_test.cc index 09726651ffeb5..a4ead70013c2b 100644 --- a/xla/tests/collective_ops_e2e_test.cc +++ b/xla/tests/collective_ops_e2e_test.cc @@ -1883,13 +1883,15 @@ class RaggedAllToAllTest : public AsyncMemcpyCollectiveOps { // the test data we need to know the sizes of all ragged rows for each // replica. // - // `input_sizes` is a 2D array of shape [num_replicas, num_replicas]. - // `input_sizes[i, j]` is the number of elements in the j-th ragged row of the - // i-th replica input. + // `input_sizes` is an array of shape [num_replicas, num_replicas, + // num_updates_per_replica]. For concenivence, `input_sizes` can be a 2D + // array, in that case `num_updates_per_replica` is assumed to be 1. template - void CreateRandomTestData(HloModule* module, - const Array& input_sizes) { + void CreateRandomTestData(HloModule* module, Array input_sizes) { CHECK(inputs_.empty()); + if (input_sizes.num_dimensions() == 2) { + input_sizes.Reshape({input_sizes.dim(0), input_sizes.dim(1), 1}); + } auto ragged_all_to_all = FindInstruction(module, HloOpcode::kRaggedAllToAll); EXPECT_THAT(ragged_all_to_all, NotNull()); @@ -1899,75 +1901,35 @@ class RaggedAllToAllTest : public AsyncMemcpyCollectiveOps { ragged_all_to_all->shape().dimensions().begin(), ragged_all_to_all->shape().dimensions().end()}; - int64_t num_replicas = input_sizes.dim(0); + Array output_sizes = input_sizes; + output_sizes.TransposeDimensions({1, 0, 2}); + Array input_offsets = CalculateOffsetsFromSizes(input_sizes); + Array output_offsets = CalculateOffsetsFromSizes(output_sizes); + output_offsets.TransposeDimensions({1, 0, 2}); + + int64_t num_replicas = input_sizes.dim(0); std::vector> input_data(num_replicas, Array(ragged_tensor_sizes)); std::vector> output_data(num_replicas, Array(ragged_tensor_sizes)); - - Array output_sizes = input_sizes; - output_sizes.TransposeDimensions({1, 0}); - - // Computes ragged tensor offsets based on the sizes of the ragged rows. - auto get_offsets = [&](const Array& sizes) { - Array offsets(sizes.dimensions()); - for (int i = 0; i < num_replicas; ++i) { - for (int j = 1; j < num_replicas; ++j) { - offsets(i, j) = offsets(i, j - 1) + sizes(i, j - 1); - } - } - return offsets; - }; - - Array input_offsets = get_offsets(input_sizes); - Array output_offsets = get_offsets(output_sizes); - output_offsets.TransposeDimensions({1, 0}); - - std::vector chunk_sizes{ragged_tensor_sizes.begin(), - ragged_tensor_sizes.end()}; - - // Fill the input and output tensors with random data. An all-to-all is - // effective a transpose. We generate a chunk of random data for each pair - // of replicas and write the chunk starting from the (i, j) offset of the - // input tensor and starting from the (j, i) offset of the output tensor. - std::vector start_indices(ragged_tensor_sizes.size()); - for (int i = 0; i < num_replicas; ++i) { - for (int j = 0; j < num_replicas; ++j) { - chunk_sizes[0] = input_sizes(i, j); - - Array chunk_data(chunk_sizes); - chunk_data.FillRandomUniform(1, 127, /*seed=*/i * num_replicas + j); - - start_indices[0] = input_offsets(i, j); - input_data[i].UpdateSlice(chunk_data, start_indices); - - start_indices[0] = output_offsets(i, j); - output_data[j].UpdateSlice(chunk_data, start_indices); - } - } - - auto get_row = [&](int64_t row_id, const Array& data) { - Array row = - data.Slice({row_id, 0}, {row_id + 1, num_replicas}); - row.Reshape({num_replicas}); - return row; - }; + FillWithRandomData(input_data, output_data, input_offsets, output_offsets, + input_sizes); // Create literals from array data. for (int replica_id = 0; replica_id < num_replicas; ++replica_id) { inputs_.push_back(LiteralUtil::CreateFromArray(input_data[replica_id])); - input_offsets_.push_back( - LiteralUtil::CreateFromArray(get_row(replica_id, input_offsets))); - input_sizes_.push_back( - LiteralUtil::CreateFromArray(get_row(replica_id, input_sizes))); + input_offsets_.push_back(LiteralUtil::CreateFromArray( + GetReplicaSlice(replica_id, input_offsets))); + input_sizes_.push_back(LiteralUtil::CreateFromArray( + GetReplicaSlice(replica_id, input_sizes))); expected_outputs_.push_back( LiteralUtil::CreateFromArray(output_data[replica_id])); - output_offsets_.push_back( - LiteralUtil::CreateFromArray(get_row(replica_id, output_offsets))); - output_sizes_.push_back( - LiteralUtil::CreateFromArray(get_row(replica_id, output_sizes))); + output_offsets_.push_back(LiteralUtil::CreateFromArray( + GetReplicaSlice(replica_id, output_offsets))); + output_sizes_.push_back(LiteralUtil::CreateFromArray( + GetReplicaSlice(replica_id, output_sizes))); } // The ragged-all-to-all accepts an output tensor as a parameter to allow @@ -1987,6 +1949,74 @@ class RaggedAllToAllTest : public AsyncMemcpyCollectiveOps { return input_literal_ptrs; } + // Computes ragged tensor offsets based on the sizes of the ragged rows. + template + Array CalculateOffsetsFromSizes(const Array& sizes) { + int64_t num_replicas = sizes.dim(0); + int64_t num_updates_per_replica = sizes.dim(2); + Array offsets(sizes.dimensions()); + for (int i = 0; i < num_replicas; ++i) { + int64_t cur_offset = 0; + for (int j = 0; j < num_replicas; ++j) { + for (int k = 0; k < num_updates_per_replica; ++k) { + offsets(i, j, k) = cur_offset; + cur_offset += sizes(i, j, k); + } + } + } + return offsets; + } + + // Fill the input and output tensors with random data. An all-to-all is + // effectively a transpose. We generate a chunk of random data for each update + // of each pair of replicas and write the chunk starting from the (i, j, k) + // offset of the input tensor and starting from the (j, i, k) offset of the + // output tensor. + template + void FillWithRandomData(std::vector>& input_data, + std::vector>& output_data, + const Array& input_offsets, + const Array& output_offsets, + const Array& input_sizes) { + int64_t num_replicas = input_sizes.dim(0); + int64_t num_updates_per_replica = input_sizes.dim(2); + std::vector start_indices(input_data[0].num_dimensions()); + std::vector chunk_sizes{input_data[0].dimensions().begin(), + input_data[0].dimensions().end()}; + + for (int i = 0; i < num_replicas; ++i) { + for (int j = 0; j < num_replicas; ++j) { + for (int k = 0; k < num_updates_per_replica; ++k) { + chunk_sizes[0] = input_sizes(i, j, k); + + Array chunk_data(chunk_sizes); + chunk_data.FillRandomUniform( + 1, 127, + /*seed=*/(i * num_replicas + j) * num_updates_per_replica + k); + + start_indices[0] = input_offsets(i, j, k); + input_data[i].UpdateSlice(chunk_data, start_indices); + + start_indices[0] = output_offsets(i, j, k); + output_data[j].UpdateSlice(chunk_data, start_indices); + } + } + } + } + + // Returns a slice of input data that corresponds to the given replica. + template + Array GetReplicaSlice(int64_t replica_id, + const Array& data) { + int64_t num_replicas = data.dim(0); + int64_t num_updates_per_replica = data.dim(2); + Array replica_slice = + data.Slice({replica_id, 0, 0}, + {replica_id + 1, num_replicas, num_updates_per_replica}); + replica_slice.Reshape({num_replicas * num_updates_per_replica}); + return replica_slice; + } + // Literates for the input and output data, offset, and size parameters of the // ragged-all-to-all. Each vector contains one literal per replica. std::vector inputs_; @@ -2044,6 +2074,50 @@ XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_2GPUs) { EXPECT_TRUE(LiteralTestUtil::Equal(expected_outputs_[1], results[1])); } +XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_2GPUs_MultipleUpdates) { + absl::string_view kModuleReplicatedStr = R"( + HloModule module, num_partitions=1 + + ENTRY entry { + input = f32[8] parameter(0) + output = f32[8] parameter(1) + input_offsets = s32[4] parameter(2) + send_sizes = s32[4] parameter(3) + output_offsets = s32[4] parameter(4) + recv_sizes = s32[4] parameter(5) + ROOT ra2a = f32[8] ragged-all-to-all(input, output, input_offsets, + send_sizes, output_offsets, recv_sizes), replica_groups={{0,1}} + })"; + + const int64_t kNumReplicas = 2; + const int64_t kNumPartitions = 1; + if (test_runner().device_count() < kNumReplicas * kNumPartitions) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions + << " devices (" << test_runner().device_count() + << " available)"; + } + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas * kNumPartitions); + + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config)); + + CreateRandomTestData( + module.get(), /*input_sizes=*/{/*replica_0=*/{{1, 2}, {2, 1}}, + /*replica_1=*/{{3, 1}, {1, 1}}}); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + HloTestBase::ExecuteReplicated(std::move(module), GetInputLiteralPtrs(), + /*num_replicas=*/kNumReplicas, + /*run_hlo_passes=*/true, + /*device_assignment=*/nullptr)); + ASSERT_EQ(results.size(), kNumReplicas); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_outputs_[0], results[0])); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_outputs_[1], results[1])); +} + XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_2GPUs_MultiDimData) { absl::string_view kModuleReplicatedStr = R"( HloModule module, num_partitions=1 @@ -2159,19 +2233,20 @@ XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_8GPUs) { HloModule module, num_partitions=1 ENTRY entry { - input = f32[128, 5, 32] parameter(0) - output = f32[128, 5, 32] parameter(1) - input_offsets = s32[8] parameter(2) - send_sizes = s32[8] parameter(3) - output_offsets = s32[8] parameter(4) - recv_sizes = s32[8] parameter(5) - ROOT ra2a = f32[128, 5, 32] ragged-all-to-all(input, output, + input = f32[512, 5, 32] parameter(0) + output = f32[512, 5, 32] parameter(1) + input_offsets = s32[32] parameter(2) + send_sizes = s32[32] parameter(3) + output_offsets = s32[32] parameter(4) + recv_sizes = s32[32] parameter(5) + ROOT ra2a = f32[512, 5, 32] ragged-all-to-all(input, output, input_offsets, send_sizes, output_offsets, recv_sizes), replica_groups={{0,1,2,3,4,5,6,7}} })"; const int64_t kNumReplicas = 8; const int64_t kNumPartitions = 1; + const int64_t kNumUpdatesPerReplica = 4; if (test_runner().device_count() < kNumReplicas * kNumPartitions) { GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions << " devices (" << test_runner().device_count() @@ -2184,7 +2259,8 @@ XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_8GPUs) { TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config)); - Array input_sizes({kNumReplicas, kNumReplicas}); + Array input_sizes( + {kNumReplicas, kNumReplicas, kNumUpdatesPerReplica}); input_sizes.FillRandomUniform(0, 10); CreateRandomTestData(module.get(), input_sizes);