Skip to content

Commit

Permalink
[hlo-translate] Accept VHLO in hlo-translate tool
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726606437
  • Loading branch information
GleasonK authored and Google-ML-Automation committed Feb 13, 2025
1 parent 0384136 commit e19d695
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 3 deletions.
1 change: 1 addition & 0 deletions xla/hlo/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ xla_cc_binary(
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TranslateLib",
"@stablehlo//:stablehlo_passes",
"@tsl//tsl/platform:protobuf",
],
)
Expand Down
11 changes: 11 additions & 0 deletions xla/hlo/tools/hlo_translate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ limitations under the License.
#include "mlir/Support/LogicalResult.h"
#include "mlir/Tools/mlir-translate/MlirTranslateMain.h"
#include "mlir/Tools/mlir-translate/Translation.h"
#include "stablehlo/transforms/Passes.h"
#include "xla/hlo/ir/hlo_input_output_alias_config.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/parser/hlo_parser.h"
Expand Down Expand Up @@ -188,6 +189,16 @@ static mlir::OwningOpRef<mlir::ModuleOp> HloToMlirTranslate(

static mlir::LogicalResult MlirToHloTranslate(mlir::ModuleOp mlir_module,
llvm::raw_ostream& output) {
// Also support portable artifacts in tooling, no-op if module is already
// StableHLO.
mlir::PassManager pm(mlir_module.getContext());
mlir::stablehlo::createStablehloDeserializePipeline(pm);
if (failed(pm.run(mlir_module))) {
mlir_module->emitError("Failed to deserialize StableHLO");
return mlir::failure();
}

// Convert to HLO
auto hlo_module_or_status = xla::ConvertStablehloToHlo(mlir_module);
if (!hlo_module_or_status.ok()) {
mlir_module->emitError(hlo_module_or_status.status().message());
Expand Down
10 changes: 7 additions & 3 deletions xla/hlo/translate/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,24 @@ lit_test_suite(
name = "all_tests",
srcs = enforce_glob(
[
"simple.mlir",
"emit_mhlo.hlo",
"emit_proto.mlir",
"print_large_constants.mlir",
"print_layouts.mlir",
"simple.hlo",
"emit_mhlo.hlo",
"simple.mlir",
"vhlo_input.mlir",
],
include = [
"*.mlir",
"*.hlo",
],
),
cfg = "//xla:lit.cfg.py",
data = [":test_utilities"],
data = [
"vhlo_input.mlir.bc",
":test_utilities",
],
tools = [
"//xla/hlo/tools:hlo-translate",
"@llvm-project//llvm:FileCheck",
Expand Down
18 changes: 18 additions & 0 deletions xla/hlo/translate/tests/vhlo_input.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: hlo-translate -mlir-to-hlo %s.bc | FileCheck %s

// File `vhlo_input.mlir.bc` is created by running the following command:
// $ stablehlo-translate --serialize --target=1.0.0 --strip-debuginfo vhlo_input.mlir > vhlo_input.mlir.bc
//
// The `.mlir.bc` file is used in the above RUN command, along with the
// filechecks specified in this file.

// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[]
func.func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<f32> {
// CHECK: %Arg_0.1 = f32[4] parameter(0)
// CHECK: %Arg_1.2 = f32[4] parameter(1)
// CHECK: %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
%0 = stablehlo.add %arg0, %arg1 : tensor<4xf32>
// CHECK: ROOT %dot.4 = f32[] dot(f32[4] %add.3, f32[4] %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}
%1 = stablehlo.dot %0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
func.return %1 : tensor<f32>
}
Binary file added xla/hlo/translate/tests/vhlo_input.mlir.bc
Binary file not shown.

0 comments on commit e19d695

Please sign in to comment.