From 56a610cb425255b464800e27f8f054429cd17ad1 Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Wed, 2 Oct 2024 19:51:52 -0400 Subject: [PATCH] Java8 unstick issue (#2961) Signed-off-by: Alexandre Eichenberger --- .../Transform/ZLow/ZLowStickExpansion.cpp | 89 +++++++----- src/Dialect/Mlir/IndexExpr.hpp | 1 + .../zlow-stick-unstick-expansion.mlir | 137 ++++++++++++++---- 3 files changed, 165 insertions(+), 62 deletions(-) diff --git a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp index 12d1cf0fbc..5c607a0ece 100644 --- a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp +++ b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp @@ -125,6 +125,12 @@ class UnstickExpansionPattern : public OpRewritePattern { IndexExpr T1 = outputDims[E1].ceilDiv(64); ubs[E1] = T1; // E1 dim is over tiles. + // Predicates used to avoid creating code that is never used. + bool neverHas64 = outputDims[E1].isLiteralAndSmallerThan(64); + bool neverHas8 = outputDims[E1].isLiteralAndSmallerThan(8); + bool hasOnly64 = + outputDims[E1].isLiteral() && (outputDims[E1].getLiteral() % 64 == 0); + // Parallel... if (enableParallel) { int64_t parId; @@ -184,10 +190,16 @@ class UnstickExpansionPattern : public OpRewritePattern { // I may process here up to [e1 ... e1 + m*64), make sure its // not going out of bound, i.e. beyond outputDIms[E1]; + IndexExpr isFullLogical; IndexExpr ub1 = SymIE(outputDims[E1]); - IndexExpr lit64Bis = LitIE(64); - IndexExpr isFull = create.krnlIE.isTileFull(e1, lit64, ub1); - IndexExpr isFullLogical = isFull >= 0; + if (hasOnly64) { + isFullLogical = PredIE(true); + } else if (neverHas64) { + isFullLogical = PredIE(false); + } else { + IndexExpr isFull = create.krnlIE.isTileFull(e1, lit64, ub1); + isFullLogical = isFull >= 0; + } create.scf.ifThenElse( // Condition isFullLogical.getValue(), @@ -198,6 +210,9 @@ class UnstickExpansionPattern : public OpRewritePattern { const int64_t unrollVL = 4; const int64_t totVL = unrollVL * archVL; assert(totVL <= 64 && "bad unroll"); + if (neverHas64) + return; // Nothing to do here. + create.scf.forLoop(litZero.getValue(), lit64.getValue(), totVL, [&](const SCFBuilder b, ValueRange loopInd) { MDBuilder create(b); @@ -206,7 +221,8 @@ class UnstickExpansionPattern : public OpRewritePattern { IndexExpr l = DimIE(loopIndex); Value vecF16[unrollVL], vecF32H[unrollVL], vecF32L[unrollVL]; - // Load f16 values from input via reinterpreted data tile. + // Load f16 values from input via reinterpreted data + // tile. for (int64_t i = 0; i < unrollVL; ++i) { vecF16[i] = create.vec.loadIE(vecF16Type, inputAsTx64, {SymIE(inputTileOffset), l + (i * archVL)}, {}); @@ -231,40 +247,45 @@ class UnstickExpansionPattern : public OpRewritePattern { } }); }, - // else, we don't have a full (64 e1) tile. + // Else, we don't have a full (64 e1) tile. [&](SCFBuilder b) { MDBuilder create(b); IndexExprScope middleScope(b, &outerScope); IndexExpr tripCount = SymIE(ub1) - SymIE(e1); - // Note: if we only have multiple of VL, loop below will handle - // all as we subtract (VL-1). Aka if VL=8 and tripCount = 16, - // tripCountWithoutPartialLastVL is 16 - 7 = 9. Thus we iterate - // over i=0 & i=8 as both are < 9. - IndexExpr tripCountWithoutPartialLastVL = - tripCount - (archVL - 1); - create.scf.forLoop(litZero.getValue(), - tripCountWithoutPartialLastVL.getValue(), archVL, - [&](SCFBuilder b, ValueRange loopInd) { - MDBuilder create(b); - IndexExprScope innerScope(b, &middleScope); - Value loopIndex = loopInd[0]; - IndexExpr l = DimIE(loopIndex); - // Load f16 values from input via reinterpreted data tile. - Value vecF16 = create.vec.loadIE(vecF16Type, inputAsTx64, - {SymIE(inputTileOffset), l}, {}); - // Convert back to f32. - auto convertOp = - rewriter.create( - loc, vecF16); - Value vecF32H = convertOp.getResult(0); - Value vecF32L = convertOp.getResult(1); - // Store f32 values back to the (normal layout) output. - DimsExpr outputAF = SymListIE(inputAF); - outputAF[E1] = outputAF[E1] + l; - create.vec.storeIE(vecF32H, alloc, outputAF); - create.vec.storeIE( - vecF32L, alloc, outputAF, {litArchVLHalf.getValue()}); - }); + if (hasOnly64) + return; + if (!neverHas8) { + // Note: if we only have multiple of VL, loop below will + // handle all as we subtract (VL-1). Aka if VL=8 and tripCount + // = 16, tripCountWithoutPartialLastVL is 16 - 7 = 9. Thus we + // iterate over i=0 & i=8 as both are < 9. + IndexExpr tripCountWithoutPartialLastVL = + tripCount - (archVL - 1); + create.scf.forLoop(litZero.getValue(), + tripCountWithoutPartialLastVL.getValue(), archVL, + [&](SCFBuilder b, ValueRange loopInd) { + MDBuilder create(b); + IndexExprScope innerScope(b, &middleScope); + Value loopIndex = loopInd[0]; + IndexExpr l = DimIE(loopIndex); + // Load f16 values from input via reinterpreted data + // tile. + Value vecF16 = create.vec.loadIE(vecF16Type, + inputAsTx64, {SymIE(inputTileOffset), l}, {}); + // Convert back to f32. + auto convertOp = + rewriter.create( + loc, vecF16); + Value vecF32H = convertOp.getResult(0); + Value vecF32L = convertOp.getResult(1); + // Store f32 values back to the (normal layout) output. + DimsExpr outputAF = SymListIE(inputAF); + outputAF[E1] = outputAF[E1] + l; + create.vec.storeIE(vecF32H, alloc, outputAF); + create.vec.storeIE(vecF32L, alloc, outputAF, + {litArchVLHalf.getValue()}); + }); + } // Deal with the last values: compute f32 using simd. IndexExpr remainingScalarValues = tripCount % archVL; IndexExpr lastL = tripCount - remainingScalarValues; diff --git a/src/Dialect/Mlir/IndexExpr.hpp b/src/Dialect/Mlir/IndexExpr.hpp index e1cd247cb7..678fb664ea 100644 --- a/src/Dialect/Mlir/IndexExpr.hpp +++ b/src/Dialect/Mlir/IndexExpr.hpp @@ -828,6 +828,7 @@ class SymbolIndexExpr : public IndexExpr { //===----------------------------------------------------------------------===// using LitIE = LiteralIndexExpr; +using PredIE = PredicateIndexExpr; using SymIE = SymbolIndexExpr; using DimIE = DimIndexExpr; diff --git a/test/mlir/accelerators/nnpa/transform/zlow-stick-unstick-expansion.mlir b/test/mlir/accelerators/nnpa/transform/zlow-stick-unstick-expansion.mlir index e3761e8bd6..36d805f267 100644 --- a/test/mlir/accelerators/nnpa/transform/zlow-stick-unstick-expansion.mlir +++ b/test/mlir/accelerators/nnpa/transform/zlow-stick-unstick-expansion.mlir @@ -193,23 +193,18 @@ func.func @test_unstick_expansion(%arg0: memref<16x8x128xf16, #map>) -> memref<1 // CHECK-DAG: [[MAP_4_:#.+]] = affine_map<(d0) -> (d0 + 16)> // CHECK-DAG: [[MAP_5_:#.+]] = affine_map<(d0) -> (d0 + 24)> // CHECK-DAG: [[MAP_6_:#.+]] = affine_map<(d0)[s0] -> (d0 + s0)> -// CHECK-DAG: [[MAP_7_:#.+]] = affine_map<()[s0] -> (-s0 + 121)> -// CHECK-DAG: [[MAP_8_:#.+]] = affine_map<()[s0] -> ((-s0) mod 8)> -// CHECK-DAG: [[MAP_9_:#.+]] = affine_map<()[s0] -> (-s0 - (-s0) mod 8 + 128)> -// CHECK-DAG: [[MAP_10_:#.+]] = affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)> // CHECK-LABEL: func.func @test_unstick_expansion // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf16, #map>) -> memref<16x8x128xf32> { -// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[CST_28_:%.+]] = arith.constant 28 : index // CHECK-DAG: [[CST_24_:%.+]] = arith.constant 24 : index // CHECK-DAG: [[CST_20_:%.+]] = arith.constant 20 : index // CHECK-DAG: [[CST_16_:%.+]] = arith.constant 16 : index // CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index // CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index // CHECK-DAG: [[CST_32_:%.+]] = arith.constant 32 : index // CHECK-DAG: [[VAR_true_:%.+]] = arith.constant true // CHECK-DAG: [[CST_64_:%.+]] = arith.constant 64 : index -// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<16x8x128xf32> // CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 @@ -232,52 +227,138 @@ func.func @test_unstick_expansion(%arg0: memref<16x8x128xf16, #map>) -> memref<1 // CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_2_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_8_]]{{.}} : memref<2x64xf16>, vector<8xf16> // CHECK-DAG: [[VAR_10_:%.+]] = affine.apply [[MAP_5_]]([[I_3_]]) // CHECK: [[LOAD_VAR_reinterpret_cast_MEM_3_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_10_]]{{.}} : memref<2x64xf16>, vector<8xf16> -// CHECK: [[output1_:%.+]], [[VAR_output2_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) -// CHECK: [[output1_0_:%.+]], [[VAR_output2_1_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_1_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) -// CHECK: [[output1_2_:%.+]], [[VAR_output2_3_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_2_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) -// CHECK: [[output1_4_:%.+]], [[VAR_output2_5_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_3_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[VAR_output1_:%.+]], [[VAR_output2_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[VAR_output1_0_:%.+]], [[VAR_output2_1_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_1_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[VAR_output1_2_:%.+]], [[VAR_output2_3_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_2_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[VAR_output1_4_:%.+]], [[VAR_output2_5_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_3_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) // CHECK: [[VAR_12_:%.+]] = affine.apply [[MAP_6_]]([[I_3_]]){{.}}[[VAR_2_]]{{.}} -// CHECK: vector.store [[output1_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]2] : memref<16x8x128xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_output1_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]2] : memref<16x8x128xf32>, vector<4xf32> // CHECK: [[VAR_13_:%.+]] = arith.addi [[VAR_12_]], [[CST_4_]] : index // CHECK: vector.store [[VAR_output2_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]3] : memref<16x8x128xf32>, vector<4xf32> // CHECK: [[VAR_14_:%.+]] = arith.addi [[VAR_12_]], [[CST_8_]] : index -// CHECK: vector.store [[output1_0_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]4] : memref<16x8x128xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_output1_0_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]4] : memref<16x8x128xf32>, vector<4xf32> // CHECK: [[VAR_15_:%.+]] = arith.addi [[VAR_12_]], [[CST_12_]] : index // CHECK: vector.store [[VAR_output2_1_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]5] : memref<16x8x128xf32>, vector<4xf32> // CHECK: [[VAR_16_:%.+]] = arith.addi [[VAR_12_]], [[CST_16_]] : index -// CHECK: vector.store [[output1_2_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]6] : memref<16x8x128xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_output1_2_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]6] : memref<16x8x128xf32>, vector<4xf32> // CHECK: [[VAR_17_:%.+]] = arith.addi [[VAR_12_]], [[CST_20_]] : index // CHECK: vector.store [[VAR_output2_3_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]7] : memref<16x8x128xf32>, vector<4xf32> // CHECK: [[VAR_18_:%.+]] = arith.addi [[VAR_12_]], [[CST_24_]] : index -// CHECK: vector.store [[output1_4_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]8] : memref<16x8x128xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_output1_4_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]8] : memref<16x8x128xf32>, vector<4xf32> // CHECK: [[VAR_19_:%.+]] = arith.addi [[VAR_12_]], [[CST_28_]] : index // CHECK: vector.store [[VAR_output2_5_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]9] : memref<16x8x128xf32>, vector<4xf32> // CHECK: } // CHECK: } else { -// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_4_:%.+]] = affine.apply [[MAP_7_]](){{.}}[[VAR_2_]]{{.}} +// CHECK: } +// CHECK: } +// CHECK: return [[RES_]] : memref<16x8x128xf32> +// CHECK: } +} + +// ----- + + +#map = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +func.func @test_unstick_expansion_127(%arg0: memref<16x8x127xf16, #map>) -> memref<16x8x127xf32> { + %alloc = memref.alloc() {alignment = 4096 : i64} : memref<16x8x127xf32> + "zlow.unstick"(%arg0, %alloc) {layout = "3DS"} : (memref<16x8x127xf16, #map>, memref<16x8x127xf32>) -> () + return %alloc : memref<16x8x127xf32> + +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0 * 64)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0)[s0] -> (s0 floordiv 64)> +// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0)[s0] -> (d0 * -64 + 63)> +// CHECK-DAG: [[MAP_4_:#.+]] = affine_map<(d0) -> (d0 + 8)> +// CHECK-DAG: [[MAP_5_:#.+]] = affine_map<(d0) -> (d0 + 16)> +// CHECK-DAG: [[MAP_6_:#.+]] = affine_map<(d0) -> (d0 + 24)> +// CHECK-DAG: [[MAP_7_:#.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK-DAG: [[MAP_8_:#.+]] = affine_map<()[s0] -> (-s0 + 120)> +// CHECK-DAG: [[MAP_9_:#.+]] = affine_map<()[s0] -> ((-s0 + 127) mod 8)> +// CHECK-DAG: [[MAP_10_:#.+]] = affine_map<()[s0] -> (-s0 - (-s0 + 127) mod 8 + 127)> +// CHECK-DAG: [[MAP_11_:#.+]] = affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)> +// CHECK-LABEL: func.func @test_unstick_expansion_127 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x127xf16, #map>) -> memref<16x8x127xf32> { +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_28_:%.+]] = arith.constant 28 : index +// CHECK-DAG: [[CST_24_:%.+]] = arith.constant 24 : index +// CHECK-DAG: [[CST_20_:%.+]] = arith.constant 20 : index +// CHECK-DAG: [[CST_16_:%.+]] = arith.constant 16 : index +// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index +// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : index +// CHECK-DAG: [[CST_32_:%.+]] = arith.constant 32 : index +// CHECK-DAG: [[CST_64_:%.+]] = arith.constant 64 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<16x8x127xf32> +// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 +// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [2, 64], strides: [64, 1] : memref<16x8x127xf16, #map> to memref<2x64xf16> +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 16, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 8, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 2){ +// CHECK: [[VAR_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]#2) +// CHECK: [[VAR_3_:%.+]] = krnl.get_linear_offset_index [[PARAM_0_]] at {{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}} : memref<16x8x127xf16, #map> +// CHECK: [[VAR_4_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]]#2){{.}}[[VAR_3_]]{{.}} +// CHECK: krnl.prefetch [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}}, read, locality<1>, data : memref<16x8x127xf16, #map> +// CHECK: krnl.prefetch [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}}, write, locality<1>, data : memref<16x8x127xf32> +// CHECK: [[VAR_5_:%.+]] = affine.apply [[MAP_3_]]([[VAR_1_]]#2){{.}}[[VAR_3_]]{{.}} +// CHECK: [[VAR_6_:%.+]] = arith.cmpi sge, [[VAR_5_]], [[CST_0_]] : index +// CHECK: scf.if [[VAR_6_]] { +// CHECK: scf.for [[I_3_:%.+]] = [[CST_0_]] to [[CST_64_]] step [[CST_32_]] { +// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[I_3_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK-DAG: [[VAR_8_:%.+]] = affine.apply [[MAP_4_]]([[I_3_]]) +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_1_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_8_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK-DAG: [[VAR_10_:%.+]] = affine.apply [[MAP_5_]]([[I_3_]]) +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_2_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_10_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK-DAG: [[VAR_12_:%.+]] = affine.apply [[MAP_6_]]([[I_3_]]) +// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_3_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_12_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK: [[VAR_output1_:%.+]], [[VAR_output2_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[VAR_output1_0_:%.+]], [[VAR_output2_1_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_1_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[VAR_output1_2_:%.+]], [[VAR_output2_3_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_2_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[VAR_output1_4_:%.+]], [[VAR_output2_5_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_3_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[VAR_14_:%.+]] = affine.apply [[MAP_7_]]([[I_3_]]){{.}}[[VAR_2_]]{{.}} +// CHECK: vector.store [[VAR_output1_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]4] : memref<16x8x127xf32>, vector<4xf32> +// CHECK: [[VAR_15_:%.+]] = arith.addi [[VAR_14_]], [[CST_4_]] : index +// CHECK: vector.store [[VAR_output2_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]5] : memref<16x8x127xf32>, vector<4xf32> +// CHECK: [[VAR_16_:%.+]] = arith.addi [[VAR_14_]], [[CST_8_]] : index +// CHECK: vector.store [[VAR_output1_0_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]6] : memref<16x8x127xf32>, vector<4xf32> +// CHECK: [[VAR_17_:%.+]] = arith.addi [[VAR_14_]], [[CST_12_]] : index +// CHECK: vector.store [[VAR_output2_1_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]7] : memref<16x8x127xf32>, vector<4xf32> +// CHECK: [[VAR_18_:%.+]] = arith.addi [[VAR_14_]], [[CST_16_]] : index +// CHECK: vector.store [[VAR_output1_2_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]8] : memref<16x8x127xf32>, vector<4xf32> +// CHECK: [[VAR_19_:%.+]] = arith.addi [[VAR_14_]], [[CST_20_]] : index +// CHECK: vector.store [[VAR_output2_3_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]9] : memref<16x8x127xf32>, vector<4xf32> +// CHECK: [[VAR_20_:%.+]] = arith.addi [[VAR_14_]], [[CST_24_]] : index +// CHECK: vector.store [[VAR_output1_4_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_20_]]{{.}} : memref<16x8x127xf32>, vector<4xf32> +// CHECK: [[VAR_21_:%.+]] = arith.addi [[VAR_14_]], [[CST_28_]] : index +// CHECK: vector.store [[VAR_output2_5_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_21_]]{{.}} : memref<16x8x127xf32>, vector<4xf32> +// CHECK: } +// CHECK: } else { +// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_4_:%.+]] = affine.apply [[MAP_8_]](){{.}}[[VAR_2_]]{{.}} // CHECK: scf.for [[I_4_:%.+]] = [[CST_0_]] to [[LOAD_VAR_reinterpret_cast_MEM_4_]] step [[CST_8_]] { // CHECK: [[LOAD_VAR_reinterpret_cast_MEM_5_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[I_4_]]{{.}} : memref<2x64xf16>, vector<8xf16> -// CHECK: [[output1_0_]], [[VAR_output2_1_1_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_5_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) -// CHECK: [[VAR_10_1_:%.+]] = affine.apply [[MAP_6_]]([[I_4_]]){{.}}[[VAR_2_]]{{.}} -// CHECK: vector.store [[output1_0_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]0] : memref<16x8x128xf32>, vector<4xf32> -// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_3_:%.+]] = arith.addi [[VAR_10_1_]], [[CST_4_]] : index -// CHECK: vector.store [[VAR_output2_1_1_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]1] : memref<16x8x128xf32>, vector<4xf32> +// CHECK: [[VAR_output1_0_1_:%.+]], [[VAR_output2_1_1_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_5_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[VAR_12_1_:%.+]] = affine.apply [[MAP_7_]]([[I_4_]]){{.}}[[VAR_2_]]{{.}} +// CHECK: vector.store [[VAR_output1_0_1_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]2] : memref<16x8x127xf32>, vector<4xf32> +// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_3_:%.+]] = arith.addi [[VAR_12_1_]], [[CST_4_]] : index +// CHECK: vector.store [[VAR_output2_1_1_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]3] : memref<16x8x127xf32>, vector<4xf32> // CHECK: } -// CHECK-DAG: [[VAR_6_1_:%.+]] = affine.apply [[MAP_8_]](){{.}}[[VAR_2_]]{{.}} -// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_1_:%.+]] = affine.apply [[MAP_9_]](){{.}}[[VAR_2_]]{{.}} +// CHECK-DAG: [[VAR_8_1_:%.+]] = affine.apply [[MAP_9_]](){{.}}[[VAR_2_]]{{.}} +// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_1_:%.+]] = affine.apply [[MAP_10_]](){{.}}[[VAR_2_]]{{.}} // CHECK: [[LOAD_VAR_reinterpret_cast_MEM_6_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[LOAD_VAR_reinterpret_cast_MEM_1_]]{{.}} : memref<2x64xf16>, vector<8xf16> -// CHECK: [[output1_]], [[VAR_output2_1_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_6_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[VAR_output1_1_:%.+]], [[VAR_output2_1_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_6_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) // CHECK: [[RES_1_:%.+]] = memref.alloca() {{.*}}: memref<8xf32> -// CHECK: vector.store [[output1_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_output1_1_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<4xf32> // CHECK: vector.store [[VAR_output2_1_]], [[RES_1_]]{{.}}[[CST_4_]]{{.}} : memref<8xf32>, vector<4xf32> -// CHECK: scf.for [[I_5_:%.+]] = [[CST_0_]] to [[VAR_6_1_]] step [[CST_1_]] { +// CHECK: scf.for [[I_5_:%.+]] = [[CST_0_]] to [[VAR_8_1_]] step [[CST_1_]] { // CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_5_:%.+]] = krnl.load [[RES_1_]]{{.}}[[I_5_]]{{.}} : memref<8xf32> -// CHECK-DAG: [[VAR_10_2_:%.+]] = affine.apply [[MAP_10_]]([[I_5_]]){{.}}[[VAR_2_]], [[LOAD_VAR_reinterpret_cast_MEM_1_]]{{.}} -// CHECK: krnl.store [[LOAD_VAR_reinterpret_cast_MEM_5_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]0] : memref<16x8x128xf32> +// CHECK-DAG: [[VAR_12_2_:%.+]] = affine.apply [[MAP_11_]]([[I_5_]]){{.}}[[VAR_2_]], [[LOAD_VAR_reinterpret_cast_MEM_1_]]{{.}} +// CHECK: krnl.store [[LOAD_VAR_reinterpret_cast_MEM_5_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]2] : memref<16x8x127xf32> // CHECK: } // CHECK: } // CHECK: } -// CHECK: return [[RES_]] : memref<16x8x128xf32> +// CHECK: return [[RES_]] : memref<16x8x127xf32> // CHECK: } }