Kleber commited on
Commit
45cc880
1 Parent(s): bed31b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -36
app.py CHANGED
@@ -4,18 +4,23 @@ import torch
4
 
5
  LANGS = ["kin_Latn","eng_Latn"]
6
  TASK = "translation"
7
- CKPT = "DigitalUmuganda/Finetuned-NLLB"
8
-
9
- model = AutoModelForSeq2SeqLM.from_pretrained(CKPT)
10
- tokenizer = AutoTokenizer.from_pretrained(CKPT)
11
 
12
  device = 0 if torch.cuda.is_available() else -1
13
 
14
 
15
- def translate(text, src_lang, tgt_lang, max_length=400):
 
 
16
  """
17
  Translate the text from source lang to target lang
18
  """
 
 
 
19
  translation_pipeline = pipeline(TASK,
20
  model=model,
21
  tokenizer=tokenizer,
@@ -31,6 +36,7 @@ def translate(text, src_lang, tgt_lang, max_length=400):
31
  gr.Interface(
32
  translate,
33
  [
 
34
  gr.components.Textbox(label="Text"),
35
  gr.components.Dropdown(label="Source Language", choices=LANGS),
36
  gr.components.Dropdown(label="Target Language", choices=LANGS),
@@ -44,34 +50,3 @@ gr.Interface(
44
  #description=description
45
  ).launch()
46
 
47
- def translate(text, src_lang, tgt_lang, max_length=400):
48
- """
49
- Translate the text from source lang to target lang
50
- """
51
- translation_pipeline = pipeline(TASK,
52
- model=model,
53
- tokenizer=tokenizer,
54
- src_lang=src_lang,
55
- tgt_lang=tgt_lang,
56
- max_length=max_length,
57
- device=device)
58
-
59
- result = translation_pipeline(text)
60
- return result[0]['translation_text']
61
-
62
-
63
- gr.Interface(
64
- translate,
65
- [
66
- gr.components.Textbox(label="Text"),
67
- gr.components.Dropdown(label="Source Language", choices=LANGS),
68
- gr.components.Dropdown(label="Target Language", choices=LANGS),
69
- #gr.components.Slider(8, 512, value=400, step=8, label="Max Length")
70
- ],
71
- ["text"],
72
- #examples=examples,
73
- # article=article,
74
- cache_examples=False,
75
- title="Finetuned-NLLB",
76
- #description=description
77
- ).launch()
 
4
 
5
  LANGS = ["kin_Latn","eng_Latn"]
6
  TASK = "translation"
7
+ # CKPT = "DigitalUmuganda/Finetuned-NLLB"
8
+ models = ["facebook/nllb-200-distilled-600M","DigitalUmuganda/Finetuned-NLLB"]
9
+ # model = AutoModelForSeq2SeqLM.from_pretrained(CKPT)
10
+ # tokenizer = AutoTokenizer.from_pretrained(CKPT)
11
 
12
  device = 0 if torch.cuda.is_available() else -1
13
 
14
 
15
+ # def translate(text, src_lang, tgt_lang, max_length=400):
16
+ def translate(model,text, src_lang, tgt_lang, max_length=400):
17
+
18
  """
19
  Translate the text from source lang to target lang
20
  """
21
+ model = AutoModelForSeq2SeqLM.from_pretrained(model)
22
+ tokenizer = AutoTokenizer.from_pretrained(model)
23
+
24
  translation_pipeline = pipeline(TASK,
25
  model=model,
26
  tokenizer=tokenizer,
 
36
  gr.Interface(
37
  translate,
38
  [
39
+ gr.components.Dropdown(label="choose a model",choices=models)
40
  gr.components.Textbox(label="Text"),
41
  gr.components.Dropdown(label="Source Language", choices=LANGS),
42
  gr.components.Dropdown(label="Target Language", choices=LANGS),
 
50
  #description=description
51
  ).launch()
52