From f99bc570d6c270c58a1e1667c7323f669c8ca27b Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Thu, 13 Feb 2025 23:11:22 -0800 Subject: [PATCH] Refactor `SpmdPartitioningVisitor::HandleReshape`. No behavior change. This change recovers cl/717991433 with modification. The previous one is not a pure refactoring since it assumes that the inference `in_sharding_1 -> out_sharding -> in_sharding_2` will have `in_sharding_1 == in_sharding_2`. This assumption may be false. In the added test target, we reshape 24 -> 6x4, and have the following inferred shardings. ``` in_sharding_1: [4] out_sharding: [2,1,2] last_tile_dim_replicate in_sharding_2: [2,2] last_tile_dim_replicate ``` This change should a pure refactoring without behavior change. PiperOrigin-RevId: 726786386 --- xla/service/spmd/spmd_partitioner.cc | 66 ++++++++++++----------- xla/service/spmd/spmd_partitioner_test.cc | 28 ++++++++++ 2 files changed, 62 insertions(+), 32 deletions(-) diff --git a/xla/service/spmd/spmd_partitioner.cc b/xla/service/spmd/spmd_partitioner.cc index 35420f237e26a..10627460e2c9b 100644 --- a/xla/service/spmd/spmd_partitioner.cc +++ b/xla/service/spmd/spmd_partitioner.cc @@ -3095,44 +3095,46 @@ absl::Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) { return DefaultAction(hlo); } + const Shape& in_shape = hlo->operand(0)->shape(); + const Shape& out_shape = hlo->shape(); auto operand = GetPartitionedHlo(hlo->operand(0)); - auto desired_operand = [&](const HloSharding& output_sharding) - -> std::optional { - // The output shape is the source and the operand shape is the target to get - // desired_operand_sharding. - std::optional desired_operand_sharding = - hlo_sharding_util::ReshapeSharding( - hlo->shape(), hlo->operand(0)->shape(), output_sharding); - if (desired_operand_sharding.has_value() && - output_sharding.NumTiles() == desired_operand_sharding->NumTiles()) { - return b_.AddInstruction(hlo->CloneWithNewOperands( - MakePartitionedShape(hlo->shape(), output_sharding), - {operand.Reshard(*desired_operand_sharding).hlo()})); + + std::vector> sharding_pairs; + auto insert_sharding_pair = [&](const HloSharding& in_sharding, + const HloSharding& out_sharding) { + if (in_sharding.NumTiles() == out_sharding.NumTiles()) { + sharding_pairs.push_back(std::make_pair(in_sharding, out_sharding)); } - return std::nullopt; }; - // Try the original output sharding at first. - if (auto operand_hlo = desired_operand(hlo->sharding())) { - SetPartitionedHlo(hlo, [&] { return *operand_hlo; }); + if (std::optional in_sharding = + hlo_sharding_util::ReshapeSharding(out_shape, in_shape, sharding)) { + insert_sharding_pair(std::move(*in_sharding), sharding); + } + if (std::optional out_sharding = + hlo_sharding_util::ReshapeSharding(in_shape, out_shape, + operand.sharding())) { + if (std::optional in_sharding = + hlo_sharding_util::ReshapeSharding(out_shape, in_shape, + *out_sharding)) { + // `in_sharding` and `operand.sharding()` may be different. + insert_sharding_pair(std::move(*in_sharding), std::move(*out_sharding)); + } + } + + if (!sharding_pairs.empty()) { + const auto& [in_sharding, out_sharding] = sharding_pairs[0]; + PartitionedHlo reshard_input = operand.Reshard(in_sharding); + HloInstruction* reshape = b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), out_sharding), + {reshard_input.hlo()})); + reshape->set_sharding(out_sharding); + PartitionedHlo reshard_reshape = + PartitionedHlo(reshape, hlo->shape(), MakePartitioningState()) + .Reshard(sharding); + SetPartitionedHlo(hlo, [&] { return reshard_reshape.hlo(); }); return absl::OkStatus(); } -// Then try the desired_output_sharding. - std::optional desired_output_sharding = - hlo_sharding_util::ReshapeSharding(hlo->operand(0)->shape(), hlo->shape(), - operand.sharding()); - if (desired_output_sharding.has_value()) { - if (auto operand_hlo = desired_operand(*desired_output_sharding)) { - (*operand_hlo)->set_sharding(*desired_output_sharding); - SetPartitionedHlo(hlo, [&] { - return PartitionedHlo(*operand_hlo, hlo->shape(), - MakePartitioningState()) - .Reshard(hlo->sharding()) - .hlo(); - }); - return absl::OkStatus(); - } - } auto shard_reshape = [](PartitionedHlo& operand, const HloSharding& sharding, diff --git a/xla/service/spmd/spmd_partitioner_test.cc b/xla/service/spmd/spmd_partitioner_test.cc index 61fcd5a83154d..ff02b2e1a12f7 100644 --- a/xla/service/spmd/spmd_partitioner_test.cc +++ b/xla/service/spmd/spmd_partitioner_test.cc @@ -3960,6 +3960,34 @@ ENTRY %reshape { op::Shape("bf16[40,16,8]"))); } +TEST_P(SpmdPartitioningTest, ReshapeWithReshard5) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY %reshape { + p0 = bf16[24] parameter(0), sharding={devices=[4]<=[4]} + ROOT reshape = bf16[6,4] reshape(p0), sharding={devices=[4,1]<=[4]} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + + auto param = AllOf(op::Parameter(0), op::Shape("bf16[6]")); + // Reshard param from {devices=[4]<=[4]} to {devices=[2,2]<=[4] + // last_tile_dim_replicate} + auto reshard_param = AllOf(op::AllReduce(op::DynamicUpdateSlice(_, param, _)), + op::Shape("bf16[12]")); + + auto reshape = AllOf(op::Reshape(reshard_param), op::Shape("bf16[3,4]")); + + // Reshard reshape from {devices=[2,1,2]<=[4] last_tile_dim_replicate} to + // {devices=[4,1]<=[4]} + auto concat = op::Concatenate( + reshape, op::Pad(op::CollectivePermute(op::Slice(reshape)), _)); + auto reshard_reshape = op::DynamicSlice(op::DynamicSlice(concat, _, _), _, _); + EXPECT_THAT(module->entry_computation()->root_instruction(), reshard_reshape); +} + TEST_P(SpmdPartitioningTest, PartialReplicateShardableReshape) { absl::string_view hlo_string = R"( HloModule module