diff --git a/requirements.txt b/requirements.txt index 59f02b5..8ec7ecd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,6 +21,7 @@ streamlit-aggrid~=0.3.4 # openai API: openai==1.60.1 ffmpeg-python>=0.2.0 +joblib>=1.5.2 # whisper timestamped: whisper-timestamped @ git+https://github.com/linto-ai/whisper-timestamped diff --git a/src/subsai/models/whisper_api_model.py b/src/subsai/models/whisper_api_model.py index 68a0cfd..795364f 100644 --- a/src/subsai/models/whisper_api_model.py +++ b/src/subsai/models/whisper_api_model.py @@ -15,6 +15,7 @@ from openai import OpenAI from pysubs2 import SSAFile from pydub import AudioSegment +from joblib import Parallel, delayed TMPDIR = tempfile.gettempdir() OPENAI_API_SIZE_LIMIT_MB = 24 @@ -81,8 +82,14 @@ class WhisperAPIModel(AbstractModel): 'description': "The base URL for the API. Useful if you're already self hosting whisper for example.", 'options': None, 'default': "https://api.openai.com/v1/" + }, + "n_jobs": { + "type": int, + "description": "Number of calls to do in parallel (1 to not use parallel call)", + "options": None, + "default": 1, } - } + } def __init__(self, model_config): # config @@ -94,10 +101,11 @@ def __init__(self, model_config): self.base_url = _load_config('base_url', model_config, self.config_schema) if not self.base_url.endswith("/"): self.base_url += "/" + self.n_jobs = _load_config("n_jobs", model_config, self.config_schema) self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) - def chunk_audio(self,audio_file_path) -> list: + def chunk_audio(self, audio_file_path) -> list: # Load the audio file audio = AudioSegment.from_mp3(audio_file_path) @@ -114,48 +122,87 @@ def chunk_audio(self,audio_file_path) -> list: # Calculate the end of the current chunk end_ms = current_ms + chunk_duration_ms # Create a chunk from the current position to the end position - chunk = audio[current_ms:int(end_ms)] + chunk = audio[current_ms : int(end_ms)] # Add the chunk to the list of chunks and include offset - chunks.append((chunk,current_ms)) + chunks.append((chunk, current_ms)) # Update the current position current_ms = end_ms return chunks - - def transcribe(self, media_file) -> str: + def _transcribe_chunk(self, chunk_data): + """ + Transcribe a single audio chunk using OpenAI Whisper API. + + Parameters + ---------- + chunk_data : tuple + Tuple containing (chunk_index, chunk, offset) + + Returns + ------- + tuple + Tuple containing (chunk_index, transcription_result, offset) + """ + i, chunk, offset = chunk_data + chunk_path = os.path.join(TMPDIR, f"chunk_{i}.mp3") + + try: + print("Transcribing audio chunk {}".format(i)) + chunk.export(chunk_path, format="mp3") + + with open(chunk_path, "rb") as audio_file: + # Use OpenAI Whisper API + result = self.client.audio.transcriptions.create( + model=self.model_type, + language=self.language, + prompt=self.prompt, + temperature=self.temperature, + file=audio_file, + response_format="srt", + ) + + # Clean up the temporary chunk file + os.remove(chunk_path) + + return (i, result, offset) + + except Exception as e: + # Clean up the temporary chunk file in case of error + if os.path.exists(chunk_path): + os.remove(chunk_path) + raise e + + def transcribe(self, media_file: str) -> str: audio_file_path = convert_video_to_audio_ffmpeg(media_file) chunks = self.chunk_audio(audio_file_path) - results = '' - - for i, (chunk,offset) in enumerate(chunks): - chunk_path = os.path.join(TMPDIR,f'chunk_{i}.mp3') - print('Transcribing audio chunk {}/{}'.format(i,len(chunks))) - chunk.export(chunk_path, format='mp3') - audio_file = open(chunk_path, "rb") - - # Use OpenAI Whisper API - result = self.client.audio.transcriptions.create( - model=self.model_type, - language=self.language, - prompt=self.prompt, - temperature=self.temperature, - file=audio_file, - response_format="srt" - ) + print(f"Processing {len(chunks)} audio chunks with {self.n_jobs} parallel jobs") - with open(chunk_path+'.srt','w') as f: - f.write(result) + # Prepare chunk data for parallel processing + chunk_data = [(i, chunk, offset) for i, (chunk, offset) in enumerate(chunks)] - # shift subtitles by offset - result = SSAFile.from_string(result) + # Use parallel processing if n_jobs > 1, otherwise process sequentially + if self.n_jobs > 1: + # Use threading backend since API calls are I/O-bound + parallel_results = Parallel(n_jobs=self.n_jobs, backend="threading")( + delayed(self._transcribe_chunk)(data) for data in chunk_data + ) + else: + # Sequential processing for n_jobs=1 + parallel_results = [self._transcribe_chunk(data) for data in chunk_data] + + # Sort results by chunk index to maintain order + parallel_results.sort(key=lambda x: x[0]) + + # Process results and apply time offsets + results = "" + for i, result_text, offset in parallel_results: + # Shift subtitles by offset + result = SSAFile.from_string(result_text) result.shift(ms=offset) - results += result.to_string('srt') - - results = ''.join(results) + results += result.to_string("srt") return SSAFile.from_string(results) -