Skip to content

Commit 8df1d5b

Browse files
committed
Chunk and local attention fix for auto-regressive generation
1 parent ff6c8ac commit 8df1d5b

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

MaxText/layers/attentions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,8 +874,12 @@ def cudnn_jax_flash_attention(
874874
decoder_segment_ids: Array | None,
875875
model_mode: str = MODEL_MODE_TRAIN,
876876
) -> tuple[Array, Array]:
877+
<<<<<<< HEAD
877878
"""CUDNN Flash Attention with JAX SDPA API.
878879
"""
880+
=======
881+
"""CUDNN Flash Attention with JAX SDPA API."""
882+
>>>>>>> 2560bdc6 (Chunk and local attention fix for auto-regressive generation)
879883
# These imports are only meant to work in a GPU build.
880884
# pylint: disable=import-outside-toplevel
881885
from jax._src.cudnn.fused_attention_stablehlo import (

0 commit comments

Comments
 (0)