|
3 | 3 | __author__ = "Jérôme Louradour" |
4 | 4 | __credits__ = ["Jérôme Louradour"] |
5 | 5 | __license__ = "GPLv3" |
6 | | -__version__ = "1.15.5" |
| 6 | +__version__ = "1.15.6" |
7 | 7 |
|
8 | 8 | # Set some environment variables |
9 | 9 | import os |
@@ -899,8 +899,9 @@ def hook_output_logits(layer, ins, outs): |
899 | 899 | if compute_word_confidence or no_speech_threshold is not None: |
900 | 900 | all_hooks.append(model.decoder.ln.register_forward_hook(hook_output_logits)) |
901 | 901 |
|
902 | | - with disable_sdpa(): |
903 | | - transcription = model.transcribe(audio, **whisper_options) |
| 902 | + with torch.no_grad(): |
| 903 | + with disable_sdpa(): |
| 904 | + transcription = model.transcribe(audio, **whisper_options) |
904 | 905 |
|
905 | 906 | finally: |
906 | 907 |
|
@@ -1062,8 +1063,9 @@ def hook_output_logits(layer, ins, outs): |
1062 | 1063 |
|
1063 | 1064 | try: |
1064 | 1065 | model.alignment_heads = alignment_heads # Avoid exception "AttributeError: 'WhisperUntied' object has no attribute 'alignment_heads'. Did you mean: 'set_alignment_heads'?"" |
1065 | | - with disable_sdpa(): |
1066 | | - transcription = model.transcribe(audio, **whisper_options) |
| 1066 | + with torch.no_grad(): |
| 1067 | + with disable_sdpa(): |
| 1068 | + transcription = model.transcribe(audio, **whisper_options) |
1067 | 1069 | finally: |
1068 | 1070 | for hook in all_hooks: |
1069 | 1071 | hook.remove() |
@@ -1238,8 +1240,9 @@ def hook(layer, ins, outs, index=j): |
1238 | 1240 | i_start = len(sot_sequence) |
1239 | 1241 |
|
1240 | 1242 | with torch.no_grad(): |
1241 | | - logprobs = model(mfcc, torch.Tensor(tokens).int().to(model.device).unsqueeze(0)) |
1242 | | - logprobs = F.log_softmax(logprobs, dim=-1) |
| 1243 | + with disable_sdpa(): |
| 1244 | + logprobs = model(mfcc, torch.Tensor(tokens).int().to(model.device).unsqueeze(0)) |
| 1245 | + logprobs = F.log_softmax(logprobs, dim=-1) |
1243 | 1246 |
|
1244 | 1247 | end_token = tokenizer.timestamp_begin + round(min(N_FRAMES * HOP_LENGTH, end_sample - start_sample) // AUDIO_SAMPLES_PER_TOKEN) |
1245 | 1248 | tokens = tokens[i_start:] + [end_token] |
|
0 commit comments