Skip to content

Commit 1b6cbbe

Browse files
emapcotomaarsen
andauthored
[fix]: correct condition for restoring layer embeddings in TransformerDecorator/AdaptiveLayerLoss (#3560)
* fix: correct condition for restoring layer embeddings in TransformerDecorator/AdaptiveLayerLoss When training with AdaptiveLayerLoss, the `all_layer_embeddings` are deleted erroneously when `output_hidden_states` is True. * fix: include all layer embeddings in call_use_cache if requested * AdaptiveLayerLoss simplification We can rely on the original 'features' some more, as @emapco also proposed in his initial commit * refactor: remove unused call_idx in TransformerDecorator * Simplify call method in TransformerDecorator even further --------- Co-authored-by: Tom Aarsen <[email protected]>
1 parent 722a76e commit 1b6cbbe

File tree

1 file changed

+10
-31
lines changed

1 file changed

+10
-31
lines changed

sentence_transformers/losses/AdaptiveLayerLoss.py

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -29,55 +29,34 @@ class TransformerDecorator:
2929
def __init__(self, transformer: Transformer, original_forward) -> None:
3030
self.transformer = transformer
3131
self.original_forward = original_forward
32-
self.embeddings: list[tuple[Tensor]] = []
33-
self.last_embeddings: list[Tensor] = []
34-
self.features: list[dict[str, Tensor]] = []
3532
self.layer_idx = None
36-
self.call_idx = 0
3733

3834
def set_layer_idx(self, layer_idx) -> None:
3935
self.layer_idx = layer_idx
40-
self.call_idx = 0
41-
42-
def get_layer_embeddings(self) -> Tensor:
43-
return torch.concat([embedding[self.layer_idx] for embedding in self.embeddings], dim=1)
4436

4537
def __call__(self, features) -> dict[str, Tensor]:
4638
if self.layer_idx is None:
47-
output = self.call_grow_cache(features)
48-
else:
49-
output = self.call_use_cache(features)
50-
self.call_idx += 1
51-
return output
39+
return self.call_grow_cache(features)
40+
return self.call_use_cache(features)
5241

5342
def call_grow_cache(self, features: dict[str, Tensor]) -> dict[str, Tensor]:
5443
"""
5544
Temporarily sets the output_hidden_states to True, runs the model, and then restores the original setting.
5645
Use the all_layer_embeddings to get the embeddings of all layers.
5746
"""
5847
original_output_hidden_states = self.transformer.auto_model.config.output_hidden_states
59-
self.transformer.auto_model.config.output_hidden_states = True
60-
61-
output = self.original_forward(features)
62-
# We ignore the first layer, as it is the input embeddings
63-
# and the last layer, as we already computed the loss over it
64-
self.num_layers = len(output["all_layer_embeddings"]) - 1
65-
self.embeddings.append(output["all_layer_embeddings"][1:-1])
66-
self.last_embeddings.append(output["token_embeddings"])
67-
self.features.append(
68-
{key: value for key, value in output.items() if key not in ["all_layer_embeddings", "token_embeddings"]}
69-
)
70-
71-
# Restore original setting
72-
self.transformer.auto_model.config.output_hidden_states = original_output_hidden_states
73-
74-
if original_output_hidden_states:
75-
del output["all_layer_embeddings"]
48+
try:
49+
self.transformer.auto_model.config.output_hidden_states = True
50+
output = self.original_forward(features)
51+
self.num_layers = len(output["all_layer_embeddings"])
52+
finally:
53+
# Restore original setting
54+
self.transformer.auto_model.config.output_hidden_states = original_output_hidden_states
7655

7756
return output
7857

7958
def call_use_cache(self, features: dict[str, Tensor]) -> dict[str, Tensor]:
80-
return {**self.features[self.call_idx], "token_embeddings": self.embeddings[self.call_idx][self.layer_idx]}
59+
return {**features, "token_embeddings": features["all_layer_embeddings"][self.layer_idx]}
8160

8261

8362
class ForwardDecorator:

0 commit comments

Comments
 (0)