Skip to content

Commit

Permalink
Optimization for Linear Quantization (#2954)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandre Eichenberger <[email protected]>
Co-authored-by: Tung D. Le <[email protected]>
  • Loading branch information
AlexandreEichenberger and tungld authored Sep 26, 2024
1 parent fb9544d commit f7d5895
Show file tree
Hide file tree
Showing 10 changed files with 214 additions and 244 deletions.
39 changes: 7 additions & 32 deletions src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,8 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
DimsExpr outputAF;
outputAF.emplace_back(zero);

// faster than original loop on z16, takes 124us for 64k vals
// Allocate output buffers.
MemRefType flatBufferType = llvm::cast<MemRefType>(flatInput.getType());
Value flatBuffer = create.mem.alignedAlloc(flatBufferType, flatInputDims);
DimsExpr bufferAF;
bufferAF.emplace_back(zero);

create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel,
{flatInput}, {inputAF}, {flatBuffer}, {bufferAF},
{flatInput}, {inputAF}, {flatAlloc}, {outputAF},
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
MultiDialectBuilder<MathBuilder> create(kb);
Value x = inputVals[0];
Expand All @@ -95,29 +88,10 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
adjustX = roundX;
// Saturate: use max into a min.
Value saturateX = create.math.clip(adjustX, qMin, qMax);
// Old approach.
// return create.math.cast(quantizedElementType, saturateX);
return saturateX;
// Convert into quantized type.
return create.math.cast(quantizedElementType, saturateX);
}});

// A second loop that performs scalar float to int performs better than the
// compiler's attempt to generate SIMD conversion code. This might not hold
// with all data types, but is definitely noticeable with uint8.
//
// Investigate further: we might save the vector to a buffer on the fly
// (avoiding a second loop as below), and then reload each value as scalar and
// then saved them as scalar (thus avoiding the insert/extract SIMD operations
// that also do not perform well). We can have a SIMD buffer in memory for the
// non-quantized and quantized simd values, but then we also need to privatize
// it, which is also not easy in this scheme. So ignore this for now.
create.krnl.forLoopIE(simdLb, simdUb, 1, enableParallel,
[&](const KrnlBuilder &kb, ValueRange loopInd) {
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create(kb);
Value buffVal = create.krnl.loadIE(flatBuffer, {zero}, {loopInd[0]});
Value res = create.math.cast(quantizedElementType, buffVal);
create.krnl.storeIE(res, flatAlloc, {zero}, {loopInd[0]});
});

if (totVL > 1)
onnxToKrnlSimdReport(op, /*successful*/ true, totVL,
simdLoopStaticTripCount, "quantizationLinear whole tensor");
Expand Down Expand Up @@ -202,9 +176,10 @@ struct ONNXQuantizeLinearOpLowering
hasZeroPoint = true;
}
if (disableQuantZeroPoint) {
// TODO: should we expect to disable hasZeroPoint forcefully, or generate
// an error if we had a zero point? Right now, just forcefully assert we
// have no zero point, i.e. ignore one even if we had a zero point.
// TODO: should we expect to disable hasZeroPoint forcefully, or
// generate an error if we had a zero point? Right now, just forcefully
// assert we have no zero point, i.e. ignore one even if we had a zero
// point.
hasZeroPoint = false;
}
emitQuantizationLinearScalarParameters(rewriter, loc, op, xMemRefType,
Expand Down
40 changes: 40 additions & 0 deletions src/Dialect/Mlir/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,46 @@ Value MathBuilder::cast(Type destType, Value src) const {
LLVM_DEBUG(llvm::dbgs() << "srcType: " << srcType << "\n";
llvm::dbgs() << "destType: " << destType << "\n";);

// Before we process with the actual cast, there is a special case that we
// want to handle here. Cast from float to int that have different width, llvm
// generate better patterns if we first cast from float to int of the same
// width, and then from int to a different size int.
// Skip that optimization if the result is a 1 bit (boolean).
if (mlir::isa<FloatType>(srcElemType) &&
mlir::isa<IntegerType>(destElemType) && bitTrunc && destElemWidth > 1) {
// Quantization: float to smaller int. First determine the intermediary
// type, same integer type as destination type, with the same type width as
// the source float type.
Type step1ElementType;
IntegerType destIntType = mlir::cast<IntegerType>(destElemType);
bool destIssSigned = destIntType.isSignless() || destIntType.isSigned();
if (destIssSigned)
step1ElementType = b().getIntegerType(srcElemWidth);
else
step1ElementType = b().getIntegerType(srcElemWidth, false);
// Perform (recursively) the 2 step conversion. Exceptionally ok here to use
// element type here as cast will promote it to a vector if src is a vector.
Value step1Val = cast(step1ElementType, src);
return cast(destType, step1Val);
}
if (mlir::isa<IntegerType>(srcElemType) &&
mlir::isa<FloatType>(destElemType) && bitExtend) {
// Dequantization: small int to a float. First determine the intermediary
// type, same integer type as source type, with the same type width as
// the destination float type.
Type step1ElementType;
IntegerType srcIntType = mlir::cast<IntegerType>(srcElemType);
bool srcIssSigned = srcIntType.isSignless() || srcIntType.isSigned();
if (srcIssSigned)
step1ElementType = b().getIntegerType(destElemWidth);
else
step1ElementType = b().getIntegerType(destElemWidth, false);
// Perform (recursively) the 2 step conversion. Exceptionally ok here to use
// element type here as cast will promote it to a vector if src is a vector.
Value step1Val = cast(step1ElementType, src);
return cast(destType, step1Val);
}

// Handle boolean first because they need special handling.
// Boolean to int/float conversions. Boolean are unsigned.
if (srcElemType.isInteger(1)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

// -----


func.func @test_dequantizelinear_i8(%arg0: tensor<4xi8>, %arg1: tensor<f32>, %arg2: tensor<i8>) -> tensor<4xf32> {
%0 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<4xi8>, tensor<f32>, tensor<i8>) -> tensor<4xf32>
return %0 : tensor<4xf32>

// mlir2FileCheck.py
// CHECK-LABEL: func.func @test_dequantizelinear_i8
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4xi8>, [[PARAM_1_:%.+]]: memref<f32>, [[PARAM_2_:%.+]]: memref<i8>) -> memref<4xf32> {
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<4xf32>
Expand All @@ -18,12 +20,13 @@ func.func @test_dequantizelinear_i8(%arg0: tensor<4xi8>, %arg1: tensor<f32>, %ar
// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref<4xi8>
// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref<f32>
// CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref<i8>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_5_:%.+]] = arith.sitofp [[LOAD_PARAM_2_MEM_]] : i8 to f32
// CHECK-DAG: [[VAR_6_:%.+]] = arith.sitofp [[LOAD_PARAM_0_MEM_]] : i8 to f32
// CHECK: [[VAR_7_:%.+]] = arith.subf [[VAR_6_]], [[VAR_5_]] : f32
// CHECK: [[VAR_8_:%.+]] = arith.mulf [[VAR_7_]], [[LOAD_PARAM_1_MEM_]] : f32
// CHECK: krnl.store [[VAR_8_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32>
// CHECK: [[VAR_5_:%.+]] = arith.extsi [[LOAD_PARAM_0_MEM_]] : i8 to i32
// CHECK-DAG: [[VAR_6_:%.+]] = arith.sitofp [[VAR_5_]] : i32 to f32
// CHECK-DAG: [[VAR_7_:%.+]] = arith.extsi [[LOAD_PARAM_2_MEM_]] : i8 to i32
// CHECK: [[VAR_8_:%.+]] = arith.sitofp [[VAR_7_]] : i32 to f32
// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_6_]], [[VAR_8_]] : f32
// CHECK: [[VAR_10_:%.+]] = arith.mulf [[VAR_9_]], [[LOAD_PARAM_1_MEM_]] : f32
// CHECK: krnl.store [[VAR_10_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32>
// CHECK: }
// CHECK: return [[RES_]] : memref<4xf32>
// CHECK: }
Expand All @@ -47,12 +50,14 @@ func.func @test_dequantizelinear_ui8(%arg0: tensor<4xui8>, %arg1: tensor<f32>, %
// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref<f32>
// CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref<ui8>
// CHECK: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8
// CHECK-DAG: [[VAR_6_:%.+]] = arith.uitofp [[VAR_5_]] : i8 to f32
// CHECK-DAG: [[VAR_7_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8
// CHECK: [[VAR_8_:%.+]] = arith.uitofp [[VAR_7_]] : i8 to f32
// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_6_]], [[VAR_8_]] : f32
// CHECK: [[VAR_10_:%.+]] = arith.mulf [[VAR_9_]], [[LOAD_PARAM_1_MEM_]] : f32
// CHECK: krnl.store [[VAR_10_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32>
// CHECK: [[VAR_6_:%.+]] = arith.extui [[VAR_5_]] : i8 to i32
// CHECK-DAG: [[VAR_7_:%.+]] = arith.uitofp [[VAR_6_]] : i32 to f32
// CHECK-DAG: [[VAR_8_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8
// CHECK: [[VAR_9_:%.+]] = arith.extui [[VAR_8_]] : i8 to i32
// CHECK: [[VAR_10_:%.+]] = arith.uitofp [[VAR_9_]] : i32 to f32
// CHECK: [[VAR_11_:%.+]] = arith.subf [[VAR_7_]], [[VAR_10_]] : f32
// CHECK: [[VAR_12_:%.+]] = arith.mulf [[VAR_11_]], [[LOAD_PARAM_1_MEM_]] : f32
// CHECK: krnl.store [[VAR_12_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32>
// CHECK: }
// CHECK: return [[RES_]] : memref<4xf32>
// CHECK: }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,19 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor<?x2xf32>) -> (tensor<?x2xu
// CHECK-DAG: [[VAR_23_:%.+]] = arith.select [[VAR_21_]], [[VAR_22_]], [[VAR_12_]] : f32
// CHECK-DAG: [[VAR_24_:%.+]] = arith.cmpf oeq, [[VAR_13_]], [[CST_5_dot_000000_]] : f32
// CHECK: [[VAR_25_:%.+]] = arith.select [[VAR_24_]], [[VAR_23_]], [[VAR_16_]] : f32
// CHECK: [[VAR_26_:%.+]] = arith.fptoui [[VAR_25_]] : f32 to i8
// CHECK: [[VAR_27_:%.+]] = builtin.unrealized_conversion_cast [[VAR_26_]] : i8 to ui8
// CHECK: [[VAR_26_:%.+]] = arith.fptoui [[VAR_25_]] : f32 to i32
// CHECK: [[VAR_27_:%.+]] = arith.trunci [[VAR_26_]] : i32 to i8
// CHECK: [[VAR_28_:%.+]] = builtin.unrealized_conversion_cast [[VAR_27_]] : i8 to ui8
// CHECK: krnl.store [[VAR_7_]], [[RES_1_]][] : memref<f32>
// CHECK: krnl.store [[VAR_27_]], [[RES_2_]][] : memref<ui8>
// CHECK-DAG: [[VAR_28_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}}
// CHECK: krnl.store [[VAR_28_]], [[RES_2_]][] : memref<ui8>
// CHECK-DAG: [[VAR_29_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}}
// CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<1xindex>
// CHECK: affine.store [[VAR_28_]], [[RES_5_]][0] : memref<1xindex>
// CHECK: affine.store [[VAR_29_]], [[RES_5_]][0] : memref<1xindex>
// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_5_]]) : (memref<?x2xf32>, memref<1xindex>) -> memref<?xf32>
// CHECK-DAG: [[VAR_29_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}}
// CHECK-DAG: [[VAR_30_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}}
// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<1xindex>
// CHECK: affine.store [[VAR_29_]], [[RES_6_]][0] : memref<1xindex>
// CHECK: affine.store [[VAR_30_]], [[RES_6_]][0] : memref<1xindex>
// CHECK-DAG: [[VAR_reshape_14_:%.+]] = memref.reshape [[RES_]]([[RES_]]_13) : (memref<?x2xui8>, memref<1xindex>) -> memref<?xui8>
// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc([[VAR_28_]]) {{.*}}: memref<?xf32>
// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1
// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){
// CHECK: [[VAR_32_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index
Expand All @@ -112,15 +112,10 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor<?x2xf32>) -> (tensor<?x2xu
// CHECK: [[VAR_49_:%.+]] = arith.addf [[VAR_48_]], [[VAR_25_]] : f32
// CHECK: [[VAR_50_:%.+]] = arith.maxnumf [[VAR_49_]], [[CST_0_dot_000000_]] : f32
// CHECK: [[VAR_51_:%.+]] = arith.minnumf [[VAR_50_]], [[CST_2_dot_550000_]] : f32
// CHECK: krnl.store [[VAR_51_]], [[RES_7_]]{{.}}[[VAR_32_2_]]{{.}} : memref<?xf32>
// CHECK: }
// CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1
// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){
// CHECK: [[VAR_32_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index
// CHECK: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_32_3_]]{{.}} : memref<?xf32>
// CHECK: [[LOAD_RES_3_MEM_1_1_:%.+]] = arith.fptoui [[LOAD_PARAM_0_MEM_1_1_]] : f32 to i8
// CHECK: [[VAR_35_3_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_RES_3_MEM_1_1_]] : i8 to ui8
// CHECK: krnl.store [[VAR_35_3_]], [[VAR_reshape_14_]]{{.}}[[VAR_32_3_]]{{.}} : memref<?xui8>
// CHECK: [[VAR_52_:%.+]] = arith.fptoui [[VAR_51_]] : f32 to i32
// CHECK: [[VAR_53_:%.+]] = arith.trunci [[VAR_52_]] : i32 to i8
// CHECK: [[VAR_54_:%.+]] = builtin.unrealized_conversion_cast [[VAR_53_]] : i8 to ui8
// CHECK: krnl.store [[VAR_54_]], [[VAR_reshape_14_]]{{.}}[[VAR_32_2_]]{{.}} : memref<?xui8>
// CHECK: }
// CHECK: return [[RES_]], [[RES_]]_6, [[RES_]]_7 : memref<?x2xui8>, memref<f32>, memref<ui8>
// CHECK: }
Expand Down
Loading

0 comments on commit f7d5895

Please sign in to comment.