language-demo / app.py
sheonhan's picture
clean up
dfe3477
raw
history blame
2.82 kB
import requests
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch
title = "Community Tab Language Detection & Translation"
description = """
When comments are created in the community tab, detect the language of the content.
Then, if the detected language is different from the user's language, display an option to translate it.
"""
TRANSLATION_API_URL = "https://api-inference.huggingface.co/models/t5-base"
LANG_ID_API_URL = "https://noe30ht5sav83xm1.us-east-1.aws.endpoints.huggingface.cloud"
ACCESS_TOKEN = os.environ.get("ACCESS_TOKEN")
ACCESS_TOKEN = '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'
headers = {"Authorization": f"Bearer {ACCESS_TOKEN}"}
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
device = 0 if torch.cuda.is_available() else -1
print(f"Is CUDA available: {torch.cuda.is_available()}")
language_code_map = {
"English": "eng_Latn",
"French": "fra_Latn",
"German": "deu_Latn",
"Spanish": "spa_Latn",
"Korean": "kor_Hang",
"Japanese": "jpn_Jpan"
}
def translate_from_api(text):
response = requests.post(TRANSLATION_API_URL, headers=headers, json={
"inputs": text, "wait_for_model": True, "use_cache": True})
return response.json()[0]['translation_text']
def translate(text, src_lang, tgt_lang):
src_lang_code = language_code_map[src_lang]
tgt_lang_code = language_code_map[tgt_lang]
print(f"src: {src_lang_code} tgt: {tgt_lang_code}")
translation_pipeline = pipeline(
"translation", model=model, tokenizer=tokenizer, src_lang=src_lang_code, tgt_lang=tgt_lang_code, device=device)
result = translation_pipeline(text)
return result[0]['translation_text']
def query(text, src_lang, tgt_lang):
translation = translate(text, src_lang, tgt_lang)
lang_id_response = requests.post(LANG_ID_API_URL, headers=headers, json={
"inputs": text, "wait_for_model": True, "use_cache": True})
lang_id = lang_id_response.json()[0]
return [lang_id, translation]
examples = [
["Hello, world", "English", "French"],
["Can I have a cheeseburger?", "English", "German"],
["Hasta la vista", "Spanish", "German"],
["동경에 휴가를 간다", "Korean", "Japanese"],
]
gr.Interface(
query,
[
gr.Textbox(lines=2),
gr.Radio(["English", "Spanish", "Korean"], value="English", label="Source Language"),
gr.Radio(["French", "German", "Japanese"], value="French", label="Target Language")
],
outputs=[
gr.Textbox(lines=3, label="Detected Language"),
gr.Textbox(lines=3, label="Translation")
],
title=title,
description=description,
examples=examples
).launch()