Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -40,30 +40,53 @@ from faster_whisper.transcribe import BatchedInferencePipeline
|
|
40 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
41 |
|
42 |
def download_audio(url, method_choice):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
parsed_url = urlparse(url)
|
44 |
logging.info(f"Downloading audio from URL: {url} using method: {method_choice}")
|
45 |
try:
|
46 |
-
if
|
|
|
47 |
audio_file = download_youtube_audio(url, method_choice)
|
48 |
else:
|
|
|
49 |
audio_file = download_direct_audio(url, method_choice)
|
50 |
if not audio_file or not os.path.exists(audio_file):
|
51 |
raise Exception(f"Failed to download audio from {url}")
|
52 |
-
return audio_file
|
53 |
except Exception as e:
|
54 |
logging.error(f"Error downloading audio: {str(e)}")
|
55 |
-
return f"Error: {str(e)}"
|
56 |
|
57 |
def download_youtube_audio(url, method_choice):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
methods = {
|
59 |
'yt-dlp': youtube_dl_method,
|
60 |
'pytube': pytube_method,
|
61 |
'youtube-dl': youtube_dl_classic_method,
|
62 |
'yt-dlp-alt': youtube_dl_alternative_method,
|
63 |
-
'ffmpeg': ffmpeg_method,
|
64 |
-
'aria2': aria2_method
|
65 |
}
|
66 |
-
method = methods.get(method_choice
|
|
|
|
|
|
|
67 |
try:
|
68 |
logging.info(f"Attempting to download YouTube audio using {method_choice}")
|
69 |
return method(url)
|
@@ -157,19 +180,31 @@ def aria2_method(url):
|
|
157 |
return output_file
|
158 |
|
159 |
def download_direct_audio(url, method_choice):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
logging.info(f"Downloading direct audio from: {url} using method: {method_choice}")
|
161 |
if method_choice == 'wget':
|
162 |
return wget_method(url)
|
163 |
else:
|
164 |
try:
|
165 |
-
response = requests.get(url)
|
166 |
if response.status_code == 200:
|
167 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
|
168 |
-
|
169 |
-
|
170 |
-
|
|
|
|
|
171 |
else:
|
172 |
-
raise Exception(f"Failed to download audio from {url}")
|
173 |
except Exception as e:
|
174 |
logging.error(f"Error downloading direct audio: {str(e)}")
|
175 |
return None
|
@@ -183,56 +218,108 @@ def wget_method(url):
|
|
183 |
return output_file
|
184 |
|
185 |
def trim_audio(audio_path, start_time, end_time):
|
186 |
-
|
187 |
-
audio
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
def save_transcription(transcription):
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
|
217 |
def get_model_options(pipeline_type):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
if pipeline_type == "faster-batched":
|
219 |
-
return ["cstr/whisper-large-v3-turbo-int8_float32", "
|
220 |
elif pipeline_type == "faster-sequenced":
|
221 |
-
return ["
|
222 |
elif pipeline_type == "transformers":
|
223 |
-
return ["openai/whisper-large-v3", "openai/whisper-large-
|
224 |
else:
|
225 |
return []
|
226 |
|
227 |
loaded_models = {}
|
228 |
|
229 |
def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, download_method, start_time=None, end_time=None, verbose=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
try:
|
231 |
if verbose:
|
232 |
logging.getLogger().setLevel(logging.INFO)
|
233 |
else:
|
234 |
logging.getLogger().setLevel(logging.WARNING)
|
235 |
-
|
236 |
logging.info(f"Transcription parameters: pipeline_type={pipeline_type}, model_id={model_id}, dtype={dtype}, batch_size={batch_size}, download_method={download_method}")
|
237 |
verbose_messages = f"Starting transcription with parameters:\nPipeline Type: {pipeline_type}\nModel ID: {model_id}\nData Type: {dtype}\nBatch Size: {batch_size}\nDownload Method: {download_method}\n"
|
238 |
|
@@ -240,21 +327,25 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
|
|
240 |
yield verbose_messages, "", None
|
241 |
|
242 |
# Determine if input_source is a URL or file
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
elif input_source
|
253 |
-
#
|
|
|
|
|
|
|
|
|
254 |
audio_path = input_source.name
|
255 |
-
|
256 |
else:
|
257 |
-
yield "No audio source provided.", "", None
|
258 |
return
|
259 |
|
260 |
# Convert start_time and end_time to float or None
|
@@ -262,8 +353,8 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
|
|
262 |
end_time = float(end_time) if end_time else None
|
263 |
|
264 |
if start_time is not None or end_time is not None:
|
265 |
-
|
266 |
-
|
267 |
verbose_messages += f"Audio trimmed from {start_time} to {end_time}\n"
|
268 |
if verbose:
|
269 |
yield verbose_messages, "", None
|
@@ -276,10 +367,9 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
|
|
276 |
else:
|
277 |
if pipeline_type == "faster-batched":
|
278 |
model = WhisperModel(model_id, device=device, compute_type=dtype)
|
279 |
-
|
280 |
elif pipeline_type == "faster-sequenced":
|
281 |
-
|
282 |
-
pipeline = model.transcribe
|
283 |
elif pipeline_type == "transformers":
|
284 |
torch_dtype = torch.float16 if dtype == "float16" else torch.float32
|
285 |
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
@@ -287,7 +377,7 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
|
|
287 |
)
|
288 |
model.to(device)
|
289 |
processor = AutoProcessor.from_pretrained(model_id)
|
290 |
-
|
291 |
"automatic-speech-recognition",
|
292 |
model=model,
|
293 |
tokenizer=processor.tokenizer,
|
@@ -300,7 +390,7 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
|
|
300 |
)
|
301 |
else:
|
302 |
raise ValueError("Invalid pipeline type")
|
303 |
-
loaded_models[model_key] = model_or_pipeline # Cache the model
|
304 |
|
305 |
start_time_perf = time.time()
|
306 |
if pipeline_type == "faster-batched":
|
@@ -343,11 +433,9 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
|
|
343 |
|
344 |
finally:
|
345 |
# Clean up temporary files
|
346 |
-
if audio_path and os.path.exists(audio_path):
|
347 |
os.remove(audio_path)
|
348 |
-
if '
|
349 |
-
os.remove(trimmed_audio_path)
|
350 |
-
if 'transcription_file' in locals() and os.path.exists(transcription_file):
|
351 |
os.remove(transcription_file)
|
352 |
|
353 |
with gr.Blocks() as iface:
|
@@ -390,6 +478,15 @@ with gr.Blocks() as iface:
|
|
390 |
transcription_file = gr.File(label="Download Transcription")
|
391 |
|
392 |
def update_model_dropdown(pipeline_type):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
try:
|
394 |
model_choices = get_model_options(pipeline_type)
|
395 |
logging.info(f"Model choices for {pipeline_type}: {model_choices}")
|
|
|
40 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
41 |
|
42 |
def download_audio(url, method_choice):
|
43 |
+
"""
|
44 |
+
Downloads audio from a given URL using the specified method.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
url (str): The URL of the audio.
|
48 |
+
method_choice (str): The method to use for downloading audio.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
tuple: (path to the downloaded audio file, is_temp_file), or (error message, False).
|
52 |
+
"""
|
53 |
parsed_url = urlparse(url)
|
54 |
logging.info(f"Downloading audio from URL: {url} using method: {method_choice}")
|
55 |
try:
|
56 |
+
if 'youtube.com' in parsed_url.netloc or 'youtu.be' in parsed_url.netloc:
|
57 |
+
# Use YouTube download methods
|
58 |
audio_file = download_youtube_audio(url, method_choice)
|
59 |
else:
|
60 |
+
# Use direct download methods
|
61 |
audio_file = download_direct_audio(url, method_choice)
|
62 |
if not audio_file or not os.path.exists(audio_file):
|
63 |
raise Exception(f"Failed to download audio from {url}")
|
64 |
+
return audio_file, True # The file is a temporary file
|
65 |
except Exception as e:
|
66 |
logging.error(f"Error downloading audio: {str(e)}")
|
67 |
+
return f"Error: {str(e)}", False
|
68 |
|
69 |
def download_youtube_audio(url, method_choice):
|
70 |
+
"""
|
71 |
+
Downloads audio from a YouTube URL using the specified method.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
url (str): The YouTube URL.
|
75 |
+
method_choice (str): The method to use for downloading ('yt-dlp', 'pytube', 'youtube-dl').
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
str: Path to the downloaded audio file, or None if failed.
|
79 |
+
"""
|
80 |
methods = {
|
81 |
'yt-dlp': youtube_dl_method,
|
82 |
'pytube': pytube_method,
|
83 |
'youtube-dl': youtube_dl_classic_method,
|
84 |
'yt-dlp-alt': youtube_dl_alternative_method,
|
|
|
|
|
85 |
}
|
86 |
+
method = methods.get(method_choice)
|
87 |
+
if method is None:
|
88 |
+
logging.warning(f"Invalid download method for YouTube: {method_choice}. Defaulting to 'yt-dlp'.")
|
89 |
+
method = youtube_dl_method
|
90 |
try:
|
91 |
logging.info(f"Attempting to download YouTube audio using {method_choice}")
|
92 |
return method(url)
|
|
|
180 |
return output_file
|
181 |
|
182 |
def download_direct_audio(url, method_choice):
|
183 |
+
"""
|
184 |
+
Downloads audio from a direct URL using the specified method.
|
185 |
+
|
186 |
+
Args:
|
187 |
+
url (str): The direct URL of the audio file.
|
188 |
+
method_choice (str): The method to use for downloading ('wget', 'requests').
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
str: Path to the downloaded audio file, or None if failed.
|
192 |
+
"""
|
193 |
logging.info(f"Downloading direct audio from: {url} using method: {method_choice}")
|
194 |
if method_choice == 'wget':
|
195 |
return wget_method(url)
|
196 |
else:
|
197 |
try:
|
198 |
+
response = requests.get(url, stream=True)
|
199 |
if response.status_code == 200:
|
200 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
|
201 |
+
for chunk in response.iter_content(chunk_size=8192):
|
202 |
+
if chunk:
|
203 |
+
temp_file.write(chunk)
|
204 |
+
logging.info(f"Downloaded direct audio to: {temp_file.name}")
|
205 |
+
return temp_file.name
|
206 |
else:
|
207 |
+
raise Exception(f"Failed to download audio from {url} with status code {response.status_code}")
|
208 |
except Exception as e:
|
209 |
logging.error(f"Error downloading direct audio: {str(e)}")
|
210 |
return None
|
|
|
218 |
return output_file
|
219 |
|
220 |
def trim_audio(audio_path, start_time, end_time):
|
221 |
+
"""
|
222 |
+
Trims an audio file to the specified start and end times.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
audio_path (str): Path to the audio file.
|
226 |
+
start_time (float): Start time in seconds.
|
227 |
+
end_time (float): End time in seconds.
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
str: Path to the trimmed audio file.
|
231 |
+
|
232 |
+
Raises:
|
233 |
+
gr.Error: If invalid start or end times are provided.
|
234 |
+
"""
|
235 |
+
try:
|
236 |
+
logging.info(f"Trimming audio from {start_time} to {end_time}")
|
237 |
+
audio = AudioSegment.from_file(audio_path)
|
238 |
+
audio_duration = len(audio) / 1000 # Duration in seconds
|
239 |
+
|
240 |
+
# Default start and end times if None
|
241 |
+
if start_time is None:
|
242 |
+
start_time = 0
|
243 |
+
if end_time is None or end_time > audio_duration:
|
244 |
+
end_time = audio_duration
|
245 |
+
|
246 |
+
# Validate times
|
247 |
+
if start_time < 0 or end_time <= 0:
|
248 |
+
raise gr.Error("Start time and end time must be positive.")
|
249 |
+
if start_time >= end_time:
|
250 |
+
raise gr.Error("End time must be greater than start time.")
|
251 |
+
if start_time > audio_duration:
|
252 |
+
raise gr.Error("Start time exceeds audio duration.")
|
253 |
+
|
254 |
+
trimmed_audio = audio[start_time * 1000:end_time * 1000]
|
255 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio_file:
|
256 |
+
trimmed_audio.export(temp_audio_file.name, format="wav")
|
257 |
+
logging.info(f"Trimmed audio saved to: {temp_audio_file.name}")
|
258 |
+
return temp_audio_file.name
|
259 |
+
except Exception as e:
|
260 |
+
logging.error(f"Error trimming audio: {str(e)}")
|
261 |
+
raise gr.Error(f"Error trimming audio: {str(e)}")
|
262 |
|
263 |
def save_transcription(transcription):
|
264 |
+
"""
|
265 |
+
Saves the transcription text to a temporary file.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
transcription (str): The transcription text.
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
str: The path to the transcription file.
|
272 |
+
"""
|
273 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.txt', mode='w', encoding='utf-8') as temp_file:
|
274 |
+
temp_file.write(transcription)
|
275 |
+
logging.info(f"Transcription saved to: {temp_file.name}")
|
276 |
+
return temp_file.name
|
277 |
|
278 |
def get_model_options(pipeline_type):
|
279 |
+
"""
|
280 |
+
Returns a list of model IDs based on the selected pipeline type.
|
281 |
+
|
282 |
+
Args:
|
283 |
+
pipeline_type (str): The type of pipeline ('faster-batched', 'faster-sequenced', 'transformers').
|
284 |
+
|
285 |
+
Returns:
|
286 |
+
list: A list of model IDs.
|
287 |
+
"""
|
288 |
if pipeline_type == "faster-batched":
|
289 |
+
return ["cstr/whisper-large-v3-turbo-int8_float32", "SYSTRAN/faster-whisper-large-v1", "GalaktischeGurke/primeline-whisper-large-v3-german-ct2"]
|
290 |
elif pipeline_type == "faster-sequenced":
|
291 |
+
return ["SYSTRAN/faster-whisper-large-v1", "GalaktischeGurke/primeline-whisper-large-v3-german-ct2"]
|
292 |
elif pipeline_type == "transformers":
|
293 |
+
return ["openai/whisper-large-v3", "openai/whisper-large-v2"]
|
294 |
else:
|
295 |
return []
|
296 |
|
297 |
loaded_models = {}
|
298 |
|
299 |
def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, download_method, start_time=None, end_time=None, verbose=False):
|
300 |
+
"""
|
301 |
+
Transcribes audio from a given source using the specified pipeline and model.
|
302 |
+
|
303 |
+
Args:
|
304 |
+
input_source (str or file): URL of audio, path to local file, or uploaded file object.
|
305 |
+
pipeline_type (str): Type of pipeline to use ('faster-batched', 'faster-sequenced', or 'transformers').
|
306 |
+
model_id (str): The ID of the model to use.
|
307 |
+
dtype (str): Data type for model computations ('int8', 'float16', or 'float32').
|
308 |
+
batch_size (int): Batch size for transcription.
|
309 |
+
download_method (str): Method to use for downloading audio.
|
310 |
+
start_time (float, optional): Start time in seconds for trimming audio.
|
311 |
+
end_time (float, optional): End time in seconds for trimming audio.
|
312 |
+
verbose (bool, optional): Whether to output verbose logging.
|
313 |
+
|
314 |
+
Yields:
|
315 |
+
Tuple[str, str, str or None]: Metrics and messages, transcription text, path to transcription file.
|
316 |
+
"""
|
317 |
try:
|
318 |
if verbose:
|
319 |
logging.getLogger().setLevel(logging.INFO)
|
320 |
else:
|
321 |
logging.getLogger().setLevel(logging.WARNING)
|
322 |
+
|
323 |
logging.info(f"Transcription parameters: pipeline_type={pipeline_type}, model_id={model_id}, dtype={dtype}, batch_size={batch_size}, download_method={download_method}")
|
324 |
verbose_messages = f"Starting transcription with parameters:\nPipeline Type: {pipeline_type}\nModel ID: {model_id}\nData Type: {dtype}\nBatch Size: {batch_size}\nDownload Method: {download_method}\n"
|
325 |
|
|
|
327 |
yield verbose_messages, "", None
|
328 |
|
329 |
# Determine if input_source is a URL or file
|
330 |
+
audio_path = None
|
331 |
+
is_temp_file = False
|
332 |
+
|
333 |
+
if isinstance(input_source, str) and (input_source.startswith('http://') or input_source.startswith('https://')):
|
334 |
+
# Input source is a URL
|
335 |
+
audio_path, is_temp_file = download_audio(input_source, download_method)
|
336 |
+
if not audio_path or audio_path.startswith("Error"):
|
337 |
+
yield f"Error downloading audio: {audio_path}", "", None
|
338 |
+
return
|
339 |
+
elif isinstance(input_source, str) and os.path.exists(input_source):
|
340 |
+
# Input source is a local file path
|
341 |
+
audio_path = input_source
|
342 |
+
is_temp_file = False
|
343 |
+
elif hasattr(input_source, 'name'):
|
344 |
+
# Input source is an uploaded file object
|
345 |
audio_path = input_source.name
|
346 |
+
is_temp_file = False
|
347 |
else:
|
348 |
+
yield "No valid audio source provided.", "", None
|
349 |
return
|
350 |
|
351 |
# Convert start_time and end_time to float or None
|
|
|
353 |
end_time = float(end_time) if end_time else None
|
354 |
|
355 |
if start_time is not None or end_time is not None:
|
356 |
+
audio_path = trim_audio(audio_path, start_time, end_time)
|
357 |
+
is_temp_file = True # The trimmed audio is a temporary file
|
358 |
verbose_messages += f"Audio trimmed from {start_time} to {end_time}\n"
|
359 |
if verbose:
|
360 |
yield verbose_messages, "", None
|
|
|
367 |
else:
|
368 |
if pipeline_type == "faster-batched":
|
369 |
model = WhisperModel(model_id, device=device, compute_type=dtype)
|
370 |
+
model_or_pipeline = BatchedInferencePipeline(model=model)
|
371 |
elif pipeline_type == "faster-sequenced":
|
372 |
+
model_or_pipeline = WhisperModel(model_id, device=device, compute_type=dtype)
|
|
|
373 |
elif pipeline_type == "transformers":
|
374 |
torch_dtype = torch.float16 if dtype == "float16" else torch.float32
|
375 |
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
|
|
377 |
)
|
378 |
model.to(device)
|
379 |
processor = AutoProcessor.from_pretrained(model_id)
|
380 |
+
model_or_pipeline = pipeline(
|
381 |
"automatic-speech-recognition",
|
382 |
model=model,
|
383 |
tokenizer=processor.tokenizer,
|
|
|
390 |
)
|
391 |
else:
|
392 |
raise ValueError("Invalid pipeline type")
|
393 |
+
loaded_models[model_key] = model_or_pipeline # Cache the model or pipeline
|
394 |
|
395 |
start_time_perf = time.time()
|
396 |
if pipeline_type == "faster-batched":
|
|
|
433 |
|
434 |
finally:
|
435 |
# Clean up temporary files
|
436 |
+
if audio_path and is_temp_file and os.path.exists(audio_path):
|
437 |
os.remove(audio_path)
|
438 |
+
if 'transcription_file' in locals() and transcription_file and os.path.exists(transcription_file):
|
|
|
|
|
439 |
os.remove(transcription_file)
|
440 |
|
441 |
with gr.Blocks() as iface:
|
|
|
478 |
transcription_file = gr.File(label="Download Transcription")
|
479 |
|
480 |
def update_model_dropdown(pipeline_type):
|
481 |
+
"""
|
482 |
+
Updates the model dropdown choices based on the selected pipeline type.
|
483 |
+
|
484 |
+
Args:
|
485 |
+
pipeline_type (str): The selected pipeline type.
|
486 |
+
|
487 |
+
Returns:
|
488 |
+
gr.update: Updated model dropdown component.
|
489 |
+
"""
|
490 |
try:
|
491 |
model_choices = get_model_options(pipeline_type)
|
492 |
logging.info(f"Model choices for {pipeline_type}: {model_choices}")
|