Spaces:
Running
Running
import gradio as gr | |
import os | |
import time | |
import sys | |
import io | |
import tempfile | |
import subprocess | |
import requests | |
from urllib.parse import urlparse | |
from pydub import AudioSegment | |
import logging | |
import torch | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
import yt_dlp | |
class LogCapture(io.StringIO): | |
def __init__(self, callback): | |
super().__init__() | |
self.callback = callback | |
def write(self, s): | |
super().write(s) | |
self.callback(s) | |
logging.basicConfig(level=logging.INFO) | |
# Clone and install faster-whisper from GitHub | |
try: | |
subprocess.run(["git", "clone", "https://github.com/SYSTRAN/faster-whisper.git"], check=True) | |
subprocess.run(["pip", "install", "-e", "./faster-whisper"], check=True) | |
except subprocess.CalledProcessError as e: | |
logging.error(f"Error during faster-whisper installation: {e}") | |
sys.exit(1) | |
sys.path.append("./faster-whisper") | |
from faster_whisper import WhisperModel | |
from faster_whisper.transcribe import BatchedInferencePipeline | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
def download_audio(url, method_choice): | |
""" | |
Downloads audio from a given URL using the specified method. | |
Args: | |
url (str): The URL of the audio. | |
method_choice (str): The method to use for downloading audio. | |
Returns: | |
tuple: (path to the downloaded audio file, is_temp_file), or (error message, False). | |
""" | |
parsed_url = urlparse(url) | |
logging.info(f"Downloading audio from URL: {url} using method: {method_choice}") | |
try: | |
if 'youtube.com' in parsed_url.netloc or 'youtu.be' in parsed_url.netloc: | |
# Use YouTube download methods | |
audio_file = download_youtube_audio(url, method_choice) | |
else: | |
# Use direct download methods | |
audio_file = download_direct_audio(url, method_choice) | |
if not audio_file or not os.path.exists(audio_file): | |
raise Exception(f"Failed to download audio from {url}") | |
return audio_file, True # The file is a temporary file | |
except Exception as e: | |
logging.error(f"Error downloading audio: {str(e)}") | |
return f"Error: {str(e)}", False | |
def download_youtube_audio(url, method_choice): | |
""" | |
Downloads audio from a YouTube URL using the specified method. | |
Args: | |
url (str): The YouTube URL. | |
method_choice (str): The method to use for downloading ('yt-dlp', 'pytube', 'youtube-dl'). | |
Returns: | |
str: Path to the downloaded audio file, or None if failed. | |
""" | |
methods = { | |
'yt-dlp': youtube_dl_method, | |
'pytube': pytube_method, | |
'youtube-dl': youtube_dl_classic_method, | |
'yt-dlp-alt': youtube_dl_alternative_method, | |
} | |
method = methods.get(method_choice) | |
if method is None: | |
logging.warning(f"Invalid download method for YouTube: {method_choice}. Defaulting to 'yt-dlp'.") | |
method = youtube_dl_method | |
try: | |
logging.info(f"Attempting to download YouTube audio using {method_choice}") | |
return method(url) | |
except Exception as e: | |
logging.error(f"Error downloading using {method_choice}: {str(e)}") | |
return None | |
def youtube_dl_method(url): | |
logging.info("Using yt-dlp method") | |
try: | |
ydl_opts = { | |
'format': 'bestaudio/best', | |
'postprocessors': [{ | |
'key': 'FFmpegExtractAudio', | |
'preferredcodec': 'mp3', | |
'preferredquality': '192', | |
}], | |
'outtmpl': '%(id)s.%(ext)s', | |
} | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
info = ydl.extract_info(url, download=True) | |
output_file = f"{info['id']}.mp3" | |
logging.info(f"Downloaded YouTube audio: {output_file}") | |
return output_file | |
except Exception as e: | |
logging.error(f"Error in youtube_dl_method: {str(e)}") | |
return None | |
def pytube_method(url): | |
logging.info("Using pytube method") | |
from pytube import YouTube | |
yt = YouTube(url) | |
audio_stream = yt.streams.filter(only_audio=True).first() | |
out_file = audio_stream.download() | |
base, ext = os.path.splitext(out_file) | |
new_file = base + '.mp3' | |
os.rename(out_file, new_file) | |
logging.info(f"Downloaded and converted audio to: {new_file}") | |
return new_file | |
def youtube_dl_classic_method(url): | |
logging.info("Using youtube-dl classic method") | |
ydl_opts = { | |
'format': 'bestaudio/best', | |
'postprocessors': [{ | |
'key': 'FFmpegExtractAudio', | |
'preferredcodec': 'mp3', | |
'preferredquality': '192', | |
}], | |
'outtmpl': '%(id)s.%(ext)s', | |
} | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
info = ydl.extract_info(url, download=True) | |
logging.info(f"Downloaded YouTube audio: {info['id']}.mp3") | |
return f"{info['id']}.mp3" | |
def youtube_dl_alternative_method(url): | |
logging.info("Using yt-dlp alternative method") | |
ydl_opts = { | |
'format': 'bestaudio/best', | |
'postprocessors': [{ | |
'key': 'FFmpegExtractAudio', | |
'preferredcodec': 'mp3', | |
'preferredquality': '192', | |
}], | |
'outtmpl': '%(id)s.%(ext)s', | |
'no_warnings': True, | |
'quiet': True, | |
'no_check_certificate': True, | |
'prefer_insecure': True, | |
} | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
info = ydl.extract_info(url, download=True) | |
logging.info(f"Downloaded YouTube audio: {info['id']}.mp3") | |
return f"{info['id']}.mp3" | |
def ffmpeg_method(url): | |
logging.info("Using ffmpeg method") | |
output_file = tempfile.mktemp(suffix='.mp3') | |
command = ['ffmpeg', '-i', url, '-vn', '-acodec', 'libmp3lame', '-q:a', '2', output_file] | |
subprocess.run(command, check=True, capture_output=True) | |
logging.info(f"Downloaded and converted audio to: {output_file}") | |
return output_file | |
def aria2_method(url): | |
logging.info("Using aria2 method") | |
output_file = tempfile.mktemp(suffix='.mp3') | |
command = ['aria2c', '--split=4', '--max-connection-per-server=4', '--out', output_file, url] | |
subprocess.run(command, check=True, capture_output=True) | |
logging.info(f"Downloaded audio to: {output_file}") | |
return output_file | |
def download_direct_audio(url, method_choice): | |
""" | |
Downloads audio from a direct URL using the specified method. | |
Args: | |
url (str): The direct URL of the audio file. | |
method_choice (str): The method to use for downloading ('wget', 'requests'). | |
Returns: | |
str: Path to the downloaded audio file, or None if failed. | |
""" | |
logging.info(f"Downloading direct audio from: {url} using method: {method_choice}") | |
if method_choice == 'wget': | |
return wget_method(url) | |
else: | |
try: | |
response = requests.get(url, stream=True) | |
if response.status_code == 200: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file: | |
for chunk in response.iter_content(chunk_size=8192): | |
if chunk: | |
temp_file.write(chunk) | |
logging.info(f"Downloaded direct audio to: {temp_file.name}") | |
return temp_file.name | |
else: | |
raise Exception(f"Failed to download audio from {url} with status code {response.status_code}") | |
except Exception as e: | |
logging.error(f"Error downloading direct audio: {str(e)}") | |
return None | |
def wget_method(url): | |
logging.info("Using wget method") | |
output_file = tempfile.mktemp(suffix='.mp3') | |
command = ['wget', '-O', output_file, url] | |
subprocess.run(command, check=True, capture_output=True) | |
logging.info(f"Downloaded audio to: {output_file}") | |
return output_file | |
def trim_audio(audio_path, start_time, end_time): | |
""" | |
Trims an audio file to the specified start and end times. | |
Args: | |
audio_path (str): Path to the audio file. | |
start_time (float): Start time in seconds. | |
end_time (float): End time in seconds. | |
Returns: | |
str: Path to the trimmed audio file. | |
Raises: | |
gr.Error: If invalid start or end times are provided. | |
""" | |
try: | |
logging.info(f"Trimming audio from {start_time} to {end_time}") | |
audio = AudioSegment.from_file(audio_path) | |
audio_duration = len(audio) / 1000 # Duration in seconds | |
# Default start and end times if None | |
if start_time is None: | |
start_time = 0 | |
if end_time is None or end_time > audio_duration: | |
end_time = audio_duration | |
# Validate times | |
if start_time < 0 or end_time <= 0: | |
raise gr.Error("Start time and end time must be positive.") | |
if start_time >= end_time: | |
raise gr.Error("End time must be greater than start time.") | |
if start_time > audio_duration: | |
raise gr.Error("Start time exceeds audio duration.") | |
trimmed_audio = audio[start_time * 1000:end_time * 1000] | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio_file: | |
trimmed_audio.export(temp_audio_file.name, format="wav") | |
logging.info(f"Trimmed audio saved to: {temp_audio_file.name}") | |
return temp_audio_file.name | |
except Exception as e: | |
logging.error(f"Error trimming audio: {str(e)}") | |
raise gr.Error(f"Error trimming audio: {str(e)}") | |
def save_transcription(transcription): | |
""" | |
Saves the transcription text to a temporary file. | |
Args: | |
transcription (str): The transcription text. | |
Returns: | |
str: The path to the transcription file. | |
""" | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.txt', mode='w', encoding='utf-8') as temp_file: | |
temp_file.write(transcription) | |
logging.info(f"Transcription saved to: {temp_file.name}") | |
return temp_file.name | |
def get_model_options(pipeline_type): | |
""" | |
Returns a list of model IDs based on the selected pipeline type. | |
Args: | |
pipeline_type (str): The type of pipeline ('faster-batched', 'faster-sequenced', 'transformers'). | |
Returns: | |
list: A list of model IDs. | |
""" | |
if pipeline_type == "faster-batched": | |
return ["cstr/whisper-large-v3-turbo-int8_float32", "SYSTRAN/faster-whisper-large-v1", "GalaktischeGurke/primeline-whisper-large-v3-german-ct2"] | |
elif pipeline_type == "faster-sequenced": | |
return ["SYSTRAN/faster-whisper-large-v1", "GalaktischeGurke/primeline-whisper-large-v3-german-ct2"] | |
elif pipeline_type == "transformers": | |
return ["openai/whisper-large-v3", "openai/whisper-large-v2"] | |
else: | |
return [] | |
loaded_models = {} | |
def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, download_method, start_time=None, end_time=None, verbose=False): | |
""" | |
Transcribes audio from a given source using the specified pipeline and model. | |
Args: | |
input_source (str or file): URL of audio, path to local file, or uploaded file object. | |
pipeline_type (str): Type of pipeline to use ('faster-batched', 'faster-sequenced', or 'transformers'). | |
model_id (str): The ID of the model to use. | |
dtype (str): Data type for model computations ('int8', 'float16', or 'float32'). | |
batch_size (int): Batch size for transcription. | |
download_method (str): Method to use for downloading audio. | |
start_time (float, optional): Start time in seconds for trimming audio. | |
end_time (float, optional): End time in seconds for trimming audio. | |
verbose (bool, optional): Whether to output verbose logging. | |
Yields: | |
Tuple[str, str, str or None]: Metrics and messages, transcription text, path to transcription file. | |
""" | |
try: | |
if verbose: | |
logging.getLogger().setLevel(logging.INFO) | |
else: | |
logging.getLogger().setLevel(logging.WARNING) | |
logging.info(f"Transcription parameters: pipeline_type={pipeline_type}, model_id={model_id}, dtype={dtype}, batch_size={batch_size}, download_method={download_method}") | |
verbose_messages = f"Starting transcription with parameters:\nPipeline Type: {pipeline_type}\nModel ID: {model_id}\nData Type: {dtype}\nBatch Size: {batch_size}\nDownload Method: {download_method}\n" | |
if verbose: | |
yield verbose_messages, "", None | |
# Determine if input_source is a URL or file | |
audio_path = None | |
is_temp_file = False | |
if isinstance(input_source, str) and (input_source.startswith('http://') or input_source.startswith('https://')): | |
# Input source is a URL | |
audio_path, is_temp_file = download_audio(input_source, download_method) | |
if not audio_path or audio_path.startswith("Error"): | |
yield f"Error downloading audio: {audio_path}", "", None | |
return | |
elif isinstance(input_source, str) and os.path.exists(input_source): | |
# Input source is a local file path | |
audio_path = input_source | |
is_temp_file = False | |
elif hasattr(input_source, 'name'): | |
# Input source is an uploaded file object | |
audio_path = input_source.name | |
is_temp_file = False | |
else: | |
yield "No valid audio source provided.", "", None | |
return | |
# Convert start_time and end_time to float or None | |
start_time = float(start_time) if start_time else None | |
end_time = float(end_time) if end_time else None | |
if start_time is not None or end_time is not None: | |
audio_path = trim_audio(audio_path, start_time, end_time) | |
is_temp_file = True # The trimmed audio is a temporary file | |
verbose_messages += f"Audio trimmed from {start_time} to {end_time}\n" | |
if verbose: | |
yield verbose_messages, "", None | |
# Model caching | |
model_key = (pipeline_type, model_id, dtype) | |
if model_key in loaded_models: | |
model_or_pipeline = loaded_models[model_key] | |
logging.info("Loaded model from cache") | |
else: | |
if pipeline_type == "faster-batched": | |
model = WhisperModel(model_id, device=device, compute_type=dtype) | |
model_or_pipeline = BatchedInferencePipeline(model=model) | |
elif pipeline_type == "faster-sequenced": | |
model_or_pipeline = WhisperModel(model_id, device=device, compute_type=dtype) | |
elif pipeline_type == "transformers": | |
torch_dtype = torch.float16 if dtype == "float16" else torch.float32 | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True | |
) | |
model.to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
model_or_pipeline = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
chunk_length_s=30, | |
batch_size=batch_size, | |
return_timestamps=True, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
else: | |
raise ValueError("Invalid pipeline type") | |
loaded_models[model_key] = model_or_pipeline # Cache the model or pipeline | |
start_time_perf = time.time() | |
if pipeline_type == "faster-batched": | |
segments, info = model_or_pipeline.transcribe(audio_path, batch_size=batch_size) | |
elif pipeline_type == "faster-sequenced": | |
segments, info = model_or_pipeline.transcribe(audio_path) | |
else: | |
result = model_or_pipeline(audio_path) | |
segments = result["chunks"] | |
end_time_perf = time.time() | |
transcription_time = end_time_perf - start_time_perf | |
audio_file_size = os.path.getsize(audio_path) / (1024 * 1024) | |
metrics_output = ( | |
f"Transcription time: {transcription_time:.2f} seconds\n" | |
f"Audio file size: {audio_file_size:.2f} MB\n" | |
) | |
if verbose: | |
yield verbose_messages + metrics_output, "", None | |
transcription = "" | |
for segment in segments: | |
if pipeline_type in ["faster-batched", "faster-sequenced"]: | |
transcription_segment = f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}\n" | |
else: | |
transcription_segment = f"[{segment['timestamp'][0]:.2f}s -> {segment['timestamp'][1]:.2f}s] {segment['text']}\n" | |
transcription += transcription_segment | |
if verbose: | |
yield verbose_messages + metrics_output, transcription, None | |
transcription_file = save_transcription(transcription) | |
yield verbose_messages + metrics_output, transcription, transcription_file | |
except Exception as e: | |
logging.error(f"An error occurred during transcription: {str(e)}") | |
yield f"An error occurred: {str(e)}", "", None | |
finally: | |
# Clean up temporary files | |
if audio_path and is_temp_file and os.path.exists(audio_path): | |
os.remove(audio_path) | |
if 'transcription_file' in locals() and transcription_file and os.path.exists(transcription_file): | |
os.remove(transcription_file) | |
with gr.Blocks() as iface: | |
gr.Markdown("# Multi-Pipeline Transcription") | |
gr.Markdown("Transcribe audio using multiple pipelines and models.") | |
with gr.Row(): | |
#input_source = gr.File(label="Audio Source (Upload a file or enter a URL/YouTube URL)") | |
input_source = gr.Textbox(label="Audio Source (Upload a file or enter a URL/YouTube URL)") | |
pipeline_type = gr.Dropdown( | |
choices=["faster-batched", "faster-sequenced", "transformers"], | |
label="Pipeline Type", | |
value="faster-batched" | |
) | |
model_id = gr.Dropdown( | |
label="Model", | |
choices=get_model_options("faster-batched"), | |
value=get_model_options("faster-batched")[0] | |
) | |
with gr.Row(): | |
dtype = gr.Dropdown(choices=["int8", "float16", "float32"], label="Data Type", value="int8") | |
batch_size = gr.Slider(minimum=1, maximum=32, step=1, value=16, label="Batch Size") | |
download_method = gr.Dropdown( | |
choices=["yt-dlp", "pytube", "youtube-dl", "yt-dlp-alt", "ffmpeg", "aria2", "wget"], | |
label="Download Method", | |
value="yt-dlp" | |
) | |
with gr.Row(): | |
start_time = gr.Number(label="Start Time (seconds)", value=None, minimum=0) | |
end_time = gr.Number(label="End Time (seconds)", value=None, minimum=0) | |
verbose = gr.Checkbox(label="Verbose Output", value=True) # Set to True by default | |
transcribe_button = gr.Button("Transcribe") | |
with gr.Row(): | |
metrics_output = gr.Textbox(label="Transcription Metrics and Verbose Messages", lines=10) | |
transcription_output = gr.Textbox(label="Transcription", lines=10) | |
transcription_file = gr.File(label="Download Transcription") | |
def update_model_dropdown(pipeline_type): | |
""" | |
Updates the model dropdown choices based on the selected pipeline type. | |
Args: | |
pipeline_type (str): The selected pipeline type. | |
Returns: | |
gr.update: Updated model dropdown component. | |
""" | |
try: | |
model_choices = get_model_options(pipeline_type) | |
logging.info(f"Model choices for {pipeline_type}: {model_choices}") | |
if model_choices: | |
return gr.update(choices=model_choices, value=model_choices[0], visible=True) | |
else: | |
return gr.update(choices=["No models available"], value=None, visible=False) | |
except Exception as e: | |
logging.error(f"Error in update_model_dropdown: {str(e)}") | |
return gr.update(choices=["Error"], value="Error", visible=True) | |
# event handler for pipeline_type change | |
pipeline_type.change(update_model_dropdown, inputs=[pipeline_type], outputs=[model_id]) | |
def transcribe_with_progress(*args): | |
for result in transcribe_audio(*args): | |
yield result | |
transcribe_button.click( | |
transcribe_with_progress, | |
inputs=[input_source, pipeline_type, model_id, dtype, batch_size, download_method, start_time, end_time, verbose], | |
outputs=[metrics_output, transcription_output, transcription_file] | |
) | |
gr.Examples( | |
examples=[ | |
["https://www.youtube.com/watch?v=daQ_hqA6HDo", "faster-batched", "cstr/whisper-large-v3-turbo-int8_float32", "int8", 16, "yt-dlp", None, None, True], | |
["https://mcdn.podbean.com/mf/web/dir5wty678b6g4vg/HoP_453_-_The_Price_is_Right_-_Law_and_Economics_in_the_Second_Scholastic5yxzh.mp3", "faster-sequenced", "deepdml/faster-whisper-large-v3-turbo-ct2", "float16", 1, "ffmpeg", 0, 300, True], | |
["path/to/local/audio.mp3", "transformers", "openai/whisper-large-v3", "float16", 16, "yt-dlp", 60, 180, True] | |
], | |
inputs=[input_source, pipeline_type, model_id, dtype, batch_size, download_method, start_time, end_time, verbose], | |
) | |
iface.launch() |