Skip to content

Commit

Permalink
Java8 unstick issue (#2961)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandre Eichenberger <[email protected]>
  • Loading branch information
AlexandreEichenberger authored Oct 2, 2024
1 parent ec314b7 commit 56a610c
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 62 deletions.
89 changes: 55 additions & 34 deletions src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
IndexExpr T1 = outputDims[E1].ceilDiv(64);
ubs[E1] = T1; // E1 dim is over tiles.

// Predicates used to avoid creating code that is never used.
bool neverHas64 = outputDims[E1].isLiteralAndSmallerThan(64);
bool neverHas8 = outputDims[E1].isLiteralAndSmallerThan(8);
bool hasOnly64 =
outputDims[E1].isLiteral() && (outputDims[E1].getLiteral() % 64 == 0);

// Parallel...
if (enableParallel) {
int64_t parId;
Expand Down Expand Up @@ -184,10 +190,16 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {

// I may process here up to [e1 ... e1 + m*64), make sure its
// not going out of bound, i.e. beyond outputDIms[E1];
IndexExpr isFullLogical;
IndexExpr ub1 = SymIE(outputDims[E1]);
IndexExpr lit64Bis = LitIE(64);
IndexExpr isFull = create.krnlIE.isTileFull(e1, lit64, ub1);
IndexExpr isFullLogical = isFull >= 0;
if (hasOnly64) {
isFullLogical = PredIE(true);
} else if (neverHas64) {
isFullLogical = PredIE(false);
} else {
IndexExpr isFull = create.krnlIE.isTileFull(e1, lit64, ub1);
isFullLogical = isFull >= 0;
}
create.scf.ifThenElse(
// Condition
isFullLogical.getValue(),
Expand All @@ -198,6 +210,9 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
const int64_t unrollVL = 4;
const int64_t totVL = unrollVL * archVL;
assert(totVL <= 64 && "bad unroll");
if (neverHas64)
return; // Nothing to do here.

create.scf.forLoop(litZero.getValue(), lit64.getValue(), totVL,
[&](const SCFBuilder b, ValueRange loopInd) {
MDBuilder create(b);
Expand All @@ -206,7 +221,8 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
IndexExpr l = DimIE(loopIndex);
Value vecF16[unrollVL], vecF32H[unrollVL],
vecF32L[unrollVL];
// Load f16 values from input via reinterpreted data tile.
// Load f16 values from input via reinterpreted data
// tile.
for (int64_t i = 0; i < unrollVL; ++i) {
vecF16[i] = create.vec.loadIE(vecF16Type, inputAsTx64,
{SymIE(inputTileOffset), l + (i * archVL)}, {});
Expand All @@ -231,40 +247,45 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
}
});
},
// else, we don't have a full (64 e1) tile.
// Else, we don't have a full (64 e1) tile.
[&](SCFBuilder b) {
MDBuilder create(b);
IndexExprScope middleScope(b, &outerScope);
IndexExpr tripCount = SymIE(ub1) - SymIE(e1);
// Note: if we only have multiple of VL, loop below will handle
// all as we subtract (VL-1). Aka if VL=8 and tripCount = 16,
// tripCountWithoutPartialLastVL is 16 - 7 = 9. Thus we iterate
// over i=0 & i=8 as both are < 9.
IndexExpr tripCountWithoutPartialLastVL =
tripCount - (archVL - 1);
create.scf.forLoop(litZero.getValue(),
tripCountWithoutPartialLastVL.getValue(), archVL,
[&](SCFBuilder b, ValueRange loopInd) {
MDBuilder create(b);
IndexExprScope innerScope(b, &middleScope);
Value loopIndex = loopInd[0];
IndexExpr l = DimIE(loopIndex);
// Load f16 values from input via reinterpreted data tile.
Value vecF16 = create.vec.loadIE(vecF16Type, inputAsTx64,
{SymIE(inputTileOffset), l}, {});
// Convert back to f32.
auto convertOp =
rewriter.create<ZLowConvertDLF16ToF32VectorOp>(
loc, vecF16);
Value vecF32H = convertOp.getResult(0);
Value vecF32L = convertOp.getResult(1);
// Store f32 values back to the (normal layout) output.
DimsExpr outputAF = SymListIE(inputAF);
outputAF[E1] = outputAF[E1] + l;
create.vec.storeIE(vecF32H, alloc, outputAF);
create.vec.storeIE(
vecF32L, alloc, outputAF, {litArchVLHalf.getValue()});
});
if (hasOnly64)
return;
if (!neverHas8) {
// Note: if we only have multiple of VL, loop below will
// handle all as we subtract (VL-1). Aka if VL=8 and tripCount
// = 16, tripCountWithoutPartialLastVL is 16 - 7 = 9. Thus we
// iterate over i=0 & i=8 as both are < 9.
IndexExpr tripCountWithoutPartialLastVL =
tripCount - (archVL - 1);
create.scf.forLoop(litZero.getValue(),
tripCountWithoutPartialLastVL.getValue(), archVL,
[&](SCFBuilder b, ValueRange loopInd) {
MDBuilder create(b);
IndexExprScope innerScope(b, &middleScope);
Value loopIndex = loopInd[0];
IndexExpr l = DimIE(loopIndex);
// Load f16 values from input via reinterpreted data
// tile.
Value vecF16 = create.vec.loadIE(vecF16Type,
inputAsTx64, {SymIE(inputTileOffset), l}, {});
// Convert back to f32.
auto convertOp =
rewriter.create<ZLowConvertDLF16ToF32VectorOp>(
loc, vecF16);
Value vecF32H = convertOp.getResult(0);
Value vecF32L = convertOp.getResult(1);
// Store f32 values back to the (normal layout) output.
DimsExpr outputAF = SymListIE(inputAF);
outputAF[E1] = outputAF[E1] + l;
create.vec.storeIE(vecF32H, alloc, outputAF);
create.vec.storeIE(vecF32L, alloc, outputAF,
{litArchVLHalf.getValue()});
});
}
// Deal with the last values: compute f32 using simd.
IndexExpr remainingScalarValues = tripCount % archVL;
IndexExpr lastL = tripCount - remainingScalarValues;
Expand Down
1 change: 1 addition & 0 deletions src/Dialect/Mlir/IndexExpr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,7 @@ class SymbolIndexExpr : public IndexExpr {
//===----------------------------------------------------------------------===//

using LitIE = LiteralIndexExpr;
using PredIE = PredicateIndexExpr;
using SymIE = SymbolIndexExpr;
using DimIE = DimIndexExpr;

Expand Down
Loading

0 comments on commit 56a610c

Please sign in to comment.