Skip to content

Commit

Permalink
Add test with two loops
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726168257
  • Loading branch information
frgossen authored and Google-ML-Automation committed Feb 12, 2025
1 parent 803886b commit 1c9c163
Showing 1 changed file with 126 additions and 0 deletions.
126 changes: 126 additions & 0 deletions xla/service/gpu/gpu_p2p_pipeliner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,132 @@ TEST_F(GpuP2PPipelinerTest,
UnorderedElementsAre(recv_done_op, send_done_op));
}

TEST_F(GpuP2PPipelinerTest, TwoLoopsWithConflictingAllReduces) {
const char* kHloStr = R"(
HloModule test
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
cond {
param = (u32[], f32[64]) parameter(0)
i = u32[] get-tuple-element(param), index=0
n = u32[] constant(2)
ROOT result = pred[] compare(i, n), direction=LT
}
body {
param = (u32[], f32[64]) parameter(0)
i = u32[] get-tuple-element(param), index=0
data = f32[64] get-tuple-element(param), index=1
// Decomposed cp_fwd.
after-all = token[] after-all()
recv = (f32[64], u32[], token[]) recv(after-all), channel_id=1,
frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}
send = (f32[64], u32[], token[]) send(data, after-all), channel_id=2,
frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}},
control-predecessors={recv}
recv_done = (f32[64], token[]) recv-done(recv), channel_id=1,
control-predecessors={send}
send_done = token[] send-done(send), channel_id=2,
control-predecessors={recv_done}
recv_data = f32[64] get-tuple-element(recv_done), index=0
c1 = u32[] constant(1)
i_ = u32[] add(u32[] i, u32[] c1)
ROOT result = (u32[], f32[64]) tuple(i_, recv_data)
}
ENTRY entry {
c0 = u32[] constant(0)
a = f32[] constant(42)
data = f32[64] broadcast(a), dimensions={}
// Conflicting all-reduce before loop.
ar = f32[64] all-reduce(data), channel_id=3, replica_groups={{0,1,2,3}},
to_apply=add
while_a_init = (u32[], f32[64]) tuple(c0, ar)
while_a = (u32[], f32[64]) while(while_a_init), condition=cond, body=body
// Conflicting all-reduce after loop.
while_a_dep_data = f32[64] get-tuple-element(while_a), index=1
sandwitched_ar = f32[64] all-reduce(while_a_dep_data), channel_id=4,
replica_groups={{0,1,2,3}}, to_apply=add
while_b_init = (u32[], f32[64]) tuple(c0, sandwitched_ar)
while_b = (u32[], f32[64]) while(while_b_init), condition=cond, body=body
// Conflicting all-reduce after loop.
while_b_dep_data = f32[64] get-tuple-element(while_b), index=1
ROOT final_ar = f32[64] all-reduce(while_b_dep_data), channel_id=5,
replica_groups={{0,1,2,3}}, to_apply=add
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kHloStr, config_));

// Run pass.
TF_ASSERT_OK_AND_ASSIGN(
bool changed, RunOptimizer(module.get(),
/*enable_partial_send_recv_pipelining=*/true));
EXPECT_TRUE(changed);

// Find ops around the while loop.
HloInstruction* ar_op = FindInstruction(module.get(), "ar");
HloInstruction* recv_a_op = FindInstruction(module.get(), "recv.1");
HloInstruction* send_a_op = FindInstruction(module.get(), "send.1");
HloInstruction* recv_done_a_op = FindInstruction(module.get(), "recv_done.2");
HloInstruction* send_done_a_op = FindInstruction(module.get(), "send_done.2");
HloInstruction* sandwitched_ar_op =
FindInstruction(module.get(), "sandwitched_ar");
HloInstruction* recv_b_op = FindInstruction(module.get(), "recv.3");
HloInstruction* send_b_op = FindInstruction(module.get(), "send.3");
HloInstruction* recv_done_b_op = FindInstruction(module.get(), "recv_done.4");
HloInstruction* send_done_b_op = FindInstruction(module.get(), "send_done.4");
HloInstruction* final_ar_op = FindInstruction(module.get(), "final_ar");

// Find the two while loops.
HloInstruction* while_a_op = FindInstruction(module.get(), "while");
HloInstruction* while_b_op = FindInstruction(module.get(), "while.1");

// Assert relation between send/recv ops and while loops.
EXPECT_THAT(while_a_op, GmockMatch(m::While(m::Tuple(
m::Op(), m::Op().Is(ar_op), m::Op().Is(recv_a_op),
m::Op().Is(send_a_op), m::Op()))));
EXPECT_THAT(while_b_op,
GmockMatch(m::While(m::Tuple(
m::Op(), m::Op().Is(sandwitched_ar_op), m::Op().Is(recv_b_op),
m::Op().Is(send_b_op), m::Op()))));

// Expect control dependencies between loop-domincating conflicting
// collectives and peeled send/recv ops. Also, expect control dependencies
// between corresponding send/recv-done ops and all other conflicting
// collectives.
EXPECT_THAT(recv_a_op->control_predecessors(), UnorderedElementsAre(ar_op));
EXPECT_THAT(send_a_op->control_predecessors(),
UnorderedElementsAre(ar_op, recv_a_op));
EXPECT_THAT(send_done_a_op->control_predecessors(),
UnorderedElementsAre(recv_done_a_op));
EXPECT_THAT(sandwitched_ar_op->control_predecessors(),
UnorderedElementsAre(recv_done_a_op, send_done_a_op));
EXPECT_THAT(recv_b_op->control_predecessors(),
UnorderedElementsAre(ar_op, sandwitched_ar_op));
EXPECT_THAT(send_b_op->control_predecessors(),
UnorderedElementsAre(ar_op, sandwitched_ar_op, recv_b_op));
EXPECT_THAT(send_done_b_op->control_predecessors(),
UnorderedElementsAre(recv_done_b_op));
EXPECT_THAT(final_ar_op->control_predecessors(),
UnorderedElementsAre(send_done_b_op, recv_done_b_op,
recv_done_a_op, send_done_a_op));
}

} // namespace
} // namespace gpu
} // namespace xla

0 comments on commit 1c9c163

Please sign in to comment.