yvankob commited on
Commit
06397e6
·
1 Parent(s): 03da8e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -19
app.py CHANGED
@@ -28,23 +28,50 @@ pipe = pipeline(
28
  )
29
 
30
 
31
- def load_translation_model():
32
- model_name = 'facebook/nllb-200-distilled-1.3B'
33
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
34
- tokenizer = AutoTokenizer.from_pretrained(model_name)
35
- return model, tokenizer
36
-
37
- translation_model, translation_tokenizer = load_translation_model()
38
-
39
-
40
- def translate_text(text, source_language, target_language):
41
- source_code = flores_codes[source_language]
42
- target_code = flores_codes[target_language]
43
-
44
- translator = pipeline('translation', model=translation_model, tokenizer=translation_tokenizer, src_lang=source_code, tgt_lang=target_code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  output = translator(text, max_length=400)
46
- return output[0]['translation_text']
47
 
 
 
 
 
 
 
 
 
48
 
49
 
50
  def transcribe(inputs, task):
@@ -52,7 +79,7 @@ def transcribe(inputs, task):
52
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
53
 
54
  text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
55
- translated_text = translation(source_lang, target_lang, text)["result"]
56
  return text, translated_text
57
 
58
 
@@ -109,16 +136,16 @@ def yt_transcribe(yt_url, task, max_filesize=75.0):
109
  inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
110
 
111
  text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
112
- translated_text = translation(source_lang, target_lang, text)["result"]
113
 
114
  return html_embed_str, text, translated_text
115
 
116
 
117
- lang_codes = list(flores_codes.keys())
118
-
119
 
120
  demo = gr.Blocks()
121
 
 
 
122
  mf_transcribe = gr.Interface(
123
  fn=transcribe,
124
  inputs=[
 
28
  )
29
 
30
 
31
+ def load_models():
32
+ # build model and tokenizer
33
+ model_name_dict = {
34
+ 'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
35
+ #'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
36
+ #'nllb-1.3B': 'facebook/nllb-200-1.3B',
37
+ #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
38
+ #'nllb-3.3B': 'facebook/nllb-200-3.3B',
39
+ # 'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
40
+ }
41
+
42
+ model_dict = {}
43
+
44
+ for call_name, real_name in model_name_dict.items():
45
+ print('\tLoading model: %s' % call_name)
46
+ model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
47
+ tokenizer = AutoTokenizer.from_pretrained(real_name)
48
+ model_dict[call_name+'_model'] = model
49
+ model_dict[call_name+'_tokenizer'] = tokenizer
50
+
51
+ return model_dict
52
+
53
+ def translation(source, target, text):
54
+ if len(model_dict) == 2:
55
+ model_name = 'nllb-distilled-1.3B'
56
+
57
+ start_time = time.time()
58
+ source = flores_codes[source]
59
+ target = flores_codes[target]
60
+
61
+ model = model_dict[model_name + '_model']
62
+ tokenizer = model_dict[model_name + '_tokenizer']
63
+
64
+ translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
65
  output = translator(text, max_length=400)
 
66
 
67
+ end_time = time.time()
68
+
69
+ output = output[0]['translation_text']
70
+ result = {'inference_time': end_time - start_time,
71
+ 'source': source,
72
+ 'target': target,
73
+ 'result': output}
74
+ return result
75
 
76
 
77
  def transcribe(inputs, task):
 
79
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
80
 
81
  text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
82
+ translated_text = translation(source, target, text)["result"]
83
  return text, translated_text
84
 
85
 
 
136
  inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
137
 
138
  text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
139
+ translated_text = translation(source, target, text)["result"]
140
 
141
  return html_embed_str, text, translated_text
142
 
143
 
 
 
144
 
145
  demo = gr.Blocks()
146
 
147
+ lang_codes = list(flores_codes.keys())
148
+
149
  mf_transcribe = gr.Interface(
150
  fn=transcribe,
151
  inputs=[