Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Support FP8 quantization for MOE layers #1221

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,9 @@ def get_einsum(self, rhs_mesh_axes: Tuple[Optional[str], ...] = (), einsum_name=

def aqt_einsum(*args, **kwargs):
# simply skip kwargs, since aqt einsum doesn't support any kwargs like precision
return self.quant.einsum(rhs_mesh_axes)(*args)
is_aqt = not isinstance(self.quant, quantizations.Fp8Quantization)
kw = {"mesh_axes": rhs_mesh_axes} if is_aqt else {"dtype": self.dtype}
return self.quant.einsum(**kw)(*args)

einsum_op = aqt_einsum
else:
Expand Down
72 changes: 72 additions & 0 deletions MaxText/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from aqt.jax.v2 import calibration
import common_types
from dataclasses import dataclass
from flax.linen import fp8_ops
from flax.linen import initializers as flax_initializers
import flax.linen as nn
import jax
import jax.numpy as jnp
Expand All @@ -45,6 +47,7 @@

Array = common_types.Array
Config = common_types.Config
DType = common_types.DType
AxisIdxes = common_types.AxisIdxes
AxisNames = common_types.AxisNames
CACHE_HEADS = common_types.CACHE_HEADS
Expand All @@ -60,6 +63,10 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
"""Placeholder for dot_general implementation in subclasses."""
pass

def einsum(self, dtype: DType = jnp.float32):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll prob want to generify the arguments for this abstract method later, but this is good for now!

"""Placeholder for einsum implementation in subclasses."""
pass


def _tiling_fn(lhs, rhs, dimension_numbers, tile_size):
del lhs, rhs
Expand Down Expand Up @@ -201,6 +208,71 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
"""Returns dot_general configured with aqt params."""
return nn.Fp8DotGeneralOp

def einsum(self, dtype: DType = jnp.float32):
return Fp8Einsum(dtype=dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! One nit: could you add einsum as an abstract method to the Quantization class? For consistency? Thanks!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. PTAL.



class Fp8Einsum(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we already have the Quantization superclass above. What if we added a einsum method to that and then added this logic as part of the einsum implementation within Fp8Quantization. We could avoid some of that branching we need in linears up above that way?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Emm, I tried to move this logic to Fp8Quantization. But the trick part is that we need to maintain some variables. So, I use this inherited nn.Module. When moving the logic to Fp8Quantization and I always got errors like TypeError: Can't call __hash__ on modules that hold variables..

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think we can keep this Fp8Einsum class. I was thinking we could just have the einsum method on Fp8Quantization return an instance of this Fp8Einsum class, instead of that if/else logic that is currently present in linears.py

Not a big deal if it doesn't work - I don't want to hold up this feature on stylistic code stuff that we can clean up later.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this can be done. PTAL.

"""An fp8 einsum op.

Attributes:
amax_history_length: size of the amax history.
e4m3_dtype: e4m3 variants, e.g., e4m3fn, e4m3fnuz.
e5m2_dtype: e5m2 variants, e.g., e5m2, e5m2fnuz.
dtype: computation dtype.
"""

amax_history_length: int = 1024
e4m3_dtype: DType = jnp.float8_e4m3fn
e5m2_dtype: DType = jnp.float8_e5m2
dtype: DType = jnp.float32

def setup(self) -> None:
scale_args = (
flax_initializers.ones_init(),
jax.random.PRNGKey(0),
(1,),
jnp.float32,
)
amax_history_args = (
flax_initializers.zeros_init(),
jax.random.PRNGKey(0),
(self.amax_history_length,),
jnp.float32,
)

OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient"
self.input_amax_history = self.variable(OVERWRITE_WITH_GRADIENT, "input_amax_history", *amax_history_args)
self.kernel_amax_history = self.variable(OVERWRITE_WITH_GRADIENT, "kernel_amax_history", *amax_history_args)
self.output_grad_amax_history = self.variable(OVERWRITE_WITH_GRADIENT, "output_grad_amax_history", *amax_history_args)

self.input_scale = self.variable(OVERWRITE_WITH_GRADIENT, "input_scale", *scale_args)
self.kernel_scale = self.variable(OVERWRITE_WITH_GRADIENT, "kernel_scale", *scale_args)
self.output_grad_scale = self.variable(OVERWRITE_WITH_GRADIENT, "output_grad_scale", *scale_args)

def __call__(self, eqn, *args, **kwargs):
assert len(args) == 2
x = args[0]
k = args[1]

comp_dtype = self.dtype
k = jnp.asarray(k, comp_dtype)
x = jnp.asarray(x, comp_dtype)

x_qdq = fp8_ops.in_qdq(comp_dtype, self.e4m3_dtype, x, self.input_scale.value, self.input_amax_history.value)
k_qdq = fp8_ops.in_qdq(comp_dtype, self.e4m3_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value)

y_qdq = jnp.einsum(eqn, x_qdq, k_qdq, _dot_general=fp8_ops.dot_general_with_precision)

y = fp8_ops.out_qdq(
comp_dtype,
self.e5m2_dtype,
y_qdq,
self.output_grad_scale.value,
self.output_grad_amax_history.value,
)
return y


def _get_int8_quant_config(config):
drhs_bits = None
Expand Down
Loading