|
3 | 3 | __author__ = "Jérôme Louradour" |
4 | 4 | __credits__ = ["Jérôme Louradour"] |
5 | 5 | __license__ = "GPLv3" |
6 | | -__version__ = "1.15.4" |
| 6 | +__version__ = "1.15.5" |
7 | 7 |
|
8 | 8 | # Set some environment variables |
9 | 9 | import os |
|
46 | 46 | AUDIO_TIME_PER_TOKEN = AUDIO_SAMPLES_PER_TOKEN / SAMPLE_RATE # 0.02 (sec) |
47 | 47 | SEGMENT_DURATION = N_FRAMES * HOP_LENGTH / SAMPLE_RATE # 30.0 (sec) |
48 | 48 |
|
| 49 | +# Access attention in latest versions... |
| 50 | +if whisper.__version__ >= "20240930": |
| 51 | + from whisper.model import disable_sdpa |
| 52 | +else: |
| 53 | + from contextlib import contextmanager |
| 54 | + |
| 55 | + # Dummy context manager that does nothing |
| 56 | + @contextmanager |
| 57 | + def disable_sdpa(): |
| 58 | + try: |
| 59 | + yield |
| 60 | + finally: |
| 61 | + pass |
| 62 | + |
49 | 63 | # Logs |
50 | 64 | import logging |
51 | 65 | logger = logging.getLogger("whisper_timestamped") |
@@ -885,7 +899,8 @@ def hook_output_logits(layer, ins, outs): |
885 | 899 | if compute_word_confidence or no_speech_threshold is not None: |
886 | 900 | all_hooks.append(model.decoder.ln.register_forward_hook(hook_output_logits)) |
887 | 901 |
|
888 | | - transcription = model.transcribe(audio, **whisper_options) |
| 902 | + with disable_sdpa(): |
| 903 | + transcription = model.transcribe(audio, **whisper_options) |
889 | 904 |
|
890 | 905 | finally: |
891 | 906 |
|
@@ -1047,7 +1062,8 @@ def hook_output_logits(layer, ins, outs): |
1047 | 1062 |
|
1048 | 1063 | try: |
1049 | 1064 | model.alignment_heads = alignment_heads # Avoid exception "AttributeError: 'WhisperUntied' object has no attribute 'alignment_heads'. Did you mean: 'set_alignment_heads'?"" |
1050 | | - transcription = model.transcribe(audio, **whisper_options) |
| 1065 | + with disable_sdpa(): |
| 1066 | + transcription = model.transcribe(audio, **whisper_options) |
1051 | 1067 | finally: |
1052 | 1068 | for hook in all_hooks: |
1053 | 1069 | hook.remove() |
|
0 commit comments