Skip to content

Commit

Permalink
[XLA] Relax HLO verifier restriction on channel ids since they are me…
Browse files Browse the repository at this point in the history
…aningless in spmd programs.

PiperOrigin-RevId: 726156241
  • Loading branch information
blakehechtman authored and Google-ML-Automation committed Feb 12, 2025
1 parent a40dbd6 commit df9764a
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 31 deletions.
27 changes: 0 additions & 27 deletions xla/service/hlo_verifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2511,9 +2511,6 @@ absl::Status VerifyOriginalValue(const HloModule& module) {
// collectives).
absl::Status VerifyChannels(const HloModule& module,
const HloVerifierOpts& opts) {
absl::flat_hash_map<int64_t, std::vector<const HloInstruction*>>
channel_instructions;

// Send/recv instruction must have a unique user. If it is the corresponding
// send-done/recv-done operation, channel IDs must match.
for (const HloComputation* computation : module.computations()) {
Expand All @@ -2522,7 +2519,6 @@ absl::Status VerifyChannels(const HloModule& module,
if (!channel_instr || !channel_instr->channel_id()) {
continue;
}
channel_instructions[*channel_instr->channel_id()].push_back(instruction);

switch (instruction->opcode()) {
case HloOpcode::kSend: {
Expand Down Expand Up @@ -2565,29 +2561,6 @@ absl::Status VerifyChannels(const HloModule& module,
}
}

// Iterate over each channel to check invariants.
for (auto& [channel_id, instructions] : channel_instructions) {
const HloInstruction* first = instructions[0];
if (const auto* sendrecv = DynCast<HloSendRecvInstruction>(first)) {
absl::flat_hash_set<HloOpcode> opcodes;
for (const HloInstruction* instr : instructions) {
opcodes.insert(instr->opcode());
auto cast = DynCast<HloSendRecvInstruction>(instr);
TF_RET_CHECK(cast != nullptr)
<< "channel " << channel_id
<< " is used for different types of channel instructions";
}
} else {
if (opts.verify_unique_channel_ids) {
for (const HloInstruction* instr : instructions) {
TF_RET_CHECK(first->opcode() == instr->opcode())
<< "channel " << channel_id
<< " is used for different types of channel instructions";
}
}
}
}

return absl::OkStatus();
}

Expand Down
6 changes: 2 additions & 4 deletions xla/service/hlo_verifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2325,8 +2325,7 @@ TEST_F(HloVerifierTest, ChannelVerifier) {
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
EXPECT_THAT(verifier().Run(module.get()).status().message(),
HasSubstr("used for different types of channel instructions"));
TF_ASSERT_OK(verifier().Run(module.get()));
}

TEST_F(HloVerifierTest, ChannelVerifierPartiallyPipelinedAsyncRecv) {
Expand Down Expand Up @@ -2527,8 +2526,7 @@ TEST_F(HloVerifierTest, CollectiveChannelVerifier) {
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
EXPECT_THAT(verifier().Run(module.get()).status().message(),
HasSubstr("used for different types of channel instructions"));
TF_ASSERT_OK(verifier().Run(module.get()));
}

TEST_F(HloVerifierTestLayoutSensitive, CollectivePermuteStartAndDone) {
Expand Down

0 comments on commit df9764a

Please sign in to comment.