Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandre Eichenberger <[email protected]>
  • Loading branch information
AlexandreEichenberger committed Feb 12, 2025
1 parent 642bb45 commit a1dd524
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 23 deletions.
64 changes: 41 additions & 23 deletions src/Conversion/KrnlToAffine/KrnlMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,28 +224,33 @@ class KrnlMatmulLowering : public ConversionPattern {
if (simdize) {
// SIMD code generator.
if (matVectorProduct) {
Value TmpProd = allocForGenSimdMatVect(create.affineKMem, elementType,
iComputeTileSize, jComputeTileSize, kComputeTileSize, vectorLen,
fullUnrollAndJam);
// Alloc of temp outside of inner if/then/else.
Value TmpSimdProd = allocForGenSimdMatVect(create.affineKMem,
elementType, iComputeTileSize, jComputeTileSize, kComputeTileSize,
vectorLen, fullUnrollAndJam);
Value TmpScalarProd = allocForGenScalar(create.affineKMem, elementType,
iTrip, jTrip, kTrip, /*unroll*/ false);
// clang-format off
create.affineKMem.ifThenElseIE(indexScope, allFullTiles,
/* then full tiles */ [&](const AffineBuilderKrnlMem &createAffine) {
genSimdMatVect(createAffine, matmulOp, TmpProd, elementType, aStart, bStart,
genSimdMatVect(createAffine, matmulOp, TmpSimdProd, elementType, aStart, bStart,
cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize,
vectorLen, fullUnrollAndJam);
}, /* else has partial tiles */ [&](const AffineBuilderKrnlMem &createAffine) {
genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart,
genScalar(createAffine, matmulOp, TmpScalarProd, elementType, aStart, bStart, cStart,
iTrip, jTrip, kTrip, /*unroll*/ false);
});
// clang-format on
} else {
Value TmpC = allocForGenSimdMat(create.affineKMem, elementType,
Value TmpSimdC = allocForGenSimdMatMat(create.affineKMem, elementType,
iComputeTileSize, jComputeTileSize, kComputeTileSize, vectorLen,
fullUnrollAndJam);
Value TmpScalarC = allocForGenScalar(create.affineKMem, elementType,
iTrip, jPartialTrip, kTrip, /*unroll*/ false);
// clang-format off
create.affineKMem.ifThenElseIE(indexScope, allFullTiles,
/* then full tiles */ [&](const AffineBuilderKrnlMem &createAffine) {
genSimdMatMat(createAffine, matmulOp, TmpC, elementType, aStart, bStart,
genSimdMatMat(createAffine, matmulOp, TmpSimdC, elementType, aStart, bStart,
cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize,
vectorLen, fullUnrollAndJam);
},
Expand All @@ -255,25 +260,30 @@ class KrnlMatmulLowering : public ConversionPattern {
// Test if SIMD dim (M) is full.
createAffine.ifThenElseIE(indexScope, jFullTiles,
/* full SIMD */ [&](const AffineBuilderKrnlMem &createAffine) {
genSimdMatMat(createAffine, matmulOp, TmpC, elementType, aStart, bStart,
genSimdMatMat(createAffine, matmulOp, TmpSimdC, elementType, aStart, bStart,
cStart, iTrip, jComputeTileSize, kTrip, vectorLen, /*unroll*/ false);
}, /* else partial SIMD */ [&](const AffineBuilderKrnlMem &createAffine) {
genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart,
genScalar(createAffine, matmulOp, TmpScalarC, elementType, aStart, bStart, cStart,
iTrip, jPartialTrip, kTrip, /*unroll*/ false);
});
});
// clang-format on
}
} else {
// Scalar code generator.
Value TmpThenC =
allocForGenScalar(create.affineKMem, elementType, iComputeTileSize,
jComputeTileSize, kComputeTileSize, fullUnrollAndJam);
Value TmpElseC = allocForGenScalar(
create.affineKMem, elementType, iTrip, jTrip, kTrip, false);
// clang-format off
create.affineKMem.ifThenElseIE(indexScope, allFullTiles,
/* then full */ [&](const AffineBuilderKrnlMem &createAffine) {
genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart,
genScalar(createAffine, matmulOp, TmpThenC, elementType, aStart, bStart, cStart,
iComputeTileSize, jComputeTileSize, kComputeTileSize,
fullUnrollAndJam);
}, /* else partial */ [&](const AffineBuilderKrnlMem &createAffine) {
genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart,
genScalar(createAffine, matmulOp, TmpElseC, elementType, aStart, bStart, cStart,
iTrip, jTrip, kTrip, false);
});
// clang-format on
Expand All @@ -283,21 +293,32 @@ class KrnlMatmulLowering : public ConversionPattern {
}

private:
void genScalar(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
Type elementType, ArrayRef<IndexExpr> aStart, ArrayRef<IndexExpr> bStart,
ArrayRef<IndexExpr> cStart, IndexExpr I, IndexExpr J, IndexExpr K,
Value allocForGenScalar(const AffineBuilderKrnlMem &createAffine,
Type elementType, IndexExpr I, IndexExpr J, IndexExpr K,
bool unrollJam) const {
// Get operands.
MemRefBuilder createMemRef(createAffine);
int64_t unrollFactor = (unrollJam && J.isLiteral()) ? J.getLiteral() : 1;
// Have to privatize CTmpType by unroll factor (1 if none).
MemRefType CTmpType = MemRefType::get({unrollFactor}, elementType);
assert(BUFFER_ALIGN >= gDefaultAllocAlign);
//
if (parallelEnabled)
return createMemRef.alignedAlloc(CTmpType, BUFFER_ALIGN);
return createMemRef.alignedAlloca(CTmpType, BUFFER_ALIGN);
}

void genScalar(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
Value TmpC, Type elementType, ArrayRef<IndexExpr> aStart,
ArrayRef<IndexExpr> bStart, ArrayRef<IndexExpr> cStart, IndexExpr I,
IndexExpr J, IndexExpr K, bool unrollJam) const {
// Get operands.
KrnlMatMulOpAdaptor operandAdaptor(op);
MemRefBuilder createMemRef(createAffine);

Value A(operandAdaptor.getA()), B(operandAdaptor.getB()),
C(operandAdaptor.getC());
int64_t unrollFactor = (unrollJam && J.isLiteral()) ? J.getLiteral() : 1;
// Have to privatize CTmpType by unroll factor (1 if none).
MemRefType CTmpType = MemRefType::get({unrollFactor}, elementType);
assert(BUFFER_ALIGN >= gDefaultAllocAlign);
Value TmpC = createMemRef.alignedAlloc(CTmpType, BUFFER_ALIGN);

// For i, j loops.
LiteralIndexExpr zeroIE(0);
Expand Down Expand Up @@ -372,8 +393,7 @@ class KrnlMatmulLowering : public ConversionPattern {
// But at this time, if parallel is enabled, alloca would be stuck inside of
// the parallel loop, which is not great. TODO: migrate alloca from inside
// the parallel loop to the OMP parallel region before the loop.

// Hi alex
// Grep for this pattern in all 3 instances of "parallelEnabled".
if (parallelEnabled)
return create.mem.alignedAlloc(CTmpType, BUFFER_ALIGN);
return create.mem.alignedAlloca(CTmpType, BUFFER_ALIGN);
Expand Down Expand Up @@ -458,7 +478,7 @@ class KrnlMatmulLowering : public ConversionPattern {
}
}

Value allocForGenSimdMat(const AffineBuilderKrnlMem &createAffine,
Value allocForGenSimdMatMat(const AffineBuilderKrnlMem &createAffine,
Type elementType, IndexExpr I, IndexExpr J, IndexExpr K,
IndexExpr vectorLen, bool unrollJam) const {
// can simdize only if K is compile time
Expand All @@ -478,8 +498,6 @@ class KrnlMatmulLowering : public ConversionPattern {
// But at this time, if parallel is enabled, alloca would be stuck inside of
// the parallel loop, which is not great. TODO: migrate alloca from inside
// the parallel loop to the OMP parallel region before the loop.

// Hi alex
if (parallelEnabled)
return create.mem.alignedAlloc(CTmpType, BUFFER_ALIGN);
return create.mem.alignedAlloca(CTmpType, BUFFER_ALIGN);
Expand Down
4 changes: 4 additions & 0 deletions src/Conversion/ONNXToKrnl/Math/Gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ struct ONNXGemmOpLowering : public OpConversionPattern<GemmOp> {
[&](const KrnlBuilder &createKrnl, ValueRange i1_j1_indices) {
Value i1(i1_j1_indices[0]), j1(i1_j1_indices[1]);
// If parallel, will stay inside, otherwise will migrate out.
// Since they are not in an if structure, migration out is not an
// issue.
Value aBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN);
Value bBuff = create.mem.alignedAlloc(bTileType, BUFFER_ALIGN);
Value rBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN);
Expand Down Expand Up @@ -313,6 +315,8 @@ struct ONNXGemmOpLowering : public OpConversionPattern<GemmOp> {
[&](const KrnlBuilder &createKrnl, ValueRange j1_k1_indices) {
Value j1(j1_k1_indices[0]), k1(j1_k1_indices[1]);
// If parallel, it will stay inside, otherwise it will migrate out.
// Since allocs are not in an if structure, migration is not an
// issue.
Value aBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN);
Value bBuff = create.mem.alignedAlloc(bTileType, BUFFER_ALIGN);
if (bTrans)
Expand Down

0 comments on commit a1dd524

Please sign in to comment.