Skip to content

Commit

Permalink
Add control dependencies for peeled send/recv
Browse files Browse the repository at this point in the history
For send/recv we have to ensure that they ar enot pipelined beyond any conflicting collective.

PiperOrigin-RevId: 726039862
  • Loading branch information
frgossen authored and Google-ML-Automation committed Feb 12, 2025
1 parent e7e6f6a commit 0c852dc
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 1 deletion.
6 changes: 6 additions & 0 deletions xla/service/collective_conflict_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ std::vector<HloInstruction*> FindAllConflictingCollectives(
const HloComputation* computation,
const std::vector<HloInstruction*>& seed_collectives);

inline std::vector<HloInstruction*> FindAllConflictingCollectives(
HloInstruction* seed_collective) {
return FindAllConflictingCollectives(seed_collective->parent(),
{seed_collective});
}

} // namespace xla

#endif // XLA_SERVICE_COLLECTIVE_CONFLICT_ANALYSIS_H_
3 changes: 3 additions & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2242,11 +2242,14 @@ xla_cc_test(
"//xla/hlo/testlib:filecheck",
"//xla/service:hlo_module_config",
"//xla/service:hlo_verifier",
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_googletest//:gtest",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
123 changes: 122 additions & 1 deletion xla/service/gpu/gpu_p2p_pipeliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/hlo/pass/hlo_pass_pipeline.h"
Expand Down Expand Up @@ -238,6 +240,120 @@ static absl::Status PostProcessRotatedSendRecvOps(
return absl::OkStatus();
}

// For a peeled send/recv instruction, find the corresponding send/recv-done
// instruction after the while loop.
static HloInstruction* FindSendRecvDoneInstruction(HloInstruction* instr) {
CHECK(instr->opcode() == HloOpcode::kRecv ||
instr->opcode() == HloOpcode::kSend);
CHECK_EQ(instr->user_count(), 1);
HloInstruction* candidate = instr->users().front();
if (candidate->opcode() == HloOpcode::kTuple) {
HloInstruction* tuple_op = candidate;
int64_t i = tuple_op->operand_index(instr);
CHECK_EQ(tuple_op->user_count(), 1);
HloInstruction* while_op = tuple_op->users().front();
CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
for (HloInstruction* user : while_op->users()) {
HloGetTupleElementInstruction* gte_op =
DynCast<HloGetTupleElementInstruction>(user);
if (gte_op == nullptr || gte_op->tuple_index() != i) continue;
CHECK_EQ(gte_op->user_count(), 1);
candidate = gte_op->users().front();
break;
}
}
CHECK(candidate->opcode() == HloOpcode::kRecvDone ||
candidate->opcode() == HloOpcode::kSendDone);
return candidate;
}

static absl::Status AddControlDependencies(
std::vector<HloInstruction*>& from_instructions, HloInstruction* to_instr) {
for (HloInstruction* from_instr : from_instructions) {
TF_RETURN_IF_ERROR(from_instr->AddControlDependencyTo(to_instr));
}
return absl::OkStatus();
}

static absl::Status AddControlDependencies(
HloInstruction* from_instr,
absl::flat_hash_set<HloInstruction*>& to_instructions) {
for (HloInstruction* to_instr : to_instructions) {
TF_RETURN_IF_ERROR(from_instr->AddControlDependencyTo(to_instr));
}
return absl::OkStatus();
}

static absl::Status PostProcessPeeledSendRecvOps(
std::vector<HloInstruction*>& peeled_send_recvs) {
// Convert to set for faster lookup.
absl::flat_hash_set<HloInstruction*> peeled_send_recvs_set;
peeled_send_recvs_set.insert(peeled_send_recvs.begin(),
peeled_send_recvs.end());

// Add control dependencies between conflicting collectives and peeled
// send/recv ops.
for (HloInstruction* peeled_instr : peeled_send_recvs) {
CHECK(peeled_instr->opcode() == HloOpcode::kRecv ||
peeled_instr->opcode() == HloOpcode::kSend);

// Find all conflicting collectives that were not peeled out of the loop.
absl::flat_hash_set<HloInstruction*> unpeeled_conflicting_collectives;
for (HloInstruction* instr : FindAllConflictingCollectives(peeled_instr)) {
if (peeled_send_recvs_set.contains(instr)) continue;
unpeeled_conflicting_collectives.insert(instr);
}

// Find the while loop.
CHECK_EQ(peeled_instr->user_count(), 1);
HloInstruction* tuple_op = peeled_instr->users().front();
CHECK_EQ(tuple_op->opcode(), HloOpcode::kTuple);
CHECK_EQ(tuple_op->user_count(), 1);
HloInstruction* while_op = tuple_op->users().front();
CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);

// We separated unpeeled conflicting collectives into two categories:
// 1. Those that may dominate the while loop (the while loop may have a data
// dependency on them, `may_dominate_while_loop`).
// 2. Those that are known to not dominate the while loop (remaining
// instructions in `unpeeled_conflicting_collectives`).
std::vector<HloInstruction*> may_dominate_while_loop;
for (HloInstruction* instr :
while_op->parent()->MakeInstructionPostOrder()) {
// All instructions in post order that come after the while loop are known
// to not dominate it.
if (instr == while_op) {
break;
}
// If we're looking at an instruction that is an unpeeled conflicting
// collective, it is possible that it dominates the while loop. Move it
// into the first category set.
if (unpeeled_conflicting_collectives.contains(instr)) {
may_dominate_while_loop.push_back(instr);
unpeeled_conflicting_collectives.erase(instr);
}
}

// Add control dependencies from dominating conflciting collectives to the
// peeled send/recv instruction. This guarantees that the conflicting
// collectives cannot slip in between the peeled send/recv instructions
// where it could cause a deadlock.
TF_RETURN_IF_ERROR(
AddControlDependencies(may_dominate_while_loop, peeled_instr));

// Add control dependencies from the final peeleled send/recv-done
// instruction to the conflicting collectives that are dominated by the
// while loop. This guarantees that the conflicting collectives cannot slip
// in between the peeled send/recv instructions where it could cause a
// deadlock.
HloInstruction* done_op = FindSendRecvDoneInstruction(peeled_instr);
TF_RETURN_IF_ERROR(
AddControlDependencies(done_op, unpeeled_conflicting_collectives));
}

return absl::OkStatus();
}

absl::StatusOr<bool> GpuP2PPipeliner::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
Expand All @@ -249,10 +365,14 @@ absl::StatusOr<bool> GpuP2PPipeliner::Run(

// If partial send/recv pipelining is enabled, collect send/recv instructions
// for post-processing.
std::vector<HloInstruction*> peeled_send_recvs;
std::vector<HloInstruction*> rotated_send_recvs;
if (enable_partial_send_recv_pipelining_) {
should_process = PipelineOnlySendRecvStart;
postprocess_backward_peeled_op = std::nullopt;
postprocess_backward_peeled_op = [&](HloInstruction* it) {
peeled_send_recvs.push_back(it);
return absl::OkStatus();
};
postprocess_backward_rotated_op = [&](HloInstruction* it) {
rotated_send_recvs.push_back(it);
return absl::OkStatus();
Expand Down Expand Up @@ -283,6 +403,7 @@ absl::StatusOr<bool> GpuP2PPipeliner::Run(
// Post-process rotated and peeled send/recv ops to add control dependencies
// with conflicting collectives.
TF_RETURN_IF_ERROR(PostProcessRotatedSendRecvOps(rotated_send_recvs));
TF_RETURN_IF_ERROR(PostProcessPeeledSendRecvOps(peeled_send_recvs));

return changed;
}
Expand Down
98 changes: 98 additions & 0 deletions xla/service/gpu/gpu_p2p_pipeliner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@ limitations under the License.
#include "xla/hlo/testlib/filecheck.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/hlo_verifier.h"
#include "xla/service/pattern_matcher.h"
#include "xla/service/pattern_matcher_gmock.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/util.h"

namespace xla {
namespace gpu {
namespace {

namespace m = xla::match;
using ::testing::UnorderedElementsAre;

class GpuP2PPipelinerTest : public HloTestBase {
Expand Down Expand Up @@ -478,6 +481,101 @@ TEST_F(GpuP2PPipelinerTest, OneSendRecvWithOneConflictingAllReduce) {
UnorderedElementsAre(send_done_op));
}

TEST_F(GpuP2PPipelinerTest,
OneSendRecvWithConflictingAllReduceBeforeAndAfterLoop) {
const char* kHloStr = R"(
HloModule test
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
cond {
param = (u32[], f32[64]) parameter(0)
i = u32[] get-tuple-element(param), index=0
n = u32[] constant(2)
ROOT result = pred[] compare(i, n), direction=LT
}
body {
param = (u32[], f32[64]) parameter(0)
i = u32[] get-tuple-element(param), index=0
data = f32[64] get-tuple-element(param), index=1
// Decomposed cp_fwd.
after-all = token[] after-all()
recv = (f32[64], u32[], token[]) recv(after-all), channel_id=1,
frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}
send = (f32[64], u32[], token[]) send(data, after-all), channel_id=1,
frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}},
control-predecessors={recv}
recv_done = (f32[64], token[]) recv-done(recv), channel_id=1,
control-predecessors={send}
send_done = token[] send-done(send), channel_id=1,
control-predecessors={recv_done}
recv_data = f32[64] get-tuple-element(recv_done), index=0
c1 = u32[] constant(1)
i_ = u32[] add(u32[] i, u32[] c1)
ROOT result = (u32[], f32[64]) tuple(i_, recv_data)
}
ENTRY entry {
c0 = u32[] constant(0)
a = f32[] constant(42)
data = f32[64] broadcast(a), dimensions={}
// Conflicting all-reduce before loop.
ar = f32[64] all-reduce(data), channel_id=2, replica_groups={{0,1,2,3}},
to_apply=add
while_init = (u32[], f32[64]) tuple(c0, ar)
result = (u32[], f32[64]) while(while_init), condition=cond,
body=body
// Conflicting all-reduce after loop.
while_dep_data = f32[64] get-tuple-element(result), index=1
ROOT final_ar = f32[64] all-reduce(while_dep_data), channel_id=3,
replica_groups={{0,1,2,3}}, to_apply=add
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kHloStr, config_));

// Run pass.
TF_ASSERT_OK_AND_ASSIGN(
bool changed, RunOptimizer(module.get(),
/*enable_partial_send_recv_pipelining=*/true));
EXPECT_TRUE(changed);

// Find ops around the while loop.
HloInstruction* ar_op = FindInstruction(module.get(), "ar");
HloInstruction* recv_op = FindInstruction(module.get(), "recv.1");
HloInstruction* send_op = FindInstruction(module.get(), "send.1");
HloInstruction* while_op = FindInstruction(module.get(), "while");
HloInstruction* recv_done_op = FindInstruction(module.get(), "recv_done.2");
HloInstruction* send_done_op = FindInstruction(module.get(), "send_done.2");
HloInstruction* final_ar_op = FindInstruction(module.get(), "final_ar");
EXPECT_THAT(while_op, GmockMatch(m::While(m::Tuple(
m::Op(), m::Op().Is(ar_op), m::Op().Is(recv_op),
m::Op().Is(send_op), m::Op()))));

// Expect control dependencies from conflicting all-reduce before the while
// loop to send/recv, expect control dependencies from send/recv-done to
// conflicting all-reduce after the loop.
EXPECT_THAT(recv_op->control_predecessors(), UnorderedElementsAre(ar_op));
EXPECT_THAT(send_op->control_predecessors(),
UnorderedElementsAre(recv_op, ar_op));
EXPECT_THAT(send_done_op->control_predecessors(),
UnorderedElementsAre(recv_done_op));
EXPECT_THAT(final_ar_op->control_predecessors(),
UnorderedElementsAre(recv_done_op, send_done_op));
}

} // namespace
} // namespace gpu
} // namespace xla

0 comments on commit 0c852dc

Please sign in to comment.