Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use causal mask type for training or prefill with fixed seqlen when #1248

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

wenchenvincent
Copy link

@wenchenvincent wenchenvincent commented Feb 7, 2025

calling CuDNN flash attention.

Description

This PR refines the mask types and attention masks when calling CuDNN Flash Attention. The existing code only specifies the mask type for sliding window and uses mask type padding_causal with an attention mask for other cases, regardless using fixed seqlen or variable seqlen. While there is nothing wrong with this, it would be more efficient to pass the mask type causal with no attention mask for the cases with fixed seqlen in training and prefill. There are two reasons why this would be better:

  • The current code would use padding_causal and generate an attention mask and pass them to Transformer Engine. Then Transformer Engine would extract the sequence length information from the attention mask into another data structure for cuDNN to consume. With the change in this PR, the overhead of generating the attention mask in Maxtext and extracting information from the attention mask in Transformer Engine could be saved.
  • In addition to the above overhead, when Maxtext is run on the ROCm SW stack on AMDGPU, padding_causal mask type would require calling the flash attention backend for variable seqlen, while causal mask type would call the flash attention backend for fixed seqlen. Currently, the variable seqlen backend is not yet enabled in Transformer Engine ROCm, thus this PR would enable running flash attention in Maxtext on AMDGPU. And in the future, even if the backend for fixed seqlen is enabled in Transformer Engine ROCm, we would still expect that calling the backend for fixed seqlen would be more efficient for such cases.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant