Skip to content

Commit

Permalink
Upgrading llvm and stablehlo hash (#3053)
Browse files Browse the repository at this point in the history
* upgrading llvm and stablehlo hash. Fixing mlir tests

Signed-off-by: Christopher Munoz <[email protected]>

* fixing vector shapecast bug introduced by upgraded llvm

Signed-off-by: Christopher Munoz <[email protected]>

---------

Signed-off-by: Christopher Munoz <[email protected]>
  • Loading branch information
christopherlmunoz authored Feb 14, 2025
1 parent 21d4759 commit df601dc
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 54 deletions.
2 changes: 1 addition & 1 deletion docs/BuildOnLinuxOSX.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
``` bash
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout e86910337f98e57f5b9253f7d80d5b916eb1d97e && cd ..
cd llvm-project && git checkout b270525f730be6e7196667925f5a9bfa153262e9 && cd ..
```

[same-as-file]: <> (utils/build-mlir.sh)
Expand Down
2 changes: 1 addition & 1 deletion docs/BuildOnWindows.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Install MLIR (as a part of LLVM-Project):
```shell
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout e86910337f98e57f5b9253f7d80d5b916eb1d97e && cd ..
cd llvm-project && git checkout b270525f730be6e7196667925f5a9bfa153262e9 && cd ..
```

[same-as-file]: <> (utils/build-mlir.cmd)
Expand Down
13 changes: 12 additions & 1 deletion src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,23 @@ void populateAffineAndKrnlToLLVMConversion(RewritePatternSet &patterns,
// They run it in two steps, and add additional lowerings.

vector::populateVectorToVectorCanonicalizationPatterns(patterns);
vector::populateVectorBitCastLoweringPatterns(patterns);
vector::populateVectorBroadcastLoweringPatterns(patterns);
vector::populateVectorContractLoweringPatterns(
patterns, vector::VectorTransformsOptions());
vector::populateVectorMaskOpLoweringPatterns(patterns);
vector::populateVectorShapeCastLoweringPatterns(patterns);
vector::populateVectorInterleaveLoweringPatterns(patterns);
vector::populateVectorTransposeLoweringPatterns(
patterns, vector::VectorTransformsOptions());
vector::populateVectorShapeCastLoweringPatterns(patterns);
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
vector::populateVectorTransferLoweringPatterns(
patterns, /*maxTransferRank=*/1);
vector::populateVectorMaskMaterializationPatterns(
patterns, /*force32BitVectorIndices*/ false);
vector::populateVectorInsertExtractStridedSliceTransforms(patterns);
vector::populateVectorStepLoweringPatterns(patterns);
vector::populateVectorRankReducingFMAPattern(patterns);

populateAffineToStdConversionPatterns(patterns);
populateSCFToControlFlowConversionPatterns(patterns);
Expand Down
12 changes: 8 additions & 4 deletions src/Conversion/ONNXToTOSA/Math/Conv2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Value createConvInGroups(PatternRewriter &rewriter, Operation *op,
const llvm::ArrayRef<int64_t> weightShape, Value &newInput,
Value &newWeight, Value &bias, const int64_t groups,
DenseI64ArrayAttr &pads, DenseI64ArrayAttr &strides,
DenseI64ArrayAttr &dilations) {
DenseI64ArrayAttr &dilations, TypeAttr &accType) {
// Set up constants outside of loop
const int64_t sizeOfSliceInput = weightShape[1];
const int64_t sizeOfSliceKernel = weightShape[0] / groups;
Expand Down Expand Up @@ -72,7 +72,7 @@ Value createConvInGroups(PatternRewriter &rewriter, Operation *op,
mlir::cast<ShapedType>(resultType).getElementType());
Value tempConv2D = tosa::CreateOpAndInfer<mlir::tosa::Conv2DOp>(rewriter,
op->getLoc(), newConvOutputType, newSliceInput, newSliceWeight,
newSliceBias, pads, strides, dilations);
newSliceBias, pads, strides, dilations, accType);
// Add value to vector
sliceValues.push_back(tempConv2D);
}
Expand Down Expand Up @@ -147,6 +147,10 @@ class ONNXConvOpLoweringToTOSA : public ConversionPattern {
DenseI64ArrayAttr newPads =
rewriter.getDenseI64ArrayAttr({pads[0], pads[2], pads[1], pads[3]});

Type convType =
(resultType.isF16()) ? rewriter.getF16Type() : rewriter.getF32Type();
TypeAttr accType = mlir::TypeAttr::get(convType);

// Handle group parameter by creating multiple convs
const int64_t group = adaptor.getGroup();
Value conv2D = NULL;
Expand All @@ -157,11 +161,11 @@ class ONNXConvOpLoweringToTOSA : public ConversionPattern {

conv2D = tosa::CreateOpAndInfer<mlir::tosa::Conv2DOp>(rewriter,
convOp->getLoc(), newConvOutputType, newInput, newWeight, bias,
newPads, strides, dilations);
newPads, strides, dilations, accType);
} else {
conv2D = createConvInGroups(rewriter, convOp, tosaBuilder, resultType,
weightShape, newInput, newWeight, bias, group, newPads, strides,
dilations);
dilations, accType);
}

// Convert output [N,OH,OW,OC] -> [N,OC,OH,OW]
Expand Down
8 changes: 3 additions & 5 deletions src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Value buildOnnxToTosaPaddingConstOp(mlir::PatternRewriter &rewriter,

// Create a new pad vec in the right format
// ONNX : [b1, b2, b3, b4, e1, e2, e3, e4]
// TOSA :[[b1, e1], [b2, e2], [b3, e3], [b4, e4]]
// TOSA :[b1, e1, b2, e2, b3, e3, b4, e4]

// Adds any initial or last vals, not included in onnxPads.
llvm::SmallVector<int64_t, 8> tosaPads{initialVals};
Expand All @@ -59,11 +59,9 @@ Value buildOnnxToTosaPaddingConstOp(mlir::PatternRewriter &rewriter,
tosaPads.push_back(onnxPads[i + dimSize]);
}
tosaPads.insert(tosaPads.end(), lastVals.begin(), lastVals.end());

// TOSA format groups dimensions by 2.
const unsigned int numberOfDims = tosaPads.size() / 2;
TosaBuilder tosaBuilder(rewriter, loc);
return tosaBuilder.getConst(tosaPads, {numberOfDims, 2});
return tosaBuilder.getConst(
tosaPads, {static_cast<int64_t>(tosaPads.size())});
}

} // namespace tosa
Expand Down
33 changes: 8 additions & 25 deletions test/mlir/conversion/krnl_to_llvm/krnl_math_function_lowering.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ func.func @test_krnl_erf_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf32>

// CHECK-LABEL: test_krnl_erf_lowering
// CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32>
// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr
// CHECK: [[ERF_RES:%.+]] = llvm.call @erff([[SCALAR_IN]]) : (f32) -> f32
Expand All @@ -41,9 +39,7 @@ func.func @test_krnl_acos_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf32

// CHECK-LABEL: test_krnl_acos_lowering
// CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32>
// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr
// CHECK: [[ACOS_RES:%.+]] = llvm.call @acosf([[SCALAR_IN]]) : (f32) -> f32
Expand All @@ -66,9 +62,7 @@ func.func @test_krnl_acosh_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf3

// CHECK-LABEL: test_krnl_acosh_lowering
// CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32>
// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr
// CHECK: [[ACOS_RES:%.+]] = llvm.call @acoshf([[SCALAR_IN]]) : (f32) -> f32
Expand All @@ -91,9 +85,7 @@ func.func @test_krnl_asin_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf32

// CHECK-LABEL: test_krnl_asin_lowering
// CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32>
// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr
// CHECK: [[ACOS_RES:%.+]] = llvm.call @asinf([[SCALAR_IN]]) : (f32) -> f32
Expand All @@ -116,9 +108,7 @@ func.func @test_krnl_asinh_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf3

// CHECK-LABEL: test_krnl_asinh_lowering
// CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32>
// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr
// CHECK: [[ACOS_RES:%.+]] = llvm.call @asinhf([[SCALAR_IN]]) : (f32) -> f32
Expand All @@ -141,9 +131,7 @@ func.func @test_krnl_atan_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf32

// CHECK-LABEL: test_krnl_atan_lowering
// CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32>
// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr
// CHECK: [[ACOS_RES:%.+]] = llvm.call @atanf([[SCALAR_IN]]) : (f32) -> f32
Expand All @@ -165,9 +153,7 @@ func.func @test_krnl_atanh_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf3

// CHECK-LABEL: test_krnl_atanh_lowering
// CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32>
// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr
// CHECK: [[ACOS_RES:%.+]] = llvm.call @atanhf([[SCALAR_IN]]) : (f32) -> f32
Expand All @@ -189,12 +175,9 @@ func.func @test_krnl_tan_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf32>

// CHECK-LABEL: test_krnl_tan_lowering
// CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32>
// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr
// CHECK: [[ACOS_RES:%.+]] = llvm.call @tanf([[SCALAR_IN]]) : (f32) -> f32
// CHECK: [[DATA_OUT:%.+]] = llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: llvm.store [[ACOS_RES]], [[DATA_OUT]] : f32, !llvm.ptr

Loading

0 comments on commit df601dc

Please sign in to comment.