Skip to content

Commit

Permalink
Add support of return for KrnlCallOp (#2949)
Browse files Browse the repository at this point in the history
* change op define

Signed-off-by: chentong319 <[email protected]>

* lower to llvm

Signed-off-by: chentong319 <[email protected]>

* test

Signed-off-by: chentong319 <[email protected]>

---------

Signed-off-by: chentong319 <[email protected]>
  • Loading branch information
chentong319 authored Sep 23, 2024
1 parent bf905d1 commit 087f069
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 5 deletions.
26 changes: 21 additions & 5 deletions src/Conversion/KrnlToLLVM/KrnlCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,27 @@ class KrnlCallOpLowering : public ConversionPattern {
rewriter, op, namedAttr.getValue(), parameterTypeList, parameterList);
}

FlatSymbolRefAttr callRef =
create.llvm.getOrInsertSymbolRef(module, krnlCallOp.getFuncName(),
LLVM::LLVMVoidType::get(module.getContext()), parameterTypeList);
create.llvm.call({}, callRef, parameterList);
ValueRange returns = op->getResults();
if (returns.size() == 0) {
// There is no return
FlatSymbolRefAttr callRef =
create.llvm.getOrInsertSymbolRef(module, krnlCallOp.getFuncName(),
LLVM::LLVMVoidType::get(module.getContext()), parameterTypeList);
create.llvm.call({}, callRef, parameterList);

rewriter.eraseOp(op);
} else {
assert(returns.size() == 1 &&
"Only one return value is allowed for krnl.call now");
Type llvmReturnType =
llvmTypeConverter->convertType(returns[0].getType());

FlatSymbolRefAttr callRef = create.llvm.getOrInsertSymbolRef(
module, krnlCallOp.getFuncName(), llvmReturnType, parameterTypeList);
auto llvmCall =
create.llvm.call({llvmReturnType}, callRef, parameterList);
rewriter.replaceOp(op, llvmCall.getDefiningOp()->getResults()[0]);
}

// Destroy OMTensor wrappers of parameters.
const auto &apiRegistry =
Expand All @@ -81,7 +98,6 @@ class KrnlCallOpLowering : public ConversionPattern {
rewriter, loc, apiRegistry, RuntimeAPI::API::DESTROY_OMTENSOR, {omt});
}

rewriter.eraseOp(op);
return success();
}

Expand Down
9 changes: 9 additions & 0 deletions src/Dialect/Krnl/Krnl.td
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,22 @@ def KrnlCallOp : Op<Krnl_Dialect, "call",
DefaultValuedAttr<SI64Attr, "1">:$numOfOutput,
Variadic<AnyType>:$parameters);

// Return Value for the Call.
// No return if the type is NoneType (void in llvm)
// Only scalar type is supported now.
// In future, return of memref can be supported with pointer of OMTensor.
// The returned memref will be created inside the call.
let results = (outs Variadic<AnyTypeOf<[AnyFloat, AnyInteger]>>:$returnValue);

// builders to build KrnlCallOp from op and operands, helping conversion from
// onnx to krnl.
// The name of function can be determined by the op name and elemnt type of
// the return, or given to builder if the simple rule does not work.
// Attributes of the op will be propagated to KrnlCallOp if the copyAttrs is
// true. Or the attribute names can be specified.
let builders = [
OpBuilder<(ins "std::string":$funcNameStr, "int64_t":$numOfOutput, "mlir::ValueRange":$operands)>,
OpBuilder<(ins "mlir::StringAttr":$funcNameStr, "IntegerAttr":$numOfOutput, "mlir::ValueRange":$operands)>,
OpBuilder<(ins "std::string":$funcNameStr, "mlir::ValueRange":$results, "mlir::Operation *":$op, "mlir::ValueRange":$operands, "std::vector<std::string>":$attributeNames)>,
OpBuilder<(ins "mlir::ValueRange":$results, "mlir::Operation *":$op, "mlir::ValueRange":$operands, "bool":$copyAttrs)>,
OpBuilder<(ins "std::string":$funcNameStr, "mlir::ValueRange":$results, "mlir::Operation *":$op, "mlir::ValueRange":$operands, "std::vector<std::string>":$attributeNames)>,
Expand Down
10 changes: 10 additions & 0 deletions src/Dialect/Krnl/KrnlOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,16 @@ void KrnlCallOp::build(OpBuilder &builder, ::mlir::OperationState &odsState,
build(builder, odsState, funcNameStr, resultVals, op, operands, copyAttrs);
}

void KrnlCallOp::build(OpBuilder &builder, ::mlir::OperationState &odsState,
std::string funcName, int64_t numOfOutput, ValueRange operands) {
build(builder, odsState, {}, funcName, numOfOutput, operands);
}

void KrnlCallOp::build(OpBuilder &builder, ::mlir::OperationState &odsState,
StringAttr funcName, IntegerAttr numOfOutput, ValueRange operands) {
build(builder, odsState, {}, funcName, numOfOutput, operands);
}

void KrnlCallOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
Expand Down
10 changes: 10 additions & 0 deletions test/mlir/conversion/krnl_to_llvm/call_with_return.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: onnx-mlir-opt --convert-krnl-to-llvm %s -split-input-file | FileCheck %s

func.func private @test_krnl_call_with_return(%arg0: memref<2x3xi32>) -> i32 {
%1 = "krnl.call"() {funcName = "get_omp_num_thread", numOfOutput = 0 : si64} : () -> (i32)
func.return %1: i32
// CHECK: llvm.func @get_omp_num_thread() -> i32
// CHECK: llvm.func @test_krnl_call_with_return
// CHECK: [[VAR_0_:%.+]] = llvm.call @get_omp_num_thread() : () -> i32
// CHECK: llvm.return [[VAR_0_]] : i32
}

0 comments on commit 087f069

Please sign in to comment.