@@ -873,7 +873,7 @@ def cudnn_jax_flash_attention(
873
873
value : Array ,
874
874
decoder_segment_ids : Array | None ,
875
875
model_mode : str = MODEL_MODE_TRAIN ,
876
- ) -> Array :
876
+ ) -> tuple [ Array , Array ] :
877
877
"""CUDNN Flash Attention with JAX SDPA API.
878
878
"""
879
879
# These imports are only meant to work in a GPU build.
@@ -888,7 +888,7 @@ def cudnn_jax_flash_attention(
888
888
if model_mode == MODEL_MODE_AUTOREGRESSIVE :
889
889
lengths = jnp .sum (decoder_segment_ids , axis = - 1 )
890
890
891
- return dot_product_attention (
891
+ output , lse = dot_product_attention (
892
892
query ,
893
893
key ,
894
894
value ,
@@ -901,7 +901,7 @@ def cudnn_jax_flash_attention(
901
901
return_residual = True
902
902
)
903
903
else :
904
- return dot_product_attention (
904
+ output , lse = dot_product_attention (
905
905
query ,
906
906
key ,
907
907
value ,
@@ -911,6 +911,9 @@ def cudnn_jax_flash_attention(
911
911
qkv_layout = "BTNH" ,
912
912
return_residual = True
913
913
)
914
+ output = checkpoint_name (output , "context" )
915
+ lse = checkpoint_name (lse , "context" )
916
+ return output , lse
914
917
915
918
def compute_local_attention (
916
919
self , attn_weights : Array , value : Array | KVTensor , q_seq_len : int , model_mode : str
0 commit comments