-
Notifications
You must be signed in to change notification settings - Fork 321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVIDIA] Support FP8 quantization for MOE layers #1221
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Do you have any end-to-end tests about perf change?
cc @michelle-yooh we could have some internal tests in the future.
@@ -202,6 +205,84 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): | |||
return nn.Fp8DotGeneralOp | |||
|
|||
|
|||
class Fp8Einsum(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: we already have the Quantization superclass above. What if we added a einsum
method to that and then added this logic as part of the einsum
implementation within Fp8Quantization
. We could avoid some of that branching we need in linears
up above that way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Emm, I tried to move this logic to Fp8Quantization. But the trick part is that we need to maintain some variables. So, I use this inherited nn.Module. When moving the logic to Fp8Quantization and I always got errors like TypeError: Can't call __hash__ on modules that hold variables.
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I think we can keep this Fp8Einsum
class. I was thinking we could just have the einsum
method on Fp8Quantization
return an instance of this Fp8Einsum
class, instead of that if/else logic that is currently present in linears.py
Not a big deal if it doesn't work - I don't want to hold up this feature on stylistic code stuff that we can clean up later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this can be done. PTAL.
Yes, we have tested on |
MaxText/layers/linears.py
Outdated
@@ -583,7 +583,11 @@ def aqt_einsum(*args, **kwargs): | |||
# simply skip kwargs, since aqt einsum doesn't support any kwargs like precision | |||
return self.quant.einsum(rhs_mesh_axes)(*args) | |||
|
|||
einsum_op = aqt_einsum | |||
# We need a separate way to retrieve the quantized einsum for fp8 config. | |||
if isinstance(self.quant, quantizations.Fp8Quantization): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So what I'm thinking here is we could instead habe
def quant_einsum(*args, **kwargs):
# or something similar depending on how we structure the einsum method on Quantization
return self.quant.einsum(rhs_mesh_axes)(*args)
einsum_op = quant_einsum
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Fp8Quantization impl of einsum
can just return Fp8Einsum
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to Anfal's comments. That's a great idea!
We can do it in a followup PR if we want to get this feature in soon.
@@ -201,6 +204,87 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): | |||
"""Returns dot_general configured with aqt params.""" | |||
return nn.Fp8DotGeneralOp | |||
|
|||
def einsum(self, dtype: DType = jnp.float32): | |||
return Fp8Einsum(dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! One nit: could you add einsum
as an abstract method to the Quantization
class? For consistency? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. PTAL.
@@ -60,6 +63,10 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): | |||
"""Placeholder for dot_general implementation in subclasses.""" | |||
pass | |||
|
|||
def einsum(self, dtype: DType = jnp.float32): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we'll prob want to generify the arguments for this abstract method later, but this is good for now!
@anfals Anything else we can do on our side to help merge? |
Can you squash all your commits into 1? Then I can kick off the workflows and @yangyuwei or @gobbleturk can give code owner approval :). Thank you! |
Done. PTAL. @anfals |
Hi @kaixih, I tried running a mixtral 8x7b job with quantization: fp8 off this PR and faced |
@michelle-yooh Can you try it with |
@anfals Rebase is done. Can you help merging? |
@michelle-yooh Once you have verified things on your end, can you ping me? I can help with the merge. Thanks! |
Gentle ping. Any update? @anfals @michelle-yooh |
i started up the workflow checks. @michelle-yooh if you verified this fixes the issues you were facing, I can merge this today. thanks! |
We talked about this in the sync and @michelle-yooh said to merge in and she'll run some tests when she gets bandwidth and follow up. Looks like the linter is failing. @kaixih could you try running
I'll see if I can update the description to get that other check working. Thanks! |
Minimize branches Add abstract method Format
Manually fixed the format issues. (Surprised that we now use 125 as the line-length max instead of 80). Hope the tests look good now. |
Description
This PR enables MoE with Fp8 quantization for NVIDIA GPUs.
cc. @kocchop @nouiz
Tests
Tested on H100
Checklist
Before submitting this PR, please make sure (put X in square brackets):