Skip to content

Commit

Permalink
[XLA] Don't check recursively for thread mismatches
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 725849380
  • Loading branch information
vsytch authored and Google-ML-Automation committed Feb 12, 2025
1 parent 8c3aea8 commit d6be12c
Showing 1 changed file with 7 additions and 37 deletions.
44 changes: 7 additions & 37 deletions xla/service/hlo_verifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,27 +106,6 @@ int64_t GetSubgroupSize(HloCollectiveInstruction* hlo,
}
}

absl::Status CheckNestedComputationThreadNameEqual(
const HloComputation* comp, bool skip_nested_async_op_check) {
for (const HloInstruction* instr : comp->instructions()) {
if (skip_nested_async_op_check && instr->IsAsynchronous()) {
continue;
}
for (const HloComputation* called_cmp : instr->called_computations()) {
if (called_cmp->execution_thread() != comp->execution_thread()) {
return Internal(
"Nested computations expects same computation's thread name: %s vs "
"%s, in called computation `%s` vs caller computation `%s`",
called_cmp->execution_thread(), comp->execution_thread(),
called_cmp->name(), comp->name());
}
TF_RETURN_IF_ERROR(CheckNestedComputationThreadNameEqual(
called_cmp, skip_nested_async_op_check));
}
}
return absl::OkStatus();
}

absl::Status CheckUnaryOpWithResultAccuracy(HloInstruction* unary) {
HloOpcode opcode = unary->opcode();
if (unary->has_result_accuracy()) {
Expand Down Expand Up @@ -1651,13 +1630,11 @@ absl::Status CheckAsyncOpComputationThreadName(const HloInstruction* async_op) {
HloOpcodeString(async_op->opcode()), async_execution_thread,
async_op->async_wrapped_computation()->execution_thread());
}
return CheckNestedComputationThreadNameEqual(
async_op->async_wrapped_computation(),
/*skip_nested_async_op_check=*/false);
return absl::OkStatus();
}

absl::Status CheckCallableInstructionThreadName(
const HloInstruction* instruction, bool skip_nested_async_op_check) {
const HloInstruction* instruction) {
for (const HloComputation* computation : instruction->called_computations()) {
if (instruction->parent() != nullptr) {
if (instruction->parent()->execution_thread() !=
Expand All @@ -1669,8 +1646,6 @@ absl::Status CheckCallableInstructionThreadName(
computation->execution_thread());
}
}
TF_RETURN_IF_ERROR(CheckNestedComputationThreadNameEqual(
computation, skip_nested_async_op_check));
}
return absl::OkStatus();
}
Expand Down Expand Up @@ -2806,8 +2781,7 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
}

absl::Status HandleFusion(HloInstruction* fusion) override {
TF_RETURN_IF_ERROR(CheckCallableInstructionThreadName(
fusion, /*skip_nested_async_op_check*/ false));
TF_RETURN_IF_ERROR(CheckCallableInstructionThreadName(fusion));
return CheckFusionInstruction(fusion);
}

Expand Down Expand Up @@ -2854,8 +2828,7 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
xla_while->operand_count(), xla_while->ToString());
}
// Allow kWhile to contain computations on separate thread.
TF_RETURN_IF_ERROR(CheckCallableInstructionThreadName(
xla_while, /*skip_nested_async_op_check=*/true));
TF_RETURN_IF_ERROR(CheckCallableInstructionThreadName(xla_while));

// Verify consistency of sharding of while instructions and related
// instructions (parameters, root) in its called computations.
Expand All @@ -2869,8 +2842,7 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {

absl::Status HandleCall(HloInstruction* call) override {
if (opts_.verify_call_nested_computation_thread_name) {
return CheckCallableInstructionThreadName(
call, /*skip_nested_async_op_check=*/true);
return CheckCallableInstructionThreadName(call);
}
return absl::OkStatus();
}
Expand All @@ -2893,8 +2865,7 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
branch_computation->root_instruction());
}
// Allow kConditional to contain computations on separate thread.
TF_RETURN_IF_ERROR(CheckCallableInstructionThreadName(
conditional, /*skip_nested_async_op_check=*/true));
TF_RETURN_IF_ERROR(CheckCallableInstructionThreadName(conditional));

// Verify consistency of sharding of conditional instructions and roots of
// its branches.
Expand Down Expand Up @@ -2955,8 +2926,7 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
absl::Status HandleCustomCall(HloInstruction* hlo) override {
if (opts_.verify_call_nested_computation_thread_name) {
// Allow kCustomCall to contain computations on separate thread.
return CheckCallableInstructionThreadName(
hlo, /*skip_nested_async_op_check=*/true);
return CheckCallableInstructionThreadName(hlo);
}
return absl::OkStatus();
}
Expand Down

0 comments on commit d6be12c

Please sign in to comment.