Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ streamlit-aggrid~=0.3.4
# openai API:
openai==1.60.1
ffmpeg-python>=0.2.0
joblib>=1.5.2

# whisper timestamped:
whisper-timestamped @ git+https://github.com/linto-ai/whisper-timestamped
Expand Down
109 changes: 78 additions & 31 deletions src/subsai/models/whisper_api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from openai import OpenAI
from pysubs2 import SSAFile
from pydub import AudioSegment
from joblib import Parallel, delayed

TMPDIR = tempfile.gettempdir()
OPENAI_API_SIZE_LIMIT_MB = 24
Expand Down Expand Up @@ -81,8 +82,14 @@ class WhisperAPIModel(AbstractModel):
'description': "The base URL for the API. Useful if you're already self hosting whisper for example.",
'options': None,
'default': "https://api.openai.com/v1/"
},
"n_jobs": {
"type": int,
"description": "Number of calls to do in parallel (1 to not use parallel call)",
"options": None,
"default": 1,
}
}
}

def __init__(self, model_config):
# config
Expand All @@ -94,10 +101,11 @@ def __init__(self, model_config):
self.base_url = _load_config('base_url', model_config, self.config_schema)
if not self.base_url.endswith("/"):
self.base_url += "/"
self.n_jobs = _load_config("n_jobs", model_config, self.config_schema)

self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)

def chunk_audio(self,audio_file_path) -> list:
def chunk_audio(self, audio_file_path) -> list:
# Load the audio file
audio = AudioSegment.from_mp3(audio_file_path)

Expand All @@ -114,48 +122,87 @@ def chunk_audio(self,audio_file_path) -> list:
# Calculate the end of the current chunk
end_ms = current_ms + chunk_duration_ms
# Create a chunk from the current position to the end position
chunk = audio[current_ms:int(end_ms)]
chunk = audio[current_ms : int(end_ms)]
# Add the chunk to the list of chunks and include offset
chunks.append((chunk,current_ms))
chunks.append((chunk, current_ms))
# Update the current position
current_ms = end_ms

return chunks


def transcribe(self, media_file) -> str:
def _transcribe_chunk(self, chunk_data):
"""
Transcribe a single audio chunk using OpenAI Whisper API.

Parameters
----------
chunk_data : tuple
Tuple containing (chunk_index, chunk, offset)

Returns
-------
tuple
Tuple containing (chunk_index, transcription_result, offset)
"""
i, chunk, offset = chunk_data
chunk_path = os.path.join(TMPDIR, f"chunk_{i}.mp3")

try:
print("Transcribing audio chunk {}".format(i))
chunk.export(chunk_path, format="mp3")

with open(chunk_path, "rb") as audio_file:
# Use OpenAI Whisper API
result = self.client.audio.transcriptions.create(
model=self.model_type,
language=self.language,
prompt=self.prompt,
temperature=self.temperature,
file=audio_file,
response_format="srt",
)

# Clean up the temporary chunk file
os.remove(chunk_path)

return (i, result, offset)

except Exception as e:
# Clean up the temporary chunk file in case of error
if os.path.exists(chunk_path):
os.remove(chunk_path)
raise e

def transcribe(self, media_file: str) -> str:

audio_file_path = convert_video_to_audio_ffmpeg(media_file)

chunks = self.chunk_audio(audio_file_path)

results = ''

for i, (chunk,offset) in enumerate(chunks):
chunk_path = os.path.join(TMPDIR,f'chunk_{i}.mp3')
print('Transcribing audio chunk {}/{}'.format(i,len(chunks)))
chunk.export(chunk_path, format='mp3')
audio_file = open(chunk_path, "rb")

# Use OpenAI Whisper API
result = self.client.audio.transcriptions.create(
model=self.model_type,
language=self.language,
prompt=self.prompt,
temperature=self.temperature,
file=audio_file,
response_format="srt"
)
print(f"Processing {len(chunks)} audio chunks with {self.n_jobs} parallel jobs")

with open(chunk_path+'.srt','w') as f:
f.write(result)
# Prepare chunk data for parallel processing
chunk_data = [(i, chunk, offset) for i, (chunk, offset) in enumerate(chunks)]

# shift subtitles by offset
result = SSAFile.from_string(result)
# Use parallel processing if n_jobs > 1, otherwise process sequentially
if self.n_jobs > 1:
# Use threading backend since API calls are I/O-bound
parallel_results = Parallel(n_jobs=self.n_jobs, backend="threading")(
delayed(self._transcribe_chunk)(data) for data in chunk_data
)
else:
# Sequential processing for n_jobs=1
parallel_results = [self._transcribe_chunk(data) for data in chunk_data]

# Sort results by chunk index to maintain order
parallel_results.sort(key=lambda x: x[0])

# Process results and apply time offsets
results = ""
for i, result_text, offset in parallel_results:
# Shift subtitles by offset
result = SSAFile.from_string(result_text)
result.shift(ms=offset)
results += result.to_string('srt')

results = ''.join(results)
results += result.to_string("srt")

return SSAFile.from_string(results)

Loading