datnth1709 commited on
Commit
85cd50e
·
1 Parent(s): 765e08f
Files changed (2) hide show
  1. app.py +46 -107
  2. app_old.py +0 -362
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import nltk
3
  import librosa
4
- import soundfile as sf
5
  from transformers import pipeline
6
  from transformers.file_utils import cached_path, hf_bucket_url
7
  import os, zipfile
@@ -80,7 +79,9 @@ def speech2text_vi(audio):
80
  """English speech2text"""
81
  nltk.download("punkt")
82
  # Loading the model and the tokenizer
83
- eng_asr = pipeline("automatic-speech-recognition", "facebook/wav2vec2-base-960h")
 
 
84
 
85
  def load_data(input_file):
86
  """ Function for resampling to ensure that the speech input is sampled at 16KHz.
@@ -93,7 +94,7 @@ def load_data(input_file):
93
  # Resampling at 16KHz since wav2vec2-base-960h is pretrained and fine-tuned on speech audio sampled at 16 KHz.
94
  if sample_rate != 16000:
95
  speech = librosa.resample(speech, sample_rate, 16000)
96
- return speech, sample_rate
97
 
98
  def correct_casing(input_sentence):
99
  """ This function is for correcting the casing of the generated transcribed text
@@ -105,10 +106,18 @@ def correct_casing(input_sentence):
105
  def speech2text_en(input_file):
106
  """This function generates transcripts for the provided audio input
107
  """
108
- speech, samplerate = load_data(input_file)
109
  # Tokenize
110
- text = eng_asr(speech)["text"]
111
- return text
 
 
 
 
 
 
 
 
112
 
113
 
114
  """Machine translation"""
@@ -138,33 +147,6 @@ def inference_envi(audio):
138
  return en_text, vi_text
139
 
140
  def transcribe_vi(audio, state_vi="", state_en=""):
141
- ds = speech_file_to_array_fn(audio.name)
142
- # infer model
143
- input_values = processor(
144
- ds["speech"],
145
- sampling_rate=ds["sampling_rate"],
146
- return_tensors="pt"
147
- ).input_values
148
- # decode ctc output
149
- logits = vi_model(input_values).logits[0]
150
- pred_ids = torch.argmax(logits, dim=-1)
151
- greedy_search_output = processor.decode(pred_ids)
152
- beam_search_output = ngram_lm_model.decode(logits.cpu().detach().numpy(), beam_width=500)
153
- state_vi += beam_search_output + " "
154
- en_text = translate_vi2en(beam_search_output)
155
- state_en += en_text + " "
156
- return state_vi, state_en
157
-
158
- def transcribe_en(audio, state_en="", state_vi=""):
159
- speech, samplerate = load_data(audio)
160
- # Tokenize
161
- transcription = eng_asr(speech)["text"]
162
- state_en += transcription + " "
163
- vi_text = translate_en2vi(transcription)
164
- state_vi += vi_text + " "
165
- return state_en, state_vi
166
-
167
- def transcribe_vi_rm(audio, state_vi="", state_en=""):
168
  ds = speech_file_to_array_fn(audio.name)
169
  # infer model
170
  input_values = processor(
@@ -182,41 +164,23 @@ def transcribe_vi_rm(audio, state_vi="", state_en=""):
182
  state_en += en_text + " "
183
  return state_vi, state_en, state_vi, state_en
184
 
185
- def transcribe_en_rm(audio, state_en="", state_vi=""):
186
- speech, samplerate = load_data(audio)
187
  # Tokenize
188
- transcription = eng_asr(speech)["text"]
 
 
 
 
 
 
 
 
189
  state_en += transcription + " "
190
  vi_text = translate_en2vi(transcription)
191
  state_vi += vi_text + " "
192
  return state_en, state_vi, state_en, state_vi
193
 
194
- def transcribe_vi_rd(audio, state=""):
195
- ds = speech_file_to_array_fn(audio.name)
196
- # infer model
197
- input_values = processor(
198
- ds["speech"],
199
- sampling_rate=ds["sampling_rate"],
200
- return_tensors="pt"
201
- ).input_values
202
- # decode ctc output
203
- logits = vi_model(input_values).logits[0]
204
- pred_ids = torch.argmax(logits, dim=-1)
205
- greedy_search_output = processor.decode(pred_ids)
206
- beam_search_output = ngram_lm_model.decode(logits.cpu().detach().numpy(), beam_width=500)
207
- en_text = translate_vi2en(beam_search_output)
208
- state += en_text + " "
209
- return state, state
210
-
211
- def transcribe_en_rd(audio, state=""):
212
- speech, samplerate = load_data(audio)
213
- # Tokenize
214
- transcription = eng_asr(speech)["text"]
215
- transcription = correct_casing(transcription.lower())
216
- vi_text = translate_en2vi(transcription)
217
- state += vi_text + " "
218
- return state, state
219
-
220
  """Gradio demo"""
221
 
222
  vi_example_text = ["Có phải bạn đang muốn tìm mua nhà ở ngoại ô thành phố Hồ Chí Minh không?",
@@ -243,39 +207,27 @@ with gr.Blocks() as demo:
243
  translate_button_vien_1.click(lambda text: translate_vi2en(text), inputs=vietnamese_text, outputs=english_out_1)
244
  gr.Examples(examples=vi_example_text,
245
  inputs=[vietnamese_text])
246
-
247
  with gr.TabItem("Speech2text and Vi-En Translation"):
248
  with gr.Row():
249
  with gr.Column():
250
- vi_audio_1 = gr.Audio(source="microphone", label="Input Vietnamese Audio", type="file", streaming=False)
251
  translate_button_vien_2 = gr.Button(value="Translate To English")
252
  with gr.Column():
253
  speech2text_vi1 = gr.Textbox(label="Vietnamese Text")
254
  english_out_2 = gr.Textbox(label="English Text")
255
- translate_button_vien_2.click(lambda vi_voice: inference_vien(vi_voice), inputs=vi_audio_1, outputs=[speech2text_vi1, english_out_2])
 
256
  gr.Examples(examples=vi_example_voice,
257
- inputs=[vi_audio_1])
258
-
259
  with gr.TabItem("Vi-En Realtime Translation"):
260
- # with gr.Row():
261
- # with gr.Column():
262
- # vi_audio_2 = gr.Audio(source="microphone", label="Input Vietnamese Audio", type="file", streaming=True)
263
- # with gr.Column():
264
- # speech2text_vi2 = gr.Textbox(label="Vietnamese Text")
265
- # english_out_3 = gr.Textbox(label="English Text")
266
- # vi_audio_2.change(transcribe_vi, [vi_audio_2, speech2text_vi2, english_out_3], [speech2text_vi2, english_out_3])
267
-
268
- gr.Interface(
269
- fn=transcribe_vi_rd,
270
- inputs=[
271
- gr.Audio(source="microphone", type="file", streaming=True),
272
- "state"
273
- ],
274
- outputs=[
275
- "textbox",
276
- "state"
277
- ],
278
- live=True).launch()
279
 
280
 
281
  with gr.Tabs():
@@ -303,27 +255,14 @@ with gr.Blocks() as demo:
303
  inputs=[en_audio_1])
304
 
305
  with gr.TabItem("En-Vi Realtime Translation"):
306
- # with gr.Row():
307
- # with gr.Column():
308
- # en_audio_2 = gr.Audio(source="microphone", label="Input English Audio", type="filepath", streaming=True)
309
- # with gr.Column():
310
- # speech2text_en2 = gr.Textbox(label="English Text")
311
- # vietnamese_out_3 = gr.Textbox(label="Vietnamese Text")
312
- # en_audio_2.change(transcribe_en, [en_audio_2, speech2text_en2, vietnamese_out_3], [speech2text_en2, vietnamese_out_3])
313
- # speech2text_en2, vietnamese_out_3 = transcribe_en(en_audio_2, speech2text_en2, vietnamese_out_3)
314
-
315
- gr.Interface(
316
- fn=transcribe_en_rd,
317
- inputs=[
318
- gr.Audio(source="microphone", type="filepath", streaming=True),
319
- "state"
320
- ],
321
- outputs=[
322
- "textbox",
323
- "state"
324
- ],
325
- live=True).launch()
326
-
327
 
328
  if __name__ == "__main__":
329
  demo.launch()
 
1
  import gradio as gr
2
  import nltk
3
  import librosa
 
4
  from transformers import pipeline
5
  from transformers.file_utils import cached_path, hf_bucket_url
6
  import os, zipfile
 
79
  """English speech2text"""
80
  nltk.download("punkt")
81
  # Loading the model and the tokenizer
82
+ model_name = "facebook/wav2vec2-base-960h"
83
+ eng_tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)
84
+ eng_model = Wav2Vec2ForCTC.from_pretrained(model_name)
85
 
86
  def load_data(input_file):
87
  """ Function for resampling to ensure that the speech input is sampled at 16KHz.
 
94
  # Resampling at 16KHz since wav2vec2-base-960h is pretrained and fine-tuned on speech audio sampled at 16 KHz.
95
  if sample_rate != 16000:
96
  speech = librosa.resample(speech, sample_rate, 16000)
97
+ return speech
98
 
99
  def correct_casing(input_sentence):
100
  """ This function is for correcting the casing of the generated transcribed text
 
106
  def speech2text_en(input_file):
107
  """This function generates transcripts for the provided audio input
108
  """
109
+ speech = load_data(input_file)
110
  # Tokenize
111
+ input_values = eng_tokenizer(speech, return_tensors="pt").input_values
112
+ # Take logits
113
+ logits = eng_model(input_values).logits
114
+ # Take argmax
115
+ predicted_ids = torch.argmax(logits, dim=-1)
116
+ # Get the words from predicted word ids
117
+ transcription = eng_tokenizer.decode(predicted_ids[0])
118
+ # Output is all upper case
119
+ transcription = correct_casing(transcription.lower())
120
+ return transcription
121
 
122
 
123
  """Machine translation"""
 
147
  return en_text, vi_text
148
 
149
  def transcribe_vi(audio, state_vi="", state_en=""):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  ds = speech_file_to_array_fn(audio.name)
151
  # infer model
152
  input_values = processor(
 
164
  state_en += en_text + " "
165
  return state_vi, state_en, state_vi, state_en
166
 
167
+ def transcribe_en(audio, state_en="", state_vi=""):
168
+ speech = load_data(audio)
169
  # Tokenize
170
+ input_values = eng_tokenizer(speech, return_tensors="pt").input_values
171
+ # Take logits
172
+ logits = eng_model(input_values).logits
173
+ # Take argmax
174
+ predicted_ids = torch.argmax(logits, dim=-1)
175
+ # Get the words from predicted word ids
176
+ transcription = eng_tokenizer.decode(predicted_ids[0])
177
+ # Output is all upper case
178
+ transcription = correct_casing(transcription.lower())
179
  state_en += transcription + " "
180
  vi_text = translate_en2vi(transcription)
181
  state_vi += vi_text + " "
182
  return state_en, state_vi, state_en, state_vi
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  """Gradio demo"""
185
 
186
  vi_example_text = ["Có phải bạn đang muốn tìm mua nhà ở ngoại ô thành phố Hồ Chí Minh không?",
 
207
  translate_button_vien_1.click(lambda text: translate_vi2en(text), inputs=vietnamese_text, outputs=english_out_1)
208
  gr.Examples(examples=vi_example_text,
209
  inputs=[vietnamese_text])
 
210
  with gr.TabItem("Speech2text and Vi-En Translation"):
211
  with gr.Row():
212
  with gr.Column():
213
+ vi_audio = gr.Audio(source="microphone", label="Input Vietnamese Audio", type="file", streaming=False)
214
  translate_button_vien_2 = gr.Button(value="Translate To English")
215
  with gr.Column():
216
  speech2text_vi1 = gr.Textbox(label="Vietnamese Text")
217
  english_out_2 = gr.Textbox(label="English Text")
218
+
219
+ translate_button_vien_2.click(lambda vi_voice: inference_vien(vi_voice), inputs=vi_audio, outputs=[speech2text_vi1, english_out_2])
220
  gr.Examples(examples=vi_example_voice,
221
+ inputs=[vi_audio])
 
222
  with gr.TabItem("Vi-En Realtime Translation"):
223
+ with gr.Row():
224
+ with gr.Column():
225
+ vi_audio = gr.Audio(source="microphone", label="Input Vietnamese Audio", type="file", streaming=True)
226
+ translate_button_vien_2 = gr.Button(value="Translate To English")
227
+ with gr.Column():
228
+ speech2text_vi2 = gr.Textbox(label="Vietnamese Text")
229
+ english_out_3 = gr.Textbox(label="English Text")
230
+ vi_audio.change(transcribe_vi, [vi_audio, "state_vi", "state_en"], [speech2text_vi2, english_out_3, "state_vi", "state_en"])
 
 
 
 
 
 
 
 
 
 
 
231
 
232
 
233
  with gr.Tabs():
 
255
  inputs=[en_audio_1])
256
 
257
  with gr.TabItem("En-Vi Realtime Translation"):
258
+ with gr.Row():
259
+ with gr.Column():
260
+ en_audio_2 = gr.Audio(source="microphone", label="Input English Audio", type="filepath", streaming=True)
261
+ # translate_button_envi_2 = gr.Button(value="Translate To Vietnamese")
262
+ with gr.Column():
263
+ speech2text_en2 = gr.Textbox(label="English Text")
264
+ vietnamese_out_3 = gr.Textbox(label="Vietnamese Text")
265
+ en_audio_2.change(transcribe_en, [en_audio_2, "state_en", "state_vi"], [speech2text_en2, vietnamese_out_3, "state_en", "state_vi"])
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
  if __name__ == "__main__":
268
  demo.launch()
app_old.py DELETED
@@ -1,362 +0,0 @@
1
- import gradio as gr
2
- import nltk
3
- import librosa
4
- import soundfile as sf
5
- from transformers import pipeline
6
- from transformers.file_utils import cached_path, hf_bucket_url
7
- import os, zipfile
8
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2Tokenizer
9
- from datasets import load_dataset
10
- import torch
11
- import kenlm
12
- import torchaudio
13
- from pyctcdecode import Alphabet, BeamSearchDecoderCTC, LanguageModel
14
-
15
- """Vietnamese speech2text"""
16
- cache_dir = './cache/'
17
- processor = Wav2Vec2Processor.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h", cache_dir=cache_dir)
18
- vi_model = Wav2Vec2ForCTC.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h", cache_dir=cache_dir)
19
- lm_file = hf_bucket_url("nguyenvulebinh/wav2vec2-base-vietnamese-250h", filename='vi_lm_4grams.bin.zip')
20
- lm_file = cached_path(lm_file,cache_dir=cache_dir)
21
- with zipfile.ZipFile(lm_file, 'r') as zip_ref:
22
- zip_ref.extractall(cache_dir)
23
- lm_file = cache_dir + 'vi_lm_4grams.bin'
24
-
25
- def get_decoder_ngram_model(tokenizer, ngram_lm_path):
26
- vocab_dict = tokenizer.get_vocab()
27
- sort_vocab = sorted((value, key) for (key, value) in vocab_dict.items())
28
- vocab = [x[1] for x in sort_vocab][:-2]
29
- vocab_list = vocab
30
- # convert ctc blank character representation
31
- vocab_list[tokenizer.pad_token_id] = ""
32
- # replace special characters
33
- vocab_list[tokenizer.unk_token_id] = ""
34
- # vocab_list[tokenizer.bos_token_id] = ""
35
- # vocab_list[tokenizer.eos_token_id] = ""
36
- # convert space character representation
37
- vocab_list[tokenizer.word_delimiter_token_id] = " "
38
- # specify ctc blank char index, since conventially it is the last entry of the logit matrix
39
- alphabet = Alphabet.build_alphabet(vocab_list, ctc_token_idx=tokenizer.pad_token_id)
40
- lm_model = kenlm.Model(ngram_lm_path)
41
- decoder = BeamSearchDecoderCTC(alphabet,
42
- language_model=LanguageModel(lm_model))
43
- return decoder
44
- ngram_lm_model = get_decoder_ngram_model(processor.tokenizer, lm_file)
45
-
46
- # define function to read in sound file
47
- def speech_file_to_array_fn(path, max_seconds=10):
48
- batch = {"file": path}
49
- speech_array, sampling_rate = torchaudio.load(batch["file"])
50
- if sampling_rate != 16000:
51
- transform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
52
- new_freq=16000)
53
- speech_array = transform(speech_array)
54
- speech_array = speech_array[0]
55
- if max_seconds > 0:
56
- speech_array = speech_array[:max_seconds*16000]
57
- batch["speech"] = speech_array.numpy()
58
- batch["sampling_rate"] = 16000
59
- return batch
60
-
61
- # tokenize
62
- def speech2text_vi(audio):
63
- # read in sound file
64
- # load dummy dataset and read soundfiles
65
- ds = speech_file_to_array_fn(audio.name)
66
- # infer model
67
- input_values = processor(
68
- ds["speech"],
69
- sampling_rate=ds["sampling_rate"],
70
- return_tensors="pt"
71
- ).input_values
72
- # decode ctc output
73
- logits = vi_model(input_values).logits[0]
74
- pred_ids = torch.argmax(logits, dim=-1)
75
- greedy_search_output = processor.decode(pred_ids)
76
- beam_search_output = ngram_lm_model.decode(logits.cpu().detach().numpy(), beam_width=500)
77
- return beam_search_output
78
-
79
-
80
- """English speech2text"""
81
- nltk.download("punkt")
82
- # Loading the model and the tokenizer
83
- model_name = "facebook/wav2vec2-base-960h"
84
- eng_tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)
85
- eng_model = Wav2Vec2ForCTC.from_pretrained(model_name)
86
-
87
- def load_data(input_file):
88
- """ Function for resampling to ensure that the speech input is sampled at 16KHz.
89
- """
90
- # read the file
91
- speech, sample_rate = librosa.load(input_file)
92
- # make it 1-D
93
- if len(speech.shape) > 1:
94
- speech = speech[:, 0] + speech[:, 1]
95
- # Resampling at 16KHz since wav2vec2-base-960h is pretrained and fine-tuned on speech audio sampled at 16 KHz.
96
- if sample_rate != 16000:
97
- speech = librosa.resample(speech, sample_rate, 16000)
98
- return speech, sample_rate
99
-
100
- def correct_casing(input_sentence):
101
- """ This function is for correcting the casing of the generated transcribed text
102
- """
103
- sentences = nltk.sent_tokenize(input_sentence)
104
- return (' '.join([s.replace(s[0], s[0].capitalize(), 1) for s in sentences]))
105
-
106
-
107
- def speech2text_en(input_file):
108
- """This function generates transcripts for the provided audio input
109
- """
110
- speech, samplerate = load_data(input_file)
111
- # Tokenize
112
- input_values = eng_tokenizer(speech, sampling_rate = samplerate, return_tensors="pt").input_values
113
- # Take logits
114
- logits = eng_model(input_values).logits
115
- # Take argmax
116
- predicted_ids = torch.argmax(logits, dim=-1)
117
- # Get the words from predicted word ids
118
- transcription = eng_tokenizer.decode(predicted_ids[0])
119
- # Output is all upper case
120
- transcription = correct_casing(transcription.lower())
121
- return transcription
122
-
123
-
124
- """Machine translation"""
125
- vien_model_checkpoint = "datnth1709/finetuned_HelsinkiNLP-opus-mt-vi-en_PhoMT"
126
- envi_model_checkpoint = "datnth1709/finetuned_HelsinkiNLP-opus-mt-en-vi_PhoMT"
127
- vien_translator = pipeline("translation", model=vien_model_checkpoint)
128
- envi_translator = pipeline("translation", model=envi_model_checkpoint)
129
-
130
- def translate_vi2en(Vietnamese):
131
- return vien_translator(Vietnamese)[0]['translation_text']
132
-
133
- def translate_en2vi(English):
134
- return envi_translator(English)[0]['translation_text']
135
-
136
-
137
-
138
-
139
- """ Inference"""
140
- def inference_vien(audio):
141
- vi_text = speech2text_vi(audio)
142
- en_text = translate_vi2en(vi_text)
143
- return vi_text, en_text
144
-
145
- def inference_envi(audio):
146
- en_text = speech2text_en(audio)
147
- vi_text = translate_en2vi(en_text)
148
- return en_text, vi_text
149
-
150
- def transcribe_vi(audio, state_vi="", state_en=""):
151
- ds = speech_file_to_array_fn(audio.name)
152
- # infer model
153
- input_values = processor(
154
- ds["speech"],
155
- sampling_rate=ds["sampling_rate"],
156
- return_tensors="pt"
157
- ).input_values
158
- # decode ctc output
159
- logits = vi_model(input_values).logits[0]
160
- pred_ids = torch.argmax(logits, dim=-1)
161
- greedy_search_output = processor.decode(pred_ids)
162
- beam_search_output = ngram_lm_model.decode(logits.cpu().detach().numpy(), beam_width=500)
163
- state_vi += beam_search_output + " "
164
- en_text = translate_vi2en(beam_search_output)
165
- state_en += en_text + " "
166
- return state_vi, state_en
167
-
168
- def transcribe_en(audio, state_en="", state_vi=""):
169
- speech, samplerate = load_data(audio)
170
- # Tokenize
171
- input_values = eng_tokenizer(speech, sampling_rate = samplerate, return_tensors="pt").input_values
172
- # Take logits
173
- logits = eng_model(input_values).logits
174
- # Take argmax
175
- predicted_ids = torch.argmax(logits, dim=-1)
176
- # Get the words from predicted word ids
177
- transcription = eng_tokenizer.decode(predicted_ids[0])
178
- # Output is all upper case
179
- transcription = correct_casing(transcription.lower())
180
- state_en += transcription + " "
181
- vi_text = translate_en2vi(transcription)
182
- state_vi += vi_text + " "
183
- return state_en, state_vi
184
-
185
- def transcribe_vi_rm(audio, state_vi="", state_en=""):
186
- ds = speech_file_to_array_fn(audio.name)
187
- # infer model
188
- input_values = processor(
189
- ds["speech"],
190
- sampling_rate=ds["sampling_rate"],
191
- return_tensors="pt"
192
- ).input_values
193
- # decode ctc output
194
- logits = vi_model(input_values).logits[0]
195
- pred_ids = torch.argmax(logits, dim=-1)
196
- greedy_search_output = processor.decode(pred_ids)
197
- beam_search_output = ngram_lm_model.decode(logits.cpu().detach().numpy(), beam_width=500)
198
- state_vi += beam_search_output + " "
199
- en_text = translate_vi2en(beam_search_output)
200
- state_en += en_text + " "
201
- return state_vi, state_en, state_vi, state_en
202
-
203
- def transcribe_en_rm(audio, state_en="", state_vi=""):
204
- speech, samplerate = load_data(audio)
205
- # Tokenize
206
- input_values = eng_tokenizer(speech, sampling_rate = samplerate, return_tensors="pt").input_values
207
- # Take logits
208
- logits = eng_model(input_values).logits
209
- # Take argmax
210
- predicted_ids = torch.argmax(logits, dim=-1)
211
- # Get the words from predicted word ids
212
- transcription = eng_tokenizer.decode(predicted_ids[0])
213
- # Output is all upper case
214
- transcription = correct_casing(transcription.lower())
215
- state_en += transcription + " "
216
- vi_text = translate_en2vi(transcription)
217
- state_vi += vi_text + " "
218
- return state_en, state_vi, state_en, state_vi
219
-
220
- def transcribe_vi_rd(audio, state=""):
221
- ds = speech_file_to_array_fn(audio.name)
222
- # infer model
223
- input_values = processor(
224
- ds["speech"],
225
- sampling_rate=ds["sampling_rate"],
226
- return_tensors="pt"
227
- ).input_values
228
- # decode ctc output
229
- logits = vi_model(input_values).logits[0]
230
- pred_ids = torch.argmax(logits, dim=-1)
231
- greedy_search_output = processor.decode(pred_ids)
232
- beam_search_output = ngram_lm_model.decode(logits.cpu().detach().numpy(), beam_width=500)
233
- en_text = translate_vi2en(beam_search_output)
234
- state += en_text + " "
235
- return state, state
236
-
237
- def transcribe_en_rd(audio, state=""):
238
- speech, samplerate = load_data(audio)
239
- # Tokenize
240
- input_values = eng_tokenizer(speech, sampling_rate = samplerate, return_tensors="pt").input_values
241
- # Take logits
242
- logits = eng_model(input_values).logits
243
- # Take argmax
244
- predicted_ids = torch.argmax(logits, dim=-1)
245
- # Get the words from predicted word ids
246
- transcription = eng_tokenizer.decode(predicted_ids[0])
247
- # Output is all upper case
248
- transcription = correct_casing(transcription.lower())
249
- vi_text = translate_en2vi(transcription)
250
- state += vi_text + " "
251
- return state, state
252
-
253
- """Gradio demo"""
254
-
255
- vi_example_text = ["Có phải bạn đang muốn tìm mua nhà ở ngoại ô thành phố Hồ Chí Minh không?",
256
- "Ánh mắt ta chạm nhau. Chỉ muốn ngắm anh lâu thật lâu.",
257
- "Nếu như một câu nói có thể khiến em vui."]
258
- vi_example_voice =[['vi_speech_01.wav'], ['vi_speech_02.wav'], ['vi_speech_03.wav']]
259
-
260
- en_example_text = ["According to a study by Statista, the global AI market is set to grow up to 54 percent every single year.",
261
- "As one of the world's greatest cities, Air New Zealand is proud to add the Big Apple to its list of 29 international destinations.",
262
- "And yet, earlier this month, I found myself at Halloween Horror Nights at Universal Orlando Resort, one of the most popular Halloween events in the US among hardcore horror buffs."
263
- ]
264
- en_example_voice =[['en_speech_01.wav'], ['en_speech_02.wav'], ['en_speech_03.wav']]
265
-
266
-
267
- with gr.Blocks() as demo:
268
- with gr.Tabs():
269
- with gr.TabItem("Translation: Vietnamese to English"):
270
- with gr.Row():
271
- with gr.Column():
272
- vietnamese_text = gr.Textbox(label="Vietnamese Text")
273
- translate_button_vien_1 = gr.Button(value="Translate To English")
274
- with gr.Column():
275
- english_out_1 = gr.Textbox(label="English Text")
276
- translate_button_vien_1.click(lambda text: translate_vi2en(text), inputs=vietnamese_text, outputs=english_out_1)
277
- gr.Examples(examples=vi_example_text,
278
- inputs=[vietnamese_text])
279
-
280
- with gr.TabItem("Speech2text and Vi-En Translation"):
281
- with gr.Row():
282
- with gr.Column():
283
- vi_audio_1 = gr.Audio(source="microphone", label="Input Vietnamese Audio", type="file", streaming=False)
284
- translate_button_vien_2 = gr.Button(value="Translate To English")
285
- with gr.Column():
286
- speech2text_vi1 = gr.Textbox(label="Vietnamese Text")
287
- english_out_2 = gr.Textbox(label="English Text")
288
- translate_button_vien_2.click(lambda vi_voice: inference_vien(vi_voice), inputs=vi_audio_1, outputs=[speech2text_vi1, english_out_2])
289
- gr.Examples(examples=vi_example_voice,
290
- inputs=[vi_audio_1])
291
-
292
- with gr.TabItem("Vi-En Realtime Translation"):
293
- # with gr.Row():
294
- # with gr.Column():
295
- # vi_audio_2 = gr.Audio(source="microphone", label="Input Vietnamese Audio", type="file", streaming=True)
296
- # with gr.Column():
297
- # speech2text_vi2 = gr.Textbox(label="Vietnamese Text")
298
- # english_out_3 = gr.Textbox(label="English Text")
299
- # vi_audio_2.change(transcribe_vi, [vi_audio_2, speech2text_vi2, english_out_3], [speech2text_vi2, english_out_3])
300
-
301
- gr.Interface(
302
- fn=transcribe_vi_rd,
303
- inputs=[
304
- gr.Audio(source="microphone", type="file", streaming=True),
305
- "state"
306
- ],
307
- outputs=[
308
- "textbox",
309
- "state"
310
- ],
311
- live=True).launch()
312
-
313
-
314
- with gr.Tabs():
315
- with gr.TabItem("Translation: English to Vietnamese"):
316
- with gr.Row():
317
- with gr.Column():
318
- english_text = gr.Textbox(label="English Text")
319
- translate_button_envi_1 = gr.Button(value="Translate To Vietnamese")
320
- with gr.Column():
321
- vietnamese_out_1 = gr.Textbox(label="Vietnamese Text")
322
- translate_button_envi_1.click(lambda text: translate_en2vi(text), inputs=english_text, outputs=vietnamese_out_1)
323
- gr.Examples(examples=en_example_text,
324
- inputs=[english_text])
325
-
326
- with gr.TabItem("Speech2text and En-Vi Translation"):
327
- with gr.Row():
328
- with gr.Column():
329
- en_audio_1 = gr.Audio(source="microphone", label="Input English Audio", type="filepath", streaming=False)
330
- translate_button_envi_2 = gr.Button(value="Translate To Vietnamese")
331
- with gr.Column():
332
- speech2text_en1 = gr.Textbox(label="English Text")
333
- vietnamese_out_2 = gr.Textbox(label="Vietnamese Text")
334
- translate_button_envi_2.click(lambda en_voice: inference_envi(en_voice), inputs=en_audio_1, outputs=[speech2text_en1, vietnamese_out_2])
335
- gr.Examples(examples=en_example_voice,
336
- inputs=[en_audio_1])
337
-
338
- with gr.TabItem("En-Vi Realtime Translation"):
339
- # with gr.Row():
340
- # with gr.Column():
341
- # en_audio_2 = gr.Audio(source="microphone", label="Input English Audio", type="filepath", streaming=True)
342
- # with gr.Column():
343
- # speech2text_en2 = gr.Textbox(label="English Text")
344
- # vietnamese_out_3 = gr.Textbox(label="Vietnamese Text")
345
- # en_audio_2.change(transcribe_en, [en_audio_2, speech2text_en2, vietnamese_out_3], [speech2text_en2, vietnamese_out_3])
346
- # speech2text_en2, vietnamese_out_3 = transcribe_en(en_audio_2, speech2text_en2, vietnamese_out_3)
347
-
348
- gr.Interface(
349
- fn=transcribe_en_rd,
350
- inputs=[
351
- gr.Audio(source="microphone", type="filepath", streaming=True),
352
- "state"
353
- ],
354
- outputs=[
355
- "textbox",
356
- "state"
357
- ],
358
- live=True).launch()
359
-
360
-
361
- if __name__ == "__main__":
362
- demo.launch()