Skip to content

Commit ee35e7c

Browse files
committed
Fixes #212 : workaround that disable SPD attention in latest version of openai-whisper (20240930) which prevents from accessing attention weights
1 parent b6f035f commit ee35e7c

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

whisper_timestamped/transcribe.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
__author__ = "Jérôme Louradour"
44
__credits__ = ["Jérôme Louradour"]
55
__license__ = "GPLv3"
6-
__version__ = "1.15.4"
6+
__version__ = "1.15.5"
77

88
# Set some environment variables
99
import os
@@ -46,6 +46,20 @@
4646
AUDIO_TIME_PER_TOKEN = AUDIO_SAMPLES_PER_TOKEN / SAMPLE_RATE # 0.02 (sec)
4747
SEGMENT_DURATION = N_FRAMES * HOP_LENGTH / SAMPLE_RATE # 30.0 (sec)
4848

49+
# Access attention in latest versions...
50+
if whisper.__version__ >= "20240930":
51+
from whisper.model import disable_sdpa
52+
else:
53+
from contextlib import contextmanager
54+
55+
# Dummy context manager that does nothing
56+
@contextmanager
57+
def disable_sdpa():
58+
try:
59+
yield
60+
finally:
61+
pass
62+
4963
# Logs
5064
import logging
5165
logger = logging.getLogger("whisper_timestamped")
@@ -885,7 +899,8 @@ def hook_output_logits(layer, ins, outs):
885899
if compute_word_confidence or no_speech_threshold is not None:
886900
all_hooks.append(model.decoder.ln.register_forward_hook(hook_output_logits))
887901

888-
transcription = model.transcribe(audio, **whisper_options)
902+
with disable_sdpa():
903+
transcription = model.transcribe(audio, **whisper_options)
889904

890905
finally:
891906

@@ -1047,7 +1062,8 @@ def hook_output_logits(layer, ins, outs):
10471062

10481063
try:
10491064
model.alignment_heads = alignment_heads # Avoid exception "AttributeError: 'WhisperUntied' object has no attribute 'alignment_heads'. Did you mean: 'set_alignment_heads'?""
1050-
transcription = model.transcribe(audio, **whisper_options)
1065+
with disable_sdpa():
1066+
transcription = model.transcribe(audio, **whisper_options)
10511067
finally:
10521068
for hook in all_hooks:
10531069
hook.remove()

0 commit comments

Comments
 (0)