Skip to content

Commit

Permalink
Migrate adding alloca_scope to scf passes (for parallel with NNPA) (#…
Browse files Browse the repository at this point in the history
…2684)

Signed-off-by: Alexandre Eichenberger <[email protected]>
Co-authored-by: Tung D. Le <[email protected]>
  • Loading branch information
AlexandreEichenberger and tungld authored Jan 22, 2024
1 parent 9060f61 commit 592618f
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 83 deletions.
22 changes: 12 additions & 10 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,6 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
void addKrnlToAffinePasses(mlir::PassManager &pm) {
pm.addNestedPass<func::FuncOp>(
onnx_mlir::krnl::createConvertKrnlToAffinePass());
if (enableParallel) {
// Pass to ensure that memory allocated by parallel loops stay inside the
// parallel region (privatization of memory). Otherwise, all threads would
// end up sharing the same temporary data. This pass works on affine
// parallel operations, and must be executed (in presence of OMP
// parallelism) before bufferization. In practical terms, this pass add
// memref.alloca_scope inside each parallel for.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(onnx_mlir::createProcessAffineParallelPrivatePass());
}
}

void addKrnlToLLVMPasses(
Expand All @@ -227,6 +217,18 @@ void addKrnlToLLVMPasses(
// After affine is lowered, KrnlRegion for affine scope can be removed.
pm.addNestedPass<func::FuncOp>(krnl::createLowerKrnlRegionPass());

if (enableParallel) {
// Pass to ensure that memory allocated by parallel loops stay inside the
// parallel region (privatization of memory). Otherwise, all threads would
// end up sharing the same temporary data. This pass works on affine
// parallel operations, and must be executed (in presence of OMP
// parallelism) before bufferization. In practical terms, this pass add
// memref.alloca_scope inside each parallel for.
pm.addPass(onnx_mlir::createProcessScfParallelPrivatePass());
// No canonicalize passes are allowed between that pass above and the buffer
// management passes.
}

// Hoist allocations out of loop nests to avoid stack overflow.
pm.addPass(bufferization::createBufferLoopHoistingPass());

Expand Down
2 changes: 1 addition & 1 deletion src/Pass/Passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ std::unique_ptr<mlir::Pass> createLowerToKrnlPass(bool enableTiling,
bool enableSIMD, bool enableParallel, std::string opsForCall);
void configureOnnxToKrnlLoweringPass(bool reportOnParallel,
bool parallelIsEnabled, bool reportOnSimd, bool simdIsEnabled);
std::unique_ptr<mlir::Pass> createProcessAffineParallelPrivatePass();
std::unique_ptr<mlir::Pass> createProcessScfParallelPrivatePass();

#ifdef ONNX_MLIR_ENABLE_STABLEHLO
/// Add pass for lowering to Stablehlo IR.
Expand Down
2 changes: 1 addition & 1 deletion src/Tools/onnx-mlir-opt/RegisterPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void registerOMPasses(int optLevel) {
});

mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return createProcessAffineParallelPrivatePass();
return createProcessScfParallelPrivatePass();
});

mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
Expand Down
4 changes: 2 additions & 2 deletions src/Transform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ add_onnx_mlir_library(OMLowerKrnlRegion
MLIRTransformUtils
)

add_onnx_mlir_library(OMAffineParallelPrivateRegion
ProcessAffineParallelPrivate.cpp
add_onnx_mlir_library(OMScfParallelPrivateRegion
ProcessScfParallelPrivate.cpp

LINK_LIBS PUBLIC
OMSupport
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

//===-- ProcessAffineParallelPrivate.cpp - handle parallel private data ---===//
//===-- ProcessScfParallelPrivate.cpp - handle parallel private data ---===//
//
// Copyright 2023-2024 The IBM Research Authors.
//
Expand All @@ -12,25 +12,22 @@
// shared among all threads.
//
// Input:
// affine.parallel (%arg1) = (0) to (16384) step (32) {
// scf.parallel (%arg1) = (0) to (16384) step (32) {
// body
// }
//
// Output:
// affine.parallel (%arg1) = (0) to (16384) step (32) {
// scf.parallel (%arg1) = (0) to (16384) step (32) {
// memref.alloca_scope {
// body
// }
// }
//
// TODO: if we use scf.parallel, then the same optimization should be added as
// for the affine.parallel construct.
//===----------------------------------------------------------------------===//

#include "src/Transform/ProcessAffineParallelPrivate.hpp"
#include "src/Transform/ProcessScfParallelPrivate.hpp"
#include "src/Pass/Passes.hpp"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
Expand All @@ -43,47 +40,38 @@

#include "src/Support/TypeUtilities.hpp"

#define DEBUG_TYPE "affine-parallel-private"
#define DEBUG_TYPE "scf-parallel-private"

using namespace mlir;

namespace {

struct ProcessAffineParallelWithoutScopePattern
: public OpRewritePattern<affine::AffineParallelOp> {
using OpRewritePattern<affine::AffineParallelOp>::OpRewritePattern;
struct ProcessScfParallelWithoutScopePattern
: public OpRewritePattern<scf::ParallelOp> {
using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;

static bool matchParallelForWithAllocScope(
affine::AffineParallelOp parForOp) {
if (parForOp.getRegion().empty()) {
static bool matchParallelForWithAllocScope(scf::ParallelOp parForOp) {
if (parForOp.getRegion().empty())
// Ignore empty parallel regions (side effects of optimization).
return true;
}
Block *loopBody = parForOp.getBody();
Operation &firstOp = loopBody->front();
if (!isa<memref::AllocaScopeOp>(&firstOp)) {
if (!isa<memref::AllocaScopeOp>(&firstOp))
// Found a parallel region without an alloca scope, need to add one
return false;
}
// Found a parallel region WITH an alloca scope, we are good.
return true;
}

LogicalResult matchAndRewrite(affine::AffineParallelOp parForOp,
PatternRewriter &rewriter) const final {
LogicalResult matchAndRewrite(
scf::ParallelOp parForOp, PatternRewriter &rewriter) const final {
Location loc = parForOp.getLoc();
assert(!matchParallelForWithAllocScope(parForOp) &&
"expected par for without alloca here");
// Create a copy of the parallel for op, as this pass requires new ops.
SmallVector<Type, 4> resultTypes;
for (auto t : parForOp.getResults()) {
resultTypes.emplace_back(t.getType());
}
auto newParForOp = rewriter.create<affine::AffineParallelOp>(loc,
resultTypes, parForOp.getReductionsAttr(), parForOp.getLowerBoundsMap(),
parForOp.getLowerBoundsGroupsAttr(), parForOp.getUpperBoundsMap(),
parForOp.getUpperBoundsGroupsAttr(), parForOp.getSteps(),
parForOp.getMapOperands());
auto newParForOp =
rewriter.create<scf::ParallelOp>(loc, parForOp.getLowerBound(),
parForOp.getUpperBound(), parForOp.getStep(), parForOp.getInits());
rewriter.eraseBlock(newParForOp.getBody());
newParForOp.getRegion().takeBody(parForOp.getRegion());
// Create a body that is surrounded by an alloca scope.
// Code inspired from SCFToOpenMP.cpp, in ParallelOpLowering struct, line
Expand All @@ -93,13 +81,12 @@ struct ProcessAffineParallelWithoutScopePattern
// Create a block containing the ops in the loop body.
Block *ops = rewriter.splitBlock(&*newParForOp.getRegion().begin(),
newParForOp.getRegion().begin()->begin());
auto oldYield = cast<affine::AffineYieldOp>(ops->getTerminator());

auto oldYield = cast<scf::YieldOp>(ops->getTerminator());
// Insertion point at the top of the loop.
rewriter.setInsertionPointToStart(&*newParForOp.getRegion().begin());
// Create scope and affine yield.
// Create scope and scf yield.
auto scope = rewriter.create<memref::AllocaScopeOp>(loc, TypeRange());
rewriter.create<affine::AffineYieldOp>(loc, oldYield.getOperands());
rewriter.create<scf::YieldOp>(loc, oldYield.getOperands());
// Move the ops of the loop body into the alloca scope.
Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion());
rewriter.mergeBlocks(ops, scopeBlock);
Expand All @@ -109,67 +96,66 @@ struct ProcessAffineParallelWithoutScopePattern
oldYield, oldYield->getOperands());
}
rewriter.replaceOp(parForOp, newParForOp);

return success();
}
};

struct ProcessAffineParallelPrivatePass
: public PassWrapper<ProcessAffineParallelPrivatePass,
struct ProcessScfParallelPrivatePass
: public PassWrapper<ProcessScfParallelPrivatePass,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ProcessAffineParallelPrivatePass)
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ProcessScfParallelPrivatePass)

ProcessAffineParallelPrivatePass() {}
ProcessAffineParallelPrivatePass(const ProcessAffineParallelPrivatePass &pass)
: mlir::PassWrapper<ProcessAffineParallelPrivatePass,
ProcessScfParallelPrivatePass() {}
ProcessScfParallelPrivatePass(const ProcessScfParallelPrivatePass &pass)
: mlir::PassWrapper<ProcessScfParallelPrivatePass,
OperationPass<func::FuncOp>>() {}

StringRef getArgument() const override { return "affine-parallel-private"; }
StringRef getArgument() const override { return "scf-parallel-private"; }

StringRef getDescription() const override {
return "Process affine parallel for op to support private variables.";
return "Process scf parallel for op to support private variables.";
}

void runOnOperation() final;

typedef PassWrapper<ProcessAffineParallelPrivatePass,
typedef PassWrapper<ProcessScfParallelPrivatePass,
OperationPass<func::FuncOp>>
BaseType;
};

void ProcessAffineParallelPrivatePass::runOnOperation() {
void ProcessScfParallelPrivatePass::runOnOperation() {
func::FuncOp function = getOperation();
MLIRContext *context = &getContext();

ConversionTarget target(getContext());
target.addLegalDialect<mlir::affine::AffineDialect, mlir::arith::ArithDialect,
target.addLegalDialect<mlir::scf::SCFDialect, mlir::arith::ArithDialect,
mlir::memref::MemRefDialect, mlir::func::FuncDialect,
mlir::vector::VectorDialect, mlir::scf::SCFDialect>();
mlir::vector::VectorDialect>();

// Locate parallel for without scope
target.addDynamicallyLegalOp<affine::AffineParallelOp>(
[](affine::AffineParallelOp op) {
return ProcessAffineParallelWithoutScopePattern::
matchParallelForWithAllocScope(op);
});
target.addDynamicallyLegalOp<scf::ParallelOp>([](scf::ParallelOp op) {
return ProcessScfParallelWithoutScopePattern::
matchParallelForWithAllocScope(op);
});
RewritePatternSet patterns(context);
onnx_mlir::getParallelPrivateAffineToAffinePatterns(patterns);
onnx_mlir::getParallelPrivateScfToScfPatterns(patterns);

if (failed(applyPartialConversion(function, target, std::move(patterns))))
signalPassFailure();
}

} // namespace

void onnx_mlir::getParallelPrivateAffineToAffinePatterns(
void onnx_mlir::getParallelPrivateScfToScfPatterns(
mlir::RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.insert<ProcessAffineParallelWithoutScopePattern>(context);
patterns.insert<ProcessScfParallelWithoutScopePattern>(context);
}

/*!
* Create a RecomposeONNX pass.
*/
std::unique_ptr<mlir::Pass>
onnx_mlir::createProcessAffineParallelPrivatePass() {
return std::make_unique<ProcessAffineParallelPrivatePass>();
std::unique_ptr<mlir::Pass> onnx_mlir::createProcessScfParallelPrivatePass() {
return std::make_unique<ProcessScfParallelPrivatePass>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ namespace onnx_mlir {

// Exports the RecomposeONNXToONNXPass patterns. They are all plain rewrite
// patterns that can be used with any PatternRewriter, not conversion patterns.
void getParallelPrivateAffineToAffinePatterns(
mlir::RewritePatternSet &patterns);
void getParallelPrivateScfToScfPatterns(mlir::RewritePatternSet &patterns);

} // namespace onnx_mlir
Original file line number Diff line number Diff line change
@@ -1,50 +1,57 @@
// RUN: onnx-mlir-opt -O3 --march=x86-64 --affine-parallel-private %s -split-input-file | FileCheck %s
// RUN: onnx-mlir-opt -O3 --march=x86-64 --scf-parallel-private %s -split-input-file | FileCheck %s

// -----


func.func @add_with_par(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) attributes {llvm.emit_c_interface} {
func.func @add_with_par(%arg0: memref<16x8x128xf32>) -> (memref<16x8x128xf32>) {
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c16384 = arith.constant 16384 : index
%alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32>
%alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
affine.store %c16384, %alloc_0[0] : memref<1xindex>
memref.store %c16384, %alloc_0[%c0] : memref<1xindex>
%reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
%alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
affine.store %c16384, %alloc_1[0] : memref<1xindex>
memref.store %c16384, %alloc_1[%c0] : memref<1xindex>
%reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
%alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex>
affine.store %c16384, %alloc_3[0] : memref<1xindex>
memref.store %c16384, %alloc_3[%c0] : memref<1xindex>
%reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
affine.parallel (%arg1) = (0) to (16384) step (32) {
scf.parallel (%arg1) = (%c0) to (%c16384) step (%c32) {
%0 = vector.load %reshape[%arg1] : memref<16384xf32>, vector<32xf32>
%1 = vector.load %reshape_2[%arg1] : memref<16384xf32>, vector<32xf32>
%2 = arith.addf %0, %1 : vector<32xf32>
vector.store %2, %reshape_4[%arg1] : memref<16384xf32>, vector<32xf32>
scf.yield
}
return %alloc : memref<16x8x128xf32>

// mlir2FileCheck.py
// CHECK-LABEL: func.func @add_with_par
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) attributes {llvm.emit_c_interface} {
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32>) -> memref<16x8x128xf32> {
// CHECK-DAG: [[CST_32_:%.+]] = arith.constant 32 : index
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[CST_16384_:%.+]] = arith.constant 16384 : index
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<16x8x128xf32>
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<1xindex>
// CHECK: affine.store [[CST_16384_]], [[RES_1_]][0] : memref<1xindex>
// CHECK: memref.store [[CST_16384_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<1xindex>
// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_1_]]) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex>
// CHECK: affine.store [[CST_16384_]], [[RES_2_]][0] : memref<1xindex>
// CHECK: memref.store [[CST_16384_]], [[RES_2_]]{{.}}[[CST_0_]]{{.}} : memref<1xindex>
// CHECK-DAG: [[VAR_reshape_2_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_2_]]) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<1xindex>
// CHECK: affine.store [[CST_16384_]], [[RES_3_]][0] : memref<1xindex>
// CHECK: memref.store [[CST_16384_]], [[RES_3_]]{{.}}[[CST_0_]]{{.}} : memref<1xindex>
// CHECK: [[VAR_reshape_4_:%.+]] = memref.reshape [[RES_]]([[RES_]]_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32>
// CHECK: affine.parallel ([[arg1_:%.+]]) = (0) to (16384) step (32) {
// CHECK: scf.parallel ([[arg1_:%.+]]) = ([[CST_0_]]) to ([[CST_16384_]]) step ([[CST_32_]]) {
// CHECK: memref.alloca_scope {
// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[arg1_]]{{.}} : memref<16384xf32>, vector<32xf32>
// CHECK-DAG: [[LOAD_VAR_reshape_2_MEM_:%.+]] = vector.load [[VAR_reshape_2_]]{{.}}[[arg1_]]{{.}} : memref<16384xf32>, vector<32xf32>
// CHECK: [[VAR_2_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_]], [[LOAD_VAR_reshape_2_MEM_]] : vector<32xf32>
// CHECK: vector.store [[VAR_2_]], [[VAR_reshape_4_]]{{.}}[[arg1_]]{{.}} : memref<16384xf32>, vector<32xf32>
// CHECK: }
// CHECK: scf.yield
// CHECK: }
// CHECK: return [[RES_]] : memref<16x8x128xf32>
// CHECK: }
}

0 comments on commit 592618f

Please sign in to comment.