Skip to content

Commit ef00dc2

Browse files
authored
Merge pull request #180 from thiswillbeyourgithub/parallel-api-calls
parallel api calls
2 parents fdae104 + a9109d9 commit ef00dc2

File tree

2 files changed

+79
-31
lines changed

2 files changed

+79
-31
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ streamlit-aggrid~=0.3.4
2121
# openai API:
2222
openai==1.60.1
2323
ffmpeg-python>=0.2.0
24+
joblib>=1.5.2
2425

2526
# whisper timestamped:
2627
whisper-timestamped @ git+https://github.com/linto-ai/whisper-timestamped

src/subsai/models/whisper_api_model.py

Lines changed: 78 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from openai import OpenAI
1616
from pysubs2 import SSAFile
1717
from pydub import AudioSegment
18+
from joblib import Parallel, delayed
1819

1920
TMPDIR = tempfile.gettempdir()
2021
OPENAI_API_SIZE_LIMIT_MB = 24
@@ -81,8 +82,14 @@ class WhisperAPIModel(AbstractModel):
8182
'description': "The base URL for the API. Useful if you're already self hosting whisper for example.",
8283
'options': None,
8384
'default': "https://api.openai.com/v1/"
85+
},
86+
"n_jobs": {
87+
"type": int,
88+
"description": "Number of calls to do in parallel (1 to not use parallel call)",
89+
"options": None,
90+
"default": 1,
8491
}
85-
}
92+
}
8693

8794
def __init__(self, model_config):
8895
# config
@@ -94,10 +101,11 @@ def __init__(self, model_config):
94101
self.base_url = _load_config('base_url', model_config, self.config_schema)
95102
if not self.base_url.endswith("/"):
96103
self.base_url += "/"
104+
self.n_jobs = _load_config("n_jobs", model_config, self.config_schema)
97105

98106
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
99107

100-
def chunk_audio(self,audio_file_path) -> list:
108+
def chunk_audio(self, audio_file_path) -> list:
101109
# Load the audio file
102110
audio = AudioSegment.from_mp3(audio_file_path)
103111

@@ -114,48 +122,87 @@ def chunk_audio(self,audio_file_path) -> list:
114122
# Calculate the end of the current chunk
115123
end_ms = current_ms + chunk_duration_ms
116124
# Create a chunk from the current position to the end position
117-
chunk = audio[current_ms:int(end_ms)]
125+
chunk = audio[current_ms : int(end_ms)]
118126
# Add the chunk to the list of chunks and include offset
119-
chunks.append((chunk,current_ms))
127+
chunks.append((chunk, current_ms))
120128
# Update the current position
121129
current_ms = end_ms
122130

123131
return chunks
124-
125132

126-
def transcribe(self, media_file) -> str:
133+
def _transcribe_chunk(self, chunk_data):
134+
"""
135+
Transcribe a single audio chunk using OpenAI Whisper API.
136+
137+
Parameters
138+
----------
139+
chunk_data : tuple
140+
Tuple containing (chunk_index, chunk, offset)
141+
142+
Returns
143+
-------
144+
tuple
145+
Tuple containing (chunk_index, transcription_result, offset)
146+
"""
147+
i, chunk, offset = chunk_data
148+
chunk_path = os.path.join(TMPDIR, f"chunk_{i}.mp3")
149+
150+
try:
151+
print("Transcribing audio chunk {}".format(i))
152+
chunk.export(chunk_path, format="mp3")
153+
154+
with open(chunk_path, "rb") as audio_file:
155+
# Use OpenAI Whisper API
156+
result = self.client.audio.transcriptions.create(
157+
model=self.model_type,
158+
language=self.language,
159+
prompt=self.prompt,
160+
temperature=self.temperature,
161+
file=audio_file,
162+
response_format="srt",
163+
)
164+
165+
# Clean up the temporary chunk file
166+
os.remove(chunk_path)
167+
168+
return (i, result, offset)
169+
170+
except Exception as e:
171+
# Clean up the temporary chunk file in case of error
172+
if os.path.exists(chunk_path):
173+
os.remove(chunk_path)
174+
raise e
175+
176+
def transcribe(self, media_file: str) -> str:
127177

128178
audio_file_path = convert_video_to_audio_ffmpeg(media_file)
129179

130180
chunks = self.chunk_audio(audio_file_path)
131181

132-
results = ''
133-
134-
for i, (chunk,offset) in enumerate(chunks):
135-
chunk_path = os.path.join(TMPDIR,f'chunk_{i}.mp3')
136-
print('Transcribing audio chunk {}/{}'.format(i,len(chunks)))
137-
chunk.export(chunk_path, format='mp3')
138-
audio_file = open(chunk_path, "rb")
139-
140-
# Use OpenAI Whisper API
141-
result = self.client.audio.transcriptions.create(
142-
model=self.model_type,
143-
language=self.language,
144-
prompt=self.prompt,
145-
temperature=self.temperature,
146-
file=audio_file,
147-
response_format="srt"
148-
)
182+
print(f"Processing {len(chunks)} audio chunks with {self.n_jobs} parallel jobs")
149183

150-
with open(chunk_path+'.srt','w') as f:
151-
f.write(result)
184+
# Prepare chunk data for parallel processing
185+
chunk_data = [(i, chunk, offset) for i, (chunk, offset) in enumerate(chunks)]
152186

153-
# shift subtitles by offset
154-
result = SSAFile.from_string(result)
187+
# Use parallel processing if n_jobs > 1, otherwise process sequentially
188+
if self.n_jobs > 1:
189+
# Use threading backend since API calls are I/O-bound
190+
parallel_results = Parallel(n_jobs=self.n_jobs, backend="threading")(
191+
delayed(self._transcribe_chunk)(data) for data in chunk_data
192+
)
193+
else:
194+
# Sequential processing for n_jobs=1
195+
parallel_results = [self._transcribe_chunk(data) for data in chunk_data]
196+
197+
# Sort results by chunk index to maintain order
198+
parallel_results.sort(key=lambda x: x[0])
199+
200+
# Process results and apply time offsets
201+
results = ""
202+
for i, result_text, offset in parallel_results:
203+
# Shift subtitles by offset
204+
result = SSAFile.from_string(result_text)
155205
result.shift(ms=offset)
156-
results += result.to_string('srt')
157-
158-
results = ''.join(results)
206+
results += result.to_string("srt")
159207

160208
return SSAFile.from_string(results)
161-

0 commit comments

Comments
 (0)