Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

5 bits for exponent? #21

Open
bjarthur opened this issue Feb 6, 2023 · 12 comments
Open

5 bits for exponent? #21

bjarthur opened this issue Feb 6, 2023 · 12 comments
Labels
enhancement New feature or request
Milestone

Comments

@bjarthur
Copy link

bjarthur commented Feb 6, 2023

the new H100 from nvidia has 8-bit floats in two flavors: 4 bits for the exponent like Float8s.jl's Float8_4, and 5 bits. scroll down to "NVIDIA Hopper FP8 data format" here: https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/

have you considered adding this type to Float8s.jl? currently i'm using https://github.com/goualard-f/MicroFloatingPoints.jl to simulate to see if that many exponent bits is better (than 4 or 3), and it is painfully slow.

@milankl
Copy link
Member

milankl commented Feb 6, 2023

I'd be happy if you want to add it, and I'm happy to assist implementing it. In principle there's only 2 things needed

  • Float8_5 to Float32 conversion (that's easy, it's just a table lookup, check src/float8_to_float32.jl)
  • Float32 to Float8_5, that's trickier as unfortunately you cannot just round the mantissa bits and check for under&overflow but one needs to account for the subnormals in Float8_5 which are normals in Float32. @JeffreySarnoff wrote this code back in 2020 for 3 exponent bits, you might be able to adapt it though!

Having said that, is there any more information on how they define Float8 on the H100? As I currently don't see any other use for it, I'd be happy to stick to Nvidia's specs. For example, they may not have NaN's and redefine what exponent only 1s means, similar for subnormals. Asking because I've seen a bunch of very low precision floats that redefine the floating-point standard somewhat. Not that Float8 was ever standardised, but the logical extension down to 8 bits.

@milankl milankl added the enhancement New feature or request label Feb 6, 2023
@milankl milankl added this to the v0.2 milestone Feb 6, 2023
@bjarthur
Copy link
Author

bjarthur commented Feb 6, 2023

the 5 exp bit version is IEEE compliant. the 4 exp bit version does not have either inf and has only one nan. see https://arxiv.org/pdf/2209.05433.pdf

it might make more sense to be consistent with whatever you've already done for the 4 and 3 exp bit versions you already have.

@bjarthur
Copy link
Author

bjarthur commented Feb 8, 2023

i started looking into the "easy" lookup table (your first bullet above) to convert Float8 to Float32, and i'm wondering how these two tables for 3 and 4 exp bits were created. can they be generated programmatically? if so, it might be simplest, for me at least, to modify that code for 5 exp bits.

@milankl
Copy link
Member

milankl commented Feb 8, 2023

From my thesis $n_e$, number of exponent bits, $n_m$, number of mantissa bits
image

with $bias = 2^{n_e-1} - 1$ and the fraction $$f = \sum_{i=1}^{n_m}m_i2^{-i}$$ (big endian).

Then

julia> ne = 5
julia> nm = 2
julia> bias = 2^(ne-1) - 1
15     # same as Base.exponent_bias(Float16) btw

function fraction(ui::UInt8,nm=2)
    Σ = 0
    for i in 1:nm
        mask = 0x80 >> (8-nm+i-1)
        mi = (ui & mask) >> (nm-i)
        Σ += mi*2.0^(-i)
    end
    Σ
end

julia> f = [fraction(UInt8(mi),2) for mi in 0:2^nm-1]
4-element Vector{Float64}:
 0.0
 0.25
 0.5
 0.75

julia> subnormals = Float32[2.0^(1-bias)*fi for fi in f]     # including 0
4-element Vector{Float32}:
 0.0
 1.5258789f-5
 3.0517578f-5
 4.5776367f-5

julia> normals = Float32[2.0^(e-bias)*(1+fi) for e in 1:2^ne-2 for fi in f]
120-element Vector{Float32}:
     6.1035156f-5
     7.6293945f-5
     9.1552734f-5
    ...

We can then concatenate all representable Float8 as

julia> cat(subnormals,normals,Inf32,repeat([NaN32],2^nm-1),-subnormals,-normals,-Inf32,repeat([NaN32],2^nm-1),dims=1)
256-element Vector{Float32}:
      0.0
      1.5258789f-5
      3.0517578f-5
      4.5776367f-5
      6.1035156f-5
      7.6293945f-5
      9.1552734f-5
      0.00010681152
      0.00012207031
      0.00015258789
      0.00018310547
      0.00021362305
      0.00024414062
      0.00030517578
      0.00036621094
      0.0004272461
      0.00048828125
      0.00061035156
      0.0007324219
      0.0008544922
      0.0009765625
      0.0012207031
      
  -3072.0
  -3584.0
  -4096.0
  -5120.0
  -6144.0
  -7168.0
  -8192.0
 -10240.0
 -12288.0
 -14336.0
 -16384.0
 -20480.0
 -24576.0
 -28672.0
 -32768.0
 -40960.0
 -49152.0
 -57344.0
    -Inf
    NaN
    NaN
    NaN

@milankl
Copy link
Member

milankl commented Feb 8, 2023

Given that we may want 3 different Float8 formats in this package, it might be worth not hardcoding the tables as we did before but to create them dynamically. It's more readable, more reproducible and serves as some form of documentation. (As I seemingly cannot remember how we created this tables some years ago!!!)

@JeffreySarnoff
Copy link
Member

JeffreySarnoff commented Feb 11, 2023

I generated them. An advantage to the hardcoding is that it allows relative timing of compound calculations to be more consistent than would otherwise be the case. This assists investigation of the impact of providing certain other ops hardcoded (in hardware) or by e.g. multiple poly approx.

@bjarthur
Copy link
Author

how about hard coding them with meta programming?

@JeffreySarnoff
Copy link
Member

It is easier to use some subsidiary functions with a few outer "get it together" functions.
With 8-bit floating point representations, the assignment of specific UInt8 bit patterns to e.g. (if used within the given representation) NaN, +/-Inf, -0 is an important feature. And hardcoding handles this best.

@milankl
Copy link
Member

milankl commented Feb 11, 2023

Sorry, maybe to be a bit clearer, what I thought was to do something like

abstract type AbstractFloat8 <: AbstractFloat end
primitive type Float8_e3m4 <: AbstractFloat8 8 end        # 3 exponent bits, IEEE compliant
primitive type Float8_e4m3 <: AbstractFloat8 8 end        # 4 exponent bits, like Nvidia's E4M3
primitive type Float8_e5m2 <: AbstractFloat8 8 end        # 5 exponent bits, IEEE compliant

# define one of the above as default
Float8 = ...

function representable_float8s(::Type{T}) where {T<:AbstractFloat8}
    ne = Base.exponent_bits(T)
    nm = Base.significand_bits(T)
    ...
    # call normals, subnormals and concatentate accordingly
    if T == Float8_e3m4    # somehow distinguish between the different formats (or use dispatch for that)
        all_float8s = cat(subnormals,normals,...)
    ...
    end

    return all_float8s
end

# the following is then executed on every using/import
const float8_e3m4_to_float32 = representable_float8s(Float8_e3m4)
const float8_e4m3_to_float32 = representable_float8s(Float8_e4m3)
const ...

# and conversion defined as
Base.Float32(x::Float8_e3m4) = @inbounds float8_e3m4_to_float32[reinterpret(UInt8,x) + 0x01]
...

I find Float8_e3m4 while precise, somewhat annoying because technical though. We could also name them

  • Float8, i.e. (e=3, m=4) just because that's what we currently have
  • AFloat8, i.e. (e=4, m=3) because see below
  • BFloat8, i.e. (e=5,m=2) similar to BFloat16 as it uses the number exponent bits from the next more precise IEEE floats, BFloat16 uses e=8 as does Float32, so BFloat8 uses e=5 as does Float16...?

@JeffreySarnoff
Copy link
Member

fyi papers are using this way of indicating FP8 exponent and significand bits
(that there is a sign bit and it is the msb of the byte is presumed)

abstract type AbstractFloat8 <: AbstractFloat end

primitive type E3M4 <: AbstractFloat8 8 end        # 3 exponent bits, IEEE compliant
primitive type E4M3 <: AbstractFloat8 8 end        # 4 exponent bits, like Nvidia's E4M3
primitive type E5M2 <: AbstractFloat8 8 end        # 5 exponent bits, IEEE compliant

(they should be using S rather than M, as IEEE 754 specifies that terminology)

And there are still specifiable params.
[for the formats at hand, you have detailed this.
a reason not to use e.g. E4M3 without some prefix or suffix]
The exponent bias is one of them (k and then 2^k or 2^k-1).
The count of NaNs is another (0, 1, >1).
And other more representation related
e.g. how are the bytes valued, are they ordered or are they of ordered subsequences
e.g. if there is a NaN aor a -0 what are the encodings [aor is and/or]

@bjarthur
Copy link
Author

MicroFloatingPoints.jl uses parametric types to signify the bit partitioning: Floatmu{4,3} for example is an 8-bit float with 4 exp bits and 3 fraction bits. everything is IEEE compliant there.

perhaps we could parameterize the exp bias and NaN/Inf count similarly. maybe Float8{E,B,N,I} where E is the no. of exp bits, B the bias, N the no. of NaNs, and I the no. if Infs.

@JeffreySarnoff
Copy link
Member

It is cleaner and clearer to reduce the parameter count .. [at least at first].
The IEEE has a Standards-track effort on Arithmetic Formats for Machine Learning.
We could work with 1 NaN for ML 8bit floats types. I am of the view that there is use for 2 Infs (+/-) and use for 0 Infs (saturating arithmetic); other counts are much less compelling imo.
People like subnormals, even where there are just a few of them.

Sometimes 1 (unsigned) zero is proper, other times ( Complex{Float8} ) 2 signed zeros is proper.

maybe all are signed floats with two infs
Float8{Ebits,Bias,Neg0} == Float8{E,B,1,2} with -0 only if Neg0==true

Although this drops some parametric flexibility, it would be easier to get right. Then introducing more flexibility could proceed with care. I ran into this design dilemma with early versions of DoubleFloats.jl. Pushing more params is not the best first way.

For the more blue-sky-numeratives: I would like to experiment with signed Huge (a nonspecific finite value that exceeds the exact finite value floatmax(T) and -where a mathematical perspective is needed- is considered much greater than that, and signed Tiny (a nonspecific finite value that is nonzero and -where a mathematical perspective is needed- is considered much less than the exact finite value nextfloat(zero(T))). So as we augment parameters, a way to indulge this is sought.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants