Skip to content

Commit

Permalink
Parallelization of ConstProp compilation (#3042)
Browse files Browse the repository at this point in the history
To accelerate compilation time, this PR parallelizes compilation of ConstProp using `parallelFor()`. This mainly improves constant propagation for reduction computation. Run sequentially without applying this parallelization when input tensor is small to avoid parallelization overhead.

---------

Signed-off-by: Haruki Imai <[email protected]>
  • Loading branch information
imaihal authored Feb 6, 2025
1 parent d8de38c commit ab75f99
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 20 deletions.
86 changes: 70 additions & 16 deletions src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
//===----------------------------------------------------------------------===//

#include "src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp"

#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Threading.h"
#include "llvm/ADT/STLExtras.h"

#include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp"
Expand Down Expand Up @@ -849,6 +849,8 @@ ElementsAttr ElementsAttrBuilder::reduce(ElementsAttr elms,
if (axes.empty())
return elms;

Type elementType = elms.getElementType();
MLIRContext *ctx = elementType.getContext();
SmallVector<unsigned, 4> sortedAxes(axes);
std::sort(sortedAxes.begin(), sortedAxes.end());
assert(
Expand Down Expand Up @@ -885,22 +887,74 @@ ElementsAttr ElementsAttrBuilder::reduce(ElementsAttr elms,

ShapedType reducedType = type.clone(reducedShape);
return fromWideNums(reducedType, [&](MutableArrayRef<WideNum> dstNums) {
// Traverse and populate each element d in dstNums.
for (auto &idxoffs : StridesRange<1>(reducedShape, {reducedStrides})) {
WideNum &d = dstNums[idxoffs.flattenedIndex];
int64_t srcPos = idxoffs[0];
// Traverse all the elements that reduce together into d.
// srcNums elements may be repeated if there are zeros in axesStrides.
StridesRange<1> axesRange(axesShape, {axesStrides});
auto axesIter = axesRange.begin();
auto axesEnd = axesRange.end();
assert(axesIter->at(0) == 0 && "initial src offset must be zero");
d = srcNums.get()[srcPos];
while (++axesIter != axesEnd) {
int64_t srcOffset = axesIter->at(0);
d = reducer(d, srcNums.get()[srcPos + srcOffset]);
StridesRange<1> sRange(reducedShape, {reducedStrides});
StridesRange<1> axesRange(axesShape, {axesStrides});
SmallVector<std::pair<int64_t, uint64_t>, 4> batch;
for (auto &idxoffs : sRange)
batch.emplace_back(std::make_pair(idxoffs.flattenedIndex, idxoffs[0]));

auto fetchBatch = [&](size_t threadNumber, bool parallel) {
// retrun all data without spliting for sequential execution.
if (!parallel)
return llvm::make_range(batch.begin(), batch.end());
// Each thread fetches the same batch size. The leftovers are set in the
// threads with small thread number.
size_t tileSize = floor(batch.size() / ctx->getNumThreads());
size_t leftovers = batch.size() % ctx->getNumThreads();
int beginOffset;
if (threadNumber < leftovers) {
// for the first few threads, it is as if the block size is larger by 1.
tileSize++;
beginOffset = threadNumber * tileSize;
} else {
// for the last threads, its as we shift the start by leftovers.
beginOffset = threadNumber * tileSize + leftovers;
}
}
int endOffset = beginOffset + tileSize;
return llvm::make_range(
batch.begin() + beginOffset, batch.begin() + endOffset);
};

auto work = [&](size_t threadNumber, bool parallel = true) {
auto tile = fetchBatch(threadNumber, parallel);
// Traverse and populate each element d in dstNums.
for (auto b : tile) {
WideNum &d = dstNums[b.first];
int64_t srcPos = b.second;
// Traverse all the elements that reduce together into d.
// srcNums elements may be repeated if there are zeros in axesStrides.
auto axesIter = axesRange.begin();
auto axesEnd = axesRange.end();
assert(axesIter->at(0) == 0 && "initial src offset must be zero");
d = srcNums.get()[srcPos];
while (++axesIter != axesEnd) {
int64_t srcOffset = axesIter->at(0);
d = reducer(d, srcNums.get()[srcPos + srcOffset]);
}
}
};
// Using 'parallelFor()' introduces large overhead. Followings are actual
// measurement results on IBM z16 to decide the 'minCount'. We measured
// 'onnx.ReduceSum()' in 'test/mlir/onnx/onnx_constprop_parallel.mlir' using
// several input size. From these results, we decided to use 2000 as the
// 'minCount'.
//
// inputCounts|Sequential | Parallel with 2 threads
// | (work()) | (parallelFor())
// | (msec) | (msec)
// --------------------------------------------------
// 400 | 0.065 | 0.153
// 800 | 0.115 | 0.164
// 1200 | 0.175 | 0.201
// 1600 | 0.226 | 0.228
// 2000 | 0.282 | 0.258
// 2400 | 0.336 | 0.284
constexpr size_t minCount = 2000;
size_t inputCount = batch.size() * axesRange.size();
if (inputCount < minCount)
work(0, /*parallel*/ false);
else
parallelFor(ctx, 0, ctx->getNumThreads(), work);
});
}

Expand Down
45 changes: 41 additions & 4 deletions src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#ifndef ONNX_MLIR_ELEM_ATTR_BUILDER_H
#define ONNX_MLIR_ELEM_ATTR_BUILDER_H
#include "mlir/IR/Threading.h"

#include "src/Dialect/ONNX/ElementsAttr/BType.hpp"
#include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp"
Expand Down Expand Up @@ -244,10 +245,46 @@ class ElementsAttrBuilder {
// Constructs a transformer that changes every element to the result of
// applying the given function to the element.
template <typename Function = WideNum (*)(WideNum)>
static inline Transformer functionTransformer(Function fun) {
return [fun = std::move(fun)](llvm::MutableArrayRef<WideNum> data) -> void {
for (WideNum &n : data)
n = fun(n);
inline Transformer functionTransformer(Function fun) {
mlir::MLIRContext *ctx = disposablePool.getContext();
return [fun = std::move(fun), ctx](
llvm::MutableArrayRef<WideNum> data) -> void {
auto fetchBatch = [&](size_t threadNumber, bool parallel) {
// retrun all data without spliting for sequential execution.
if (!parallel)
return llvm::make_range(data.begin(), data.end());
// Each thread fetches the same data size. The leftovers are set in the
// threads with small thread number.
size_t tileSize = floor(data.size() / ctx->getNumThreads());
size_t leftovers = data.size() % ctx->getNumThreads();
int beginOffset;
if (threadNumber < leftovers) {
// for the first few threads, it is as if the block size is larger
// by 1.
tileSize++;
beginOffset = threadNumber * tileSize;
} else {
// for the last threads, its as we shift the start by leftovers.
beginOffset = threadNumber * tileSize + leftovers;
}
int endOffset = beginOffset + tileSize;
return llvm::make_range(
data.begin() + beginOffset, data.begin() + endOffset);
};

auto work = [&](size_t threadNumber, bool parallel = true) {
auto tile = fetchBatch(threadNumber, parallel);
for (WideNum &n : tile)
n = fun(n);
};
// Using 'parallelFor()' introduces large overhead.
// To avoid this overhead, call work() directry if input size is less than
// `minCount`.
constexpr size_t minCount = 1000;
if (data.size() < minCount)
work(0, /*parallel*/ false);
else
parallelFor(ctx, 0, ctx->getNumThreads(), work);
};
}

Expand Down
Loading

0 comments on commit ab75f99

Please sign in to comment.