Skip to content

Commit e495276

Browse files
committed
Fixes #221 : workaround that disable SPD attention in latest version of openai-whisper (20240930) which prevents from accessing attention weights
1 parent ee35e7c commit e495276

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

whisper_timestamped/transcribe.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
__author__ = "Jérôme Louradour"
44
__credits__ = ["Jérôme Louradour"]
55
__license__ = "GPLv3"
6-
__version__ = "1.15.5"
6+
__version__ = "1.15.6"
77

88
# Set some environment variables
99
import os
@@ -899,8 +899,9 @@ def hook_output_logits(layer, ins, outs):
899899
if compute_word_confidence or no_speech_threshold is not None:
900900
all_hooks.append(model.decoder.ln.register_forward_hook(hook_output_logits))
901901

902-
with disable_sdpa():
903-
transcription = model.transcribe(audio, **whisper_options)
902+
with torch.no_grad():
903+
with disable_sdpa():
904+
transcription = model.transcribe(audio, **whisper_options)
904905

905906
finally:
906907

@@ -1062,8 +1063,9 @@ def hook_output_logits(layer, ins, outs):
10621063

10631064
try:
10641065
model.alignment_heads = alignment_heads # Avoid exception "AttributeError: 'WhisperUntied' object has no attribute 'alignment_heads'. Did you mean: 'set_alignment_heads'?""
1065-
with disable_sdpa():
1066-
transcription = model.transcribe(audio, **whisper_options)
1066+
with torch.no_grad():
1067+
with disable_sdpa():
1068+
transcription = model.transcribe(audio, **whisper_options)
10671069
finally:
10681070
for hook in all_hooks:
10691071
hook.remove()
@@ -1238,8 +1240,9 @@ def hook(layer, ins, outs, index=j):
12381240
i_start = len(sot_sequence)
12391241

12401242
with torch.no_grad():
1241-
logprobs = model(mfcc, torch.Tensor(tokens).int().to(model.device).unsqueeze(0))
1242-
logprobs = F.log_softmax(logprobs, dim=-1)
1243+
with disable_sdpa():
1244+
logprobs = model(mfcc, torch.Tensor(tokens).int().to(model.device).unsqueeze(0))
1245+
logprobs = F.log_softmax(logprobs, dim=-1)
12431246

12441247
end_token = tokenizer.timestamp_begin + round(min(N_FRAMES * HOP_LENGTH, end_sample - start_sample) // AUDIO_SAMPLES_PER_TOKEN)
12451248
tokens = tokens[i_start:] + [end_token]

0 commit comments

Comments
 (0)