Skip to content

Commit

Permalink
Add support for ONNX.shape with permutation pattern (#3066)
Browse files Browse the repository at this point in the history
*add ONNX Dialect builder to create shape without the output type, and also enables a permutations of the dims, as required by some newer optimizations.

---------

Signed-off-by: Alexandre Eichenberger <[email protected]>
  • Loading branch information
AlexandreEichenberger authored Feb 4, 2025
1 parent 06ad7fb commit f44085b
Show file tree
Hide file tree
Showing 11 changed files with 104 additions and 63 deletions.
16 changes: 0 additions & 16 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,4 @@ def ACT_RELUAttr: NativeCodeCall<"$_builder.getStringAttr(\"ACT_RELU\")">;

def GetTypeOf : NativeCodeCall<"$0.getType()" >;

def GetNullAttr : NativeCodeCall<"Attribute()">;

def GetZeroI64Attr: NativeCodeCall<
"IntegerAttr::get($_builder.getIntegerType(64, /*isSigned=*/true), APInt(64, 0, /*isSigned=*/true))"
>;

def IsCompatibleWithNNPALevelArch14: Constraint<
CPred<"isCompatibleWithNNPALevel(NNPALevel::M14)">,
"Input level is compatible with NNPA level"
>;

def IsCompatibleWithNNPALevelArch15: Constraint<
CPred<"isCompatibleWithNNPALevel(NNPALevel::M15)">,
"Input level is compatible with NNPA level"
>;

#endif // ONNX_TO_ZHIGH_COMMON
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def expandConstantOperandForAddOp1: Pat<
def expandConstantOperandForAddOp2: Pat<
(ONNXAddOp $x, (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_)),
(ONNXAddOp $x, (ONNXExpandOp $c,
(CreateShapeOp (GetShapeTypeOf $x), $x),
(CreateShapeOp $x),
(returnType $x))),
[(IsUniBroadcastingFromFirstToSecond $c, $x)]
>;
Expand All @@ -126,15 +126,15 @@ def expandConstantOperandForAddOp2: Pat<
def expandConstantOperandForDivOp1: Pat<
(ONNXDivOp $x, (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_)),
(ONNXDivOp $x, (ONNXExpandOp $c,
(CreateShapeOp (GetShapeTypeOf $x), $x),
(CreateShapeOp $x),
(returnType $x))),
[(IsUniBroadcastingFromFirstToSecond $c, $x)]
>;

def expandConstantOperandForDivOp2: Pat<
(ONNXDivOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x),
(ONNXDivOp (ONNXExpandOp $c,
(CreateShapeOp (GetShapeTypeOf $x), $x),
(CreateShapeOp $x),
(returnType $x)),
$x),
[(IsUniBroadcastingFromFirstToSecond $c, $x)]
Expand All @@ -153,7 +153,7 @@ def expandConstantOperandForMulOp1: Pat<
def expandConstantOperandForMulOp2: Pat<
(ONNXMulOp $x, (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_)),
(ONNXMulOp $x, (ONNXExpandOp $c,
(CreateShapeOp (GetShapeTypeOf $x), $x),
(CreateShapeOp $x),
(returnType $x))),
[(IsUniBroadcastingFromFirstToSecond $c, $x)]
>;
Expand All @@ -165,15 +165,15 @@ def expandConstantOperandForMulOp2: Pat<
def expandConstantOperandForSubOp1: Pat<
(ONNXSubOp $x, (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_)),
(ONNXSubOp $x, (ONNXExpandOp $c,
(CreateShapeOp (GetShapeTypeOf $x), $x),
(CreateShapeOp $x),
(returnType $x))),
[(IsUniBroadcastingFromFirstToSecond $c, $x)]
>;

def expandConstantOperandForSubOp2: Pat<
(ONNXSubOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x),
(ONNXSubOp (ONNXExpandOp $c,
(CreateShapeOp (GetShapeTypeOf $x), $x),
(CreateShapeOp $x),
(returnType $x)),
$x),
[(IsUniBroadcastingFromFirstToSecond $c, $x)]
Expand Down Expand Up @@ -378,7 +378,7 @@ def rewriteSoftmaxNDto3D: Pat<
(ReshapeTo3D:$r $input),
(SoftmaxAxisMinusOne),
(returnType $r)),
(CreateShapeOp (GetShapeResultType $input), $input),
(CreateShapeOp $input),
(GetZeroI64Attr)),
[(HasRankGT<3> $input)]
>;
Expand Down
9 changes: 9 additions & 0 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,15 @@ def ZHighReshapeOp:ZHigh_Op<"Reshape", [Pure,
OptionalAttr<StrAttr>:$layout); // Layout of output Z Tensor, default same as input.
let results = (outs AnyTypeOf<[AnyZTensor]>:$result);

let builders = [
// Copied from matmul, needed for DDR.
OpBuilder<(ins "::mlir::Value":$source, "::mlir::Value":$shape, "::mlir::StringAttr":$layout), [{
Type elementType = mlir::cast<ShapedType>(source.getType()).getElementType();
UnrankedTensorType resType = UnrankedTensorType::get(elementType);
build($_builder, $_state, resType, source, shape, layout);
}]>
];

let extraClassDefinition = [{
onnx_mlir::ONNXOpShapeHelper * ZHighReshapeOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef<mlir::Value> oper,
onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) {
Expand Down
20 changes: 0 additions & 20 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -636,25 +636,5 @@ IntegerAttr getDefaultSaturation(PatternRewriter &rewriter) {
return IntegerAttr();
}

// Create an array tensor to contain three dimensions of layout 3DS.
// The tensor is created from 4DS's shape by removing the value 1 at axis 1.
// e.g. 4DS tensor: tensor<3, 1, 4, 5>,
// this function returns a tensor: tensor<3xi64> = [3, 4, 5]
Value create3DSShapeFrom4DS(OpBuilder &builder, Location loc, Value val4DS) {
OnnxBuilder create(builder, loc);
ArrayRef<int64_t> shape4DS = getShape(val4DS.getType());
assert(shape4DS.size() == 4 && "The tensor must have rank of 4");
assert(shape4DS[1] == 1 && "The second dim must be 1");
if (hasStaticShape(val4DS.getType())) {
return create.constantInt64(
ArrayRef<int64_t>{shape4DS[0], shape4DS[2], shape4DS[3]});
}
Value dim0 = create.dim(val4DS, 0);
Value dim1 = create.dim(val4DS, 2);
Value dim2 = create.dim(val4DS, 3);
return create.concat(
RankedTensorType::get({3}, builder.getI64Type()), {dim0, dim1, dim2}, 0);
}

} // namespace zhigh
} // namespace onnx_mlir
7 changes: 0 additions & 7 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,6 @@ bool hasNNPAUse(mlir::Value v);
/// Get saturation settings.
mlir::IntegerAttr getDefaultSaturation(mlir::PatternRewriter &rewriter);

/// Create an array tensor to contain three dimensions of layout 3DS.
/// The tensor is created from 4DS's shape by removing the value 1 at axis 1.
/// e.g. 4DS tensor: tensor<3, 1, 4, 5>,
/// this function returns a tensor: tensor<3xi64> = [3, 4, 5]
mlir::Value create3DSShapeFrom4DS(
mlir::OpBuilder &builder, mlir::Location loc, mlir::Value val3DS);

} // namespace zhigh
} // namespace onnx_mlir
#endif
43 changes: 36 additions & 7 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ def IsNoneType : Constraint<CPred<"mlir::isa<NoneType>(($_self).getType())">>;

// Create an ONNX Shape Op with type
def CreateShapeOp: NativeCodeCall<
"$_builder.create<mlir::ONNXShapeOp>($_loc, $0, $1, IntegerAttr(), 0)"
"::onnx_mlir::OnnxBuilder($_builder, $_loc).shape($0);"
>;

def Create3DShapePermuteRightmostOp: NativeCodeCall<
"::onnx_mlir::OnnxBuilder($_builder, $_loc).shape($0, {0, 2, 1});"
>;

// Get a type for a tensor that stores the shape of another tensor.
Expand Down Expand Up @@ -248,14 +252,39 @@ def GetDefaultSaturation : NativeCodeCall<
class IsConstOf<int v>: Constraint<
CPred<"onnx_mlir::isConstOf($0, " # v # ")">,
"Value is a s scalar constant of v"
>;

// Create an array tensor to contain three dimensions of layout 3DS.
// The tensor is created from 4DS's shape by removing the value 1 at axis 1.
// e.g. 4DS tensor: tensor<3, 1, 4, 5>,
// this function returns a tensor: tensor<3xi64> = [3, 4, 5]
def Create3DSShapeFrom4DS: NativeCodeCall<
"::onnx_mlir::OnnxBuilder($_builder, $_loc).shape($0, {0, 2, 3});"
>;

def GetNullAttr : NativeCodeCall<"Attribute()">;

def GetZeroI64Attr: NativeCodeCall<
"IntegerAttr::get($_builder.getIntegerType(64, /*isSigned=*/true), APInt(64, 0, /*isSigned=*/true))"
>;

def GetOneI64Attr: NativeCodeCall<
"IntegerAttr::get($_builder.getIntegerType(64, /*isSigned=*/true), APInt(64, 1, /*isSigned=*/true))"
>;

class IsInt64NAttr<int n> : Constraint<
CPred<"$0.getValue().getSExtValue() == " # n>,
"The signed extended int64 attribute equal to N"
>;

def IsCompatibleWithNNPALevelArch14: Constraint<
CPred<"isCompatibleWithNNPALevel(NNPALevel::M14)">,
"Input level is compatible with NNPA level"
>;

// Create an array tensor to contain three dimensions of layout 3DS.
// The tensor is created from 4DS's shape by removing the value 1 at axis 1.
// e.g. 4DS tensor: tensor<3, 1, 4, 5>,
// this function returns a tensor: tensor<3xi64> = [3, 4, 5]
def Create3DSShapeFrom4DS: NativeCodeCall<
"::onnx_mlir::zhigh::create3DSShapeFrom4DS($_builder, $_loc, $0)"
def IsCompatibleWithNNPALevelArch15: Constraint<
CPred<"isCompatibleWithNNPALevel(NNPALevel::M15)">,
"Input level is compatible with NNPA level"
>;

#endif // OP_HELPER
2 changes: 2 additions & 0 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
//===----------------------------------------------------------------------===//

#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp"
#include "src/Accelerators/NNPA/Support/NNPALimit.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"

using namespace mlir;
Expand Down
11 changes: 7 additions & 4 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/ZHighStick.td
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,11 @@ def ReshapeTransposeReshapeRoberta3DSWPattern1 : Pat<
$shape2, $_),
$layout3DS, $saturation),
// Output: initial X value unchanged, but transformed with the new compatible shape.
(ZHighReshapeOp $X, (CreateShapeOp (GetShapeTypeOf $stick), $stick), (GetLayout $stick)),
(ZHighReshapeOp $X, (CreateShapeOp $stick), (GetLayout $stick)),
// Conditions.
[(TensorHas3DSLayout $X), (Is3DSLayout $layout3DS), // Input/output are 3DS.
(IsStaticShapeTensor $X), (IsStaticShapeTensor $unstick), // Static shapes only.
(IsStaticShapeTensor $X), (IsStaticShapeTensor $unstick),
// Static shapes only.
(IsStaticShapeTensor $reshape1), (IsStaticShapeTensor $transpose),
(IsStaticShapeTensor $reshape2),(IsStaticShapeTensor $stick),
(IsShapeDimMultipleOf32<1> $X), // Second dim of input is a multiple of 32.
Expand Down Expand Up @@ -222,10 +223,11 @@ def ReshapeTransposeReshapeRoberta3DSWPattern2 : Pat<
$shape2, $_),
$layout3DS, $saturation),
// Output: initial X value unchanged, but transformed with the compatible shape.
(ZHighReshapeOp $X, (CreateShapeOp (GetShapeTypeOf $stick), $stick), (GetLayout $stick)),
(ZHighReshapeOp $X, (CreateShapeOp $stick), (GetLayout $stick)),
// Conditions.
[(TensorHas3DSLayout $X), (Is3DSLayout $layout3DS), // Input/output are 3DS.
(IsStaticShapeTensor $X), (IsStaticShapeTensor $unstick), // Static shapes only.
(IsStaticShapeTensor $X), (IsStaticShapeTensor $unstick),
// Static shapes only.
(IsStaticShapeTensor $reshape1), (IsStaticShapeTensor $transpose),
(IsStaticShapeTensor $reshape2),(IsStaticShapeTensor $stick),
(IsShapeDimMultipleOf32<1> $X), // Second dim of input is a multiple of 32.
Expand Down Expand Up @@ -317,6 +319,7 @@ def ReshapeTransposeReshape3DSTo2DPattern : Pat<
]
>;


// Pattern in the CCFD model.
// 4DS and 3DS have exactly same data values when the second dim of 4DS is 1.
// (The second dim of 4DS indicates unidirectional or bidirectional LSTM/GRU/RNN,
Expand Down
34 changes: 34 additions & 0 deletions src/Dialect/ONNX/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,40 @@ Value OnnxBuilder::shape(
toTensor(outputType), toTensor(input), endAttr, startAttr);
}

// Get the shape of an input and perform a permutation on it. Perm values are
// in the range [0, rank(input)). Type is inferred. Operation get the dimensions
// using onnx.dim and use onnx.concat to place the right value at the right
// position.
Value OnnxBuilder::shape(Value input, mlir::ArrayRef<int64_t> perm) const {
ShapedType inputType = mlir::cast<ShapedType>(input.getType());
int64_t inputRank = inputType.getRank();
auto inputShape = inputType.getShape();
int64_t permRank = perm.size();
bool isStatic = llvm::none_of(
inputShape, [](int64_t d) { return ShapedType::isDynamic(d); });
if (isStatic) {
// Static, no need to create dims. Gather shapes into a constant array.
llvm::SmallVector<int64_t, 4> permutedShapes;
for (int64_t p = 0; p < permRank; ++p) {
int64_t d = perm[p];
assert(d >= 0 && d < inputRank &&
"perm values expected in [0..rank(input))");
permutedShapes.emplace_back(inputShape[d]);
}
return constantInt64(permutedShapes);
}
// Dynamic shape: create the dims as needed and gather values in a concat.
llvm::SmallVector<Value, 4> permutedDims;
for (int64_t p = 0; p < permRank; ++p) {
int64_t d = perm[p];
assert(
d >= 0 && d < inputRank && "perm values expected in [0..rank(input))");
permutedDims.emplace_back(dim(input, d));
}
Type outputType = RankedTensorType::get({permRank}, b().getI64Type());
return concat(outputType, permutedDims, 0);
}

Value OnnxBuilder::slice(Type outputType, Value input, Value starts, Value ends,
Value axes, Value steps) const {
return createTypedOpAndInferShapes<ONNXSliceOp>(toTensor(outputType),
Expand Down
8 changes: 7 additions & 1 deletion src/Dialect/ONNX/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,18 @@ struct OnnxBuilder : DialectBuilder {

// ONNXShapeOp (start is inclusive, default 0; end is exclusive, default
// nullptr means all)
mlir::Value shape(mlir::Value input) const;
mlir::Value shape(mlir::Value input) const; // Infer the type.
mlir::Value shape(mlir::Type outputType, mlir::Value input) const;
mlir::Value shape(
mlir::Type outputType, mlir::Value input, int64_t start) const;
mlir::Value shape(mlir::Type outputType, mlir::Value input, int64_t start,
int64_t end) const;
// Get the shape of an input and permute the positions of its shape dims. Perm
// values are in the range [0, rank(input)). Say an 4D input with dims (d0,
// d1, d2, d3). Call to "Shape(input, {0, 1, 3, 2})" will produce a tensor
// with "[d0, d1, d3, d2]" values. Or call to "Shape(input, {0, 2, 3})" will
// produce a shape of reduced dimensions (4D->3D) with dims "[d0, d2, d3]".
mlir::Value shape(mlir::Value input, mlir::ArrayRef<int64_t> perm) const;

// ONNXSliceOp
mlir::Value slice(mlir::Type outputType, mlir::Value input,
Expand Down
3 changes: 2 additions & 1 deletion src/Dialect/ONNX/ONNXOps/OpHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "llvm/Support/Path.h"

#include "src/Dialect/Mlir/IndexExpr.hpp"
#include "src/Dialect/ONNX/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXLayoutHelper.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
Expand Down Expand Up @@ -877,7 +878,7 @@ std::string getNodeNameInPresenceOfOpt(Operation *op, bool useFileLine) {
// Support for DenseElementsAttr.
//===----------------------------------------------------------------------===//

bool isElementAttrUninitializedDenseResource(mlir::ElementsAttr elementsAttr) {
bool isElementAttrUninitializedDenseResource(ElementsAttr elementsAttr) {
const auto denseResourceElementsAttr =
mlir::dyn_cast<DenseResourceElementsAttr>(elementsAttr);
return denseResourceElementsAttr &&
Expand Down

0 comments on commit f44085b

Please sign in to comment.