Skip to content

Commit

Permalink
Support fp8 quant for Moe layer
Browse files Browse the repository at this point in the history
Minimize branches

Add abstract method
  • Loading branch information
kaixih committed Feb 12, 2025
1 parent d0270c7 commit 46cfefa
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 1 deletion.
4 changes: 3 additions & 1 deletion MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,9 @@ def get_einsum(self, rhs_mesh_axes: Tuple[Optional[str], ...] = (), einsum_name=

def aqt_einsum(*args, **kwargs):
# simply skip kwargs, since aqt einsum doesn't support any kwargs like precision
return self.quant.einsum(rhs_mesh_axes)(*args)
is_aqt = not isinstance(self.quant, quantizations.Fp8Quantization)
kw = {'mesh_axes': rhs_mesh_axes} if is_aqt else {"dtype": self.dtype}
return self.quant.einsum(**kw)(*args)

einsum_op = aqt_einsum
else:
Expand Down
88 changes: 88 additions & 0 deletions MaxText/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from aqt.jax.v2 import calibration
import common_types
from dataclasses import dataclass
from flax.linen import fp8_ops
from flax.linen import initializers as flax_initializers
import flax.linen as nn
import jax
import jax.numpy as jnp
Expand All @@ -45,6 +47,7 @@

Array = common_types.Array
Config = common_types.Config
DType = common_types.DType
AxisIdxes = common_types.AxisIdxes
AxisNames = common_types.AxisNames
CACHE_HEADS = common_types.CACHE_HEADS
Expand All @@ -60,6 +63,10 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
"""Placeholder for dot_general implementation in subclasses."""
pass

def einsum(self, dtype: DType = jnp.float32):
"""Placeholder for einsum implementation in subclasses."""
pass


def _tiling_fn(lhs, rhs, dimension_numbers, tile_size):
del lhs, rhs
Expand Down Expand Up @@ -201,6 +208,87 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
"""Returns dot_general configured with aqt params."""
return nn.Fp8DotGeneralOp

def einsum(self, dtype: DType = jnp.float32):
return Fp8Einsum(dtype=dtype)


class Fp8Einsum(nn.Module):
"""An fp8 einsum op.
Attributes:
amax_history_length: size of the amax history.
e4m3_dtype: e4m3 variants, e.g., e4m3fn, e4m3fnuz.
e5m2_dtype: e5m2 variants, e.g., e5m2, e5m2fnuz.
dtype: computation dtype.
"""
amax_history_length: int = 1024
e4m3_dtype: DType = jnp.float8_e4m3fn
e5m2_dtype: DType = jnp.float8_e5m2
dtype: DType = jnp.float32

def setup(self) -> None:
scale_args = (
flax_initializers.ones_init(),
jax.random.PRNGKey(0),
(1,),
jnp.float32,
)
amax_history_args = (
flax_initializers.zeros_init(),
jax.random.PRNGKey(0),
(self.amax_history_length,),
jnp.float32,
)

OVERWRITE_WITH_GRADIENT = '_overwrite_with_gradient'
self.input_amax_history = self.variable(
OVERWRITE_WITH_GRADIENT, 'input_amax_history', *amax_history_args
)
self.kernel_amax_history = self.variable(
OVERWRITE_WITH_GRADIENT, 'kernel_amax_history', *amax_history_args
)
self.output_grad_amax_history = self.variable(
OVERWRITE_WITH_GRADIENT, 'output_grad_amax_history', *amax_history_args
)

self.input_scale = self.variable(
OVERWRITE_WITH_GRADIENT, 'input_scale', *scale_args
)
self.kernel_scale = self.variable(
OVERWRITE_WITH_GRADIENT, 'kernel_scale', *scale_args
)
self.output_grad_scale = self.variable(
OVERWRITE_WITH_GRADIENT, 'output_grad_scale', *scale_args
)

def __call__(self, eqn, *args, **kwargs):
assert len(args) == 2
x = args[0]
k = args[1]

comp_dtype = self.dtype
k = jnp.asarray(k, comp_dtype)
x = jnp.asarray(x, comp_dtype)

x_qdq = fp8_ops.in_qdq(
comp_dtype, self.e4m3_dtype, x, self.input_scale.value,
self.input_amax_history.value
)
k_qdq = fp8_ops.in_qdq(
comp_dtype, self.e4m3_dtype, k, self.kernel_scale.value,
self.kernel_amax_history.value
)

y_qdq = jnp.einsum(
eqn, x_qdq, k_qdq, _dot_general=fp8_ops.dot_general_with_precision
)

y = fp8_ops.out_qdq(
comp_dtype, self.e5m2_dtype, y_qdq, self.output_grad_scale.value,
self.output_grad_amax_history.value,
)
return y


def _get_int8_quant_config(config):
drhs_bits = None
Expand Down

0 comments on commit 46cfefa

Please sign in to comment.