Skip to content

Commit 06acab8

Browse files
committed
updates
1 parent 57ba98f commit 06acab8

File tree

3 files changed

+395
-150
lines changed

3 files changed

+395
-150
lines changed

src/transformers/models/deepseek_v32/configuration_deepseek_v32.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1919
# See the License for the specific language governing permissions and
2020
# limitations under the License.
21-
2221
from ...configuration_utils import PretrainedConfig
2322
from ...modeling_rope_utils import rope_config_validation
2423

@@ -180,6 +179,9 @@ def __init__(
180179
num_experts_per_tok=None,
181180
norm_topk_prob=False,
182181
moe_intermediate_size=1407,
182+
index_n_heads=64,
183+
index_head_dim=128,
184+
index_topk=2048,
183185
**kwargs,
184186
):
185187
super().__init__(
@@ -233,6 +235,9 @@ def __init__(
233235
self.num_experts_per_tok = num_experts_per_tok
234236
self.norm_topk_prob = norm_topk_prob
235237
self.moe_intermediate_size = moe_intermediate_size
238+
self.index_n_heads = index_n_heads
239+
self.index_head_dim = index_head_dim
240+
self.index_top_k = index_topk
236241

237242

238243
__all__ = ["DeepseekV32Config"]

src/transformers/models/deepseek_v32/modeling_deepseek_v32.py

Lines changed: 139 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@
1818
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1919
# See the License for the specific language governing permissions and
2020
# limitations under the License.
21-
22-
import warnings
23-
from typing import Callable, Optional, Union
21+
import math
22+
from typing import Optional, Union
2423

2524
import torch
2625
import torch.nn.functional as F
@@ -34,7 +33,7 @@
3433
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
3534
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
3635
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37-
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36+
from ...modeling_utils import PreTrainedModel
3837
from ...processing_utils import Unpack
3938
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
4039
from ...utils.deprecation import deprecate_kwarg
@@ -231,44 +230,6 @@ def forward(self, x, position_ids):
231230
return freqs_cis
232231

233232

234-
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
235-
"""
236-
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
237-
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
238-
"""
239-
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
240-
if n_rep == 1:
241-
return hidden_states
242-
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
243-
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
244-
245-
246-
def eager_attention_forward(
247-
module: nn.Module,
248-
query: torch.Tensor,
249-
key: torch.Tensor,
250-
value: torch.Tensor,
251-
attention_mask: Optional[torch.Tensor],
252-
scaling: float,
253-
dropout: float = 0.0,
254-
**kwargs: Unpack[TransformersKwargs],
255-
):
256-
key_states = repeat_kv(key, module.num_key_value_groups)
257-
value_states = repeat_kv(value, module.num_key_value_groups)
258-
259-
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
260-
if attention_mask is not None:
261-
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
262-
attn_weights = attn_weights + causal_mask
263-
264-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
265-
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
266-
attn_output = torch.matmul(attn_weights, value_states)
267-
attn_output = attn_output.transpose(1, 2).contiguous()
268-
269-
return attn_output, attn_weights
270-
271-
272233
def apply_rotary_emb(
273234
xq: torch.Tensor,
274235
xk: torch.Tensor,
@@ -285,123 +246,154 @@ def apply_rotary_emb(
285246
return xq_out, xk_out
286247

287248

288-
class DeepseekV32Attention(nn.Module):
289-
"""Multi-headed attention from 'Attention Is All You Need' paper"""
290-
291-
def __init__(self, config: DeepseekV32Config, layer_idx: Optional[int] = None):
249+
class DeepseekV32Indexer(nn.Module):
250+
def __init__(self, config: DeepseekV32Config):
292251
super().__init__()
293-
self.config = config
294-
self.layer_idx = layer_idx
295-
self.attention_dropout = config.attention_dropout
296-
self.hidden_size = config.hidden_size
297-
self.num_heads = config.num_attention_heads
298-
self.head_dim = config.head_dim
299-
self.max_position_embeddings = config.max_position_embeddings
300-
self.rope_theta = config.rope_theta
301-
self.q_lora_rank = config.q_lora_rank
302-
self.qk_rope_head_dim = config.qk_rope_head_dim
303-
self.kv_lora_rank = config.kv_lora_rank
304-
self.v_head_dim = config.v_head_dim
305-
self.qk_nope_head_dim = config.qk_nope_head_dim
306-
self.qk_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
307-
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
308-
309-
self.is_causal = True
310-
311-
if self.q_lora_rank is None:
312-
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
313-
else:
314-
self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias)
315-
self.q_a_layernorm = DeepseekV32RMSNorm(config.q_lora_rank)
316-
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
317-
318-
self.kv_a_proj_with_mqa = nn.Linear(
319-
self.hidden_size,
320-
config.kv_lora_rank + config.qk_rope_head_dim,
321-
bias=config.attention_bias,
252+
self.dim: int = config.dim
253+
self.n_heads: int = config.index_n_heads
254+
self.n_local_heads = config.index_n_heads # // world_size
255+
self.head_dim: int = config.index_head_dim
256+
self.rope_head_dim: int = config.qk_rope_head_dim
257+
self.index_topk: int = config.index_topk
258+
self.q_lora_rank: int = config.q_lora_rank
259+
self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.head_dim)
260+
self.wk = nn.Linear(self.dim, self.head_dim)
261+
self.k_norm = nn.LayerNorm(self.head_dim)
262+
self.weights_proj = nn.Linear(self.dim, self.n_heads, dtype=torch.get_default_dtype())
263+
self.softmax_scale = self.head_dim**-0.5
264+
self.scale_fmt = config.scale_fmt
265+
self.k_sclae_head_dim = self.head_dim
266+
self.register_buffer(
267+
"k_cache",
268+
torch.zeros(config.max_batch_size, config.max_seq_len, self.head_dim, dtype=torch.float8_e4m3fn),
269+
persistent=False,
322270
)
323-
self.kv_a_layernorm = DeepseekV32RMSNorm(config.kv_lora_rank)
324-
self.kv_b_proj = nn.Linear(
325-
config.kv_lora_rank,
326-
self.num_heads * (self.qk_head_dim - self.qk_rope_head_dim + self.v_head_dim),
327-
bias=False,
271+
self.register_buffer(
272+
"k_scale_cache",
273+
torch.zeros(config.max_batch_size, config.max_seq_len, self.head_dim, dtype=torch.float32),
274+
persistent=False,
328275
)
329276

330-
self.o_proj = nn.Linear(
331-
self.num_heads * self.v_head_dim,
332-
self.hidden_size,
333-
bias=config.attention_bias,
277+
def forward(
278+
self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]
279+
):
280+
bsz, seqlen, _ = x.size()
281+
end_pos = start_pos + seqlen
282+
q = self.wq_b(qr)
283+
q = q.reshape(bsz, seqlen, -1, self.head_dim)
284+
q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
285+
q_pe = apply_rotary_pos_emb(q_pe, freqs_cis)
286+
q = torch.cat([q_pe, q_nope], dim=-1)
287+
k = self.wk(x)
288+
k = self.k_norm(k)
289+
k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
290+
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis).squeeze(2)
291+
k = torch.cat([k_pe, k_nope], dim=-1)
292+
q_fp8, q_scale = act_quant(q, block_size, self.scale_fmt)
293+
k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt)
294+
self.k_cache[:bsz, start_pos:end_pos] = k_fp8
295+
self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale
296+
weights = self.weights_proj(x) * self.n_heads**-0.5
297+
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
298+
index_score = fp8_index(
299+
q_fp8.contiguous(),
300+
weights,
301+
self.k_cache[:bsz, :end_pos].contiguous(),
302+
self.k_scale_cache[:bsz, :end_pos].contiguous(),
334303
)
304+
if mask is not None:
305+
index_score += mask
306+
topk_indices = index_score.topk(min(self.index_topk, end_pos), dim=-1)[1]
307+
topk_indices_ = topk_indices.clone()
308+
assert torch.all(topk_indices == topk_indices_), f"{topk_indices=} {topk_indices_=}"
309+
return topk_indices
310+
311+
312+
class DeepseekV32Attention(nn.Module):
313+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
314+
315+
def __init__(self, config, layer_idx):
316+
self.softmax_scale = self.qk_head_dim**-0.5
317+
if config.max_seq_len > config.original_seq_len:
318+
mscale = 0.1 * config.mscale * math.log(config.rope_factor) + 1.0
319+
self.softmax_scale = self.softmax_scale * mscale * mscale
335320

336-
self.scaling = self.qk_head_dim ** (-0.5)
321+
self.indexer = DeepseekV32Indexer(config)
322+
323+
self.register_buffer(
324+
"kv_cache", torch.zeros(config.max_batch_size, config.max_seq_len, self.kv_lora_rank), persistent=False
325+
)
326+
self.register_buffer(
327+
"pe_cache", torch.zeros(config.max_batch_size, config.max_seq_len, self.qk_rope_head_dim), persistent=False
328+
)
329+
self.dequant_wkv_b = None
337330

338331
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
339332
def forward(
340-
self,
341-
hidden_states: torch.Tensor,
342-
attention_mask: Optional[torch.Tensor] = None,
343-
past_key_values: Optional[Cache] = None,
344-
cache_position: Optional[torch.LongTensor] = None,
345-
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
346-
position_ids: Optional[torch.Tensor] = None,
347-
**kwargs,
333+
self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]
348334
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
349-
if "padding_mask" in kwargs:
350-
warnings.warn(
351-
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
352-
)
353-
batch_size, seq_length = hidden_states.shape[:-1]
354-
query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
355-
key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
356-
357-
if self.q_lora_rank is None:
358-
q = self.q_proj(hidden_states)
359-
else:
360-
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
361-
q = q.view(query_shape).transpose(1, 2)
362-
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
363-
364-
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
365-
k_nope, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
366-
k_nope = self.kv_b_proj(self.kv_a_layernorm(k_nope)).view(key_shape).transpose(1, 2)
367-
k_nope, value_states = torch.split(k_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
368-
369-
k_pe = k_pe.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
370-
q_pe, k_pe = apply_rotary_emb(q_pe, k_pe, position_embeddings.to(q_pe.device))
371-
372-
k_pe = k_pe.expand(*k_nope.shape[:-1], -1)
373-
query_states = torch.cat((q_nope, q_pe), dim=-1)
374-
key_states = torch.cat((k_nope, k_pe), dim=-1)
375-
376-
if past_key_values is not None:
377-
# sin and cos are specific to RoPE models; cache_position needed for the static cache
378-
cache_kwargs = {"cache_position": cache_position}
379-
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
380-
381-
if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
382-
value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
383-
384-
attention_interface: Callable = eager_attention_forward
385-
if self.config._attn_implementation != "eager":
386-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
387-
388-
attn_output, attn_weights = attention_interface(
389-
self,
390-
query_states,
391-
key_states,
392-
value_states,
393-
attention_mask,
394-
dropout=0.0 if not self.training else self.attention_dropout,
395-
scaling=self.scaling,
396-
**kwargs,
397-
)
335+
"""
336+
Forward pass for the Multi-Head Latent Attention (MLA) Layer.
398337
399-
if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
400-
attn_output = attn_output[:, :, :, : self.v_head_dim]
338+
config.
339+
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
340+
start_pos (int): Starting position in the sequence for caching.
341+
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
342+
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
401343
402-
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
403-
attn_output = self.o_proj(attn_output)
404-
return attn_output, attn_weights
344+
Returns:
345+
torch.Tensor: Output tensor with the same shape as the input.
346+
"""
347+
bsz, seqlen, _ = x.size()
348+
end_pos = start_pos + seqlen
349+
qr = self.q_norm(self.wq_a(x))
350+
q = self.wq_b(qr)
351+
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
352+
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
353+
kv = self.wkv_a(x)
354+
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
355+
kv = self.kv_norm(kv)
356+
q_pe, k_pe = apply_rotary_emb(q_pe, k_pe.unsqueeze(2), freqs_cis)
357+
358+
self.kv_cache[:bsz, start_pos:end_pos] = kv
359+
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
360+
if mask is not None: # MHA prefill
361+
q = torch.cat([q_nope, q_pe], dim=-1)
362+
kv = self.wkv_b(kv)
363+
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
364+
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
365+
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
366+
scores = torch.einsum("bshd,bthd->bsht", q.float(), k.float()) * self.softmax_scale
367+
368+
# indexer
369+
topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
370+
index_mask = torch.full((bsz, seqlen, seqlen), float("-inf"), device=x.device).scatter_(
371+
-1, topk_indices, 0
372+
)
373+
index_mask += mask
374+
scores += index_mask.unsqueeze(2)
375+
376+
scores = scores.softmax(dim=-1, dtype=torch.float32)
377+
x = torch.einsum("bsht,bthd->bshd", scores.type_as(x), v)
378+
else: # MHA decode
379+
wkv_b = self.wkv_b.weight
380+
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
381+
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, : self.qk_nope_head_dim])
382+
scores = (
383+
torch.einsum("bshc,btc->bsht", q_nope.float(), self.kv_cache[:bsz, :end_pos].float())
384+
+ torch.einsum("bshr,btr->bsht", q_pe.float(), self.pe_cache[:bsz, :end_pos].float())
385+
) * self.softmax_scale
386+
387+
# indexer
388+
topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
389+
index_mask = torch.full((bsz, 1, end_pos), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0)
390+
scores += index_mask.unsqueeze(2)
391+
392+
scores = scores.softmax(dim=-1, dtype=torch.float32)
393+
x = torch.einsum("bsht,btc->bshc", scores.type_as(x), self.kv_cache[:bsz, :end_pos])
394+
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim :])
395+
x = self.wo(x.flatten(2))
396+
return x
405397

406398

407399
class DeepseekV32DecoderLayer(GradientCheckpointingLayer):

0 commit comments

Comments
 (0)