Skip to content

Commit

Permalink
test case (#2957)
Browse files Browse the repository at this point in the history
Signed-off-by: chentong319 <[email protected]>
  • Loading branch information
chentong319 authored Sep 30, 2024
1 parent e5901e2 commit f8e2466
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/gru.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,34 @@ func.func @test_onnx_to_zhigh_gru0_bidir_dyn(%X: tensor<?x?x?xf32>, %W: tensor<2

// -----

func.func @gru_with_len(%arg0: tensor<2x2x1xf32>, %arg1: tensor<1x3x1xf32>, %arg2 : tensor<1x3x1xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%lens = onnx.Constant dense<[2, 1]> : tensor<2xi32>
%cst = "onnx.NoValue"() {value} : () -> none
%res:2 = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %lens, %cst) {layout = 0 : si64, linear_before_reset = 1 : si64}
: ( tensor<2x2x1xf32>, tensor<1x3x1xf32>, tensor<1x3x1xf32>, none, tensor<2xi32>, none) -> (tensor<*xf32>, tensor<*xf32>)
onnx.Return %res#0, %res#1 : tensor<*xf32>, tensor<*xf32>

// CHECK-LABEL: func.func @gru_with_len
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x2x1xf32>, [[PARAM_1_:%.+]]: tensor<1x3x1xf32>, [[PARAM_2_:%.+]]: tensor<1x3x1xf32>) -> (tensor<2x1x2x1xf32>, tensor<1x2x1xf32>) {
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 1]> : tensor<2xi32>
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor<2x2x1xf32>) -> tensor<2x2x1xf16, #zhigh.layout<{dataLayout = "3DS"}>>
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [0, 2, 1]} : (tensor<1x3x1xf32>) -> tensor<1x1x3xf32>
// CHECK: [[VAR_4_:%.+]]:3 = "onnx.SplitV11"([[VAR_3_]]) {axis = 2 : si64} : (tensor<1x1x3xf32>) -> (tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<1x1x1xf32>)
// CHECK-DAG: [[VAR_5_:%.+]] = "zhigh.StickForGRU"([[VAR_4_]]#0, [[VAR_4_]]#1, [[VAR_4_]]#2) : (tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<1x1x1xf32>) -> tensor<*xf16>
// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Transpose"([[PARAM_2_]]) {perm = [0, 2, 1]} : (tensor<1x3x1xf32>) -> tensor<1x1x3xf32>
// CHECK: [[VAR_7_:%.+]]:3 = "onnx.SplitV11"([[VAR_6_]]) {axis = 2 : si64} : (tensor<1x1x3xf32>) -> (tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<1x1x1xf32>)
// CHECK: [[VAR_8_:%.+]] = "zhigh.StickForGRU"([[VAR_7_]]#0, [[VAR_7_]]#1, [[VAR_7_]]#2) : (tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<1x1x1xf32>) -> tensor<*xf16>
// CHECK: [[VAR_9_:%.+]] = "zhigh.GRU"([[VAR_2_]], [[VAR_1_]], [[VAR_5_]], [[VAR_1_]], [[VAR_8_]], [[VAR_1_]]) {direction = "forward", hidden_size = 1 : si64, return_all_steps = -1 : si64} : (tensor<2x2x1xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none, tensor<*xf16>, none, tensor<*xf16>, none) -> tensor<*xf16>
// CHECK: [[VAR_10_:%.+]] = "zhigh.Unstick"([[VAR_9_]]) : (tensor<*xf16>) -> tensor<2x1x2x1xf32>
// CHECK-DAG: [[VAR_11_:%.+]] = "zhigh.FixGRUY"([[VAR_10_]], [[VAR_0_]], [[VAR_1_]]) : (tensor<2x1x2x1xf32>, tensor<2xi32>, none) -> tensor<2x1x2x1xf32>
// CHECK-DAG: [[VAR_12_:%.+]] = "zhigh.FixGRUYh"([[VAR_10_]], [[VAR_0_]]) : (tensor<2x1x2x1xf32>, tensor<2xi32>) -> tensor<1x2x1xf32>
// CHECK: onnx.Return [[VAR_11_]], [[VAR_12_]] : tensor<2x1x2x1xf32>, tensor<1x2x1xf32>
// CHECK: }
}

// -----

// COM : Maximum hidden_size in GRU is 10880. Not lowered when using 10881.

func.func @test_onnx_to_zhigh_gru_exceed_num_hidden(%X: tensor<7x2000x204xf32>, %W: tensor<1x16384x204xf32>, %R: tensor<1x16384x10881xf32>, %B: tensor<1x16386xf32>) -> (tensor<7x1x2000x10881xf32>, tensor<1x2000x10881xf32>) {
Expand Down

0 comments on commit f8e2466

Please sign in to comment.