Skip to content

Commit ff6c8ac

Browse files
author
maxtext authors
committed
Merge pull request #1805 from Cjkkkk:cudnn_sdpa_checkpoint
PiperOrigin-RevId: 770211092
2 parents 1c7f964 + a653c03 commit ff6c8ac

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

MaxText/layers/attentions.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ def cudnn_jax_flash_attention(
873873
value: Array,
874874
decoder_segment_ids: Array | None,
875875
model_mode: str = MODEL_MODE_TRAIN,
876-
) -> Array:
876+
) -> tuple[Array, Array]:
877877
"""CUDNN Flash Attention with JAX SDPA API.
878878
"""
879879
# These imports are only meant to work in a GPU build.
@@ -888,7 +888,7 @@ def cudnn_jax_flash_attention(
888888
if model_mode == MODEL_MODE_AUTOREGRESSIVE:
889889
lengths = jnp.sum(decoder_segment_ids, axis=-1)
890890

891-
return dot_product_attention(
891+
output, lse = dot_product_attention(
892892
query,
893893
key,
894894
value,
@@ -901,7 +901,7 @@ def cudnn_jax_flash_attention(
901901
return_residual=True
902902
)
903903
else:
904-
return dot_product_attention(
904+
output, lse = dot_product_attention(
905905
query,
906906
key,
907907
value,
@@ -911,6 +911,9 @@ def cudnn_jax_flash_attention(
911911
qkv_layout="BTNH",
912912
return_residual=True
913913
)
914+
output = checkpoint_name(output, "context")
915+
lse = checkpoint_name(lse, "context")
916+
return output, lse
914917

915918
def compute_local_attention(
916919
self, attn_weights: Array, value: Array | KVTensor, q_seq_len: int, model_mode: str

0 commit comments

Comments
 (0)