You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: _posts/2025-01-22-transformer-showdown.md
+33-3Lines changed: 33 additions & 3 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -2,7 +2,7 @@
2
2
layout: post
3
3
title: Transformer showdown MHA vs MLA vs nGPT vs Differential Transformer
4
4
description: Comparing various transformer architectures like MHA, GQA, Multi Latent Attention, nGPT, Differential Transformer.
5
-
date: 2025-01-09 00:27 +0530
5
+
date: 2025-01-22 20:27 +0530
6
6
categories: [Transformer, Architectures]
7
7
tags: [MLA, MHA, GQA, Multi Latent Attention, nGPT, Differential Transformer]
8
8
render_with_liquid: false
@@ -30,7 +30,7 @@ So without further ado, lets get comparing.
30
30
31
31
**Parameters**: $W_{q_i}$ is of shape $(h*d, d)$, so has $h*d^2$ parameters per head. Same for $W_{k_i}, W_{v_i}$. So $W_q + W_k + W_v$ contributes to a total of $3*h*(h*d^2)$ paramters. $W_o$ is of size $(h*d, h*d)$ so $h^2*d^2$ parameters. Total of $4*n*h^2*d^2$.
32
32
33
-
For example, [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json) has [32 attention heads](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L14) and [32 key value heads](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L16). So llama 2 7B uses MHA. It has a [hidden_size of 4096](https://huggingface.co/meta-llama/Llama-2-7b-hf/,blob/main/config.json#L9). This means, each head has a head_dim (d) of **128**. So the algebra tells us that $W_{q_i}$ would be of shape $(128 *32,128) = (4096,128)$. Each Q (similarly K,V) would be of shape $(4096, 128*32)=(4096,4096)$ contributing to $128^2 * 32^2=1,6,777,216$ paramters. Executing the below code would give you the same result. Voila.
33
+
For example, [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json) has [32 attention heads](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L14) and [32 key value heads](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L16). So llama 2 7B uses MHA. It has a [hidden_size of 4096](https://huggingface.co/meta-llama/Llama-2-7b-hf/,blob/main/config.json#L9). This means, each head has a head_dim (d) of **128**. So the algebra tells us that $W_{q_i}$ would be of shape $(128 *32,128) = (4096,128)$. Each Q (similarly K,V) would be of shape $(4096, 128*32)=(4096,4096)$ contributing to $128^2 * 32^2=16,777,216$ paramters. Executing the below code would give you the same result. Voila.
34
34
35
35
```python
36
36
llama2 = AutoModelForCausalLM.from_pretrained(
@@ -126,6 +126,11 @@ So without further ado, lets get comparing.
126
126
-### MultiHead Latent Attention (MLA)
127
127
A new architecture found in DeepSeek V2 family of models. Here, we compress the Keys and values into a latent space and uncompress them back to original space when inference takes place. The idea is to get the advantages of MHAwhile saving up on KVCache as it scales linearly with context length. Each key and value are compressed from`d` dimensions to `c` dimension space.
128
128
129
+
Last year, I was playing around with llama 2and mistral family of models. I tried to understand why some models perform better than the others from a mathematical perspective. I was fiddling with eigen values of each of the weight matrices. What I observed was very interesting. All the models exhibit some sort of low rank behaviour, where 50% of the eigen values explain 90% of the variance ([Read for reference](https://stats.stackexchange.com/a/171599)). So essentially, we can compress the keys and values (evne queries) to atleast half their original size without losing much information. This can be thought of as an explanation to why MLA might work. The compression ratio is higher than 2:1 but you get the idea.
130
+
131
+

132
+
_Share of eigen values contributing to 90%in weight_
133
+
129
134
$$
130
135
c_t^{KV} = W^{DKV} X, \quad \text{ where } c_t^{KV} \in \R^{c} \quad \text { is down projection of keys }\\
131
136
k_t^C = W^{UK} c_t^{KV} \quad \text { up projection of keys} \\
@@ -142,6 +147,9 @@ So without further ado, lets get comparing.
**KVCache**: Each compressed vector is of size `c` per token per layer per group. Hence a total of $n*g*c*t$. Keys and values are inferred by decompressing this ($k_t^C, v_t^C$). A compression of `2*d/c` times compared to MHA. Note that in final implementation there's a nuance of additional heads (and hence keys and values) for RoPE. That adds a little more overhead. So the compression ratio essentially becomes $2*d/(c+r)$ where r is the RoPE key dimension.
146
154
147
155
@@ -156,13 +164,35 @@ _MHA vs GQA vs MQA vs MLA_
156
164
-### Differential Transformer
157
165
Introduced in a [2024 paper from Microsoft](https://arxiv.org/abs/2410.05258). The main motivation is that attention scores have a lot of noise. So if we have two subnetworks calculating attention, subtracting one from the other would act as subtracting random noise from information induced with noise. This helps control the attention scores and logits (outliers are of lesser magnitude). This is said to improve convergence. We discussed this in great detail in one of our other blogs on substack, [check it out.](https://datta0.substack.com/i/150138108/differential-transformer)
158
166
159
-
- Here owing to having two attention units, the number of paramters, activations and KVCache requirement goes up by a factor of 2 each as compared to GQA
- Here owing to having two attention units, the number of paramters, activations and KVCache requirement goes up by a factor of 2 each as compared to GQA.
160
171
161
172
-### nGPT
162
173
Introduced in a [2024 paper fromNVIDIA](https://arxiv.org/abs/2410.01131). The main idea is, if normalisation layers are so important to the performance of deep networks and LLMs, why not make normalistion mathemtically implicit to the network. Given this assumption, at every step, we try to make sure we're interacting with normalized vectors and only normalised vectors are passed on after every step. This too is said to improve convergence. We discussed this in great detail in one of our other blogs on substack, [check it out.](https://datta0.substack.com/i/151875954/ngpt-normalized-transformer)
- Apart from more normalisations there isn't much that would meaningfully contribute to parameters or activations or KVCache as compared to GQA.
165
179
166
180
So now that the introductions are out of the way, the burning question is do the changes contribute to any meaningful differences in the final performance of the models?
167
181
168
182
Well the answer is nuanced. Let's see how they stack up.
183
+
184
+

185
+
_Train losses on wikipedia dataset_
186
+
187
+
I started with a model that has `16` layers, with a hidden size of `1536`. The MHA variant had `16` attention heads (hence 16 key value heads) while the GQA variant had `4` key value heads. The MLP block had an intermediate size of `2048`. I used [GPT2 tokenizer from NeelNanda](https://huggingface.co/NeelNanda/gpt-neox-tokenizer-digits) which is modified to treat numbers as individual tokens.
188
+
189
+
Looks like nGPT outperforms the rest by a decent margin on a [100k sample of the wikipedia dataset](https://huggingface.co/datasets/imdatta0/wikipedia_en_sample)
190
+
191
+

192
+
_Train losses on minipile dataset_
193
+
194
+
On the [minipile dataset](huggingface.co/datasets/jeankaddour/minipile) which is approximately 10x larger than the wiki data, I saw that there isn't much to choose between MLA, MHA, GQA and DiffAttention. Which is great since GQA uses 4x less keys and values resulting in 4x less KVCache. Surprisingly, nGPT's losses seem to go down as low as0.2 when the others hover around 3. I tried to repeat the experiement multiple times with multiple configs only to find a similar loss curve. I also checked validation loss forall the models, they look very similar to train loss curves so there isn't much value in plotting those. We will have to look into why this is the case but it definitely is fascinating.
195
+
196
+
197
+
### Conclusion
198
+
All inall, GQA offers a very good alternative to MHA, sometimes even outperforming it while also using 4-8x less space for KVCache. MLA builds upon that by compressing the Keys and values even further. Turns out, this also acts as regularisation. Normalisation is the king of all. Given that normalisation is a key component in deep learning, it is no surprise that making it explicit for every operation. This opens up new paths to LLM training. We will explore the down stream capabilities of the models in a future write up. Until then, Ciao.
0 commit comments