Skip to content

Commit

Permalink
[MHLO->StableHLO] Allow MHLO with XLA features to be partially import…
Browse files Browse the repository at this point in the history
…ed to StableHLO+CHLO

PiperOrigin-RevId: 726647695
  • Loading branch information
GleasonK authored and Google-ML-Automation committed Feb 13, 2025
1 parent d73250c commit 789ff04
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,27 @@ namespace {

struct HloLegalizeToStablehloPass
: public impl::HloLegalizeToStablehloPassBase<HloLegalizeToStablehloPass> {
HloLegalizeToStablehloPass()
: HloLegalizeToStablehloPassBase<HloLegalizeToStablehloPass>() {}
explicit HloLegalizeToStablehloPass(
const HloLegalizeToStablehloPassOptions& opts)
: HloLegalizeToStablehloPassBase<HloLegalizeToStablehloPass>(opts) {}

void runOnOperation() override {
ConversionTarget target(getContext());
target.addIllegalDialect<mhlo::MhloDialect>();
target.addLegalDialect<stablehlo::StablehloDialect>();

if (allow_xla_features_) {
// These ops do not exist in StableHLO.
target.addLegalOp<
mhlo::AddDependencyOp, mhlo::AsyncDoneOp, mhlo::AsyncStartOp,
mhlo::AsyncUpdateOp, mhlo::BitcastOp, mhlo::CopyOp, mhlo::DomainOp,
mhlo::ErfOp, mhlo::FusionOp, mhlo::MinimumBroadcastShapesOp,
mhlo::RaggedDotOp, mhlo::SparseDotOp, mhlo::StochasticConvertOp,
mhlo::TopKOp, mhlo::TraceOp, mhlo::XlaRngGetAndUpdateStateOp>();
}

stablehlo::HloToStablehloTypeConverter converter;
RewritePatternSet patterns(&getContext());
stablehlo::populateHloToStablehloPatterns(
Expand All @@ -60,10 +76,5 @@ struct HloLegalizeToStablehloPass

} // namespace

std::unique_ptr<mlir::OperationPass<ModuleOp>>
createHloLegalizeToStablehloPass() {
return std::make_unique<HloLegalizeToStablehloPass>();
}

} // namespace mhlo
} // namespace mlir
7 changes: 5 additions & 2 deletions xla/mlir_hlo/mhlo/transforms/mhlo_passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,16 @@ def CollapseElementwiseMapPass

def HloLegalizeToStablehloPass : Pass<"hlo-legalize-to-stablehlo", "ModuleOp"> {
let summary = "Legalize HLO to StableHLO.";
let constructor = "createHloLegalizeToStablehloPass()";
let dependentDialects = ["stablehlo::StablehloDialect"];
let options = [
Option<"allow_experimental_features_", "allow-experimental-features",
"bool", /*default=*/"false",
"Allow legalization of experimental MHLO features via StableHLO "
"custom_call">
"custom_call">,
Option<"allow_xla_features_", "allow-xla-features", "bool",
/*default=*/"false",
"Allow XLA's MHLO ops not in StableHLO to remain present after "
"legalization (copy, add_dependency, fusion, etc.)">
];
}

Expand Down
3 changes: 0 additions & 3 deletions xla/mlir_hlo/mhlo/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,6 @@ std::unique_ptr<OperationPass<func::FuncOp>> createCollapseElementwiseMapPass();
// Pass to replace unsigned types with signless integers.
std::unique_ptr<OperationPass<ModuleOp>> createConvertToSignlessPass();

// Legalizes from the MHLO dialect to the StableHLO dialect.
std::unique_ptr<OperationPass<ModuleOp>> createHloLegalizeToStablehloPass();

// Legalizes from the StableHLO dialect to the MHLO dialect.
std::unique_ptr<OperationPass<ModuleOp>> createStablehloLegalizeToHloPass();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// RUN: mlir-hlo-opt %s -hlo-legalize-to-stablehlo=allow-xla-features --split-input-file | FileCheck %s

func.func @async_ops(%arg1: tensor<128x32xf32>) -> tensor<128x128xf32> attributes {execution_thread = "main"} {
// CHECK: stablehlo.all_gather
%0 = "mhlo.all_gather"(%arg1) {
all_gather_dim = 1 : i64,
channel_handle = #mhlo.channel_handle<handle = 1, type = 0>,
shard_count = 4,
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
use_global_device_ids
} : (tensor<128x32xf32>) -> tensor<128x128xf32>
return %0 : tensor<128x128xf32>
}

func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x128xf32> {
// CHECK: mhlo.async_start
%0 = "mhlo.async_start"(%arg0) {called_computation = @async_ops, execution_thread = "main"} : (tensor<128x32xf32>) -> !mhlo.async_bundle<tensor<128x32xf32>, tensor<128x128xf32>>
// CHECK: mhlo.async_done
%1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle<tensor<128x32xf32>, tensor<128x128xf32>>) -> tensor<128x128xf32>
return %1 : tensor<128x128xf32>
}

// -----

// CHECK-LABEL: func @copy
func.func @copy() -> tensor<2x1xi32> {
// CHECK: %[[CST:.*]] = stablehlo.constant dense<{{.*}}> : tensor<2x1xi32>
// CHECK: %[[COPY:.*]] = mhlo.copy %[[CST]] : tensor<2x1xi32>
%0 = mhlo.constant dense<[[0], [1]]> : tensor<2x1xi32>
%1 = "mhlo.copy"(%0) : (tensor<2x1xi32>) -> tensor<2x1xi32>

// CHECK: return %[[COPY]]
func.return %1 : tensor<2x1xi32>
}

0 comments on commit 789ff04

Please sign in to comment.