From 0c852dcab286724bcc7c6599579ddfc3f2297847 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Wed, 12 Feb 2025 07:01:00 -0800 Subject: [PATCH] Add control dependencies for peeled send/recv For send/recv we have to ensure that they ar enot pipelined beyond any conflicting collective. PiperOrigin-RevId: 726039862 --- xla/service/collective_conflict_analysis.h | 6 + xla/service/gpu/BUILD | 3 + xla/service/gpu/gpu_p2p_pipeliner.cc | 123 ++++++++++++++++++++- xla/service/gpu/gpu_p2p_pipeliner_test.cc | 98 ++++++++++++++++ 4 files changed, 229 insertions(+), 1 deletion(-) diff --git a/xla/service/collective_conflict_analysis.h b/xla/service/collective_conflict_analysis.h index 4c0218d2e7ba4..868bcff315cad 100644 --- a/xla/service/collective_conflict_analysis.h +++ b/xla/service/collective_conflict_analysis.h @@ -54,6 +54,12 @@ std::vector FindAllConflictingCollectives( const HloComputation* computation, const std::vector& seed_collectives); +inline std::vector FindAllConflictingCollectives( + HloInstruction* seed_collective) { + return FindAllConflictingCollectives(seed_collective->parent(), + {seed_collective}); +} + } // namespace xla #endif // XLA_SERVICE_COLLECTIVE_CONFLICT_ANALYSIS_H_ diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 08fae4015c367..dee798d709cd3 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -2242,11 +2242,14 @@ xla_cc_test( "//xla/hlo/testlib:filecheck", "//xla/service:hlo_module_config", "//xla/service:hlo_verifier", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest", + "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/gpu/gpu_p2p_pipeliner.cc b/xla/service/gpu/gpu_p2p_pipeliner.cc index fb04136ab1a29..3567786dc4e7e 100644 --- a/xla/service/gpu/gpu_p2p_pipeliner.cc +++ b/xla/service/gpu/gpu_p2p_pipeliner.cc @@ -27,7 +27,9 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" @@ -238,6 +240,120 @@ static absl::Status PostProcessRotatedSendRecvOps( return absl::OkStatus(); } +// For a peeled send/recv instruction, find the corresponding send/recv-done +// instruction after the while loop. +static HloInstruction* FindSendRecvDoneInstruction(HloInstruction* instr) { + CHECK(instr->opcode() == HloOpcode::kRecv || + instr->opcode() == HloOpcode::kSend); + CHECK_EQ(instr->user_count(), 1); + HloInstruction* candidate = instr->users().front(); + if (candidate->opcode() == HloOpcode::kTuple) { + HloInstruction* tuple_op = candidate; + int64_t i = tuple_op->operand_index(instr); + CHECK_EQ(tuple_op->user_count(), 1); + HloInstruction* while_op = tuple_op->users().front(); + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + for (HloInstruction* user : while_op->users()) { + HloGetTupleElementInstruction* gte_op = + DynCast(user); + if (gte_op == nullptr || gte_op->tuple_index() != i) continue; + CHECK_EQ(gte_op->user_count(), 1); + candidate = gte_op->users().front(); + break; + } + } + CHECK(candidate->opcode() == HloOpcode::kRecvDone || + candidate->opcode() == HloOpcode::kSendDone); + return candidate; +} + +static absl::Status AddControlDependencies( + std::vector& from_instructions, HloInstruction* to_instr) { + for (HloInstruction* from_instr : from_instructions) { + TF_RETURN_IF_ERROR(from_instr->AddControlDependencyTo(to_instr)); + } + return absl::OkStatus(); +} + +static absl::Status AddControlDependencies( + HloInstruction* from_instr, + absl::flat_hash_set& to_instructions) { + for (HloInstruction* to_instr : to_instructions) { + TF_RETURN_IF_ERROR(from_instr->AddControlDependencyTo(to_instr)); + } + return absl::OkStatus(); +} + +static absl::Status PostProcessPeeledSendRecvOps( + std::vector& peeled_send_recvs) { + // Convert to set for faster lookup. + absl::flat_hash_set peeled_send_recvs_set; + peeled_send_recvs_set.insert(peeled_send_recvs.begin(), + peeled_send_recvs.end()); + + // Add control dependencies between conflicting collectives and peeled + // send/recv ops. + for (HloInstruction* peeled_instr : peeled_send_recvs) { + CHECK(peeled_instr->opcode() == HloOpcode::kRecv || + peeled_instr->opcode() == HloOpcode::kSend); + + // Find all conflicting collectives that were not peeled out of the loop. + absl::flat_hash_set unpeeled_conflicting_collectives; + for (HloInstruction* instr : FindAllConflictingCollectives(peeled_instr)) { + if (peeled_send_recvs_set.contains(instr)) continue; + unpeeled_conflicting_collectives.insert(instr); + } + + // Find the while loop. + CHECK_EQ(peeled_instr->user_count(), 1); + HloInstruction* tuple_op = peeled_instr->users().front(); + CHECK_EQ(tuple_op->opcode(), HloOpcode::kTuple); + CHECK_EQ(tuple_op->user_count(), 1); + HloInstruction* while_op = tuple_op->users().front(); + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + + // We separated unpeeled conflicting collectives into two categories: + // 1. Those that may dominate the while loop (the while loop may have a data + // dependency on them, `may_dominate_while_loop`). + // 2. Those that are known to not dominate the while loop (remaining + // instructions in `unpeeled_conflicting_collectives`). + std::vector may_dominate_while_loop; + for (HloInstruction* instr : + while_op->parent()->MakeInstructionPostOrder()) { + // All instructions in post order that come after the while loop are known + // to not dominate it. + if (instr == while_op) { + break; + } + // If we're looking at an instruction that is an unpeeled conflicting + // collective, it is possible that it dominates the while loop. Move it + // into the first category set. + if (unpeeled_conflicting_collectives.contains(instr)) { + may_dominate_while_loop.push_back(instr); + unpeeled_conflicting_collectives.erase(instr); + } + } + + // Add control dependencies from dominating conflciting collectives to the + // peeled send/recv instruction. This guarantees that the conflicting + // collectives cannot slip in between the peeled send/recv instructions + // where it could cause a deadlock. + TF_RETURN_IF_ERROR( + AddControlDependencies(may_dominate_while_loop, peeled_instr)); + + // Add control dependencies from the final peeleled send/recv-done + // instruction to the conflicting collectives that are dominated by the + // while loop. This guarantees that the conflicting collectives cannot slip + // in between the peeled send/recv instructions where it could cause a + // deadlock. + HloInstruction* done_op = FindSendRecvDoneInstruction(peeled_instr); + TF_RETURN_IF_ERROR( + AddControlDependencies(done_op, unpeeled_conflicting_collectives)); + } + + return absl::OkStatus(); +} + absl::StatusOr GpuP2PPipeliner::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { @@ -249,10 +365,14 @@ absl::StatusOr GpuP2PPipeliner::Run( // If partial send/recv pipelining is enabled, collect send/recv instructions // for post-processing. + std::vector peeled_send_recvs; std::vector rotated_send_recvs; if (enable_partial_send_recv_pipelining_) { should_process = PipelineOnlySendRecvStart; - postprocess_backward_peeled_op = std::nullopt; + postprocess_backward_peeled_op = [&](HloInstruction* it) { + peeled_send_recvs.push_back(it); + return absl::OkStatus(); + }; postprocess_backward_rotated_op = [&](HloInstruction* it) { rotated_send_recvs.push_back(it); return absl::OkStatus(); @@ -283,6 +403,7 @@ absl::StatusOr GpuP2PPipeliner::Run( // Post-process rotated and peeled send/recv ops to add control dependencies // with conflicting collectives. TF_RETURN_IF_ERROR(PostProcessRotatedSendRecvOps(rotated_send_recvs)); + TF_RETURN_IF_ERROR(PostProcessPeeledSendRecvOps(peeled_send_recvs)); return changed; } diff --git a/xla/service/gpu/gpu_p2p_pipeliner_test.cc b/xla/service/gpu/gpu_p2p_pipeliner_test.cc index 5ce5ecfafe3a2..9e65016ad5211 100644 --- a/xla/service/gpu/gpu_p2p_pipeliner_test.cc +++ b/xla/service/gpu/gpu_p2p_pipeliner_test.cc @@ -34,6 +34,8 @@ limitations under the License. #include "xla/hlo/testlib/filecheck.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_verifier.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" @@ -41,6 +43,7 @@ namespace xla { namespace gpu { namespace { +namespace m = xla::match; using ::testing::UnorderedElementsAre; class GpuP2PPipelinerTest : public HloTestBase { @@ -478,6 +481,101 @@ TEST_F(GpuP2PPipelinerTest, OneSendRecvWithOneConflictingAllReduce) { UnorderedElementsAre(send_done_op)); } +TEST_F(GpuP2PPipelinerTest, + OneSendRecvWithConflictingAllReduceBeforeAndAfterLoop) { + const char* kHloStr = R"( + HloModule test + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + cond { + param = (u32[], f32[64]) parameter(0) + i = u32[] get-tuple-element(param), index=0 + n = u32[] constant(2) + ROOT result = pred[] compare(i, n), direction=LT + } + + body { + param = (u32[], f32[64]) parameter(0) + i = u32[] get-tuple-element(param), index=0 + data = f32[64] get-tuple-element(param), index=1 + + // Decomposed cp_fwd. + after-all = token[] after-all() + recv = (f32[64], u32[], token[]) recv(after-all), channel_id=1, + frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + send = (f32[64], u32[], token[]) send(data, after-all), channel_id=1, + frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}, + control-predecessors={recv} + recv_done = (f32[64], token[]) recv-done(recv), channel_id=1, + control-predecessors={send} + send_done = token[] send-done(send), channel_id=1, + control-predecessors={recv_done} + recv_data = f32[64] get-tuple-element(recv_done), index=0 + + + c1 = u32[] constant(1) + i_ = u32[] add(u32[] i, u32[] c1) + + ROOT result = (u32[], f32[64]) tuple(i_, recv_data) + } + + ENTRY entry { + c0 = u32[] constant(0) + a = f32[] constant(42) + data = f32[64] broadcast(a), dimensions={} + + // Conflicting all-reduce before loop. + ar = f32[64] all-reduce(data), channel_id=2, replica_groups={{0,1,2,3}}, + to_apply=add + + while_init = (u32[], f32[64]) tuple(c0, ar) + result = (u32[], f32[64]) while(while_init), condition=cond, + body=body + + // Conflicting all-reduce after loop. + while_dep_data = f32[64] get-tuple-element(result), index=1 + ROOT final_ar = f32[64] all-reduce(while_dep_data), channel_id=3, + replica_groups={{0,1,2,3}}, to_apply=add + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kHloStr, config_)); + + // Run pass. + TF_ASSERT_OK_AND_ASSIGN( + bool changed, RunOptimizer(module.get(), + /*enable_partial_send_recv_pipelining=*/true)); + EXPECT_TRUE(changed); + + // Find ops around the while loop. + HloInstruction* ar_op = FindInstruction(module.get(), "ar"); + HloInstruction* recv_op = FindInstruction(module.get(), "recv.1"); + HloInstruction* send_op = FindInstruction(module.get(), "send.1"); + HloInstruction* while_op = FindInstruction(module.get(), "while"); + HloInstruction* recv_done_op = FindInstruction(module.get(), "recv_done.2"); + HloInstruction* send_done_op = FindInstruction(module.get(), "send_done.2"); + HloInstruction* final_ar_op = FindInstruction(module.get(), "final_ar"); + EXPECT_THAT(while_op, GmockMatch(m::While(m::Tuple( + m::Op(), m::Op().Is(ar_op), m::Op().Is(recv_op), + m::Op().Is(send_op), m::Op())))); + + // Expect control dependencies from conflicting all-reduce before the while + // loop to send/recv, expect control dependencies from send/recv-done to + // conflicting all-reduce after the loop. + EXPECT_THAT(recv_op->control_predecessors(), UnorderedElementsAre(ar_op)); + EXPECT_THAT(send_op->control_predecessors(), + UnorderedElementsAre(recv_op, ar_op)); + EXPECT_THAT(send_done_op->control_predecessors(), + UnorderedElementsAre(recv_done_op)); + EXPECT_THAT(final_ar_op->control_predecessors(), + UnorderedElementsAre(recv_done_op, send_done_op)); +} + } // namespace } // namespace gpu } // namespace xla