Use causal mask type for training or prefill with fixed seqlen when #1248
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
calling CuDNN flash attention.
Description
This PR refines the mask types and attention masks when calling CuDNN Flash Attention. The existing code only specifies the mask type for sliding window and uses mask type
padding_causal
with an attention mask for other cases, regardless using fixed seqlen or variable seqlen. While there is nothing wrong with this, it would be more efficient to pass the mask typecausal
with no attention mask for the cases with fixed seqlen in training and prefill. There are two reasons why this would be better:padding_causal
and generate an attention mask and pass them to Transformer Engine. Then Transformer Engine would extract the sequence length information from the attention mask into another data structure for cuDNN to consume. With the change in this PR, the overhead of generating the attention mask in Maxtext and extracting information from the attention mask in Transformer Engine could be saved.padding_causal
mask type would require calling the flash attention backend for variable seqlen, whilecausal
mask type would call the flash attention backend for fixed seqlen. Currently, the variable seqlen backend is not yet enabled in Transformer Engine ROCm, thus this PR would enable running flash attention in Maxtext on AMDGPU. And in the future, even if the backend for fixed seqlen is enabled in Transformer Engine ROCm, we would still expect that calling the backend for fixed seqlen would be more efficient for such cases.Tests
Please describe how you tested this change, and include any instructions and/or
commands to reproduce.
Checklist
Before submitting this PR, please make sure (put X in square brackets):