We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ff6c8ac commit 8df1d5bCopy full SHA for 8df1d5b
MaxText/layers/attentions.py
@@ -874,8 +874,12 @@ def cudnn_jax_flash_attention(
874
decoder_segment_ids: Array | None,
875
model_mode: str = MODEL_MODE_TRAIN,
876
) -> tuple[Array, Array]:
877
+<<<<<<< HEAD
878
"""CUDNN Flash Attention with JAX SDPA API.
879
"""
880
+=======
881
+ """CUDNN Flash Attention with JAX SDPA API."""
882
+>>>>>>> 2560bdc6 (Chunk and local attention fix for auto-regressive generation)
883
# These imports are only meant to work in a GPU build.
884
# pylint: disable=import-outside-toplevel
885
from jax._src.cudnn.fused_attention_stablehlo import (
0 commit comments