Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MX: move block_size and elem_dtype into MXLinearConfig #1689

Merged
merged 11 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 19 additions & 28 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
import torch.nn as nn

from torchao.prototype.mx_formats.config import MXLinearConfig
from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES
from torchao.prototype.mx_formats.mx_linear import (
MXInferenceLinear,
Expand Down Expand Up @@ -59,8 +60,13 @@ def test_linear_eager(elem_dtype, bias, input_shape):
nn.Linear(8, 6, bias=bias, device="cuda"),
)
m_mx = copy.deepcopy(m)
block_size = 2
swap_linear_with_mx_linear(m_mx, *elem_dtype, block_size=block_size)
config = MXLinearConfig(
block_size=2,
elem_dtype=elem_dtype[0],
elem_dtype_weight_override=elem_dtype[1],
elem_dtype_grad_output_override=elem_dtype[2],
)
swap_linear_with_mx_linear(m_mx, config=config)

x_ref = torch.randn(*input_shape, device="cuda").requires_grad_()
x = copy.deepcopy(x_ref)
Expand Down Expand Up @@ -97,8 +103,8 @@ def test_activation_checkpointing():
nn.Linear(4, 6, bias=True, device="cuda"),
nn.Linear(6, 6, bias=True, device="cuda"),
)
block_size = 2
swap_linear_with_mx_linear(m, elem_dtype, block_size=block_size)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
swap_linear_with_mx_linear(m, config=config)

x = torch.randn(*input_shape, device="cuda").requires_grad_()
g = torch.randn(*grad_shape, device="cuda")
Expand Down Expand Up @@ -133,8 +139,8 @@ def test_linear_compile(elem_dtype, bias, use_autocast):
m_mx = nn.Sequential(
nn.Linear(K, N, bias=bias, device="cuda"),
)
block_size = 2
swap_linear_with_mx_linear(m_mx, elem_dtype, block_size=block_size)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
swap_linear_with_mx_linear(m_mx, config=config)
m_mx_c = copy.deepcopy(m_mx)
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")

Expand Down Expand Up @@ -181,8 +187,8 @@ def test_inference_linear(elem_dtype, bias, input_shape):
m = nn.Sequential(nn.Linear(4, 6, bias=bias, dtype=torch.bfloat16))
m = m.cuda()
m_mx = copy.deepcopy(m)
block_size = 2
swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
swap_linear_with_mx_inference_linear(m_mx, config=config)

x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16)
y_ref = m(x)
Expand All @@ -209,8 +215,8 @@ def test_inference_compile_simple(elem_dtype):
m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16))
m = m.cuda()
m_mx = copy.deepcopy(m)
block_size = 2
swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
swap_linear_with_mx_inference_linear(m_mx, config=config)
m_mx = torch.compile(m_mx, fullgraph="true")

x = torch.randn(2, 4, device="cuda", dtype=torch.bfloat16)
Expand All @@ -223,20 +229,6 @@ def test_inference_compile_simple(elem_dtype):
assert sqnr >= 13.5


def test_mx_linear_input_weight_gradient_dtypes():
m = nn.Sequential(nn.Linear(32, 32))
swap_linear_with_mx_linear(m, *SUPPORTED_ELEM_DTYPES[:3], block_size=32)
assert m[0].in_elem_dtype == SUPPORTED_ELEM_DTYPES[0]
assert m[0].w_elem_dtype == SUPPORTED_ELEM_DTYPES[1]
assert m[0].grad_elem_dtype == SUPPORTED_ELEM_DTYPES[2]

m = nn.Sequential(nn.Linear(32, 32))
swap_linear_with_mx_linear(m, torch.float8_e4m3fn, block_size=32)
assert m[0].in_elem_dtype == torch.float8_e4m3fn
assert m[0].w_elem_dtype == torch.float8_e4m3fn
assert m[0].grad_elem_dtype == torch.float8_e4m3fn


def test_filter_fn():
m1 = nn.Sequential(
nn.Linear(32, 32),
Expand All @@ -245,12 +237,11 @@ def test_filter_fn():
m2 = copy.deepcopy(m1)
filter_fn = lambda mod, fqn: fqn != "1" # noqa: E731

swap_linear_with_mx_linear(
m1, torch.float8_e4m3fn, block_size=32, filter_fn=filter_fn
)
config = MXLinearConfig(block_size=32)
swap_linear_with_mx_linear(m1, config=config, filter_fn=filter_fn)
assert type(m1[0]) == MXLinear
assert type(m1[1]) == torch.nn.Linear

swap_linear_with_mx_inference_linear(m2, torch.float8_e4m3fn, 32, filter_fn) # noqa: E501
swap_linear_with_mx_inference_linear(m2, config=config, filter_fn=filter_fn) # noqa: E501
assert type(m2[0]) == MXInferenceLinear
assert type(m2[1]) == torch.nn.Linear
41 changes: 31 additions & 10 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pytest
import torch

from torchao.prototype.mx_formats import config
from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
DTYPE_FP6_E2M3,
Expand All @@ -18,6 +17,7 @@
from torchao.prototype.mx_formats.mx_tensor import (
E8M0_EXPONENT_NAN_VAL,
MXTensor,
ScaleCalculationMode,
to_dtype,
)
from torchao.quantization.utils import compute_error
Expand Down Expand Up @@ -47,8 +47,10 @@ def run_before_and_after_tests():
torch._dynamo.reset()


def _test_mx(data_hp, elem_dtype, block_size):
data_mx = MXTensor.to_mx(data_hp, elem_dtype, block_size)
def _test_mx(
data_hp, elem_dtype, block_size, scale_calculation_mode=ScaleCalculationMode.FLOOR
):
data_mx = MXTensor.to_mx(data_hp, elem_dtype, block_size, scale_calculation_mode)
data_mx_dq = data_mx.to_dtype(data_hp.dtype)

def assert_sqnr_gt_threshold(orig, new, threshold):
Expand All @@ -61,7 +63,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
assert sqnr >= threshold

if elem_dtype is torch.float8_e4m3fn:
assert_sqnr_gt_threshold(data_hp, data_mx_dq, 20.0)
assert_sqnr_gt_threshold(data_hp, data_mx_dq, 18.0)
else:
assert_sqnr_gt_threshold(data_hp, data_mx_dq, 14.0)

Expand All @@ -74,6 +76,15 @@ def test_hello_world(elem_dtype):
_test_mx(data, elem_dtype, block_size)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("scale_calculation_mode", [s for s in ScaleCalculationMode])
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_realistic_numerics(elem_dtype, scale_calculation_mode):
data = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
block_size = 32
_test_mx(data, elem_dtype, block_size, scale_calculation_mode)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_all_zeros(elem_dtype):
Expand Down Expand Up @@ -127,8 +138,14 @@ def test_exponent_nan_out(elem_dtype):
else:
raise AssertionError("unsupported")
block_size = 2
use_fp4_custom_triton_dequant_kernel = False
tensor_mx = MXTensor(
scale_e8m0_bits, data_bits, elem_dtype, block_size, torch.float
scale_e8m0_bits,
data_bits,
elem_dtype,
block_size,
torch.float,
use_fp4_custom_triton_dequant_kernel,
)
tensor_hp = tensor_mx.to_dtype(torch.float)
assert torch.all(torch.isnan(tensor_hp[0:1]))
Expand Down Expand Up @@ -176,15 +193,16 @@ def test_transpose(elem_dtype, fp4_triton):
M, K = 128, 256
block_size = 32
tensor_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size)
config.use_fp4_custom_triton_dequant_kernel = fp4_triton
tensor_mx = MXTensor.to_mx(
tensor_hp,
elem_dtype,
block_size,
use_fp4_custom_triton_dequant_kernel=fp4_triton,
)
tensor_mx_dq_t = tensor_mx.to_dtype(tensor_hp.dtype).t()
config.use_fp4_custom_triton_dequant_kernel = False

tensor_mx_t = tensor_mx.t()
config.use_fp4_custom_triton_dequant_kernel = fp4_triton
tensor_mx_t_dq = tensor_mx_t.to_dtype(tensor_hp.dtype)
config.use_fp4_custom_triton_dequant_kernel = False

assert tensor_mx_dq_t.shape == tensor_mx_t_dq.shape
torch.testing.assert_close(tensor_mx_dq_t, tensor_mx_t_dq, atol=0, rtol=0)
Expand Down Expand Up @@ -246,18 +264,21 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):

to_dtype_c = torch.compile(to_dtype, fullgraph=True)

use_fp4_custom_triton_dequant_kernel = False
x_mx_dq = to_dtype(
x_mx._data,
x_mx._scale_e8m0,
x_mx._elem_dtype,
x_mx._block_size,
hp_dtype, # noqa: E501
use_fp4_custom_triton_dequant_kernel,
)
x_mx_c_dq = to_dtype_c(
x_mx_c._data,
x_mx_c._scale_e8m0,
x_mx_c._elem_dtype,
x_mx_c._block_size,
hp_dtype,
use_fp4_custom_triton_dequant_kernel,
)
torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0)
11 changes: 6 additions & 5 deletions torchao/prototype/mx_formats/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ This is a module to do MX training, the MX matmul is currently emulated.

```python
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
from torchao.prototype.mx_formats.config import MXLinearConfig

m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
elem_dtype = torch.float8_e4m3fn
swap_linear_with_mx_linear(m, elem_dtype, block_size=32)
config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32)
swap_linear_with_mx_linear(m, config=config)

# training loop (not shown)
```
Expand All @@ -55,11 +56,11 @@ This is a module to do MX inference, weights are in MX and matmul is in high pre

```python
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_inference_linear
from torchao.prototype.mx_formats.config import MXLinearConfig

m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
elem_dtype = torch.float8_e4m3fn
block_size = 32
swap_linear_with_mx_inference_linear(m, elem_dtype, block_size)
config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32)
swap_linear_with_mx_inference_linear(m, config=config)

# do inference (not shown)
```
Expand Down
46 changes: 44 additions & 2 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,44 @@
# If True, uses a custom triton kernel for fp4 dequantize
use_fp4_custom_triton_dequant_kernel = False
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Any, Optional

import torch

from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES


@dataclass
class MXLinearConfig:
# block size for scaling, default is 32 to match
# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
# section 5.2
block_size: int = 32

# element dtype, used for activations, weights and gradients
elem_dtype: Any = torch.float8_e4m3fn

# overrides for element dtype for weights and gradients
# TODO(future PR): refactor to make this cleaner
elem_dtype_weight_override: Optional[Any] = None
elem_dtype_grad_output_override: Optional[Any] = None

# If True, uses a custom triton kernel for fp4 dequantize
use_fp4_custom_triton_dequant_kernel: bool = False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you think that we will want to keep this public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unlikely, but IMO we can punt that until later

def __post_init__(self):
assert (
self.elem_dtype in SUPPORTED_ELEM_DTYPES
), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"
if self.elem_dtype_weight_override is not None:
assert (
self.elem_dtype_weight_override in SUPPORTED_ELEM_DTYPES
), f"elem_dtype_weight_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"
if self.elem_dtype_grad_output_override is not None:
assert (
self.elem_dtype_grad_output_override in SUPPORTED_ELEM_DTYPES
), f"elem_dtype_grad_output_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"
Loading
Loading