1515from openai import OpenAI
1616from pysubs2 import SSAFile
1717from pydub import AudioSegment
18+ from joblib import Parallel , delayed
1819
1920TMPDIR = tempfile .gettempdir ()
2021OPENAI_API_SIZE_LIMIT_MB = 24
@@ -81,8 +82,14 @@ class WhisperAPIModel(AbstractModel):
8182 'description' : "The base URL for the API. Useful if you're already self hosting whisper for example." ,
8283 'options' : None ,
8384 'default' : "https://api.openai.com/v1/"
85+ },
86+ "n_jobs" : {
87+ "type" : int ,
88+ "description" : "Number of calls to do in parallel (1 to not use parallel call)" ,
89+ "options" : None ,
90+ "default" : 1 ,
8491 }
85- }
92+ }
8693
8794 def __init__ (self , model_config ):
8895 # config
@@ -94,10 +101,11 @@ def __init__(self, model_config):
94101 self .base_url = _load_config ('base_url' , model_config , self .config_schema )
95102 if not self .base_url .endswith ("/" ):
96103 self .base_url += "/"
104+ self .n_jobs = _load_config ("n_jobs" , model_config , self .config_schema )
97105
98106 self .client = OpenAI (api_key = self .api_key , base_url = self .base_url )
99107
100- def chunk_audio (self ,audio_file_path ) -> list :
108+ def chunk_audio (self , audio_file_path ) -> list :
101109 # Load the audio file
102110 audio = AudioSegment .from_mp3 (audio_file_path )
103111
@@ -114,48 +122,87 @@ def chunk_audio(self,audio_file_path) -> list:
114122 # Calculate the end of the current chunk
115123 end_ms = current_ms + chunk_duration_ms
116124 # Create a chunk from the current position to the end position
117- chunk = audio [current_ms : int (end_ms )]
125+ chunk = audio [current_ms : int (end_ms )]
118126 # Add the chunk to the list of chunks and include offset
119- chunks .append ((chunk ,current_ms ))
127+ chunks .append ((chunk , current_ms ))
120128 # Update the current position
121129 current_ms = end_ms
122130
123131 return chunks
124-
132+
133+ def _transcribe_chunk (self , chunk_data ):
134+ """
135+ Transcribe a single audio chunk using OpenAI Whisper API.
136+
137+ Parameters
138+ ----------
139+ chunk_data : tuple
140+ Tuple containing (chunk_index, chunk, offset)
141+
142+ Returns
143+ -------
144+ tuple
145+ Tuple containing (chunk_index, transcription_result, offset)
146+ """
147+ i , chunk , offset = chunk_data
148+ chunk_path = os .path .join (TMPDIR , f"chunk_{ i } .mp3" )
149+
150+ try :
151+ print ("Transcribing audio chunk {}" .format (i ))
152+ chunk .export (chunk_path , format = "mp3" )
153+
154+ with open (chunk_path , "rb" ) as audio_file :
155+ # Use OpenAI Whisper API
156+ result = self .client .audio .transcriptions .create (
157+ model = self .model_type ,
158+ language = self .language ,
159+ prompt = self .prompt ,
160+ temperature = self .temperature ,
161+ file = audio_file ,
162+ response_format = "srt" ,
163+ )
164+
165+ # Clean up the temporary chunk file
166+ os .remove (chunk_path )
167+
168+ return (i , result , offset )
169+
170+ except Exception as e :
171+ # Clean up the temporary chunk file in case of error
172+ if os .path .exists (chunk_path ):
173+ os .remove (chunk_path )
174+ raise e
125175
126176 def transcribe (self , media_file : str ) -> str :
127177
128178 audio_file_path = convert_video_to_audio_ffmpeg (media_file )
129179
130180 chunks = self .chunk_audio (audio_file_path )
131181
132- results = ''
133-
134- for i , (chunk ,offset ) in enumerate (chunks ):
135- chunk_path = os .path .join (TMPDIR ,f'chunk_{ i } .mp3' )
136- print ('Transcribing audio chunk {}/{}' .format (i ,len (chunks )))
137- chunk .export (chunk_path , format = 'mp3' )
138- audio_file = open (chunk_path , "rb" )
139-
140- # Use OpenAI Whisper API
141- result = self .client .audio .transcriptions .create (
142- model = self .model_type ,
143- language = self .language ,
144- prompt = self .prompt ,
145- temperature = self .temperature ,
146- file = audio_file ,
147- response_format = "srt"
148- )
182+ print (f"Processing { len (chunks )} audio chunks with { self .n_jobs } parallel jobs" )
149183
150- with open ( chunk_path + '.srt' , 'w' ) as f :
151- f . write ( result )
184+ # Prepare chunk data for parallel processing
185+ chunk_data = [( i , chunk , offset ) for i , ( chunk , offset ) in enumerate ( chunks )]
152186
153- # shift subtitles by offset
154- result = SSAFile .from_string (result )
187+ # Use parallel processing if n_jobs > 1, otherwise process sequentially
188+ if self .n_jobs > 1 :
189+ # Use threading backend since API calls are I/O-bound
190+ parallel_results = Parallel (n_jobs = self .n_jobs , backend = "threading" )(
191+ delayed (self ._transcribe_chunk )(data ) for data in chunk_data
192+ )
193+ else :
194+ # Sequential processing for n_jobs=1
195+ parallel_results = [self ._transcribe_chunk (data ) for data in chunk_data ]
196+
197+ # Sort results by chunk index to maintain order
198+ parallel_results .sort (key = lambda x : x [0 ])
199+
200+ # Process results and apply time offsets
201+ results = ""
202+ for i , result_text , offset in parallel_results :
203+ # Shift subtitles by offset
204+ result = SSAFile .from_string (result_text )
155205 result .shift (ms = offset )
156- results += result .to_string ('srt' )
157-
158- results = '' .join (results )
206+ results += result .to_string ("srt" )
159207
160208 return SSAFile .from_string (results )
161-
0 commit comments