Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Native WebGPU EP] Add packedQKV and do_rotary attribute support to GroupQueryAttention operator #23386

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
44 changes: 31 additions & 13 deletions onnxruntime/contrib_ops/webgpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k) {

Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
if (!is_packed_qkv_) {
shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
}
if (feed_past_key_) {
shader.AddInput("past_key", ShaderUsage::UseUniform);
}
Expand All @@ -102,13 +104,21 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "let m = workgroup_id.y * TILE_SIZE;\n"
<< "let n = workgroup_id.x * TILE_SIZE;\n"
<< "let batch_idx = workgroup_id.z / uniforms.num_heads;\n"
<< "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n"
<< "let sequence_length = uniforms.M;\n"
<< "var total_sequence_length = uniforms.N;\n";
if (is_packed_qkv_) {
shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n"
<< "let kv_num_heads = uniforms.kv_num_heads /" << n_reps_ << ";\n"
<< "let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K;\n"
<< "let qOffset = batch_idx * packed_batch_stride + head_idx * uniforms.M * uniforms.K;\n"
<< "let kOffset = batchIdx * packed_batch_stride + (uniforms.num_heads + kvHeadIdx) * uniforms.kv_sequence_length * uniforms.K;\n";
} else {
shader.MainFunctionBody() << "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n"
<< "let kOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n";
}
std::ostringstream oss;
InitVarStub(oss, seqlen_k_);
shader.MainFunctionBody() << oss.str();
shader.MainFunctionBody() << "let kOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n";
if (has_present_key_) {
shader.MainFunctionBody() << "let presentKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.present_sequence_length * uniforms.K;\n";
}
Expand All @@ -126,11 +136,11 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " let pastKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.past_sequence_length * uniforms.K;\n"
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
<< " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
<< " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n"
<< " tileK[idx] = " << (is_packed_qkv_ ? "q" : "key") << "[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n"
<< " }\n";
} else {
shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length) {\n"
" tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
" tileK[idx] = " << (is_packed_qkv_ ? "q" : "key") << "[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
" }\n";
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length) {\n"
" tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
" tileK[idx] = " << (is_packed_qkv_ ? "q" : "key") << "[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
" }\n";
}
shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length) {\n"
" tileK[idx] = "
<< (is_packed_qkv_ ? "q" : "key") << "[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
" }\n";
}


Expand Down Expand Up @@ -181,9 +191,11 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1);

AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size,
components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
{K, ProgramTensorMetadataDependency::TypeAndRank, components}});
components, parameters.is_first_prompt_, parameters.is_packed_qkv_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
program.AddInput({Q, ProgramTensorMetadataDependency::TypeAndRank, components});
if (K != nullptr) {
program.AddInput({K, ProgramTensorMetadataDependency::TypeAndRank, components});
}
if (feed_past_key) {
program.AddInput({past_key, ProgramTensorMetadataDependency::TypeAndRank, components});
}
Expand All @@ -203,7 +215,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
(parameters.sequence_length_ + tile_size - 1) / tile_size,
parameters.batch_size_ * parameters.num_heads_)
.SetWorkgroupSize(tile_size, tile_size)
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_)
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_, parameters.is_packed_qkv_)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
{static_cast<uint32_t>(vectorized_head_size)},
{static_cast<uint32_t>(total_sequence_length)},
Expand Down Expand Up @@ -331,7 +343,13 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
std::ostringstream oss;
InitVarStub(oss, seqlen_k_);
shader.MainFunctionBody() << oss.str();
shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n";
if (is_packed_qkv_) {
shader.MainFunctionBody() << "let kv_num_heads = uniforms.num_heads / " << n_reps_ << ";\n"
<< "let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K;\n"
<< "let vOffset = batch_idx * packed_batch_stride + (uniforms.num_heads + kv_num_heads + kvHeadIdx) * uniforms.N * uniforms.kv_sequence_length + n;\n";
} else {
shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n";
}
if (has_present_value_) {
shader.MainFunctionBody() << "let presentValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.present_sequence_length + n;\n";
}
Expand Down Expand Up @@ -399,7 +417,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
const bool has_present_value = output_count > 1 && past_value != nullptr;
constexpr int tile_size = 12;

VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.is_packed_qkv_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank},
{V, ProgramTensorMetadataDependency::TypeAndRank}});
if (feed_past_value) {
Expand All @@ -416,7 +434,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
program.SetDispatchGroupSize((parameters.v_head_size_ + tile_size - 1) / tile_size,
(parameters.sequence_length_ + tile_size - 1) / tile_size,
parameters.batch_size_ * parameters.num_heads_)
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_)
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_, parameters.is_packed_qkv_)
.SetWorkgroupSize(tile_size, tile_size)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
{static_cast<uint32_t>(total_sequence_length)},
Expand Down Expand Up @@ -451,7 +469,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs,
parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_));

ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value,
ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, parameters.is_packed_qkv_ ? Q : V, past_value, output, present_value,
parameters, past_sequence_length, total_sequence_length, seqlen_k));

return Status::OK();
Expand Down
10 changes: 6 additions & 4 deletions onnxruntime/contrib_ops/webgpu/bert/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
public:
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, bool is_packed_qkv, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_packed_qkv_(is_packed_qkv) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand Down Expand Up @@ -64,6 +64,7 @@
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
bool is_packed_qkv_;
};

class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
Expand All @@ -90,8 +91,8 @@

class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
public:
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, bool is_packed_qkv, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)

Check warning on line 94 in onnxruntime/contrib_ops/webgpu/bert/attention.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/attention.h:94: Add #include <string> for string [build/include_what_you_use] [4]
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_packed_qkv_(is_packed_qkv) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -118,6 +119,7 @@
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
bool is_packed_qkv_;
};

} // namespace webgpu
Expand Down
Loading
Loading