Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.jit slows down the code a lot on function with simple array operations and jnp.roll() #24373

Open
pmocz opened this issue Oct 17, 2024 · 4 comments
Labels
bug Something isn't working XLA

Comments

@pmocz
Copy link

pmocz commented Oct 17, 2024

Description

I get significant 4x slowdown in JAX code when I add a @jax.jit to my main update function, which manipulates large arrays with element-wise math and jnp.roll()

A minimal reproducer is included below, where removing the @jax.jit around the update() function (line marked by a comment # XXX) speeds up the code a lot. The slowdown is not due to compile-time overhead. I'm quite puzzled by the behavior and think it may be a bug in JAX or XLA. What is the best way to get to the bottom of this issue? To reproduce, run python euler.py with and without the jit decorator around update():

import jax
import jax.numpy as jnp
import time

# simulation parameters
N = 1024
boxsize = 1.0
dx = boxsize / N
vol = dx**2
dt = 0.0001


@jax.jit
def get_conserved(rho, vx, vy, P):
    """Calculate the conserved variables from the primitive variables"""

    Mass = rho * vol
    Momx = rho * vx * vol
    Momy = rho * vy * vol
    Energy = (P / (5 / 3 - 1) + 0.5 * rho * (vx**2 + vy**2)) * vol

    return Mass, Momx, Momy, Energy


@jax.jit
def get_primitive(Mass, Momx, Momy, Energy):
    """Calculate the primitive variable from the conserved variables"""

    rho = Mass / vol
    vx = Momx / rho / vol
    vy = Momy / rho / vol
    P = (Energy / vol - 0.5 * rho * (vx**2 + vy**2)) * (5 / 3 - 1)

    return rho, vx, vy, P


@jax.jit
def get_gradient(f):
    """Calculate the gradients of a field"""

    f_dx = (jnp.roll(f, -1, axis=0) - jnp.roll(f, 1, axis=0)) / (2 * dx)
    f_dy = (jnp.roll(f, -1, axis=1) - jnp.roll(f, 1, axis=1)) / (2 * dx)

    return f_dx, f_dy


@jax.jit
def extrapolate_to_face(f, f_dx, f_dy):
    """Extrapolate the field from face centers to faces using gradients"""

    f_XL = f - f_dx * dx / 2
    f_XL = jnp.roll(f_XL, -1, axis=0)
    f_XR = f + f_dx * dx / 2

    f_YL = f - f_dy * dx / 2
    f_YL = jnp.roll(f_YL, -1, axis=1)
    f_YR = f + f_dy * dx / 2

    return f_XL, f_XR, f_YL, f_YR


@jax.jit
def apply_fluxes(F, flux_F_X, flux_F_Y):
    """Apply fluxes to conserved variables to update solution state"""

    F += -dt * dx * flux_F_X
    F += dt * dx * jnp.roll(flux_F_X, 1, axis=0)
    F += -dt * dx * flux_F_Y
    F += dt * dx * jnp.roll(flux_F_Y, 1, axis=1)

    return F


@jax.jit
def get_flux(rho_L, rho_R, vx_L, vx_R, vy_L, vy_R, P):
    """Calculate fluxes between 2 states"""

    # left and right energies
    en_L = P / (5 / 3 - 1) + 0.5 * rho_L * (vx_L**2 + vy_L**2)
    en_R = P / (5 / 3 - 1) + 0.5 * rho_R * (vx_R**2 + vy_R**2)

    # compute star (averaged) states
    rho_star = 0.5 * (rho_L + rho_R)
    momx_star = 0.5 * (rho_L * vx_L + rho_R * vx_R)
    momy_star = 0.5 * (rho_L * vy_L + rho_R * vy_R)
    en_star = 0.5 * (en_L + en_R)

    P_star = (5 / 3 - 1) * (en_star - 0.5 * (momx_star**2 + momy_star**2) / rho_star)

    flux_Mass = momx_star
    flux_Momx = momx_star**2 / rho_star + P_star
    flux_Momy = momx_star * momy_star / rho_star
    flux_Energy = (en_star + P_star) * momx_star / rho_star

    # add stabilizing diffusive term
    flux_Mass -= 0.5 * 0.5 * (rho_L - rho_R)
    flux_Momx -= 0.5 * 0.5 * (rho_L * vx_L - rho_R * vx_R)
    flux_Momy -= 0.5 * 0.5 * (rho_L * vy_L - rho_R * vy_R)
    flux_Energy -= 0.5 * 0.5 * (en_L - en_R)

    return flux_Mass, flux_Momx, flux_Momy, flux_Energy


@jax.jit  # <---  XXX Adding this line slows down the code a lot!!
def update(Mass, Momx, Momy, Energy):
    """Take a simulation timestep"""

    rho, vx, vy, P = get_primitive(Mass, Momx, Momy, Energy)

    rho_dx, rho_dy = get_gradient(rho)
    vx_dx, vx_dy = get_gradient(vx)
    vy_dx, vy_dy = get_gradient(vy)

    rho_XL, rho_XR, rho_YL, rho_YR = extrapolate_to_face(rho, rho_dx, rho_dy)
    vx_XL, vx_XR, vx_YL, vx_YR = extrapolate_to_face(vx, vx_dx, vx_dy)
    vy_XL, vy_XR, vy_YL, vy_YR = extrapolate_to_face(vy, vy_dx, vy_dy)

    flux_Mass_X, flux_Momx_X, flux_Momy_X, flux_Energy_X = get_flux(
        rho_XL, rho_XR, vx_XL, vx_XR, vy_XL, vy_XR, P
    )
    flux_Mass_Y, flux_Momy_Y, flux_Momx_Y, flux_Energy_Y = get_flux(
        rho_YL, rho_YR, vy_YL, vy_YR, vx_YL, vx_YR, P
    )

    Mass = apply_fluxes(Mass, flux_Mass_X, flux_Mass_Y)
    Momx = apply_fluxes(Momx, flux_Momx_X, flux_Momx_Y)
    Momy = apply_fluxes(Momy, flux_Momy_X, flux_Momy_Y)
    Energy = apply_fluxes(Energy, flux_Energy_X, flux_Energy_Y)

    return Mass, Momx, Momy, Energy


def main():
    """Finite Volume simulation"""

    # Setup
    xlin = jnp.linspace(0.5 * dx, boxsize - 0.5 * dx, N)
    X, Y = jnp.meshgrid(xlin, xlin, indexing="ij")

    rho = 1.0 + (jnp.abs(Y - 0.5) < 0.25)
    vx = -0.5 + (jnp.abs(Y - 0.5) < 0.25)
    vy = 0.1 * jnp.sin(4 * jnp.pi * X)
    P = 2.5 * jnp.ones(X.shape)

    Mass, Momx, Momy, Energy = get_conserved(rho, vx, vy, P)

    # Main Loop
    tic = time.time()
    for n_iter in range(40):

        Mass, Momx, Momy, Energy = jax.block_until_ready(
            update(Mass, Momx, Momy, Energy)
        )

        cell_updates = X.shape[0] * X.shape[1] * n_iter
        total_time = time.time() - tic
        mcups = cell_updates / (1e6 * total_time)
        print("  million cell updates / second: ", mcups)

    print("Total time: ", total_time)


if __name__ == "__main__":
    main()

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.33
jaxlib: 0.4.33
numpy:  2.1.2
python: 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:35:20) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='C916PXT6XW', release='23.6.0', version='Darwin Kernel Version 23.6.0: Wed Jul 31 20:50:00 PDT 2024; root:xnu-10063.141.1.700.5~1/RELEASE_ARM64_T6031', machine='arm64')
@pmocz pmocz added the bug Something isn't working label Oct 17, 2024
@pmocz pmocz changed the title jax.jit slows down the code a lot on function with simple array operations and np.roll() jax.jit slows down the code a lot on function with simple array operations and jnp.roll() Oct 17, 2024
@pmocz
Copy link
Author

pmocz commented Oct 18, 2024

I am simplifying the code to highlight the error:

import jax
import jax.numpy as jnp
import time


@jax.jit
def get_gradient(f):
    """Calculate the gradients of a field"""

    f_dx = jnp.roll(f, -1, axis=0) - jnp.roll(f, 1, axis=0)
    f_dy = jnp.roll(f, -1, axis=1) - jnp.roll(f, 1, axis=1)

    return f_dx, f_dy


@jax.jit
def extrapolate_to_face(f, f_dx, f_dy):
    """Extrapolate the field from face centers to faces using gradients"""

    f_XL = f - f_dx
    f_XL = jnp.roll(f_XL, -1, axis=0)
    f_XR = f + f_dx

    f_YL = f - f_dy
    f_YL = jnp.roll(f_YL, -1, axis=1)
    f_YR = f + f_dy

    return f_XL, f_XR, f_YL, f_YR


@jax.jit
def apply_fluxes(F, flux_F_X, flux_F_Y):
    """Apply fluxes to conserved variables to update solution state"""

    F += -flux_F_X
    F += jnp.roll(flux_F_X, 1, axis=0)
    F += -flux_F_Y
    F += jnp.roll(flux_F_Y, 1, axis=1)

    return F


@jax.jit
def get_flux(A_L, A_R, B_L, B_R):
    """Calculate fluxes between 2 states"""

    A_star = 0.5 * (A_L + A_R)
    B_star = 0.5 * (B_L + B_R)

    flux_A = B_star
    flux_B = B_star**2 / A_star

    flux_A -= 0.1 * (A_L - A_R)
    flux_B -= 0.1 * (B_L - B_R)

    return flux_A, flux_B


# @jax.jit  # <---  XXX Adding this line slows down the code a lot!!
def update(A, B):
    """Take a simulation timestep"""

    A_dx, A_dy = get_gradient(A)
    B_dx, B_dy = get_gradient(B)

    A_XL, A_XR, A_YL, A_YR = extrapolate_to_face(A, A_dx, A_dy)
    B_XL, B_XR, B_YL, B_YR = extrapolate_to_face(B, B_dx, B_dy)

    flux_A_X, flux_B_X = get_flux(A_XL, A_XR, B_XL, B_XR)
    flux_A_Y, flux_B_Y = get_flux(A_YL, A_YR, B_YL, B_YR)

    A = apply_fluxes(A, flux_A_X, flux_A_Y)
    B = apply_fluxes(B, flux_B_X, flux_B_Y)

    return A, B


@jax.jit
def update_compiled_SLOW(A, B):
    return update(A, B)


def main():

    N = 1024

    A = jnp.ones((N, N))
    B = jnp.ones((N, N))
    tic = time.time()
    for _ in range(200):
        (
            A,
            B,
        ) = update(A, B)
    print("Total time not compiled: ", time.time() - tic)

    A = jnp.ones((N, N))
    B = jnp.ones((N, N))
    tic = time.time()
    for _ in range(200):
        A, B = update_compiled_SLOW(A, B)
    print("Total time compiled: ", time.time() - tic)


if __name__ == "__main__":
    main()

gives:

Total time not compiled:  0.6709847450256348
Total time compiled:  2.1534647941589355

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 18, 2024

Thanks for the report! This is definitely unexpected, and points to some compiler issue.

I updated your timing to separate out the first call, use block_until_ready to avoid issues due to asynchronous dispatch, and use IPython's %timeit syntax for better fidelity:

_ = jax.block_until_ready(update(A, B, C))
%timeit jax.block_until_ready(update(A, B, C))

_ = jax.block_until_ready(update_compiled_SLOW(A, B, C))
%timeit jax.block_until_ready(update_compiled_SLOW(A, B, C))

This is the result on Colab CPU:

44.1 ms ± 7.17 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
165 ms ± 27.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

and this is the result on a Colab T4 GPU:

2.72 ms ± 1.46 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.21 ms ± 11.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

So it seems this issue is particular to the XLA:CPU compiler. It may be worth reporting this upstream at https://github.com/openxla/xla, though it would be useful to try and reduce the repro even further.

@jakevdp jakevdp added the XLA label Oct 18, 2024
@pmocz
Copy link
Author

pmocz commented Oct 18, 2024

Thanks for taking a look at this @jakevdp , and pin-pointing that this seems to be a CPU only issue. Definitely unexpected. What is really weird too is that if I comment out some simple terms in the apply_fluxes function like: flux_A -= 0.1 * (A_L - A_R), flux_B -= 0.1 * (B_L - B_R) then the issue goes away

I will make an issue with the XLA team as well

@pmocz
Copy link
Author

pmocz commented Oct 18, 2024

XLA issue is raised here: openxla/xla#18478

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working XLA
Projects
None yet
Development

No branches or pull requests

2 participants