From 34d70f5fae524f0f0647a3909d6d1bcc4d8fa991 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Wed, 8 Jan 2025 10:15:55 +0800 Subject: [PATCH] [QNN] MatMul Op Builder to Handle All Cases of ONNX's MatMul (#22639) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ONNX's MatMul is same as numpy.matmul, which supports input tensors with rank >= 1. But QNN's MatMul can only support input tensors with rank >= 2. This PR is to add MatMulOpBuilder for QNN EP to build QNN graph to support all possible cases of ONNX's MatMul, by adding Reshape nodes if necessary, e.g., if Reshape 1D input to 2D if exists, and Reshape output to expected shape at the end.   This PR also tries to use FullyConnected Op for MatMul if 2nd input is 2D initializer or 1D tensor because FullyConnected is faster than MatMul on QNN EP. If 2nd input is 2D tensor, we require it an initializer because FullyConnected requires 2nd input in [n, k] shape, we can transpose it when graph building if it's an initializer (we don't want to add extra Transpose node). Use swin_base model as example, which contains several MatMul nodes with 2nd input is 2D initializer (not followed by Add), running on Gen3 mobile device, before the change, it takes 34.8876 ms, after this change, it's 27.0639 ms. --- .../qnn/builder/op_builder_factory.cc | 5 +- .../qnn/builder/op_builder_factory.h | 2 + .../builder/opbuilder/matmul_op_builder.cc | 227 ++++++++++ .../qnn/builder/qnn_model_wrapper.cc | 48 +- .../providers/qnn/builder/qnn_model_wrapper.h | 11 + .../test/providers/qnn/matmul_test.cpp | 417 +++++++----------- 6 files changed, 436 insertions(+), 274 deletions(-) create mode 100644 onnxruntime/core/providers/qnn/builder/opbuilder/matmul_op_builder.cc diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index 6ef17b40d274b..e411c2a6bf536 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -51,7 +51,6 @@ OpBuilderRegistrations::OpBuilderRegistrations() { CreateSimpleOpBuilder("Sub", *this); CreateSimpleOpBuilder("Tanh", *this); - CreateSimpleOpBuilder("MatMul", *this); CreateSimpleOpBuilder("Concat", *this); CreateSimpleOpBuilder("QuantizeLinear", *this); @@ -170,6 +169,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() { { CreateExpandOpBuilder("Expand", *this); } + + { + CreateMatMulOpBuilder("MatMul", *this); + } } const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) { diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index 1d3faba6bc69a..e11eae84341fe 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -96,5 +96,7 @@ void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateHardSigmoidOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); + +void CreateMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/matmul_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/matmul_op_builder.cc new file mode 100644 index 0000000000000..bac08f1993f47 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/matmul_op_builder.cc @@ -0,0 +1,227 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/common.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace qnn { + +/** + * ONNX's MatMul supports 1D tensor as input on both size, but neither QNN's MatMul nor FullyConnected supports it. + * So we need to add Reshape Ops if necessary. + * In two cases, FullyConnected (input_1's shape is [n, k]) is used instead of MatMul without extra Transpose Op: + * 1. input_1 is 2D initializer. + * 2. input_1 is 1D tensor. + */ +class MatMulOpBuilder : public BaseOpBuilder { + public: + MatMulOpBuilder() : BaseOpBuilder("MatMulOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(MatMulOpBuilder); + + protected: + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger, + std::vector& input_names, bool do_op_validation) const override ORT_MUST_USE_RESULT; + + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, + std::vector&& input_names, const logging::Logger& logger, + bool do_op_validation) const override ORT_MUST_USE_RESULT; +}; + +namespace { + +Status CheckInputs(const QnnModelWrapper& qnn_model_wrapper, const NodeUnitIODef& input_def_0, + const NodeUnitIODef& input_def_1, TensorInfo& input_info_0, TensorInfo& input_info_1, + bool& use_fully_connected) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input_def_0, input_info_0)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input_def_1, input_info_1)); + + // Use FullyConnected if 2nd input is 2D initializer or 1D tensor. + // FullyConnected cannot pass the Op validation if keep_dims is true, so if input_0 is per-channel quantized tensor + // with rank > 2, it's not easy to set the quantization parameters for the output reshaped 2D tensor. + // In this case, we will not use FullyConnected. + use_fully_connected = + (input_info_1.shape.size() == 2 && input_info_1.is_initializer) || input_info_1.shape.size() == 1; + use_fully_connected = + use_fully_connected && !(input_info_0.quant_param.IsPerChannel() && input_info_0.shape.size() > 2); + return Status::OK(); +} + +} // namespace + +Status MatMulOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, + const logging::Logger& logger, std::vector& input_names, + bool do_op_validation) const { + const auto& inputs = node_unit.Inputs(); + TensorInfo input_info_0{}; + TensorInfo input_info_1{}; + bool use_fully_connected = false; + ORT_RETURN_IF_ERROR( + CheckInputs(qnn_model_wrapper, inputs[0], inputs[1], input_info_0, input_info_1, use_fully_connected)); + bool reshape_input_0 = input_info_0.shape.size() == 1; + bool reshape_input_1 = input_info_1.shape.size() == 1; + + // Process input 0. + const std::string& org_input_0_name = inputs[0].node_arg.Name(); + std::string input_0_name = org_input_0_name; + if (reshape_input_0) { + input_0_name = org_input_0_name + "_ort_qnn_ep_reshape"; + std::vector shape_2d{1, input_info_0.shape[0]}; + QnnQuantParamsWrapper quant_param_2d = input_info_0.quant_param.Copy(); + ORT_RETURN_IF_ERROR(quant_param_2d.HandleUnsqueeze(input_info_0.shape, shape_2d)); + + // If input_0 is initializer, unpack it and add the tensor with new quantization parameter and shape. + // Otherwise, add a Reshape node. + if (input_info_0.is_initializer) { + std::vector unpacked_tensor; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_info_0.initializer_tensor, unpacked_tensor)); + Qnn_TensorType_t tensor_type = qnn_model_wrapper.GetTensorType(org_input_0_name); + QnnTensorWrapper input_tensorwrapper(input_0_name, tensor_type, input_info_0.qnn_data_type, + std::move(quant_param_2d), std::move(shape_2d), std::move(unpacked_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); + } else { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(org_input_0_name, input_0_name, input_info_0.shape, shape_2d, + input_info_0.qnn_data_type, input_info_0.quant_param, + quant_param_2d, do_op_validation, + qnn_model_wrapper.IsGraphInput(org_input_0_name), false)); + } + } else { + if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_0_name)) { + LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_0_name; + } else { + QnnTensorWrapper input_0_tensor; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(inputs[0], input_0_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_0_tensor)), "Failed to add tensor."); + } + } + input_names.emplace_back(input_0_name); + + // Process input 1. + const std::string& org_input_1_name = inputs[1].node_arg.Name(); + std::string input_1_name = org_input_1_name; + if (reshape_input_1 || use_fully_connected) { + std::vector shape_2d; + QnnQuantParamsWrapper quant_param_2d = input_info_1.quant_param.Copy(); + if (reshape_input_1) { + // Input is 1D tensor. + input_1_name = org_input_1_name + "_ort_qnn_ep_reshape"; + if (use_fully_connected) { + // FullyConnected requires input_1's shape to be [n, k]. + shape_2d = {1, input_info_1.shape[0]}; + } else { + shape_2d = {input_info_1.shape[0], 1}; + } + ORT_RETURN_IF_ERROR(quant_param_2d.HandleUnsqueeze(input_info_1.shape, shape_2d)); + } else { + input_1_name = org_input_1_name + "_ort_qnn_ep_transpose"; + shape_2d = {input_info_1.shape[1], input_info_1.shape[0]}; + ORT_RETURN_IF_ERROR(quant_param_2d.HandleTranspose(std::vector({1, 0}))); + } + + // If input_1 is initializer, unpack it and add the tensor with new quantization parameter and shape. + // Otherwise, add a Reshape node. + if (input_info_1.is_initializer) { + std::vector unpacked_tensor; + if (use_fully_connected && !reshape_input_1) { + // 2D initializer should be transposed to [n, k]. + ORT_RETURN_IF_ERROR(TwoDimensionTranspose(qnn_model_wrapper, input_info_1.shape, + *input_info_1.initializer_tensor, unpacked_tensor)); + } else { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_info_1.initializer_tensor, unpacked_tensor)); + } + + Qnn_TensorType_t tensor_type = qnn_model_wrapper.GetTensorType(org_input_1_name); + QnnTensorWrapper input_tensorwrapper(input_1_name, tensor_type, input_info_1.qnn_data_type, + std::move(quant_param_2d), std::move(shape_2d), std::move(unpacked_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); + } else { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(org_input_1_name, input_1_name, input_info_1.shape, shape_2d, + input_info_1.qnn_data_type, input_info_1.quant_param, + quant_param_2d, do_op_validation, + qnn_model_wrapper.IsGraphInput(org_input_1_name), false)); + } + } else { + if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_1_name)) { + LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_1_name; + } else { + QnnTensorWrapper input_1_tensor; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(inputs[1], input_1_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_1_tensor)), "Failed to add tensor."); + } + } + input_names.emplace_back(input_1_name); + + return Status::OK(); +} + +Status MatMulOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& /*logger*/, bool do_op_validation) const { + const auto& inputs = node_unit.Inputs(); + TensorInfo input_info_0{}; + TensorInfo input_info_1{}; + bool use_fully_connected = false; + ORT_RETURN_IF_ERROR( + CheckInputs(qnn_model_wrapper, inputs[0], inputs[1], input_info_0, input_info_1, use_fully_connected)); + bool reshape_input_0 = input_info_0.shape.size() == 1; + bool reshape_input_1 = input_info_1.shape.size() == 1; + bool reshape_output = reshape_input_0 || reshape_input_1 || (use_fully_connected && input_info_0.shape.size() > 2); + + const std::string& org_output_name = node_unit.Outputs()[0].node_arg.Name(); + std::string op_output_name = org_output_name; + TensorInfo output_info{}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info)); + std::vector op_output_shape = output_info.shape; + QnnQuantParamsWrapper op_output_quant_param = output_info.quant_param.Copy(); + if (reshape_output) { + op_output_name = org_output_name + "_ort_qnn_ep_reshape"; + if (use_fully_connected && input_info_0.shape.size() > 2) { + op_output_shape = {std::accumulate(input_info_0.shape.begin(), input_info_0.shape.end() - 1, + static_cast(1), std::multiplies()), + reshape_input_1 ? 1 : input_info_1.shape.back()}; + ORT_ENFORCE(!op_output_quant_param.IsPerChannel()); + } else { + // If both inputs are 1D tensors, the output shape is [1] instead of scalar. So if both inputs are 1D tensors, + // we only need to add one "1" to the op_output_shape. + if (reshape_input_1) { + op_output_shape.emplace_back(1); + } else if (reshape_input_0) { + op_output_shape.insert(op_output_shape.end() - 1, 1); + } + ORT_RETURN_IF_ERROR(op_output_quant_param.HandleUnsqueeze(output_info.shape, op_output_shape)); + } + } + + const bool is_graph_output = qnn_model_wrapper.IsGraphOutput(org_output_name); + const bool is_op_output_graph_output = is_graph_output && !reshape_output; + Qnn_TensorType_t op_output_tensor_type = + is_op_output_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; + QnnTensorWrapper op_output_tensor_wrapper(op_output_name, op_output_tensor_type, output_info.qnn_data_type, + op_output_quant_param.Copy(), std::vector(op_output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(op_output_tensor_wrapper)), + "Failed to add output tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, + use_fully_connected ? QNN_OP_FULLY_CONNECTED : QNN_OP_MAT_MUL, + std::move(input_names), {op_output_name}, {}, do_op_validation), + "Failed to add fused Matmul node."); + + if (reshape_output) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode( + op_output_name, org_output_name, op_output_shape, output_info.shape, output_info.qnn_data_type, + op_output_quant_param, output_info.quant_param, do_op_validation, false, is_graph_output)); + } + + return Status::OK(); +} + +void CreateMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.AddOpBuilder(op_type, std::make_unique()); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index 2c7f3c8b22ddd..129a015164ad4 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -495,49 +495,45 @@ Status QnnModelWrapper::GetTensorInfo(const NodeUnitIODef& input, TensorInfo& te return Status::OK(); } -Status QnnModelWrapper::AddReshapeNode(const std::string& input_name, - const std::string& output_name, +Status QnnModelWrapper::AddReshapeNode(const std::string& input_name, const std::string& output_name, const std::vector& input_shape, const std::vector& output_shape, const Qnn_DataType_t& tensor_data_type, - const QnnQuantParamsWrapper& quantize_param, - bool do_op_validation, - bool is_for_input, - bool is_for_output) { - // Do not allow QNN EP to insert Reshape nodes with per-channel quantization on dynamic tensors. - // We could technically support this by shifting the quantization param's axis value, but - // we don't need this right now. - ORT_RETURN_IF(quantize_param.IsPerChannel(), - "Do not support inserted Reshape nodes with per-channel quantization"); - QnnTensorWrapper input_tensorwrapper(input_name, - is_for_input ? QNN_TENSOR_TYPE_APP_WRITE : QNN_TENSOR_TYPE_NATIVE, - tensor_data_type, - quantize_param.Copy(), + const QnnQuantParamsWrapper& input_quantize_param, + const QnnQuantParamsWrapper& output_quantize_param, bool do_op_validation, + bool is_for_input, bool is_for_output) { + QnnTensorWrapper input_tensorwrapper(input_name, is_for_input ? QNN_TENSOR_TYPE_APP_WRITE : QNN_TENSOR_TYPE_NATIVE, + tensor_data_type, input_quantize_param.Copy(), std::vector(input_shape)); ORT_RETURN_IF_NOT(AddTensorWrapper(std::move(input_tensorwrapper)), "QNN EP: Failed to add input tensor for inserted Reshape."); Qnn_TensorType_t tensor_type = is_for_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; - QnnTensorWrapper output_tensorwrapper(output_name, - tensor_type, - tensor_data_type, - quantize_param.Copy(), + QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, tensor_data_type, output_quantize_param.Copy(), std::vector(output_shape)); ORT_RETURN_IF_NOT(AddTensorWrapper(std::move(output_tensorwrapper)), "QNN EP: Failed to add output tensor for inserted Reshape."); - ORT_RETURN_IF_NOT(CreateQnnNode(output_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_RESHAPE, - {input_name}, - {output_name}, - {}, - do_op_validation), + ORT_RETURN_IF_NOT(CreateQnnNode(output_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_RESHAPE, {input_name}, + {output_name}, {}, do_op_validation), "QNN EP: Failed to create manually inserted Qnn Reshape node."); return Status::OK(); } +Status QnnModelWrapper::AddReshapeNode(const std::string& input_name, const std::string& output_name, + const std::vector& input_shape, + const std::vector& output_shape, + const Qnn_DataType_t& tensor_data_type, + const QnnQuantParamsWrapper& quantize_param, bool do_op_validation, + bool is_for_input, bool is_for_output) { + // Do not allow QNN EP to insert Reshape nodes with per-channel quantization on dynamic tensors + // if only one quantization param is provided. + ORT_RETURN_IF(quantize_param.IsPerChannel(), "Do not support inserted Reshape nodes with per-channel quantization"); + return AddReshapeNode(input_name, output_name, input_shape, output_shape, tensor_data_type, quantize_param, + quantize_param, do_op_validation, is_for_input, is_for_output); +} + Status QnnModelWrapper::AddTransposeNode(NodeIndex node_index, const std::string& input_name, const std::string& output_name, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index f3e52050e79e0..19d0f058116a4 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -141,6 +141,17 @@ class QnnModelWrapper { Status GetTensorInfo(const NodeUnitIODef& input, TensorInfo& input_info) const; + Status AddReshapeNode(const std::string& input_name, + const std::string& output_name, + const std::vector& input_shape, + const std::vector& output_shape, + const Qnn_DataType_t& tensor_data_type, + const QnnQuantParamsWrapper& input_quantize_param, + const QnnQuantParamsWrapper& output_quantize_param, + bool do_op_validation, + bool is_for_input = true, + bool is_for_output = false); + Status AddReshapeNode(const std::string& input_name, const std::string& output_name, const std::vector& input_shape, diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index 5c6967761b1db..74edc25939e00 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -26,6 +26,31 @@ static GetTestModelFn BuildMatMulOpTestCase(const TestInputDef& input1_de }; } +static void RunMatMulOpTest(bool is_htp_backend, const std::vector& shape_0, + const std::vector& shape_1, bool is_initializer_0, bool is_initializer_1, + ExpectedEPNodeAssignment expected_ep_assignment = ExpectedEPNodeAssignment::All, + int opset = 18, float f32_abs_err = 1e-4f) { + ProviderOptions provider_options; + if (is_htp_backend) { +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + } else { +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + } + + RunQnnModelTest(BuildMatMulOpTestCase( + TestInputDef(shape_0, is_initializer_0, GetSequentialFloatData(shape_0, 0.01f, 0.02f)), + TestInputDef(shape_1, is_initializer_1, GetSequentialFloatData(shape_1, 0.02f, 0.02f))), + provider_options, opset, expected_ep_assignment, f32_abs_err); +} + // Returns a function that creates a graph with a QDQ MatMul operator. template static GetTestQDQModelFn BuildMatMulOpQDQTestCase(const TestInputDef& input0_def, @@ -36,13 +61,13 @@ static GetTestQDQModelFn BuildMatMulOpQDQTestCase(const TestInputDe // input1 -> Q -> DQ -> NodeArg* input0 = MakeTestInput(builder, input0_def); QuantParams input0_qparams = GetTestInputQuantParams(input0_def); - auto* input0_qdq = AddQDQNodePair(builder, input0, input0_qparams.scale, input0_qparams.zero_point, - use_contrib_qdq); + auto* input0_qdq = + AddQDQNodePair(builder, input0, input0_qparams.scale, input0_qparams.zero_point, use_contrib_qdq); // input1 -> Q -> DQ -> NodeArg* input1 = MakeTestInput(builder, input1_def); QuantParams input1_qparams = GetTestInputQuantParams(input1_def); - auto* input1_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, input1_qparams.zero_point, - use_contrib_qdq); + auto* input1_qdq = + AddQDQNodePair(builder, input1, input1_qparams.scale, input1_qparams.zero_point, use_contrib_qdq); // MatMul auto* op_output = builder.MakeIntermediate(); @@ -59,16 +84,15 @@ static GetTestQDQModelFn BuildQDQPerChannelMatMulTestCase(const Tes const TestInputDef& weights_def, int64_t weight_quant_axis, bool use_contrib_qdq = false) { - return [input_def, weights_def, weight_quant_axis, - use_contrib_qdq](ModelTestBuilder& builder, - std::vector>& output_qparams) { + return [input_def, weights_def, weight_quant_axis, use_contrib_qdq]( + ModelTestBuilder& builder, std::vector>& output_qparams) { std::vector matmul_inputs; // input -> Q/DQ -> auto* input = MakeTestInput(builder, input_def); QuantParams input_qparams = GetTestInputQuantParams(input_def); - auto* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, - use_contrib_qdq); + auto* input_qdq = + AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, use_contrib_qdq); matmul_inputs.push_back(input_qdq); // Quantized(weights) -> DQ -> @@ -89,14 +113,13 @@ static GetTestQDQModelFn BuildQDQPerChannelMatMulTestCase(const Tes num_weight_storage_elems = Int4x2::CalcNumInt4Pairs(weights_shape.Size()); } quantized_weights.resize(num_weight_storage_elems); - QuantizeValues(weights_def.GetRawData(), quantized_weights, weights_shape, - weight_scales, weight_zero_points, pos_weight_quant_axis); + QuantizeValues(weights_def.GetRawData(), quantized_weights, weights_shape, weight_scales, + weight_zero_points, pos_weight_quant_axis); NodeArg* weights_initializer = builder.MakeInitializer(weights_def.GetShape(), quantized_weights); NodeArg* weights_dq = builder.MakeIntermediate(); - Node& weights_dq_node = builder.AddDequantizeLinearNode(weights_initializer, weight_scales, - weight_zero_points, weights_dq, - nullptr, use_contrib_qdq); + Node& weights_dq_node = builder.AddDequantizeLinearNode( + weights_initializer, weight_scales, weight_zero_points, weights_dq, nullptr, use_contrib_qdq); weights_dq_node.AddAttribute("axis", weight_quant_axis); matmul_inputs.push_back(weights_dq); @@ -108,17 +131,11 @@ static GetTestQDQModelFn BuildQDQPerChannelMatMulTestCase(const Tes }; } -// Runs a QDQ per-channel MatMul model on the QNN HTP backend. Checks the graph node assignment, and that the -// QDQ model is accurate on QNN EP (compared to CPU EP). -template -static void RunQDQPerChannelMatMulOpOpTest(const TestInputDef& input_def, - const TestInputDef& weights_def, - int64_t weight_quant_axis, - ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 21, - bool use_contrib_qdq = false, - QDQTolerance tolerance = QDQTolerance(), - bool enable_fp16_precision = true) { +template +static void RunQDQMatMulOpTest(const std::vector& shape_0, const std::vector& shape_1, + bool is_initializer_0, bool is_initializer_1, + ExpectedEPNodeAssignment expected_ep_assignment = ExpectedEPNodeAssignment::All, + int opset = 21, bool use_contrib_qdq = false) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -126,52 +143,29 @@ static void RunQDQPerChannelMatMulOpOpTest(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif - if (enable_fp16_precision) { - provider_options["enable_htp_fp16_precision"] = "1"; - } else { - provider_options["enable_htp_fp16_precision"] = "0"; - } - - TestQDQModelAccuracy(BuildMatMulOpTestCase(input_def, weights_def), - BuildQDQPerChannelMatMulTestCase(input_def, - weights_def, - weight_quant_axis, - use_contrib_qdq), - provider_options, - opset, - expected_ep_assignment, - tolerance); -} - -// Runs an MatMul model on the QNN CPU backend. Checks the graph node assignment, and that inference -// outputs for QNN and CPU match. -static void RunMatMulOpOpTest(const TestInputDef& input1_def, - const TestInputDef& input2_def, - ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 13, - float f32_abs_err = 1e-4f) { - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnCpu.dll"; -#else - provider_options["backend_path"] = "libQnnCpu.so"; -#endif - - RunQnnModelTest(BuildMatMulOpTestCase(input1_def, input2_def), - provider_options, - opset, - expected_ep_assignment, - f32_abs_err); + TestInputDef input0_def( + shape_0, is_initializer_0, + GetFloatDataInRange(-0.1f, 0.1f, + static_cast(std::accumulate(shape_0.begin(), shape_0.end(), static_cast(1), + std::multiplies())))); + TestInputDef input1_def( + shape_1, is_initializer_1, + GetFloatDataInRange(-0.1f, 0.1f, + static_cast(std::accumulate(shape_1.begin(), shape_1.end(), static_cast(1), + std::multiplies())))); + + TestQDQModelAccuracy( + BuildMatMulOpTestCase(input0_def, input1_def), + BuildMatMulOpQDQTestCase(input0_def, input1_def, use_contrib_qdq), + provider_options, opset, expected_ep_assignment); } -// Runs a QDQ MatMul model on the QNN HTP backend. Checks the graph node assignment, and that the -// QDQ model is accurate on QNN EP (compared to CPU EP). -template -static void RunQDQMatMulOpOpTest(const TestInputDef& input1_def, - const TestInputDef& input2_def, - ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 18, - bool use_contrib_qdq = false) { +template +static void RunQDQPerChannelMatMulOpTest( + const std::vector& shape_input, const std::vector& shape_weight, int64_t weight_quant_axis, + QDQTolerance tolerance = QDQTolerance(), + ExpectedEPNodeAssignment expected_ep_assignment = ExpectedEPNodeAssignment::All, int opset = 21, + bool use_contrib_qdq = false, bool enable_fp16_precision = true) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -179,197 +173,126 @@ static void RunQDQMatMulOpOpTest(const TestInputDef& input1_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildMatMulOpTestCase(input1_def, input2_def), - BuildMatMulOpQDQTestCase(input1_def, input2_def, - use_contrib_qdq), - provider_options, - opset, - expected_ep_assignment); + if (enable_fp16_precision) { + provider_options["enable_htp_fp16_precision"] = "1"; + } else { + provider_options["enable_htp_fp16_precision"] = "0"; + } + + TestInputDef input_def( + shape_input, false, + GetFloatDataInRange(-0.1f, 0.1f, + static_cast(std::accumulate(shape_input.begin(), shape_input.end(), + static_cast(1), std::multiplies())))); + TestInputDef weight_def( + shape_weight, true, + GetFloatDataInRange(-0.1f, 0.1f, + static_cast(std::accumulate(shape_weight.begin(), shape_weight.end(), + static_cast(1), std::multiplies())))); + + TestQDQModelAccuracy(BuildMatMulOpTestCase(input_def, weight_def), + BuildQDQPerChannelMatMulTestCase( + input_def, weight_def, weight_quant_axis, use_contrib_qdq), + provider_options, opset, expected_ep_assignment, tolerance); } // // CPU tests: // - -// TODO: Crashes during QNN CPU execution (QNN SDK 2.22) -TEST_F(QnnCPUBackendTests, DISABLED_MatMulOp) { - RunMatMulOpOpTest(TestInputDef({2, 3}, false, {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}), - TestInputDef({3, 2}, false, {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}), - ExpectedEPNodeAssignment::All, 18); -} - -// Test MatMul broadcasting -// Failed randomly on Linux -// Value of: expected_tensor.DataAsSpan() -// Expected: contains 896 values, where each value and its corresponding value in 16-byte object -// <80-03 00-00 00-00 00-00 40-B8 53-08 CC-7F 00-00> are an almost-equal pair -// Actual: 16-byte object <80-03 00-00 00-00 00-00 C0-B7 43-08 CC-7F 00-00>, where the value pair -// (-5.19657087, 0) at index #29 don't match, which is 5.19657 from -5.19657 -TEST_F(QnnCPUBackendTests, DISABLED_MatMulOp_Broadcast) { - // Create two matrices with element values in the range [-10.0, 10.0]. - std::vector input_a = GetFloatDataInRange(-10.0f, 10.0f, 28 * 64); - std::vector input_b = GetFloatDataInRange(-10.0f, 10.0f, 64 * 32); - - RunMatMulOpOpTest(TestInputDef({28, 1, 64}, false, input_a), - TestInputDef({64, 32}, false, input_b), - ExpectedEPNodeAssignment::All, 18, 0.0004f); -} - -// TODO: Crashes during QNN CPU execution (QNN SDK 2.22) -TEST_F(QnnCPUBackendTests, DISABLED_MatMulOp_PaddingAndBroadcast_BLargerThanA) { - std::vector input0_shape = {2, 3, 2}; - std::vector input1_shape = {3, 2, 2, 1}; - RunMatMulOpOpTest(TestInputDef(input0_shape, false, GetSequentialFloatData(input0_shape)), - TestInputDef(input1_shape, false, GetSequentialFloatData(input1_shape)), - ExpectedEPNodeAssignment::All, 7); +TEST_F(QnnCPUBackendTests, MatMulOp) { + // RunMatMulOpTest(is_htp_backend, shape_0, shape_1, is_initializer_0, is_initializer_1) + RunMatMulOpTest(false, {2, 3}, {3, 2}, false, false); + RunMatMulOpTest(false, {2, 3}, {3, 2}, false, true); + RunMatMulOpTest(false, {2, 3}, {3, 2}, true, false); + RunMatMulOpTest(false, {2, 3}, {3, 2}, true, true); // constant folding + RunMatMulOpTest(false, {2, 3}, {2, 3, 2}, false, false); + RunMatMulOpTest(false, {3, 3, 3}, {3, 2}, true, false); + RunMatMulOpTest(false, {2, 3, 3, 3}, {3, 2}, false, true); + RunMatMulOpTest(false, {2, 3, 3, 3}, {2, 3, 3, 2}, false, true); + RunMatMulOpTest(false, {2, 1, 2, 3}, {3, 3, 2}, false, false); + RunMatMulOpTest(false, {3}, {3}, false, false); + RunMatMulOpTest(false, {3}, {3}, false, true); + RunMatMulOpTest(false, {3}, {3}, true, false); + RunMatMulOpTest(false, {3}, {3, 2}, false, false); + RunMatMulOpTest(false, {3}, {3, 2}, false, true); + RunMatMulOpTest(false, {3}, {3, 3, 2}, true, false); + RunMatMulOpTest(false, {2, 3}, {3}, false, false); + RunMatMulOpTest(false, {2, 3}, {3}, true, false); + RunMatMulOpTest(false, {2, 3, 3, 3}, {3}, false, false); + + // Failed randomly on Linux + // Expected: contains 36 values, where each value and its corresponding value in 16-byte object + // <24-00 00-00 00-00 00-00 40-4A 47-42 4D-56 00-00> are an almost-equal pair + // Actual: 16-byte object <24-00 00-00 00-00 00-00 80-39 2B-42 4D-56 00-00>, where the value pair (0.104199991, 0) + // at index #18 don't match, which is -0.1042 from 0.1042 + // RunMatMulOpTest(false, {2, 3, 3, 3}, {3, 2}, true, false); } #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + // // HTP tests: // - -TEST_F(QnnHTPBackendTests, MatMulOp_HTP_u8) { - std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; - std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; - RunQDQMatMulOpOpTest(TestInputDef({2, 3}, false, input0_data), - TestInputDef({3, 2}, false, input1_data), - ExpectedEPNodeAssignment::All, 18); -} - -// Test QDQ MatMul with 16-bit act, 8-bit weights (static) -TEST_F(QnnHTPBackendTests, MatMulOp_HTP_A16_W8Static) { - std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; - std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; - RunQDQMatMulOpOpTest(TestInputDef({2, 3}, false, input0_data), - TestInputDef({3, 2}, true, input1_data), - ExpectedEPNodeAssignment::All, - 18, - true); // Use com.microsoft Q/DQ ops -} - -// Test QDQ per-channel MatMul with 16-bit act, signed 4-bit weights (static) -TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_A16_WeightInt4) { - std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; - std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; - RunQDQPerChannelMatMulOpOpTest(TestInputDef({1, 1, 2, 3}, false, input0_data), - TestInputDef({1, 1, 3, 2}, true, input1_data), - 1, // quantization axis - ExpectedEPNodeAssignment::All, - 21, - false); -} - -// Test QDQ per-channel MatMul with 16-bit act, unsigned 4-bit weights (static) -TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_A16_WeightUInt4) { - std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; - std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; - RunQDQPerChannelMatMulOpOpTest(TestInputDef({1, 1, 2, 3}, false, input0_data), - TestInputDef({1, 1, 3, 2}, true, input1_data), - 1, // quantization axis - ExpectedEPNodeAssignment::All, - 21, - false); -} - -// Test QDQ per-channel MatMul with int8 act, int4 weights (static) -// QNN 2.27 regression. Also fails on QNN 2.28.2. -// Failed to finalize QNN graph. Error code: 1002 -TEST_F(QnnHTPBackendTests, DISABLED_MatMulOp_PerChannel_AS8_WeightInt4) { - std::vector input0_data = GetFloatDataInRange(-5.0f, 5.0f, 6); - std::vector input1_data = {-2.0f, -1.0f, -0.5f, 0.0f, 1.0f, 2.0f}; - RunQDQPerChannelMatMulOpOpTest(TestInputDef({1, 1, 2, 3}, false, input0_data), - TestInputDef({1, 1, 3, 2}, true, input1_data), - 1, // quantization axis - ExpectedEPNodeAssignment::All, - 21, - false, - QDQTolerance(0.007f), - false); -} - -// Test QDQ per-channel MatMul with 16-bit act, int8 weights (static) -TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_A16_WeightInt8) { - std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; - std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; - RunQDQPerChannelMatMulOpOpTest(TestInputDef({1, 1, 2, 3}, false, input0_data), - TestInputDef({1, 1, 3, 2}, true, input1_data), - 1, // quantization axis - ExpectedEPNodeAssignment::All, - 21, - false); -} - -// Test QDQ MatMul with uint16 activation uint16 weights, both dynamic -// Inaccuracy detected for output 'output_0', element 1. -// Output quant params: scale=0.0015259021893143654, zero_point=0. -// Expected val: 40 -// QNN QDQ val: 39.681087493896484 (err 0.31891250610351562) -// CPU QDQ val: 39.99847412109375 (err 0.00152587890625) -TEST_F(QnnHTPBackendTests, DISABLED_MatMulOp_HTP_A16_W16Dynamic) { - std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; - std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; - RunQDQMatMulOpOpTest(TestInputDef({2, 3}, false, input0_data), - TestInputDef({3, 2}, false, input1_data), - ExpectedEPNodeAssignment::All, - 18, - true); // Use com.microsoft Q/DQ ops -} - -// Test QDQ MatMul with uint16 activation uint16 weights, both dynamic -// Inaccuracy detected for output 'output_0', element 1. -// Output quant params: scale=0.71908456087112427, zero_point=1. -// Expected val: 46848.41015625 -// QNN QDQ val: 46844.04296875 (err 4.3671875) -// CPU QDQ val: 46848.359375 (err 0.05078125) -TEST_F(QnnHTPBackendTests, DISABLED_MatMulOp_HTP_A16_W16DynamicLarge) { - std::vector input0_data = GetFloatDataInRange(-10.0f, 10.0f, 12 * 96 * 512); - std::vector input1_data = GetFloatDataInRange(-10.0f, 10.0f, 12 * 96 * 512); - RunQDQMatMulOpOpTest(TestInputDef({1, 12, 96, 512}, false, input0_data), - TestInputDef({1, 12, 512, 96}, false, input1_data), - ExpectedEPNodeAssignment::All, - 18, - true); // Use com.microsoft Q/DQ ops -} - -// Test 16-bit QDQ MatMul with static weights -// TODO: Inaccuracy detected for output 'output', element 0. -// Output quant params: scale=0.0015259021893143654, zero_point=0. -// Expected val: 98 -// QNN QDQ val: 0.65461206436157227 (err 97.345390319824219) -// CPU QDQ val: 98.002593994140625 (err 0.002593994140625) -TEST_F(QnnHTPBackendTests, DISABLED_MatMulOp_HTP_A16_W16) { - std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; - std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; - RunQDQMatMulOpOpTest(TestInputDef({2, 3}, false, input0_data), - TestInputDef({3, 2}, true, input1_data), - ExpectedEPNodeAssignment::All, - 18, - true); // Use com.microsoft Q/DQ ops -} - -// Test 8-bit QDQ MatMul broadcasting -TEST_F(QnnHTPBackendTests, MatMulOp_Broadcast) { - RunQDQMatMulOpOpTest(TestInputDef({28, 1, 64}, false, -10.0f, 10.0f), - TestInputDef({64, 32}, false, -10.0f, 10.0f), - ExpectedEPNodeAssignment::All, 18); +TEST_F(QnnHTPBackendTests, MatMulOp) { + // RunMatMulOpTest(is_htp_backend, shape_0, shape_1, is_initializer_0, is_initializer_1, expected_ep_assignment, + // opset, f32_abs_err) + RunMatMulOpTest(true, {2, 3}, {3, 2}, false, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); + RunMatMulOpTest(true, {2, 3}, {3, 2}, false, true, ExpectedEPNodeAssignment::All, 18, 1e-2f); + RunMatMulOpTest(true, {2, 3}, {3, 2}, true, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); + RunMatMulOpTest(true, {2, 3}, {3, 2}, true, true); // constant folding + RunMatMulOpTest(true, {2, 3}, {2, 3, 2}, false, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); + RunMatMulOpTest(true, {2, 3, 3, 3}, {3, 2}, true, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); + RunMatMulOpTest(true, {2, 3, 3, 3}, {3, 2}, false, true, ExpectedEPNodeAssignment::All, 18, 1e-2f); + RunMatMulOpTest(true, {2, 3, 3, 3}, {2, 3, 3, 2}, false, true, ExpectedEPNodeAssignment::All, 18, 1e-2f); + RunMatMulOpTest(true, {2, 1, 2, 3}, {3, 3, 2}, false, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); + RunMatMulOpTest(true, {3}, {3}, false, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); + RunMatMulOpTest(true, {3}, {3}, false, true, ExpectedEPNodeAssignment::All, 18, 1e-2f); + RunMatMulOpTest(true, {3}, {3}, true, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); + RunMatMulOpTest(true, {3}, {3, 2}, false, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); + RunMatMulOpTest(true, {3}, {3, 2}, false, true, ExpectedEPNodeAssignment::All, 18, 1e-2f); + RunMatMulOpTest(true, {3}, {3, 3, 2}, true, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); + RunMatMulOpTest(true, {2, 3}, {3}, false, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); + RunMatMulOpTest(true, {2, 3}, {3}, true, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); + RunMatMulOpTest(true, {2, 3, 3, 3}, {3}, false, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); + + // Failed randomly on Linux + // Expected: contains 18 values, where each value and its corresponding value in 16-byte object + // <12-00 00-00 00-00 00-00 40-3D CC-A5 5A-7A 00-00> are an almost-equal pair + // Actual: 16-byte object <12-00 00-00 00-00 00-00 80-E8 CF-8F 5B-7A 00-00>, where the value pair + // (0.0393999927, 98304.0078) at index #6 don't match, which is 98304 from 0.0394 + // RunMatMulOpTest(true, {3, 3, 3}, {3, 2}, true, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); } -// Test 16-bit QDQ MatMul broadcasting -// TODO: Inaccuracy detected for output 'output', element 0. -// Output quant params: scale=0.0028538699261844158, zero_point=6050. -// Expected val: 169.76341247558594 -// QNN QDQ val: -16.675161361694336 (err 186.43856811523438) -// CPU QDQ val: 169.762451171875 (err 0.0009613037109375) -TEST_F(QnnHTPBackendTests, DISABLED_MatMulOp_Broadcast_A16_W16) { - std::vector input_a = GetFloatDataInRange(-10.0f, 10.0f, 28 * 64); - std::vector input_b = GetFloatDataInRange(-10.0f, 10.0f, 64 * 32); - - RunQDQMatMulOpOpTest(TestInputDef({28, 1, 64}, false, input_a), - TestInputDef({64, 32}, true, input_b), - ExpectedEPNodeAssignment::All, - 18, - true); // Use com.microsoft Q/DQ ops +TEST_F(QnnHTPBackendTests, MatMulOp_QDQ) { + // UINT8 + // RunQDQMatMulOpTest(shape_0, shape_1, is_initializer_0, is_initializer_1, expected_ep_assignment, opset, + // use_contrib_qdq) + RunQDQMatMulOpTest({2, 3}, {3, 2}, false, false); + RunQDQMatMulOpTest({2, 3}, {3, 2}, false, true); + RunQDQMatMulOpTest({2, 2, 3}, {3, 2}, true, false, ExpectedEPNodeAssignment::All, 18, + true); + RunQDQMatMulOpTest({2, 1, 3, 3}, {3, 3, 2}, false, true); + RunQDQMatMulOpTest({3}, {3}, false, false); + RunQDQMatMulOpTest({2, 3}, {3}, true, false); + + // UINT16, UINT8 + RunQDQMatMulOpTest({2, 3}, {3, 2}, false, false); + RunQDQMatMulOpTest({2, 3}, {3, 2}, false, true, ExpectedEPNodeAssignment::All, 18, true); + RunQDQMatMulOpTest({2, 3, 3, 3}, {3, 2}, true, false); + RunQDQMatMulOpTest({3}, {3, 2}, false, true); + RunQDQMatMulOpTest({2, 3, 3, 3}, {3}, false, false); + + // UINT16, per-channel signed 4-bit weight + // RunQDQPerChannelMatMulOpTest(shape_input, shape_weight, weight_quant_axis, tolerance, expected_ep_assignment, + // opset, use_contrib_qdq, enable_fp16_precision) + RunQDQPerChannelMatMulOpTest({2, 3}, {3, 2}, 1); + RunQDQPerChannelMatMulOpTest({2, 3, 3, 3}, {3, 2}, -1, QDQTolerance(), + ExpectedEPNodeAssignment::All, 18, true); + + // // UINT16, per-channel INT8 weight + RunQDQPerChannelMatMulOpTest({2, 3}, {3, 2}, 1, QDQTolerance(), + ExpectedEPNodeAssignment::All, 21, false, false); + RunQDQPerChannelMatMulOpTest({2, 3, 3}, {3}, -1); } #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)