Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using the Triton version of flash attention hinders the overall optimization benefits of persistent caching #287

Open
MelodicDrumstep opened this issue Feb 4, 2025 · 4 comments
Labels
question Further information is requested

Comments

@MelodicDrumstep
Copy link

MelodicDrumstep commented Feb 4, 2025

Hello! I have just started working with jax and am currently doing a GPU inference optimization project related to alphafold3. I tried to optimize the program using persistent cache and added the following configurations (no other modifications were made):

jax.config.update("jax_compilation_cache_dir", my_cache_dir)
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In the xla-based flash attention version, after adding the persistent cache configuration, the second run of the program shows a significant performance boost (around 2.5 times faster compared to the first run). However, when switching to the Triton-based Flash Attention implementation with the same persistent cache configuration, the second run showed no performance improvement over the first (although the corresponding cache files are still generated in the cache directory).

I wonder if this indicates that the internal implementation of persistent cache in the jax framework is incompatible with Triton. If not, what could be the reasons for it to be ineffective? I don't know if this is more related to the internal implementation of alphafold3.

System info:

  • python version: 3.11.11
  • Jax version: 0.4.34
  • CUDA version: V12.4.99
  • Alphafold3 version: 3.0.0
@MelodicDrumstep MelodicDrumstep changed the title Rewriting the kernel in Triton hinders the overall optimization benefits of persistent caching Using the Triton version of flash attention hinders the overall optimization benefits of persistent caching Feb 4, 2025
@JoyBoy1021
Copy link

JoyBoy1021 commented Feb 4, 2025

JAX caching is based on computation graphs (XLA HLO), I think Triton had its own JIT

@MelodicDrumstep
Copy link
Author

JAX caching is based on computation graphs (XLA HLO), I think Triton had its own JIT

I think you have a point! I'm not quite familiar with the internal mechanism of persistant caching, but I will try to write some minimal experimental examples of Triton with persistent caching.

@Augustin-Zidek Augustin-Zidek added the question Further information is requested label Feb 5, 2025
@google-deepmind google-deepmind deleted a comment from JoyBoy1021 Feb 5, 2025
@Augustin-Zidek
Copy link
Collaborator

Augustin-Zidek commented Feb 5, 2025

Could you try doing your experiment with this XLA flag set?

XLA_FLAGS=--xla_gpu_autotune_level=0

This disables the XLA autotuner which might be trying to find the hyperparams for Triton kernel call from scratch every time. It is not very likely this will help though...

Some ideas might also be suggested in jax-ml/jax#26304.

@MelodicDrumstep
Copy link
Author

@Augustin-Zidek Thank you for your reply! I tried adding

os.environ['XLA_FLAGS'] = '--xla_gpu_autotune_level=0' 

and running the Triton version, but the situation seems to be exactly the same as before.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants