ashourzadeh7 commited on
Commit
cc91c93
1 Parent(s): 307f292

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -35
app.py CHANGED
@@ -1,29 +1,12 @@
1
- import os
2
- import torch
3
- import gradio as gr
4
- import time
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
- from flores200_codes import flores_codes
7
-
8
 
9
- def load_models():
10
- # build model and tokenizer
11
- model_name_dict = {'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
12
- #'nllb-1.3B': 'facebook/nllb-200-1.3B',
13
- #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
14
- #'nllb-3.3B': 'facebook/nllb-200-3.3B',
15
- }
16
-
17
- model_dict = {}
18
 
19
- for call_name, real_name in model_name_dict.items():
20
- print('\tLoading model: %s' % call_name)
21
- model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
22
- tokenizer = AutoTokenizer.from_pretrained(real_name)
23
- model_dict[call_name+'_model'] = model
24
- model_dict[call_name+'_tokenizer'] = tokenizer
25
 
26
- return model_dict
27
 
28
  LANGS = ["pes_Arab", "ckb_Arab", "eng_Latn"]
29
  langs_dict = {
@@ -36,13 +19,7 @@ def translate(text, src_lang, tgt_lang):
36
  """
37
  Translate the text from source lang to target lang
38
  """
39
-
40
- if len(model_dict) == 2:
41
- model_name = 'nllb-3.3B'
42
- model = model_dict[model_name + '_model']
43
- tokenizer = model_dict[model_name + '_tokenizer']
44
-
45
- translation_pipeline = pipeline("translation", model=model, tokenizer=tokenizer, src_lang=langs_dict[src_lang], tgt_lang=langs_dict[tgt_lang], max_length=400, device="cpu")
46
  result = translation_pipeline(text)
47
  return result[0]['translation_text']
48
 
@@ -76,12 +53,7 @@ def add_line(input_path, output_path):
76
  return output_path
77
 
78
  if __name__ == '__main__':
79
- print('\tinit models')
80
-
81
- #global model_dict
82
 
83
- #model_dict = load_models()
84
-
85
  interface = gr.Interface(
86
  fn=file_translate,
87
  inputs=[
 
 
 
 
 
 
 
 
1
 
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
+ import torch
 
 
 
 
 
 
 
4
 
5
+ # this model was loaded from https://hf.co/models
6
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
+ model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-1.3B").to(device)
8
+ tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-1.3B")
 
 
9
 
 
10
 
11
  LANGS = ["pes_Arab", "ckb_Arab", "eng_Latn"]
12
  langs_dict = {
 
19
  """
20
  Translate the text from source lang to target lang
21
  """
22
+ translation_pipeline = pipeline("translation", model=model, tokenizer=tokenizer, src_lang=langs_dict[src_lang], tgt_lang=langs_dict[tgt_lang], max_length=400, device=device)
 
 
 
 
 
 
23
  result = translation_pipeline(text)
24
  return result[0]['translation_text']
25
 
 
53
  return output_path
54
 
55
  if __name__ == '__main__':
 
 
 
56
 
 
 
57
  interface = gr.Interface(
58
  fn=file_translate,
59
  inputs=[