Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[webgpu] Restore MatMulNBits workgroup size for Phi-3.5 #23349

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

daijh
Copy link
Contributor

@daijh daijh commented Jan 14, 2025

Description

This change restores the MatMulNBits workgroup size from (8, 8, 1) back to (16, 8, 1) to resolve a performance regression observed on Intel iGPUs during token generation (M=1).

Motivation and Context

As above.

This change restores the MatMulNBits workgroup size from (8, 8, 1) back
to (16, 8, 1) to resolve a performance regression observed on Intel
iGPUs during token generation (M=1).

Signed-off-by: Jianhui Dai <[email protected]>
@daijh
Copy link
Contributor Author

daijh commented Jan 14, 2025

@[email protected] @jchen10

Please take a look before broader review.

@jchen10
Copy link
Contributor

jchen10 commented Jan 14, 2025

Do you have more specific information on how much this regresses on what GPU platforms?

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Jan 14, 2025
@daijh
Copy link
Contributor Author

daijh commented Jan 15, 2025

Phi-3.5 decoding performance on LNL decreased from approximately 28 tps to 26 tps.

The detailed profiling results below. MatMulNBits-1x8192x3072 seems the primary contributor to the performance difference.

**workgroup_size-8x8x1

Op Count Duration(us) Percent
MatMulNBits-1x3072x8192 14528 2336019 33.97%
MatMulNBits-1x3072x3072 29056 2070941 30.12%
MatMulNBits-1x8192x3072 7264 1547771 22.51%
MultiHeadAttention 43584 494532 7.19%
SkipSimplifiedLayerNormalization 14528 154671 2.25%
MatMulNBits-1x3072x32064 227 127522 1.85%
RotaryEmbedding 14528 55148 0.80%
Mul 14528 35564 0.52%
SimplifiedLayerNormalization 227 29672 0.43%
Sigmoid 7264 15966 0.23%
Where 454 3294 0.05%
Cast 454 1172 0.02%
Gather 227 1164 0.02%
Tile 227 926 0.01%
Sub 227 673 0.01%
Add 227 644 0.01%
Expand 227 516 0.01%
ALL 147777 6876195 100.00%

**workgroup_size-16x8x1

Op Count Duration(us) Percent
MatMulNBits-1x3072x8192 14528 2274354 34.67%
MatMulNBits-1x3072x3072 29056 1990174 30.34%
MatMulNBits-1x8192x3072 7264 1366201 20.83%
MultiHeadAttention 43584 487508 7.43%
SkipSimplifiedLayerNormalization 14528 159851 2.44%
MatMulNBits-1x3072x32064 227 123921 1.89%
RotaryEmbedding 14528 61867 0.94%
Mul 14528 35728 0.54%
SimplifiedLayerNormalization 227 30158 0.46%
Sigmoid 7264 21388 0.33%
Where 454 3533 0.05%
Gather 227 1217 0.02%
Cast 454 1175 0.02%
Tile 227 971 0.01%
Sub 227 689 0.01%
Add 227 632 0.01%
Expand 227 520 0.01%
ALL 147777 6559887 100.00%

@@ -583,7 +583,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
program.CacheHint("T_M" + std::to_string(tile_m) + "Subgroup" + std::to_string(use_subgroup));
} else if (block_size == 32) {
components = 1;
constexpr uint32_t workgroup_size = 64;
// TODO: Tune the workgroup size when `M=1`.
constexpr uint32_t workgroup_size = 128;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching it. I changed it by accident. Maybe just restore it as before:

    constexpr uint32_t workgroup_size = 128;
    const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4
                                                             : 1;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We tested various combinations of workgroup_size_x, workgroup_size_y, and output_number (e.g. (32, 8, 1), (32, 16, 1), (16, 16, 1) ...), with initial results showing a 5-15% performance improvement (M=1) compared to (16, 8, 1) , with results varying across devices.

This TODO tracks upcoming tuning.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:WebGPU ort-web webgpu provider
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants