diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 17a76a750d..c2eb66960f 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -11,6 +11,7 @@ import torch import torch.nn as nn +from torchao.prototype.mx_formats.config import MXLinearConfig from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES from torchao.prototype.mx_formats.mx_linear import ( MXInferenceLinear, @@ -59,8 +60,13 @@ def test_linear_eager(elem_dtype, bias, input_shape): nn.Linear(8, 6, bias=bias, device="cuda"), ) m_mx = copy.deepcopy(m) - block_size = 2 - swap_linear_with_mx_linear(m_mx, *elem_dtype, block_size=block_size) + config = MXLinearConfig( + block_size=2, + elem_dtype=elem_dtype[0], + elem_dtype_weight_override=elem_dtype[1], + elem_dtype_grad_output_override=elem_dtype[2], + ) + swap_linear_with_mx_linear(m_mx, config=config) x_ref = torch.randn(*input_shape, device="cuda").requires_grad_() x = copy.deepcopy(x_ref) @@ -97,8 +103,8 @@ def test_activation_checkpointing(): nn.Linear(4, 6, bias=True, device="cuda"), nn.Linear(6, 6, bias=True, device="cuda"), ) - block_size = 2 - swap_linear_with_mx_linear(m, elem_dtype, block_size=block_size) + config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype) + swap_linear_with_mx_linear(m, config=config) x = torch.randn(*input_shape, device="cuda").requires_grad_() g = torch.randn(*grad_shape, device="cuda") @@ -133,8 +139,8 @@ def test_linear_compile(elem_dtype, bias, use_autocast): m_mx = nn.Sequential( nn.Linear(K, N, bias=bias, device="cuda"), ) - block_size = 2 - swap_linear_with_mx_linear(m_mx, elem_dtype, block_size=block_size) + config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype) + swap_linear_with_mx_linear(m_mx, config=config) m_mx_c = copy.deepcopy(m_mx) m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor") @@ -181,8 +187,8 @@ def test_inference_linear(elem_dtype, bias, input_shape): m = nn.Sequential(nn.Linear(4, 6, bias=bias, dtype=torch.bfloat16)) m = m.cuda() m_mx = copy.deepcopy(m) - block_size = 2 - swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size) + config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype) + swap_linear_with_mx_inference_linear(m_mx, config=config) x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16) y_ref = m(x) @@ -209,8 +215,8 @@ def test_inference_compile_simple(elem_dtype): m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16)) m = m.cuda() m_mx = copy.deepcopy(m) - block_size = 2 - swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size) + config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype) + swap_linear_with_mx_inference_linear(m_mx, config=config) m_mx = torch.compile(m_mx, fullgraph="true") x = torch.randn(2, 4, device="cuda", dtype=torch.bfloat16) @@ -223,20 +229,6 @@ def test_inference_compile_simple(elem_dtype): assert sqnr >= 13.5 -def test_mx_linear_input_weight_gradient_dtypes(): - m = nn.Sequential(nn.Linear(32, 32)) - swap_linear_with_mx_linear(m, *SUPPORTED_ELEM_DTYPES[:3], block_size=32) - assert m[0].in_elem_dtype == SUPPORTED_ELEM_DTYPES[0] - assert m[0].w_elem_dtype == SUPPORTED_ELEM_DTYPES[1] - assert m[0].grad_elem_dtype == SUPPORTED_ELEM_DTYPES[2] - - m = nn.Sequential(nn.Linear(32, 32)) - swap_linear_with_mx_linear(m, torch.float8_e4m3fn, block_size=32) - assert m[0].in_elem_dtype == torch.float8_e4m3fn - assert m[0].w_elem_dtype == torch.float8_e4m3fn - assert m[0].grad_elem_dtype == torch.float8_e4m3fn - - def test_filter_fn(): m1 = nn.Sequential( nn.Linear(32, 32), @@ -245,12 +237,11 @@ def test_filter_fn(): m2 = copy.deepcopy(m1) filter_fn = lambda mod, fqn: fqn != "1" # noqa: E731 - swap_linear_with_mx_linear( - m1, torch.float8_e4m3fn, block_size=32, filter_fn=filter_fn - ) + config = MXLinearConfig(block_size=32) + swap_linear_with_mx_linear(m1, config=config, filter_fn=filter_fn) assert type(m1[0]) == MXLinear assert type(m1[1]) == torch.nn.Linear - swap_linear_with_mx_inference_linear(m2, torch.float8_e4m3fn, 32, filter_fn) # noqa: E501 + swap_linear_with_mx_inference_linear(m2, config=config, filter_fn=filter_fn) # noqa: E501 assert type(m2[0]) == MXInferenceLinear assert type(m2[1]) == torch.nn.Linear diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index 32f45e3755..09e7563ebb 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -41,10 +41,11 @@ This is a module to do MX training, the MX matmul is currently emulated. ```python from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear +from torchao.prototype.mx_formats.config import MXLinearConfig m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() -elem_dtype = torch.float8_e4m3fn -swap_linear_with_mx_linear(m, elem_dtype, block_size=32) +config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32) +swap_linear_with_mx_linear(m, config=config) # training loop (not shown) ``` @@ -55,11 +56,11 @@ This is a module to do MX inference, weights are in MX and matmul is in high pre ```python from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_inference_linear +from torchao.prototype.mx_formats.config import MXLinearConfig m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() -elem_dtype = torch.float8_e4m3fn -block_size = 32 -swap_linear_with_mx_inference_linear(m, elem_dtype, block_size) +config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32) +swap_linear_with_mx_inference_linear(m, config=config) # do inference (not shown) ``` diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 7b68b5b6a5..7cdf2d4e58 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -5,9 +5,40 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass +from typing import Any, Optional + +import torch + +from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES @dataclass class MXLinearConfig: + # block size for scaling, default is 32 to match + # https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, + # section 5.2 + block_size: int = 32 + + # element dtype, used for activations, weights and gradients + elem_dtype: Any = torch.float8_e4m3fn + + # overrides for element dtype for weights and gradients + # TODO(future PR): refactor to make this cleaner + elem_dtype_weight_override: Optional[Any] = None + elem_dtype_grad_output_override: Optional[Any] = None + # If True, uses a custom triton kernel for fp4 dequantize use_fp4_custom_triton_dequant_kernel: bool = False + + def __post_init__(self): + assert ( + self.elem_dtype in SUPPORTED_ELEM_DTYPES + ), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" + if self.elem_dtype_weight_override is not None: + assert ( + self.elem_dtype_weight_override in SUPPORTED_ELEM_DTYPES + ), f"elem_dtype_weight_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" + if self.elem_dtype_grad_output_override is not None: + assert ( + self.elem_dtype_grad_output_override in SUPPORTED_ELEM_DTYPES + ), f"elem_dtype_grad_output_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index 72c2b6ab39..a38a8c5499 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -107,22 +107,11 @@ class MXLinear(torch.nn.Linear): def from_float( cls, mod, - elem_dtype, - elem_dtype_weight_override=None, - elem_dtype_grad_output_override=None, - *, - # TODO(next PR): move elem_dtype* and block size into config - config: MXLinearConfig = None, - block_size=32, + config: Optional[MXLinearConfig] = MXLinearConfig(), ): + # TODO(before land): remove this + assert isinstance(config, MXLinearConfig) mod.__class__ = MXLinear - mod.in_elem_dtype = elem_dtype - mod.w_elem_dtype = elem_dtype_weight_override or elem_dtype - mod.grad_elem_dtype = elem_dtype_grad_output_override or elem_dtype - mod.block_size = block_size - # TODO(next PR): fix this - if config is None: - config = MXLinearConfig() mod.config = config return mod @@ -135,13 +124,14 @@ def forward(self, x): else: w = self.weight + config = self.config y = mx_mm.apply( x, w, - self.in_elem_dtype, - self.w_elem_dtype, - self.grad_elem_dtype, - self.block_size, + config.elem_dtype, + config.elem_dtype_weight_override or config.elem_dtype, + config.elem_dtype_grad_output_override or config.elem_dtype, + config.block_size, ) if self.bias is not None: y = y + self.bias @@ -158,9 +148,11 @@ class MXInferenceLinear(torch.nn.Linear): @classmethod @torch.no_grad() - def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig): - # TODO(next PR): move elem_dtype and block_size into config - + def from_float( + cls, + mod, + config: Optional[MXLinearConfig] = MXLinearConfig(), + ): with torch.device("meta"): super_kwargs = { "in_features": mod.in_features, @@ -171,10 +163,9 @@ def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig): # TODO(future PR): set to new_mod.weight directly, will need to work # through some errors new_mod.weight_mx = MXTensor.to_mx( - mod.weight, elem_dtype, block_size=block_size + mod.weight, config.elem_dtype, block_size=config.block_size ) new_mod.bias = mod.bias - new_mod.elem_dtype = elem_dtype new_mod.config = config return new_mod @@ -213,13 +204,8 @@ def _is_linear(mod, fqn): def swap_linear_with_mx_linear( model, - elem_dtype, - elem_dtype_weight_override=None, - elem_dtype_grad_output_override=None, *, - # TODO(next PR): move elem_dtype* and block_size into config config: Optional[MXLinearConfig] = None, - block_size=32, filter_fn=None, ): if filter_fn is None: @@ -232,24 +218,16 @@ def __fn(mod, fqn): combined_filter_fn = __fn replace_with_custom_fn_if_matches_filter( model, - lambda mod: MXLinear.from_float( - mod, - elem_dtype, - elem_dtype_weight_override, - elem_dtype_grad_output_override, - config=config, - block_size=block_size, - ), + lambda mod: MXLinear.from_float(mod, config=config), combined_filter_fn, ) def swap_linear_with_mx_inference_linear( model, - elem_dtype, - block_size, - filter_fn=None, + *, config: Optional[MXLinearConfig] = None, + filter_fn=None, ): if filter_fn is None: combined_filter_fn = _is_linear @@ -261,8 +239,6 @@ def __fn(mod, fqn): combined_filter_fn = __fn replace_with_custom_fn_if_matches_filter( model, - lambda mod: MXInferenceLinear.from_float( - mod, elem_dtype, block_size, config=config - ), + lambda mod: MXInferenceLinear.from_float(mod, config=config), combined_filter_fn, )