Skip to content

Commit

Permalink
Use causal mask type for training or prefill with fixed seqlen when
Browse files Browse the repository at this point in the history
calling CuDNN flash attention.
  • Loading branch information
wenchenvincent committed Feb 7, 2025
1 parent 86e05f6 commit 14e2e7b
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,12 @@ def cudnn_flash_attention(
sliding_window_size = [self.sliding_window_size, 0]
mask_type = "causal" # SWA only works with causal masking
attn_mask = None
elif model_mode != common_types.MODEL_MODE_AUTOREGRESSIVE and decoder_segment_ids is None:
# For training or prefill of the decoder-architecture without sequence packing
# it is more efficient to use causal mask type as it calls the fixed-seqlen backend instead
# of the variable-seqlen backend.
mask_type = "causal"
attn_mask = None
else:
# generate attn_mask
mask_type = "padding_causal" # only padding_causal mask type can take a created mask
Expand Down

0 comments on commit 14e2e7b

Please sign in to comment.