Kleber commited on
Commit
8f6ad47
1 Parent(s): 937b21a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
+ import torch
4
+
5
+ LANGS = ["kin_Latn","eng_Latn"]
6
+ TASK = "translation"
7
+ # CKPT = "DigitalUmuganda/Finetuned-NLLB"
8
+ # MODELS = ["facebook/nllb-200-distilled-600M","DigitalUmuganda/Finetuned-NLLB"]
9
+ # model = AutoModelForSeq2SeqLM.from_pretrained(CKPT)
10
+ # tokenizer = AutoTokenizer.from_pretrained(CKPT)
11
+
12
+ device = 0 if torch.cuda.is_available() else -1
13
+
14
+ #general_model = AutoModelForSeq2SeqLM.from_pretrained("mbazaNLP/Nllb_finetuned_general_en_kin")
15
+ #education_model = AutoModelForSeq2SeqLM.from_pretrained("mbazaNLP/Nllb_finetuned_education_en_kin")
16
+ tourism_model = AutoModelForSeq2SeqLM.from_pretrained("mbazaNLP/Nllb_finetuned_tourism_en_kin")
17
+ #MODELS = {"General model":general_model_model,"Education model":education_model,"Tourism model":tourism_model}
18
+ #MODELS = {"Education model":education_model,"Tourism model":tourism_model}
19
+
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained("mbazaNLP/Nllb_finetuned_general_en_kin")
22
+ # def translate(text, src_lang, tgt_lang, max_length=400):
23
+
24
+ TASK = "translation"
25
+
26
+
27
+ device = 0 if torch.cuda.is_available() else -1
28
+
29
+
30
+
31
+ def translate(text, source_lang, target_lang, max_length=400):
32
+ """
33
+ Translate text from source language to target language
34
+ """
35
+ # src_lang = choose_language(source_lang)
36
+ # tgt_lang= choose_language(target_lang)
37
+ # if src_lang==None:
38
+ # return "Error: the source langage is incorrect"
39
+ # elif tgt_lang==None:
40
+ # return "Error: the target language is incorrect"
41
+
42
+ translation_pipeline = pipeline(TASK,
43
+ model=tourism_model,
44
+ tokenizer=tokenizer,
45
+ src_lang=source_lang,
46
+ tgt_lang=target_lang,
47
+ max_length=max_length,
48
+ device=device)
49
+ result = translation_pipeline(text)
50
+ return result[0]['translation_text']
51
+
52
+
53
+ gradio_ui= gr.Interface(
54
+ fn=translate,
55
+ title="NLLB-Tourism EN-KIN Translation Demo",
56
+ inputs= [
57
+ gr.components.Textbox(label="Text"),
58
+ gr.components.Dropdown(label="Source Language", choices=LANGS),
59
+ gr.components.Dropdown(label="Target Language", choices=LANGS),
60
+ # gr.components.Slider(8, 400, value=400, step=8, label="Max Length")
61
+ ],
62
+ outputs=gr.outputs.Textbox(label="Translated text")
63
+ )
64
+
65
+ gradio_ui.launch()