Skip to content

Commit

Permalink
[webgpu] Implement Split operator (#23198)
Browse files Browse the repository at this point in the history
Test: onnxruntime_test_all.exe --gtest_filter=SplitOperatorTest.*

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
jchen10 authored Jan 13, 2025
1 parent 377165f commit a9be6b7
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 5 deletions.
162 changes: 162 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/split.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/tensor/split.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"

namespace onnxruntime {
namespace webgpu {

namespace {

// Helper function to calculate the output index based on the input index and the sizes of the splits.
void CalculateOutputIndex(std::ostream& os, size_t output_count) {
os << "fn calculate_output_index(index: u32) -> u32 {\n"
<< " for (var i: u32 = 0u; i < " << output_count << "u; i += 1u ) {\n"
<< " if (index < " << GetElementAt("uniforms.sizes_in_split_axis", "i", output_count) << ") {\n"
<< " return i;\n"
<< " }\n"
<< " }\n"
<< " return " << output_count << "u;\n"
<< "}\n";
}

// Helper function to write the buffer data for each output.
void WriteBufferData(std::ostream& os, const ShaderVariableHelper& input,
gsl::span<const ShaderVariableHelper*> outputs) {
os << "fn write_buffer_data(output_number: u32, global_idx: u32, indices: output_0_indices_t) {\n";
for (size_t i = 0; i < outputs.size(); ++i) {
const auto buffer_write = outputs[i]->SetByIndices("indices", input.GetByOffset("global_idx"));
if (outputs.size() == 1) {
os << buffer_write;
} else if (i == 0) {
os << " if (output_number == 0u) {\n"
<< " " << buffer_write << "\n";
} else if (i == outputs.size() - 1) {
os << " } else {\n"
<< " " << buffer_write << "\n";
} else {
os << " } else if (output_number == " << i << "u) {\n"
<< " " << buffer_write << "\n";
}
}
os << " }\n"
<< "}\n";
}

} // namespace

Status SplitProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);

size_t output_count = Outputs().size();
std::vector<const ShaderVariableHelper*> outputs;
outputs.reserve(output_count);
for (size_t i = 0; i < output_count; ++i) {
outputs.push_back(
&shader.AddOutput("output_" + std::to_string(i), ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias));
}

// Add implementation of fn calculate_output_index.
CalculateOutputIndex(shader.AdditionalImplementation(), output_count);
// Add implementation of fn write_buffer_data.
WriteBufferData(shader.AdditionalImplementation(), input, outputs);

shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.input_size")
<< " var indices = " << input.OffsetToIndices("global_idx") << ";\n"
<< " var index = indices[" << axis_ << "];\n"
<< " let output_number = calculate_output_index(index);\n"
<< " if (output_number != 0u) {\n"
<< " index -= uniforms.sizes_in_split_axis[output_number - 1u];\n"
<< " indices[" << axis_ << "] = index;\n"
<< " }\n"
<< " write_buffer_data(output_number, global_idx, indices);\n";

return Status::OK();
}

Status Split::ComputeInternal(ComputeContext& context) const {
const Tensor* input = context.Input<Tensor>(0);
auto& input_shape = input->Shape();
auto num_outputs = context.OutputCount();

int64_t axis = axis_;
std::vector<int64_t> split_sizes;

split_sizes.assign(split_sizes_.begin(), split_sizes_.end());
// Compute split_sizes from the 'split' input tensor.
if (split_sizes_.size() == 0 && context.InputCount() > 1) {
const Tensor* split_tensor = context.Input<Tensor>(1);
// Check if split_tensor is valid.
if (split_tensor != nullptr) {
ORT_ENFORCE(split_tensor->Shape().NumDimensions() == 1, "The split tensor must be a vector tensor.");
// Get split_sizes from the input tensor.
auto nDims = static_cast<size_t>(split_tensor->Shape()[0]);
const auto* data = split_tensor->Data<int64_t>();
split_sizes.assign(data, data + nDims);
}
}

// The variables below are not actually used in the current implementation.
int before_dims = 0;
int after_dims_including_split_axis = 0;
int after_dims_excluding_split = 0;
// This handles the case where the axis is negative. It also splits outputs evenly according to num_ouputs if
// split_sizes is empty.
ORT_RETURN_IF_ERROR(PrepareForCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis,
after_dims_excluding_split, split_sizes));

SplitProgram program{gsl::narrow_cast<uint32_t>(axis)};
program.AddInput({input, ProgramTensorMetadataDependency::TypeAndRank});

auto output_dimensions = input_shape.AsShapeVector();
for (int i = 0; i < num_outputs; ++i) {
// Update the size of dimension for axis we're splitting on.
auto split_size = narrow<int>(split_sizes[i]);
output_dimensions[narrow<size_t>(axis)] = split_size;

Tensor* output = context.Output(i, TensorShape{output_dimensions});
program.AddOutput({output, ProgramTensorMetadataDependency::Rank});
}

uint32_t input_size = gsl::narrow<uint32_t>(input_shape.Size());
// Early return if the input tensor is empty.
if (input_size == 0) {
return Status::OK();
}

uint32_t previous_sum = 0;
std::vector<uint32_t> sizes_in_split_axis;
// sizes_in_split_axis are the cumulative sizes of the splits in the split axis.
for (auto split_size : split_sizes) {
previous_sum += gsl::narrow<uint32_t>(split_size);
sizes_in_split_axis.push_back(previous_sum);
}

program
.SetDispatchGroupSize((input_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.CacheHint(std::to_string(axis))
.AddUniformVariables(
{input_size, gsl::span<const uint32_t>(sizes_in_split_axis.data(), sizes_in_split_axis.size())});
return context.RunProgram(program);
}

#define WEBGPU_SPLIT_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \
ONNX_OPERATOR_KERNEL_EX(OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \
KERNEL_CLASS);

#define WEBGPU_SPLIT_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX(OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \
KERNEL_CLASS);

WEBGPU_SPLIT_VERSIONED_KERNEL(Split, 1, 1, Split_1, WebGpuSupportedNumberTypes())
WEBGPU_SPLIT_VERSIONED_KERNEL(Split, 2, 10, Split_2_10, WebGpuSupportedNumberTypes())
WEBGPU_SPLIT_VERSIONED_KERNEL(Split, 11, 12, Split_11_12, WebGpuSupportedNumberTypes())
WEBGPU_SPLIT_VERSIONED_KERNEL(Split, 13, 17, Split_13_17, WebGpuSupportedNumberTypes())
WEBGPU_SPLIT_KERNEL(Split, 18, Split_18, WebGpuSupportedNumberTypes());

} // namespace webgpu
} // namespace onnxruntime
61 changes: 61 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/split.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/common.h"
#include "core/providers/cpu/tensor/split.h"

namespace onnxruntime {
namespace webgpu {

class SplitProgram final : public Program<SplitProgram> {
public:
SplitProgram(const uint32_t axis) : Program{"Split"}, axis_{axis} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32},
{"sizes_in_split_axis", ProgramUniformVariableDataType::Uint32});

private:
uint32_t axis_;
};

class Split : public WebGpuKernel, public SplitBase {
public:
Split(const OpKernelInfo& info, uint32_t opset) : WebGpuKernel(info), SplitBase(info, opset) {}

protected:
Status ComputeInternal(ComputeContext& context) const override;
};

class Split_1 final : public Split {
public:
Split_1(const OpKernelInfo& info) : Split(info, 1) {}
};

class Split_2_10 final : public Split {
public:
Split_2_10(const OpKernelInfo& info) : Split(info, 2) {}
};

class Split_11_12 final : public Split {
public:
Split_11_12(const OpKernelInfo& info) : Split(info, 11) {}
};

class Split_13_17 final : public Split {
public:
Split_13_17(const OpKernelInfo& info) : Split(info, 13) {}
};

class Split_18 final : public Split {
public:
Split_18(const OpKernelInfo& info) : Split(info, 18) {}
};

} // namespace webgpu
} // namespace onnxruntime
11 changes: 6 additions & 5 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -637,11 +637,12 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Concat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Concat)>,

// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 1, Split)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 2, 10, Split)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Split)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Split)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 1, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 2, 10, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, Split)>,

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 8, 12, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Expand)>,

Expand Down

0 comments on commit a9be6b7

Please sign in to comment.