Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matmul CPU performance regression #3072

Merged
merged 24 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,

void addKrnlToAffinePasses(mlir::PassManager &pm) {
pm.addNestedPass<func::FuncOp>(
onnx_mlir::krnl::createConvertKrnlToAffinePass());
onnx_mlir::krnl::createConvertKrnlToAffinePass(enableParallel));
}

void addKrnlToLLVMPasses(
Expand Down
21 changes: 18 additions & 3 deletions src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -844,11 +844,21 @@ struct ConvertKrnlToAffinePass
: public PassWrapper<ConvertKrnlToAffinePass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertKrnlToAffinePass);

ConvertKrnlToAffinePass() = default;
ConvertKrnlToAffinePass(const ConvertKrnlToAffinePass &pass)
: PassWrapper<ConvertKrnlToAffinePass, OperationPass<func::FuncOp>>() {}
ConvertKrnlToAffinePass(bool parallelEnabled) {
this->parallelEnabled = parallelEnabled;
}

StringRef getArgument() const override { return "convert-krnl-to-affine"; }

StringRef getDescription() const override { return "Lower Krnl dialect."; }

void runOnOperation() final;

Option<bool> parallelEnabled{*this, "parallel-enabled",
llvm::cl::desc("Enable parallelization"), llvm::cl::init(false)};
};

void ConvertKrnlToAffinePass::runOnOperation() {
Expand Down Expand Up @@ -1008,7 +1018,7 @@ void ConvertKrnlToAffinePass::runOnOperation() {
RewritePatternSet patterns(ctx);
AffineTypeConverter typeConverter;

populateKrnlToAffineConversion(typeConverter, patterns, ctx);
populateKrnlToAffineConversion(typeConverter, patterns, ctx, parallelEnabled);

// Create list for recording the <loop, unroll factor> pairs associated with
// this function.
Expand Down Expand Up @@ -1046,16 +1056,21 @@ std::unique_ptr<Pass> createConvertKrnlToAffinePass() {
return std::make_unique<ConvertKrnlToAffinePass>();
}

std::unique_ptr<Pass> createConvertKrnlToAffinePass(bool parallelEnabled) {
return std::make_unique<ConvertKrnlToAffinePass>(parallelEnabled);
}

void populateKrnlToAffineConversion(TypeConverter &typeConverter,
RewritePatternSet &patterns, MLIRContext *ctx) {
RewritePatternSet &patterns, MLIRContext *ctx, bool parallelEnabled) {
krnl::populateLoweringKrnlCopyFromBufferOpPattern(
typeConverter, patterns, ctx);
krnl::populateLoweringKrnlCopyToBufferOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlLoadOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlStoreOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlGetLinearOffsetIndexOpPattern(
typeConverter, patterns, ctx);
krnl::populateLoweringKrnlMatmultOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlMatmultOpPattern(
typeConverter, patterns, ctx, parallelEnabled);
krnl::populateLoweringKrnlMemsetOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlPrefetchOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlTerminatorOpPattern(typeConverter, patterns, ctx);
Expand Down
6 changes: 4 additions & 2 deletions src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ using UnrollAndJamList = llvm::SmallVector<UnrollAndJamRecord, 4>;
using UnrollAndJamMap = std::map<mlir::Operation *, UnrollAndJamList *>;

void populateKrnlToAffineConversion(mlir::TypeConverter &typeConverter,
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx,
bool enableParallel);

void populateLoweringKrnlCopyFromBufferOpPattern(
mlir::TypeConverter &typeConverter, mlir::RewritePatternSet &patterns,
Expand All @@ -77,7 +78,8 @@ void populateLoweringKrnlGetLinearOffsetIndexOpPattern(
mlir::MLIRContext *ctx);

void populateLoweringKrnlMatmultOpPattern(mlir::TypeConverter &typeConverter,
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx,
bool parallelEnabled);

void populateLoweringKrnlMemsetOpPattern(mlir::TypeConverter &typeConverter,
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
Expand Down
158 changes: 115 additions & 43 deletions src/Conversion/KrnlToAffine/KrnlMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@ extern std::mutex unrollAndJamMutex;
class KrnlMatmulLowering : public ConversionPattern {
public:
explicit KrnlMatmulLowering(
TypeConverter &typeConverter, MLIRContext *context)
TypeConverter &typeConverter, MLIRContext *context, bool parallelEnabled)
: ConversionPattern(
typeConverter, KrnlMatMulOp::getOperationName(), 1, context) {}
typeConverter, KrnlMatMulOp::getOperationName(), 1, context) {
this->parallelEnabled = parallelEnabled;
}
bool parallelEnabled = false;

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -221,22 +224,33 @@ class KrnlMatmulLowering : public ConversionPattern {
if (simdize) {
// SIMD code generator.
if (matVectorProduct) {
// Alloc of temp outside of inner if/then/else.
Value TmpSimdProd = allocForGenSimdMatVect(create.affineKMem,
elementType, iComputeTileSize, jComputeTileSize, kComputeTileSize,
vectorLen, fullUnrollAndJam);
Value TmpScalarProd = allocForGenScalar(create.affineKMem, elementType,
iTrip, jTrip, kTrip, /*unroll*/ false);
// clang-format off
create.affineKMem.ifThenElseIE(indexScope, allFullTiles,
/* then full tiles */ [&](const AffineBuilderKrnlMem &createAffine) {
genSimdMatVect(createAffine, matmulOp, elementType, aStart, bStart,
genSimdMatVect(createAffine, matmulOp, TmpSimdProd, elementType, aStart, bStart,
cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize,
vectorLen, fullUnrollAndJam);
}, /* else has partial tiles */ [&](const AffineBuilderKrnlMem &createAffine) {
genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart,
genScalar(createAffine, matmulOp, TmpScalarProd, elementType, aStart, bStart, cStart,
iTrip, jTrip, kTrip, /*unroll*/ false);
});
// clang-format on
} else {
Value TmpSimdC = allocForGenSimdMatMat(create.affineKMem, elementType,
iComputeTileSize, jComputeTileSize, kComputeTileSize, vectorLen,
fullUnrollAndJam);
Value TmpScalarC = allocForGenScalar(create.affineKMem, elementType,
iTrip, jPartialTrip, kTrip, /*unroll*/ false);
// clang-format off
create.affineKMem.ifThenElseIE(indexScope, allFullTiles,
/* then full tiles */ [&](const AffineBuilderKrnlMem &createAffine) {
genSimdMatMat(createAffine, matmulOp, elementType, aStart, bStart,
genSimdMatMat(createAffine, matmulOp, TmpSimdC, elementType, aStart, bStart,
cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize,
vectorLen, fullUnrollAndJam);
},
Expand All @@ -246,33 +260,30 @@ class KrnlMatmulLowering : public ConversionPattern {
// Test if SIMD dim (M) is full.
createAffine.ifThenElseIE(indexScope, jFullTiles,
/* full SIMD */ [&](const AffineBuilderKrnlMem &createAffine) {
genSimdMatMat(createAffine, matmulOp, elementType, aStart, bStart,
genSimdMatMat(createAffine, matmulOp, TmpSimdC, elementType, aStart, bStart,
cStart, iTrip, jComputeTileSize, kTrip, vectorLen, /*unroll*/ false);
}, /* else partial SIMD */ [&](const AffineBuilderKrnlMem &createAffine) {
// TODO: evaluate if get performance from partial SIMD
if (false && jPartialTrip.isLiteral() && jPartialTrip.getLiteral() >=2) {
// has a known trip count along the simd dimension of at least 2
// elements, use simd again.
genSimdMatMat(createAffine, matmulOp, elementType, aStart, bStart,
cStart, iTrip, jPartialTrip, kTrip, vectorLen, /*unroll*/ false);
} else {
genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart,
iTrip, jPartialTrip, kTrip, /*unroll*/ false);
}
genScalar(createAffine, matmulOp, TmpScalarC, elementType, aStart, bStart, cStart,
iTrip, jPartialTrip, kTrip, /*unroll*/ false);
});
});
// clang-format on
}
} else {
// Scalar code generator.
Value TmpThenC =
allocForGenScalar(create.affineKMem, elementType, iComputeTileSize,
jComputeTileSize, kComputeTileSize, fullUnrollAndJam);
Value TmpElseC = allocForGenScalar(
create.affineKMem, elementType, iTrip, jTrip, kTrip, false);
// clang-format off
create.affineKMem.ifThenElseIE(indexScope, allFullTiles,
/* then full */ [&](const AffineBuilderKrnlMem &createAffine) {
genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart,
genScalar(createAffine, matmulOp, TmpThenC, elementType, aStart, bStart, cStart,
iComputeTileSize, jComputeTileSize, kComputeTileSize,
fullUnrollAndJam);
}, /* else partial */ [&](const AffineBuilderKrnlMem &createAffine) {
genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart,
genScalar(createAffine, matmulOp, TmpElseC, elementType, aStart, bStart, cStart,
iTrip, jTrip, kTrip, false);
});
// clang-format on
Expand All @@ -282,21 +293,32 @@ class KrnlMatmulLowering : public ConversionPattern {
}

private:
void genScalar(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
Type elementType, ArrayRef<IndexExpr> aStart, ArrayRef<IndexExpr> bStart,
ArrayRef<IndexExpr> cStart, IndexExpr I, IndexExpr J, IndexExpr K,
Value allocForGenScalar(const AffineBuilderKrnlMem &createAffine,
Type elementType, IndexExpr I, IndexExpr J, IndexExpr K,
bool unrollJam) const {
// Get operands.
MemRefBuilder createMemRef(createAffine);
int64_t unrollFactor = (unrollJam && J.isLiteral()) ? J.getLiteral() : 1;
// Have to privatize CTmpType by unroll factor (1 if none).
MemRefType CTmpType = MemRefType::get({unrollFactor}, elementType);
assert(BUFFER_ALIGN >= gDefaultAllocAlign);
//
if (parallelEnabled)
return createMemRef.alignedAlloc(CTmpType, BUFFER_ALIGN);
return createMemRef.alignedAlloca(CTmpType, BUFFER_ALIGN);
}

void genScalar(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
Value TmpC, Type elementType, ArrayRef<IndexExpr> aStart,
ArrayRef<IndexExpr> bStart, ArrayRef<IndexExpr> cStart, IndexExpr I,
IndexExpr J, IndexExpr K, bool unrollJam) const {
// Get operands.
KrnlMatMulOpAdaptor operandAdaptor(op);
MemRefBuilder createMemRef(createAffine);

Value A(operandAdaptor.getA()), B(operandAdaptor.getB()),
C(operandAdaptor.getC());
int64_t unrollFactor = (unrollJam && J.isLiteral()) ? J.getLiteral() : 1;
// Have to privatize CTmpType by unroll factor (1 if none).
MemRefType CTmpType = MemRefType::get({unrollFactor}, elementType);
assert(BUFFER_ALIGN >= gDefaultAllocAlign);
Value TmpC = createMemRef.alignedAlloc(CTmpType, BUFFER_ALIGN);

// For i, j loops.
LiteralIndexExpr zeroIE(0);
Expand Down Expand Up @@ -342,11 +364,46 @@ class KrnlMatmulLowering : public ConversionPattern {
}
}

Value allocForGenSimdMatVect(const AffineBuilderKrnlMem &createAffine,
Type elementType, IndexExpr I, IndexExpr J, IndexExpr K,
IndexExpr vectorLen, bool unrollJam) const {
// can simdize only if I & K is compile time
assert(I.isLiteral() && K.isLiteral() && vectorLen.isLiteral() &&
"can only simdize with compile time "
"blocking factor on simd axis");
MultiDialectBuilder<VectorBuilder, MemRefBuilder> create(createAffine);
int64_t iLit(I.getLiteral()), VL(vectorLen.getLiteral());
int64_t archVL = create.vec.getArchVectorLength(elementType);

// Generate the vector type conversions.
assert(VL == archVL && "vector length and VL must be identical for now");
VectorType vecType = VectorType::get({VL}, elementType);
int64_t iUnrollFactor = iLit;
assert(iUnrollFactor % VL == 0 && "i blocking should be a multiple of VL");

// Have to privatize CTmpType by unroll factor.
MemRefType CTmpType = MemRefType::get({iUnrollFactor}, vecType);
assert(BUFFER_ALIGN >= gDefaultAllocAlign &&
"alignment of buffers cannot be smaller than the default alignment "
"(which is set for SIMD correctness");
// Ok to use an alloca here because hoisting will take it out of the loop,
// as it is now generated before the scf.if which precluded the migration to
// outside the loops.

// But at this time, if parallel is enabled, alloca would be stuck inside of
// the parallel loop, which is not great. TODO: migrate alloca from inside
// the parallel loop to the OMP parallel region before the loop.
// Grep for this pattern in all 3 instances of "parallelEnabled".
if (parallelEnabled)
return create.mem.alignedAlloc(CTmpType, BUFFER_ALIGN);
return create.mem.alignedAlloca(CTmpType, BUFFER_ALIGN);
}

// Initially, simdize with full K vector length.
void genSimdMatVect(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
Type elementType, ArrayRef<IndexExpr> aStart, ArrayRef<IndexExpr> bStart,
ArrayRef<IndexExpr> cStart, IndexExpr I, IndexExpr J, IndexExpr K,
IndexExpr vectorLen, bool unrollJam) const {
Value TmpProd, Type elementType, ArrayRef<IndexExpr> aStart,
ArrayRef<IndexExpr> bStart, ArrayRef<IndexExpr> cStart, IndexExpr I,
IndexExpr J, IndexExpr K, IndexExpr vectorLen, bool unrollJam) const {
// can simdize only if I & K is compile time
assert(I.isLiteral() && K.isLiteral() && vectorLen.isLiteral() &&
"can only simdize with compile time "
Expand All @@ -367,12 +424,6 @@ class KrnlMatmulLowering : public ConversionPattern {
int64_t iUnrollFactor = iLit;
assert(iUnrollFactor % VL == 0 && "i blocking should be a multiple of VL");

// Have to privatize CTmpType by unroll factor.
MemRefType CTmpType = MemRefType::get({iUnrollFactor}, vecType);
assert(BUFFER_ALIGN >= gDefaultAllocAlign &&
"alignment of buffers cannot be smaller than the default alignment "
"(which is set for SIMD correctness");
Value TmpProd = create.mem.alignedAlloc(CTmpType, BUFFER_ALIGN);
// Init with zero.
Value fZero = create.math.constant(elementType, 0);
Value vFZero = create.vec.broadcast(vecType, fZero);
Expand Down Expand Up @@ -427,11 +478,36 @@ class KrnlMatmulLowering : public ConversionPattern {
}
}

Value allocForGenSimdMatMat(const AffineBuilderKrnlMem &createAffine,
Type elementType, IndexExpr I, IndexExpr J, IndexExpr K,
IndexExpr vectorLen, bool unrollJam) const {
// can simdize only if K is compile time
MultiDialectBuilder<MemRefBuilder> create(createAffine);

// Generate the vector type conversions.
int64_t VL = vectorLen.getLiteral();
VectorType vecType = VectorType::get({VL}, elementType);
int64_t unrollFactor = (unrollJam && I.isLiteral()) ? I.getLiteral() : 1;
// Have to privatize CTmpType by unroll factor (1 if none).
MemRefType CTmpType = MemRefType::get({unrollFactor}, vecType);
assert(BUFFER_ALIGN >= gDefaultAllocAlign);
// Ok to use an alloca here because hoisting will take it out of the loop,
// as it is now generated before the scf.if which precluded the migration to
// outside the loops.

// But at this time, if parallel is enabled, alloca would be stuck inside of
// the parallel loop, which is not great. TODO: migrate alloca from inside
// the parallel loop to the OMP parallel region before the loop.
if (parallelEnabled)
return create.mem.alignedAlloc(CTmpType, BUFFER_ALIGN);
return create.mem.alignedAlloca(CTmpType, BUFFER_ALIGN);
}

// Simdize along J / memory rows in B and C.
void genSimdMatMat(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
Type elementType, ArrayRef<IndexExpr> aStart, ArrayRef<IndexExpr> bStart,
ArrayRef<IndexExpr> cStart, IndexExpr I, IndexExpr J, IndexExpr K,
IndexExpr vectorLen, bool unrollJam) const {
Value TmpC, Type elementType, ArrayRef<IndexExpr> aStart,
ArrayRef<IndexExpr> bStart, ArrayRef<IndexExpr> cStart, IndexExpr I,
IndexExpr J, IndexExpr K, IndexExpr vectorLen, bool unrollJam) const {
// can simdize only if K is compile time
assert(J.isLiteral() &&
"can only simdize with compile time blocking factor on simd axis");
Expand All @@ -446,10 +522,6 @@ class KrnlMatmulLowering : public ConversionPattern {
int64_t VL = vectorLen.getLiteral();
VectorType vecType = VectorType::get({VL}, elementType);
int64_t unrollFactor = (unrollJam && I.isLiteral()) ? I.getLiteral() : 1;
// Have to privatize CTmpType by unroll factor (1 if none).
MemRefType CTmpType = MemRefType::get({unrollFactor}, vecType);
assert(BUFFER_ALIGN >= gDefaultAllocAlign);
Value TmpC = create.mem.alignedAlloc(CTmpType, BUFFER_ALIGN);

// Iterates over the I indices (j are simd dim).
Value iSaved, kSaved;
Expand Down Expand Up @@ -547,8 +619,8 @@ class KrnlMatmulLowering : public ConversionPattern {
}; // namespace krnl

void populateLoweringKrnlMatmultOpPattern(TypeConverter &typeConverter,
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<KrnlMatmulLowering>(typeConverter, ctx);
RewritePatternSet &patterns, MLIRContext *ctx, bool parallelEnabled) {
patterns.insert<KrnlMatmulLowering>(typeConverter, ctx, parallelEnabled);
}

} // namespace krnl
Expand Down
4 changes: 4 additions & 0 deletions src/Conversion/ONNXToKrnl/Math/Gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ struct ONNXGemmOpLowering : public OpConversionPattern<GemmOp> {
[&](const KrnlBuilder &createKrnl, ValueRange i1_j1_indices) {
Value i1(i1_j1_indices[0]), j1(i1_j1_indices[1]);
// If parallel, will stay inside, otherwise will migrate out.
// Since they are not in an if structure, migration out is not an
// issue.
Value aBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN);
Value bBuff = create.mem.alignedAlloc(bTileType, BUFFER_ALIGN);
Value rBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN);
Expand Down Expand Up @@ -313,6 +315,8 @@ struct ONNXGemmOpLowering : public OpConversionPattern<GemmOp> {
[&](const KrnlBuilder &createKrnl, ValueRange j1_k1_indices) {
Value j1(j1_k1_indices[0]), k1(j1_k1_indices[1]);
// If parallel, it will stay inside, otherwise it will migrate out.
// Since allocs are not in an if structure, migration is not an
// issue.
Value aBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN);
Value bBuff = create.mem.alignedAlloc(bTileType, BUFFER_ALIGN);
if (bTrans)
Expand Down
1 change: 1 addition & 0 deletions src/Pass/Passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ std::unique_ptr<mlir::Pass> createElideConstGlobalValuePass();
namespace krnl {
/// Pass for lowering frontend dialects to Krnl IR dialect.
std::unique_ptr<mlir::Pass> createConvertKrnlToAffinePass();
std::unique_ptr<mlir::Pass> createConvertKrnlToAffinePass(bool parallelEnabled);

/// Pass for lowering Seq in Krnl dialect.
std::unique_ptr<mlir::Pass> createConvertSeqToMemrefPass();
Expand Down