cstr commited on
Commit
fa54222
1 Parent(s): 4347dae
Files changed (1) hide show
  1. app.py +79 -45
app.py CHANGED
@@ -45,11 +45,11 @@ logging.info(f"Using device: {device}")
45
  def download_audio(url, method_choice):
46
  """
47
  Downloads audio from a given URL using the specified method.
48
-
49
  Args:
50
  url (str): The URL of the audio.
51
  method_choice (str): The method to use for downloading audio.
52
-
53
  Returns:
54
  tuple: (path to the downloaded audio file, is_temp_file), or (error message, False).
55
  """
@@ -64,11 +64,14 @@ def download_audio(url, method_choice):
64
  audio_file = download_direct_audio(url, method_choice)
65
 
66
  if not audio_file or not os.path.exists(audio_file):
67
- raise Exception(f"Failed to download audio from {url}")
 
 
68
  return audio_file, True
69
  except Exception as e:
70
- logging.error(f"Error downloading audio: {str(e)}")
71
- return f"Error: {str(e)}", False
 
72
 
73
  def download_youtube_audio(url, method_choice):
74
  """
@@ -114,15 +117,20 @@ def yt_dlp_method(url):
114
  'preferredcodec': 'mp3',
115
  'preferredquality': '192',
116
  }],
117
- 'quiet': True,
118
  'no_warnings': True,
 
119
  }
120
- with yt_dlp.YoutubeDL(ydl_opts) as ydl:
121
- info = ydl.extract_info(url, download=True)
122
- output_file = ydl.prepare_filename(info)
123
- output_file = os.path.splitext(output_file)[0] + '.mp3'
124
- logging.info(f"Downloaded YouTube audio: {output_file}")
125
- return output_file
 
 
 
 
126
 
127
  def pytube_method(url):
128
  """
@@ -136,15 +144,24 @@ def pytube_method(url):
136
  """
137
  logging.info("Using pytube method")
138
  from pytube import YouTube
139
- yt = YouTube(url)
140
- audio_stream = yt.streams.filter(only_audio=True).first()
141
- temp_dir = tempfile.mkdtemp()
142
- out_file = audio_stream.download(output_path=temp_dir)
143
- base, ext = os.path.splitext(out_file)
144
- new_file = base + '.mp3'
145
- os.rename(out_file, new_file)
146
- logging.info(f"Downloaded and converted audio to: {new_file}")
147
- return new_file
 
 
 
 
 
 
 
 
 
148
 
149
  def download_rtsp_audio(url):
150
  """
@@ -173,11 +190,11 @@ def download_rtsp_audio(url):
173
  def download_direct_audio(url, method_choice):
174
  """
175
  Downloads audio from a direct URL using the specified method.
176
-
177
  Args:
178
  url (str): The direct URL of the audio file.
179
  method_choice (str): The method to use for downloading.
180
-
181
  Returns:
182
  str: Path to the downloaded audio file, or None if failed.
183
  """
@@ -191,9 +208,14 @@ def download_direct_audio(url, method_choice):
191
  }
192
  method = methods.get(method_choice, requests_method)
193
  try:
194
- return method(url)
 
 
 
 
 
195
  except Exception as e:
196
- logging.error(f"Error downloading direct audio: {str(e)}")
197
  return None
198
 
199
  def requests_method(url):
@@ -402,10 +424,10 @@ loaded_models = {}
402
 
403
  def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, download_method, start_time=None, end_time=None, verbose=False):
404
  """
405
- Transcribes audio from a given URL using the specified pipeline and model.
406
 
407
  Args:
408
- input_source (str): URL of the audio.
409
  pipeline_type (str): Type of pipeline to use ('faster-batched', 'faster-sequenced', or 'transformers').
410
  model_id (str): The ID of the model to use.
411
  dtype (str): Data type for model computations ('int8', 'float16', or 'float32').
@@ -430,22 +452,36 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
430
  if verbose:
431
  yield verbose_messages, "", None
432
 
433
- # Input source is expected to be a URL
434
- if not input_source or not input_source.strip():
435
- yield "No audio URL provided.", "", None
436
- return
437
-
438
- # Download the audio from the URL
439
- audio_path, is_temp_file = download_audio(input_source, download_method)
440
- if not audio_path or audio_path.startswith("Error"):
441
- yield f"Error downloading audio: {audio_path}", "", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
  return
443
 
444
  # Convert start_time and end_time to float or None
445
  start_time = float(start_time) if start_time else None
446
  end_time = float(end_time) if end_time else None
447
 
448
- # Trim the audio if start or end times are provided
449
  if start_time is not None or end_time is not None:
450
  audio_path = trim_audio(audio_path, start_time, end_time)
451
  is_temp_file = True # The trimmed audio is a temporary file
@@ -459,7 +495,6 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
459
  model_or_pipeline = loaded_models[model_key]
460
  logging.info("Loaded model from cache")
461
  else:
462
- # Load the appropriate model or pipeline based on the pipeline type
463
  if pipeline_type == "faster-batched":
464
  model = WhisperModel(model_id, device=device, compute_type=dtype)
465
  model_or_pipeline = BatchedInferencePipeline(model=model)
@@ -489,10 +524,11 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
489
  device=device,
490
  )
491
  else:
492
- raise ValueError("Invalid pipeline type")
 
 
493
  loaded_models[model_key] = model_or_pipeline # Cache the model or pipeline
494
 
495
- # Perform the transcription
496
  start_time_perf = time.time()
497
  if pipeline_type == "faster-batched":
498
  segments, info = model_or_pipeline.transcribe(audio_path, batch_size=batch_size)
@@ -503,7 +539,6 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
503
  segments = result["chunks"]
504
  end_time_perf = time.time()
505
 
506
- # Calculate metrics
507
  transcription_time = end_time_perf - start_time_perf
508
  audio_file_size = os.path.getsize(audio_path) / (1024 * 1024)
509
 
@@ -515,7 +550,6 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
515
  if verbose:
516
  yield verbose_messages + metrics_output, "", None
517
 
518
- # Compile the transcription text
519
  transcription = ""
520
 
521
  for segment in segments:
@@ -527,13 +561,13 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
527
  if verbose:
528
  yield verbose_messages + metrics_output, transcription, None
529
 
530
- # Save the transcription to a file
531
  transcription_file = save_transcription(transcription)
532
  yield verbose_messages + metrics_output, transcription, transcription_file
533
 
534
  except Exception as e:
535
- logging.error(f"An error occurred during transcription: {str(e)}")
536
- yield f"An error occurred: {str(e)}", "", None
 
537
 
538
  finally:
539
  # Clean up temporary audio files
 
45
  def download_audio(url, method_choice):
46
  """
47
  Downloads audio from a given URL using the specified method.
48
+
49
  Args:
50
  url (str): The URL of the audio.
51
  method_choice (str): The method to use for downloading audio.
52
+
53
  Returns:
54
  tuple: (path to the downloaded audio file, is_temp_file), or (error message, False).
55
  """
 
64
  audio_file = download_direct_audio(url, method_choice)
65
 
66
  if not audio_file or not os.path.exists(audio_file):
67
+ error_msg = f"Failed to download audio from {url} using method {method_choice}"
68
+ logging.error(error_msg)
69
+ return error_msg, False
70
  return audio_file, True
71
  except Exception as e:
72
+ error_msg = f"Error downloading audio from {url} using method {method_choice}: {str(e)}"
73
+ logging.error(error_msg)
74
+ return error_msg, False
75
 
76
  def download_youtube_audio(url, method_choice):
77
  """
 
117
  'preferredcodec': 'mp3',
118
  'preferredquality': '192',
119
  }],
120
+ 'quiet': False,
121
  'no_warnings': True,
122
+ 'logger': logging.getLogger(), # Capture yt-dlp logs
123
  }
124
+ try:
125
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
126
+ info = ydl.extract_info(url, download=True)
127
+ output_file = ydl.prepare_filename(info)
128
+ output_file = os.path.splitext(output_file)[0] + '.mp3'
129
+ logging.info(f"Downloaded YouTube audio: {output_file}")
130
+ return output_file
131
+ except Exception as e:
132
+ logging.error(f"Error in yt_dlp_method: {str(e)}")
133
+ raise Exception(f"yt-dlp failed to download audio: {str(e)}")
134
 
135
  def pytube_method(url):
136
  """
 
144
  """
145
  logging.info("Using pytube method")
146
  from pytube import YouTube
147
+ try:
148
+ yt = YouTube(url)
149
+ audio_stream = yt.streams.filter(only_audio=True).first()
150
+ if audio_stream is None:
151
+ error_msg = "No audio streams available with pytube."
152
+ logging.error(error_msg)
153
+ raise Exception(error_msg)
154
+ temp_dir = tempfile.mkdtemp()
155
+ out_file = audio_stream.download(output_path=temp_dir)
156
+ base, ext = os.path.splitext(out_file)
157
+ new_file = base + '.mp3'
158
+ os.rename(out_file, new_file)
159
+ logging.info(f"Downloaded and converted audio to: {new_file}")
160
+ return new_file
161
+ except Exception as e:
162
+ logging.error(f"Error in pytube_method: {str(e)}")
163
+ raise Exception(f"pytube failed to download audio: {str(e)}")
164
+
165
 
166
  def download_rtsp_audio(url):
167
  """
 
190
  def download_direct_audio(url, method_choice):
191
  """
192
  Downloads audio from a direct URL using the specified method.
193
+
194
  Args:
195
  url (str): The direct URL of the audio file.
196
  method_choice (str): The method to use for downloading.
197
+
198
  Returns:
199
  str: Path to the downloaded audio file, or None if failed.
200
  """
 
208
  }
209
  method = methods.get(method_choice, requests_method)
210
  try:
211
+ audio_file = method(url)
212
+ if not audio_file or not os.path.exists(audio_file):
213
+ error_msg = f"Failed to download direct audio from {url} using method {method_choice}"
214
+ logging.error(error_msg)
215
+ return None
216
+ return audio_file
217
  except Exception as e:
218
+ logging.error(f"Error downloading direct audio with {method_choice}: {str(e)}")
219
  return None
220
 
221
  def requests_method(url):
 
424
 
425
  def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, download_method, start_time=None, end_time=None, verbose=False):
426
  """
427
+ Transcribes audio from a given source using the specified pipeline and model.
428
 
429
  Args:
430
+ input_source (str or file): URL of audio, path to local file, or uploaded file object.
431
  pipeline_type (str): Type of pipeline to use ('faster-batched', 'faster-sequenced', or 'transformers').
432
  model_id (str): The ID of the model to use.
433
  dtype (str): Data type for model computations ('int8', 'float16', or 'float32').
 
452
  if verbose:
453
  yield verbose_messages, "", None
454
 
455
+ # Determine if input_source is a URL or file
456
+ audio_path = None
457
+ is_temp_file = False
458
+
459
+ if isinstance(input_source, str) and (input_source.startswith('http://') or input_source.startswith('https://')):
460
+ # Input source is a URL
461
+ audio_path, is_temp_file = download_audio(input_source, download_method)
462
+ if not audio_path or audio_path.startswith("Error"):
463
+ error_msg = f"Error downloading audio: {audio_path}"
464
+ logging.error(error_msg)
465
+ yield error_msg, "", None
466
+ return
467
+ elif isinstance(input_source, str) and os.path.exists(input_source):
468
+ # Input source is a local file path
469
+ audio_path = input_source
470
+ is_temp_file = False
471
+ elif hasattr(input_source, 'name'):
472
+ # Input source is an uploaded file object
473
+ audio_path = input_source.name
474
+ is_temp_file = False
475
+ else:
476
+ error_msg = "No valid audio source provided."
477
+ logging.error(error_msg)
478
+ yield error_msg, "", None
479
  return
480
 
481
  # Convert start_time and end_time to float or None
482
  start_time = float(start_time) if start_time else None
483
  end_time = float(end_time) if end_time else None
484
 
 
485
  if start_time is not None or end_time is not None:
486
  audio_path = trim_audio(audio_path, start_time, end_time)
487
  is_temp_file = True # The trimmed audio is a temporary file
 
495
  model_or_pipeline = loaded_models[model_key]
496
  logging.info("Loaded model from cache")
497
  else:
 
498
  if pipeline_type == "faster-batched":
499
  model = WhisperModel(model_id, device=device, compute_type=dtype)
500
  model_or_pipeline = BatchedInferencePipeline(model=model)
 
524
  device=device,
525
  )
526
  else:
527
+ error_msg = "Invalid pipeline type"
528
+ logging.error(error_msg)
529
+ raise ValueError(error_msg)
530
  loaded_models[model_key] = model_or_pipeline # Cache the model or pipeline
531
 
 
532
  start_time_perf = time.time()
533
  if pipeline_type == "faster-batched":
534
  segments, info = model_or_pipeline.transcribe(audio_path, batch_size=batch_size)
 
539
  segments = result["chunks"]
540
  end_time_perf = time.time()
541
 
 
542
  transcription_time = end_time_perf - start_time_perf
543
  audio_file_size = os.path.getsize(audio_path) / (1024 * 1024)
544
 
 
550
  if verbose:
551
  yield verbose_messages + metrics_output, "", None
552
 
 
553
  transcription = ""
554
 
555
  for segment in segments:
 
561
  if verbose:
562
  yield verbose_messages + metrics_output, transcription, None
563
 
 
564
  transcription_file = save_transcription(transcription)
565
  yield verbose_messages + metrics_output, transcription, transcription_file
566
 
567
  except Exception as e:
568
+ error_msg = f"An error occurred during transcription: {str(e)}"
569
+ logging.error(error_msg)
570
+ yield error_msg, "", None
571
 
572
  finally:
573
  # Clean up temporary audio files