diff --git a/src/Conversion/ONNXToStablehlo/Math/Softmax.cpp b/src/Conversion/ONNXToStablehlo/Math/Softmax.cpp index f1b0d61657..fa833d0fbd 100644 --- a/src/Conversion/ONNXToStablehlo/Math/Softmax.cpp +++ b/src/Conversion/ONNXToStablehlo/Math/Softmax.cpp @@ -126,10 +126,6 @@ struct ONNXSoftmaxOpLoweringToStablehlo : public ConversionPattern { Location loc = op->getLoc(); Type outputType = *op->result_type_begin(); assert(isRankedShapedType(outputType) && "Expected Ranked ShapedType"); - assert(mlir::cast(operand.getType()) - .getElementType() - .isF32() && - "Currently Only float32 is supported for input"); // Exponential operation Value ElementwiseExpStableHLO = rewriter.create( @@ -204,4 +200,4 @@ void populateLoweringONNXSoftmaxOpToStablehloPattern( RewritePatternSet &patterns, MLIRContext *ctx) { patterns.insert(ctx); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax-Decompose.mlir b/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax-Decompose.mlir new file mode 100644 index 0000000000..0da75f096a --- /dev/null +++ b/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax-Decompose.mlir @@ -0,0 +1,101 @@ +// RUN: onnx-mlir-opt --decompose-onnx="target=stablehlo" --convert-onnx-to-stablehlo %s --canonicalize -split-input-file | FileCheck %s + +func.func @test_softmax(%arg0 : tensor<10x20x30xf32>) -> tensor<10x20x30xf32> { + %0 = "onnx.Softmax"(%arg0) {axis = 1: si64} : (tensor<10x20x30xf32>) -> tensor<10x20x30xf32> + "func.return"(%0) : (tensor<10x20x30xf32>) -> () +} + +// CHECK-LABEL: func.func @test_softmax +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x20x30xf32>) -> tensor<10x20x30xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = shape.const_shape [10, 1, 30] : tensor<3xindex> +// CHECK-DAG: [[VAR_1_:%.+]] = shape.const_shape [10, 20, 30] : tensor<3xindex> +// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor +// CHECK: [[VAR_4_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_3_]]) applies stablehlo.maximum across dimensions = [1] : (tensor<10x20x30xf32>, tensor) -> tensor<10x30xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = stablehlo.dynamic_reshape [[VAR_4_]], [[VAR_0_]] : (tensor<10x30xf32>, tensor<3xindex>) -> tensor<10x1x30xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_1_]], dims = [0, 1, 2] : (tensor<10x20x30xf32>, tensor<3xindex>) -> tensor<10x20x30xf32> +// CHECK: [[VAR_7_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_5_]], [[VAR_1_]], dims = [0, 1, 2] : (tensor<10x1x30xf32>, tensor<3xindex>) -> tensor<10x20x30xf32> +// CHECK: [[VAR_8_:%.+]] = stablehlo.subtract [[VAR_6_]], [[VAR_7_]] : tensor<10x20x30xf32> +// CHECK: [[VAR_9_:%.+]] = stablehlo.exponential [[VAR_8_]] : tensor<10x20x30xf32> +// CHECK: [[VAR_10_:%.+]] = stablehlo.reduce([[VAR_9_]] init: [[VAR_2_]]) applies stablehlo.add across dimensions = [1] : (tensor<10x20x30xf32>, tensor) -> tensor<10x30xf32> +// CHECK-DAG: [[VAR_11_:%.+]] = stablehlo.dynamic_reshape [[VAR_10_]], [[VAR_0_]] : (tensor<10x30xf32>, tensor<3xindex>) -> tensor<10x1x30xf32> +// CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_9_]], [[VAR_1_]], dims = [0, 1, 2] : (tensor<10x20x30xf32>, tensor<3xindex>) -> tensor<10x20x30xf32> +// CHECK: [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_11_]], [[VAR_1_]], dims = [0, 1, 2] : (tensor<10x1x30xf32>, tensor<3xindex>) -> tensor<10x20x30xf32> +// CHECK: [[VAR_14_:%.+]] = stablehlo.divide [[VAR_12_]], [[VAR_13_]] : tensor<10x20x30xf32> +// CHECK: return [[VAR_14_]] : tensor<10x20x30xf32> +// CHECK: } + +// ----- + +func.func @test_softmax_dynamic(%arg0 : tensor) -> tensor { + %0 = "onnx.Softmax"(%arg0) {axis = 1: si64} : (tensor) -> tensor + "func.return"(%0) : (tensor) -> () +} + +// CHECK-LABEL: func.func @test_softmax_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.maximum across dimensions = [1] : (tensor, tensor) -> tensor +// CHECK-DAG: [[VAR_3_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor -> tensor<3xindex> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_0_]] : tensor<3xindex>, index -> index +// CHECK-DAG: [[VAR_5_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_2_]] : tensor<3xindex>, index -> index +// CHECK: [[VAR_6_:%.+]] = shape.from_extents [[VAR_4_]], [[CST_1_]], [[VAR_5_]] : index, index, index +// CHECK: [[VAR_7_:%.+]] = shape.to_extent_tensor [[VAR_6_]] : !shape.shape -> tensor<3xindex> +// CHECK-DAG: [[VAR_8_:%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_7_]] : (tensor, tensor<3xindex>) -> tensor +// CHECK-DAG: [[VAR_9_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor -> tensor<3xindex> +// CHECK: [[VAR_10_:%.+]] = shape.shape_of [[VAR_8_]] : tensor -> tensor<3xindex> +// CHECK: [[VAR_11_:%.+]] = shape.broadcast [[VAR_9_]], [[VAR_10_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> +// CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor +// CHECK-DAG: [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_8_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor +// CHECK: [[VAR_14_:%.+]] = stablehlo.subtract [[VAR_12_]], [[VAR_13_]] : tensor +// CHECK: [[VAR_15_:%.+]] = stablehlo.exponential [[VAR_14_]] : tensor +// CHECK-DAG: [[VAR_16_:%.+]] = stablehlo.reduce([[VAR_15_]] init: [[VAR_0_]]) applies stablehlo.add across dimensions = [1] : (tensor, tensor) -> tensor +// CHECK-DAG: [[VAR_17_:%.+]] = shape.shape_of [[VAR_15_]] : tensor -> tensor<3xindex> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_18_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_0_]] : tensor<3xindex>, index -> index +// CHECK-DAG: [[VAR_19_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_2_]] : tensor<3xindex>, index -> index +// CHECK: [[VAR_20_:%.+]] = shape.from_extents [[VAR_18_]], [[CST_1_]], [[VAR_19_]] : index, index, index +// CHECK: [[VAR_21_:%.+]] = shape.to_extent_tensor [[VAR_20_]] : !shape.shape -> tensor<3xindex> +// CHECK-DAG: [[VAR_22_:%.+]] = stablehlo.dynamic_reshape [[VAR_16_]], [[VAR_21_]] : (tensor, tensor<3xindex>) -> tensor +// CHECK-DAG: [[VAR_23_:%.+]] = shape.shape_of [[VAR_15_]] : tensor -> tensor<3xindex> +// CHECK: [[VAR_24_:%.+]] = shape.shape_of [[VAR_22_]] : tensor -> tensor<3xindex> +// CHECK: [[VAR_25_:%.+]] = shape.broadcast [[VAR_23_]], [[VAR_24_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> +// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_15_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor +// CHECK-DAG: [[VAR_27_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_22_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor +// CHECK: [[VAR_28_:%.+]] = stablehlo.divide [[VAR_26_]], [[VAR_27_]] : tensor +// CHECK: return [[VAR_28_]] : tensor +// CHECK: } + + +// ----- + +func.func @test_softmax_2d(%arg0 : tensor<1x10xf32>) -> tensor<1x10xf32> { + %0 = "onnx.Softmax"(%arg0) {axis = -1 : si64} : (tensor<1x10xf32>) -> tensor<1x10xf32> + "func.return"(%0) : (tensor<1x10xf32>) -> () +} + +// CHECK-LABEL: func.func @test_softmax_2d +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x10xf32>) -> tensor<1x10xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = shape.const_shape [1, 1] : tensor<2xindex> +// CHECK-DAG: [[VAR_1_:%.+]] = shape.const_shape [1, 10] : tensor<2xindex> +// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor +// CHECK: [[VAR_4_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_3_]]) applies stablehlo.maximum across dimensions = [1] : (tensor<1x10xf32>, tensor) -> tensor<1xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = stablehlo.dynamic_reshape [[VAR_4_]], [[VAR_0_]] : (tensor<1xf32>, tensor<2xindex>) -> tensor<1x1xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_1_]], dims = [0, 1] : (tensor<1x10xf32>, tensor<2xindex>) -> tensor<1x10xf32> +// CHECK: [[VAR_7_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_5_]], [[VAR_1_]], dims = [0, 1] : (tensor<1x1xf32>, tensor<2xindex>) -> tensor<1x10xf32> +// CHECK: [[VAR_8_:%.+]] = stablehlo.subtract [[VAR_6_]], [[VAR_7_]] : tensor<1x10xf32> +// CHECK: [[VAR_9_:%.+]] = stablehlo.exponential [[VAR_8_]] : tensor<1x10xf32> +// CHECK: [[VAR_10_:%.+]] = stablehlo.reduce([[VAR_9_]] init: [[VAR_2_]]) applies stablehlo.add across dimensions = [1] : (tensor<1x10xf32>, tensor) -> tensor<1xf32> +// CHECK-DAG: [[VAR_11_:%.+]] = stablehlo.dynamic_reshape [[VAR_10_]], [[VAR_0_]] : (tensor<1xf32>, tensor<2xindex>) -> tensor<1x1xf32> +// CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_9_]], [[VAR_1_]], dims = [0, 1] : (tensor<1x10xf32>, tensor<2xindex>) -> tensor<1x10xf32> +// CHECK: [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_11_]], [[VAR_1_]], dims = [0, 1] : (tensor<1x1xf32>, tensor<2xindex>) -> tensor<1x10xf32> +// CHECK: [[VAR_14_:%.+]] = stablehlo.divide [[VAR_12_]], [[VAR_13_]] : tensor<1x10xf32> +// CHECK: return [[VAR_14_]] : tensor<1x10xf32> +// CHECK: } diff --git a/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax.mlir b/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax.mlir index 0da75f096a..3fe15a13d1 100644 --- a/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax.mlir +++ b/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax.mlir @@ -1,101 +1,31 @@ -// RUN: onnx-mlir-opt --decompose-onnx="target=stablehlo" --convert-onnx-to-stablehlo %s --canonicalize -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --convert-onnx-to-stablehlo %s --canonicalize -split-input-file | FileCheck %s -func.func @test_softmax(%arg0 : tensor<10x20x30xf32>) -> tensor<10x20x30xf32> { - %0 = "onnx.Softmax"(%arg0) {axis = 1: si64} : (tensor<10x20x30xf32>) -> tensor<10x20x30xf32> - "func.return"(%0) : (tensor<10x20x30xf32>) -> () +func.func @test_softmax_bf16(%arg0 : tensor<10x20x30xbf16>) -> tensor<10x20x30xbf16> { + %0 = "onnx.Softmax"(%arg0) {axis = 1: si64} : (tensor<10x20x30xbf16>) -> tensor<10x20x30xbf16> + "func.return"(%0) : (tensor<10x20x30xbf16>) -> () } -// CHECK-LABEL: func.func @test_softmax -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x20x30xf32>) -> tensor<10x20x30xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = shape.const_shape [10, 1, 30] : tensor<3xindex> -// CHECK-DAG: [[VAR_1_:%.+]] = shape.const_shape [10, 20, 30] : tensor<3xindex> -// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor -// CHECK: [[VAR_4_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_3_]]) applies stablehlo.maximum across dimensions = [1] : (tensor<10x20x30xf32>, tensor) -> tensor<10x30xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = stablehlo.dynamic_reshape [[VAR_4_]], [[VAR_0_]] : (tensor<10x30xf32>, tensor<3xindex>) -> tensor<10x1x30xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_1_]], dims = [0, 1, 2] : (tensor<10x20x30xf32>, tensor<3xindex>) -> tensor<10x20x30xf32> -// CHECK: [[VAR_7_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_5_]], [[VAR_1_]], dims = [0, 1, 2] : (tensor<10x1x30xf32>, tensor<3xindex>) -> tensor<10x20x30xf32> -// CHECK: [[VAR_8_:%.+]] = stablehlo.subtract [[VAR_6_]], [[VAR_7_]] : tensor<10x20x30xf32> -// CHECK: [[VAR_9_:%.+]] = stablehlo.exponential [[VAR_8_]] : tensor<10x20x30xf32> -// CHECK: [[VAR_10_:%.+]] = stablehlo.reduce([[VAR_9_]] init: [[VAR_2_]]) applies stablehlo.add across dimensions = [1] : (tensor<10x20x30xf32>, tensor) -> tensor<10x30xf32> -// CHECK-DAG: [[VAR_11_:%.+]] = stablehlo.dynamic_reshape [[VAR_10_]], [[VAR_0_]] : (tensor<10x30xf32>, tensor<3xindex>) -> tensor<10x1x30xf32> -// CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_9_]], [[VAR_1_]], dims = [0, 1, 2] : (tensor<10x20x30xf32>, tensor<3xindex>) -> tensor<10x20x30xf32> -// CHECK: [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_11_]], [[VAR_1_]], dims = [0, 1, 2] : (tensor<10x1x30xf32>, tensor<3xindex>) -> tensor<10x20x30xf32> -// CHECK: [[VAR_14_:%.+]] = stablehlo.divide [[VAR_12_]], [[VAR_13_]] : tensor<10x20x30xf32> -// CHECK: return [[VAR_14_]] : tensor<10x20x30xf32> -// CHECK: } +// CHECK-LABEL: func.func @test_softmax_bf16 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x20x30xbf16>) -> tensor<10x20x30xbf16> { +// CHECK: [[CST:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK-NEXT: [[EXP:%.+]] = stablehlo.exponential [[PARAM_0_]] : tensor<10x20x30xbf16> +// CHECK-NEXT: [[REDUCE:%.+]] = stablehlo.reduce([[EXP]] init: [[CST]]) applies stablehlo.add across dimensions = [1] : (tensor<10x20x30xbf16>, tensor) -> tensor<10x30xbf16> +// CHECK-NEXT: [[DENOM:%.+]] = stablehlo.broadcast_in_dim [[REDUCE]], dims = [0, 2] : (tensor<10x30xbf16>) -> tensor<10x20x30xbf16> +// CHECK-NEXT: [[RES:%.+]] = stablehlo.divide [[EXP]], [[DENOM]] : tensor<10x20x30xbf16> +// CHECK-NEXT: return [[RES]] : tensor<10x20x30xbf16> // ----- -func.func @test_softmax_dynamic(%arg0 : tensor) -> tensor { - %0 = "onnx.Softmax"(%arg0) {axis = 1: si64} : (tensor) -> tensor - "func.return"(%0) : (tensor) -> () +func.func @test_softmax_f64(%arg0 : tensor<10x20x30xf64>) -> tensor<10x20x30xf64> { + %0 = "onnx.Softmax"(%arg0) {axis = -1: si64} : (tensor<10x20x30xf64>) -> tensor<10x20x30xf64> + "func.return"(%0) : (tensor<10x20x30xf64>) -> () } -// CHECK-LABEL: func.func @test_softmax_dynamic -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { -// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index -// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.maximum across dimensions = [1] : (tensor, tensor) -> tensor -// CHECK-DAG: [[VAR_3_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor -> tensor<3xindex> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_4_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_0_]] : tensor<3xindex>, index -> index -// CHECK-DAG: [[VAR_5_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_2_]] : tensor<3xindex>, index -> index -// CHECK: [[VAR_6_:%.+]] = shape.from_extents [[VAR_4_]], [[CST_1_]], [[VAR_5_]] : index, index, index -// CHECK: [[VAR_7_:%.+]] = shape.to_extent_tensor [[VAR_6_]] : !shape.shape -> tensor<3xindex> -// CHECK-DAG: [[VAR_8_:%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_7_]] : (tensor, tensor<3xindex>) -> tensor -// CHECK-DAG: [[VAR_9_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor -> tensor<3xindex> -// CHECK: [[VAR_10_:%.+]] = shape.shape_of [[VAR_8_]] : tensor -> tensor<3xindex> -// CHECK: [[VAR_11_:%.+]] = shape.broadcast [[VAR_9_]], [[VAR_10_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> -// CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor -// CHECK-DAG: [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_8_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor -// CHECK: [[VAR_14_:%.+]] = stablehlo.subtract [[VAR_12_]], [[VAR_13_]] : tensor -// CHECK: [[VAR_15_:%.+]] = stablehlo.exponential [[VAR_14_]] : tensor -// CHECK-DAG: [[VAR_16_:%.+]] = stablehlo.reduce([[VAR_15_]] init: [[VAR_0_]]) applies stablehlo.add across dimensions = [1] : (tensor, tensor) -> tensor -// CHECK-DAG: [[VAR_17_:%.+]] = shape.shape_of [[VAR_15_]] : tensor -> tensor<3xindex> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_18_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_0_]] : tensor<3xindex>, index -> index -// CHECK-DAG: [[VAR_19_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_2_]] : tensor<3xindex>, index -> index -// CHECK: [[VAR_20_:%.+]] = shape.from_extents [[VAR_18_]], [[CST_1_]], [[VAR_19_]] : index, index, index -// CHECK: [[VAR_21_:%.+]] = shape.to_extent_tensor [[VAR_20_]] : !shape.shape -> tensor<3xindex> -// CHECK-DAG: [[VAR_22_:%.+]] = stablehlo.dynamic_reshape [[VAR_16_]], [[VAR_21_]] : (tensor, tensor<3xindex>) -> tensor -// CHECK-DAG: [[VAR_23_:%.+]] = shape.shape_of [[VAR_15_]] : tensor -> tensor<3xindex> -// CHECK: [[VAR_24_:%.+]] = shape.shape_of [[VAR_22_]] : tensor -> tensor<3xindex> -// CHECK: [[VAR_25_:%.+]] = shape.broadcast [[VAR_23_]], [[VAR_24_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> -// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_15_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor -// CHECK-DAG: [[VAR_27_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_22_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor -// CHECK: [[VAR_28_:%.+]] = stablehlo.divide [[VAR_26_]], [[VAR_27_]] : tensor -// CHECK: return [[VAR_28_]] : tensor -// CHECK: } - - -// ----- - -func.func @test_softmax_2d(%arg0 : tensor<1x10xf32>) -> tensor<1x10xf32> { - %0 = "onnx.Softmax"(%arg0) {axis = -1 : si64} : (tensor<1x10xf32>) -> tensor<1x10xf32> - "func.return"(%0) : (tensor<1x10xf32>) -> () -} - -// CHECK-LABEL: func.func @test_softmax_2d -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x10xf32>) -> tensor<1x10xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = shape.const_shape [1, 1] : tensor<2xindex> -// CHECK-DAG: [[VAR_1_:%.+]] = shape.const_shape [1, 10] : tensor<2xindex> -// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor -// CHECK: [[VAR_4_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_3_]]) applies stablehlo.maximum across dimensions = [1] : (tensor<1x10xf32>, tensor) -> tensor<1xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = stablehlo.dynamic_reshape [[VAR_4_]], [[VAR_0_]] : (tensor<1xf32>, tensor<2xindex>) -> tensor<1x1xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_1_]], dims = [0, 1] : (tensor<1x10xf32>, tensor<2xindex>) -> tensor<1x10xf32> -// CHECK: [[VAR_7_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_5_]], [[VAR_1_]], dims = [0, 1] : (tensor<1x1xf32>, tensor<2xindex>) -> tensor<1x10xf32> -// CHECK: [[VAR_8_:%.+]] = stablehlo.subtract [[VAR_6_]], [[VAR_7_]] : tensor<1x10xf32> -// CHECK: [[VAR_9_:%.+]] = stablehlo.exponential [[VAR_8_]] : tensor<1x10xf32> -// CHECK: [[VAR_10_:%.+]] = stablehlo.reduce([[VAR_9_]] init: [[VAR_2_]]) applies stablehlo.add across dimensions = [1] : (tensor<1x10xf32>, tensor) -> tensor<1xf32> -// CHECK-DAG: [[VAR_11_:%.+]] = stablehlo.dynamic_reshape [[VAR_10_]], [[VAR_0_]] : (tensor<1xf32>, tensor<2xindex>) -> tensor<1x1xf32> -// CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_9_]], [[VAR_1_]], dims = [0, 1] : (tensor<1x10xf32>, tensor<2xindex>) -> tensor<1x10xf32> -// CHECK: [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_11_]], [[VAR_1_]], dims = [0, 1] : (tensor<1x1xf32>, tensor<2xindex>) -> tensor<1x10xf32> -// CHECK: [[VAR_14_:%.+]] = stablehlo.divide [[VAR_12_]], [[VAR_13_]] : tensor<1x10xf32> -// CHECK: return [[VAR_14_]] : tensor<1x10xf32> -// CHECK: } +// CHECK-LABEL: func.func @test_softmax_f64 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x20x30xf64>) -> tensor<10x20x30xf64> { +// CHECK: [[CST:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK-NEXT: [[EXP:%.+]] = stablehlo.exponential [[PARAM_0_]] : tensor<10x20x30xf64> +// CHECK-NEXT: [[REDUCE:%.+]] = stablehlo.reduce([[EXP]] init: [[CST]]) applies stablehlo.add across dimensions = [2] : (tensor<10x20x30xf64>, tensor) -> tensor<10x20xf64> +// CHECK-NEXT: [[DENOM:%.+]] = stablehlo.broadcast_in_dim [[REDUCE]], dims = [0, 1] : (tensor<10x20xf64>) -> tensor<10x20x30xf64> +// CHECK-NEXT: [[RES:%.+]] = stablehlo.divide [[EXP]], [[DENOM]] : tensor<10x20x30xf64> +// CHECK-NEXT: return [[RES]] : tensor<10x20x30xf64>