-
Notifications
You must be signed in to change notification settings - Fork 321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVIDIA] Support FP8 quantization for MOE layers #1221
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,8 @@ | |
from aqt.jax.v2 import calibration | ||
import common_types | ||
from dataclasses import dataclass | ||
from flax.linen import fp8_ops | ||
from flax.linen import initializers as flax_initializers | ||
import flax.linen as nn | ||
import jax | ||
import jax.numpy as jnp | ||
|
@@ -45,6 +47,7 @@ | |
|
||
Array = common_types.Array | ||
Config = common_types.Config | ||
DType = common_types.DType | ||
AxisIdxes = common_types.AxisIdxes | ||
AxisNames = common_types.AxisNames | ||
CACHE_HEADS = common_types.CACHE_HEADS | ||
|
@@ -60,6 +63,10 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): | |
"""Placeholder for dot_general implementation in subclasses.""" | ||
pass | ||
|
||
def einsum(self, dtype: DType = jnp.float32): | ||
"""Placeholder for einsum implementation in subclasses.""" | ||
pass | ||
|
||
|
||
def _tiling_fn(lhs, rhs, dimension_numbers, tile_size): | ||
del lhs, rhs | ||
|
@@ -201,6 +208,71 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): | |
"""Returns dot_general configured with aqt params.""" | ||
return nn.Fp8DotGeneralOp | ||
|
||
def einsum(self, dtype: DType = jnp.float32): | ||
return Fp8Einsum(dtype=dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks great! One nit: could you add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. PTAL. |
||
|
||
|
||
class Fp8Einsum(nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: we already have the Quantization superclass above. What if we added a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Emm, I tried to move this logic to Fp8Quantization. But the trick part is that we need to maintain some variables. So, I use this inherited nn.Module. When moving the logic to Fp8Quantization and I always got errors like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I think we can keep this Not a big deal if it doesn't work - I don't want to hold up this feature on stylistic code stuff that we can clean up later. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this can be done. PTAL. |
||
"""An fp8 einsum op. | ||
|
||
Attributes: | ||
amax_history_length: size of the amax history. | ||
e4m3_dtype: e4m3 variants, e.g., e4m3fn, e4m3fnuz. | ||
e5m2_dtype: e5m2 variants, e.g., e5m2, e5m2fnuz. | ||
dtype: computation dtype. | ||
""" | ||
|
||
amax_history_length: int = 1024 | ||
e4m3_dtype: DType = jnp.float8_e4m3fn | ||
e5m2_dtype: DType = jnp.float8_e5m2 | ||
dtype: DType = jnp.float32 | ||
|
||
def setup(self) -> None: | ||
scale_args = ( | ||
flax_initializers.ones_init(), | ||
jax.random.PRNGKey(0), | ||
(1,), | ||
jnp.float32, | ||
) | ||
amax_history_args = ( | ||
flax_initializers.zeros_init(), | ||
jax.random.PRNGKey(0), | ||
(self.amax_history_length,), | ||
jnp.float32, | ||
) | ||
|
||
OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" | ||
self.input_amax_history = self.variable(OVERWRITE_WITH_GRADIENT, "input_amax_history", *amax_history_args) | ||
self.kernel_amax_history = self.variable(OVERWRITE_WITH_GRADIENT, "kernel_amax_history", *amax_history_args) | ||
self.output_grad_amax_history = self.variable(OVERWRITE_WITH_GRADIENT, "output_grad_amax_history", *amax_history_args) | ||
|
||
self.input_scale = self.variable(OVERWRITE_WITH_GRADIENT, "input_scale", *scale_args) | ||
self.kernel_scale = self.variable(OVERWRITE_WITH_GRADIENT, "kernel_scale", *scale_args) | ||
self.output_grad_scale = self.variable(OVERWRITE_WITH_GRADIENT, "output_grad_scale", *scale_args) | ||
|
||
def __call__(self, eqn, *args, **kwargs): | ||
assert len(args) == 2 | ||
x = args[0] | ||
k = args[1] | ||
|
||
comp_dtype = self.dtype | ||
k = jnp.asarray(k, comp_dtype) | ||
x = jnp.asarray(x, comp_dtype) | ||
|
||
x_qdq = fp8_ops.in_qdq(comp_dtype, self.e4m3_dtype, x, self.input_scale.value, self.input_amax_history.value) | ||
k_qdq = fp8_ops.in_qdq(comp_dtype, self.e4m3_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value) | ||
|
||
y_qdq = jnp.einsum(eqn, x_qdq, k_qdq, _dot_general=fp8_ops.dot_general_with_precision) | ||
|
||
y = fp8_ops.out_qdq( | ||
comp_dtype, | ||
self.e5m2_dtype, | ||
y_qdq, | ||
self.output_grad_scale.value, | ||
self.output_grad_amax_history.value, | ||
) | ||
return y | ||
|
||
|
||
def _get_int8_quant_config(config): | ||
drhs_bits = None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we'll prob want to generify the arguments for this abstract method later, but this is good for now!