diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp index 5561955304..efcb10d829 100644 --- a/src/Compiler/CompilerPasses.cpp +++ b/src/Compiler/CompilerPasses.cpp @@ -214,7 +214,7 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, void addKrnlToAffinePasses(mlir::PassManager &pm) { pm.addNestedPass( - onnx_mlir::krnl::createConvertKrnlToAffinePass()); + onnx_mlir::krnl::createConvertKrnlToAffinePass(enableParallel)); } void addKrnlToLLVMPasses( diff --git a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp index 73609c2f14..c3b466930c 100644 --- a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp +++ b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp @@ -844,11 +844,21 @@ struct ConvertKrnlToAffinePass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertKrnlToAffinePass); + ConvertKrnlToAffinePass() = default; + ConvertKrnlToAffinePass(const ConvertKrnlToAffinePass &pass) + : PassWrapper>() {} + ConvertKrnlToAffinePass(bool parallelEnabled) { + this->parallelEnabled = parallelEnabled; + } + StringRef getArgument() const override { return "convert-krnl-to-affine"; } StringRef getDescription() const override { return "Lower Krnl dialect."; } void runOnOperation() final; + + Option parallelEnabled{*this, "parallel-enabled", + llvm::cl::desc("Enable parallelization"), llvm::cl::init(false)}; }; void ConvertKrnlToAffinePass::runOnOperation() { @@ -1008,7 +1018,7 @@ void ConvertKrnlToAffinePass::runOnOperation() { RewritePatternSet patterns(ctx); AffineTypeConverter typeConverter; - populateKrnlToAffineConversion(typeConverter, patterns, ctx); + populateKrnlToAffineConversion(typeConverter, patterns, ctx, parallelEnabled); // Create list for recording the pairs associated with // this function. @@ -1046,8 +1056,12 @@ std::unique_ptr createConvertKrnlToAffinePass() { return std::make_unique(); } +std::unique_ptr createConvertKrnlToAffinePass(bool parallelEnabled) { + return std::make_unique(parallelEnabled); +} + void populateKrnlToAffineConversion(TypeConverter &typeConverter, - RewritePatternSet &patterns, MLIRContext *ctx) { + RewritePatternSet &patterns, MLIRContext *ctx, bool parallelEnabled) { krnl::populateLoweringKrnlCopyFromBufferOpPattern( typeConverter, patterns, ctx); krnl::populateLoweringKrnlCopyToBufferOpPattern(typeConverter, patterns, ctx); @@ -1055,7 +1069,8 @@ void populateKrnlToAffineConversion(TypeConverter &typeConverter, krnl::populateLoweringKrnlStoreOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlGetLinearOffsetIndexOpPattern( typeConverter, patterns, ctx); - krnl::populateLoweringKrnlMatmultOpPattern(typeConverter, patterns, ctx); + krnl::populateLoweringKrnlMatmultOpPattern( + typeConverter, patterns, ctx, parallelEnabled); krnl::populateLoweringKrnlMemsetOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlPrefetchOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlTerminatorOpPattern(typeConverter, patterns, ctx); diff --git a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp index 2bc0fd3aae..c1c222a293 100644 --- a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp +++ b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp @@ -56,7 +56,8 @@ using UnrollAndJamList = llvm::SmallVector; using UnrollAndJamMap = std::map; void populateKrnlToAffineConversion(mlir::TypeConverter &typeConverter, - mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); + mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx, + bool enableParallel); void populateLoweringKrnlCopyFromBufferOpPattern( mlir::TypeConverter &typeConverter, mlir::RewritePatternSet &patterns, @@ -77,7 +78,8 @@ void populateLoweringKrnlGetLinearOffsetIndexOpPattern( mlir::MLIRContext *ctx); void populateLoweringKrnlMatmultOpPattern(mlir::TypeConverter &typeConverter, - mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); + mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx, + bool parallelEnabled); void populateLoweringKrnlMemsetOpPattern(mlir::TypeConverter &typeConverter, mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); diff --git a/src/Conversion/KrnlToAffine/KrnlMatmul.cpp b/src/Conversion/KrnlToAffine/KrnlMatmul.cpp index 8ab9ef7b1a..6b42177457 100644 --- a/src/Conversion/KrnlToAffine/KrnlMatmul.cpp +++ b/src/Conversion/KrnlToAffine/KrnlMatmul.cpp @@ -42,9 +42,12 @@ extern std::mutex unrollAndJamMutex; class KrnlMatmulLowering : public ConversionPattern { public: explicit KrnlMatmulLowering( - TypeConverter &typeConverter, MLIRContext *context) + TypeConverter &typeConverter, MLIRContext *context, bool parallelEnabled) : ConversionPattern( - typeConverter, KrnlMatMulOp::getOperationName(), 1, context) {} + typeConverter, KrnlMatMulOp::getOperationName(), 1, context) { + this->parallelEnabled = parallelEnabled; + } + bool parallelEnabled = false; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { @@ -221,22 +224,33 @@ class KrnlMatmulLowering : public ConversionPattern { if (simdize) { // SIMD code generator. if (matVectorProduct) { + // Alloc of temp outside of inner if/then/else. + Value TmpSimdProd = allocForGenSimdMatVect(create.affineKMem, + elementType, iComputeTileSize, jComputeTileSize, kComputeTileSize, + vectorLen, fullUnrollAndJam); + Value TmpScalarProd = allocForGenScalar(create.affineKMem, elementType, + iTrip, jTrip, kTrip, /*unroll*/ false); // clang-format off create.affineKMem.ifThenElseIE(indexScope, allFullTiles, /* then full tiles */ [&](const AffineBuilderKrnlMem &createAffine) { - genSimdMatVect(createAffine, matmulOp, elementType, aStart, bStart, + genSimdMatVect(createAffine, matmulOp, TmpSimdProd, elementType, aStart, bStart, cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize, vectorLen, fullUnrollAndJam); }, /* else has partial tiles */ [&](const AffineBuilderKrnlMem &createAffine) { - genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart, + genScalar(createAffine, matmulOp, TmpScalarProd, elementType, aStart, bStart, cStart, iTrip, jTrip, kTrip, /*unroll*/ false); }); // clang-format on } else { + Value TmpSimdC = allocForGenSimdMatMat(create.affineKMem, elementType, + iComputeTileSize, jComputeTileSize, kComputeTileSize, vectorLen, + fullUnrollAndJam); + Value TmpScalarC = allocForGenScalar(create.affineKMem, elementType, + iTrip, jPartialTrip, kTrip, /*unroll*/ false); // clang-format off create.affineKMem.ifThenElseIE(indexScope, allFullTiles, /* then full tiles */ [&](const AffineBuilderKrnlMem &createAffine) { - genSimdMatMat(createAffine, matmulOp, elementType, aStart, bStart, + genSimdMatMat(createAffine, matmulOp, TmpSimdC, elementType, aStart, bStart, cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize, vectorLen, fullUnrollAndJam); }, @@ -246,33 +260,30 @@ class KrnlMatmulLowering : public ConversionPattern { // Test if SIMD dim (M) is full. createAffine.ifThenElseIE(indexScope, jFullTiles, /* full SIMD */ [&](const AffineBuilderKrnlMem &createAffine) { - genSimdMatMat(createAffine, matmulOp, elementType, aStart, bStart, + genSimdMatMat(createAffine, matmulOp, TmpSimdC, elementType, aStart, bStart, cStart, iTrip, jComputeTileSize, kTrip, vectorLen, /*unroll*/ false); }, /* else partial SIMD */ [&](const AffineBuilderKrnlMem &createAffine) { - // TODO: evaluate if get performance from partial SIMD - if (false && jPartialTrip.isLiteral() && jPartialTrip.getLiteral() >=2) { - // has a known trip count along the simd dimension of at least 2 - // elements, use simd again. - genSimdMatMat(createAffine, matmulOp, elementType, aStart, bStart, - cStart, iTrip, jPartialTrip, kTrip, vectorLen, /*unroll*/ false); - } else { - genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart, - iTrip, jPartialTrip, kTrip, /*unroll*/ false); - } + genScalar(createAffine, matmulOp, TmpScalarC, elementType, aStart, bStart, cStart, + iTrip, jPartialTrip, kTrip, /*unroll*/ false); }); }); // clang-format on } } else { // Scalar code generator. + Value TmpThenC = + allocForGenScalar(create.affineKMem, elementType, iComputeTileSize, + jComputeTileSize, kComputeTileSize, fullUnrollAndJam); + Value TmpElseC = allocForGenScalar( + create.affineKMem, elementType, iTrip, jTrip, kTrip, false); // clang-format off create.affineKMem.ifThenElseIE(indexScope, allFullTiles, /* then full */ [&](const AffineBuilderKrnlMem &createAffine) { - genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart, + genScalar(createAffine, matmulOp, TmpThenC, elementType, aStart, bStart, cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize, fullUnrollAndJam); }, /* else partial */ [&](const AffineBuilderKrnlMem &createAffine) { - genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart, + genScalar(createAffine, matmulOp, TmpElseC, elementType, aStart, bStart, cStart, iTrip, jTrip, kTrip, false); }); // clang-format on @@ -282,21 +293,32 @@ class KrnlMatmulLowering : public ConversionPattern { } private: - void genScalar(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op, - Type elementType, ArrayRef aStart, ArrayRef bStart, - ArrayRef cStart, IndexExpr I, IndexExpr J, IndexExpr K, + Value allocForGenScalar(const AffineBuilderKrnlMem &createAffine, + Type elementType, IndexExpr I, IndexExpr J, IndexExpr K, bool unrollJam) const { // Get operands. + MemRefBuilder createMemRef(createAffine); + int64_t unrollFactor = (unrollJam && J.isLiteral()) ? J.getLiteral() : 1; + // Have to privatize CTmpType by unroll factor (1 if none). + MemRefType CTmpType = MemRefType::get({unrollFactor}, elementType); + assert(BUFFER_ALIGN >= gDefaultAllocAlign); + // + if (parallelEnabled) + return createMemRef.alignedAlloc(CTmpType, BUFFER_ALIGN); + return createMemRef.alignedAlloca(CTmpType, BUFFER_ALIGN); + } + + void genScalar(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op, + Value TmpC, Type elementType, ArrayRef aStart, + ArrayRef bStart, ArrayRef cStart, IndexExpr I, + IndexExpr J, IndexExpr K, bool unrollJam) const { + // Get operands. KrnlMatMulOpAdaptor operandAdaptor(op); MemRefBuilder createMemRef(createAffine); Value A(operandAdaptor.getA()), B(operandAdaptor.getB()), C(operandAdaptor.getC()); int64_t unrollFactor = (unrollJam && J.isLiteral()) ? J.getLiteral() : 1; - // Have to privatize CTmpType by unroll factor (1 if none). - MemRefType CTmpType = MemRefType::get({unrollFactor}, elementType); - assert(BUFFER_ALIGN >= gDefaultAllocAlign); - Value TmpC = createMemRef.alignedAlloc(CTmpType, BUFFER_ALIGN); // For i, j loops. LiteralIndexExpr zeroIE(0); @@ -342,11 +364,46 @@ class KrnlMatmulLowering : public ConversionPattern { } } + Value allocForGenSimdMatVect(const AffineBuilderKrnlMem &createAffine, + Type elementType, IndexExpr I, IndexExpr J, IndexExpr K, + IndexExpr vectorLen, bool unrollJam) const { + // can simdize only if I & K is compile time + assert(I.isLiteral() && K.isLiteral() && vectorLen.isLiteral() && + "can only simdize with compile time " + "blocking factor on simd axis"); + MultiDialectBuilder create(createAffine); + int64_t iLit(I.getLiteral()), VL(vectorLen.getLiteral()); + int64_t archVL = create.vec.getArchVectorLength(elementType); + + // Generate the vector type conversions. + assert(VL == archVL && "vector length and VL must be identical for now"); + VectorType vecType = VectorType::get({VL}, elementType); + int64_t iUnrollFactor = iLit; + assert(iUnrollFactor % VL == 0 && "i blocking should be a multiple of VL"); + + // Have to privatize CTmpType by unroll factor. + MemRefType CTmpType = MemRefType::get({iUnrollFactor}, vecType); + assert(BUFFER_ALIGN >= gDefaultAllocAlign && + "alignment of buffers cannot be smaller than the default alignment " + "(which is set for SIMD correctness"); + // Ok to use an alloca here because hoisting will take it out of the loop, + // as it is now generated before the scf.if which precluded the migration to + // outside the loops. + + // But at this time, if parallel is enabled, alloca would be stuck inside of + // the parallel loop, which is not great. TODO: migrate alloca from inside + // the parallel loop to the OMP parallel region before the loop. + // Grep for this pattern in all 3 instances of "parallelEnabled". + if (parallelEnabled) + return create.mem.alignedAlloc(CTmpType, BUFFER_ALIGN); + return create.mem.alignedAlloca(CTmpType, BUFFER_ALIGN); + } + // Initially, simdize with full K vector length. void genSimdMatVect(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op, - Type elementType, ArrayRef aStart, ArrayRef bStart, - ArrayRef cStart, IndexExpr I, IndexExpr J, IndexExpr K, - IndexExpr vectorLen, bool unrollJam) const { + Value TmpProd, Type elementType, ArrayRef aStart, + ArrayRef bStart, ArrayRef cStart, IndexExpr I, + IndexExpr J, IndexExpr K, IndexExpr vectorLen, bool unrollJam) const { // can simdize only if I & K is compile time assert(I.isLiteral() && K.isLiteral() && vectorLen.isLiteral() && "can only simdize with compile time " @@ -367,12 +424,6 @@ class KrnlMatmulLowering : public ConversionPattern { int64_t iUnrollFactor = iLit; assert(iUnrollFactor % VL == 0 && "i blocking should be a multiple of VL"); - // Have to privatize CTmpType by unroll factor. - MemRefType CTmpType = MemRefType::get({iUnrollFactor}, vecType); - assert(BUFFER_ALIGN >= gDefaultAllocAlign && - "alignment of buffers cannot be smaller than the default alignment " - "(which is set for SIMD correctness"); - Value TmpProd = create.mem.alignedAlloc(CTmpType, BUFFER_ALIGN); // Init with zero. Value fZero = create.math.constant(elementType, 0); Value vFZero = create.vec.broadcast(vecType, fZero); @@ -427,11 +478,36 @@ class KrnlMatmulLowering : public ConversionPattern { } } + Value allocForGenSimdMatMat(const AffineBuilderKrnlMem &createAffine, + Type elementType, IndexExpr I, IndexExpr J, IndexExpr K, + IndexExpr vectorLen, bool unrollJam) const { + // can simdize only if K is compile time + MultiDialectBuilder create(createAffine); + + // Generate the vector type conversions. + int64_t VL = vectorLen.getLiteral(); + VectorType vecType = VectorType::get({VL}, elementType); + int64_t unrollFactor = (unrollJam && I.isLiteral()) ? I.getLiteral() : 1; + // Have to privatize CTmpType by unroll factor (1 if none). + MemRefType CTmpType = MemRefType::get({unrollFactor}, vecType); + assert(BUFFER_ALIGN >= gDefaultAllocAlign); + // Ok to use an alloca here because hoisting will take it out of the loop, + // as it is now generated before the scf.if which precluded the migration to + // outside the loops. + + // But at this time, if parallel is enabled, alloca would be stuck inside of + // the parallel loop, which is not great. TODO: migrate alloca from inside + // the parallel loop to the OMP parallel region before the loop. + if (parallelEnabled) + return create.mem.alignedAlloc(CTmpType, BUFFER_ALIGN); + return create.mem.alignedAlloca(CTmpType, BUFFER_ALIGN); + } + // Simdize along J / memory rows in B and C. void genSimdMatMat(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op, - Type elementType, ArrayRef aStart, ArrayRef bStart, - ArrayRef cStart, IndexExpr I, IndexExpr J, IndexExpr K, - IndexExpr vectorLen, bool unrollJam) const { + Value TmpC, Type elementType, ArrayRef aStart, + ArrayRef bStart, ArrayRef cStart, IndexExpr I, + IndexExpr J, IndexExpr K, IndexExpr vectorLen, bool unrollJam) const { // can simdize only if K is compile time assert(J.isLiteral() && "can only simdize with compile time blocking factor on simd axis"); @@ -446,10 +522,6 @@ class KrnlMatmulLowering : public ConversionPattern { int64_t VL = vectorLen.getLiteral(); VectorType vecType = VectorType::get({VL}, elementType); int64_t unrollFactor = (unrollJam && I.isLiteral()) ? I.getLiteral() : 1; - // Have to privatize CTmpType by unroll factor (1 if none). - MemRefType CTmpType = MemRefType::get({unrollFactor}, vecType); - assert(BUFFER_ALIGN >= gDefaultAllocAlign); - Value TmpC = create.mem.alignedAlloc(CTmpType, BUFFER_ALIGN); // Iterates over the I indices (j are simd dim). Value iSaved, kSaved; @@ -547,8 +619,8 @@ class KrnlMatmulLowering : public ConversionPattern { }; // namespace krnl void populateLoweringKrnlMatmultOpPattern(TypeConverter &typeConverter, - RewritePatternSet &patterns, MLIRContext *ctx) { - patterns.insert(typeConverter, ctx); + RewritePatternSet &patterns, MLIRContext *ctx, bool parallelEnabled) { + patterns.insert(typeConverter, ctx, parallelEnabled); } } // namespace krnl diff --git a/src/Conversion/ONNXToKrnl/Math/Gemm.cpp b/src/Conversion/ONNXToKrnl/Math/Gemm.cpp index af0724c446..4f110bd6bd 100644 --- a/src/Conversion/ONNXToKrnl/Math/Gemm.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Gemm.cpp @@ -248,6 +248,8 @@ struct ONNXGemmOpLowering : public OpConversionPattern { [&](const KrnlBuilder &createKrnl, ValueRange i1_j1_indices) { Value i1(i1_j1_indices[0]), j1(i1_j1_indices[1]); // If parallel, will stay inside, otherwise will migrate out. + // Since they are not in an if structure, migration out is not an + // issue. Value aBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN); Value bBuff = create.mem.alignedAlloc(bTileType, BUFFER_ALIGN); Value rBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN); @@ -313,6 +315,8 @@ struct ONNXGemmOpLowering : public OpConversionPattern { [&](const KrnlBuilder &createKrnl, ValueRange j1_k1_indices) { Value j1(j1_k1_indices[0]), k1(j1_k1_indices[1]); // If parallel, it will stay inside, otherwise it will migrate out. + // Since allocs are not in an if structure, migration is not an + // issue. Value aBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN); Value bBuff = create.mem.alignedAlloc(bTileType, BUFFER_ALIGN); if (bTrans) diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index f22cbf2595..070fe3d671 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -108,6 +108,7 @@ std::unique_ptr createElideConstGlobalValuePass(); namespace krnl { /// Pass for lowering frontend dialects to Krnl IR dialect. std::unique_ptr createConvertKrnlToAffinePass(); +std::unique_ptr createConvertKrnlToAffinePass(bool parallelEnabled); /// Pass for lowering Seq in Krnl dialect. std::unique_ptr createConvertSeqToMemrefPass();