Skip to content

Commit

Permalink
Option to set the number of threads for parallel compilation (#3048)
Browse files Browse the repository at this point in the history
* Add -j option for specifying the number of threads for compilation.

Signed-off-by: Haruki Imai <[email protected]>
  • Loading branch information
imaihal authored Jan 30, 2025
1 parent 2e4a46a commit 6d2b1d4
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 3 deletions.
9 changes: 8 additions & 1 deletion src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//===------------------------ CompilerOptions.cpp -------------------------===//
//
// Copyright 2022, 2024 The IBM Research Authors.
// Copyright 2022-2025 The IBM Research Authors.
//
// =============================================================================
//
Expand Down Expand Up @@ -46,6 +46,7 @@ bool disableKrnlOpFusion; // common for both
bool disableQuantZeroPoint; // common for both
bool enableKrnlBufferReuse; // common for both
bool disableMemRefPrefetch; // common for both
uint64_t compilationNumThreads; // common for both
EmissionTargetType emissionTarget; // onnx-mlir only
bool invokeOnnxVersionConverter; // onnx-mlir only
bool preserveLocations; // onnx-mlir only
Expand Down Expand Up @@ -617,6 +618,12 @@ static llvm::cl::opt<bool, true> disableConstantPropOpt("disable-constant-prop",
llvm::cl::location(disableConstantProp), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<uint64_t, true> compilation_num_threads("j",
llvm::cl::desc("Use <int> threads for compilation. The default value is "
"0, which spawns threads for all available CPUs.\n"),
llvm::cl::location(compilationNumThreads), llvm::cl::init(0),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::list<std::string, std::vector<std::string>> extraLibPathsOpt(
"L",
llvm::cl::desc(
Expand Down
3 changes: 2 additions & 1 deletion src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//===------------------------ CompilerOptions.hpp -------------------------===//
//
// Copyright 2022-2024 The IBM Research Authors.
// Copyright 2022-2025 The IBM Research Authors.
//
// =============================================================================
//
Expand Down Expand Up @@ -92,6 +92,7 @@ extern bool disableKrnlOpFusion; // common for both
extern bool disableQuantZeroPoint; // common for both
extern bool enableKrnlBufferReuse; // common for both
extern bool disableMemRefPrefetch; // common for both
extern uint64_t compilationNumThreads; // common for both
extern EmissionTargetType emissionTarget; // onnx-mlir only
extern bool invokeOnnxVersionConverter; // onnx-mlir only
extern bool preserveLocations; // onnx-mlir only
Expand Down
11 changes: 11 additions & 0 deletions src/Tools/onnx-mlir-opt/onnx-mlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <mlir/IR/AsmState.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/Threading.h>
#include <mlir/InitAllPasses.h>
#include <mlir/Interfaces/ViewLikeInterface.h>
#include <mlir/Pass/Pass.h>
Expand Down Expand Up @@ -184,8 +185,18 @@ int main(int argc, char **argv) {
for (auto *accel : accel::Accelerator::getAccelerators())
accel->configurePasses();

std::unique_ptr<llvm::ThreadPoolInterface> threadPoolPtr = nullptr;
auto passManagerSetupFn = [&](PassManager &pm) {
MLIRContext *ctx = pm.getContext();
// Set number of threads in the MLIRContext
if (compilationNumThreads > 0)
ctx->disableMultithreading();
if (compilationNumThreads > 1) {
threadPoolPtr = std::make_unique<llvm::DefaultThreadPool>(
llvm::hardware_concurrency(compilationNumThreads));
ctx->setThreadPool(*threadPoolPtr);
}

// MlirOptMain constructed ctx with our registry so we just load all our
// already registered dialects.
ctx->loadAllAvailableDialects();
Expand Down
16 changes: 15 additions & 1 deletion src/onnx-mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//===------------------ onnx-mlir.cpp - Compiler Driver ------------------===//
//
// Copyright 2019-2022 The IBM Research Authors.
// Copyright 2019-2025 The IBM Research Authors.
//
// =============================================================================
// Main function for onnx-mlir.
Expand All @@ -15,6 +15,7 @@
#include <regex>

#include "mlir/IR/AsmState.h"
#include "mlir/IR/Threading.h"
#include "mlir/Support/Timing.h"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Compiler/CompilerUtils.hpp"
Expand Down Expand Up @@ -68,7 +69,20 @@ int main(int argc, char *argv[]) {
}

// Create context after MLIRContextCLOptions are registered and parsed.
// The multi-threading in MLIRContext is enabled by default. It must be
// disabled to control the number of threads. To use single thread, simply
// disable it. To use a specific number of threads, disable it once and then
// set a new thread pool.
mlir::MLIRContext context;
std::unique_ptr<llvm::ThreadPoolInterface> threadPoolPtr = nullptr;
if (compilationNumThreads > 0)
context.disableMultithreading();
if (compilationNumThreads > 1) {
threadPoolPtr = std::make_unique<llvm::DefaultThreadPool>(
llvm::hardware_concurrency(compilationNumThreads));
context.setThreadPool(*threadPoolPtr);
}

if (!context.isMultithreadingEnabled()) {
assert(context.getNumThreads() == 1 && "1 thread if no multithreading");
LLVM_DEBUG(llvm::dbgs() << "multithreading is disabled\n");
Expand Down

0 comments on commit 6d2b1d4

Please sign in to comment.