datnth1709 commited on
Commit
7c45a5f
·
1 Parent(s): f0cffbb

realtime translate

Browse files
Files changed (2) hide show
  1. app.py +14 -74
  2. convert.ipynb +112 -0
app.py CHANGED
@@ -181,40 +181,6 @@ def transcribe_en(audio, state_en="", state_vi=""):
181
  state_vi += vi_text + "+"
182
  return state_en, state_vi
183
 
184
- def transcribe_vi_1(audio, state_en=""):
185
- ds = speech_file_to_array_fn(audio.name)
186
- # infer model
187
- input_values = processor(
188
- ds["speech"],
189
- sampling_rate=ds["sampling_rate"],
190
- return_tensors="pt"
191
- ).input_values
192
- # decode ctc output
193
- logits = vi_model(input_values).logits[0]
194
- pred_ids = torch.argmax(logits, dim=-1)
195
- greedy_search_output = processor.decode(pred_ids)
196
- beam_search_output = ngram_lm_model.decode(logits.cpu().detach().numpy(), beam_width=500)
197
- en_text = translate_vi2en(beam_search_output)
198
- state_en += en_text + " "
199
- return state_en, state_en
200
-
201
- def transcribe_en_1(audio, state_vi=""):
202
- speech = load_data(audio)
203
- # Tokenize
204
- input_values = eng_tokenizer(speech, return_tensors="pt").input_values
205
- # Take logits
206
- logits = eng_model(input_values).logits
207
- # Take argmax
208
- predicted_ids = torch.argmax(logits, dim=-1)
209
- # Get the words from predicted word ids
210
- transcription = eng_tokenizer.decode(predicted_ids[0])
211
- # Output is all upper case
212
- transcription = correct_casing(transcription.lower())
213
- vi_text = translate_en2vi(transcription)
214
- state_vi += vi_text + "+"
215
- return state_vi, state_vi
216
-
217
-
218
  """Gradio demo"""
219
 
220
  vi_example_text = ["Có phải bạn đang muốn tìm mua nhà ở ngoại ô thành phố Hồ Chí Minh không?",
@@ -255,26 +221,13 @@ with gr.Blocks() as demo:
255
  inputs=[vi_audio_1])
256
 
257
  with gr.TabItem("Vi-En Realtime Translation"):
258
- gr.Interface(
259
- fn=transcribe_vi_1,
260
- inputs=[
261
- gr.Audio(source="microphone", label="Input Vietnamese Audio", type="file", streaming=True),
262
- "state",
263
- ],
264
- outputs= [
265
- "text",
266
- "state",
267
-
268
- ],
269
- live=True).launch()
270
-
271
- # with gr.Row():
272
- # with gr.Column():
273
- # vi_audio_2 = gr.Audio(source="microphone", label="Input Vietnamese Audio", type="file", streaming=True)
274
- # with gr.Column():
275
- # speech2text_vi2 = gr.Textbox(label="Vietnamese Text")
276
- # english_out_3 = gr.Textbox(label="English Text")
277
- # vi_audio_2.change(transcribe_vi, [vi_audio_2, speech2text_vi2, english_out_3], [speech2text_vi2, english_out_3])
278
 
279
 
280
  with gr.Tabs():
@@ -302,26 +255,13 @@ with gr.Blocks() as demo:
302
  inputs=[en_audio_1])
303
 
304
  with gr.TabItem("En-Vi Realtime Translation"):
305
- gr.Interface(
306
- fn=transcribe_en_1,
307
- inputs=[
308
- gr.Audio(source="microphone", label="Input English Audio", type="filepath", streaming=True),
309
- "state",
310
- ],
311
- outputs= [
312
- "text",
313
- "state",
314
-
315
- ],
316
- live=True).launch()
317
-
318
- # with gr.Row():
319
- # with gr.Column():
320
- # en_audio_2 = gr.Audio(source="microphone", label="Input English Audio", type="filepath", streaming=True)
321
- # with gr.Column():
322
- # speech2text_en2 = gr.Textbox(label="English Text")
323
- # vietnamese_out_3 = gr.Textbox(label="Vietnamese Text")
324
- # en_audio_2.change(transcribe_en, [en_audio_2, speech2text_en2, vietnamese_out_3], [speech2text_en2, vietnamese_out_3])
325
 
326
  if __name__ == "__main__":
327
  demo.launch()
 
181
  state_vi += vi_text + "+"
182
  return state_en, state_vi
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  """Gradio demo"""
185
 
186
  vi_example_text = ["Có phải bạn đang muốn tìm mua nhà ở ngoại ô thành phố Hồ Chí Minh không?",
 
221
  inputs=[vi_audio_1])
222
 
223
  with gr.TabItem("Vi-En Realtime Translation"):
224
+ with gr.Row():
225
+ with gr.Column():
226
+ vi_audio_2 = gr.Audio(source="microphone", label="Input Vietnamese Audio", type="file", streaming=True)
227
+ with gr.Column():
228
+ speech2text_vi2 = gr.Textbox(label="Vietnamese Text")
229
+ english_out_3 = gr.Textbox(label="English Text")
230
+ vi_audio_2.change(transcribe_vi, [vi_audio_2, speech2text_vi2, english_out_3], [speech2text_vi2, english_out_3])
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
 
233
  with gr.Tabs():
 
255
  inputs=[en_audio_1])
256
 
257
  with gr.TabItem("En-Vi Realtime Translation"):
258
+ with gr.Row():
259
+ with gr.Column():
260
+ en_audio_2 = gr.Audio(source="microphone", label="Input English Audio", type="filepath", streaming=True)
261
+ with gr.Column():
262
+ speech2text_en2 = gr.Textbox(label="English Text")
263
+ vietnamese_out_3 = gr.Textbox(label="Vietnamese Text")
264
+ en_audio_2.change(transcribe_en, [en_audio_2, speech2text_en2, vietnamese_out_3], [speech2text_en2, vietnamese_out_3])
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
  if __name__ == "__main__":
267
  demo.launch()
convert.ipynb ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "data": {
10
+ "application/json": {
11
+ "ascii": false,
12
+ "bar_format": null,
13
+ "colour": null,
14
+ "elapsed": 0.014345407485961914,
15
+ "initial": 0,
16
+ "n": 0,
17
+ "ncols": null,
18
+ "nrows": null,
19
+ "postfix": null,
20
+ "prefix": "Downloading",
21
+ "rate": null,
22
+ "total": 1596,
23
+ "unit": "B",
24
+ "unit_divisor": 1024,
25
+ "unit_scale": true
26
+ },
27
+ "application/vnd.jupyter.widget-view+json": {
28
+ "model_id": "6c1e4c5c553c4150b92ef38251ec5ccd",
29
+ "version_major": 2,
30
+ "version_minor": 0
31
+ },
32
+ "text/plain": [
33
+ "Downloading: 0%| | 0.00/1.56k [00:00<?, ?B/s]"
34
+ ]
35
+ },
36
+ "metadata": {},
37
+ "output_type": "display_data"
38
+ },
39
+ {
40
+ "ename": "ValueError",
41
+ "evalue": "Unrecognized configuration class <class 'transformers.models.wav2vec2.configuration_wav2vec2.Wav2Vec2Config'> for this kind of AutoModel: AutoModelForSequenceClassification.\nModel type should be one of AlbertConfig, BartConfig, BertConfig, BigBirdConfig, BigBirdPegasusConfig, BloomConfig, CamembertConfig, CanineConfig, ConvBertConfig, CTRLConfig, Data2VecTextConfig, DebertaConfig, DebertaV2Config, DistilBertConfig, ElectraConfig, FlaubertConfig, FNetConfig, FunnelConfig, GPT2Config, GPTNeoConfig, GPTJConfig, IBertConfig, LayoutLMConfig, LayoutLMv2Config, LayoutLMv3Config, LEDConfig, LongformerConfig, MBartConfig, MegatronBertConfig, MobileBertConfig, MPNetConfig, NystromformerConfig, OpenAIGPTConfig, PerceiverConfig, PLBartConfig, QDQBertConfig, ReformerConfig, RemBertConfig, RobertaConfig, RoFormerConfig, SqueezeBertConfig, TapasConfig, TransfoXLConfig, XLMConfig, XLMRobertaConfig, XLMRobertaXLConfig, XLNetConfig, YosoConfig.",
42
+ "output_type": "error",
43
+ "traceback": [
44
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
45
+ "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)",
46
+ "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_29808\\3826148515.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[1;31m# load model and tokenizer\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[0mmodel_id\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;34m\"facebook/wav2vec2-base-960h\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 6\u001b[1;33m \u001b[0mmodel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mAutoModelForSequenceClassification\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_pretrained\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel_id\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 7\u001b[0m \u001b[0mtokenizer\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mAutoTokenizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_pretrained\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel_id\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 8\u001b[0m \u001b[0mdummy_model_input\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"This is a sample\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mreturn_tensors\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m\"pt\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
47
+ "\u001b[1;32mc:\\Python37\\lib\\site-packages\\transformers\\models\\auto\\auto_factory.py\u001b[0m in \u001b[0;36mfrom_pretrained\u001b[1;34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001b[0m\n\u001b[0;32m 446\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mmodel_class\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_pretrained\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpretrained_model_name_or_path\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0mmodel_args\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 447\u001b[0m raise ValueError(\n\u001b[1;32m--> 448\u001b[1;33m \u001b[1;34mf\"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\\n\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 449\u001b[0m \u001b[1;34mf\"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}.\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 450\u001b[0m )\n",
48
+ "\u001b[1;31mValueError\u001b[0m: Unrecognized configuration class <class 'transformers.models.wav2vec2.configuration_wav2vec2.Wav2Vec2Config'> for this kind of AutoModel: AutoModelForSequenceClassification.\nModel type should be one of AlbertConfig, BartConfig, BertConfig, BigBirdConfig, BigBirdPegasusConfig, BloomConfig, CamembertConfig, CanineConfig, ConvBertConfig, CTRLConfig, Data2VecTextConfig, DebertaConfig, DebertaV2Config, DistilBertConfig, ElectraConfig, FlaubertConfig, FNetConfig, FunnelConfig, GPT2Config, GPTNeoConfig, GPTJConfig, IBertConfig, LayoutLMConfig, LayoutLMv2Config, LayoutLMv3Config, LEDConfig, LongformerConfig, MBartConfig, MegatronBertConfig, MobileBertConfig, MPNetConfig, NystromformerConfig, OpenAIGPTConfig, PerceiverConfig, PLBartConfig, QDQBertConfig, ReformerConfig, RemBertConfig, RobertaConfig, RoFormerConfig, SqueezeBertConfig, TapasConfig, TransfoXLConfig, XLMConfig, XLMRobertaConfig, XLMRobertaXLConfig, XLNetConfig, YosoConfig."
49
+ ]
50
+ }
51
+ ],
52
+ "source": [
53
+ "import torch\n",
54
+ "from transformers import AutoModelForSequenceClassification, AutoTokenizer, Wav2Vec2Tokenizer, Wav2Vec2ForCTC\n",
55
+ "\n",
56
+ "# load model and tokenizer\n",
57
+ "model_name = \"facebook/wav2vec2-base-960h\"\n",
58
+ "tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)\n",
59
+ "model = Wav2Vec2ForCTC.from_pretrained(model_name)\n",
60
+ "dummy_model_input = tokenizer(\"This is a sample\", return_tensors=\"pt\")\n",
61
+ "\n",
62
+ "# export\n",
63
+ "torch.onnx.export(\n",
64
+ " model, \n",
65
+ " tuple(dummy_model_input.values()),\n",
66
+ " f=\"torch-model.onnx\", \n",
67
+ " input_names=['input_ids', 'attention_mask'], \n",
68
+ " output_names=['logits'], \n",
69
+ " dynamic_axes={'input_ids': {0: 'batch_size', 1: 'sequence'}, \n",
70
+ " 'attention_mask': {0: 'batch_size', 1: 'sequence'}, \n",
71
+ " 'logits': {0: 'batch_size', 1: 'sequence'}}, \n",
72
+ " do_constant_folding=True, \n",
73
+ " opset_version=13, \n",
74
+ ")\n"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": null,
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": []
83
+ }
84
+ ],
85
+ "metadata": {
86
+ "kernelspec": {
87
+ "display_name": "Python 3.7.9 64-bit",
88
+ "language": "python",
89
+ "name": "python3"
90
+ },
91
+ "language_info": {
92
+ "codemirror_mode": {
93
+ "name": "ipython",
94
+ "version": 3
95
+ },
96
+ "file_extension": ".py",
97
+ "mimetype": "text/x-python",
98
+ "name": "python",
99
+ "nbconvert_exporter": "python",
100
+ "pygments_lexer": "ipython3",
101
+ "version": "3.7.9"
102
+ },
103
+ "orig_nbformat": 4,
104
+ "vscode": {
105
+ "interpreter": {
106
+ "hash": "d49c3f6d6dd49f9272b571d9fad348ab55b8c6c3f691520d74ed0af1f69c3dd8"
107
+ }
108
+ }
109
+ },
110
+ "nbformat": 4,
111
+ "nbformat_minor": 2
112
+ }