You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello! I have just started working with jax and am currently doing a GPU inference optimization project related to alphafold3. I tried to optimize the program using persistent cache and added the following configurations (no other modifications were made):
In the xla-based flash attention version, after adding the persistent cache configuration, the second run of the program shows a significant performance boost (around 2.5 times faster compared to the first run). However, when switching to the Triton-based Flash Attention implementation with the same persistent cache configuration, the second run showed no performance improvement over the first (although the corresponding cache files are still generated in the cache directory).
I wonder if this indicates that the internal implementation of persistent cache in the jax framework is incompatible with Triton. If not, what could be the reasons for it to be ineffective? I don't know if this is more related to the internal implementation of alphafold3.
System info:
python version: 3.11.11
Jax version: 0.4.34
CUDA version: V12.4.99
Alphafold3 version: 3.0.0
The text was updated successfully, but these errors were encountered:
MelodicDrumstep
changed the title
Rewriting the kernel in Triton hinders the overall optimization benefits of persistent caching
Using the Triton version of flash attention hinders the overall optimization benefits of persistent caching
Feb 4, 2025
JAX caching is based on computation graphs (XLA HLO), I think Triton had its own JIT
I think you have a point! I'm not quite familiar with the internal mechanism of persistant caching, but I will try to write some minimal experimental examples of Triton with persistent caching.
Could you try doing your experiment with this XLA flag set?
XLA_FLAGS=--xla_gpu_autotune_level=0
This disables the XLA autotuner which might be trying to find the hyperparams for Triton kernel call from scratch every time. It is not very likely this will help though...
Hello! I have just started working with jax and am currently doing a GPU inference optimization project related to alphafold3. I tried to optimize the program using persistent cache and added the following configurations (no other modifications were made):
In the xla-based flash attention version, after adding the persistent cache configuration, the second run of the program shows a significant performance boost (around 2.5 times faster compared to the first run). However, when switching to the Triton-based Flash Attention implementation with the same persistent cache configuration, the second run showed no performance improvement over the first (although the corresponding cache files are still generated in the cache directory).
I wonder if this indicates that the internal implementation of persistent cache in the jax framework is incompatible with Triton. If not, what could be the reasons for it to be ineffective? I don't know if this is more related to the internal implementation of alphafold3.
System info:
The text was updated successfully, but these errors were encountered: