File size: 2,921 Bytes
476e166
 
a7f2f12
8b8b295
476e166
a7f2f12
 
 
 
 
 
 
 
476e166
8b8b295
476e166
 
 
a7f2f12
 
 
dfe3477
a7f2f12
 
 
 
 
 
 
 
 
 
8b8b295
 
 
 
 
 
8262899
8b8b295
 
 
ceaa373
a7f2f12
 
 
 
8b8b295
a7f2f12
 
 
 
 
 
 
 
 
 
 
8b8b295
 
a7f2f12
8b8b295
ceaa373
 
f54a85e
 
dfe3477
f54a85e
 
 
 
476e166
ceaa373
a7f2f12
 
f54a85e
 
a7f2f12
ceaa373
 
 
 
476e166
f54a85e
 
97e52ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import requests
import os

import fasttext
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch

title = "Community Tab Language Detection & Translation"
description = """
When comments are created in the community tab, detect the language of the content.
Then, if the detected language is different from the user's language, display an option to translate it.
"""

LANG_ID_API_URL = "https://q5esh83u7boq5qwd.us-east-1.aws.endpoints.huggingface.cloud"
ACCESS_TOKEN = os.environ.get("ACCESS_TOKEN")
headers = {"Authorization": f"Bearer {ACCESS_TOKEN}"}

model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
device = 0 if torch.cuda.is_available() else -1
print(f"Is CUDA available: {torch.cuda.is_available()}")

language_code_map = {
    "English": "eng_Latn",
    "French": "fra_Latn",
    "German": "deu_Latn",
    "Spanish": "spa_Latn",
    "Korean": "kor_Hang",
    "Japanese": "jpn_Jpan"
}

def identify_language(text):
    model_file = "lid218e.bin"
    model_full_path = os.path.join(os.path.dirname(__file__), model_file)
    model = fasttext.load_model(model_full_path)
    predictions = model.predict(text, k=1) # e.g., (('__label__eng_Latn',), array([0.81148803]))
    
    PREFIX_LENGTH = 9 # To strip away '__label__' from language code
    language_code = predictions[0][0][PREFIX_LENGTH:]
    return language_code
    


def translate(text, src_lang, tgt_lang):
    src_lang_code = language_code_map[src_lang]
    tgt_lang_code = language_code_map[tgt_lang]

    translation_pipeline = pipeline(
        "translation", model=model, tokenizer=tokenizer, src_lang=src_lang_code, tgt_lang=tgt_lang_code, device=device)
    result = translation_pipeline(text)
    return result[0]['translation_text']


def query(text, src_lang, tgt_lang):
    translation = translate(text, src_lang, tgt_lang)
    lang_id_response = requests.post(LANG_ID_API_URL, headers=headers, json={
        "inputs": text, "wait_for_model": True, "use_cache": True})
    lang_id = lang_id_response.json()[0]
    
    language_code = identify_language(text)

    return [language_code, translation]


examples = [
    ["Hello, world", "English", "French"],
    ["Can I have a cheeseburger?", "English", "German"],
    ["Hasta la vista", "Spanish", "German"],
    ["동경에 휴가를 간다", "Korean", "Japanese"],
]

gr.Interface(
    query,
    [
        gr.Textbox(lines=2),
        gr.Radio(["English", "Spanish", "Korean"], value="English", label="Source Language"),
        gr.Radio(["French", "German", "Japanese"], value="French", label="Target Language")
    ],
    outputs=[
        gr.Textbox(lines=3, label="Detected Language"),
        gr.Textbox(lines=3, label="Translation")
    ],
    title=title,
    description=description,
    examples=examples
).launch()