Skip to content

Commit

Permalink
[WebNN EP] Optimize model partitioning (#23332)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

The old `GetCapability` function of WebNN EP is just a very simple
search for groups of nodes that can be handled. This doesn't work well
in the following example graph, where A and D could be handled by the
EP, but B is between them in the topological order, as you get two
single node capabilities. However, it may also be advantageous if C and
E could be handled by the EP, since they would be combined with D even
though they are not connected.
```
    A  B  C
    | /   |
    D     E
    |     |
```
Therefore, we improve partitioning results by reusing
`utils::CreateSupportedPartitions`, which walks the edges for each node
that the EP can handle as they are iterated in topological order. This
would guarantee that all connected nodes that can be handled are grouped
together. Correspondingly, we modify the `webnn::GetSupportedNodes`
function to return the supported nodes instead of the group of supported
partitions.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: Dwayne Robinson <[email protected]>
  • Loading branch information
peishenyan and fdwr authored Jan 16, 2025
1 parent 5735e1b commit 80f686e
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 117 deletions.
44 changes: 15 additions & 29 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,44 +99,30 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n
return true;
}

std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
const emscripten::val& wnn_builder,
const WebnnDeviceType device_type,
const emscripten::val& wnn_limits,
const logging::Logger& logger) {
std::vector<std::vector<size_t>> supported_node_groups;
std::vector<size_t> supported_node_group;
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();

for (size_t i = 0; i < node_indices.size(); i++) {
auto node_idx = node_indices[i];
const auto* node(graph_viewer.GetNode(node_idx));
std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewer,
const emscripten::val& wnn_builder,
const WebnnDeviceType device_type,
const emscripten::val& wnn_limits,
const logging::Logger& logger) {
std::unordered_set<const Node*> supported_nodes;

for (const auto& node : graph_viewer.Nodes()) {
bool supported = false;
// Firstly check if platform supports the WebNN op.
if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) {
supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, logger);
if (CheckSingleOp(node.OpType(), wnn_builder, device_type)) {
supported = IsNodeSupported(node, graph_viewer, device_type, wnn_limits, logger);
}

LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType()
<< "] index: [" << node_idx
<< "] name: [" << node->Name()
LOGS(logger, VERBOSE) << "Operator type: [" << node.OpType()
<< "] index: [" << node.Index()
<< "] name: [" << node.Name()
<< "] supported: [" << supported
<< "]";
if (supported) {
supported_node_group.push_back(node_idx);
} else {
if (!supported_node_group.empty()) {
supported_node_groups.push_back(supported_node_group);
supported_node_group.clear();
}
supported_nodes.insert(&node);
}
}

if (!supported_node_group.empty()) {
supported_node_groups.push_back(supported_node_group);
}

return supported_node_groups;
return supported_nodes;
}

bool AreInputDataTypesSame(const std::string& op_type,
Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,12 @@ inline bool TensorExists(const ConstPointerContainer<std::vector<NodeArg*>>& def
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name,
const logging::Logger& logger, bool allow_empty_input = false);

// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
const emscripten::val& wnn_builder,
const WebnnDeviceType device_type,
const emscripten::val& wnn_limits,
const logging::Logger& logger);
// Get a set of nodes supported by WebNN EP.
std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewer,
const emscripten::val& wnn_builder,
const WebnnDeviceType device_type,
const emscripten::val& wnn_limits,
const logging::Logger& logger);
// TODO(@Honry): Some ONNX ops are supported by decomposed WebNN ops,
// we need to check the support of the decomposed ops.
static const InlinedHashMap<std::string, std::string> op_map = {
Expand Down
122 changes: 40 additions & 82 deletions onnxruntime/core/providers/webnn/webnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@
#include "core/common/safeint.h"
#include "core/providers/webnn/allocator.h"
#include "core/providers/webnn/data_transfer.h"
#include "core/providers/partitioning_utils.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"

#include "builders/model.h"
#include "builders/helper.h"
#include "builders/model_builder.h"

namespace onnxruntime {

constexpr const char* WEBNN = "WEBNN";

WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags)
: IExecutionProvider{
onnxruntime::kWebNNExecutionProvider,
Expand Down Expand Up @@ -51,8 +56,6 @@ WebNNExecutionProvider::~WebNNExecutionProvider() {}
std::vector<std::unique_ptr<ComputeCapability>>
WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer,
const IKernelLookup& /*kernel_registries*/) const {
std::vector<std::unique_ptr<ComputeCapability>> result;

// For subgraph which is the attribute of the control flow nodes, part of its initializers are stored in its
// ancestor graphs as common initializers shared for other subgraphs. We need to collect all of them used for
// identifying the required initializer names and storing into 'meta_def->constant_initializers'.
Expand All @@ -64,67 +67,44 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
all_initializers = webnn::CollectAllInitializedTensors(graph_viewer);
}

/*
Very basic search for groups of nodes that can be handled by the EP.
This doesn't work perfectly if you have a scenario like the following where A and D could be handled by the EP
but B is between them in the topological sort as you'll get two single node capabilities. However if can also
be advantageous if C and E could be handled by the EP as they would be combined with D even though not connected.
Not sure how often each of these scenarios happens.
A B C
| / |
D E
| |
Would probably be better to walk the edges for each node the EP can handle as they are iterated in topological order,
accumulating nodes (and saving which ones have been taken) until you run out. This would guarantee all
connected nodes that can be handled are grouped together.
*/

const auto& logger = *GetLogger();

emscripten::val wnn_builder = emscripten::val::global("MLGraphBuilder").new_(wnn_context_);
if (!wnn_builder.as<bool>()) {
ORT_THROW("Failed to create WebNN builder.");
}

const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, wnn_limits_, logger);
wnn_builder = emscripten::val::undefined();
// Get all the NodeUnits in the graph_viewer
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;

if (node_groups.empty()) {
return result;
}
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger);

const auto& graph_output_list = graph_viewer.GetOutputs();
InlinedHashSet<const NodeArg*> graph_outputs(graph_output_list.cbegin(), graph_output_list.cend());
const auto supported_nodes = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, wnn_limits_, logger);

size_t num_of_supported_nodes = 0;
for (const auto& group : node_groups) {
if (group.empty())
continue;
const auto gen_metadef_name = [&]() {
HashValue model_hash;
int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
return MakeString(WEBNN, "_", model_hash, "_", metadef_id);
};

num_of_supported_nodes += group.size();
LOGS(logger, VERBOSE) << "WebNNExecutionProvider::GetCapability, current supported node group size: "
<< group.size();
auto result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {},
gen_metadef_name, WEBNN, kWebNNExecutionProvider,
&node_unit_map, /*drop_constant_initializers*/ true);

InlinedHashSet<NodeIndex> node_set;
node_set.reserve(group.size());
for (const auto& index : group) {
node_set.insert(index);
}
// Release wnn_builder
wnn_builder = emscripten::val::undefined();

std::unique_ptr<IndexedSubGraph> sub_graph = std::make_unique<IndexedSubGraph>();
const auto& graph_output_list = graph_viewer.GetOutputs();
InlinedHashSet<const NodeArg*> graph_outputs(graph_output_list.cbegin(), graph_output_list.cend());

for (auto& capability : result) {
auto& sub_graph = capability->sub_graph;
if (sub_graph->nodes.empty())
continue;

std::vector<std::string> subgraph_initializers;
InlinedHashSet<const NodeArg*> node_outputs;
InlinedHashSet<const NodeArg*> subgraph_inputs;
InlinedHashSet<const NodeArg*> subgraph_outputs;
std::vector<const NodeArg*> ordered_subgraph_inputs;
// Output should be unique. It may be produced as graph output and subgraph output.
InlinedHashSet<const NodeArg*> ordered_subgraph_outputs;

for (const auto& index : group) {
sub_graph->nodes.push_back(index);
for (const auto& index : sub_graph->nodes) {
const auto* node = graph_viewer.GetNode(index);

for (const auto* input : node->InputDefs()) {
Expand All @@ -136,39 +116,13 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
if (is_subgraph && Contains(all_initializers, input->Name())) {
subgraph_initializers.push_back(input->Name());
}
// If the node input was not produced by this subgraph, add it to the subgraph inputs.
if (node_outputs.count(input) == 0) {
if (subgraph_inputs.count(input) == 0) {
subgraph_inputs.insert(input);
ordered_subgraph_inputs.push_back(input);
}
}
}

const auto& output_defs = node->OutputDefs();
for (const auto* output_def : output_defs) {
node_outputs.insert(output_def);
// if output is overall graph output we need to produce it.
if (graph_outputs.count(output_def) != 0) {
ordered_subgraph_outputs.insert(output_def);
}
}

// if output connects to a node not in this subgraph we need to produce it.
for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) {
if (node_set.count(it->GetNode().Index()) == 0) {
const auto* output_def = output_defs[it->GetSrcArgIndex()];
if (subgraph_outputs.count(output_def) == 0) {
subgraph_outputs.insert(output_def);
ordered_subgraph_outputs.insert(output_def);
}
}
}
}

// Assign inputs and outputs to subgraph's meta_def.
uint64_t model_hash;
int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
const auto meta_def_old = sub_graph->GetMetaDef();
auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>();
meta_def->name = "WEBNN_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id);
meta_def->domain = kMSDomain;
Expand All @@ -181,20 +135,24 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
}
}

for (const auto& input : ordered_subgraph_inputs) {
meta_def->inputs.push_back(input->Name());
for (const auto& input : meta_def_old->inputs) {
meta_def->inputs.push_back(input);
}

for (const auto& output : ordered_subgraph_outputs) {
meta_def->outputs.push_back(output->Name());
for (const auto& output : meta_def_old->outputs) {
meta_def->outputs.push_back(output);
}

sub_graph->SetMetaDef(std::move(meta_def));

result.push_back(std::make_unique<ComputeCapability>(std::move(sub_graph)));
}

auto num_of_partitions = result.size();
const auto num_of_partitions = result.size();
const auto num_of_supported_nodes = std::accumulate(
result.begin(), result.end(), size_t{0},
[](const auto& acc, const auto& partition) -> size_t {
return acc + (partition && partition->sub_graph ? partition->sub_graph->nodes.size() : 0);
});

const auto summary_msg = MakeString(
"WebNNExecutionProvider::GetCapability,",
" number of partitions supported by WebNN: ", num_of_partitions,
Expand Down

0 comments on commit 80f686e

Please sign in to comment.