Kleber commited on
Commit
051ace5
1 Parent(s): 5d2ba07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -10,22 +10,23 @@ MODELS = ["facebook/nllb-200-distilled-600M","DigitalUmuganda/Finetuned-NLLB"]
10
  # tokenizer = AutoTokenizer.from_pretrained(CKPT)
11
 
12
  device = 0 if torch.cuda.is_available() else -1
 
 
 
13
 
14
-
15
  # def translate(text, src_lang, tgt_lang, max_length=400):
16
  def translate(text, src_lang, tgt_lang, CKPT, max_length=400):
17
 
18
  """
19
  Translate the text from source lang to target lang
20
  """
21
- model = AutoModelForSeq2SeqLM.from_pretrained(CKPT)
22
- tokenizer = AutoTokenizer.from_pretrained(CKPT)
23
-
24
  translation_pipeline = pipeline(TASK,
25
  tokenizer=tokenizer,
26
  src_lang=src_lang,
27
  tgt_lang=tgt_lang,
28
- model = model,
29
  max_length=max_length,
30
  device=device)
31
 
 
10
  # tokenizer = AutoTokenizer.from_pretrained(CKPT)
11
 
12
  device = 0 if torch.cuda.is_available() else -1
13
+ fb_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
14
+ du_model = AutoModelForSeq2SeqLM.from_pretrained("DigitalUmuganda/Finetuned-NLLB")
15
+ tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
16
 
17
+ models = {"facebook/nllb-200-distilled-600M":fb_model,"DigitalUmuganda/Finetuned-NLLB":du_model}
18
  # def translate(text, src_lang, tgt_lang, max_length=400):
19
  def translate(text, src_lang, tgt_lang, CKPT, max_length=400):
20
 
21
  """
22
  Translate the text from source lang to target lang
23
  """
24
+
 
 
25
  translation_pipeline = pipeline(TASK,
26
  tokenizer=tokenizer,
27
  src_lang=src_lang,
28
  tgt_lang=tgt_lang,
29
+ model = models[CKPT],
30
  max_length=max_length,
31
  device=device)
32