1616from ..decode import decode_stable
1717from ..stabilization import NonSpeechPredictor
1818from ..timing import add_word_timestamps_stable
19- from ..utils import safe_print , isolate_useful_options , update_options , exact_div
19+ from ..utils import safe_print , isolate_useful_options , update_options , exact_div , format_timestamp
2020from ..whisper_compatibility import warn_compatibility_issues , get_tokenizer
2121from ..default import get_min_word_dur , get_prepend_punctuations , get_append_punctuations
2222
@@ -73,6 +73,7 @@ def transcribe_stable(
7373 extra_models : Optional [List ["Whisper" ]] = None ,
7474 dynamic_heads : Optional [Union [bool , int , str ]] = None ,
7575 clip_timestamps : Optional [Union [str , List [float ]]] = None ,
76+ resume : Union [WhisperResult , str , dict , list ] = None ,
7677 ** decode_options ) \
7778 -> WhisperResult :
7879 """
@@ -201,6 +202,9 @@ def transcribe_stable(
201202 clip_timestamps : str or list of float
202203 Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
203204 The last end timestamp defaults to the end of the file.
205+ resume : stable_whisper.result.WhisperResult or str or dict or list
206+ Path/data of an unfinished transcription output to continue transciption from.
207+ Use "+" as suffix of the path to resume from the end of second last segment (e.g "output-UNFINISHED.json+").
204208 decode_options
205209 Keyword arguments to construct class:`whisper.decode.DecodingOptions` instances.
206210
@@ -436,6 +440,28 @@ def new_segment(
436440
437441 with tqdm (total = initial_duration , unit = 'sec' , disable = verbose is not False , desc = task .title ()) as tqdm_pbar :
438442
443+ if resume is not None :
444+ remove_last_seg = False
445+ if not isinstance (resume , WhisperResult ):
446+ if isinstance (resume , str ) and resume .endswith ('+' ):
447+ resume = resume [:- 1 ]
448+ remove_last_seg = True
449+ resume = WhisperResult (resume )
450+ if resume and remove_last_seg :
451+ del resume [- 1 ]
452+ resume .unfinished_start = - 1.0
453+ if resume .unfinished_start == - 1.0 :
454+ resume_start = resume [- 1 ].end if resume else 0.0
455+ else :
456+ resume_start = resume .unfinished_start
457+ seek_sample = round (resume_start * SAMPLE_RATE )
458+ tqdm_pbar .write (f'Resuming from { format_timestamp (resume_start )} ' )
459+ decode_options ["language" ] = resume .language
460+
461+ interrupted_time = - 1.0
462+ segment_samples : int = 0
463+ mel_segment : torch .Tensor = torch .zeros (0 )
464+
439465 def update_pbar (curr_total_duration = None ):
440466 nonlocal audio_features
441467 audio_features = None
@@ -460,10 +486,11 @@ def fast_forward():
460486 update_seek ()
461487 update_pbar ()
462488
463- while True :
489+ def inner_transcribe ():
490+ nonlocal seek_sample , segment_samples , prompt_reset_since , mel_segment
464491 audio_segment , new_seek = audio .next_valid_chunk (seek_sample , N_SAMPLES )
465492 if audio_segment is None :
466- break
493+ return 1
467494 if new_seek != seek_sample :
468495 seek_sample = new_seek
469496 update_pbar ()
@@ -478,7 +505,7 @@ def fast_forward():
478505
479506 if is_silent_segment :
480507 fast_forward ()
481- continue
508+ return
482509
483510 if nonspeech_skip and silence_preds ['timings' ] is not None :
484511 silence_starts = silence_preds ['timings' ][0 ] - time_offset
@@ -490,7 +517,7 @@ def fast_forward():
490517 if silence_starts [skip_idx ] < min_word_dur or int (silence_starts [skip_idx ] * SAMPLE_RATE ) == 0 :
491518 segment_samples = round (silence_ends [skip_idx ] * SAMPLE_RATE )
492519 fast_forward ()
493- continue
520+ return
494521 audio_segment = audio_segment [..., :int (silence_starts [skip_idx ] * SAMPLE_RATE )]
495522 segment_samples = audio_segment .shape [- 1 ]
496523 segment_duration = segment_samples / SAMPLE_RATE
@@ -513,7 +540,7 @@ def fast_forward():
513540
514541 if should_skip :
515542 fast_forward ()
516- continue
543+ return
517544
518545 current_segments = []
519546
@@ -644,7 +671,7 @@ def fast_forward():
644671
645672 if len (current_segments ) == 0 :
646673 fast_forward ()
647- continue
674+ return
648675
649676 if segment_silence_timing is not None :
650677 for seg_i , segment in enumerate (current_segments ):
@@ -677,8 +704,21 @@ def fast_forward():
677704
678705 fast_forward ()
679706
707+ while True :
708+ try :
709+ if inner_transcribe () is not None :
710+ break
711+ except KeyboardInterrupt :
712+ if all_segments :
713+ interrupted_time = all_segments [- 1 ]['end' ]
714+ curr_seek_time = seek_sample / SAMPLE_RATE
715+ if curr_seek_time > interrupted_time :
716+ interrupted_time = curr_seek_time
717+ tqdm_pbar .write (f'Interrupted at { format_timestamp (seek_sample / SAMPLE_RATE )} ' )
718+ break
719+
680720 # final update
681- update_pbar (seek_sample / SAMPLE_RATE )
721+ update_pbar (( seek_sample / SAMPLE_RATE ) if interrupted_time == - 1 else None )
682722
683723 if model .device != torch .device ('cpu' ):
684724 torch .cuda .empty_cache ()
@@ -696,18 +736,37 @@ def fast_forward():
696736 ),
697737 force_order = not word_timestamps
698738 )
699- if word_timestamps and regroup :
700- final_result .regroup (regroup )
701739
702740 if time_scale is not None :
703741 final_result .rescale_time (1 / time_scale )
704742
743+ final_nonspeech_timings = nonspeech_predictor .nonspeech_timings if suppress_silence else None
744+
745+ if resume is not None :
746+ if resume :
747+ if final_result :
748+ resume .fill_in_gaps (final_result , verbose = False )
749+ if final_nonspeech_timings :
750+ resume .update_nonspeech_sections (* final_nonspeech_timings , overwrite = False )
751+ final_result = resume
752+ else :
753+ ns_starts = [sect ['start' ] for sect in resume .nonspeech_sections ]
754+ ns_ends = [sect ['end' ] for sect in resume .nonspeech_sections ]
755+ if final_nonspeech_timings :
756+ ns_starts .extend (final_nonspeech_timings [0 ])
757+ ns_ends .extend (final_nonspeech_timings [1 ])
758+ final_result .update_nonspeech_sections (ns_starts , ns_ends , overwrite = True )
759+ elif final_nonspeech_timings :
760+ final_result .update_nonspeech_sections (* final_nonspeech_timings , overwrite = True )
761+
762+ if word_timestamps and regroup :
763+ final_result .regroup (regroup )
764+
765+ final_result .unfinished_start = interrupted_time
766+
705767 if len (final_result .text ) == 0 :
706768 warnings .warn (f'Failed to { task } audio. Result contains no text. ' )
707769
708- if suppress_silence and (final_nonspeech_timings := nonspeech_predictor .nonspeech_timings ):
709- final_result .update_nonspeech_sections (* final_nonspeech_timings )
710-
711770 return final_result
712771
713772
0 commit comments