File size: 4,246 Bytes
6c226f9
 
29d0597
 
 
8e787d3
d790c0b
29d0597
 
88183ad
29d0597
6c226f9
 
29d0597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b7db2a
 
29d0597
 
 
 
 
 
 
 
 
 
 
8b7db2a
29d0597
 
8b7db2a
 
 
29d0597
 
 
 
 
8b7db2a
33dbb9e
 
 
 
 
 
 
 
 
 
 
 
8b7db2a
33dbb9e
 
7fd7c54
29d0597
 
33dbb9e
 
29d0597
 
 
 
7fd7c54
8b7db2a
e5df979
 
8b7db2a
 
 
e5df979
 
8b7db2a
 
 
e5df979
 
8b7db2a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import gradio as gr
from faster_whisper import WhisperModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from pydub import AudioSegment
import yt_dlp as youtube_dl
import tempfile
from transformers.pipelines.audio_utils import ffmpeg_read
from gradio.components import Audio, Dropdown, Radio, Textbox
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"


# Paramètres
FILE_LIMIT_MB = 1000
YT_LENGTH_LIMIT_S = 3600  # Limite de 1 heure pour les vidéos YouTube

# Charger les codes de langue
from flores200_codes import flores_codes

# Fonction pour déterminer le device
def set_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

device = set_device()

# Charger les modèles une seule fois
model_dict = {}
def load_models():
    global model_dict
    if not model_dict:
        model_name_dict = {
            #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
            'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
            #'nllb-1.3B': 'facebook/nllb-200-1.3B',
            #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
            #'nllb-3.3B': 'facebook/nllb-200-3.3B',
            # 'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
        }
        for call_name, real_name in model_name_dict.items():
            model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
            tokenizer = AutoTokenizer.from_pretrained(real_name)
            model_dict[call_name+'_model'] = model
            model_dict[call_name+'_tokenizer'] = tokenizer

load_models()

# Fonction pour la transcription
def transcribe_audio(audio_file):
    model_size = "large-v2"
    model = WhisperModel(model_size)
    # model = WhisperModel(model_size, device=device, compute_type="int8")
    segments, _ = model.transcribe(audio_file, beam_size=1)
    transcriptions = [("[%.2fs -> %.2fs]" % (seg.start, seg.end), seg.text) for seg in segments]
    return transcriptions

# Fonction pour la traduction
def traduction(text, source_lang, target_lang):
    model_name = "nllb-distilled-600M"
    model = model_dict[model_name + "_model"]
    tokenizer = model_dict[model_name + "_tokenizer"]
    translator = pipeline("translation", model=model, tokenizer=tokenizer)
    return translator(text, src_lang=flores_codes[source_lang], tgt_lang=flores_codes[target_lang])[0]["translation_text"]

# Fonction principale
def full_transcription_and_translation(audio_file, source_lang, target_lang):
    if audio_file.startswith("http"):
        audio_file = download_yt_audio(audio_file)
    transcriptions = transcribe_audio(audio_file)
    translations = [(timestamp, traduction(text, source_lang, target_lang)) for timestamp, text in transcriptions]
    return transcriptions, translations

# Téléchargement audio YouTube
def download_yt_audio(yt_url):
    with tempfile.NamedTemporaryFile(suffix='.mp3') as f:
        ydl_opts = {
            'format': 'bestaudio/best',
            'outtmpl': f.name,
            'postprocessors': [{
                'key': 'FFmpegExtractAudio',
                'preferredcodec': 'mp3',
                'preferredquality': '192',
            }],
        }
        with youtube_dl.YoutubeDL(ydl_opts) as ydl:
            ydl.download([yt_url])
        return f.name

lang_codes = list(flores_codes.keys())

# Interface Gradio
def gradio_interface(audio_file, source_lang, target_lang):
    if audio_file.startswith("http"):
        audio_file = download_yt_audio(audio_file)
    transcriptions, translations = full_transcription_and_translation(audio_file, source_lang, target_lang)
    transcribed_text = '\n'.join([f"{timestamp}: {text}" for timestamp, text in transcriptions])
    translated_text = '\n'.join([f"{timestamp}: {text}" for timestamp, text in translations])
    return transcribed_text, translated_text

iface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Audio(type="filepath"),
        gr.Dropdown(lang_codes, value='French', label='Source Language'),
        gr.Dropdown(lang_codes, value='English', label='Target Language'),
    ],
    outputs=[
        gr.Textbox(label="Transcribed Text"),
        gr.Textbox(label="Translated Text")
    ]
)

iface.launch()