Skip to content

Commit 837d041

Browse files
committed
workaround till we solve mathjax
1 parent 2100c4b commit 837d041

File tree

3 files changed

+60
-70
lines changed

3 files changed

+60
-70
lines changed

_posts/2025-01-09-transformer-showdown.md

Lines changed: 0 additions & 44 deletions
This file was deleted.

_posts/2025-01-22-transformer-showdown.md

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,27 @@ So without further ado, lets get comparing.
1919
- ### Multi Head Attention (MHA)
2020
The standard attention that was introduced as part of Attention is all you need. Very commonly used in a lot of models before 2023. Each layer has equal number of query, key and value heads. So if a layer has `h` heads, we'd have `h` queries, `h` keys and `h` values.
2121

22-
$$ Q_i = W_{q_i} X, \quad K_i = W_{k_i} X, \quad V_i = W_{v_i} X \quad \text {where X is the input}$$
22+
$$
23+
Q_i = W_{q_i} X, \quad K_i = W_{k_i} X, \quad V_i = W_{v_i} X \quad \text {where X is the input}
24+
$$
2325

24-
$$ Q = [Q_1, Q_2,..., Q_h], \quad K = [K_1, K_2,..., K_h], \quad V = [V_1, V_2,..., V_h] \quad \text {where [ ] is concatenation}$$
26+
$$
27+
Q = [Q_1, Q_2,..., Q_h], \quad K = [K_1, K_2,..., K_h], \quad V = [V_1, V_2,..., V_h] \quad \text {where [ ] is concatenation}
28+
$$
2529

26-
$$ A_i = \text{Attention}(Q_i, K_i, V_i) = softmax(\frac{Q_iK_i^T}{\sqrt{d_k}})V_i $$
27-
$$ A = [A_1, A_2, ..., A_h] \space @ \space Wo $$
30+
$$
31+
A_i = \text{Attention}(Q_i, K_i, V_i) = softmax(\frac{Q_iK_i^T}{\sqrt{d_k}})V_i
32+
$$
33+
34+
$$
35+
A = [A_1, A_2, ..., A_h] \space @ \space Wo
36+
$$
2837

2938
**Config** : `n` layers. Each layer has `h` heads. Each head has `d` dimensions. Total token count `t`.
3039

31-
**Parameters**: $W_{q_i}$ is of shape $(h*d, d)$, so has $h*d^2$ parameters per head. Same for $W_{k_i}, W_{v_i}$. So $W_q + W_k + W_v$ contributes to a total of $3*h*(h*d^2)$ paramters. $W_o$ is of size $(h*d, h*d)$ so $h^2*d^2$ parameters. Total of $4*n*h^2*d^2$.
40+
**Parameters**: $$W_{q_i}$$ is of shape $$(h*d, d)$$, so has $$h*d^2$$ parameters per head. Same for $$W_{k_i}, W_{v_i}$$. So $$W_q + W_k + W_v$$ contributes to a total of $$3*h*(h*d^2)$$ paramters. $$W_o$$ is of size $$(h*d, h*d)$$ so $$h^2*d^2$$ parameters. Total of $$4*n*h^2*d^2$$.
3241

33-
For example, [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json) has [32 attention heads](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L14) and [32 key value heads](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L16). So llama 2 7B uses MHA. It has a [hidden_size of 4096](https://huggingface.co/meta-llama/Llama-2-7b-hf/,blob/main/config.json#L9). This means, each head has a head_dim (d) of **128**. So the algebra tells us that $W_{q_i}$ would be of shape $(128 *32,128) = (4096,128)$. Each Q (similarly K,V) would be of shape $(4096, 128*32)=(4096,4096)$ contributing to $128^2 * 32^2=16,777,216$ paramters. Executing the below code would give you the same result. Voila.
42+
For example, [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json) has [32 attention heads](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L14) and [32 key value heads](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L16). So llama 2 7B uses MHA. It has a [hidden_size of 4096](https://huggingface.co/meta-llama/Llama-2-7b-hf/,blob/main/config.json#L9). This means, each head has a head_dim (d) of **128**. So the algebra tells us that $$W_{q_i}$$ would be of shape $$(128 *32,128) = (4096,128)$$. Each Q (similarly K,V) would be of shape $$(4096, 128*32)=(4096,4096)$$ contributing to $$128^2 * 32^2=16,777,216$$ paramters. Executing the below code would give you the same result. Voila.
3443

3544
```python
3645
llama2 = AutoModelForCausalLM.from_pretrained(
@@ -64,28 +73,38 @@ So without further ado, lets get comparing.
6473
Wo shape is torch.Size([4096, 4096]) contributes to 16777216 paramters
6574
```
6675

67-
**Activations**: Each token's query is of size `d` (per head). The same is for key and value. Hence a total of $3*n*h*d$ per token. The final output is of same shape as well. The attention scores, one per each pair of input tokens, form a matrix of size $t*t$ hence $t^2$. So a total of $4*n*h*d*t + n*h*t^2$.
76+
**Activations**: Each token's query is of size `d` (per head). The same is for key and value. Hence a total of $$3*n*h*d$$ per token. The final output is of same shape as well. The attention scores, one per each pair of input tokens, form a matrix of size $$t*t$$ hence $$t^2$$. So a total of $$4*n*h*d*t + n*h*t^2$$.
6877

69-
**KVCache**: Each key and value is of size `d` per token per head per layer. Hence a total of $n*h*d*t$ for key (and value). So the size of KV Cache is $2*n*h*d*t$
78+
**KVCache**: Each key and value is of size `d` per token per head per layer. Hence a total of $$n*h*d*t$$ for key (and value). So the size of KV Cache is $$2*n*h*d*t$$
7079

7180

7281
- ### Multi Query Attention (MQA)
7382
A small modification of MHA. Instead of having one key and value per query head, we'd only have a single key per token and all the query heads try to find similarity with that. This results in each layer having `h` queries, `1` key and `1` value. The advantage here is that if you're saving KVCache for speeding up inference, your KVCache is reduced by `h` times. But this is not as performant as MHA as we're reducing the scope of information stored in keys to only single vector.
7483

75-
$$ A_i = \text{Attention}(Q_i, K, V_i) = softmax(\frac{Q_iK^T}{\sqrt{d_k}})V_i $$
76-
$$ A = [A_1, A_2, ..., A_h] \space @ \space Wo $$
84+
$$
85+
A_i = \text{Attention}(Q_i, K, V_i) = softmax(\frac{Q_iK^T}{\sqrt{d_k}})V_i
86+
$$
87+
88+
$$
89+
A = [A_1, A_2, ..., A_h] \space @ \space Wo
90+
$$
7791

78-
- **Parameters**: $W_{q_i}$ is of shape $(h*d, d)$, so has $h^2*d^2$ parameters. As for $W_{k_i}, W_{v_i}$, they output a single vector per head. So $(d,d)$ shape and hence $d^2$ parameters. $W_o$ is of size $(h*d, h*d)$ so $h^2*d^2$ parameters. Total of $2*n*h^2*d^2 + 2*n*h*d^2$.
79-
- **Activations**: Each token's query is of size `d` (per head) So $h*d$.There is only one key shared across all the heads hence only $2*d$ (key and value). Hence a total of $n*h*d*t + 2*n*d*t$. The final output is of same shape as query as well. The attention scores form a matrix of size $t*t$ hence $t^2$. So a total of $2*n*h*d*t + 2*n*d*t + n*h*t^2$.
80-
- **KVCache**: Each key and value is of size `d` per token per layer. Hence a total of $2*n*d*t$. A compression of `h` times compared to MHA.
92+
- **Parameters**: $$W_{q_i}$$ is of shape $$(h*d, d)$$, so has $$h^2*d^2$$ parameters. As for $$W_{k_i}, W_{v_i}$$, they output a single vector per head. So $$(d,d)$$ shape and hence $$d^2$$ parameters. $$W_o$$ is of size $$(h*d, h*d)$$ so $$h^2*d^2$$ parameters. Total of $$2*n*h^2*d^2 + 2*n*h*d^2$$.
93+
- **Activations**: Each token's query is of size `d` (per head) So $$h*d$$.There is only one key shared across all the heads hence only $$2*d$$ (key and value). Hence a total of $$n*h*d*t + 2*n*d*t$$. The final output is of same shape as query as well. The attention scores form a matrix of size $$t*t$$ hence $$t^2$$. So a total of $$2*n*h*d*t + 2*n*d*t + n*h*t^2$$.
94+
- **KVCache**: Each key and value is of size `d` per token per layer. Hence a total of $$2*n*d*t$$. A compression of `h` times compared to MHA.
8195

8296
- ### Grouped Query Attention (GQA)
8397
This acts as a middle ground between MHA and MQA. Instead of 1 key and value catering to all the queries, we have 1 key and value catering to a group of queries. So we'd have `h` queries, `k` keys and `k` values where `k` divides `h`. So for each layer, you'd be storing `2 * k` embedding vectors. You'd find a lot of models that use this architecture. Generally speaking, a single query caters to `4` to `8` queries. You can identify whether a model uses this when you see [`num_attention_heads`](https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json#L16)`≠`[`num_kv_heads`](https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json#L18) in model's config.json
8498

85-
$$ A_i = \text{Attention}(Q_i, K_{i//g}, V_i) = softmax(\frac{Q_iK_{i//g}^T}{\sqrt{d_k}})V_i $$
86-
$$ A = [A_1, A_2, ..., A_h] \space @ \space Wo $$
99+
$$
100+
A_i = \text{Attention}(Q_i, K_{i//g}, V_i) = softmax(\frac{Q_iK_{i//g}^T}{\sqrt{d_k}})V_i
101+
$$
102+
103+
$$
104+
A = [A_1, A_2, ..., A_h] \space @ \space Wo
105+
$$
87106

88-
**Parameters**: $W_{q_i}$ is of shape $(h*d, d)$, so has $h^2*d^2$ parameters. As for $W_{k_i}, W_{v_i}$, they output a single vector per group of heads. So $(g*d,d)$ shape and hence $g*d^2$ parameters. $W_o$ is of size $(h*d, h*d)$ so $h^2*d^2$ parameters. Total of $2*n*h^2*d^2 + 2*n*g*h*d^2$.
107+
**Parameters**: $$W_{q_i}$$ is of shape $$(h*d, d)$$, so has $$h^2*d^2$$ parameters. As for $$W_{k_i}, W_{v_i}$$, they output a single vector per group of heads. So $$(g*d,d)$$ shape and hence $$g*d^2$$ parameters. $$W_o$$ is of size $$(h*d, h*d)$$ so $$h^2*d^2$$ parameters. Total of $$2*n*h^2*d^2 + 2*n*g*h*d^2$$.
89108

90109
```python
91110
llama3 = AutoModelForCausalLM.from_pretrained(
@@ -119,9 +138,9 @@ So without further ado, lets get comparing.
119138
```
120139

121140

122-
**Activations**: Each token's query is of size `d` (per head) resulting in $h*d$ sized tensor. There is one key per group of heads hence only $d*g$ (key, value) which together add up to $2*d*g$ per token. Hence a total of $n*h*d*t + 2*n*g*d*t$. The final output is of same shape as query as well. The attention scores form a matrix of size $t*t$ hence $t^2$. So a total of $2*n*h*d*t + 2*n*g*d*t + n*h*t^2$.
141+
**Activations**: Each token's query is of size `d` (per head) resulting in $$h*d$$ sized tensor. There is one key per group of heads hence only $$d*g$$ (key, value) which together add up to $$2*d*g$$ per token. Hence a total of $$n*h*d*t + 2*n*g*d*t$$. The final output is of same shape as query as well. The attention scores form a matrix of size $$t*t$$ hence $$t^2$$. So a total of $$2*n*h*d*t + 2*n*g*d*t + n*h*t^2$$.
123142

124-
**KVCache**: Each key and value is of size `d` per token per layer per group. Hence a total of $2*n*g*d*t$. A compression of `h/g` times compared to MHA.
143+
**KVCache**: Each key and value is of size `d` per token per layer per group. Hence a total of $$2*n*g*d*t$$. A compression of `h/g` times compared to MHA.
125144

126145
- ### MultiHead Latent Attention (MLA)
127146
A new architecture found in DeepSeek V2 family of models. Here, we compress the Keys and values into a latent space and uncompress them back to original space when inference takes place. The idea is to get the advantages of MHA while saving up on KVCache as it scales linearly with context length. Each key and value are compressed from `d` dimensions to `c` dimension space.
@@ -132,25 +151,39 @@ So without further ado, lets get comparing.
132151
_Share of eigen values contributing to 90% in weight_
133152

134153
$$
135-
c_t^{KV} = W^{DKV} X, \quad \text{ where } c_t^{KV} \in \R^{c} \quad \text { is down projection of keys }\\
136-
k_t^C = W^{UK} c_t^{KV} \quad \text { up projection of keys} \\
154+
c_t^{KV} = W^{DKV} X, \quad \text{ where } c_t^{KV} \in \mathbb{R}^{c} \quad \text { is down projection of keys }
155+
$$
156+
157+
$$
158+
k_t^C = W^{UK} c_t^{KV} \quad \text { up projection of keys}
159+
$$
160+
161+
$$
137162
v_t^C = W^{UV} c_t^{KV} \quad \text { up projection of values}
138163
$$
139164

140165
$$
141-
c_t^{Q} = W^{DQ} X \quad \text{ where } c_t^{Q} \in \R^{c} \\
166+
c_t^{Q} = W^{DQ} X \quad \text{ where } c_t^{Q} \in \mathbb{R}^{c}
167+
$$
168+
169+
$$
142170
q_t^C = W^{UQ} c_t^{Q}
143171
$$
144172

145173

146174

147-
$$ A_i = \text{Attention}(Q_i, K_i, V_i) = softmax(\frac{Q_i K_i^T}{\sqrt{d_k}})V_i $$
148-
$$ A = [A_1, A_2, ..., A_h] \space @ \space Wo $$
175+
$$
176+
A_i = \text{Attention}(Q_i, K_i, V_i) = softmax(\frac{Q_i K_i^T}{\sqrt{d_k}})V_i
177+
$$
178+
179+
$$
180+
A = [A_1, A_2, ..., A_h] \space @ \space Wo
181+
$$
149182

150183
![Multi Latent Attention formulae](assets/img/blogs/transformer_showdown/mla.png)
151184
_Multi Latent Attention formulae_
152185

153-
**KVCache**: Each compressed vector is of size `c` per token per layer per group. Hence a total of $n*g*c*t$. Keys and values are inferred by decompressing this ($k_t^C, v_t^C$). A compression of `2*d/c` times compared to MHA. Note that in final implementation there's a nuance of additional heads (and hence keys and values) for RoPE. That adds a little more overhead. So the compression ratio essentially becomes $2*d/(c+r)$ where r is the RoPE key dimension.
186+
**KVCache**: Each compressed vector is of size `c` per token per layer per group. Hence a total of $$n*g*c*t$$. Keys and values are inferred by decompressing this ($$k_t^C, v_t^C$$). A compression of `2*d/c` times compared to MHA. Note that in final implementation there's a nuance of additional heads (and hence keys and values) for RoPE. That adds a little more overhead. So the compression ratio essentially becomes $$2*d/(c+r)$$ where r is the RoPE key dimension.
154187

155188

156189
This image from DeepSeek V2 paper gives a crisp view of the above mentioned architectures.
@@ -177,6 +210,8 @@ _MHA vs GQA vs MQA vs MLA_
177210

178211
- Apart from more normalisations there isn't much that would meaningfully contribute to parameters or activations or KVCache as compared to GQA.
179212

213+
## Results and Findings
214+
180215
So now that the introductions are out of the way, the burning question is do the changes contribute to any meaningful differences in the final performance of the models?
181216

182217
Well the answer is nuanced. Let's see how they stack up.
@@ -194,5 +229,5 @@ _Train losses on minipile dataset_
194229
On the [minipile dataset](https://huggingface.co/datasets/jeankaddour/minipile) which is approximately 10x larger than the wiki data, I saw that there isn't much to choose between MLA, MHA, GQA and DiffAttention. Which is great since GQA uses 4x less keys and values resulting in 4x less KVCache. Surprisingly, nGPT's losses seem to go down as low as 0.2 when the others hover around 3. I tried to repeat the experiement multiple times with multiple configs only to find a similar loss curve. I also checked validation loss for all the models, they look very similar to train loss curves so there isn't much value in plotting those. We will have to look into why this is the case but it definitely is fascinating.
195230

196231

197-
### Conclusion
232+
## Conclusion
198233
All in all, GQA offers a very good alternative to MHA, sometimes even outperforming it while also using 4-8x less space for KVCache. MLA builds upon that by compressing the Keys and values even further. Turns out, this also acts as regularisation. Normalisation is the king of all. Given that normalisation is a key component in deep learning, it is no surprise that making it explicit for every operation. This opens up new paths to LLM training. We will explore the down stream capabilities of the models in a future write up. Until then, Ciao.

assets/js/data/mathjax.js

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
---
22
layout: compress
3-
# WARNING: Don't use '//' to comment out code, use '{% comment %}' and '{% endcomment %}' instead.
43
---
54

65
{%- comment -%}

0 commit comments

Comments
 (0)