From 9ed71af4c5b1f4e29d80c5a0db20ae72366db24d Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Fri, 14 Feb 2025 12:35:33 -0800 Subject: [PATCH] Handling incompatible operand type during HLO -> Mhlo conversion. PiperOrigin-RevId: 727026128 --- xla/hlo/tools/hlo_opt/BUILD | 1 + xla/hlo/tools/hlo_opt/opt_lib.cc | 2 + .../hlo_to_mhlo/hlo_function_importer.cc | 38 +++++++++++++++++++ xla/hlo/translate/hlo_to_mhlo/tests/BUILD | 1 + .../tests/operand_convert_for_convolution.hlo | 36 ++++++++++++++++++ 5 files changed, 78 insertions(+) create mode 100644 xla/hlo/translate/hlo_to_mhlo/tests/operand_convert_for_convolution.hlo diff --git a/xla/hlo/tools/hlo_opt/BUILD b/xla/hlo/tools/hlo_opt/BUILD index 0b9a9215a741b..9f2e6209fbe6c 100644 --- a/xla/hlo/tools/hlo_opt/BUILD +++ b/xla/hlo/tools/hlo_opt/BUILD @@ -62,6 +62,7 @@ cc_library( "//xla/hlo/transforms/simplifiers:broadcast_canonicalizer", "//xla/hlo/transforms/simplifiers:conditional_canonicalizer", "//xla/hlo/transforms/simplifiers:convert_mover", + "//xla/hlo/transforms/simplifiers:convert_operand_folding", "//xla/hlo/transforms/simplifiers:convolution_group_converter", "//xla/hlo/transforms/simplifiers:dynamic_dimension_simplifier", "//xla/hlo/transforms/simplifiers:flatten_call_graph", diff --git a/xla/hlo/tools/hlo_opt/opt_lib.cc b/xla/hlo/tools/hlo_opt/opt_lib.cc index e4baeff7e28c7..eb6d2af995739 100644 --- a/xla/hlo/tools/hlo_opt/opt_lib.cc +++ b/xla/hlo/tools/hlo_opt/opt_lib.cc @@ -65,6 +65,7 @@ limitations under the License. #include "xla/hlo/transforms/simplifiers/broadcast_canonicalizer.h" #include "xla/hlo/transforms/simplifiers/conditional_canonicalizer.h" #include "xla/hlo/transforms/simplifiers/convert_mover.h" +#include "xla/hlo/transforms/simplifiers/convert_operand_folder.h" #include "xla/hlo/transforms/simplifiers/convolution_group_converter.h" #include "xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier.h" #include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" @@ -199,6 +200,7 @@ void OptProvider::RegisterAllHardwareIndependentPasses() { RegisterPass(); RegisterPass(); RegisterPass(); + RegisterPass(); RegisterPass(); RegisterPass(); RegisterPass(); diff --git a/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc index 160372d3081b5..2dad0c1ffb115 100644 --- a/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -52,6 +52,7 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Region.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -74,6 +75,7 @@ limitations under the License. #include "xla/layout.h" #include "xla/literal.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/primitive_util.h" #include "xla/protobuf_util.h" #include "xla/service/hlo.pb.h" #include "xla/shape_util.h" @@ -1859,6 +1861,42 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( "precision_config", ConvertPrecisionConfig(&instruction->precision_config(), builder_))); + // If the element types of the operands for convolution are different, + // insert a convert op to convert the operands to the common element type + // while preserving the values. + auto lhs = operands[0]; + auto rhs = operands[1]; + auto lhs_element_type = instruction->operand(0)->shape().element_type(); + auto rhs_element_type = instruction->operand(1)->shape().element_type(); + if (lhs_element_type != rhs_element_type) { + if (primitive_util::CastPreservesValues(lhs_element_type, + rhs_element_type)) { + auto convert_op_return_type = + mlir::cast(lhs.getType()) + .clone(mlir::getElementTypeOrSelf(rhs)); + lhs = func_builder->create( + loc, convert_op_return_type, lhs); + } else if (primitive_util::CastPreservesValues(rhs_element_type, + lhs_element_type)) { + auto convert_op_return_type = + mlir::cast(rhs.getType()) + .clone(mlir::getElementTypeOrSelf(lhs)); + rhs = func_builder->create( + loc, convert_op_return_type, rhs); + } else { + return InvalidArgument( + "Unsupported conversion between element types of operands (%s " + "and %s) for convolution.", + instruction->operand(0)->shape().ToString(), + instruction->operand(1)->shape().ToString()); + } + return func_builder + ->create( + loc, result_type, std::vector{lhs, rhs}, + attributes) + .getOperation(); + } + return func_builder ->create(loc, result_type, operands, attributes) diff --git a/xla/hlo/translate/hlo_to_mhlo/tests/BUILD b/xla/hlo/translate/hlo_to_mhlo/tests/BUILD index e94d4299f1256..2a5b4a9bbcd9a 100644 --- a/xla/hlo/translate/hlo_to_mhlo/tests/BUILD +++ b/xla/hlo/translate/hlo_to_mhlo/tests/BUILD @@ -28,6 +28,7 @@ lit_test_suite( "location.hlo", "module_attributes.hlo", "module_config.hlo", + "operand_convert_for_convolution.hlo", "ragged_dot.hlo", "simple.hlo", "spmd_module_sharding.hlo", diff --git a/xla/hlo/translate/hlo_to_mhlo/tests/operand_convert_for_convolution.hlo b/xla/hlo/translate/hlo_to_mhlo/tests/operand_convert_for_convolution.hlo new file mode 100644 index 0000000000000..0ddf97f39ebb7 --- /dev/null +++ b/xla/hlo/translate/hlo_to_mhlo/tests/operand_convert_for_convolution.hlo @@ -0,0 +1,36 @@ +// RUN: hlo-translate --hlo-to-mlir --emit-mhlo --split-input-file -verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: func @main +HloModule main, entry_computation_layout={(pred[28,1,1]{2,1,0}, f32[1,28,1]{2,1,0})->f32[1,28,1]{2,1,0}} + +ENTRY %main (p0: pred[28,1,1], p1: f32[1,28,1]) -> f32[1,28,1] { + %p1 = f32[1,28,1]{2,1,0} parameter(1) + %p0 = pred[28,1,1]{2,1,0} parameter(0) + // CHECK: %[[convert:.*]] = mhlo.convert %arg0 : (tensor<28x1x1xi1>) -> tensor<28x1x1xf32> + // CHECK-NEXT: mhlo.convolution(%arg1, %[[convert]]) + ROOT %convolution.9 = f32[1,28,1]{2,1,0} convolution(f32[1,28,1]{2,1,0} %p1, pred[28,1,1]{2,1,0} %p0), window={size=28 pad=13_14}, dim_labels=b0f_0io->b0f +} + +// ----- + +// CHECK-LABEL: func @main +HloModule main, entry_computation_layout={(f32[28,1,1]{2,1,0}, pred[1,28,1]{2,1,0})->f32[1,28,1]{2,1,0}} + +ENTRY %main.10 (p0: f32[28,1,1], p1: pred[1,28,1]) -> f32[1,28,1] { + %p1 = pred[1,28,1]{2,1,0} parameter(1) + %p0 = f32[28,1,1]{2,1,0} parameter(0) + // CHECK: %[[CONVERT:.*]] = mhlo.convert %arg1 : (tensor<1x28x1xi1>) -> tensor<1x28x1xf32> + // CHECK-NEXT: mhlo.convolution(%[[CONVERT]], %arg0) + ROOT %convolution.9 = f32[1,28,1]{2,1,0} convolution(pred[1,28,1]{2,1,0} %p1, f32[28,1,1]{2,1,0} %p0), window={size=28 pad=13_14}, dim_labels=b0f_0io->b0f +} + +// ----- + +// expected-error@-3 {{Unsupported conversion between element types of operands (f8e8m0fnu[1,28,1] and pred[28,1,1]) for convolution.}} +HloModule main, entry_computation_layout={(pred[28,1,1]{2,1,0}, f8e8m0fnu[1,28,1]{2,1,0})->pred[1,28,1]{2,1,0}} + +ENTRY %main.10 (p0: pred[28,1,1], p1: f8e8m0fnu[1,28,1]) -> pred[1,28,1] { + %p1 = f8e8m0fnu[1,28,1]{2,1,0} parameter(1) + %p0 = pred[28,1,1]{2,1,0} parameter(0) + ROOT %convolution.9 = pred[1,28,1]{2,1,0} convolution(f8e8m0fnu[1,28,1]{2,1,0} %p1, pred[28,1,1]{2,1,0} %p0), window={size=28 pad=13_14}, dim_labels=b0f_0io->b0f +}