Skip to content

Commit d884971

Browse files
authored
[loss] Add AdaptiveLayerLoss; 2d Matryoshka loss modifiers (#2506)
* Initial draft for AdaptiveLayerLoss * Re-add KV-divergence, now on embeddings after pooling (etc.) * Remove dead code * Add draft Matryoshka2dLoss shorthand * Warn about not using CMNRL with AdaptiveLayerLoss * Write docstrings * Implement n_dims_per_step to MatryoshkaLoss * Add KL-temperature * Set 0.3 as the default KL-div temperature * Update loss as recommended by Sean * Add weights to allow users to tune AL & 2D Matryoshka better * Introduce docstrings for Matryoshka2dLoss, add relations * Add docs on AdaptiveLayerLoss + results for Matryoshka * Run formatting * Update performance figure * Also describe 2d Matryoshka * Update the AL default to 1 layer per step This matches 2DMSE + it seems to work just as well (if not better) * Add 2d Matryoshka training scripts * Add AL to toctree * Fix incorrect empty line * Add new loss modifier section * Link from the loss functions to documentation pages
1 parent 937be8c commit d884971

File tree

13 files changed

+1126
-3
lines changed

13 files changed

+1126
-3
lines changed

docs/package_reference/losses.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,16 @@ This allows our network to be fine-tuned to recognize the similarity of sentence
9090
.. autoclass:: sentence_transformers.losses.MatryoshkaLoss
9191
```
9292

93+
## Matryoshka2dLoss
94+
```eval_rst
95+
.. autoclass:: sentence_transformers.losses.Matryoshka2dLoss
96+
```
97+
98+
## AdaptiveLayerLoss
99+
```eval_rst
100+
.. autoclass:: sentence_transformers.losses.AdaptiveLayerLoss
101+
```
102+
93103
## MegaBatchMarginLoss
94104

95105
```eval_rst

docs/training/loss_overview.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,17 @@ Loss functions play a critical role in the performance of your fine-tuned model.
1515
| `(anchor, positive/negative) pairs` | `1 if positive, 0 if negative` | <a href="../package_reference/losses.html#contrastiveloss">`ContrastiveLoss`</a><br><a href="../package_reference/losses.html#onlinecontrastiveloss">`OnlineContrastiveLoss`</a> |
1616
| `(sentence_A, sentence_B) pairs` | `float similarity score` | <a href="../package_reference/losses.html#cosentloss">`CoSENTLoss`</a><br><a href="../package_reference/losses.html#angleloss">`AnglELoss`</a><br><a href="../package_reference/losses.html#cosinesimilarityloss">`CosineSimilarityLoss`</a> |
1717
| `(anchor, positive, negative) triplets` | `none` | <a href="../package_reference/losses.html#cachedmultiplenegativesrankingloss">`CachedMultipleNegativesRankingLoss`</a><br><a href="../package_reference/losses.html#multiplenegativesrankingloss">`MultipleNegativesRankingLoss`</a><br><a href="../package_reference/losses.html#tripletloss">`TripletLoss`</a> |
18-
| `any` | `any` | <a href="../package_reference/losses.html#matryoshkaloss">`MatryoshkaLoss`</a> |
18+
19+
## Loss modifiers
20+
21+
These loss functions can be seen as *loss modifiers*: they work on top of standard loss functions, but apply those loss functions in different ways to try and instil useful properties into the trained embedding model.
22+
23+
For example, models trained with <a href="../package_reference/losses.html#matryoshkaloss">`MatryoshkaLoss`</a> produce embeddings whose size can be truncated without notable losses in performance, and models trained with <a href="../package_reference/losses.html#adaptivelayerloss">`AdaptiveLayerLoss`</a> still perform well when you remove model layers for faster inference.
24+
25+
| Texts | Labels | Appropriate Loss Functions |
26+
|-------|--------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
27+
| `any` | `any` | <a href="../package_reference/losses.html#matryoshkaloss">`MatryoshkaLoss`</a><br><a href="../package_reference/losses.html#adaptivelayerloss">`AdaptiveLayerLoss`</a><br><a href="../package_reference/losses.html#matryoshka2dloss">`Matryoshka2dLoss`</a> |
28+
1929

2030
## Distillation
2131
These loss functions are specifically designed to be used when distilling the knowledge from one model into another.
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Adaptive Layers
2+
3+
Embedding models are often encoder models with numerous layers, such as 12 (e.g. [all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)) or 6 (e.g. [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)). To get embeddings, every single one of these layers must be traversed. [2D Matryoshka Sentence Embeddings](https://arxiv.org/abs/2402.14776) (2DMSE) revisits this concept by proposing an approach to train embedding models that will perform well when only using a selection of all layers. This results in faster inference speeds at relatively low performance costs.
4+
5+
## Use Cases
6+
7+
The 2DMSE paper mentions that using a few layers of a larger model trained using Adaptive Layers and Matryoshka Representation Learning can outperform a smaller model that was trained like a standard embedding model.
8+
9+
## Results
10+
11+
Let's look at the performance that we may be able to expect from an Adaptive Layer embedding model versus a regular embedding model. For this experiment, I have trained two models:
12+
13+
* [tomaarsen/mpnet-base-nli-adaptive-layer](https://huggingface.co/tomaarsen/mpnet-base-nli-adaptive-layer): Trained by running [adaptive_layer_nli.py](adaptive_layer_nli.py) with [microsoft/mpnet-base](https://huggingface.co/microsoft/mpnet-base).
14+
* [tomaarsen/mpnet-base-nli](https://huggingface.co/tomaarsen/mpnet-base-nli): A near identical model as the former, but using only `MultipleNegativesRankingLoss` rather than `AdaptiveLayerLoss` on top of `MultipleNegativesRankingLoss`. I also use [microsoft/mpnet-base](https://huggingface.co/microsoft/mpnet-base) as the base model.
15+
16+
Both of these models were trained on the AllNLI dataset, which is a concatenation of the [SNLI](https://huggingface.co/datasets/snli) and [MultiNLI](https://huggingface.co/datasets/multi_nli) datasets. I have evaluated these models on the [STSBenchmark](https://huggingface.co/datasets/mteb/stsbenchmark-sts) test set using multiple different embedding dimensions. The results are plotted in the following figure:
17+
18+
![adaptive_layer_results](https://huggingface.co/tomaarsen/mpnet-base-nli-adaptive-layer/resolve/main/adaptive_layer_results.png)
19+
20+
The first figure shows that the Adaptive Layer model stays much more performant when reducing the number of layers in the model. This is also clearly shown in the second figure, which displays that 80% of the performance is preserved when the number of layers is reduced all the way to 1.
21+
22+
Lastly, the third figure shows the expected speedup ratio for GPU & CPU devices in my tests. As you can see, removing half of the layers results in roughly a 2x speedup, at a cost of ~15% performance on STSB (~86 -> ~75 Spearman correlation). When removing even more layers, the performance benefit gets larger for CPUs, and between 5x and 10x speedups are very feasible with a 20% loss in performance.
23+
24+
## Training
25+
26+
Training with Adaptive Layer support is quite elementary: rather than applying some loss function on only the last layer, we also apply that same loss function on the pooled embeddings from previous layers. Additionally, we employ a KL-divergence loss that aims to make the embeddings of the non-last layers match that of the last layer. This can be seen as a fascinating approach of [knowledge distillation](../distillation/README.html#knowledge-distillation), but with the last layer as the teacher model and the prior layers as the student models.
27+
28+
For example, with the 12-layer [microsoft/mpnet-base](https://huggingface.co/microsoft/mpnet-base), it will now be trained such that the model produces meaningful embeddings after each of the 12 layers.
29+
30+
```python
31+
from sentence_transformers import SentenceTransformer
32+
from sentence_transformers.losses import CoSENTLoss, AdaptiveLayerLoss
33+
34+
model = SentenceTransformer("microsoft/mpnet-base")
35+
36+
base_loss = CoSENTLoss(model=model)
37+
loss = AdaptiveLayerLoss(model=model, loss=base_loss)
38+
```
39+
* **Reference**: <a href="../../../docs/package_reference/losses.html#adaptivelayerloss"><code>AdaptiveLayerLoss</code></a>
40+
41+
Note that training with `AdaptiveLayerLoss` is not notably slower than without using it.
42+
43+
Additionally, this can be combined with the `MatryoshkaLoss` such that the resulting model can be reduced both in the number of layers, but also in the size of the output dimensions. See also the [Matryoshka Embeddings](../matryoshka/README.html) for more information on reducing output dimensions. In Sentence Transformers, the combination of these two losses is called `Matryoshka2dLoss`, and a shorthand is provided for simpler training.
44+
45+
```python
46+
from sentence_transformers import SentenceTransformer
47+
from sentence_transformers.losses import CoSENTLoss, Matryoshka2dLoss
48+
49+
model = SentenceTransformer("microsoft/mpnet-base")
50+
51+
base_loss = CoSENTLoss(model=model)
52+
loss = Matryoshka2dLoss(model=model, loss=base_loss, matryoshka_dims=[768, 512, 256, 128, 64])
53+
```
54+
55+
* **Reference**: <a href="../../../docs/package_reference/losses.html#matryoshka2dloss"><code>Matryoshka2dLoss</code></a>
56+
57+
## Inference
58+
59+
After a model has been trained using the Adaptive Layer loss, you can then truncate the model layers to your desired layer count. Note that this requires doing a bit of surgery on the model itself, and each model is structured a bit differently, so the steps are slightly different depending on the model.
60+
61+
First of all, we will load the model & access the underlying `transformers` model like so:
62+
63+
```python
64+
from sentence_transformers import SentenceTransformer
65+
66+
model = SentenceTransformer("tomaarsen/mpnet-base-nli-adaptive-layer")
67+
68+
# We can access the underlying model with `model[0].auto_model`
69+
print(model[0].auto_model)
70+
```
71+
```
72+
MPNetModel(
73+
(embeddings): MPNetEmbeddings(
74+
(word_embeddings): Embedding(30527, 768, padding_idx=1)
75+
(position_embeddings): Embedding(514, 768, padding_idx=1)
76+
(LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
77+
(dropout): Dropout(p=0.1, inplace=False)
78+
)
79+
(encoder): MPNetEncoder(
80+
(layer): ModuleList(
81+
(0-11): 12 x MPNetLayer(
82+
(attention): MPNetAttention(
83+
(attn): MPNetSelfAttention(
84+
(q): Linear(in_features=768, out_features=768, bias=True)
85+
(k): Linear(in_features=768, out_features=768, bias=True)
86+
(v): Linear(in_features=768, out_features=768, bias=True)
87+
(o): Linear(in_features=768, out_features=768, bias=True)
88+
(dropout): Dropout(p=0.1, inplace=False)
89+
)
90+
(LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
91+
(dropout): Dropout(p=0.1, inplace=False)
92+
)
93+
(intermediate): MPNetIntermediate(
94+
(dense): Linear(in_features=768, out_features=3072, bias=True)
95+
(intermediate_act_fn): GELUActivation()
96+
)
97+
(output): MPNetOutput(
98+
(dense): Linear(in_features=3072, out_features=768, bias=True)
99+
(LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
100+
(dropout): Dropout(p=0.1, inplace=False)
101+
)
102+
)
103+
)
104+
(relative_attention_bias): Embedding(32, 12)
105+
)
106+
(pooler): MPNetPooler(
107+
(dense): Linear(in_features=768, out_features=768, bias=True)
108+
(activation): Tanh()
109+
)
110+
)
111+
```
112+
This output will differ depending on the model. We will look for the repeated layers in the encoder. For this MPNet model, this is stored under `model[0].auto_model.encoder.layer`. Then we can slice the model to only keep the first few layers to speed up the model:
113+
114+
```python
115+
new_num_layers = 3
116+
model[0].auto_model.encoder.layer = model[0].auto_model.encoder.layer[:new_num_layers]
117+
```
118+
119+
Then we can run inference with it using <a href="../../../docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"><code>SentenceTransformers.encode</code></a>.
120+
121+
```python
122+
from sentence_transformers import SentenceTransformer
123+
from sentence_transformers.util import cos_sim
124+
125+
model = SentenceTransformer("tomaarsen/mpnet-base-nli-adaptive-layer")
126+
new_num_layers = 3
127+
model[0].auto_model.encoder.layer = model[0].auto_model.encoder.layer[:new_num_layers]
128+
129+
embeddings = model.encode(
130+
[
131+
"The weather is so nice!",
132+
"It's so sunny outside!",
133+
"He drove to the stadium.",
134+
]
135+
)
136+
# Similarity of the first sentence with the other two
137+
similarities = cos_sim(embeddings[0], embeddings[1:])
138+
# => tensor([[0.7761, 0.1655]])
139+
# compared to tensor([[ 0.7547, -0.0162]]) for the full model
140+
```
141+
As you can see, the similarity between the related sentences is much higher than the unrelated sentence, despite only using 3 layers. Feel free to copy this script locally, modify the `new_num_layers`, and observe the difference in similarities.
142+
143+
144+
## Code Examples
145+
146+
See the following scripts as examples of how to apply the <a href="../../../docs/package_reference/losses.html#adaptivelayerloss"><code>AdaptiveLayerLoss</code></a> in practice:
147+
148+
* **[adaptive_layer_nli.py](adaptive_layer_nli.py)**: This example uses the `MultipleNegativesRankingLoss` with `AdaptiveLayerLoss` to train a strong embedding model using Natural Language Inference (NLI) data. It is an adaptation of the [NLI](../nli/README) documentation.
149+
* **[adaptive_layer_sts.py](adaptive_layer_sts.py)**: This example uses the CoSENTLoss with AdaptiveLayerLoss to train an embedding model on the training set of the STSBenchmark dataset. It is an adaptation of the [STS](../sts/README) documentation.
150+
151+
And the following scripts to see how to apply <a href="../../../docs/package_reference/losses.html#matryoshka2dloss"><code>Matryoshka2dLoss</code></a>:
152+
* **[2d_matryoshka_nli.py](../matryoshka/2d_matryoshka_nli.py)**: This example uses the `MultipleNegativesRankingLoss` with `Matryoshka2dLoss` to train a strong embedding model using Natural Language Inference (NLI) data. It is an adaptation of the [NLI](../nli/README) documentation.
153+
* **[2d_matryoshka_sts.py](../matryoshka/2d_matryoshka_sts.py)**: This example uses the `CoSENTLoss` with `Matryoshka2dLoss` to train an embedding model on the training set of the STSBenchmark dataset. It is an adaptation of the [STS](../sts/README) documentation.

0 commit comments

Comments
 (0)