Skip to content

Commit

Permalink
Make hlo-opt return error status when given an invalid --passes a…
Browse files Browse the repository at this point in the history
…rgument.

`hlo-opt` previously logged an error in this case but did not indicate an error in its exit status.

Note that the program now exits immediately upon encountering an invalid `--passes` argument; it previously logged an error and continued executing.

Fixing this bug also revealed that the `algebraic_simplifier.hlo` test file wasn't doing anything because the `AlgebraicSimplifier` pass wasn't registered in `hlo-opt`. This CL therefore also registers that pass and updates its test accordingly.

PiperOrigin-RevId: 726576856
  • Loading branch information
mrguenther authored and Google-ML-Automation committed Feb 13, 2025
1 parent 5d0a237 commit ba741cd
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 20 deletions.
1 change: 1 addition & 0 deletions xla/hlo/tools/hlo_opt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ cc_library(
"//xla/hlo/transforms/expanders:rng_expander",
"//xla/hlo/transforms/expanders:stable_sort_expander",
"//xla/hlo/transforms/expanders:stochastic_convert_decomposer",
"//xla/hlo/transforms/simplifiers:algebraic_simplifier",
"//xla/hlo/transforms/simplifiers:all_reduce_folder",
"//xla/hlo/transforms/simplifiers:batch_dot_simplification",
"//xla/hlo/transforms/simplifiers:broadcast_canonicalizer",
Expand Down
10 changes: 6 additions & 4 deletions xla/hlo/tools/hlo_opt/opt_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ limitations under the License.
#include "xla/hlo/transforms/expanders/stable_sort_expander.h"
#include "xla/hlo/transforms/expanders/stochastic_convert_decomposer.h"
#include "xla/hlo/transforms/operand_upcaster.h"
#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h"
#include "xla/hlo/transforms/simplifiers/all_reduce_folder.h"
#include "xla/hlo/transforms/simplifiers/batch_dot_simplification.h"
#include "xla/hlo/transforms/simplifiers/broadcast_canonicalizer.h"
Expand Down Expand Up @@ -138,17 +139,17 @@ absl::StatusOr<std::optional<std::string>> OptProvider::GenerateStage(
return module->ToString();
}

absl::StatusOr<std::optional<std::string>>
OptProvider::BuildAndRunTransformPipeline(std::unique_ptr<HloModule> module,
const std::string& input_pass_names) {
absl::StatusOr<std::string> OptProvider::BuildAndRunTransformPipeline(
std::unique_ptr<HloModule> module, const std::string& input_pass_names) {
HloPassPipeline transforms_pipeline{"transforms_pipeline"};
for (const auto& pass_name :
std::vector<std::string>(absl::StrSplit(input_pass_names, ','))) {
auto it = pass_registry_.find(pass_name);
if (it != pass_registry_.end()) {
it->second(transforms_pipeline);
} else {
LOG(ERROR) << "Pass " << pass_name << " not found.";
return absl::InvalidArgumentError(
absl::StrCat("Pass ", pass_name, " not found."));
}
}
CHECK_OK(transforms_pipeline.Run(module.get(), {}));
Expand Down Expand Up @@ -186,6 +187,7 @@ void OptProvider::RegisterAllHardwareIndependentPasses() {
RegisterPass<BarToHelloModulePass>();
// Hardware-independent HLO passes
// go/keep-sorted start
RegisterPass<AlgebraicSimplifier>(AlgebraicSimplifierOptions());
RegisterPass<AllGatherBroadcastReorder>();
RegisterPass<AllReduceContiguous>();
RegisterPass<AllReduceFolder>();
Expand Down
2 changes: 1 addition & 1 deletion xla/hlo/tools/hlo_opt/opt_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class OptProvider {

// Runs input passes on a input module and returns the optimized module
// string.
absl::StatusOr<std::optional<std::string>> BuildAndRunTransformPipeline(
absl::StatusOr<std::string> BuildAndRunTransformPipeline(
std::unique_ptr<HloModule> input_module,
const std::string& input_pass_names);

Expand Down
6 changes: 4 additions & 2 deletions xla/hlo/transforms/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ lit_test_suite(
name = "hlo_opt_tests",
srcs = enforce_glob(
[
"run_single_pass.hlo",
"run_multiple_passes.hlo",
"algebraic_simplifier.hlo",
"hlo_opt_expect_failure.hlo",
"run_multiple_passes.hlo",
"run_single_pass.hlo",
],
include = [
"*.hlo",
Expand All @@ -35,5 +36,6 @@ lit_test_suite(
tools = [
"//xla/hlo/tools:hlo-opt",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
],
)
30 changes: 17 additions & 13 deletions xla/hlo/transforms/tests/algebraic_simplifier.hlo
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
// RUN: hlo-opt %s --passes=algebraic_simplifier | FileCheck %s
// NOTE: Assertions have been autogenerated by hlo/tools/generate_hlo_test_checks.py
// RUN: hlo-opt %s --passes=algsimp | FileCheck %s

// CHECK-LABEL: HloModule m, entry_computation_layout={(s32[8]{0}, s32[8]{0}, s32[8]{0})->s32[8]{0}}

// CHECK-LABEL: ENTRY %test
// CHECK-NEXT: %[[p0:[^ ]+]] = s32[8]{0} parameter(0)
// CHECK-NEXT: %[[p1:[^ ]+]] = s32[8]{0} parameter(1)
// CHECK-NEXT: %[[add:[^ ]+]] = s32[8]{0} add(s32[8]{0} %[[p0]], s32[8]{0} %[[p1]])
// CHECK-NEXT: %[[p2:[^ ]+]] = s32[8]{0} parameter(2)
// CHECK-NEXT: ROOT %[[multiply:[^ ]+]] = s32[8]{0} multiply(s32[8]{0} %[[add]], s32[8]{0} %[[p2]])

HloModule m
ENTRY test {
// CHECK: %[[p0:.*]] = s32[8]{0} parameter(0)
// CHECK-NEXT: %[[p2:.*]] = s32[8]{0} parameter(2)
// CHECK-NEXT: %[[x:.*]] = s32[8]{0} multiply(s32[8]{0} %[[p0]], s32[8]{0} %[[p2]])
// CHECK-NEXT: %[[p1:.*]] = s32[8]{0} parameter(1)
// CHECK-NEXT: %[[y:.*]] = s32[8]{0} multiply(s32[8]{0} %[[p1]], s32[8]{0} %[[p2]])
// CHECK-NEXT: ROOT %[[sum:.*]] = s32[8]{0} add(s32[8]{0} %[[x]], s32[8]{0} %[[y]])
p0 = s32[8] parameter(0)
p1 = s32[8] parameter(1)
p2 = s32[8] parameter(2)
x = s32[8] multiply(p0, p2)
y = s32[8] multiply(p1, p2)
ROOT sum = s32[8] add(x, y)
p0 = s32[8] parameter(0)
p1 = s32[8] parameter(1)
p2 = s32[8] parameter(2)
x = s32[8] multiply(p0, p2)
y = s32[8] multiply(p1, p2)
ROOT sum = s32[8] add(x, y)
}
10 changes: 10 additions & 0 deletions xla/hlo/transforms/tests/hlo_opt_expect_failure.hlo
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: not hlo-opt %s --passes=nonexistent-optimization-pass 2>&1 \
// RUN: | FileCheck %s

// CHECK: INVALID_ARGUMENT: Pass nonexistent-optimization-pass not found.

HloModule NoOpModule

ENTRY no_op {
ROOT no_op = () tuple()
}

0 comments on commit ba741cd

Please sign in to comment.