Skip to content

Commit

Permalink
[WebNN] Add op support validation for decomposed WebNN ops
Browse files Browse the repository at this point in the history
- Some ONNX op are supported by decomposed WebNN ops, defines a decomposed_op_map
  map to specific decomposed WebNN ops list.
- WebNN ops have various first input names such as 'a', 'input', 'inputs', etc.
  Defines a webnn_op_first_input_name_map map to record special names other than
  'input', and a GetWebNNOpFirstInputName function to retrieve the first input name
  of a WebNN op.
- Check if the input and output data types are supported by each decomposed WebNN op.
- Remove the unnecessary CheckSingleOp function, WebNN's OpSupportLimits has already
  covered op supported check.
  • Loading branch information
Honry committed Jan 15, 2025
1 parent b67983c commit 66e40bd
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 52 deletions.
13 changes: 5 additions & 8 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,7 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
for (size_t i = 0; i < node_indices.size(); i++) {
auto node_idx = node_indices[i];
const auto* node(graph_viewer.GetNode(node_idx));
bool supported = false;
// Firstly check if platform supports the WebNN op.
if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) {
supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, logger);
}
const bool supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, logger);

LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType()
<< "] index: [" << node_idx
Expand Down Expand Up @@ -154,7 +150,7 @@ bool AreInputDataTypesSame(const std::string& op_type,
return true;
}

bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types) {
bool IsSupportedDataType(const int32_t& onnx_data_type, const emscripten::val& webnn_supported_data_types) {
auto it = onnx_to_webnn_data_type_map.find(static_cast<ONNX_NAMESPACE::TensorProto_DataType>(onnx_data_type));
if (it == onnx_to_webnn_data_type_map.end())
return false;
Expand All @@ -169,7 +165,7 @@ bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& we

// Check if the input or output data type of ONNX node is supported by the WebNN operator.
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
const int32_t onnx_data_type,
const int32_t& onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
Expand All @@ -184,7 +180,7 @@ bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,

bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type,
const std::string& webnn_op_type,
const int32_t onnx_data_type,
const int32_t& onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
Expand All @@ -193,6 +189,7 @@ bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type,
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] is not supported for now";
return false;
}

if (wnn_limits[webnn_op_type][webnn_input_output_name].isUndefined()) {
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] doesn't have parameter ["
<< webnn_input_output_name << "]";
Expand Down
60 changes: 42 additions & 18 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,15 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
const WebnnDeviceType device_type,
const emscripten::val& wnn_limits,
const logging::Logger& logger);
// TODO(@Honry): Some ONNX ops are supported by decomposed WebNN ops,
// we need to check the support of the decomposed ops.

// Some ONNX ops are supported by decomposed WebNN ops.
static const InlinedHashMap<std::string, std::vector<std::string>> decomposed_op_map = {
{"LRN", {"add", "averagePool2d", "div", "mul", "pad", "pow", "transpose"}},
{"RotaryEmbedding", {"add", "concat", "gather", "mul", "reshape", "split"}},
{"SimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}},
{"SkipSimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}},
};
// ONNX op type to WebNN op type mapping.
static const InlinedHashMap<std::string, std::string> op_map = {
{"Abs", "abs"},
{"Add", "add"},
Expand Down Expand Up @@ -247,7 +254,6 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Log", "log"},
{"LpPool", "l2Pool2d"},
{"LSTM", "lstm"},
{"LRN", "averagePool2d"},
{"MatMul", "matmul"},
{"MatMulInteger", "matmulInteger"},
{"Max", "max"},
Expand Down Expand Up @@ -275,17 +281,14 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Relu", "relu"},
{"Reshape", "reshape"},
{"Resize", "resample2d"},
{"RotaryEmbedding", "gather"},
{"ScatterElements", "scatterElements"},
{"ScatterND", "scatterND"},
{"Shape", "slice"},
{"Sigmoid", "sigmoid"},
{"Sign", "sign"},
{"SimplifiedLayerNormalization", "layerNormalization"},
{"Softplus", "softplus"},
{"Softsign", "softsign"},
{"Sin", "sin"},
{"SkipSimplifiedLayerNormalization", "layerNormalization"},
{"Slice", "slice"},
{"Softmax", "softmax"},
{"Split", "split"},
Expand All @@ -302,16 +305,37 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Xor", "logicalXor"},
};

inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder,
const WebnnDeviceType device_type) {
auto op_map_entry = op_map.find(op_type);
// Returns false if the op_type is not listed in the op_map or
// if the WebNN op has not been implemented in MLGraphBuilder in current browser.
if (op_map_entry == op_map.end() || !wnn_builder[op_map_entry->second].as<bool>()) {
return false;
}
// WebNN op name to its first input name mapping, only record the name that is different from "input".
// This map is used to determine the first input name of a WebNN op and is utilized by OpSupportLimits.
static const InlinedHashMap<std::string, std::string> webnn_op_first_input_name_map = {
{"add", "a"},
{"concat", "inputs"},
{"div", "a"},
{"equal", "a"},
{"gemm", "a"},
{"greater", "a"},
{"greaterOrEqual", "a"},
{"lesser", "a"},
{"lesserOrEqual", "a"},
{"logicalAnd", "a"},
{"logicalNot", "a"},
{"logicalOr", "a"},
{"logicalXor", "a"},
{"matmul", "a"},
{"max", "a"},
{"min", "a"},
{"mul", "a"},
{"pow", "a"},
{"sub", "a"},
{"where", "condition"},
};

return true;
// Retrieve the first input name of a WebNN op used for validating supported input data types.
// WebNN ops have various first input names such as 'a', 'input', 'inputs', etc.
// Special names other than 'input' are recorded in the webnn_op_first_input_name_map.
inline std::string GetWebNNOpFirstInputName(const std::string& webnn_op_type) {
auto it = webnn_op_first_input_name_map.find(webnn_op_type);
return (it != webnn_op_first_input_name_map.end()) ? it->second : "input";
}

inline bool GetWebNNOpType(const std::string& op_type, std::string& webnn_op_type) {
Expand Down Expand Up @@ -341,16 +365,16 @@ static const InlinedHashMap<ONNX_NAMESPACE::TensorProto_DataType, std::string> o
bool AreInputDataTypesSame(const std::string& op_type,
gsl::span<const int32_t> input_types,
const logging::Logger& logger);
bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types);
bool IsSupportedDataType(const int32_t& onnx_data_type, const emscripten::val& webnn_supported_data_types);
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
const int32_t onnx_data_type,
const int32_t& onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const logging::Logger& logger);
bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type,
const std::string& webnn_op_type,
const int32_t onnx_data_type,
const int32_t& onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,13 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializ
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
std::string webnn_op_type;

Check warning on line 65 in onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc:65: Add #include <string> for string [build/include_what_you_use] [4]
if (!GetWebNNOpType(op_type, webnn_op_type))
return false;

return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger);
const auto webnn_input_name = GetWebNNOpFirstInputName(op_type);
return IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, input_type, wnn_limits,
webnn_input_name, "input", logger);
}

bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits,
Expand Down
24 changes: 0 additions & 24 deletions onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@ class CastOpBuilder : public BaseOpBuilder {
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
private:
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -85,25 +80,6 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
return Status::OK();
}

// Operator support related.
bool CastOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input_type;

if (!GetType(*input_defs[0], input_type, logger))
return false;

if (!IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "input", logger))
return false;

NodeAttrHelper helper(node);
// Check cast to type.
const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED);
return IsDataTypeSupportedByOp(op_type, to_type, wnn_limits, "output", "to", logger);
}

void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<CastOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
Expand Down
45 changes: 45 additions & 0 deletions onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class LRNOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};

Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
Expand Down Expand Up @@ -142,6 +146,47 @@ bool LRNOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return true;
}

bool LRNOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input_type = 0;
if (!GetType(*input_defs[0], input_type, logger)) {
return false;
}

// Check if the input data type is supported by each decomposed WebNN op.
// Decomposed ops include: "add", "averagePool2d", "div", "mul", "pad", "pow" and "transpose".
for (const auto& webnn_op_type : decomposed_op_map.at(op_type)) {
const auto webnn_input_name = GetWebNNOpFirstInputName(webnn_op_type);
if (!IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, input_type, wnn_limits, webnn_input_name, "X", logger)) {
return false;
}
}

return true;
}

bool LRNOpBuilder::HasSupportedOutputsImpl(const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& output_defs = node.OutputDefs();
const auto& op_type = node.OpType();
int32_t output_type = 0;
if (!GetType(*output_defs[0], output_type, logger)) {
return false;
}

// Check if the output data type is supported by every decomposed WebNN op.
for (const auto& webnn_op_type : decomposed_op_map.at(op_type)) {
if (!IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, output_type, wnn_limits, "output", "Y", logger)) {
return false;
}
}

return true;
}

void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<LRNOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class NormalizationOpBuilder : public BaseOpBuilder {
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};

Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
Expand Down Expand Up @@ -305,7 +307,44 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet&
return false;
}

return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger);
if (op_type == "SimplifiedLayerNormalization" || op_type == "SkipSimplifiedLayerNormalization") {
// SkipSimplifiedLayerNormalization and SimplifiedLayerNormalization are supported by decomposed WebNN ops.
// Check if the input data type is supported by each decomposed WebNN op.
// Decomposed ops include: "add", "div", "mul", "pow", "reduceMean" and "sqrt".
for (const auto& webnn_op_type : decomposed_op_map.at(op_type)) {
const auto webnn_input_name = GetWebNNOpFirstInputName(webnn_op_type);
if (!IsDataTypeSupportedByWebNNOp(
op_type, webnn_op_type, input0_type, wnn_limits, webnn_input_name, "input", logger)) {
return false;
}
}
return true;
} else {
return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger);
}
}

bool NormalizationOpBuilder::HasSupportedOutputsImpl(const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& output_defs = node.OutputDefs();
const auto& op_type = node.OpType();
int32_t output_type = 0;
if (!GetType(*output_defs[0], output_type, logger)) {
return false;
}

if (op_type == "SimplifiedLayerNormalization" || op_type == "SkipSimplifiedLayerNormalization") {
// Check if the output data type is supported by every decomposed WebNN op.
for (const auto& webnn_op_type : decomposed_op_map.at(op_type)) {
if (!IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, output_type, wnn_limits, "output", "output", logger)) {
return false;
}
}
return true;
} else {
return IsDataTypeSupportedByOp(op_type, output_type, wnn_limits, "output", "Output", logger);
}
}

void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
Expand Down
Loading

0 comments on commit 66e40bd

Please sign in to comment.