Support MXFP6 packing and fused unpack-dequantise kernel #1687
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Good afternoon! Following recent developments and increased support for MXFP formats, it would be useful to support efficient packing for
MXFP6
to benefit from the decrease in memory consumption and bandwidth requirements vs(MX)FP8
.MXFP6
has shown to perform similarly well compared withMXFP8
in LLM inference tasks, and with sufficient QAT even as well asfloat32
, e.g. in the MXFP paper.This PR packs the bits representing the
FP6
values in a4+2
fashion as is done in the FP6 LLM paper, and supports bothE2M3
andE3M2
variants. Packing is done via standalone Triton kernel, with unpacking and dequantisation performed via a fused kernel for better performance.Tests have been added in
test_custom_cast.py
andtest_mx_tensor.py
to cover accuracy in quantise-pack-unpack-dequantise with various FP6 values (min/max norm, min/max subnorm,-0.0
etc for bothE2M3
andE3M2
variants) as well as checking packed tensor dimensions.Note: due to the
4+2
packing scheme this requires the packing dimension to be a multiple of4
since the packed dimension will be 3/4 of this. However the typical MX block size is32
(—>24
when packed), and HW implementations tend to require dims to be multiples of16
or32
, so this should not be a problem. The relevant test case dimensions have been changed from6
to8
and the MX block sizes from2
to4
where applicable in order to accommodate this requirement.