Skip to content

Commit

Permalink
Use causal mask type for training or prefill with fixed seqlen when
Browse files Browse the repository at this point in the history
calling CuDNN flash attention.
  • Loading branch information
wenchenvincent committed Feb 9, 2025
1 parent 86e05f6 commit 249bacd
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,13 @@ def cudnn_flash_attention(
sliding_window_size = [self.sliding_window_size, 0]
mask_type = "causal" # SWA only works with causal masking
attn_mask = None
elif model_mode != common_types.MODEL_MODE_AUTOREGRESSIVE and decoder_segment_ids is None:
# For training or prefill cases where the sequence lengths are fixed, it is more efficient to take the
# shortcut of using causal mask type without passing an attention mask. This reduces the overhead of
# generating the attention mask in Maxtext and extracting the accumulative seqlen from the attention mask
# in Transformer Engine.
mask_type = "causal"
attn_mask = None
else:
# generate attn_mask
mask_type = "padding_causal" # only padding_causal mask type can take a created mask
Expand Down

0 comments on commit 249bacd

Please sign in to comment.