You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: _posts/2025-01-22-transformer-showdown.md
+60-25Lines changed: 60 additions & 25 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -19,18 +19,27 @@ So without further ado, lets get comparing.
19
19
-### Multi Head Attention (MHA)
20
20
The standard attention that was introduced as part of Attention is all you need. Very commonly used in a lot of models before 2023. Each layer has equal number of query, key and value heads. So if a layer has `h` heads, we'd have `h` queries, `h` keys and `h` values.
21
21
22
-
$$ Q_i = W_{q_i} X, \quad K_i = W_{k_i} X, \quad V_i = W_{v_i} X \quad \text {where X is the input}$$
22
+
$$
23
+
Q_i = W_{q_i} X, \quad K_i = W_{k_i} X, \quad V_i = W_{v_i} X \quad \text {where X is the input}
24
+
$$
23
25
24
-
$$ Q = [Q_1, Q_2,..., Q_h], \quad K = [K_1, K_2,..., K_h], \quad V = [V_1, V_2,..., V_h] \quad \text {where [ ] is concatenation}$$
26
+
$$
27
+
Q = [Q_1, Q_2,..., Q_h], \quad K = [K_1, K_2,..., K_h], \quad V = [V_1, V_2,..., V_h] \quad \text {where [ ] is concatenation}
**Config** : `n` layers. Each layer has `h` heads. Each head has `d` dimensions. Total token count `t`.
30
39
31
-
**Parameters**: $W_{q_i}$ is of shape $(h*d, d)$, so has $h*d^2$ parameters per head. Same for $W_{k_i}, W_{v_i}$. So $W_q + W_k + W_v$ contributes to a total of $3*h*(h*d^2)$ paramters. $W_o$ is of size $(h*d, h*d)$ so $h^2*d^2$ parameters. Total of $4*n*h^2*d^2$.
40
+
**Parameters**: $$W_{q_i}$$ is of shape $$(h*d, d)$$, so has $$h*d^2$$ parameters per head. Same for $$W_{k_i}, W_{v_i}$$. So $$W_q + W_k + W_v$$ contributes to a total of $$3*h*(h*d^2)$$ paramters. $$W_o$$ is of size $$(h*d, h*d)$$ so $$h^2*d^2$$ parameters. Total of $$4*n*h^2*d^2$$.
32
41
33
-
For example, [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json) has [32 attention heads](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L14) and [32 key value heads](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L16). So llama 2 7B uses MHA. It has a [hidden_size of 4096](https://huggingface.co/meta-llama/Llama-2-7b-hf/,blob/main/config.json#L9). This means, each head has a head_dim (d) of **128**. So the algebra tells us that $W_{q_i}$ would be of shape $(128 *32,128) = (4096,128)$. Each Q (similarly K,V) would be of shape $(4096, 128*32)=(4096,4096)$ contributing to $128^2 * 32^2=16,777,216$ paramters. Executing the below code would give you the same result. Voila.
42
+
For example, [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json) has [32 attention heads](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L14) and [32 key value heads](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L16). So llama 2 7B uses MHA. It has a [hidden_size of 4096](https://huggingface.co/meta-llama/Llama-2-7b-hf/,blob/main/config.json#L9). This means, each head has a head_dim (d) of **128**. So the algebra tells us that $$W_{q_i}$$ would be of shape $$(128 *32,128) = (4096,128)$$. Each Q (similarly K,V) would be of shape $$(4096, 128*32)=(4096,4096)$$ contributing to $$128^2 * 32^2=16,777,216$$ paramters. Executing the below code would give you the same result. Voila.
34
43
35
44
```python
36
45
llama2 = AutoModelForCausalLM.from_pretrained(
@@ -64,28 +73,38 @@ So without further ado, lets get comparing.
64
73
Wo shape is torch.Size([4096, 4096]) contributes to 16777216 paramters
65
74
```
66
75
67
-
**Activations**: Each token's query is of size `d` (per head). The same is for key and value. Hence a total of $3*n*h*d$ per token. The final output is of same shape as well. The attention scores, one per each pair of input tokens, form a matrix of size $t*t$ hence $t^2$. So a total of $4*n*h*d*t + n*h*t^2$.
76
+
**Activations**: Each token's query is of size `d` (per head). The same is for key and value. Hence a total of $$3*n*h*d$$ per token. The final output is of same shape as well. The attention scores, one per each pair of input tokens, form a matrix of size $$t*t$$ hence $$t^2$$. So a total of $$4*n*h*d*t + n*h*t^2$$.
68
77
69
-
**KVCache**: Each key and value is of size `d` per token per head per layer. Hence a total of $n*h*d*t$for key (and value). So the size of KV Cache is$2*n*h*d*t$
78
+
**KVCache**: Each key and value is of size `d` per token per head per layer. Hence a total of $$n*h*d*t$$for key (and value). So the size of KV Cache is$$2*n*h*d*t$$
70
79
71
80
72
81
-### Multi Query Attention (MQA)
73
82
A small modification of MHA. Instead of having one key and value per query head, we'd only have a single key per token and all the query heads try to find similarity with that. This results in each layer having `h` queries, `1` key and `1` value. The advantage here is that if you're saving KVCache for speeding up inference, your KVCache is reduced by `h` times. But this isnotas performant asMHAas we're reducing the scope of information stored in keys to only single vector.
-**Parameters**: $W_{q_i}$is of shape $(h*d, d)$, so has $h^2*d^2$ parameters. As for$W_{k_i}, W_{v_i}$, they output a single vector per head. So $(d,d)$ shape and hence $d^2$ parameters. $W_o$is of size $(h*d, h*d)$ so $h^2*d^2$ parameters. Total of $2*n*h^2*d^2+2*n*h*d^2$.
79
-
-**Activations**: Each token's query is of size `d` (per head) So $h*d$.There is only one key shared across all the heads hence only $2*d$ (key and value). Hence a total of $n*h*d*t + 2*n*d*t$. The final output is of same shape as query as well. The attention scores form a matrix of size $t*t$ hence $t^2$. So a total of $2*n*h*d*t + 2*n*d*t + n*h*t^2$.
80
-
-**KVCache**: Each key and value is of size `d` per token per layer. Hence a total of $2*n*d*t$. A compression of `h` times compared to MHA.
92
+
-**Parameters**: $$W_{q_i}$$is of shape $$(h*d, d)$$, so has $$h^2*d^2$$ parameters. As for$$W_{k_i}, W_{v_i}$$, they output a single vector per head. So $$(d,d)$$ shape and hence $$d^2$$ parameters. $$W_o$$is of size $$(h*d, h*d)$$ so $$h^2*d^2$$ parameters. Total of $$2*n*h^2*d^2+2*n*h*d^2$$.
93
+
-**Activations**: Each token's query is of size `d` (per head) So $$h*d$$.There is only one key shared across all the heads hence only $$2*d$$ (key and value). Hence a total of $$n*h*d*t + 2*n*d*t$$. The final output is of same shape as query as well. The attention scores form a matrix of size $$t*t$$ hence $$t^2$$. So a total of $$2*n*h*d*t + 2*n*d*t + n*h*t^2$$.
94
+
-**KVCache**: Each key and value is of size `d` per token per layer. Hence a total of $$2*n*d*t$$. A compression of `h` times compared to MHA.
81
95
82
96
-### Grouped Query Attention (GQA)
83
97
This acts as a middle ground between MHAandMQA. Instead of 1 key and value catering to all the queries, we have 1 key and value catering to a group of queries. So we'd have `h` queries, `k` keys and `k` values where `k` divides `h`. So for each layer, you'd be storing `2* k` embedding vectors. You'd find a lot of models that use this architecture. Generally speaking, a single query caters to `4` to `8` queries. You can identify whether a model uses this when you see [`num_attention_heads`](https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json#L16)`≠`[`num_kv_heads`](https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json#L18) in model's config.json
**Parameters**: $W_{q_i}$is of shape $(h*d, d)$, so has $h^2*d^2$ parameters. As for$W_{k_i}, W_{v_i}$, they output a single vector per group of heads. So $(g*d,d)$ shape and hence $g*d^2$ parameters. $W_o$is of size $(h*d, h*d)$ so $h^2*d^2$ parameters. Total of $2*n*h^2*d^2+2*n*g*h*d^2$.
107
+
**Parameters**: $$W_{q_i}$$is of shape $$(h*d, d)$$, so has $$h^2*d^2$$ parameters. As for$$W_{k_i}, W_{v_i}$$, they output a single vector per group of heads. So $$(g*d,d)$$ shape and hence $$g*d^2$$ parameters. $$W_o$$is of size $$(h*d, h*d)$$ so $$h^2*d^2$$ parameters. Total of $$2*n*h^2*d^2+2*n*g*h*d^2$$.
89
108
90
109
```python
91
110
llama3 = AutoModelForCausalLM.from_pretrained(
@@ -119,9 +138,9 @@ So without further ado, lets get comparing.
119
138
```
120
139
121
140
122
-
**Activations**: Each token's query is of size `d` (per head) resulting in $h*d$ sized tensor. There is one key per group of heads hence only $d*g$ (key, value) which together add up to $2*d*g$ per token. Hence a total of $n*h*d*t + 2*n*g*d*t$. The final output is of same shape as query as well. The attention scores form a matrix of size $t*t$ hence $t^2$. So a total of $2*n*h*d*t + 2*n*g*d*t + n*h*t^2$.
141
+
**Activations**: Each token's query is of size `d` (per head) resulting in $$h*d$$ sized tensor. There is one key per group of heads hence only $$d*g$$ (key, value) which together add up to $$2*d*g$$ per token. Hence a total of $$n*h*d*t + 2*n*g*d*t$$. The final output is of same shape as query as well. The attention scores form a matrix of size $$t*t$$ hence $$t^2$$. So a total of $$2*n*h*d*t + 2*n*g*d*t + n*h*t^2$$.
123
142
124
-
**KVCache**: Each key and value is of size `d` per token per layer per group. Hence a total of $2*n*g*d*t$. A compression of `h/g` times compared to MHA.
143
+
**KVCache**: Each key and value is of size `d` per token per layer per group. Hence a total of $$2*n*g*d*t$$. A compression of `h/g` times compared to MHA.
125
144
126
145
-### MultiHead Latent Attention (MLA)
127
146
A new architecture found in DeepSeek V2 family of models. Here, we compress the Keys and values into a latent space and uncompress them back to original space when inference takes place. The idea is to get the advantages of MHAwhile saving up on KVCache as it scales linearly with context length. Each key and value are compressed from`d` dimensions to `c` dimension space.
@@ -132,25 +151,39 @@ So without further ado, lets get comparing.
132
151
_Share of eigen values contributing to 90%in weight_
133
152
134
153
$$
135
-
c_t^{KV} = W^{DKV} X, \quad \text{ where } c_t^{KV} \in \R^{c} \quad \text { is down projection of keys }\\
136
-
k_t^C = W^{UK} c_t^{KV} \quad \text { up projection of keys} \\
154
+
c_t^{KV} = W^{DKV} X, \quad \text{ where } c_t^{KV} \in \mathbb{R}^{c} \quad \text { is down projection of keys }
155
+
$$
156
+
157
+
$$
158
+
k_t^C = W^{UK} c_t^{KV} \quad \text { up projection of keys}
159
+
$$
160
+
161
+
$$
137
162
v_t^C = W^{UV} c_t^{KV} \quad \text { up projection of values}
138
163
$$
139
164
140
165
$$
141
-
c_t^{Q} = W^{DQ} X \quad \text{ where } c_t^{Q} \in \R^{c} \\
166
+
c_t^{Q} = W^{DQ} X \quad \text{ where } c_t^{Q} \in \mathbb{R}^{c}
**KVCache**: Each compressed vector is of size `c` per token per layer per group. Hence a total of $n*g*c*t$. Keys and values are inferred by decompressing this ($k_t^C, v_t^C$). A compression of `2*d/c` times compared to MHA. Note that in final implementation there's a nuance of additional heads (and hence keys and values) for RoPE. That adds a little more overhead. So the compression ratio essentially becomes $2*d/(c+r)$ where r is the RoPE key dimension.
186
+
**KVCache**: Each compressed vector is of size `c` per token per layer per group. Hence a total of $$n*g*c*t$$. Keys and values are inferred by decompressing this ($$k_t^C, v_t^C$$). A compression of `2*d/c` times compared to MHA. Note that in final implementation there's a nuance of additional heads (and hence keys and values) for RoPE. That adds a little more overhead. So the compression ratio essentially becomes $$2*d/(c+r)$$ where r is the RoPE key dimension.
154
187
155
188
156
189
This image from DeepSeek V2 paper gives a crisp view of the above mentioned architectures.
@@ -177,6 +210,8 @@ _MHA vs GQA vs MQA vs MLA_
177
210
178
211
- Apart from more normalisations there isn't much that would meaningfully contribute to parameters or activations or KVCache as compared to GQA.
179
212
213
+
## Results and Findings
214
+
180
215
So now that the introductions are out of the way, the burning question is do the changes contribute to any meaningful differences in the final performance of the models?
181
216
182
217
Well the answer is nuanced. Let's see how they stack up.
@@ -194,5 +229,5 @@ _Train losses on minipile dataset_
194
229
On the [minipile dataset](https://huggingface.co/datasets/jeankaddour/minipile) which is approximately 10x larger than the wiki data, I saw that there isn't much to choose between MLA, MHA, GQA and DiffAttention. Which is great since GQA uses 4x less keys and values resulting in 4x less KVCache. Surprisingly, nGPT's losses seem to go down as low as0.2 when the others hover around 3. I tried to repeat the experiement multiple times with multiple configs only to find a similar loss curve. I also checked validation loss forall the models, they look very similar to train loss curves so there isn't much value in plotting those. We will have to look into why this is the case but it definitely is fascinating.
195
230
196
231
197
-
### Conclusion
232
+
## Conclusion
198
233
All inall, GQA offers a very good alternative to MHA, sometimes even outperforming it while also using 4-8x less space for KVCache. MLA builds upon that by compressing the Keys and values even further. Turns out, this also acts as regularisation. Normalisation is the king of all. Given that normalisation is a key component in deep learning, it is no surprise that making it explicit for every operation. This opens up new paths to LLM training. We will explore the down stream capabilities of the models in a future write up. Until then, Ciao.
0 commit comments