Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use causal mask type for training or prefill with fixed seqlen when #1248

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,13 @@ def cudnn_flash_attention(
sliding_window_size = [self.sliding_window_size, 0]
mask_type = "causal" # SWA only works with causal masking
attn_mask = None
elif model_mode != common_types.MODEL_MODE_AUTOREGRESSIVE and decoder_segment_ids is None:
# For training or prefill cases where the sequence lengths are fixed, it is more efficient to take the
# shortcut of using causal mask type without passing an attention mask. This reduces the overhead of
# generating the attention mask in Maxtext and extracting the accumulative seqlen from the attention mask
# in Transformer Engine.
mask_type = "causal"
attn_mask = None
else:
# generate attn_mask
mask_type = "padding_causal" # only padding_causal mask type can take a created mask
Expand Down