Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from faster_whisper import WhisperModel | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
from pydub import AudioSegment | |
import yt_dlp as youtube_dl | |
import tempfile | |
from transformers.pipelines.audio_utils import ffmpeg_read | |
from gradio.components import Audio, Dropdown, Radio, Textbox | |
import os | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
# Paramètres | |
FILE_LIMIT_MB = 1000 | |
YT_LENGTH_LIMIT_S = 3600 # Limite de 1 heure pour les vidéos YouTube | |
# Charger les codes de langue | |
from flores200_codes import flores_codes | |
# Fonction pour déterminer le device | |
def set_device(): | |
return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
device = set_device() | |
# Charger les modèles une seule fois | |
model_dict = {} | |
def load_models(): | |
global model_dict | |
if not model_dict: | |
model_name_dict = { | |
#'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B', | |
'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M', | |
#'nllb-1.3B': 'facebook/nllb-200-1.3B', | |
#'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B', | |
#'nllb-3.3B': 'facebook/nllb-200-3.3B', | |
# 'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M', | |
} | |
for call_name, real_name in model_name_dict.items(): | |
model = AutoModelForSeq2SeqLM.from_pretrained(real_name) | |
tokenizer = AutoTokenizer.from_pretrained(real_name) | |
model_dict[call_name+'_model'] = model | |
model_dict[call_name+'_tokenizer'] = tokenizer | |
load_models() | |
# Fonction pour la transcription | |
def transcribe_audio(audio_file): | |
model_size = "large-v2" | |
model = WhisperModel(model_size) | |
# model = WhisperModel(model_size, device=device, compute_type="int8") | |
segments, _ = model.transcribe(audio_file, beam_size=1) | |
transcriptions = [("[%.2fs -> %.2fs]" % (seg.start, seg.end), seg.text) for seg in segments] | |
return transcriptions | |
# Fonction pour la traduction | |
def traduction(text, source_lang, target_lang): | |
model_name = "nllb-distilled-600M" | |
model = model_dict[model_name + "_model"] | |
tokenizer = model_dict[model_name + "_tokenizer"] | |
translator = pipeline("translation", model=model, tokenizer=tokenizer) | |
return translator(text, src_lang=flores_codes[source_lang], tgt_lang=flores_codes[target_lang])[0]["translation_text"] | |
# Fonction principale | |
def full_transcription_and_translation(audio_file, source_lang, target_lang): | |
if audio_file.startswith("http"): | |
audio_file = download_yt_audio(audio_file) | |
transcriptions = transcribe_audio(audio_file) | |
translations = [(timestamp, traduction(text, source_lang, target_lang)) for timestamp, text in transcriptions] | |
return transcriptions, translations | |
# Téléchargement audio YouTube | |
def download_yt_audio(yt_url): | |
with tempfile.NamedTemporaryFile(suffix='.mp3') as f: | |
ydl_opts = { | |
'format': 'bestaudio/best', | |
'outtmpl': f.name, | |
'postprocessors': [{ | |
'key': 'FFmpegExtractAudio', | |
'preferredcodec': 'mp3', | |
'preferredquality': '192', | |
}], | |
} | |
with youtube_dl.YoutubeDL(ydl_opts) as ydl: | |
ydl.download([yt_url]) | |
return f.name | |
lang_codes = list(flores_codes.keys()) | |
# Interface Gradio | |
def gradio_interface(audio_file, source_lang, target_lang): | |
if audio_file.startswith("http"): | |
audio_file = download_yt_audio(audio_file) | |
transcriptions, translations = full_transcription_and_translation(audio_file, source_lang, target_lang) | |
transcribed_text = '\n'.join([f"{timestamp}: {text}" for timestamp, text in transcriptions]) | |
translated_text = '\n'.join([f"{timestamp}: {text}" for timestamp, text in translations]) | |
return transcribed_text, translated_text | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Audio(type="filepath"), | |
gr.Dropdown(lang_codes, value='French', label='Source Language'), | |
gr.Dropdown(lang_codes, value='English', label='Target Language'), | |
], | |
outputs=[ | |
gr.Textbox(label="Transcribed Text"), | |
gr.Textbox(label="Translated Text") | |
] | |
) | |
iface.launch() | |