Skip to content

Commit

Permalink
Instrumentation cleanup when operation was removed (#3061)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandre Eichenberger <[email protected]>
  • Loading branch information
AlexandreEichenberger authored Feb 3, 2025
1 parent a7653c1 commit 06ad7fb
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ void addKrnlToLLVMPasses(
// pm.addNestedPass<func::FuncOp>(krnl::createConvertSeqToMemrefPass());

pm.addPass(mlir::memref::createFoldMemRefAliasOpsPass());

if (profileIR)
pm.addNestedPass<func::FuncOp>(onnx_mlir::createInstrumentCleanupPass());

if (enableBoundCheck)
pm.addPass(mlir::createGenerateRuntimeVerificationPass());
pm.addPass(krnl::createConvertKrnlToLLVMPass(verifyInputTensors,
Expand Down
2 changes: 2 additions & 0 deletions src/Pass/Passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ std::unique_ptr<mlir::Pass> createConstPropONNXToONNXPass();
std::unique_ptr<mlir::Pass> createInstrumentPass();
std::unique_ptr<mlir::Pass> createInstrumentPass(
const std::string &ops, unsigned actions);
/// Pass for instrument cleanup.
std::unique_ptr<mlir::Pass> createInstrumentCleanupPass();

/// Passes for instrumenting the ONNX ops to print their operand type
/// signatures at runtime.
Expand Down
4 changes: 4 additions & 0 deletions src/Tools/onnx-mlir-opt/RegisterPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ void registerOMPasses(int optLevel) {
mlir::registerPass(
[]() -> std::unique_ptr<mlir::Pass> { return createInstrumentPass(); });

mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return createInstrumentCleanupPass();
});

mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return createInstrumentONNXSignaturePass("NONE");
});
Expand Down
1 change: 1 addition & 0 deletions src/Transform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ add_onnx_mlir_library(OMScfParallelPrivateRegion

add_onnx_mlir_library(OMInstrument
InstrumentPass.cpp
InstrumentCleanupPass.cpp

INCLUDE_DIRS PUBLIC
${ONNX_MLIR_SRC_ROOT}/include
Expand Down
113 changes: 113 additions & 0 deletions src/Transform/InstrumentCleanupPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===------- InstrumentCleanupPass.cpp - Instrumentation -----------------===//
//
// Copyright 2025 The IBM Research Authors.
//
// =============================================================================
//
// This file implements a Function level pass that remove consecutive
// instrumentation operations (first with "before" tag and second with "after")
// as they do not measure anything.
//
//===----------------------------------------------------------------------===//

#include <regex>
#include <set>
#include <string>

#include "onnx-mlir/Compiler/OMCompilerRuntimeTypes.h"
#include "onnx-mlir/Compiler/OMCompilerTypes.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/raw_ostream.h"

#include "src/Compiler/OptionUtils.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Interface/ShapeInferenceOpInterface.hpp"
#include "src/Pass/Passes.hpp"

using namespace mlir;

namespace onnx_mlir {

/*!
* This pass insert KrnlInstrumentOp before and after each ops
*/

class InstrumentCleanupPass : public mlir::PassWrapper<InstrumentCleanupPass,
OperationPass<func::FuncOp>> {

public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InstrumentCleanupPass)

InstrumentCleanupPass(){};
InstrumentCleanupPass(const InstrumentCleanupPass &pass)
: mlir::PassWrapper<InstrumentCleanupPass,
OperationPass<func::FuncOp>>() {}

private:
public:
StringRef getArgument() const override { return "instrument-cleanup"; }

StringRef getDescription() const override {
return "instrument cleanup on ops.";
}

void runOnOperation() override {
llvm::SmallVector<Operation *> eraseOpList;
bool skipNext = false;

// Iterate on the operations nested in this function
getOperation().walk([&](mlir::Operation *op) -> WalkResult {
if (skipNext) {
skipNext = false;
return WalkResult::advance();
}
KrnlInstrumentOp firstInstrOp = mlir::dyn_cast<KrnlInstrumentOp>(op);
// Check if we have a first instrumentation op with instr before.
if (!firstInstrOp)
return WalkResult::advance();
uint64_t firstTag = firstInstrOp.getTag();
// skip if not before, or if this call initializes the instrumentation.
if (!IS_INSTRUMENT_BEFORE_OP(firstTag) || IS_INSTRUMENT_INIT(firstTag))
return WalkResult::advance();
// Check if we have a second instrumentation op with instr after.
Operation *nextOp = op->getNextNode();
if (!nextOp)
return WalkResult::advance();
KrnlInstrumentOp secondInstrOp = mlir::dyn_cast<KrnlInstrumentOp>(nextOp);
if (!secondInstrOp)
return WalkResult::advance();
uint64_t secondTag = secondInstrOp.getTag();
// skip if not after, or if this call initializes the instrumentation.
if (!IS_INSTRUMENT_AFTER_OP(secondTag) || IS_INSTRUMENT_INIT(secondTag))
return WalkResult::advance();
// Could check opName but we already have a before/after pair, it can only
// be of the same op.
// Schedule both instrumentation to be removed as there is nothing between
// the start and the stop of the instrumentation.
eraseOpList.emplace_back(op);
eraseOpList.emplace_back(nextOp);
skipNext = true;
return WalkResult::advance();
});
// Remove ops.
for (Operation *op : eraseOpList)
op->erase();
}
};
} // namespace onnx_mlir

/*!
* Create an instrumentation pass.
*/
std::unique_ptr<mlir::Pass> onnx_mlir::createInstrumentCleanupPass() {
return std::make_unique<InstrumentCleanupPass>();
}

0 comments on commit 06ad7fb

Please sign in to comment.