File size: 3,458 Bytes
d1fad9f
e53d944
 
d1fad9f
e53d944
 
45cc880
2c7f6bc
45cc880
 
e53d944
 
051ace5
 
 
e53d944
051ace5
45cc880
8b4c96d
45cc880
bed31b5
 
 
196cc2a
 
bed31b5
 
 
 
8b4c96d
bed31b5
 
 
 
8b4c96d
bed31b5
8b4c96d
bed31b5
8b4c96d
 
 
196cc2a
 
8b4c96d
 
 
 
f444bbb
8b4c96d
 
 
 
 
 
 
 
30e2223
 
 
 
 
 
8b4c96d
 
 
 
 
 
 
5d2ba07
8b4c96d
 
30e2223
 
 
 
 
 
8b4c96d
bed31b5
 
 
8b4c96d
2ea2408
8b4c96d
2ea2408
 
 
30e2223
 
 
 
 
 
bed31b5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch

LANGS = ["kin_Latn","eng_Latn"]
TASK = "translation"
# CKPT = "DigitalUmuganda/Finetuned-NLLB"
MODELS = ["facebook/nllb-200-distilled-600M","DigitalUmuganda/Finetuned-NLLB"]
# model = AutoModelForSeq2SeqLM.from_pretrained(CKPT)
# tokenizer = AutoTokenizer.from_pretrained(CKPT)

device = 0 if torch.cuda.is_available() else -1
fb_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
du_model = AutoModelForSeq2SeqLM.from_pretrained("DigitalUmuganda/Finetuned-NLLB")
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")

models = {"facebook/nllb-200-distilled-600M":fb_model,"DigitalUmuganda/Finetuned-NLLB":du_model}
# def translate(text, src_lang, tgt_lang, max_length=400):
def translate_fb(text, src_lang, tgt_lang, max_length=400):

    """
    Translate the text from source lang to target lang
    """
    print("fb src_lang: ",src_lang)
    print("fb dest_lang: ",tgt_lang)
    translation_pipeline = pipeline(TASK,
                                    tokenizer=tokenizer,
                                    src_lang=src_lang,
                                    tgt_lang=tgt_lang,
                                    model = fb_model,
                                    max_length=max_length,
                                    device=device)

    result = translation_pipeline(text)
    return result[0]['translation_text']

def translate_du(text, src_lang, tgt_lang, CKPT, max_length=400):

    """
    Translate the text from source lang to target lang
    """
    print("du src_lang: ",src_lang)
    print("du tgt_lang: ",tgt_lang)
    translation_pipeline = pipeline(TASK,
                                    tokenizer=tokenizer,
                                    src_lang=src_lang,
                                    tgt_lang=tgt_lang,
                                    model = du_model,
                                    max_length=max_length,
                                    device=device)

    result = translation_pipeline(text)
    return result[0]['translation_text']

gr_fb = gr.Interface(
    translate_fb,
    [
        gr.components.Textbox(label="Text"),
        gr.components.Dropdown(label="Source Language", choices=LANGS),
        gr.components.Dropdown(label="Target Language", choices=LANGS),
        #gr.components.Slider(8, 512, value=400, step=8, label="Max Length")
    ],
    ['text'],
    #examples=examples,
    # article=article,
    cache_examples=False,
    title="nllb-200-distilled-600M",
    #description=description
)

gr_du = gr.Interface(
    translate_du,
    [
        gr.components.Textbox(label="Text"),
        gr.components.Dropdown(label="Source Language", choices=LANGS),
        gr.components.Dropdown(label="Target Language", choices=LANGS),
        #gr.components.Slider(8, 512, value=400, step=8, label="Max Length")
    ],
    ['text'],
    #examples=examples,
    # article=article,
    cache_examples=False,
    title="nllb-200-distilled-600M-Finetuned",
    # description=description
)
gr.Parallel(
    gr_fb,
    gr_du,  
    # [
    #     gr.components.Textbox(label="Text"),
    #     gr.components.Dropdown(label="Source Language", choices=LANGS),
    #     gr.components.Dropdown(label="Target Language", choices=LANGS),
    #     #gr.components.Slider(8, 512, value=400, step=8, label="Max Length")
    ).launch()