Skip to content

Commit 5816fb0

Browse files
committed
Transformer showdown finalise
1 parent 6419ac7 commit 5816fb0

File tree

10 files changed

+63
-6
lines changed

10 files changed

+63
-6
lines changed

_data/share.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ platforms:
88

99
- type: Linkedin
1010
icon: "fab fa-linkedin"
11-
link: "https://www.linkedin.com/shareArticle?mini=true&url=URL"
11+
link: "https://www.linkedin.com/sharing/share-offsite/?url=URL"

_posts/2025-01-09-transformer-showdown.md renamed to _posts/2025-01-22-transformer-showdown.md

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
layout: post
33
title: Transformer showdown MHA vs MLA vs nGPT vs Differential Transformer
44
description: Comparing various transformer architectures like MHA, GQA, Multi Latent Attention, nGPT, Differential Transformer.
5-
date: 2025-01-09 00:27 +0530
5+
date: 2025-01-22 20:27 +0530
66
categories: [Transformer, Architectures]
77
tags: [MLA, MHA, GQA, Multi Latent Attention, nGPT, Differential Transformer]
88
render_with_liquid: false
@@ -30,7 +30,7 @@ So without further ado, lets get comparing.
3030

3131
**Parameters**: $W_{q_i}$ is of shape $(h*d, d)$, so has $h*d^2$ parameters per head. Same for $W_{k_i}, W_{v_i}$. So $W_q + W_k + W_v$ contributes to a total of $3*h*(h*d^2)$ paramters. $W_o$ is of size $(h*d, h*d)$ so $h^2*d^2$ parameters. Total of $4*n*h^2*d^2$.
3232

33-
For example, [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json) has [32 attention heads](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L14) and [32 key value heads](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L16). So llama 2 7B uses MHA. It has a [hidden_size of 4096](https://huggingface.co/meta-llama/Llama-2-7b-hf/,blob/main/config.json#L9). This means, each head has a head_dim (d) of **128**. So the algebra tells us that $W_{q_i}$ would be of shape $(128 *32,128) = (4096,128)$. Each Q (similarly K,V) would be of shape $(4096, 128*32)=(4096,4096)$ contributing to $128^2 * 32^2=1,6,777,216$ paramters. Executing the below code would give you the same result. Voila.
33+
For example, [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json) has [32 attention heads](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L14) and [32 key value heads](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L16). So llama 2 7B uses MHA. It has a [hidden_size of 4096](https://huggingface.co/meta-llama/Llama-2-7b-hf/,blob/main/config.json#L9). This means, each head has a head_dim (d) of **128**. So the algebra tells us that $W_{q_i}$ would be of shape $(128 *32,128) = (4096,128)$. Each Q (similarly K,V) would be of shape $(4096, 128*32)=(4096,4096)$ contributing to $128^2 * 32^2=16,777,216$ paramters. Executing the below code would give you the same result. Voila.
3434

3535
```python
3636
llama2 = AutoModelForCausalLM.from_pretrained(
@@ -126,6 +126,11 @@ So without further ado, lets get comparing.
126126
- ### MultiHead Latent Attention (MLA)
127127
A new architecture found in DeepSeek V2 family of models. Here, we compress the Keys and values into a latent space and uncompress them back to original space when inference takes place. The idea is to get the advantages of MHA while saving up on KVCache as it scales linearly with context length. Each key and value are compressed from `d` dimensions to `c` dimension space.
128128

129+
Last year, I was playing around with llama 2 and mistral family of models. I tried to understand why some models perform better than the others from a mathematical perspective. I was fiddling with eigen values of each of the weight matrices. What I observed was very interesting. All the models exhibit some sort of low rank behaviour, where 50% of the eigen values explain 90% of the variance ([Read for reference](https://stats.stackexchange.com/a/171599)). So essentially, we can compress the keys and values (evne queries) to atleast half their original size without losing much information. This can be thought of as an explanation to why MLA might work. The compression ratio is higher than 2:1 but you get the idea.
130+
131+
![Share of eigen values contributing to 90% in weight](assets/img/blogs/transformer_showdown/llama_eigen.png)
132+
_Share of eigen values contributing to 90% in weight_
133+
129134
$$
130135
c_t^{KV} = W^{DKV} X, \quad \text{ where } c_t^{KV} \in \R^{c} \quad \text { is down projection of keys }\\
131136
k_t^C = W^{UK} c_t^{KV} \quad \text { up projection of keys} \\
@@ -142,6 +147,9 @@ So without further ado, lets get comparing.
142147
$$ A_i = \text{Attention}(Q_i, K_i, V_i) = softmax(\frac{Q_i K_i^T}{\sqrt{d_k}})V_i $$
143148
$$ A = [A_1, A_2, ..., A_h] \space @ \space Wo $$
144149

150+
![Multi Latent Attention formulae](assets/img/blogs/transformer_showdown/mla.png)
151+
_Multi Latent Attention formulae_
152+
145153
**KVCache**: Each compressed vector is of size `c` per token per layer per group. Hence a total of $n*g*c*t$. Keys and values are inferred by decompressing this ($k_t^C, v_t^C$). A compression of `2*d/c` times compared to MHA. Note that in final implementation there's a nuance of additional heads (and hence keys and values) for RoPE. That adds a little more overhead. So the compression ratio essentially becomes $2*d/(c+r)$ where r is the RoPE key dimension.
146154

147155

@@ -156,13 +164,35 @@ _MHA vs GQA vs MQA vs MLA_
156164
- ### Differential Transformer
157165
Introduced in a [2024 paper from Microsoft](https://arxiv.org/abs/2410.05258). The main motivation is that attention scores have a lot of noise. So if we have two subnetworks calculating attention, subtracting one from the other would act as subtracting random noise from information induced with noise. This helps control the attention scores and logits (outliers are of lesser magnitude). This is said to improve convergence. We discussed this in great detail in one of our other blogs on substack, [check it out.](https://datta0.substack.com/i/150138108/differential-transformer)
158166

159-
- Here owing to having two attention units, the number of paramters, activations and KVCache requirement goes up by a factor of 2 each as compared to GQA
167+
![Differential Transformer](assets/img/blogs/transformer_showdown/diff_transformer.png)
168+
_Differntial Transformer_
169+
170+
- Here owing to having two attention units, the number of paramters, activations and KVCache requirement goes up by a factor of 2 each as compared to GQA.
160171

161172
- ### nGPT
162173
Introduced in a [2024 paper from NVIDIA](https://arxiv.org/abs/2410.01131). The main idea is, if normalisation layers are so important to the performance of deep networks and LLMs, why not make normalistion mathemtically implicit to the network. Given this assumption, at every step, we try to make sure we're interacting with normalized vectors and only normalised vectors are passed on after every step. This too is said to improve convergence. We discussed this in great detail in one of our other blogs on substack, [check it out.](https://datta0.substack.com/i/151875954/ngpt-normalized-transformer)
163174

175+
![nGPT formulae](assets/img/blogs/transformer_showdown/ngpt.png)
176+
_nGPT formulae_
177+
164178
- Apart from more normalisations there isn't much that would meaningfully contribute to parameters or activations or KVCache as compared to GQA.
165179

166180
So now that the introductions are out of the way, the burning question is do the changes contribute to any meaningful differences in the final performance of the models?
167181

168182
Well the answer is nuanced. Let's see how they stack up.
183+
184+
![Train losses on wikipedia dataset](assets/img/blogs/transformer_showdown/wiki_train_loss.png)
185+
_Train losses on wikipedia dataset_
186+
187+
I started with a model that has `16` layers, with a hidden size of `1536`. The MHA variant had `16` attention heads (hence 16 key value heads) while the GQA variant had `4` key value heads. The MLP block had an intermediate size of `2048`. I used [GPT2 tokenizer from NeelNanda](https://huggingface.co/NeelNanda/gpt-neox-tokenizer-digits) which is modified to treat numbers as individual tokens.
188+
189+
Looks like nGPT outperforms the rest by a decent margin on a [100k sample of the wikipedia dataset](https://huggingface.co/datasets/imdatta0/wikipedia_en_sample)
190+
191+
![Train losses on minipile dataset](assets/img/blogs/transformer_showdown/minipile_train_loss.png)
192+
_Train losses on minipile dataset_
193+
194+
On the [minipile dataset](huggingface.co/datasets/jeankaddour/minipile) which is approximately 10x larger than the wiki data, I saw that there isn't much to choose between MLA, MHA, GQA and DiffAttention. Which is great since GQA uses 4x less keys and values resulting in 4x less KVCache. Surprisingly, nGPT's losses seem to go down as low as 0.2 when the others hover around 3. I tried to repeat the experiement multiple times with multiple configs only to find a similar loss curve. I also checked validation loss for all the models, they look very similar to train loss curves so there isn't much value in plotting those. We will have to look into why this is the case but it definitely is fascinating.
195+
196+
197+
### Conclusion
198+
All in all, GQA offers a very good alternative to MHA, sometimes even outperforming it while also using 4-8x less space for KVCache. MLA builds upon that by compressing the Keys and values even further. Turns out, this also acts as regularisation. Normalisation is the king of all. Given that normalisation is a key component in deep learning, it is no surprise that making it explicit for every operation. This opens up new paths to LLM training. We will explore the down stream capabilities of the models in a future write up. Until then, Ciao.

_tabs/about.md

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,32 @@ icon: fas fa-info-circle
44
order: 4
55
---
66

7-
> Add Markdown syntax content to file `_tabs/about.md`{: .filepath } and it will show up on this page.
8-
{: .prompt-tip }
7+
### Hi there 👋, I'm Datta Nimmaturi
8+
9+
<img src="{{ '/assets/img/me.jpg' | relative_url }}" alt="Datta Nimmaturi" style="width: 360px; float: right">
10+
11+
- 🎓 I hold Bachelor's degree in Mathematics and Computer Science
12+
- 💻 Currently working at Nutanix.
13+
- 🤖 Have a huge interest in Maths, Deep Learning and Deep Reinforcement Learning
14+
- ❤️ Also Love Chess, Cricket and Tinkering with Operating Systems.
15+
- 📝 I also write another blog on substack summarising ML Research papers. [Check it out](https://datta0.substack.com) and here's a preview...
16+
17+
18+
<br>
19+
20+
<div id="substack-feed-embed"></div>
21+
<script>
22+
window.SubstackFeedWidget = {
23+
substackUrl: "datta0.substack.com",
24+
posts: 4,
25+
filter: "top",
26+
layout: "center",
27+
hidden: ["subtitle", "author", "date"],
28+
colors: {
29+
primary: "#FFF9F9",
30+
secondary: "#808080",
31+
background: "#000000",
32+
}
33+
};
34+
</script>
35+
<script src="https://substackapi.com/embeds/feed.js" async></script>
324 KB
Loading
138 KB
Loading
295 KB
Loading
196 KB
Loading
114 KB
Loading
257 KB
Loading
-11.9 KB
Binary file not shown.

0 commit comments

Comments
 (0)