Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for float8_e4m3 and float8_e3m4 types #16585

Closed
wants to merge 1 commit into from

Conversation

apivovarov
Copy link
Contributor

@apivovarov apivovarov commented Aug 28, 2024

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

f8E4M3 type follows IEEE 754 convention.

f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 17 =6
- Precision specifies the total number of bits used for the significand (mantisa), 
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)

f8E3M4 type follows IEEE 754 convention

f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 13 =2
- Precision specifies the total number of bits used for the significand (mantissa), 
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)

Testing:

bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test

Related PRs:

  • LLVM PR-97179 [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
  • LLVM PR-97118 [MLIR] Add f8E4M3 IEEE 754 type (Merged)
  • LLVM PR-99698 [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
  • LLVM PR-101230 [MLIR] Add f8E3M4 IEEE 754 type (Merged)
  • StableHLO PR-2486 [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
  • StableHLO PR-2482 Add f8E4M3 and f8E3M4 types support (Merged)
  • ml_dtypes PR-161 Add float8_e4m3 (Merged)
  • ml_dtypes PR-171 Add float8_e3m4 (Merged)
  • XLA PR-17075 [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
  • XLA PR-3200 Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
  • JAX PR-23585 Add float8_e4m3 type support (in Review)

# LINT.ThenChange(Google-internal path)

tf_http_archive(
name = "stablehlo",
sha256 = STABLEHLO_SHA256,
strip_prefix = "stablehlo-{commit}".format(commit = STABLEHLO_COMMIT),
urls = tf_mirror_urls("https://github.com/openxla/stablehlo/archive/{commit}.zip".format(commit = STABLEHLO_COMMIT)),
urls = tf_mirror_urls("https://github.com/apivovarov/stablehlo/archive/{commit}.zip".format(commit = STABLEHLO_COMMIT)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not intended to fetch it from your repo.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's intended that you also update the deps in the same PR, could you split it in separate PRs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is pending the merge of StableHLO openxla/stablehlo#2482 Add f8E4M3 and f8E3M4 types support (in Review).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Alexander, Integrate StableHLO at openxla/stablehlo@4f31b2e7 was werged to XLA main today. It includes float8_e4m3 type support. My temporary change in third_party/stablehlo/workspace.bzl was removed from this PR. @mooskagh

GleasonK pushed a commit to openxla/stablehlo that referenced this pull request Sep 3, 2024
### Summary
This is a proposal to add `Float8E4M3` and `Float8E3M4` floating point
types to StableHLO.
Feedback welcome, see [RFC: Float8E4M3 and
Float8E3M4](https://github.com/apivovarov/stablehlo/blob/rfc_f8E4M3_f8E3M4/rfcs/20240808-f8E4M3_f8E3M4.md)
for more details.

### References and Links
- LLVM [PR-97179](llvm/llvm-project#97179)
[APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118)
[MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698)
[APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
- LLVM [PR-101230](llvm/llvm-project#101230)
[MLIR] Add f8E3M4 IEEE 754 type (Merged)
- [RFC: FP8 in
StableHLO](https://github.com/openxla/stablehlo/blob/main/rfcs/20221031-fp8.md)
- [RFC: Float8E4M3FNUZ and
Float8E5M2FNUZ](https://github.com/openxla/stablehlo/blob/main/rfcs/20230321-fp8_fnuz.md)
- StableHLO [PR-2482](#2482)
Add f8E4M3 and f8E3M4 types support
- [Amazon EC2 Trn1
Instances](https://aws.amazon.com/ec2/instance-types/trn1/)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add
float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add
float8_e3m4 (Merged)
- XLA [PR-16585](openxla/xla#16585) Add support
for float8_e4m3
GleasonK pushed a commit to openxla/stablehlo that referenced this pull request Sep 4, 2024
This PR adds f8E4M3 and f8E3M4 types support.

f8E4M3 and f8E3M4 types follow IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa), 
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa), 
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179)
[APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118)
[MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698)
[APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
- LLVM [PR-101230](llvm/llvm-project#101230)
[MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](#2486)
[RFC] Add f8E4M3 and f8E3M4 types support
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add
float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add
float8_e3m4 (Merged)
- XLA [PR-16585](openxla/xla#16585) Add support
for float8_e4m3
@apivovarov
Copy link
Contributor Author

Reed, David, could you please help review this PR? @reedwm @ddunl

@ddunl
Copy link
Member

ddunl commented Sep 5, 2024

I think Reed is the best person to review, I think this will require a patch on our end due to tensorflow/third_party, let me know when you approve and I can take care of the patch

@loislo loislo removed their request for review September 9, 2024 11:14
@apivovarov
Copy link
Contributor Author

Hi Reed,

This PR introduces support for the new f8E4M3 type, which adheres to the IEEE-754 convention. I've already added this type to the LLVM, MLIR, ml_dtypes, and StableHLO projects. This PR extends the support to XLA and includes a reference implementation for the CPU compiler.

Could you please help review this PR?

@reedwm

Copy link
Member

@reedwm reedwm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! Sorry for the delay in reviewing.

Normally, it's better to minimize the size of PRs, but I would prefer if E3M4 is also added in the same PR, since it touches most of the same files in the exact same way as E4M3, so it makes it easy to batch review both dtypes at once.

But if adding E3M4 to the same PR is inconvenient with you, I'm fine with this being done as a separate, future PR.

Comment on lines 320 to 321
{BF16, F16, F8E5M2, F8E4M3, F8E4M3FN,
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the type list here is growing, I would avoid hardcoding the list of FP8 types. One way to avoid this is to have DoWithUpcastToF32 take a should_upcast bool instead of the existing upcast_types list. Then you can pass something like should_upcast = BitWidth(b.GetShape(x).element_type) <= 16.

There are a lot of places where we list out all FP8 types, but every place we can remove listing these out will help when more types are added :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened PR #17130 - Add default upcasting behavior to DoWithUpcastToF32

@@ -111,6 +111,59 @@ INSTANTIATE_TEST_SUITE_P(DoublePrecisionInputs, FixedValueTest,
0x1.fffffffffffffp-127,
0x1.aaaaaaaaaaaaap-127));

TEST(FPDistanceTest, F8E4M3Distance) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this test is almost identical to F8E4M3FNDistance, can you merge them to avoid duplication?

One way to to create a type-parameterized test with TYPED_TEST_P. Another way would be to have a for-loop over primtiives types F8E4M3 and F8E4M3FN, and in the body use primitive_util::PrimitiveTypeSwitch with a lambda that does the CalculateDistanceInFloats calls. See here for an example of how to use PrimitiveTypeSwitch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened PR #17135 - Add TypeParam to FP8E4M3DistanceTest

} else if constexpr (std::is_integral_v<ElementwiseT>) {
if constexpr (std::is_signed_v<ElementwiseT>) {
if (rhs_el < static_cast<ElementwiseT>(0)) {
ElementWiseBinaryOp(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't change the formatting here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

restored

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

github workflow runs pipx to check clang formatting. Opened PR #17234 Format hlo_evaluator_typed_visitor.h

@@ -25,6 +25,63 @@ limitations under the License.
namespace xla {
namespace {

TEST(LiteralComparisonTest, F8E4M3CompareNear_Equal) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all the FP8 tests are duplicated for each FP8 type. Can you use TYPED_TEST_P to deduplicate them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened #17133 - Dedup LiteralComparisonTests

@@ -644,15 +648,19 @@ TEST_F(LiteralUtilTest, IsAll) {
// 9 rounds to 8 in E5M2 but is not equal to 8, so this should be false
EXPECT_FALSE(LiteralUtil::CreateR1<tsl::float8_e5m2>({q16}).IsAll(9));

tsl::float8_e4m3fn r16(9); // Exactly representable in e4m3
tsl::float8_e4m3 e4m3(9); // Exactly representable in e4m3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why there is a convention of using subsequent single letters to name the FP8 values (q16, r16, s16, etc) but you should follow it or change the convention. Either name this q16, renaming the above e5m2 to p16, or rename all the other FP8 variable names to something more descriptive, as you did for this one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed to q16

Comment on lines 283 to 284
case xla::F8E4M3:
return absl::UnimplementedError("F8E4M3 not implemented");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid having to modify this every time a new FP8 type is added, remove all these FP8 cases and check if IsF8Type(literal.shape().element_type() before the switch statement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened PR #17170 Code dedup in execution_trace_utils LiteralToValue

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just FYI, this was an intentional choice. Missing switch cases are a compiler error, so having a switch without a default case is preferable when possible. No big deal though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using single-case switch statements can make it easier for the compiler to detect potential errors in the code.
Opened PR #17279 - Use switch case without default in LiteralToValue

@@ -500,6 +500,36 @@ TEST_F(FloatNormalizationTest, DoNotChangeBitcastConvert) {
EXPECT_EQ(root->operand(0)->shape().element_type(), U16);
}

TEST_F(FloatNormalizationTest, ResolveIfUnsupportedF8e4m3) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge this with the existing test ResolveIfUnsupportedF8e5m2, either by looping over values (F8E4M3, F8E5M2) or by using a value-parameterized test with TEST_P.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened PR #17177 - Parametrize FloatNormalizationF8Test ResolveIfUnsupportedF8


XlaBuilder builder(TestName());
auto c = ConstantR1<tsl::float8_e4m3>(&builder, constant);
// F8 outputs are not yet supported so convert to F32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is actually no longer true. We should change the other tests with this comment as well. The test OneCellF8e5m2fnuz does have an FP8 output, so you can use that as an example in modifying this test.

If you want, you can also change the two existing tests with the comment to have FP8 outputs as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened PR #17182 - Parametrize ConstantsFloatTest OneCellFloat

Comment on lines 355 to 389
f16_reduced =
b->CreateOr(b->CreateAnd(f16_reduced, i16_const(0x9FFF)),
b->CreateLShr(b->CreateAnd(f16_reduced, i16_const(0x4000)),
i16_const(1)));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is effectively subtracting 8 from the exponent I think, as the difference in exponent bias is 8. Why not do that subtraction directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In contrast to the EmitF16ToF8e4m3fn function, the EmitF16ToF8e4m3 function does not include special code to handle Inf and NaN cases. (code around constexpr int max_finite_value = 0x5F7F;)

If I use -8 approach then several tests in //xla/tests:convert_test_cpu FAILED.
e.g.

  • inf -> -1.0
  • nan -> -1.5

Example:

input is inf
EmitReducePrecisionIR returns
x = 0.11111.0000000000 (0x7C00)

Option1: minus 8
x -= 0.01000.0000000000
// x is 0.10111.0000000000
// Shift to convert to F8: x is 1.0111.000
// f8e4m3 Result is -1.0 (Wrong)

Option2: Right shift E5 exponent's leftmost bit
x = (x & 0b1001'1111'1111'1111) | ((x & 0b0100'0000'0000'0000) >> 1)
// x is 0.01111.0000000000
// Shift to convert to F8: x is 0.1111.000
// f8e4m3 Result is inf (Correct)


// Set output exponent to 11111 if input exponent is 1111 (Inf or NaN)
// 0.1111.000 is 0x78
// 0.11111.000000000000 is 0x7C00
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.11111.000000000000 has 12 zeros at the end, when it should have 10.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -220,6 +220,59 @@ absl::StatusOr<llvm::Value*> EmitReducePrecisionIR(
return result;
}

llvm::Value* handle_halfway_points_F16ToF8(llvm::Value* f16_abs_bits,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I think you might have to modify expand_float_ops.cc, which is used by the new MLIR emitters which replace the existing emitters on GPUs. But I'm not very familiar with these new emitters. @jreiffers, can you advice on what needs to be done here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have generic support for all possible float conversions, but the emitted code might not be optimal, so it should be considered a fallback. I didn't look at these conversion routines here in detail, but if they're better, it would make sense to port them to the MLIR pipeline.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated xla/service/gpu/fusions/transforms/expand_float_ops.cc and added f8E4M3 cases to:

  • IsInf()
  • IsNaN()
  • RewriteF8Cst::matchAndRewrite() // If we're comparing to +-0, compare the absolute values.

expand_float_ops.cc includes a specialized function for the f8e5m2 type - EmitF16ToF8e5m2(). This is because F16 is technically f16e5m10. The two types are similar, with the primary difference being that the mantissa in f8e5m2 is truncated to 2 bits.

f8E4M3 has a different number of exponent and mantissa bits. The conversion can be efficiently managed using the "generic support for all possible float conversions".

Tested xla on CUDA:

//xla/tests/...    799 tests: 799 tests pass
//xla/service/...  865 tests: 865 tests pass
//xla/client/...    77 tests: 77 tests pass
//xla/runtime/...    1 test: 1 test passes
//xla/ffi/...        6 tests: 6 tests pass
//xla/hlo/...       12 tests: 12 tests pass
//xla/mlir/...     141 tests: 141 tests pass
//xla/mlir_hlo/...  98 tests: 98 tests pass
//xla/pjrt/...      26 tests: 26 tests pass
//xla/tools/...     41 tests: 41 tests pass
//xla/translate/... 61 tests: 61 tests pass

@jreiffers
Copy link
Member

Apologies for the delay, I'm OOO this week. Will take a look on Monday.

copybara-service bot pushed a commit that referenced this pull request Sep 13, 2024
ml_dtypes Updates:
Add float8_e4m3 and float8_e3m4 types support
Fix float divmod with zero denominator
Add int2 and uint2 types
ml_dtypes/commits

Related PRs
ml_dtypes PR Add float8_e4m3 jax-ml/ml_dtypes#161 Add float8_e4m3 (Merged)
XLA PR Add support for float8_e4m3 #16585 (In Review)

This closes #17075

PiperOrigin-RevId: 674396944
copybara-service bot pushed a commit that referenced this pull request Sep 13, 2024
ml_dtypes Updates:
Add float8_e4m3 and float8_e3m4 types support
Fix float divmod with zero denominator
Add int2 and uint2 types
ml_dtypes/commits

Related PRs
ml_dtypes PR Add float8_e4m3 jax-ml/ml_dtypes#161 Add float8_e4m3 (Merged)
XLA PR Add support for float8_e4m3 #16585 (In Review)

This closes #17075

PiperOrigin-RevId: 674396944
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 2, 2024
…oduleUnchangedNoShardingPerformed of the enum is unused, effectively making it a boolean. Also simplified away some dead code.

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 681506949
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 2, 2024
…erOutput> as the AutoShardingSolverResult::skip_auto_sharding is now dead after some recent changes.

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 678928364
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 2, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

PiperOrigin-RevId: 681551979
@apivovarov
Copy link
Contributor Author

PR has been merged! Reed, thank you for your help, guidance, and support! @reedwm

@reedwm
Copy link
Member

reedwm commented Oct 2, 2024

No problem, and thanks for the well-tested PR! Also thank you for all the test clean up PRs!

Note in merging, when converting to E3M4, I had to change the code to first convert to half to take into account we do not use an ml-dtypes version that includes jax-ml/ml_dtypes#205 yet. I added TODOs in the form of TODO(b/370786669) to all places where a conversion to half was added. These can be removed once we update ml-dtypes to a version that includes jax-ml/ml_dtypes#205. I'm happy to do this myself or for you to do this once ml-dtypes is updated.

copybara-service bot pushed a commit to google/tsl that referenced this pull request Oct 3, 2024
PR #16585: Add support for float8_e4m3 and float8_e3m4 types

Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 =...

***

PiperOrigin-RevId: 681876540
copybara-service bot pushed a commit to google/tsl that referenced this pull request Oct 3, 2024
PR #16585: Add support for float8_e4m3 and float8_e3m4 types

Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 =...

***

PiperOrigin-RevId: 681876540
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 3, 2024
The commit hash is now 6f02f77c4fa624d8b467c36d1d959a9b49b07900, matching what is in TF at https://github.com/tensorflow/tensorflow/blob/master/third_party/py/ml_dtypes/workspace.bzl

This fixes a breakage caused by openxla/xla#16585

PiperOrigin-RevId: 681951673
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 3, 2024
The commit hash is now 6f02f77c4fa624d8b467c36d1d959a9b49b07900, matching what is in TF at https://github.com/tensorflow/tensorflow/blob/master/third_party/py/ml_dtypes/workspace.bzl

This fixes a breakage caused by openxla/xla#16585

PiperOrigin-RevId: 681951673
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 3, 2024
The commit hash is now 6f02f77c4fa624d8b467c36d1d959a9b49b07900, matching what is in TF at https://github.com/tensorflow/tensorflow/blob/master/third_party/py/ml_dtypes/workspace.bzl

This fixes a breakage caused by openxla/xla#16585

PiperOrigin-RevId: 682038821
copybara-service bot pushed a commit that referenced this pull request Nov 14, 2024
Imported from GitHub PR #16775

I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests.

Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases.

Changes in this PR:
- Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h`
- Add `EmitReducePrecisionIR_F16ToF8e5m2` test
- Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test

Related PR:
- [PR-16585](#16585) Add support for float8_e4m3

Copybara import of the project:

--
5972205 by Alexander Pivovarov <[email protected]>:

Add test for EmitReducePrecisionIR

Merging this change closes #16775

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16775 from apivovarov:elemental_ir_emitter_test 5972205
PiperOrigin-RevId: 696646489
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 14, 2024
Imported from GitHub PR openxla/xla#16775

I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests.

Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases.

Changes in this PR:
- Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h`
- Add `EmitReducePrecisionIR_F16ToF8e5m2` test
- Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test

Related PR:
- [PR-16585](openxla/xla#16585) Add support for float8_e4m3

Copybara import of the project:

--
59722056e36e5a0bab7736b4ad3897446861de0f by Alexander Pivovarov <[email protected]>:

Add test for EmitReducePrecisionIR

Merging this change closes #16775

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16775 from apivovarov:elemental_ir_emitter_test 59722056e36e5a0bab7736b4ad3897446861de0f
PiperOrigin-RevId: 696646489
copybara-service bot pushed a commit that referenced this pull request Nov 14, 2024
Imported from GitHub PR #16775

I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests.

Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases.

Changes in this PR:
- Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h`
- Add `EmitReducePrecisionIR_F16ToF8e5m2` test
- Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test

Related PR:
- [PR-16585](#16585) Add support for float8_e4m3

Copybara import of the project:

--
5972205 by Alexander Pivovarov <[email protected]>:

Add test for EmitReducePrecisionIR

Merging this change closes #16775

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16775 from apivovarov:elemental_ir_emitter_test 5972205
PiperOrigin-RevId: 696646489
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 14, 2024
Imported from GitHub PR openxla/xla#16775

I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests.

Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases.

Changes in this PR:
- Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h`
- Add `EmitReducePrecisionIR_F16ToF8e5m2` test
- Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test

Related PR:
- [PR-16585](openxla/xla#16585) Add support for float8_e4m3

Copybara import of the project:

--
59722056e36e5a0bab7736b4ad3897446861de0f by Alexander Pivovarov <[email protected]>:

Add test for EmitReducePrecisionIR

Merging this change closes #16775

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16775 from apivovarov:elemental_ir_emitter_test 59722056e36e5a0bab7736b4ad3897446861de0f
PiperOrigin-RevId: 696646489
copybara-service bot pushed a commit that referenced this pull request Nov 15, 2024
Imported from GitHub PR #16775

I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests.

Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases.

Changes in this PR:
- Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h`
- Add `EmitReducePrecisionIR_F16ToF8e5m2` test
- Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test

Related PR:
- [PR-16585](#16585) Add support for float8_e4m3

Copybara import of the project:

--
5972205 by Alexander Pivovarov <[email protected]>:

Add test for EmitReducePrecisionIR

Merging this change closes #16775

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16775 from apivovarov:elemental_ir_emitter_test 5972205
PiperOrigin-RevId: 696646489
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 15, 2024
Imported from GitHub PR openxla/xla#16775

I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests.

Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases.

Changes in this PR:
- Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h`
- Add `EmitReducePrecisionIR_F16ToF8e5m2` test
- Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test

Related PR:
- [PR-16585](openxla/xla#16585) Add support for float8_e4m3

Copybara import of the project:

--
59722056e36e5a0bab7736b4ad3897446861de0f by Alexander Pivovarov <[email protected]>:

Add test for EmitReducePrecisionIR

Merging this change closes #16775

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16775 from apivovarov:elemental_ir_emitter_test 59722056e36e5a0bab7736b4ad3897446861de0f
PiperOrigin-RevId: 696646489
copybara-service bot pushed a commit that referenced this pull request Nov 15, 2024
Imported from GitHub PR #16775

I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests.

Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases.

Changes in this PR:
- Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h`
- Add `EmitReducePrecisionIR_F16ToF8e5m2` test
- Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test

Related PR:
- [PR-16585](#16585) Add support for float8_e4m3

Copybara import of the project:

--
5972205 by Alexander Pivovarov <[email protected]>:

Add test for EmitReducePrecisionIR

Merging this change closes #16775

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16775 from apivovarov:elemental_ir_emitter_test 5972205
PiperOrigin-RevId: 696646489
copybara-service bot pushed a commit that referenced this pull request Nov 15, 2024
Imported from GitHub PR #16775

I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests.

Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases.

Changes in this PR:
- Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h`
- Add `EmitReducePrecisionIR_F16ToF8e5m2` test
- Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test

Related PR:
- [PR-16585](#16585) Add support for float8_e4m3

Copybara import of the project:

--
5972205 by Alexander Pivovarov <[email protected]>:

Add test for EmitReducePrecisionIR

Merging this change closes #16775

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16775 from apivovarov:elemental_ir_emitter_test 5972205
PiperOrigin-RevId: 696730664
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 15, 2024
Imported from GitHub PR openxla/xla#16775

I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests.

Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases.

Changes in this PR:
- Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h`
- Add `EmitReducePrecisionIR_F16ToF8e5m2` test
- Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test

Related PR:
- [PR-16585](openxla/xla#16585) Add support for float8_e4m3

Copybara import of the project:

--
59722056e36e5a0bab7736b4ad3897446861de0f by Alexander Pivovarov <[email protected]>:

Add test for EmitReducePrecisionIR

Merging this change closes #16775

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16775 from apivovarov:elemental_ir_emitter_test 59722056e36e5a0bab7736b4ad3897446861de0f
PiperOrigin-RevId: 696730664
copybara-service bot pushed a commit that referenced this pull request Nov 15, 2024
Imported from GitHub PR #16775

I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests.

Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases.

Changes in this PR:
- Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h`
- Add `EmitReducePrecisionIR_F16ToF8e5m2` test
- Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test

Related PR:
- [PR-16585](#16585) Add support for float8_e4m3

Copybara import of the project:

--
5972205 by Alexander Pivovarov <[email protected]>:

Add test for EmitReducePrecisionIR

Merging this change closes #16775

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16775 from apivovarov:elemental_ir_emitter_test 5972205
PiperOrigin-RevId: 696730664
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 15, 2024
Imported from GitHub PR openxla/xla#16775

I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests.

Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases.

Changes in this PR:
- Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h`
- Add `EmitReducePrecisionIR_F16ToF8e5m2` test
- Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test

Related PR:
- [PR-16585](openxla/xla#16585) Add support for float8_e4m3

Copybara import of the project:

--
59722056e36e5a0bab7736b4ad3897446861de0f by Alexander Pivovarov <[email protected]>:

Add test for EmitReducePrecisionIR

Merging this change closes #16775

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16775 from apivovarov:elemental_ir_emitter_test 59722056e36e5a0bab7736b4ad3897446861de0f
PiperOrigin-RevId: 696730664
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants