1818# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1919# See the License for the specific language governing permissions and
2020# limitations under the License.
21-
22- import warnings
23- from typing import Callable , Optional , Union
21+ import math
22+ from typing import Optional , Union
2423
2524import torch
2625import torch .nn .functional as F
3433from ...modeling_layers import GenericForSequenceClassification , GradientCheckpointingLayer
3534from ...modeling_outputs import BaseModelOutputWithPast , CausalLMOutputWithPast
3635from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS , dynamic_rope_update
37- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS , PreTrainedModel
36+ from ...modeling_utils import PreTrainedModel
3837from ...processing_utils import Unpack
3938from ...utils import TransformersKwargs , auto_docstring , can_return_tuple
4039from ...utils .deprecation import deprecate_kwarg
@@ -231,44 +230,6 @@ def forward(self, x, position_ids):
231230 return freqs_cis
232231
233232
234- def repeat_kv (hidden_states : torch .Tensor , n_rep : int ) -> torch .Tensor :
235- """
236- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
237- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
238- """
239- batch , num_key_value_heads , slen , head_dim = hidden_states .shape
240- if n_rep == 1 :
241- return hidden_states
242- hidden_states = hidden_states [:, :, None , :, :].expand (batch , num_key_value_heads , n_rep , slen , head_dim )
243- return hidden_states .reshape (batch , num_key_value_heads * n_rep , slen , head_dim )
244-
245-
246- def eager_attention_forward (
247- module : nn .Module ,
248- query : torch .Tensor ,
249- key : torch .Tensor ,
250- value : torch .Tensor ,
251- attention_mask : Optional [torch .Tensor ],
252- scaling : float ,
253- dropout : float = 0.0 ,
254- ** kwargs : Unpack [TransformersKwargs ],
255- ):
256- key_states = repeat_kv (key , module .num_key_value_groups )
257- value_states = repeat_kv (value , module .num_key_value_groups )
258-
259- attn_weights = torch .matmul (query , key_states .transpose (2 , 3 )) * scaling
260- if attention_mask is not None :
261- causal_mask = attention_mask [:, :, :, : key_states .shape [- 2 ]]
262- attn_weights = attn_weights + causal_mask
263-
264- attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query .dtype )
265- attn_weights = nn .functional .dropout (attn_weights , p = dropout , training = module .training )
266- attn_output = torch .matmul (attn_weights , value_states )
267- attn_output = attn_output .transpose (1 , 2 ).contiguous ()
268-
269- return attn_output , attn_weights
270-
271-
272233def apply_rotary_emb (
273234 xq : torch .Tensor ,
274235 xk : torch .Tensor ,
@@ -285,123 +246,154 @@ def apply_rotary_emb(
285246 return xq_out , xk_out
286247
287248
288- class DeepseekV32Attention (nn .Module ):
289- """Multi-headed attention from 'Attention Is All You Need' paper"""
290-
291- def __init__ (self , config : DeepseekV32Config , layer_idx : Optional [int ] = None ):
249+ class DeepseekV32Indexer (nn .Module ):
250+ def __init__ (self , config : DeepseekV32Config ):
292251 super ().__init__ ()
293- self .config = config
294- self .layer_idx = layer_idx
295- self .attention_dropout = config .attention_dropout
296- self .hidden_size = config .hidden_size
297- self .num_heads = config .num_attention_heads
298- self .head_dim = config .head_dim
299- self .max_position_embeddings = config .max_position_embeddings
300- self .rope_theta = config .rope_theta
301- self .q_lora_rank = config .q_lora_rank
302- self .qk_rope_head_dim = config .qk_rope_head_dim
303- self .kv_lora_rank = config .kv_lora_rank
304- self .v_head_dim = config .v_head_dim
305- self .qk_nope_head_dim = config .qk_nope_head_dim
306- self .qk_head_dim = config .qk_nope_head_dim + config .qk_rope_head_dim
307- self .num_key_value_groups = config .num_attention_heads // config .num_key_value_heads
308-
309- self .is_causal = True
310-
311- if self .q_lora_rank is None :
312- self .q_proj = nn .Linear (self .hidden_size , self .num_heads * self .qk_head_dim , bias = False )
313- else :
314- self .q_a_proj = nn .Linear (self .hidden_size , config .q_lora_rank , bias = config .attention_bias )
315- self .q_a_layernorm = DeepseekV32RMSNorm (config .q_lora_rank )
316- self .q_b_proj = nn .Linear (config .q_lora_rank , self .num_heads * self .qk_head_dim , bias = False )
317-
318- self .kv_a_proj_with_mqa = nn .Linear (
319- self .hidden_size ,
320- config .kv_lora_rank + config .qk_rope_head_dim ,
321- bias = config .attention_bias ,
252+ self .dim : int = config .dim
253+ self .n_heads : int = config .index_n_heads
254+ self .n_local_heads = config .index_n_heads # // world_size
255+ self .head_dim : int = config .index_head_dim
256+ self .rope_head_dim : int = config .qk_rope_head_dim
257+ self .index_topk : int = config .index_topk
258+ self .q_lora_rank : int = config .q_lora_rank
259+ self .wq_b = nn .Linear (self .q_lora_rank , self .n_heads * self .head_dim )
260+ self .wk = nn .Linear (self .dim , self .head_dim )
261+ self .k_norm = nn .LayerNorm (self .head_dim )
262+ self .weights_proj = nn .Linear (self .dim , self .n_heads , dtype = torch .get_default_dtype ())
263+ self .softmax_scale = self .head_dim ** - 0.5
264+ self .scale_fmt = config .scale_fmt
265+ self .k_sclae_head_dim = self .head_dim
266+ self .register_buffer (
267+ "k_cache" ,
268+ torch .zeros (config .max_batch_size , config .max_seq_len , self .head_dim , dtype = torch .float8_e4m3fn ),
269+ persistent = False ,
322270 )
323- self .kv_a_layernorm = DeepseekV32RMSNorm (config .kv_lora_rank )
324- self .kv_b_proj = nn .Linear (
325- config .kv_lora_rank ,
326- self .num_heads * (self .qk_head_dim - self .qk_rope_head_dim + self .v_head_dim ),
327- bias = False ,
271+ self .register_buffer (
272+ "k_scale_cache" ,
273+ torch .zeros (config .max_batch_size , config .max_seq_len , self .head_dim , dtype = torch .float32 ),
274+ persistent = False ,
328275 )
329276
330- self .o_proj = nn .Linear (
331- self .num_heads * self .v_head_dim ,
332- self .hidden_size ,
333- bias = config .attention_bias ,
277+ def forward (
278+ self , x : torch .Tensor , qr : torch .Tensor , start_pos : int , freqs_cis : torch .Tensor , mask : Optional [torch .Tensor ]
279+ ):
280+ bsz , seqlen , _ = x .size ()
281+ end_pos = start_pos + seqlen
282+ q = self .wq_b (qr )
283+ q = q .reshape (bsz , seqlen , - 1 , self .head_dim )
284+ q_pe , q_nope = torch .split (q , [self .rope_head_dim , self .head_dim - self .rope_head_dim ], dim = - 1 )
285+ q_pe = apply_rotary_pos_emb (q_pe , freqs_cis )
286+ q = torch .cat ([q_pe , q_nope ], dim = - 1 )
287+ k = self .wk (x )
288+ k = self .k_norm (k )
289+ k_pe , k_nope = torch .split (k , [self .rope_head_dim , self .head_dim - self .rope_head_dim ], dim = - 1 )
290+ k_pe = apply_rotary_emb (k_pe .unsqueeze (2 ), freqs_cis ).squeeze (2 )
291+ k = torch .cat ([k_pe , k_nope ], dim = - 1 )
292+ q_fp8 , q_scale = act_quant (q , block_size , self .scale_fmt )
293+ k_fp8 , k_scale = act_quant (k , block_size , self .scale_fmt )
294+ self .k_cache [:bsz , start_pos :end_pos ] = k_fp8
295+ self .k_scale_cache [:bsz , start_pos :end_pos ] = k_scale
296+ weights = self .weights_proj (x ) * self .n_heads ** - 0.5
297+ weights = weights .unsqueeze (- 1 ) * q_scale * self .softmax_scale
298+ index_score = fp8_index (
299+ q_fp8 .contiguous (),
300+ weights ,
301+ self .k_cache [:bsz , :end_pos ].contiguous (),
302+ self .k_scale_cache [:bsz , :end_pos ].contiguous (),
334303 )
304+ if mask is not None :
305+ index_score += mask
306+ topk_indices = index_score .topk (min (self .index_topk , end_pos ), dim = - 1 )[1 ]
307+ topk_indices_ = topk_indices .clone ()
308+ assert torch .all (topk_indices == topk_indices_ ), f"{ topk_indices = } { topk_indices_ = } "
309+ return topk_indices
310+
311+
312+ class DeepseekV32Attention (nn .Module ):
313+ """Multi-headed attention from 'Attention Is All You Need' paper"""
314+
315+ def __init__ (self , config , layer_idx ):
316+ self .softmax_scale = self .qk_head_dim ** - 0.5
317+ if config .max_seq_len > config .original_seq_len :
318+ mscale = 0.1 * config .mscale * math .log (config .rope_factor ) + 1.0
319+ self .softmax_scale = self .softmax_scale * mscale * mscale
335320
336- self .scaling = self .qk_head_dim ** (- 0.5 )
321+ self .indexer = DeepseekV32Indexer (config )
322+
323+ self .register_buffer (
324+ "kv_cache" , torch .zeros (config .max_batch_size , config .max_seq_len , self .kv_lora_rank ), persistent = False
325+ )
326+ self .register_buffer (
327+ "pe_cache" , torch .zeros (config .max_batch_size , config .max_seq_len , self .qk_rope_head_dim ), persistent = False
328+ )
329+ self .dequant_wkv_b = None
337330
338331 @deprecate_kwarg ("past_key_value" , new_name = "past_key_values" , version = "4.58" )
339332 def forward (
340- self ,
341- hidden_states : torch .Tensor ,
342- attention_mask : Optional [torch .Tensor ] = None ,
343- past_key_values : Optional [Cache ] = None ,
344- cache_position : Optional [torch .LongTensor ] = None ,
345- position_embeddings : Optional [tuple [torch .Tensor , torch .Tensor ]] = None ,
346- position_ids : Optional [torch .Tensor ] = None ,
347- ** kwargs ,
333+ self , x : torch .Tensor , start_pos : int , freqs_cis : torch .Tensor , mask : Optional [torch .Tensor ]
348334 ) -> tuple [torch .Tensor , Optional [torch .Tensor ], Optional [tuple [torch .Tensor ]]]:
349- if "padding_mask" in kwargs :
350- warnings .warn (
351- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
352- )
353- batch_size , seq_length = hidden_states .shape [:- 1 ]
354- query_shape = (batch_size , seq_length , - 1 , self .qk_head_dim )
355- key_shape = (batch_size , seq_length , - 1 , self .qk_nope_head_dim + self .v_head_dim )
356-
357- if self .q_lora_rank is None :
358- q = self .q_proj (hidden_states )
359- else :
360- q = self .q_b_proj (self .q_a_layernorm (self .q_a_proj (hidden_states )))
361- q = q .view (query_shape ).transpose (1 , 2 )
362- q_nope , q_pe = torch .split (q , [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
363-
364- compressed_kv = self .kv_a_proj_with_mqa (hidden_states )
365- k_nope , k_pe = torch .split (compressed_kv , [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
366- k_nope = self .kv_b_proj (self .kv_a_layernorm (k_nope )).view (key_shape ).transpose (1 , 2 )
367- k_nope , value_states = torch .split (k_nope , [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
368-
369- k_pe = k_pe .view (batch_size , 1 , seq_length , self .qk_rope_head_dim )
370- q_pe , k_pe = apply_rotary_emb (q_pe , k_pe , position_embeddings .to (q_pe .device ))
371-
372- k_pe = k_pe .expand (* k_nope .shape [:- 1 ], - 1 )
373- query_states = torch .cat ((q_nope , q_pe ), dim = - 1 )
374- key_states = torch .cat ((k_nope , k_pe ), dim = - 1 )
375-
376- if past_key_values is not None :
377- # sin and cos are specific to RoPE models; cache_position needed for the static cache
378- cache_kwargs = {"cache_position" : cache_position }
379- key_states , value_states = past_key_values .update (key_states , value_states , self .layer_idx , cache_kwargs )
380-
381- if self .config ._attn_implementation == "flash_attention_2" and self .qk_head_dim != self .v_head_dim :
382- value_states = F .pad (value_states , [0 , self .qk_head_dim - self .v_head_dim ])
383-
384- attention_interface : Callable = eager_attention_forward
385- if self .config ._attn_implementation != "eager" :
386- attention_interface = ALL_ATTENTION_FUNCTIONS [self .config ._attn_implementation ]
387-
388- attn_output , attn_weights = attention_interface (
389- self ,
390- query_states ,
391- key_states ,
392- value_states ,
393- attention_mask ,
394- dropout = 0.0 if not self .training else self .attention_dropout ,
395- scaling = self .scaling ,
396- ** kwargs ,
397- )
335+ """
336+ Forward pass for the Multi-Head Latent Attention (MLA) Layer.
398337
399- if self .config ._attn_implementation == "flash_attention_2" and self .qk_head_dim != self .v_head_dim :
400- attn_output = attn_output [:, :, :, : self .v_head_dim ]
338+ config.
339+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
340+ start_pos (int): Starting position in the sequence for caching.
341+ freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
342+ mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
401343
402- attn_output = attn_output .reshape (batch_size , seq_length , - 1 ).contiguous ()
403- attn_output = self .o_proj (attn_output )
404- return attn_output , attn_weights
344+ Returns:
345+ torch.Tensor: Output tensor with the same shape as the input.
346+ """
347+ bsz , seqlen , _ = x .size ()
348+ end_pos = start_pos + seqlen
349+ qr = self .q_norm (self .wq_a (x ))
350+ q = self .wq_b (qr )
351+ q = q .view (bsz , seqlen , self .n_local_heads , self .qk_head_dim )
352+ q_nope , q_pe = torch .split (q , [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
353+ kv = self .wkv_a (x )
354+ kv , k_pe = torch .split (kv , [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
355+ kv = self .kv_norm (kv )
356+ q_pe , k_pe = apply_rotary_emb (q_pe , k_pe .unsqueeze (2 ), freqs_cis )
357+
358+ self .kv_cache [:bsz , start_pos :end_pos ] = kv
359+ self .pe_cache [:bsz , start_pos :end_pos ] = k_pe .squeeze (2 )
360+ if mask is not None : # MHA prefill
361+ q = torch .cat ([q_nope , q_pe ], dim = - 1 )
362+ kv = self .wkv_b (kv )
363+ kv = kv .view (bsz , seqlen , self .n_local_heads , self .qk_nope_head_dim + self .v_head_dim )
364+ k_nope , v = torch .split (kv , [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
365+ k = torch .cat ([k_nope , k_pe .expand (- 1 , - 1 , self .n_local_heads , - 1 )], dim = - 1 )
366+ scores = torch .einsum ("bshd,bthd->bsht" , q .float (), k .float ()) * self .softmax_scale
367+
368+ # indexer
369+ topk_indices = self .indexer (x , qr , start_pos , freqs_cis , mask )
370+ index_mask = torch .full ((bsz , seqlen , seqlen ), float ("-inf" ), device = x .device ).scatter_ (
371+ - 1 , topk_indices , 0
372+ )
373+ index_mask += mask
374+ scores += index_mask .unsqueeze (2 )
375+
376+ scores = scores .softmax (dim = - 1 , dtype = torch .float32 )
377+ x = torch .einsum ("bsht,bthd->bshd" , scores .type_as (x ), v )
378+ else : # MHA decode
379+ wkv_b = self .wkv_b .weight
380+ wkv_b = wkv_b .view (self .n_local_heads , - 1 , self .kv_lora_rank )
381+ q_nope = torch .einsum ("bshd,hdc->bshc" , q_nope , wkv_b [:, : self .qk_nope_head_dim ])
382+ scores = (
383+ torch .einsum ("bshc,btc->bsht" , q_nope .float (), self .kv_cache [:bsz , :end_pos ].float ())
384+ + torch .einsum ("bshr,btr->bsht" , q_pe .float (), self .pe_cache [:bsz , :end_pos ].float ())
385+ ) * self .softmax_scale
386+
387+ # indexer
388+ topk_indices = self .indexer (x , qr , start_pos , freqs_cis , mask )
389+ index_mask = torch .full ((bsz , 1 , end_pos ), float ("-inf" ), device = x .device ).scatter_ (- 1 , topk_indices , 0 )
390+ scores += index_mask .unsqueeze (2 )
391+
392+ scores = scores .softmax (dim = - 1 , dtype = torch .float32 )
393+ x = torch .einsum ("bsht,btc->bshc" , scores .type_as (x ), self .kv_cache [:bsz , :end_pos ])
394+ x = torch .einsum ("bshc,hdc->bshd" , x , wkv_b [:, - self .v_head_dim :])
395+ x = self .wo (x .flatten (2 ))
396+ return x
405397
406398
407399class DeepseekV32DecoderLayer (GradientCheckpointingLayer ):
0 commit comments