Spaces:
Runtime error
Runtime error
File size: 3,458 Bytes
d1fad9f e53d944 d1fad9f e53d944 45cc880 2c7f6bc 45cc880 e53d944 051ace5 e53d944 051ace5 45cc880 8b4c96d 45cc880 bed31b5 196cc2a bed31b5 8b4c96d bed31b5 8b4c96d bed31b5 8b4c96d bed31b5 8b4c96d 196cc2a 8b4c96d f444bbb 8b4c96d 30e2223 8b4c96d 5d2ba07 8b4c96d 30e2223 8b4c96d bed31b5 8b4c96d 2ea2408 8b4c96d 2ea2408 30e2223 bed31b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch
LANGS = ["kin_Latn","eng_Latn"]
TASK = "translation"
# CKPT = "DigitalUmuganda/Finetuned-NLLB"
MODELS = ["facebook/nllb-200-distilled-600M","DigitalUmuganda/Finetuned-NLLB"]
# model = AutoModelForSeq2SeqLM.from_pretrained(CKPT)
# tokenizer = AutoTokenizer.from_pretrained(CKPT)
device = 0 if torch.cuda.is_available() else -1
fb_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
du_model = AutoModelForSeq2SeqLM.from_pretrained("DigitalUmuganda/Finetuned-NLLB")
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
models = {"facebook/nllb-200-distilled-600M":fb_model,"DigitalUmuganda/Finetuned-NLLB":du_model}
# def translate(text, src_lang, tgt_lang, max_length=400):
def translate_fb(text, src_lang, tgt_lang, max_length=400):
"""
Translate the text from source lang to target lang
"""
print("fb src_lang: ",src_lang)
print("fb dest_lang: ",tgt_lang)
translation_pipeline = pipeline(TASK,
tokenizer=tokenizer,
src_lang=src_lang,
tgt_lang=tgt_lang,
model = fb_model,
max_length=max_length,
device=device)
result = translation_pipeline(text)
return result[0]['translation_text']
def translate_du(text, src_lang, tgt_lang, CKPT, max_length=400):
"""
Translate the text from source lang to target lang
"""
print("du src_lang: ",src_lang)
print("du tgt_lang: ",tgt_lang)
translation_pipeline = pipeline(TASK,
tokenizer=tokenizer,
src_lang=src_lang,
tgt_lang=tgt_lang,
model = du_model,
max_length=max_length,
device=device)
result = translation_pipeline(text)
return result[0]['translation_text']
gr_fb = gr.Interface(
translate_fb,
[
gr.components.Textbox(label="Text"),
gr.components.Dropdown(label="Source Language", choices=LANGS),
gr.components.Dropdown(label="Target Language", choices=LANGS),
#gr.components.Slider(8, 512, value=400, step=8, label="Max Length")
],
['text'],
#examples=examples,
# article=article,
cache_examples=False,
title="nllb-200-distilled-600M",
#description=description
)
gr_du = gr.Interface(
translate_du,
[
gr.components.Textbox(label="Text"),
gr.components.Dropdown(label="Source Language", choices=LANGS),
gr.components.Dropdown(label="Target Language", choices=LANGS),
#gr.components.Slider(8, 512, value=400, step=8, label="Max Length")
],
['text'],
#examples=examples,
# article=article,
cache_examples=False,
title="nllb-200-distilled-600M-Finetuned",
# description=description
)
gr.Parallel(
gr_fb,
gr_du,
# [
# gr.components.Textbox(label="Text"),
# gr.components.Dropdown(label="Source Language", choices=LANGS),
# gr.components.Dropdown(label="Target Language", choices=LANGS),
# #gr.components.Slider(8, 512, value=400, step=8, label="Max Length")
).launch()
|