@@ -735,86 +735,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
735735 )
736736 return out
737737
738- def _var_attention_qkv (q , k , v , heads , skip_reshape ):
739- if skip_reshape :
740- return q , k , v , q .shape [- 1 ]
741- total_tokens , embed_dim = q .shape
742- head_dim = embed_dim // heads
743- return (
744- q .view (total_tokens , heads , head_dim ),
745- k .view (k .shape [0 ], heads , head_dim ),
746- v .view (v .shape [0 ], heads , head_dim ),
747- head_dim ,
748- )
749738
750-
751- def _var_attention_output (out , heads , head_dim , skip_output_reshape ):
752- if skip_output_reshape :
753- return out
754- return out .reshape (- 1 , heads * head_dim )
755-
756-
757- def _use_blackwell_attention ():
758- device = model_management .get_torch_device ()
759- if device .type != "cuda" :
760- return False
761- major , minor = torch .cuda .get_device_capability (device )
762- return (major , minor ) >= (12 , 0 )
763-
764-
765- def _validate_split_cu_seqlens (name , cu_seqlens , token_count ):
766- if cu_seqlens .dtype not in (torch .int32 , torch .int64 ):
767- raise ValueError (f"{ name } must use an integer dtype" )
768- if cu_seqlens .ndim != 1 or cu_seqlens .numel () < 2 :
769- raise ValueError (f"{ name } must be a 1D tensor with at least two offsets" )
770- if cu_seqlens [0 ].item () != 0 :
771- raise ValueError (f"{ name } must start at 0" )
772- if (cu_seqlens [1 :] <= cu_seqlens [:- 1 ]).any ().item ():
773- raise ValueError (f"{ name } must be strictly increasing" )
774- if cu_seqlens [- 1 ].item () != token_count :
775- raise ValueError (f"{ name } does not match token count" )
776-
777-
778- def _split_indices (cu_seqlens ):
779- return cu_seqlens [1 :- 1 ].to (device = "cpu" , dtype = torch .long )
780-
781-
782- def var_attention_optimized_split (q , k , v , heads , cu_seqlens_q , cu_seqlens_k , * args , skip_reshape = False , skip_output_reshape = False , ** kwargs ):
783- q , k , v , head_dim = _var_attention_qkv (q , k , v , heads , skip_reshape )
784-
785- _validate_split_cu_seqlens ("cu_seqlens_q" , cu_seqlens_q , q .shape [0 ])
786- _validate_split_cu_seqlens ("cu_seqlens_k" , cu_seqlens_k , k .shape [0 ])
787- if cu_seqlens_k [- 1 ].item () != v .shape [0 ]:
788- raise ValueError ("cu_seqlens_k does not match v token count" )
789-
790- q_split_indices = _split_indices (cu_seqlens_q )
791- k_split_indices = _split_indices (cu_seqlens_k )
792- q_splits = torch .tensor_split (q , q_split_indices , dim = 0 )
793- k_splits = torch .tensor_split (k , k_split_indices , dim = 0 )
794- v_splits = torch .tensor_split (v , k_split_indices , dim = 0 )
795- if len (q_splits ) != len (k_splits ) or len (q_splits ) != len (v_splits ):
796- raise ValueError ("cu_seqlens_q and cu_seqlens_k must describe the same sequence count" )
797-
798- out = []
799- for q_i , k_i , v_i in zip (q_splits , k_splits , v_splits ):
800- q_i = q_i .permute (1 , 0 , 2 ).unsqueeze (0 )
801- k_i = k_i .permute (1 , 0 , 2 ).unsqueeze (0 )
802- v_i = v_i .permute (1 , 0 , 2 ).unsqueeze (0 )
803- out_dtype = q_i .dtype
804- if optimized_attention is attention_sage and q_i .dtype not in (torch .float16 , torch .bfloat16 ):
805- q_i = q_i .to (torch .bfloat16 )
806- k_i = k_i .to (torch .bfloat16 )
807- v_i = v_i .to (torch .bfloat16 )
808- out_i = optimized_attention (q_i , k_i , v_i , heads , skip_reshape = True , skip_output_reshape = True )
809- if out_i .dtype != out_dtype :
810- out_i = out_i .to (out_dtype )
811- out .append (out_i .squeeze (0 ).permute (1 , 0 , 2 ))
812-
813- out = torch .cat (out , dim = 0 )
814- return _var_attention_output (out , heads , head_dim , skip_output_reshape )
815-
816-
817- optimized_var_attention = var_attention_optimized_split
818739optimized_attention = attention_basic
819740
820741if model_management .sage_attention_enabled ():
@@ -837,8 +758,6 @@ def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *a
837758 logging .info ("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention" )
838759 optimized_attention = attention_sub_quad
839760
840- logging .info ("Using optimized_attention split-loop for variable-length attention" )
841-
842761optimized_attention_masked = optimized_attention
843762
844763
@@ -854,7 +773,6 @@ def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *a
854773register_attention_function ("pytorch" , attention_pytorch )
855774register_attention_function ("sub_quad" , attention_sub_quad )
856775register_attention_function ("split" , attention_split )
857- register_attention_function ("var_attention_optimized_split" , var_attention_optimized_split )
858776
859777
860778def optimized_attention_for_device (device , mask = False , small_input = False ):
@@ -1291,3 +1209,5 @@ def forward(
12911209 x = self .proj_out (x )
12921210 out = x + x_in
12931211 return out
1212+
1213+
0 commit comments