import torch
import gradio as gr
from faster_whisper import WhisperModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from pydub import AudioSegment
import yt_dlp as youtube_dl
import tempfile
from transformers.pipelines.audio_utils import ffmpeg_read
from gradio.components import Audio, Dropdown, Radio, Textbox
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"


# Paramètres
FILE_LIMIT_MB = 1000
YT_LENGTH_LIMIT_S = 3600  # Limite de 1 heure pour les vidéos YouTube

# Charger les codes de langue
from flores200_codes import flores_codes

# Fonction pour déterminer le device
def set_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

device = set_device()

# Charger les modèles une seule fois
model_dict = {}
def load_models():
    global model_dict
    if not model_dict:
        model_name_dict = {
            #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
            'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
            #'nllb-1.3B': 'facebook/nllb-200-1.3B',
            #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
            #'nllb-3.3B': 'facebook/nllb-200-3.3B',
            # 'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
        }
        for call_name, real_name in model_name_dict.items():
            model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
            tokenizer = AutoTokenizer.from_pretrained(real_name)
            model_dict[call_name+'_model'] = model
            model_dict[call_name+'_tokenizer'] = tokenizer

load_models()

# Fonction pour la transcription
def transcribe_audio(audio_file):
    model_size = "large-v2"
    model = WhisperModel(model_size)
    # model = WhisperModel(model_size, device=device, compute_type="int8")
    segments, _ = model.transcribe(audio_file, beam_size=1)
    transcriptions = [("[%.2fs -> %.2fs]" % (seg.start, seg.end), seg.text) for seg in segments]
    return transcriptions

# Fonction pour la traduction
def traduction(text, source_lang, target_lang):
    model_name = "nllb-distilled-600M"
    model = model_dict[model_name + "_model"]
    tokenizer = model_dict[model_name + "_tokenizer"]
    translator = pipeline("translation", model=model, tokenizer=tokenizer)
    return translator(text, src_lang=flores_codes[source_lang], tgt_lang=flores_codes[target_lang])[0]["translation_text"]

# Fonction principale
def full_transcription_and_translation(audio_file, source_lang, target_lang):
    if audio_file.startswith("http"):
        audio_file = download_yt_audio(audio_file)
    transcriptions = transcribe_audio(audio_file)
    translations = [(timestamp, traduction(text, source_lang, target_lang)) for timestamp, text in transcriptions]
    return transcriptions, translations

# Téléchargement audio YouTube
def download_yt_audio(yt_url):
    with tempfile.NamedTemporaryFile(suffix='.mp3') as f:
        ydl_opts = {
            'format': 'bestaudio/best',
            'outtmpl': f.name,
            'postprocessors': [{
                'key': 'FFmpegExtractAudio',
                'preferredcodec': 'mp3',
                'preferredquality': '192',
            }],
        }
        with youtube_dl.YoutubeDL(ydl_opts) as ydl:
            ydl.download([yt_url])
        return f.name

lang_codes = list(flores_codes.keys())

# Interface Gradio
def gradio_interface(audio_file, source_lang, target_lang):
    if audio_file.startswith("http"):
        audio_file = download_yt_audio(audio_file)
    transcriptions, translations = full_transcription_and_translation(audio_file, source_lang, target_lang)
    transcribed_text = '\n'.join([f"{timestamp}: {text}" for timestamp, text in transcriptions])
    translated_text = '\n'.join([f"{timestamp}: {text}" for timestamp, text in translations])
    return transcribed_text, translated_text

iface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Audio(type="filepath"),
        gr.Dropdown(lang_codes, value='French', label='Source Language'),
        gr.Dropdown(lang_codes, value='English', label='Target Language'),
    ],
    outputs=[
        gr.Textbox(label="Transcribed Text"),
        gr.Textbox(label="Translated Text")
    ]
)

iface.launch()