ashourzadeh7 commited on
Commit
cefb4b4
1 Parent(s): 518fac7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -20
app.py CHANGED
@@ -6,36 +6,79 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
  from flores200_codes import flores_codes
7
 
8
 
 
 
 
 
 
 
 
9
 
10
- def transfer(input):
11
- with open(input, 'r', encoding="utf-8") as f:
12
- text = f.read()
13
 
14
- output_file = "out.txt"
15
- with open(output_file, 'w', encoding="utf-8") as f:
16
- file = f.write(text)
17
- return file
 
 
18
 
 
19
 
20
- if __name__ == '__main__':
 
 
 
 
 
 
 
 
 
 
21
 
 
 
 
 
22
 
 
 
 
23
 
24
- #inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'),
25
- inputs = [gr.components.file(label="Input File")]
 
 
 
26
 
27
- outputs = gr.components.file(label="Translated File", value=file)
 
 
28
 
29
- title = "NLLB distilled 600M demo"
 
 
 
30
 
31
- demo_status = "Demo is running on CPU"
32
- description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
33
 
34
- gr.Interface(translation,
35
- inputs,
36
- outputs,
37
- title=title,
38
- description=description,
39
- ).launch()
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
6
  from flores200_codes import flores_codes
7
 
8
 
9
+ def load_models():
10
+ # build model and tokenizer
11
+ model_name_dict = {'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
12
+ #'nllb-1.3B': 'facebook/nllb-200-1.3B',
13
+ #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
14
+ #'nllb-3.3B': 'facebook/nllb-200-3.3B',
15
+ }
16
 
17
+ model_dict = {}
 
 
18
 
19
+ for call_name, real_name in model_name_dict.items():
20
+ print('\tLoading model: %s' % call_name)
21
+ model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
22
+ tokenizer = AutoTokenizer.from_pretrained(real_name)
23
+ model_dict[call_name+'_model'] = model
24
+ model_dict[call_name+'_tokenizer'] = tokenizer
25
 
26
+ return model_dict
27
 
28
+ LANGS = ["pes_Arab", "ckb_Arab", "eng_Latn"]
29
+ langs_dict = {
30
+ "فارسی": "pes_Arab",
31
+ "کردی": "ckb_Arab",
32
+ "انگلیسی": "eng_Latn"
33
+ }
34
+
35
+ def translate(text, src_lang, tgt_lang):
36
+ """
37
+ Translate the text from source lang to target lang
38
+ """
39
 
40
+ if len(model_dict) == 2:
41
+ model_name = 'nllb-3.3B'
42
+ model = model_dict[model_name + '_model']
43
+ tokenizer = model_dict[model_name + '_tokenizer']
44
 
45
+ translation_pipeline = pipeline("translation", model=model, tokenizer=tokenizer, src_lang=langs_dict[src_lang], tgt_lang=langs_dict[tgt_lang], max_length=400, device="cpu")
46
+ result = translation_pipeline(text)
47
+ return result[0]['translation_text']
48
 
49
+ def file_translate(sorce_file_path, pred_file_path):
50
+ sorce_list = []
51
+ with open(sorce_file_path, "r", encoding="utf-8") as sorce_file:
52
+ for line in sorce_file:
53
+ sorce_list.append(line.strip())
54
 
55
+ pred_list = []
56
+ for line in sorce_list:
57
+ pred_list.append(translate(line, list(langs_dict.keys())[0], list(langs_dict.keys())[1]))
58
 
59
+ with open(pred_file_path, "w", encoding="utf-8") as output_file:
60
+ for translation in pred_list:
61
+ output_file.write(translation + "\n")
62
+ return pred_file_path
63
 
64
+ if __name__ == '__main__':
65
+ print('\tinit models')
66
 
67
+ global model_dict
 
 
 
 
 
68
 
69
+ model_dict = load_models()
70
+
71
+ interface = gr.Interface(
72
+ fn=file_translate,
73
+ inputs=[
74
+ gr.components.File(label="Input File"),
75
+ gr.components.Textbox(label="Output File Name (optional)"),
76
+ ],
77
+ outputs=[
78
+ gr.components.File(label="Modified File"),
79
+ ],
80
+ title="Add 'Hello' Line to Text File",
81
+ description="This Gradio demo adds the line 'Hello' to the end of a text file.",
82
+ )
83
+ interface.launch()
84