From 1c9c163e287162acfabdc00723f7e537486274e1 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Wed, 12 Feb 2025 13:08:17 -0800 Subject: [PATCH] Add test with two loops PiperOrigin-RevId: 726168257 --- xla/service/gpu/gpu_p2p_pipeliner_test.cc | 126 ++++++++++++++++++++++ 1 file changed, 126 insertions(+) diff --git a/xla/service/gpu/gpu_p2p_pipeliner_test.cc b/xla/service/gpu/gpu_p2p_pipeliner_test.cc index 9e65016ad5211..a8f85039d6146 100644 --- a/xla/service/gpu/gpu_p2p_pipeliner_test.cc +++ b/xla/service/gpu/gpu_p2p_pipeliner_test.cc @@ -576,6 +576,132 @@ TEST_F(GpuP2PPipelinerTest, UnorderedElementsAre(recv_done_op, send_done_op)); } +TEST_F(GpuP2PPipelinerTest, TwoLoopsWithConflictingAllReduces) { + const char* kHloStr = R"( + HloModule test + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + cond { + param = (u32[], f32[64]) parameter(0) + i = u32[] get-tuple-element(param), index=0 + n = u32[] constant(2) + ROOT result = pred[] compare(i, n), direction=LT + } + + body { + param = (u32[], f32[64]) parameter(0) + i = u32[] get-tuple-element(param), index=0 + data = f32[64] get-tuple-element(param), index=1 + + // Decomposed cp_fwd. + after-all = token[] after-all() + recv = (f32[64], u32[], token[]) recv(after-all), channel_id=1, + frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + send = (f32[64], u32[], token[]) send(data, after-all), channel_id=2, + frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}, + control-predecessors={recv} + recv_done = (f32[64], token[]) recv-done(recv), channel_id=1, + control-predecessors={send} + send_done = token[] send-done(send), channel_id=2, + control-predecessors={recv_done} + recv_data = f32[64] get-tuple-element(recv_done), index=0 + + + c1 = u32[] constant(1) + i_ = u32[] add(u32[] i, u32[] c1) + + ROOT result = (u32[], f32[64]) tuple(i_, recv_data) + } + + ENTRY entry { + c0 = u32[] constant(0) + a = f32[] constant(42) + data = f32[64] broadcast(a), dimensions={} + + // Conflicting all-reduce before loop. + ar = f32[64] all-reduce(data), channel_id=3, replica_groups={{0,1,2,3}}, + to_apply=add + + while_a_init = (u32[], f32[64]) tuple(c0, ar) + while_a = (u32[], f32[64]) while(while_a_init), condition=cond, body=body + + // Conflicting all-reduce after loop. + while_a_dep_data = f32[64] get-tuple-element(while_a), index=1 + sandwitched_ar = f32[64] all-reduce(while_a_dep_data), channel_id=4, + replica_groups={{0,1,2,3}}, to_apply=add + + while_b_init = (u32[], f32[64]) tuple(c0, sandwitched_ar) + while_b = (u32[], f32[64]) while(while_b_init), condition=cond, body=body + + // Conflicting all-reduce after loop. + while_b_dep_data = f32[64] get-tuple-element(while_b), index=1 + ROOT final_ar = f32[64] all-reduce(while_b_dep_data), channel_id=5, + replica_groups={{0,1,2,3}}, to_apply=add + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kHloStr, config_)); + + // Run pass. + TF_ASSERT_OK_AND_ASSIGN( + bool changed, RunOptimizer(module.get(), + /*enable_partial_send_recv_pipelining=*/true)); + EXPECT_TRUE(changed); + + // Find ops around the while loop. + HloInstruction* ar_op = FindInstruction(module.get(), "ar"); + HloInstruction* recv_a_op = FindInstruction(module.get(), "recv.1"); + HloInstruction* send_a_op = FindInstruction(module.get(), "send.1"); + HloInstruction* recv_done_a_op = FindInstruction(module.get(), "recv_done.2"); + HloInstruction* send_done_a_op = FindInstruction(module.get(), "send_done.2"); + HloInstruction* sandwitched_ar_op = + FindInstruction(module.get(), "sandwitched_ar"); + HloInstruction* recv_b_op = FindInstruction(module.get(), "recv.3"); + HloInstruction* send_b_op = FindInstruction(module.get(), "send.3"); + HloInstruction* recv_done_b_op = FindInstruction(module.get(), "recv_done.4"); + HloInstruction* send_done_b_op = FindInstruction(module.get(), "send_done.4"); + HloInstruction* final_ar_op = FindInstruction(module.get(), "final_ar"); + + // Find the two while loops. + HloInstruction* while_a_op = FindInstruction(module.get(), "while"); + HloInstruction* while_b_op = FindInstruction(module.get(), "while.1"); + + // Assert relation between send/recv ops and while loops. + EXPECT_THAT(while_a_op, GmockMatch(m::While(m::Tuple( + m::Op(), m::Op().Is(ar_op), m::Op().Is(recv_a_op), + m::Op().Is(send_a_op), m::Op())))); + EXPECT_THAT(while_b_op, + GmockMatch(m::While(m::Tuple( + m::Op(), m::Op().Is(sandwitched_ar_op), m::Op().Is(recv_b_op), + m::Op().Is(send_b_op), m::Op())))); + + // Expect control dependencies between loop-domincating conflicting + // collectives and peeled send/recv ops. Also, expect control dependencies + // between corresponding send/recv-done ops and all other conflicting + // collectives. + EXPECT_THAT(recv_a_op->control_predecessors(), UnorderedElementsAre(ar_op)); + EXPECT_THAT(send_a_op->control_predecessors(), + UnorderedElementsAre(ar_op, recv_a_op)); + EXPECT_THAT(send_done_a_op->control_predecessors(), + UnorderedElementsAre(recv_done_a_op)); + EXPECT_THAT(sandwitched_ar_op->control_predecessors(), + UnorderedElementsAre(recv_done_a_op, send_done_a_op)); + EXPECT_THAT(recv_b_op->control_predecessors(), + UnorderedElementsAre(ar_op, sandwitched_ar_op)); + EXPECT_THAT(send_b_op->control_predecessors(), + UnorderedElementsAre(ar_op, sandwitched_ar_op, recv_b_op)); + EXPECT_THAT(send_done_b_op->control_predecessors(), + UnorderedElementsAre(recv_done_b_op)); + EXPECT_THAT(final_ar_op->control_predecessors(), + UnorderedElementsAre(send_done_b_op, recv_done_b_op, + recv_done_a_op, send_done_a_op)); +} + } // namespace } // namespace gpu } // namespace xla