Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexandreEichenberger committed Feb 12, 2025
2 parents 959daa2 + d6da8b2 commit 3a1b1a0
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 6 deletions.
6 changes: 3 additions & 3 deletions .buildbot/jenkins_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ def strtobool(s: str) -> bool:
def compute_file_sha1(file_name):
"""Compute sha1 of a file."""

sha1sum = hashlib.sha1()
sha3_256sum = hashlib.sha3_256()
try:
with open(file_name, "rb") as f:
for data in iter(lambda: f.read(READ_CHUNK_SIZE), b""):
sha1sum.update(data)
return sha1sum.hexdigest()
sha3_256sum.update(data)
return sha3_256sum.hexdigest()
except:
return ""

Expand Down
11 changes: 11 additions & 0 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ std::string opsForCall; // common for both
bool disableKrnlOpFusion; // common for both
bool disableQuantZeroPoint; // common for both
bool enableKrnlBufferReuse; // common for both
bool enableSafeCodeGen; // common for both
bool disableMemRefPrefetch; // common for both
uint64_t compilationNumThreads; // common for both
EmissionTargetType emissionTarget; // onnx-mlir only
Expand Down Expand Up @@ -245,6 +246,16 @@ static llvm::cl::opt<bool, true> enableKrnlBufferReuseOpt(
llvm::cl::location(enableKrnlBufferReuse), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<bool, true> enableSafeCodeGenOpt("enable-safe-code-gen",
llvm::cl::desc("enable extra runtime check to be created in code gen. "
"Such check will have cost at runtime, and is not needed if"
"the model and the data are correct."
"Failure of check will trigger assertion error."
"(default=false).\n"
"Set to 'true' if you want to enable the check."),
llvm::cl::location(enableSafeCodeGen), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<bool, true> disableMemRefPrefetchOpt(
"disable-memref-prefetch",
llvm::cl::desc("Disable generation of memref.prefetch (default=false).\n"
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ extern std::string opsForCall; // common for both
extern bool disableKrnlOpFusion; // common for both
extern bool disableQuantZeroPoint; // common for both
extern bool enableKrnlBufferReuse; // common for both
extern bool enableSafeCodeGen; // common for both
extern bool disableMemRefPrefetch; // common for both
extern uint64_t compilationNumThreads; // common for both
extern EmissionTargetType emissionTarget; // onnx-mlir only
Expand Down
3 changes: 2 additions & 1 deletion src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// Krnl IR and standard operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
Expand Down Expand Up @@ -367,7 +368,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
target.addLegalDialect<KrnlDialect, affine::AffineDialect,
arith::ArithDialect, func::FuncDialect, linalg::LinalgDialect,
math::MathDialect, vector::VectorDialect, memref::MemRefDialect,
shape::ShapeDialect, scf::SCFDialect>();
shape::ShapeDialect, scf::SCFDialect, cf::ControlFlowDialect>();
// Needed to support unsigned int computations. To be removed if we use a
// scheme that does not rely on the UnrealizedConversionCastOp.
target.addLegalOp<::mlir::UnrealizedConversionCastOp>();
Expand Down
42 changes: 41 additions & 1 deletion src/Conversion/ONNXToKrnl/Tensor/Gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"

#include "src/Compiler/CompilerOptions.hpp"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"

Expand All @@ -37,7 +41,8 @@ struct ONNXGatherOpLowering : public OpConversionPattern<ONNXGatherOp> {
Location loc = ONNXLoc<ONNXGatherOp>(op);
ValueRange operands = adaptor.getOperands();

MultiDialectBuilder<KrnlBuilder, IndexExprBuilderForKrnl, MemRefBuilder>
MultiDialectBuilder<KrnlBuilder, IndexExprBuilderForKrnl, MemRefBuilder,
MathBuilder>
create(rewriter, loc);

// Get shape.
Expand Down Expand Up @@ -122,6 +127,41 @@ struct ONNXGatherOpLowering : public OpConversionPattern<ONNXGatherOp> {
if (indicesMayBeNegative)
index = index.selectOrSelf(index < zeroIE, index + axisDim);

// The Gather op is data dependent: the value of index should be
// within the input data size.
// Add runtime check if enableSafeCodeGen is set true
// Implementation comments vs. createGenerateRuntimeVerificationPass
// This check is according to onnx op semantics, not general bound
// check for memref. Implementation of RuntimeVerification could be
// borrowed. Slightly difference is that onnx semenatics check is for
// each dimension independently, not the final address is within
// the memref bound.
if (enableSafeCodeGen) {
// From onnx document:
// All index values are expected to be within bounds [-s, s-1]
// along axis of size s. It is an error if any of the index values
// are out of bounds.
// After the negative correction, the range should be [0, s-1]
Value upperBound = create.mem.dim(data, axisLit);
Value compareUpperBound =
create.math.slt(index.getValue(), upperBound);
// Report onnx_node_name if the op has the attribute
std::string nodeNameStr = op->getName().getStringRef().str() + " ";
StringAttr nodeName =
op->getAttrOfType<mlir::StringAttr>("onnx_node_name");
if (nodeName && !nodeName.getValue().empty()) {
nodeNameStr = nodeNameStr + nodeName.getValue().str();
}
rewriter.create<cf::AssertOp>(loc, compareUpperBound,
nodeNameStr +
" indices of GatherOp is larger than the upper bound");
Value compareLowerBound =
create.math.sge(index.getValue(), zeroIE.getValue());
rewriter.create<cf::AssertOp>(loc, compareLowerBound,
nodeNameStr +
" indices of GatherOp is less than the lower bound");
}

// Compute access function of data: data[ii + (indices[jj],) + kk]
SmallVector<IndexExpr, 4> dataAccessFct;
// First add indices iis
Expand Down
33 changes: 32 additions & 1 deletion src/Conversion/ONNXToKrnl/Tensor/GatherElements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"

#include "src/Compiler/CompilerOptions.hpp"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"

Expand All @@ -31,7 +35,8 @@ struct ONNXGatherElementsOpLowering
Location loc = ONNXLoc<ONNXGatherElementsOp>(op);
ValueRange operands = adaptor.getOperands();

MultiDialectBuilder<KrnlBuilder, IndexExprBuilderForKrnl, MemRefBuilder>
MultiDialectBuilder<KrnlBuilder, IndexExprBuilderForKrnl, MemRefBuilder,
MathBuilder>
create(rewriter, loc);

// Get shape.
Expand Down Expand Up @@ -93,6 +98,32 @@ struct ONNXGatherElementsOpLowering
index = index.selectOrSelf(index < zero, index + axisDim);
}

// Check the dynamic requirement of GatherElement Op
// Refer to the comments in Gather.cpp
if (enableSafeCodeGen) {
// From onnx document:
// All index values are expected to be within bounds [-s, s-1]
// along axis of size s. It is an error if any of the index values
// are out of bounds.
// After the negative correction, the range should be [0, s-1]
Value upperBound = create.mem.dim(data, axis);
Value compareUpperBound =
create.math.slt(index.getValue(), upperBound);
std::string nodeNameStr = op->getName().getStringRef().str() + " ";
StringAttr nodeName =
op->getAttrOfType<mlir::StringAttr>("onnx_node_name");
if (nodeName && !nodeName.getValue().empty()) {
nodeNameStr = nodeNameStr + nodeName.getValue().str();
}
rewriter.create<cf::AssertOp>(loc, compareUpperBound,
"indices of GatherOp is larger than the upper bound");
LiteralIndexExpr zero(0);
Value compareLowerBound =
create.math.sge(index.getValue(), zero.getValue());
rewriter.create<cf::AssertOp>(loc, compareLowerBound,
"indices of GatherOp is less than the lower bound");
}

// Access function for the 'data' tensor.
DimsExpr dataAccessFct;
for (int64_t i = 0; i < dataRank; ++i)
Expand Down

0 comments on commit 3a1b1a0

Please sign in to comment.