Skip to content

Commit

Permalink
[XLA] Add debug option for detecting cycles in fixed-point loops.
Browse files Browse the repository at this point in the history
Due to the way the "changed" signal is reported by passes within a fixed-point loop today, there are various scenarios in which a fixed-point loop that is "converged" may continue to run forever:

*  A composite pipeline is being run to fixed-point, and one pass exactly undoes the effect of another.
*  An individual pass falsely reports that it changed a module (perhaps because it undoes its own change).
*  The fixed-point loop sees the module go through a cycle of states.

While this check is too expensive to enable by default, it presents as a useful debug option. If we have reason to suspect one of the above scenarios is occurring, this option will allow us to identify the passes involved and address the root cause on an individual basis.

PiperOrigin-RevId: 726546024
  • Loading branch information
Google-ML-Automation committed Feb 13, 2025
1 parent ee4c408 commit 95cd71a
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 1 deletion.
6 changes: 6 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_unsupported_enable_ragged_all_to_all_decomposer(false);
opts.set_xla_gpu_experimental_pack_dot_operands_along_k_dimension(true);
opts.set_xla_unsupported_crash_on_hlo_pass_fix_max_iterations(false);
opts.set_xla_hlo_pass_fix_detect_cycles(false);
opts.set_xla_gpu_experimental_enable_sync_collective_combining(false);
return opts;
}
Expand Down Expand Up @@ -2237,6 +2238,11 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_unsupported_crash_on_hlo_pass_fix_max_iterations(),
"Crash if HloPassFix can not converge after a fixed number of "
"iterations."));
flag_list->push_back(tsl::Flag(
"xla_hlo_pass_fix_detect_cycles",
bool_setter_for(&DebugOptions::set_xla_hlo_pass_fix_detect_cycles),
debug_options->xla_hlo_pass_fix_detect_cycles(),
"Perform hash-based cycle detection in fixed-point loops."));
flag_list->push_back(tsl::Flag(
"xla_gpu_experimental_enable_sync_collective_combining",
bool_setter_for(
Expand Down
1 change: 1 addition & 0 deletions xla/hlo/pass/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ cc_library(
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down
17 changes: 17 additions & 0 deletions xla/hlo/pass/hlo_pass_fix.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ limitations under the License.
#ifndef XLA_HLO_PASS_HLO_PASS_FIX_H_
#define XLA_HLO_PASS_HLO_PASS_FIX_H_

#include <cstddef>
#include <cstdint>
#include <type_traits>

#include "absl/container/flat_hash_set.h"
#include "absl/hash/hash.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -102,7 +104,21 @@ class HloPassFix : public Pass {
HloModule* module, RunState* run_state,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
VLOG(3) << "Running HloPassFix on " << Pass::name();

absl::flat_hash_set<size_t> hashes;
while (!run_state->changed_last_iteration.empty()) {
if (module->config().debug_options().xla_hlo_pass_fix_detect_cycles()) {
size_t hash = absl::HashOf(*module);
VLOG(3) << "Module hash for " << Pass::name() << " at iteration "
<< run_state->iteration << ": " << hash;
if (hashes.contains(hash)) {
LOG(WARNING) << "Cycle detected when running " << Pass::name()
<< " on iteration " << run_state->iteration
<< "; hash: " << hash;
} else {
hashes.insert(hash);
}
}
TF_RETURN_IF_ERROR(
RunOnChangedComputationsOnce(module, run_state, execution_threads));
VLOG(3) << Pass::name() << " iteration " << run_state->iteration
Expand All @@ -125,6 +141,7 @@ class HloPassFix : public Pass {
break;
}
}
VLOG(3) << "Finished running HloPassFix on " << Pass::name();
return absl::OkStatus();
}

Expand Down
5 changes: 4 additions & 1 deletion xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ message DebugOptions {
// XLA backend-agnostic options.
//--------------------------------------------------------------------------//
// go/keep-sorted start

// Perform hash-based cycle detection in fixed-point loops.
bool xla_hlo_pass_fix_detect_cycles = 370;
// Crash if HloPassFix can not converge after a fixed number of iterations.
bool xla_unsupported_crash_on_hlo_pass_fix_max_iterations = 363;
// go/keep-sorted end
Expand Down Expand Up @@ -1158,7 +1161,7 @@ message DebugOptions {

// Note: when adding a new flag, please add it to one of the hardware-specific
// or hardware-agnostic sections at the top of this proto message.
// Next id: 370
// Next id: 371

// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
Expand Down

0 comments on commit 95cd71a

Please sign in to comment.