Skip to content

Commit 765a081

Browse files
committed
added resume to transcribe()
-added parameter, `resume`, to `transcribe()` (only for original models) -added `--save_unfinished`/`-su`, `--resume_input`/`-ri`, `--delete_resume`/`-dr` to CLI -updated `transcribe()` to return partially finished transcription if force stopped by `KeyboardInterrupt` (only for original models) -updated docstring README.md to reflect new parameters
1 parent 4fe6d22 commit 765a081

File tree

4 files changed

+123
-18
lines changed

4 files changed

+123
-18
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,12 @@ Docstrings:
261261
word-timestamp extraction. Specify the number of heads or `True` for default of 6 heads.
262262
To specify number of iterations for finding the optimal heads,
263263
use string with "," to separate heads and iterations (e.g. "8,3" for 8 heads and 3 iterations).
264+
clip_timestamps : str or list of float
265+
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
266+
The last end timestamp defaults to the end of the file.
267+
resume : stable_whisper.result.WhisperResult or str or dict or list
268+
Path/data of an unfinished transcription output to continue transciption from.
269+
Use "+" as suffix of the path to resume from the end of second last segment (e.g "output-UNFINISHED.json+").
264270
decode_options
265271
Keyword arguments to construct class:`whisper.decode.DecodingOptions` instances.
266272

stable_whisper/result.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,7 @@ def __init__(
938938
self.ori_dict = result.get('ori_dict') or result
939939
self.language = self.ori_dict.get('language')
940940
self._regroup_history = result.get('regroup_history', '')
941-
self._nonspeech_sections = result.get('nonspeech_sections', [])
941+
self._nonspeech_sections = result.get('nonspeech_sections') or []
942942
segments = (result.get('segments', self.ori_dict.get('segments')) or {}).copy()
943943
self.segments = [Segment(**s, ignore_unused_args=True) for s in segments] if segments else []
944944
self._forced_order = force_order
@@ -947,6 +947,7 @@ def __init__(
947947
self.raise_for_unsorted(check_sorted, show_unsorted)
948948
self.remove_no_word_segments(any(seg.has_words for seg in self.segments))
949949
self._ignore_special_periods = False
950+
self.unfinished_start: float = result.get('unfinished', -1.0)
950951

951952
def __getitem__(self, index: int) -> Segment:
952953
return self.segments[index]
@@ -1061,10 +1062,14 @@ def update_all_segs_with_words(self):
10611062
stacklevel=2)
10621063
self.reassign_ids()
10631064

1064-
def update_nonspeech_sections(self, silent_starts, silent_ends):
1065-
self._nonspeech_sections = [
1065+
def update_nonspeech_sections(self, silent_starts, silent_ends, overwrite: bool = True):
1066+
nonspeech_sections = [
10661067
dict(start=round(s, 3), end=round(e, 3)) for s, e in zip(silent_starts, silent_ends)
10671068
]
1069+
if overwrite:
1070+
self._nonspeech_sections = nonspeech_sections
1071+
else:
1072+
self._nonspeech_sections.extend(nonspeech_sections)
10681073

10691074
def add_segments(
10701075
self,
@@ -1397,7 +1402,8 @@ def to_dict(self, keep_orig: bool = True):
13971402
language=self.language,
13981403
ori_dict=ori_dict,
13991404
regroup_history=self._regroup_history,
1400-
nonspeech_sections=self._nonspeech_sections)
1405+
nonspeech_sections=self._nonspeech_sections,
1406+
unfinished=self.unfinished_start)
14011407

14021408
def segments_to_dicts(self, reverse_text: Union[bool, tuple] = False):
14031409
return [s.to_dict(reverse_text=reverse_text) for s in self.segments]

stable_whisper/whisper_word_level/cli.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,14 @@ def url_to_path(url: str):
129129
help="output filepaths(s);"
130130
"if not specified, auto-named output file(s) will be saved to "
131131
"[output_dir] or current dir if not specified.")
132+
parser.add_argument("--save_unfinished", "-su", action='store_true',
133+
help="whether to save unfinished outputs caused by KeyboardInterrupt; "
134+
"outputs are saved as JSON with suffix '-UNFINISHED.json'")
135+
parser.add_argument("--resume_input", "-ri", nargs="+", type=str,
136+
help="JSON of unfinished output filepaths(s) to continue transcription from end of last word; "
137+
"use '+' as suffix to redo the last segment (e.g 'output-UNFINISHED.json+')")
138+
parser.add_argument("--delete_resume", "-dr", action='store_true',
139+
help="whether to delete file(s) from '--resume_input'/'-ri' when transcription finishes")
132140
parser.add_argument("--model", '-m', default="base", type=str,
133141
help="name of the Whisper model to use")
134142
parser.add_argument("--model_dir", type=str, default=None,
@@ -439,10 +447,13 @@ def url_to_path(url: str):
439447
model_name: str = valid_model_name(args.pop("model"))
440448
model_dir: str = args.pop("model_dir")
441449
inputs: List[Union[str, torch.Tensor]] = args.pop("inputs")
450+
resume_files: List[str] = args.pop("resume_input")
442451
outputs: List[str] = args.pop("output")
443452
output_dir: str = args.pop("output_dir")
444453
output_format = args.pop("output_format")
445454
overwrite: bool = args.pop("overwrite")
455+
save_unfinished: bool = args.pop("save_unfinished")
456+
delete_resume: bool = args.pop("delete_resume")
446457
no_stream = use_deprecated_args('no_stream', 'mel_first', pop=True, expected_default=False)
447458
args['stream'] = None if not no_stream else False
448459
if overwrite:
@@ -468,6 +479,12 @@ def url_to_path(url: str):
468479
from .original_whisper import load_model as load_model_func
469480
model_name_kwarg = dict(name=model_name)
470481
else:
482+
if save_unfinished:
483+
raise NotImplementedError('--save_unfinished is only supported on vanilla Whisper models.')
484+
485+
if resume_files:
486+
raise NotImplementedError('--resume_input is currently only supported on vanilla Whisper models.')
487+
471488
if is_faster_whisper:
472489
model_type_name = 'Faster-Whisper'
473490
from .faster_whisper import load_faster_whisper as load_model_func
@@ -616,6 +633,10 @@ def finalize_outputs(input_file: str, _output: str = None, _alignment: str = Non
616633
if args['vad'] and args['vad_onnx']:
617634
args['vad'] = dict(onnx=args['vad_onnx'])
618635

636+
if resume_files and len(inputs) != len(resume_files):
637+
raise ValueError(f'--resume_input and inputs do not match in count. '
638+
f'Got {len(resume_files)} and {len(inputs)}')
639+
619640
if debug:
620641
print('Input(s) -> Outputs(s)')
621642
for i, (input_audio, output_paths, alignment) in enumerate(zip(inputs, final_outputs, alignments)):
@@ -627,7 +648,11 @@ def finalize_outputs(input_file: str, _output: str = None, _alignment: str = Non
627648
alignment = f' + text="{alignment}"'
628649
else:
629650
alignment = f' + "{alignment}"'
630-
print(f'"{input_audio}"{alignment} ->{dm_output} {output_paths}')
651+
if resume_files:
652+
resume_info = f' + "{resume_files[i]}"'
653+
else:
654+
resume_info = ''
655+
print(f'"{input_audio}"{resume_info}{alignment} ->{dm_output} {output_paths}')
631656
print('')
632657

633658
if show_curr_task:
@@ -679,6 +704,8 @@ def _load_model():
679704
model = _load_model()
680705
args['regroup'] = False
681706
args['audio'] = input_audio
707+
if resume_files:
708+
args['resume'] = resume_files[i]
682709
if denoiser_outputs:
683710
args['denoiser_options']['save_path'] = denoiser_outputs[i]
684711
transcribe_method = args.get('transcribe_method')
@@ -740,6 +767,13 @@ def _load_model():
740767
update_options_with_args('save_option', save_options)
741768
call_method_with_options(save_method, save_options)
742769

770+
if result.unfinished_start != -1:
771+
result.save_as_json(splitext(output_paths[0])[0] + '-UNFINISHED.json')
772+
break
773+
elif delete_resume and 'resume' in args and os.path.isfile(args['resume']):
774+
os.remove(args['resume'])
775+
print(f'Removed: {os.path.abspath(args["resume"])}')
776+
743777

744778
def cli(cmd: str = None):
745779
cache = {}

stable_whisper/whisper_word_level/original_whisper.py

Lines changed: 72 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ..decode import decode_stable
1717
from ..stabilization import NonSpeechPredictor
1818
from ..timing import add_word_timestamps_stable
19-
from ..utils import safe_print, isolate_useful_options, update_options, exact_div
19+
from ..utils import safe_print, isolate_useful_options, update_options, exact_div, format_timestamp
2020
from ..whisper_compatibility import warn_compatibility_issues, get_tokenizer
2121
from ..default import get_min_word_dur, get_prepend_punctuations, get_append_punctuations
2222

@@ -73,6 +73,7 @@ def transcribe_stable(
7373
extra_models: Optional[List["Whisper"]] = None,
7474
dynamic_heads: Optional[Union[bool, int, str]] = None,
7575
clip_timestamps: Optional[Union[str, List[float]]] = None,
76+
resume: Union[WhisperResult, str, dict, list] = None,
7677
**decode_options) \
7778
-> WhisperResult:
7879
"""
@@ -201,6 +202,9 @@ def transcribe_stable(
201202
clip_timestamps : str or list of float
202203
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
203204
The last end timestamp defaults to the end of the file.
205+
resume : stable_whisper.result.WhisperResult or str or dict or list
206+
Path/data of an unfinished transcription output to continue transciption from.
207+
Use "+" as suffix of the path to resume from the end of second last segment (e.g "output-UNFINISHED.json+").
204208
decode_options
205209
Keyword arguments to construct class:`whisper.decode.DecodingOptions` instances.
206210
@@ -436,6 +440,28 @@ def new_segment(
436440

437441
with tqdm(total=initial_duration, unit='sec', disable=verbose is not False, desc=task.title()) as tqdm_pbar:
438442

443+
if resume is not None:
444+
remove_last_seg = False
445+
if not isinstance(resume, WhisperResult):
446+
if isinstance(resume, str) and resume.endswith('+'):
447+
resume = resume[:-1]
448+
remove_last_seg = True
449+
resume = WhisperResult(resume)
450+
if resume and remove_last_seg:
451+
del resume[-1]
452+
resume.unfinished_start = -1.0
453+
if resume.unfinished_start == -1.0:
454+
resume_start = resume[-1].end if resume else 0.0
455+
else:
456+
resume_start = resume.unfinished_start
457+
seek_sample = round(resume_start * SAMPLE_RATE)
458+
tqdm_pbar.write(f'Resuming from {format_timestamp(resume_start)}')
459+
decode_options["language"] = resume.language
460+
461+
interrupted_time = -1.0
462+
segment_samples: int = 0
463+
mel_segment: torch.Tensor = torch.zeros(0)
464+
439465
def update_pbar(curr_total_duration=None):
440466
nonlocal audio_features
441467
audio_features = None
@@ -460,10 +486,11 @@ def fast_forward():
460486
update_seek()
461487
update_pbar()
462488

463-
while True:
489+
def inner_transcribe():
490+
nonlocal seek_sample, segment_samples, prompt_reset_since, mel_segment
464491
audio_segment, new_seek = audio.next_valid_chunk(seek_sample, N_SAMPLES)
465492
if audio_segment is None:
466-
break
493+
return 1
467494
if new_seek != seek_sample:
468495
seek_sample = new_seek
469496
update_pbar()
@@ -478,7 +505,7 @@ def fast_forward():
478505

479506
if is_silent_segment:
480507
fast_forward()
481-
continue
508+
return
482509

483510
if nonspeech_skip and silence_preds['timings'] is not None:
484511
silence_starts = silence_preds['timings'][0] - time_offset
@@ -490,7 +517,7 @@ def fast_forward():
490517
if silence_starts[skip_idx] < min_word_dur or int(silence_starts[skip_idx] * SAMPLE_RATE) == 0:
491518
segment_samples = round(silence_ends[skip_idx] * SAMPLE_RATE)
492519
fast_forward()
493-
continue
520+
return
494521
audio_segment = audio_segment[..., :int(silence_starts[skip_idx] * SAMPLE_RATE)]
495522
segment_samples = audio_segment.shape[-1]
496523
segment_duration = segment_samples / SAMPLE_RATE
@@ -513,7 +540,7 @@ def fast_forward():
513540

514541
if should_skip:
515542
fast_forward()
516-
continue
543+
return
517544

518545
current_segments = []
519546

@@ -644,7 +671,7 @@ def fast_forward():
644671

645672
if len(current_segments) == 0:
646673
fast_forward()
647-
continue
674+
return
648675

649676
if segment_silence_timing is not None:
650677
for seg_i, segment in enumerate(current_segments):
@@ -677,8 +704,21 @@ def fast_forward():
677704

678705
fast_forward()
679706

707+
while True:
708+
try:
709+
if inner_transcribe() is not None:
710+
break
711+
except KeyboardInterrupt:
712+
if all_segments:
713+
interrupted_time = all_segments[-1]['end']
714+
curr_seek_time = seek_sample / SAMPLE_RATE
715+
if curr_seek_time > interrupted_time:
716+
interrupted_time = curr_seek_time
717+
tqdm_pbar.write(f'Interrupted at {format_timestamp(seek_sample / SAMPLE_RATE)}')
718+
break
719+
680720
# final update
681-
update_pbar(seek_sample / SAMPLE_RATE)
721+
update_pbar((seek_sample / SAMPLE_RATE) if interrupted_time == -1 else None)
682722

683723
if model.device != torch.device('cpu'):
684724
torch.cuda.empty_cache()
@@ -696,18 +736,37 @@ def fast_forward():
696736
),
697737
force_order=not word_timestamps
698738
)
699-
if word_timestamps and regroup:
700-
final_result.regroup(regroup)
701739

702740
if time_scale is not None:
703741
final_result.rescale_time(1 / time_scale)
704742

743+
final_nonspeech_timings = nonspeech_predictor.nonspeech_timings if suppress_silence else None
744+
745+
if resume is not None:
746+
if resume:
747+
if final_result:
748+
resume.fill_in_gaps(final_result, verbose=False)
749+
if final_nonspeech_timings:
750+
resume.update_nonspeech_sections(*final_nonspeech_timings, overwrite=False)
751+
final_result = resume
752+
else:
753+
ns_starts = [sect['start'] for sect in resume.nonspeech_sections]
754+
ns_ends = [sect['end'] for sect in resume.nonspeech_sections]
755+
if final_nonspeech_timings:
756+
ns_starts.extend(final_nonspeech_timings[0])
757+
ns_ends.extend(final_nonspeech_timings[1])
758+
final_result.update_nonspeech_sections(ns_starts, ns_ends, overwrite=True)
759+
elif final_nonspeech_timings:
760+
final_result.update_nonspeech_sections(*final_nonspeech_timings, overwrite=True)
761+
762+
if word_timestamps and regroup:
763+
final_result.regroup(regroup)
764+
765+
final_result.unfinished_start = interrupted_time
766+
705767
if len(final_result.text) == 0:
706768
warnings.warn(f'Failed to {task} audio. Result contains no text. ')
707769

708-
if suppress_silence and (final_nonspeech_timings := nonspeech_predictor.nonspeech_timings):
709-
final_result.update_nonspeech_sections(*final_nonspeech_timings)
710-
711770
return final_result
712771

713772

0 commit comments

Comments
 (0)