cstr commited on
Commit
516bec5
·
verified ·
1 Parent(s): 4b50bd3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -67
app.py CHANGED
@@ -40,30 +40,53 @@ from faster_whisper.transcribe import BatchedInferencePipeline
40
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
41
 
42
  def download_audio(url, method_choice):
 
 
 
 
 
 
 
 
 
 
43
  parsed_url = urlparse(url)
44
  logging.info(f"Downloading audio from URL: {url} using method: {method_choice}")
45
  try:
46
- if parsed_url.netloc in ['www.youtube.com', 'youtu.be', 'youtube.com']:
 
47
  audio_file = download_youtube_audio(url, method_choice)
48
  else:
 
49
  audio_file = download_direct_audio(url, method_choice)
50
  if not audio_file or not os.path.exists(audio_file):
51
  raise Exception(f"Failed to download audio from {url}")
52
- return audio_file
53
  except Exception as e:
54
  logging.error(f"Error downloading audio: {str(e)}")
55
- return f"Error: {str(e)}"
56
 
57
  def download_youtube_audio(url, method_choice):
 
 
 
 
 
 
 
 
 
 
58
  methods = {
59
  'yt-dlp': youtube_dl_method,
60
  'pytube': pytube_method,
61
  'youtube-dl': youtube_dl_classic_method,
62
  'yt-dlp-alt': youtube_dl_alternative_method,
63
- 'ffmpeg': ffmpeg_method,
64
- 'aria2': aria2_method
65
  }
66
- method = methods.get(method_choice, youtube_dl_method)
 
 
 
67
  try:
68
  logging.info(f"Attempting to download YouTube audio using {method_choice}")
69
  return method(url)
@@ -157,19 +180,31 @@ def aria2_method(url):
157
  return output_file
158
 
159
  def download_direct_audio(url, method_choice):
 
 
 
 
 
 
 
 
 
 
160
  logging.info(f"Downloading direct audio from: {url} using method: {method_choice}")
161
  if method_choice == 'wget':
162
  return wget_method(url)
163
  else:
164
  try:
165
- response = requests.get(url)
166
  if response.status_code == 200:
167
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
168
- temp_file.write(response.content)
169
- logging.info(f"Downloaded direct audio to: {temp_file.name}")
170
- return temp_file.name
 
 
171
  else:
172
- raise Exception(f"Failed to download audio from {url}")
173
  except Exception as e:
174
  logging.error(f"Error downloading direct audio: {str(e)}")
175
  return None
@@ -183,56 +218,108 @@ def wget_method(url):
183
  return output_file
184
 
185
  def trim_audio(audio_path, start_time, end_time):
186
- logging.info(f"Trimming audio from {start_time} to {end_time}")
187
- audio = AudioSegment.from_file(audio_path)
188
- audio_duration = len(audio) / 1000 # Duration in seconds
189
-
190
- # Default start and end times if None
191
- if start_time is None:
192
- start_time = 0
193
- if end_time is None or end_time > audio_duration:
194
- end_time = audio_duration
195
-
196
- # Validate times
197
- if start_time < 0 or end_time < 0:
198
- raise gr.Error("Start time and end time must be non-negative.")
199
- if start_time >= end_time:
200
- raise gr.Error("End time must be greater than start time.")
201
- if start_time > audio_duration:
202
- raise gr.Error("Start time exceeds audio duration.")
203
-
204
- trimmed_audio = audio[start_time * 1000:end_time * 1000]
205
- trimmed_audio_path = tempfile.mktemp(suffix='.wav')
206
- trimmed_audio.export(trimmed_audio_path, format="wav")
207
- logging.info(f"Trimmed audio saved to: {trimmed_audio_path}")
208
- return trimmed_audio_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  def save_transcription(transcription):
211
- file_path = tempfile.mktemp(suffix='.txt')
212
- with open(file_path, 'w') as f:
213
- f.write(transcription)
214
- logging.info(f"Transcription saved to: {file_path}")
215
- return file_path
 
 
 
 
 
 
 
 
216
 
217
  def get_model_options(pipeline_type):
 
 
 
 
 
 
 
 
 
218
  if pipeline_type == "faster-batched":
219
- return ["cstr/whisper-large-v3-turbo-int8_float32", "deepdml/faster-whisper-large-v3-turbo-ct2", "Systran/faster-whisper-large-v3", "GalaktischeGurke/primeline-whisper-large-v3-german-ct2"]
220
  elif pipeline_type == "faster-sequenced":
221
- return ["cstr/whisper-large-v3-turbo-int8_float32", "deepdml/faster-whisper-large-v3-turbo-ct2", "Systran/faster-whisper-large-v3", "GalaktischeGurke/primeline-whisper-large-v3-german-ct2"]
222
  elif pipeline_type == "transformers":
223
- return ["openai/whisper-large-v3", "openai/whisper-large-v3-turbo", "primeline/whisper-large-v3-german"]
224
  else:
225
  return []
226
 
227
  loaded_models = {}
228
 
229
  def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, download_method, start_time=None, end_time=None, verbose=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  try:
231
  if verbose:
232
  logging.getLogger().setLevel(logging.INFO)
233
  else:
234
  logging.getLogger().setLevel(logging.WARNING)
235
-
236
  logging.info(f"Transcription parameters: pipeline_type={pipeline_type}, model_id={model_id}, dtype={dtype}, batch_size={batch_size}, download_method={download_method}")
237
  verbose_messages = f"Starting transcription with parameters:\nPipeline Type: {pipeline_type}\nModel ID: {model_id}\nData Type: {dtype}\nBatch Size: {batch_size}\nDownload Method: {download_method}\n"
238
 
@@ -240,21 +327,25 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
240
  yield verbose_messages, "", None
241
 
242
  # Determine if input_source is a URL or file
243
- if isinstance(input_source, str):
244
- if input_source.startswith('http://') or input_source.startswith('https://'):
245
- audio_path = download_audio(input_source, download_method)
246
- if not audio_path or audio_path.startswith("Error"):
247
- yield f"Error: {audio_path}", "", None
248
- return
249
- else:
250
- # Assume it's a local file path
251
- audio_path = input_source
252
- elif input_source is not None:
253
- # Uploaded file object
 
 
 
 
254
  audio_path = input_source.name
255
- logging.info(f"Using uploaded audio file: {audio_path}")
256
  else:
257
- yield "No audio source provided.", "", None
258
  return
259
 
260
  # Convert start_time and end_time to float or None
@@ -262,8 +353,8 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
262
  end_time = float(end_time) if end_time else None
263
 
264
  if start_time is not None or end_time is not None:
265
- trimmed_audio_path = trim_audio(audio_path, start_time, end_time)
266
- audio_path = trimmed_audio_path
267
  verbose_messages += f"Audio trimmed from {start_time} to {end_time}\n"
268
  if verbose:
269
  yield verbose_messages, "", None
@@ -276,10 +367,9 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
276
  else:
277
  if pipeline_type == "faster-batched":
278
  model = WhisperModel(model_id, device=device, compute_type=dtype)
279
- pipeline = BatchedInferencePipeline(model=model)
280
  elif pipeline_type == "faster-sequenced":
281
- model = WhisperModel(model_id, device=device, compute_type=dtype)
282
- pipeline = model.transcribe
283
  elif pipeline_type == "transformers":
284
  torch_dtype = torch.float16 if dtype == "float16" else torch.float32
285
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
@@ -287,7 +377,7 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
287
  )
288
  model.to(device)
289
  processor = AutoProcessor.from_pretrained(model_id)
290
- pipeline = pipeline(
291
  "automatic-speech-recognition",
292
  model=model,
293
  tokenizer=processor.tokenizer,
@@ -300,7 +390,7 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
300
  )
301
  else:
302
  raise ValueError("Invalid pipeline type")
303
- loaded_models[model_key] = model_or_pipeline # Cache the model
304
 
305
  start_time_perf = time.time()
306
  if pipeline_type == "faster-batched":
@@ -343,11 +433,9 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
343
 
344
  finally:
345
  # Clean up temporary files
346
- if audio_path and os.path.exists(audio_path):
347
  os.remove(audio_path)
348
- if 'trimmed_audio_path' in locals() and os.path.exists(trimmed_audio_path):
349
- os.remove(trimmed_audio_path)
350
- if 'transcription_file' in locals() and os.path.exists(transcription_file):
351
  os.remove(transcription_file)
352
 
353
  with gr.Blocks() as iface:
@@ -390,6 +478,15 @@ with gr.Blocks() as iface:
390
  transcription_file = gr.File(label="Download Transcription")
391
 
392
  def update_model_dropdown(pipeline_type):
 
 
 
 
 
 
 
 
 
393
  try:
394
  model_choices = get_model_options(pipeline_type)
395
  logging.info(f"Model choices for {pipeline_type}: {model_choices}")
 
40
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
41
 
42
  def download_audio(url, method_choice):
43
+ """
44
+ Downloads audio from a given URL using the specified method.
45
+
46
+ Args:
47
+ url (str): The URL of the audio.
48
+ method_choice (str): The method to use for downloading audio.
49
+
50
+ Returns:
51
+ tuple: (path to the downloaded audio file, is_temp_file), or (error message, False).
52
+ """
53
  parsed_url = urlparse(url)
54
  logging.info(f"Downloading audio from URL: {url} using method: {method_choice}")
55
  try:
56
+ if 'youtube.com' in parsed_url.netloc or 'youtu.be' in parsed_url.netloc:
57
+ # Use YouTube download methods
58
  audio_file = download_youtube_audio(url, method_choice)
59
  else:
60
+ # Use direct download methods
61
  audio_file = download_direct_audio(url, method_choice)
62
  if not audio_file or not os.path.exists(audio_file):
63
  raise Exception(f"Failed to download audio from {url}")
64
+ return audio_file, True # The file is a temporary file
65
  except Exception as e:
66
  logging.error(f"Error downloading audio: {str(e)}")
67
+ return f"Error: {str(e)}", False
68
 
69
  def download_youtube_audio(url, method_choice):
70
+ """
71
+ Downloads audio from a YouTube URL using the specified method.
72
+
73
+ Args:
74
+ url (str): The YouTube URL.
75
+ method_choice (str): The method to use for downloading ('yt-dlp', 'pytube', 'youtube-dl').
76
+
77
+ Returns:
78
+ str: Path to the downloaded audio file, or None if failed.
79
+ """
80
  methods = {
81
  'yt-dlp': youtube_dl_method,
82
  'pytube': pytube_method,
83
  'youtube-dl': youtube_dl_classic_method,
84
  'yt-dlp-alt': youtube_dl_alternative_method,
 
 
85
  }
86
+ method = methods.get(method_choice)
87
+ if method is None:
88
+ logging.warning(f"Invalid download method for YouTube: {method_choice}. Defaulting to 'yt-dlp'.")
89
+ method = youtube_dl_method
90
  try:
91
  logging.info(f"Attempting to download YouTube audio using {method_choice}")
92
  return method(url)
 
180
  return output_file
181
 
182
  def download_direct_audio(url, method_choice):
183
+ """
184
+ Downloads audio from a direct URL using the specified method.
185
+
186
+ Args:
187
+ url (str): The direct URL of the audio file.
188
+ method_choice (str): The method to use for downloading ('wget', 'requests').
189
+
190
+ Returns:
191
+ str: Path to the downloaded audio file, or None if failed.
192
+ """
193
  logging.info(f"Downloading direct audio from: {url} using method: {method_choice}")
194
  if method_choice == 'wget':
195
  return wget_method(url)
196
  else:
197
  try:
198
+ response = requests.get(url, stream=True)
199
  if response.status_code == 200:
200
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
201
+ for chunk in response.iter_content(chunk_size=8192):
202
+ if chunk:
203
+ temp_file.write(chunk)
204
+ logging.info(f"Downloaded direct audio to: {temp_file.name}")
205
+ return temp_file.name
206
  else:
207
+ raise Exception(f"Failed to download audio from {url} with status code {response.status_code}")
208
  except Exception as e:
209
  logging.error(f"Error downloading direct audio: {str(e)}")
210
  return None
 
218
  return output_file
219
 
220
  def trim_audio(audio_path, start_time, end_time):
221
+ """
222
+ Trims an audio file to the specified start and end times.
223
+
224
+ Args:
225
+ audio_path (str): Path to the audio file.
226
+ start_time (float): Start time in seconds.
227
+ end_time (float): End time in seconds.
228
+
229
+ Returns:
230
+ str: Path to the trimmed audio file.
231
+
232
+ Raises:
233
+ gr.Error: If invalid start or end times are provided.
234
+ """
235
+ try:
236
+ logging.info(f"Trimming audio from {start_time} to {end_time}")
237
+ audio = AudioSegment.from_file(audio_path)
238
+ audio_duration = len(audio) / 1000 # Duration in seconds
239
+
240
+ # Default start and end times if None
241
+ if start_time is None:
242
+ start_time = 0
243
+ if end_time is None or end_time > audio_duration:
244
+ end_time = audio_duration
245
+
246
+ # Validate times
247
+ if start_time < 0 or end_time <= 0:
248
+ raise gr.Error("Start time and end time must be positive.")
249
+ if start_time >= end_time:
250
+ raise gr.Error("End time must be greater than start time.")
251
+ if start_time > audio_duration:
252
+ raise gr.Error("Start time exceeds audio duration.")
253
+
254
+ trimmed_audio = audio[start_time * 1000:end_time * 1000]
255
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio_file:
256
+ trimmed_audio.export(temp_audio_file.name, format="wav")
257
+ logging.info(f"Trimmed audio saved to: {temp_audio_file.name}")
258
+ return temp_audio_file.name
259
+ except Exception as e:
260
+ logging.error(f"Error trimming audio: {str(e)}")
261
+ raise gr.Error(f"Error trimming audio: {str(e)}")
262
 
263
  def save_transcription(transcription):
264
+ """
265
+ Saves the transcription text to a temporary file.
266
+
267
+ Args:
268
+ transcription (str): The transcription text.
269
+
270
+ Returns:
271
+ str: The path to the transcription file.
272
+ """
273
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.txt', mode='w', encoding='utf-8') as temp_file:
274
+ temp_file.write(transcription)
275
+ logging.info(f"Transcription saved to: {temp_file.name}")
276
+ return temp_file.name
277
 
278
  def get_model_options(pipeline_type):
279
+ """
280
+ Returns a list of model IDs based on the selected pipeline type.
281
+
282
+ Args:
283
+ pipeline_type (str): The type of pipeline ('faster-batched', 'faster-sequenced', 'transformers').
284
+
285
+ Returns:
286
+ list: A list of model IDs.
287
+ """
288
  if pipeline_type == "faster-batched":
289
+ return ["cstr/whisper-large-v3-turbo-int8_float32", "SYSTRAN/faster-whisper-large-v1", "GalaktischeGurke/primeline-whisper-large-v3-german-ct2"]
290
  elif pipeline_type == "faster-sequenced":
291
+ return ["SYSTRAN/faster-whisper-large-v1", "GalaktischeGurke/primeline-whisper-large-v3-german-ct2"]
292
  elif pipeline_type == "transformers":
293
+ return ["openai/whisper-large-v3", "openai/whisper-large-v2"]
294
  else:
295
  return []
296
 
297
  loaded_models = {}
298
 
299
  def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, download_method, start_time=None, end_time=None, verbose=False):
300
+ """
301
+ Transcribes audio from a given source using the specified pipeline and model.
302
+
303
+ Args:
304
+ input_source (str or file): URL of audio, path to local file, or uploaded file object.
305
+ pipeline_type (str): Type of pipeline to use ('faster-batched', 'faster-sequenced', or 'transformers').
306
+ model_id (str): The ID of the model to use.
307
+ dtype (str): Data type for model computations ('int8', 'float16', or 'float32').
308
+ batch_size (int): Batch size for transcription.
309
+ download_method (str): Method to use for downloading audio.
310
+ start_time (float, optional): Start time in seconds for trimming audio.
311
+ end_time (float, optional): End time in seconds for trimming audio.
312
+ verbose (bool, optional): Whether to output verbose logging.
313
+
314
+ Yields:
315
+ Tuple[str, str, str or None]: Metrics and messages, transcription text, path to transcription file.
316
+ """
317
  try:
318
  if verbose:
319
  logging.getLogger().setLevel(logging.INFO)
320
  else:
321
  logging.getLogger().setLevel(logging.WARNING)
322
+
323
  logging.info(f"Transcription parameters: pipeline_type={pipeline_type}, model_id={model_id}, dtype={dtype}, batch_size={batch_size}, download_method={download_method}")
324
  verbose_messages = f"Starting transcription with parameters:\nPipeline Type: {pipeline_type}\nModel ID: {model_id}\nData Type: {dtype}\nBatch Size: {batch_size}\nDownload Method: {download_method}\n"
325
 
 
327
  yield verbose_messages, "", None
328
 
329
  # Determine if input_source is a URL or file
330
+ audio_path = None
331
+ is_temp_file = False
332
+
333
+ if isinstance(input_source, str) and (input_source.startswith('http://') or input_source.startswith('https://')):
334
+ # Input source is a URL
335
+ audio_path, is_temp_file = download_audio(input_source, download_method)
336
+ if not audio_path or audio_path.startswith("Error"):
337
+ yield f"Error downloading audio: {audio_path}", "", None
338
+ return
339
+ elif isinstance(input_source, str) and os.path.exists(input_source):
340
+ # Input source is a local file path
341
+ audio_path = input_source
342
+ is_temp_file = False
343
+ elif hasattr(input_source, 'name'):
344
+ # Input source is an uploaded file object
345
  audio_path = input_source.name
346
+ is_temp_file = False
347
  else:
348
+ yield "No valid audio source provided.", "", None
349
  return
350
 
351
  # Convert start_time and end_time to float or None
 
353
  end_time = float(end_time) if end_time else None
354
 
355
  if start_time is not None or end_time is not None:
356
+ audio_path = trim_audio(audio_path, start_time, end_time)
357
+ is_temp_file = True # The trimmed audio is a temporary file
358
  verbose_messages += f"Audio trimmed from {start_time} to {end_time}\n"
359
  if verbose:
360
  yield verbose_messages, "", None
 
367
  else:
368
  if pipeline_type == "faster-batched":
369
  model = WhisperModel(model_id, device=device, compute_type=dtype)
370
+ model_or_pipeline = BatchedInferencePipeline(model=model)
371
  elif pipeline_type == "faster-sequenced":
372
+ model_or_pipeline = WhisperModel(model_id, device=device, compute_type=dtype)
 
373
  elif pipeline_type == "transformers":
374
  torch_dtype = torch.float16 if dtype == "float16" else torch.float32
375
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
 
377
  )
378
  model.to(device)
379
  processor = AutoProcessor.from_pretrained(model_id)
380
+ model_or_pipeline = pipeline(
381
  "automatic-speech-recognition",
382
  model=model,
383
  tokenizer=processor.tokenizer,
 
390
  )
391
  else:
392
  raise ValueError("Invalid pipeline type")
393
+ loaded_models[model_key] = model_or_pipeline # Cache the model or pipeline
394
 
395
  start_time_perf = time.time()
396
  if pipeline_type == "faster-batched":
 
433
 
434
  finally:
435
  # Clean up temporary files
436
+ if audio_path and is_temp_file and os.path.exists(audio_path):
437
  os.remove(audio_path)
438
+ if 'transcription_file' in locals() and transcription_file and os.path.exists(transcription_file):
 
 
439
  os.remove(transcription_file)
440
 
441
  with gr.Blocks() as iface:
 
478
  transcription_file = gr.File(label="Download Transcription")
479
 
480
  def update_model_dropdown(pipeline_type):
481
+ """
482
+ Updates the model dropdown choices based on the selected pipeline type.
483
+
484
+ Args:
485
+ pipeline_type (str): The selected pipeline type.
486
+
487
+ Returns:
488
+ gr.update: Updated model dropdown component.
489
+ """
490
  try:
491
  model_choices = get_model_options(pipeline_type)
492
  logging.info(f"Model choices for {pipeline_type}: {model_choices}")