Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Support FP8 quantization for MOE layers #1221

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

kaixih
Copy link

@kaixih kaixih commented Jan 31, 2025

Description

This PR enables MoE with Fp8 quantization for NVIDIA GPUs.

cc. @kocchop @nouiz

Tests

Tested on H100

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Do you have any end-to-end tests about perf change?

cc @michelle-yooh we could have some internal tests in the future.

@@ -202,6 +205,84 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
return nn.Fp8DotGeneralOp


class Fp8Einsum(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we already have the Quantization superclass above. What if we added a einsum method to that and then added this logic as part of the einsum implementation within Fp8Quantization. We could avoid some of that branching we need in linears up above that way?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Emm, I tried to move this logic to Fp8Quantization. But the trick part is that we need to maintain some variables. So, I use this inherited nn.Module. When moving the logic to Fp8Quantization and I always got errors like TypeError: Can't call __hash__ on modules that hold variables..

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think we can keep this Fp8Einsum class. I was thinking we could just have the einsum method on Fp8Quantization return an instance of this Fp8Einsum class, instead of that if/else logic that is currently present in linears.py

Not a big deal if it doesn't work - I don't want to hold up this feature on stylistic code stuff that we can clean up later.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this can be done. PTAL.

@kaixih
Copy link
Author

kaixih commented Feb 5, 2025

LGTM! Do you have any end-to-end tests about perf change?

cc @michelle-yooh we could have some internal tests in the future.

Yes, we have tested on MaxText/configs/models/mixtral-8x7b.yml on one 8xH100 GPU (by changing base_num_decoder_layers to 4 to fit), it shows about 30~40% perf benefits.

@@ -583,7 +583,11 @@ def aqt_einsum(*args, **kwargs):
# simply skip kwargs, since aqt einsum doesn't support any kwargs like precision
return self.quant.einsum(rhs_mesh_axes)(*args)

einsum_op = aqt_einsum
# We need a separate way to retrieve the quantized einsum for fp8 config.
if isinstance(self.quant, quantizations.Fp8Quantization):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what I'm thinking here is we could instead habe

def quant_einsum(*args, **kwargs):
   # or something similar depending on how we structure the einsum method on Quantization
   return self.quant.einsum(rhs_mesh_axes)(*args)

einsum_op = quant_einsum

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Fp8Quantization impl of einsum can just return Fp8Einsum

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to Anfal's comments. That's a great idea!

We can do it in a followup PR if we want to get this feature in soon.

@@ -201,6 +204,87 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
"""Returns dot_general configured with aqt params."""
return nn.Fp8DotGeneralOp

def einsum(self, dtype: DType = jnp.float32):
return Fp8Einsum(dtype=dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! One nit: could you add einsum as an abstract method to the Quantization class? For consistency? Thanks!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. PTAL.

@@ -60,6 +63,10 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
"""Placeholder for dot_general implementation in subclasses."""
pass

def einsum(self, dtype: DType = jnp.float32):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll prob want to generify the arguments for this abstract method later, but this is good for now!

@kaixih
Copy link
Author

kaixih commented Feb 12, 2025

@anfals Anything else we can do on our side to help merge?

@anfals
Copy link
Collaborator

anfals commented Feb 12, 2025

@anfals Anything else we can do on our side to help merge?

Can you squash all your commits into 1? Then I can kick off the workflows and @yangyuwei or @gobbleturk can give code owner approval :). Thank you!

@kaixih
Copy link
Author

kaixih commented Feb 12, 2025

Done. PTAL. @anfals

@michelle-yooh
Copy link
Collaborator

Hi @kaixih, I tried running a mixtral 8x7b job with quantization: fp8 off this PR and faced AttributeError: 'Fp8Quantization' object has no attribute 'quant_dg'
Is this some missing attribute that needs to be added to the PR?

@kaixih
Copy link
Author

kaixih commented Feb 12, 2025

@michelle-yooh Can you try it with sparse_matmul: False (in MaxText/configs/models/mixtral-8x7b.yml or command line)?

@kaixih
Copy link
Author

kaixih commented Feb 13, 2025

@anfals Rebase is done. Can you help merging?

@anfals
Copy link
Collaborator

anfals commented Feb 13, 2025

@anfals Rebase is done. Can you help merging?

@michelle-yooh Once you have verified things on your end, can you ping me? I can help with the merge. Thanks!

@kaixih
Copy link
Author

kaixih commented Feb 20, 2025

Gentle ping. Any update? @anfals @michelle-yooh

@anfals
Copy link
Collaborator

anfals commented Feb 20, 2025

Gentle ping. Any update? @anfals @michelle-yooh

i started up the workflow checks. @michelle-yooh if you verified this fixes the issues you were facing, I can merge this today. thanks!

@anfals
Copy link
Collaborator

anfals commented Feb 20, 2025

Gentle ping. Any update? @anfals @michelle-yooh

i started up the workflow checks. @michelle-yooh if you verified this fixes the issues you were facing, I can merge this today. thanks!

We talked about this in the sync and @michelle-yooh said to merge in and she'll run some tests when she gets bandwidth and follow up.

Looks like the linter is failing. @kaixih could you try running

pyink MaxText --diff --color --pyink-indentation=2 --line-length=125 or something like that?

I'll see if I can update the description to get that other check working. Thanks!

Minimize branches

Add abstract method

Format
@kaixih
Copy link
Author

kaixih commented Feb 20, 2025

Manually fixed the format issues. (Surprised that we now use 125 as the line-length max instead of 80). Hope the tests look good now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants