From c83452072fd2c228ca66e9861533af55010db326 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 29 Jan 2025 20:32:05 -0800 Subject: [PATCH 1/6] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/mx_tensor.py | 60 ++++++++++++++++++++--- 1 file changed, 54 insertions(+), 6 deletions(-) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 8eeeaf8bfd..1581628d58 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -16,6 +16,7 @@ * Zeros: N/A """ +from enum import Enum, auto from typing import Dict, Union import torch @@ -53,11 +54,38 @@ unpack_uint4, ) +# TODO(later): read from somewhere else? +SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23 +EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1 +EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3 +EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2 +EBITS_F8_E4M3, MBITS_F8_E4M3 = 4, 3 +EBITS_F8_E5M2, MBITS_F8_E5M2 = 5, 2 + + +class ScaleCalculationMode(Enum): + """ + Enum representing the different methods for calculating MX block scaling. + There are three methods available: + FLOOR: This method is recommended by the OCP MX Spec 1.0 and uses X = 2^floor(log2(max_abs(v))-max_exp). + It result in overflow issues for large values and bad for gradient quantization. + CEIL: This method avoids overflow issues, but small values may shift to 0 due to a large scaling factor. + It uses X = 2^ceil(log2(max_abs(v))-max_exp). + EVEN: This method is a trade-off between Option 1 and Option 2. It uses X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)). + It provides better accuracy for MX4 training compared to FLOOR and CEIL. + By default, we use the EVEN method for better accuracy. + """ + + FLOOR = auto() + CEIL = auto() + EVEN = auto() + def to_mx( data_hp: torch.Tensor, elem_dtype: Union[torch.dtype, str], block_size: int, + scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, ): """ Takes a high precision tensor and converts to MX scale and raw data, in @@ -88,25 +116,45 @@ def to_mx( # where the values are zero. eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype) - # Find largest power of 2 less than or equal to max_abs. - largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs + eps)) - # Set X to be the largest power-of-two less than or equal to # max_abs(v), divided by the largest power of two representable - # in the element data type + # in the element data type, and get the mbits at the same time if elem_dtype == torch.float8_e4m3fn: target_max_pow2 = F8E4M3_MAX_POW2 + mbits = MBITS_F8_E4M3 elif elem_dtype == torch.float8_e5m2: target_max_pow2 = F8E5M2_MAX_POW2 + mbits = MBITS_F8_E5M2 elif elem_dtype == DTYPE_FP6_E2M3: target_max_pow2 = F6_E2M3_MAX_POW2 + mbits = MBITS_F6_E2M3 elif elem_dtype == DTYPE_FP6_E3M2: target_max_pow2 = F6_E3M2_MAX_POW2 + mbits = MBITS_F6_E3M2 elif elem_dtype == DTYPE_FP4: target_max_pow2 = F4_E2M1_MAX_POW2 + mbits = MBITS_F4_E2M1 else: - raise AssertionError("unsupported") - scale_e8m0_unbiased = largest_p2_lt_max_abs - target_max_pow2 + raise AssertionError("unsupported element dtype") + + # rounding before calculating the largest power of 2 + # X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)) + if scaling_mode == ScaleCalculationMode.EVEN: + nan_mask = torch.isnan(max_abs) + max_abs = max_abs.to(torch.float32).view(torch.int32) + val_to_add = 1 << (MBITS_F32 - mbits - 1) + mask = ((1 << (EBITS_F32 + SBITS)) - 1) << MBITS_F32 + max_abs = (max_abs + val_to_add) & mask + max_abs = max_abs.view(torch.float32) + max_abs[nan_mask] = torch.tensor(float("nan"), device=max_abs.device) + + # Calculate the scale for different modes + if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN): + scale_e8m0_unbiased = torch.floor(torch.log2(max_abs + eps)) - target_max_pow2 + elif scaling_mode == ScaleCalculationMode.CEIL: + scale_e8m0_unbiased = torch.ceil(torch.log2(max_abs + eps)) - target_max_pow2 + else: + raise AssertionError("unsupported scaling calculation mode") # Clamp to exponents that can be represented in e8m0 scale_e8m0_unbiased = torch.clamp( From 85da2973ba29d4c6978416dbe33896ef8ecf40fd Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 29 Jan 2025 20:53:42 -0800 Subject: [PATCH 2/6] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 16 ++++++++++++++-- torchao/prototype/mx_formats/mx_tensor.py | 11 +++++++---- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index ae87ee021e..2bad17a13d 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -18,6 +18,7 @@ from torchao.prototype.mx_formats.mx_tensor import ( E8M0_EXPONENT_NAN_VAL, MXTensor, + ScaleCalculationMode, to_dtype, ) from torchao.quantization.utils import compute_error @@ -43,8 +44,10 @@ def run_before_and_after_tests(): torch._dynamo.reset() -def _test_mx(data_hp, elem_dtype, block_size): - data_mx = MXTensor.to_mx(data_hp, elem_dtype, block_size) +def _test_mx( + data_hp, elem_dtype, block_size, scale_calculation_mode=ScaleCalculationMode.FLOOR +): + data_mx = MXTensor.to_mx(data_hp, elem_dtype, block_size, scale_calculation_mode) data_mx_dq = data_mx.to_dtype(data_hp.dtype) def assert_sqnr_gt_threshold(orig, new, threshold): @@ -70,6 +73,15 @@ def test_hello_world(elem_dtype): _test_mx(data, elem_dtype, block_size) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("scale_calculation_mode", [s for s in ScaleCalculationMode]) +@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +def test_realistic_numerics(elem_dtype, scale_calculation_mode): + data = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) + block_size = 32 + _test_mx(data, elem_dtype, block_size, scale_calculation_mode) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_all_zeros(elem_dtype): diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 1581628d58..801f29ac3c 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -318,15 +318,17 @@ class ToMXConstrFunc(torch.autograd.Function): """ @staticmethod - def forward(ctx, data_hp, elem_dtype, block_size): - scale_e8m0_biased, data_lp = to_mx(data_hp, elem_dtype, block_size) + def forward(ctx, data_hp, elem_dtype, block_size, scaling_mode): + scale_e8m0_biased, data_lp = to_mx( + data_hp, elem_dtype, block_size, scaling_mode + ) return MXTensor( scale_e8m0_biased, data_lp, elem_dtype, block_size, data_hp.dtype ) @staticmethod def backward(ctx, g): - return g, None, None + return g, None, None, None @torch._dynamo.allow_in_graph @@ -440,8 +442,9 @@ def to_mx( data_hp: torch.Tensor, elem_dtype: Union[torch.dtype, str], block_size: int = BLOCK_SIZE_DEFAULT, + scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, ): - return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size) + return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size, scaling_mode) def __tensor_flatten__(self): ctx = { From 0220b1923b448a91a193e89e0331c6664fd7e1a0 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 10 Feb 2025 09:17:20 -0800 Subject: [PATCH 3/6] Update [ghstack-poisoned] --- third_party/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/cutlass b/third_party/cutlass index bf9da7b76c..b78588d163 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit bf9da7b76c766d7ee7d536afc77880a4ef1f1156 +Subproject commit b78588d1630aa6643bf021613717bafb705df4ef From e4b5dedac3006777d21b8466ebc38c7e25eaf3c7 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 10 Feb 2025 09:57:17 -0800 Subject: [PATCH 4/6] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 20 +++++---- torchao/prototype/mx_formats/config.py | 15 ++++++- torchao/prototype/mx_formats/mx_linear.py | 22 ++++++++-- torchao/prototype/mx_formats/mx_ops.py | 6 ++- torchao/prototype/mx_formats/mx_tensor.py | 46 +++++++++++++++++---- 5 files changed, 88 insertions(+), 21 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index ad718beb9c..9e97e1c32b 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -7,7 +7,6 @@ import pytest import torch -from torchao.prototype.mx_formats import config from torchao.prototype.mx_formats.constants import ( DTYPE_FP4, DTYPE_FP6_E2M3, @@ -139,8 +138,14 @@ def test_exponent_nan_out(elem_dtype): else: raise AssertionError("unsupported") block_size = 2 + use_fp4_custom_triton_dequant_kernel = False tensor_mx = MXTensor( - scale_e8m0_bits, data_bits, elem_dtype, block_size, torch.float + scale_e8m0_bits, + data_bits, + elem_dtype, + block_size, + torch.float, + use_fp4_custom_triton_dequant_kernel, ) tensor_hp = tensor_mx.to_dtype(torch.float) assert torch.all(torch.isnan(tensor_hp[0:1])) @@ -188,15 +193,16 @@ def test_transpose(elem_dtype, fp4_triton): M, K = 128, 256 block_size = 32 tensor_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) - tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size) - config.use_fp4_custom_triton_dequant_kernel = fp4_triton + tensor_mx = MXTensor.to_mx( + tensor_hp, + elem_dtype, + block_size, + use_fp4_custom_triton_dequant_kernel=fp4_triton, + ) tensor_mx_dq_t = tensor_mx.to_dtype(tensor_hp.dtype).t() - config.use_fp4_custom_triton_dequant_kernel = False tensor_mx_t = tensor_mx.t() - config.use_fp4_custom_triton_dequant_kernel = fp4_triton tensor_mx_t_dq = tensor_mx_t.to_dtype(tensor_hp.dtype) - config.use_fp4_custom_triton_dequant_kernel = False assert tensor_mx_dq_t.shape == tensor_mx_t_dq.shape torch.testing.assert_close(tensor_mx_dq_t, tensor_mx_t_dq, atol=0, rtol=0) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 3e7e03d8f6..7b68b5b6a5 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -1,2 +1,13 @@ -# If True, uses a custom triton kernel for fp4 dequantize -use_fp4_custom_triton_dequant_kernel = False +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + + +@dataclass +class MXLinearConfig: + # If True, uses a custom triton kernel for fp4 dequantize + use_fp4_custom_triton_dequant_kernel: bool = False diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index d7aa744334..72c2b6ab39 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -8,11 +8,12 @@ Defines the prototype UX for converting a model to use mx weights """ -from typing import Any +from typing import Any, Optional import torch import torch.nn.functional as F +from torchao.prototype.mx_formats.config import MXLinearConfig from torchao.prototype.mx_formats.mx_tensor import MXTensor @@ -110,6 +111,8 @@ def from_float( elem_dtype_weight_override=None, elem_dtype_grad_output_override=None, *, + # TODO(next PR): move elem_dtype* and block size into config + config: MXLinearConfig = None, block_size=32, ): mod.__class__ = MXLinear @@ -117,6 +120,10 @@ def from_float( mod.w_elem_dtype = elem_dtype_weight_override or elem_dtype mod.grad_elem_dtype = elem_dtype_grad_output_override or elem_dtype mod.block_size = block_size + # TODO(next PR): fix this + if config is None: + config = MXLinearConfig() + mod.config = config return mod def forward(self, x): @@ -151,7 +158,9 @@ class MXInferenceLinear(torch.nn.Linear): @classmethod @torch.no_grad() - def from_float(cls, mod, elem_dtype, block_size): + def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig): + # TODO(next PR): move elem_dtype and block_size into config + with torch.device("meta"): super_kwargs = { "in_features": mod.in_features, @@ -166,6 +175,7 @@ def from_float(cls, mod, elem_dtype, block_size): ) new_mod.bias = mod.bias new_mod.elem_dtype = elem_dtype + new_mod.config = config return new_mod @torch.no_grad() @@ -207,6 +217,8 @@ def swap_linear_with_mx_linear( elem_dtype_weight_override=None, elem_dtype_grad_output_override=None, *, + # TODO(next PR): move elem_dtype* and block_size into config + config: Optional[MXLinearConfig] = None, block_size=32, filter_fn=None, ): @@ -225,6 +237,7 @@ def __fn(mod, fqn): elem_dtype, elem_dtype_weight_override, elem_dtype_grad_output_override, + config=config, block_size=block_size, ), combined_filter_fn, @@ -236,6 +249,7 @@ def swap_linear_with_mx_inference_linear( elem_dtype, block_size, filter_fn=None, + config: Optional[MXLinearConfig] = None, ): if filter_fn is None: combined_filter_fn = _is_linear @@ -247,6 +261,8 @@ def __fn(mod, fqn): combined_filter_fn = __fn replace_with_custom_fn_if_matches_filter( model, - lambda mod: MXInferenceLinear.from_float(mod, elem_dtype, block_size), + lambda mod: MXInferenceLinear.from_float( + mod, elem_dtype, block_size, config=config + ), combined_filter_fn, ) diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 57fb0d54b4..5fb3e8c6c0 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -54,6 +54,7 @@ def mx_desugar_op(aten_op, args, kwargs=None): old._elem_dtype, old._block_size, old._orig_dtype, + old._use_fp4_custom_triton_dequant_kernel, ) return new @@ -82,6 +83,7 @@ def mx_t(aten_op, args, kwargs=None): old._elem_dtype, old._block_size, old._orig_dtype, + old._use_fp4_custom_triton_dequant_kernel, ) return new @@ -120,6 +122,7 @@ def mx_view_op(aten_op, args, kwargs=None): args[0]._elem_dtype, args[0]._block_size, args[0]._orig_dtype, + args[0]._use_fp4_custom_triton_dequant_kernel, ) @@ -130,7 +133,6 @@ def autocast_to_copy(aten_op, args, kwargs=None): tensor. """ assert isinstance(args[0], MXTensor) - # print('before', args[0], args[0].dtype, args[0]._orig_dtype) assert ( len(kwargs) == 1 and "dtype" in kwargs ), "Only support dtype kwarg for autocast" @@ -144,6 +146,6 @@ def autocast_to_copy(aten_op, args, kwargs=None): args[0]._elem_dtype, args[0]._block_size, kwargs["dtype"], + args[0]._use_fp4_custom_triton_dequant_kernel, ) - # print('after', res, res.dtype, res._orig_dtype) return res diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 801f29ac3c..838ab2338c 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -21,7 +21,6 @@ import torch -import torchao.prototype.mx_formats.config as config from torchao.prototype.mx_formats.constants import ( BLOCK_SIZE_DEFAULT, DTYPE_FP4, @@ -239,7 +238,14 @@ def get_fp_scale(scale_e8m0): return s_fp -def to_dtype(data_lp, scale_e8m0, elem_dtype, block_size, target_dtype): +def to_dtype( + data_lp, + scale_e8m0, + elem_dtype, + block_size, + target_dtype, + use_fp4_custom_triton_dequant_kernel, +): orig_shape = data_lp.shape is_transposed = not data_lp.is_contiguous() # if the underlying data is transposed, convert to row major before @@ -258,7 +264,7 @@ def to_dtype(data_lp, scale_e8m0, elem_dtype, block_size, target_dtype): data_hp = f6_e3m2_unpacked_to_f32(data_lp) data_hp = data_hp.to(target_dtype) elif elem_dtype == DTYPE_FP4: - if config.use_fp4_custom_triton_dequant_kernel: + if use_fp4_custom_triton_dequant_kernel: data_hp_rescaled = triton_f4_to_scaled_bf16( data_lp, scale_e8m0, @@ -318,17 +324,29 @@ class ToMXConstrFunc(torch.autograd.Function): """ @staticmethod - def forward(ctx, data_hp, elem_dtype, block_size, scaling_mode): + def forward( + ctx, + data_hp, + elem_dtype, + block_size, + scaling_mode, + use_fp4_custom_triton_dequant_kernel, + ): scale_e8m0_biased, data_lp = to_mx( data_hp, elem_dtype, block_size, scaling_mode ) return MXTensor( - scale_e8m0_biased, data_lp, elem_dtype, block_size, data_hp.dtype + scale_e8m0_biased, + data_lp, + elem_dtype, + block_size, + data_hp.dtype, + use_fp4_custom_triton_dequant_kernel, ) @staticmethod def backward(ctx, g): - return g, None, None, None + return g, None, None, None, None @torch._dynamo.allow_in_graph @@ -345,6 +363,7 @@ def forward(ctx, tensor_lp, target_dtype): tensor_lp._elem_dtype, tensor_lp._block_size, target_dtype, + tensor_lp._use_fp4_custom_triton_dequant_kernel, ) @staticmethod @@ -360,6 +379,7 @@ def __new__( elem_dtype, block_size, orig_dtype, + use_fp4_custom_triton_dequant_kernel, ): new_size = data_bits.size() if elem_dtype == DTYPE_FP4: @@ -417,6 +437,9 @@ def __new__( self._elem_dtype = elem_dtype self._block_size = block_size self._orig_dtype = orig_dtype + self._use_fp4_custom_triton_dequant_kernel = ( + use_fp4_custom_triton_dequant_kernel + ) return self def __repr__(self): @@ -443,14 +466,22 @@ def to_mx( elem_dtype: Union[torch.dtype, str], block_size: int = BLOCK_SIZE_DEFAULT, scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, + use_fp4_custom_triton_dequant_kernel: bool = False, ): - return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size, scaling_mode) + return ToMXConstrFunc.apply( + data_hp, + elem_dtype, + block_size, + scaling_mode, + use_fp4_custom_triton_dequant_kernel, + ) def __tensor_flatten__(self): ctx = { "_elem_dtype": self._elem_dtype, "_block_size": self._block_size, "_orig_dtype": self._orig_dtype, + "_use_fp4_custom_triton_dequant_kernel": self._use_fp4_custom_triton_dequant_kernel, } return ["_scale_e8m0", "_data"], ctx @@ -467,6 +498,7 @@ def __tensor_unflatten__( metadata["_elem_dtype"], metadata["_block_size"], metadata["_orig_dtype"], + metadata["_use_fp4_custom_triton_dequant_kernel"], ) # Do not force the MXTensor type on the returned tensor From 7c1166e399c5a4e70fa643bedba0c573f8d0da26 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 10 Feb 2025 11:36:29 -0800 Subject: [PATCH 5/6] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 9e97e1c32b..2a15961586 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -264,12 +264,14 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): to_dtype_c = torch.compile(to_dtype, fullgraph=True) + use_fp4_custom_triton_dequant_kernel = False x_mx_dq = to_dtype( x_mx._data, x_mx._scale_e8m0, x_mx._elem_dtype, x_mx._block_size, hp_dtype, # noqa: E501 + use_fp4_custom_triton_dequant_kernel, ) x_mx_c_dq = to_dtype_c( x_mx_c._data, @@ -277,5 +279,6 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): x_mx_c._elem_dtype, x_mx_c._block_size, hp_dtype, + use_fp4_custom_triton_dequant_kernel, ) torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0) From 8819b28642c36876cbb191d0e33ec94f7f27099b Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 10 Feb 2025 11:36:29 -0800 Subject: [PATCH 6/6] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 47 +++++++--------- torchao/prototype/mx_formats/README.md | 11 ++-- torchao/prototype/mx_formats/config.py | 31 +++++++++++ torchao/prototype/mx_formats/mx_linear.py | 60 +++++++-------------- 4 files changed, 74 insertions(+), 75 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 17a76a750d..c2eb66960f 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -11,6 +11,7 @@ import torch import torch.nn as nn +from torchao.prototype.mx_formats.config import MXLinearConfig from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES from torchao.prototype.mx_formats.mx_linear import ( MXInferenceLinear, @@ -59,8 +60,13 @@ def test_linear_eager(elem_dtype, bias, input_shape): nn.Linear(8, 6, bias=bias, device="cuda"), ) m_mx = copy.deepcopy(m) - block_size = 2 - swap_linear_with_mx_linear(m_mx, *elem_dtype, block_size=block_size) + config = MXLinearConfig( + block_size=2, + elem_dtype=elem_dtype[0], + elem_dtype_weight_override=elem_dtype[1], + elem_dtype_grad_output_override=elem_dtype[2], + ) + swap_linear_with_mx_linear(m_mx, config=config) x_ref = torch.randn(*input_shape, device="cuda").requires_grad_() x = copy.deepcopy(x_ref) @@ -97,8 +103,8 @@ def test_activation_checkpointing(): nn.Linear(4, 6, bias=True, device="cuda"), nn.Linear(6, 6, bias=True, device="cuda"), ) - block_size = 2 - swap_linear_with_mx_linear(m, elem_dtype, block_size=block_size) + config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype) + swap_linear_with_mx_linear(m, config=config) x = torch.randn(*input_shape, device="cuda").requires_grad_() g = torch.randn(*grad_shape, device="cuda") @@ -133,8 +139,8 @@ def test_linear_compile(elem_dtype, bias, use_autocast): m_mx = nn.Sequential( nn.Linear(K, N, bias=bias, device="cuda"), ) - block_size = 2 - swap_linear_with_mx_linear(m_mx, elem_dtype, block_size=block_size) + config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype) + swap_linear_with_mx_linear(m_mx, config=config) m_mx_c = copy.deepcopy(m_mx) m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor") @@ -181,8 +187,8 @@ def test_inference_linear(elem_dtype, bias, input_shape): m = nn.Sequential(nn.Linear(4, 6, bias=bias, dtype=torch.bfloat16)) m = m.cuda() m_mx = copy.deepcopy(m) - block_size = 2 - swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size) + config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype) + swap_linear_with_mx_inference_linear(m_mx, config=config) x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16) y_ref = m(x) @@ -209,8 +215,8 @@ def test_inference_compile_simple(elem_dtype): m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16)) m = m.cuda() m_mx = copy.deepcopy(m) - block_size = 2 - swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size) + config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype) + swap_linear_with_mx_inference_linear(m_mx, config=config) m_mx = torch.compile(m_mx, fullgraph="true") x = torch.randn(2, 4, device="cuda", dtype=torch.bfloat16) @@ -223,20 +229,6 @@ def test_inference_compile_simple(elem_dtype): assert sqnr >= 13.5 -def test_mx_linear_input_weight_gradient_dtypes(): - m = nn.Sequential(nn.Linear(32, 32)) - swap_linear_with_mx_linear(m, *SUPPORTED_ELEM_DTYPES[:3], block_size=32) - assert m[0].in_elem_dtype == SUPPORTED_ELEM_DTYPES[0] - assert m[0].w_elem_dtype == SUPPORTED_ELEM_DTYPES[1] - assert m[0].grad_elem_dtype == SUPPORTED_ELEM_DTYPES[2] - - m = nn.Sequential(nn.Linear(32, 32)) - swap_linear_with_mx_linear(m, torch.float8_e4m3fn, block_size=32) - assert m[0].in_elem_dtype == torch.float8_e4m3fn - assert m[0].w_elem_dtype == torch.float8_e4m3fn - assert m[0].grad_elem_dtype == torch.float8_e4m3fn - - def test_filter_fn(): m1 = nn.Sequential( nn.Linear(32, 32), @@ -245,12 +237,11 @@ def test_filter_fn(): m2 = copy.deepcopy(m1) filter_fn = lambda mod, fqn: fqn != "1" # noqa: E731 - swap_linear_with_mx_linear( - m1, torch.float8_e4m3fn, block_size=32, filter_fn=filter_fn - ) + config = MXLinearConfig(block_size=32) + swap_linear_with_mx_linear(m1, config=config, filter_fn=filter_fn) assert type(m1[0]) == MXLinear assert type(m1[1]) == torch.nn.Linear - swap_linear_with_mx_inference_linear(m2, torch.float8_e4m3fn, 32, filter_fn) # noqa: E501 + swap_linear_with_mx_inference_linear(m2, config=config, filter_fn=filter_fn) # noqa: E501 assert type(m2[0]) == MXInferenceLinear assert type(m2[1]) == torch.nn.Linear diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index 32f45e3755..09e7563ebb 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -41,10 +41,11 @@ This is a module to do MX training, the MX matmul is currently emulated. ```python from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear +from torchao.prototype.mx_formats.config import MXLinearConfig m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() -elem_dtype = torch.float8_e4m3fn -swap_linear_with_mx_linear(m, elem_dtype, block_size=32) +config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32) +swap_linear_with_mx_linear(m, config=config) # training loop (not shown) ``` @@ -55,11 +56,11 @@ This is a module to do MX inference, weights are in MX and matmul is in high pre ```python from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_inference_linear +from torchao.prototype.mx_formats.config import MXLinearConfig m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() -elem_dtype = torch.float8_e4m3fn -block_size = 32 -swap_linear_with_mx_inference_linear(m, elem_dtype, block_size) +config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32) +swap_linear_with_mx_inference_linear(m, config=config) # do inference (not shown) ``` diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 7b68b5b6a5..7cdf2d4e58 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -5,9 +5,40 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass +from typing import Any, Optional + +import torch + +from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES @dataclass class MXLinearConfig: + # block size for scaling, default is 32 to match + # https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, + # section 5.2 + block_size: int = 32 + + # element dtype, used for activations, weights and gradients + elem_dtype: Any = torch.float8_e4m3fn + + # overrides for element dtype for weights and gradients + # TODO(future PR): refactor to make this cleaner + elem_dtype_weight_override: Optional[Any] = None + elem_dtype_grad_output_override: Optional[Any] = None + # If True, uses a custom triton kernel for fp4 dequantize use_fp4_custom_triton_dequant_kernel: bool = False + + def __post_init__(self): + assert ( + self.elem_dtype in SUPPORTED_ELEM_DTYPES + ), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" + if self.elem_dtype_weight_override is not None: + assert ( + self.elem_dtype_weight_override in SUPPORTED_ELEM_DTYPES + ), f"elem_dtype_weight_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" + if self.elem_dtype_grad_output_override is not None: + assert ( + self.elem_dtype_grad_output_override in SUPPORTED_ELEM_DTYPES + ), f"elem_dtype_grad_output_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index 72c2b6ab39..a38a8c5499 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -107,22 +107,11 @@ class MXLinear(torch.nn.Linear): def from_float( cls, mod, - elem_dtype, - elem_dtype_weight_override=None, - elem_dtype_grad_output_override=None, - *, - # TODO(next PR): move elem_dtype* and block size into config - config: MXLinearConfig = None, - block_size=32, + config: Optional[MXLinearConfig] = MXLinearConfig(), ): + # TODO(before land): remove this + assert isinstance(config, MXLinearConfig) mod.__class__ = MXLinear - mod.in_elem_dtype = elem_dtype - mod.w_elem_dtype = elem_dtype_weight_override or elem_dtype - mod.grad_elem_dtype = elem_dtype_grad_output_override or elem_dtype - mod.block_size = block_size - # TODO(next PR): fix this - if config is None: - config = MXLinearConfig() mod.config = config return mod @@ -135,13 +124,14 @@ def forward(self, x): else: w = self.weight + config = self.config y = mx_mm.apply( x, w, - self.in_elem_dtype, - self.w_elem_dtype, - self.grad_elem_dtype, - self.block_size, + config.elem_dtype, + config.elem_dtype_weight_override or config.elem_dtype, + config.elem_dtype_grad_output_override or config.elem_dtype, + config.block_size, ) if self.bias is not None: y = y + self.bias @@ -158,9 +148,11 @@ class MXInferenceLinear(torch.nn.Linear): @classmethod @torch.no_grad() - def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig): - # TODO(next PR): move elem_dtype and block_size into config - + def from_float( + cls, + mod, + config: Optional[MXLinearConfig] = MXLinearConfig(), + ): with torch.device("meta"): super_kwargs = { "in_features": mod.in_features, @@ -171,10 +163,9 @@ def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig): # TODO(future PR): set to new_mod.weight directly, will need to work # through some errors new_mod.weight_mx = MXTensor.to_mx( - mod.weight, elem_dtype, block_size=block_size + mod.weight, config.elem_dtype, block_size=config.block_size ) new_mod.bias = mod.bias - new_mod.elem_dtype = elem_dtype new_mod.config = config return new_mod @@ -213,13 +204,8 @@ def _is_linear(mod, fqn): def swap_linear_with_mx_linear( model, - elem_dtype, - elem_dtype_weight_override=None, - elem_dtype_grad_output_override=None, *, - # TODO(next PR): move elem_dtype* and block_size into config config: Optional[MXLinearConfig] = None, - block_size=32, filter_fn=None, ): if filter_fn is None: @@ -232,24 +218,16 @@ def __fn(mod, fqn): combined_filter_fn = __fn replace_with_custom_fn_if_matches_filter( model, - lambda mod: MXLinear.from_float( - mod, - elem_dtype, - elem_dtype_weight_override, - elem_dtype_grad_output_override, - config=config, - block_size=block_size, - ), + lambda mod: MXLinear.from_float(mod, config=config), combined_filter_fn, ) def swap_linear_with_mx_inference_linear( model, - elem_dtype, - block_size, - filter_fn=None, + *, config: Optional[MXLinearConfig] = None, + filter_fn=None, ): if filter_fn is None: combined_filter_fn = _is_linear @@ -261,8 +239,6 @@ def __fn(mod, fqn): combined_filter_fn = __fn replace_with_custom_fn_if_matches_filter( model, - lambda mod: MXInferenceLinear.from_float( - mod, elem_dtype, block_size, config=config - ), + lambda mod: MXInferenceLinear.from_float(mod, config=config), combined_filter_fn, )