From a03c7ee528f61ef476536b8c146d22b2e77ac5d4 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Wed, 12 Feb 2025 21:49:35 -0500 Subject: [PATCH] Remove the compile option -nnpa-clip-to-dlfloat-range Signed-off-by: Tung D. Le --- .../NNPA/Compiler/NNPACompilerOptions.cpp | 7 - .../NNPA/Compiler/NNPACompilerOptions.hpp | 1 - .../NNPA/Compiler/NNPACompilerUtils.cpp | 10 -- src/Accelerators/NNPA/NNPAAccelerator.cpp | 4 - src/Accelerators/NNPA/Pass/NNPAPasses.hpp | 3 - .../NNPA/Transform/ZHigh/CMakeLists.txt | 11 -- .../Transform/ZHigh/ZHighClipToDLFloat.cpp | 170 ------------------ .../zhigh-clip-to-dlfloat-range.mlir | 96 ---------- 8 files changed, 302 deletions(-) delete mode 100644 src/Accelerators/NNPA/Transform/ZHigh/ZHighClipToDLFloat.cpp delete mode 100644 test/mlir/accelerators/nnpa/transform/zhigh-clip-to-dlfloat-range.mlir diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp index 34457eafd8..29b48073ac 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp @@ -30,13 +30,6 @@ llvm::cl::opt nnpaEmissionTarget( clEnumVal(EmitZNONE, "Do not emit NNPA-related target (default)")), llvm::cl::init(EmitZNONE), llvm::cl::cat(OnnxMlirOptions)); -llvm::cl::opt nnpaClipToDLFloatRange("nnpa-clip-to-dlfloat-range", - llvm::cl::desc("Clip CPU tensors to dlfloat range before stickification to " - "avoid out-of-range. Only clip Softmax inputs at this " - "moment. Default is true. This option will be removed and " - "replaced by --nnpa-saturation in the future."), - llvm::cl::init(true), llvm::cl::cat(OnnxMlirOptions)); - llvm::cl::opt nnpaEnableZHighToOnnx("enable-zhigh-to-onnx", llvm::cl::desc( "Enabling this will convert a pattern `stick -> element-wise op -> " diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp index e6f7cf6aa7..1e07a92d7e 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp @@ -67,7 +67,6 @@ typedef enum { extern llvm::cl::OptionCategory OnnxMlirOptions; extern llvm::cl::OptionCategory OnnxMlirCommonOptions; extern llvm::cl::opt nnpaEmissionTarget; -extern llvm::cl::opt nnpaClipToDLFloatRange; extern llvm::cl::opt nnpaEnableZHighToOnnx; extern llvm::cl::opt nnpaEnableZHighDecomposeStickUnstick; extern llvm::cl::opt nnpaEnableCompilerStickUnstick; diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp index 45a9af09f8..141e733937 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp @@ -145,16 +145,6 @@ void addONNXToZHighPasses(mlir::PassManager &pm) { pm.addNestedPass(onnx_mlir::createShapeInferencePass()); pm.addPass(mlir::createCanonicalizerPass()); - // Clip zhigh.Stick inputs if required. This is to avoid out-of-range of - // dlfloat. Do constant propagation after clipping to remove ONNX ops used for - // clipping such as ONNXMax if applicable. - // This pass will be removed and replaced by nnpa-saturation in the future. - if (!nnpaEnableSaturation && nnpaClipToDLFloatRange) { - pm.addNestedPass( - onnx_mlir::zhigh::createZHighClipToDLFloatPass()); - pm.addNestedPass(onnx_mlir::createConstPropONNXToONNXPass()); - } - // One more call to ONNX shape inference/canonicalization/... to update shape // if possible. if (enableONNXHybridPass) { diff --git a/src/Accelerators/NNPA/NNPAAccelerator.cpp b/src/Accelerators/NNPA/NNPAAccelerator.cpp index 50ef2bf0ba..0fa76062b7 100644 --- a/src/Accelerators/NNPA/NNPAAccelerator.cpp +++ b/src/Accelerators/NNPA/NNPAAccelerator.cpp @@ -125,10 +125,6 @@ void NNPAAccelerator::registerPasses(int optLevel) const { return onnx_mlir::zhigh::createZHighLayoutPropagationPass(); }); - mlir::registerPass([]() -> std::unique_ptr { - return onnx_mlir::zhigh::createZHighClipToDLFloatPass(); - }); - mlir::registerPass([]() -> std::unique_ptr { return onnx_mlir::zhigh::createZHighDecomposeStickUnstickPass(); }); diff --git a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp index c23fb7f158..1c8d6b7012 100644 --- a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp +++ b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp @@ -55,9 +55,6 @@ std::unique_ptr createZHighConstPropagationPass(); std::unique_ptr createZHighScrubDisposablePass( bool closeAfter = true); -/// Pass for clipping values to dlfloat before stickification at ZHighIR. -std::unique_ptr createZHighClipToDLFloatPass(); - /// Pass for decomposing stick/unstick at ZHighIR. std::unique_ptr createZHighDecomposeStickUnstickPass(); diff --git a/src/Accelerators/NNPA/Transform/ZHigh/CMakeLists.txt b/src/Accelerators/NNPA/Transform/ZHigh/CMakeLists.txt index 30378c0f9e..3cae309723 100644 --- a/src/Accelerators/NNPA/Transform/ZHigh/CMakeLists.txt +++ b/src/Accelerators/NNPA/Transform/ZHigh/CMakeLists.txt @@ -38,17 +38,6 @@ add_onnx_mlir_library(OMZHighLayoutPropagation ${NNPA_INCLUDE_PATH} ) -add_onnx_mlir_rewriter(ZHighClipToDLFloat) -add_onnx_mlir_library(OMZHighClipToDLFloat - ZHighClipToDLFloat.cpp - - LINK_LIBS PUBLIC - MLIRRewrite - MLIRTransformUtils - OMZHighOps - OMONNXOps - ) - add_onnx_mlir_rewriter(ZHighDecomposeStickUnstick) add_onnx_mlir_library(OMZHighDecomposeStickUnstick ZHighDecomposeStickUnstick.cpp diff --git a/src/Accelerators/NNPA/Transform/ZHigh/ZHighClipToDLFloat.cpp b/src/Accelerators/NNPA/Transform/ZHigh/ZHighClipToDLFloat.cpp deleted file mode 100644 index 9006c36669..0000000000 --- a/src/Accelerators/NNPA/Transform/ZHigh/ZHighClipToDLFloat.cpp +++ /dev/null @@ -1,170 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - */ - -//===---------- ZHighClipToDLFloat.cpp - ZHigh High Level Optimizer -------===// -// -// Copyright 2023- The IBM Research Authors. -// -// ============================================================================= -// -// This file implements a set of rewritten rules to clip CPU numerical values -// before passing to ZHighStick, which avoids data range violation error due to -// the dlfloat range. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp" -#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp" -#include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp" -#include "src/Accelerators/NNPA/Support/NNPALimit.hpp" -#include "src/Dialect/ONNX/DialectBuilder.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" -#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" -#include "src/Support/TypeUtilities.hpp" - -using namespace mlir; -using namespace onnx_mlir; -using namespace onnx_mlir::zhigh; - -namespace onnx_mlir { -namespace zhigh { - -namespace { - -/// Check if a value is from or transitively from a zTensor without value -/// modification. -bool valueFromZTensor(Value tensor) { - // Function arguments are always CPU tensors. - if (mlir::dyn_cast(tensor)) - return false; - - Operation *op = tensor.getDefiningOp(); - - // Base case: From a zTensor. - if (isa(op)) - return true; - - // Base case: ReluOp clipped the lowerbound to zero. - if (isa(op)) - return true; - - // Base case: Operations having no input, e.g., Constant, ConstantOfShape. - if (op->getOperands().size() == 0) - return false; - - // Recursion case: There are operations (e.g. transpose, reshape, etc.) that - // do not change the input precision. So we can consider that the input is - // already in the dlfloat range if it comes from zTensor. - - // Operations whose only the first input form the output. These ops may - // have additional inputs, but they are like attributes. - if (isa(op)) - return valueFromZTensor(op->getOperand(0)); - - // PadOp - if (auto padOp = mlir::dyn_cast(op)) { - Value padVal = padOp.getConstantValue(); - // Only support default constant value that is 0 at this moment. - if (isNoneValue(padVal)) - return valueFromZTensor(op->getOperand(0)); - } - - // For all remaining operations, do a conservative check. - return llvm::all_of( - op->getOperands(), [&](Value v) { return valueFromZTensor(v); }); -} - -class ZHighClipToDLFloatPattern : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite( - ZHighStickOp stickOp, PatternRewriter &rewriter) const override { - Operation *genericOp = stickOp.getOperation(); - Location loc = genericOp->getLoc(); - - Value input = stickOp.getIn(); - Value output = stickOp.getOut(); - Type inputElementType = getElementType(input.getType()); - - // Only clip if the input is in float > 16 bit. - auto floatType = mlir::dyn_cast(inputElementType); - if (!floatType) - return failure(); - if (floatType.getWidth() <= 16) - return failure(); - - // Only clip if the consummer is Softmax with which we have seen NaNs. - if (llvm::none_of(output.getUsers(), - [&](Operation *op) { return isa(op); })) - return failure(); - - // Do not clip if the input tensor is already in the dlfloat range. - // For example, the input was unstickified from a zTensor. - if (valueFromZTensor(input)) - return failure(); - - // Clip the input values if required since the values are potentially - // out-of-bound of dlfloat. - MultiDialectBuilder create(rewriter, loc); - DenseElementsAttr minAttr = DenseElementsAttr::get( - RankedTensorType::get({1}, inputElementType), DLF16_MIN); - Value minVal = create.onnx.constant(minAttr); - Value clippedVal = create.onnx.max({input, minVal}); - Value replacedVal = - rewriter.create(loc, stickOp.getOut().getType(), - clippedVal, stickOp.getLayoutAttr(), IntegerAttr()); - - rewriter.replaceOp(genericOp, replacedVal); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// ZHighClipToDLFloatPass -//===----------------------------------------------------------------------===// - -struct ZHighClipToDLFloatPass - : public PassWrapper> { - - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ZHighClipToDLFloatPass) - - StringRef getArgument() const override { return "zhigh-clip-to-dlfloat"; } - - StringRef getDescription() const override { - return "Clip stickification inputs at ZHighIR."; - } - - void runOnOperation() override { - auto function = getOperation(); - ConversionTarget target(getContext()); - RewritePatternSet patterns(&getContext()); - patterns.insert(&getContext()); - - GreedyRewriteConfig config; - config.useTopDownTraversal = true; - /// Only pre-existing ops (that were were on the worklist at the very - /// beginning) enqueued. All other ops are excluded. - config.strictMode = GreedyRewriteStrictness::ExistingOps; - - if (failed(applyPatternsAndFoldGreedily( - function, std::move(patterns), config))) - signalPassFailure(); - } -}; -} // anonymous namespace - -std::unique_ptr createZHighClipToDLFloatPass() { - return std::make_unique(); -} - -} // namespace zhigh -} // namespace onnx_mlir diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-clip-to-dlfloat-range.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-clip-to-dlfloat-range.mlir deleted file mode 100644 index ee1e31f280..0000000000 --- a/test/mlir/accelerators/nnpa/transform/zhigh-clip-to-dlfloat-range.mlir +++ /dev/null @@ -1,96 +0,0 @@ -// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --zhigh-clip-to-dlfloat -split-input-file %s || FileCheck %s - -func.func @should_clip_stick(%arg0: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> { - %0 = "zhigh.Stick"(%arg0) {layout = "3DS"} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> - %1 = "zhigh.Softmax"(%0) {act_func = "ACT_NONE"} : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> - %2 = "zhigh.Unstick"(%1) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf32> - return %2 : tensor<3x4x5xf32> - -// CHECK-LABEL: func.func @should_clip_stick -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> { -// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<-8.57315738E+9> : tensor<1xf32> -// CHECK: [[VAR_1_:%.+]] = "onnx.Max"([[PARAM_0_]], [[VAR_0_]]) : (tensor<3x4x5xf32>, tensor<1xf32>) -> tensor<3x4x5xf32> -// CHECK: [[VAR_2_:%.+]] = "zhigh.Stick"([[VAR_1_]]) {layout = "3DS"} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> -// CHECK: [[VAR_3_:%.+]] = "zhigh.Softmax"([[VAR_2_]]) {act_func = "ACT_NONE"} : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> -// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf32> -// CHECK: return [[VAR_4_]] : tensor<3x4x5xf32> -// CHECK: } -} - -// ----- - -func.func @should_clip_transpose(%arg0: tensor<3x5x4xf32>) -> tensor<3x4x5xf32> { - %1 = "onnx.Transpose"(%arg0) { perm = [0, 2, 1]} : (tensor<3x5x4xf32>) -> tensor<3x4x5xf32> - %2 = "zhigh.Stick"(%1) {layout = "3DS"} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> - %3 = "zhigh.Softmax"(%2) {act_func = "ACT_NONE"} : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> - %4 = "zhigh.Unstick"(%3) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf32> - return %4 : tensor<3x4x5xf32> - -// CHECK-LABEL: func.func @should_clip_transpose -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x5x4xf32>) -> tensor<3x4x5xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<-8.57315738E+9> : tensor<1xf32> -// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Transpose"([[PARAM_0_]]) {perm = [0, 2, 1]} : (tensor<3x5x4xf32>) -> tensor<3x4x5xf32> -// CHECK: [[VAR_2_:%.+]] = "onnx.Max"([[VAR_1_]], [[VAR_0_]]) : (tensor<3x4x5xf32>, tensor<1xf32>) -> tensor<3x4x5xf32> -// CHECK: [[VAR_3_:%.+]] = "zhigh.Stick"([[VAR_2_]]) {layout = "3DS"} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> -// CHECK: [[VAR_4_:%.+]] = "zhigh.Softmax"([[VAR_3_]]) {act_func = "ACT_NONE"} : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> -// CHECK: [[VAR_5_:%.+]] = "zhigh.Unstick"([[VAR_4_]]) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf32> -// CHECK: return [[VAR_5_]] : tensor<3x4x5xf32> -// CHECK: } -} - -// ----- - -// Do not clip because the input comes from a zTensor via Unstick. -func.func @donot_clip_stick(%arg0: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<3x4x5xf32> { - %0 = "zhigh.Unstick"(%arg0) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<3x4x5xf32> - %1 = "zhigh.Stick"(%0) {layout = "3DS"} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> - %2 = "zhigh.Softmax"(%1) {act_func = "ACT_NONE"} : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> - %3 = "zhigh.Unstick"(%2) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf32> - return %3 : tensor<3x4x5xf32> - -// CHECK-LABEL: donot_clip_stick -// CHECK: zhigh.Unstick -// CHECK: zhigh.Stick -// CHECK: zhigh.Softmax -// CHECK: zhigh.Unstick -} - -// ----- - -// Do not clip because transpose does not change the zTensor. -func.func @donot_clip_stick_transpose(%arg0: tensor<3x5x4xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<3x4x5xf32> { - %0 = "zhigh.Unstick"(%arg0) : (tensor<3x5x4xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<3x5x4xf32> - %1 = "onnx.Transpose"(%0) { perm = [0, 2, 1]} : (tensor<3x5x4xf32>) -> tensor<3x4x5xf32> - %2 = "zhigh.Stick"(%1) {layout = "3DS"} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> - %3 = "zhigh.Softmax"(%2) {act_func = "ACT_NONE"} : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> - %4 = "zhigh.Unstick"(%3) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf32> - return %4 : tensor<3x4x5xf32> - -// CHECK-LABEL: donot_clip_stick_transpose -// CHECK: zhigh.Unstick -// CHECK: onnx.Transpose. -// CHECK: zhigh.Stick -// CHECK: zhigh.Softmax -// CHECK: zhigh.Unstick -} - -// ----- - -// Do not clip because concat does not change the zTensor. -func.func @donot_clip_stick_concat(%arg0: tensor<3x2x5xf16, #zhigh.layout<{dataLayout = "3D"}>>, %arg1: tensor<3x2x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<3x4x5xf32> { - %0 = "zhigh.Unstick"(%arg0) : (tensor<3x2x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<3x2x5xf32> - %1 = "zhigh.Unstick"(%arg1) : (tensor<3x2x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<3x2x5xf32> - %2 = "onnx.Concat"(%0, %1) { axis = 1 : si64} : (tensor<3x2x5xf32>, tensor<3x2x5xf32>) -> tensor<3x4x5xf32> - %3 = "zhigh.Stick"(%2) {layout = "3DS"} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> - %4 = "zhigh.Softmax"(%3) {act_func = "ACT_NONE"} : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> - %5 = "zhigh.Unstick"(%4) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf32> - return %5 : tensor<3x4x5xf32> - -// CHECK-LABEL: donot_clip_stick_concat -// CHECK: zhigh.Unstick -// CHECK: zhigh.Unstick -// CHECK: onnx.Concat. -// CHECK: zhigh.Stick -// CHECK: zhigh.Softmax -// CHECK: zhigh.Unstick -}