jiuuee commited on
Commit
604efba
Β·
verified Β·
1 Parent(s): dfc887a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +318 -31
app.py CHANGED
@@ -1,40 +1,327 @@
1
  import gradio as gr
2
- from nemo.collections.asr.models import ASRModel
3
  import librosa
 
 
4
  import tempfile
 
 
 
 
 
 
 
 
 
5
 
6
- # Load the NeMo ASR model
7
  model = ASRModel.from_pretrained("nvidia/canary-1b")
8
  model.eval()
9
 
10
- # Function to preprocess the audio
11
- def preprocess_audio(audio, sample_rate):
12
- # Save audio to a temporary file
13
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file:
14
- temp_audio_path = temp_audio_file.name
15
- librosa.output.write_wav(temp_audio_path, audio.squeeze(), sample_rate)
16
- return temp_audio_path
17
-
18
- # Function to transcribe audio
19
- def transcribe_audio(audio):
20
- # Preprocess audio
21
- audio_path = preprocess_audio(audio, 16000)
22
-
23
- # Perform speech recognition
24
- transcription = model.transcribe([audio_path])
25
-
26
- return transcription[0]
27
-
28
- # Interface
29
- audio_input = gr.inputs.Audio(source="microphone", label="Record Audio")
30
- output_text = gr.outputs.Textbox(label="Transcription")
31
-
32
- iface = gr.Interface(
33
- transcribe_audio,
34
- audio_input,
35
- output_text,
36
- title="Automatic Speech Recognition using Canary 1b",
37
- description="Click 'Record Audio' to start recording.",
38
  )
39
 
40
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import json
3
  import librosa
4
+ import os
5
+ import soundfile as sf
6
  import tempfile
7
+ import uuid
8
+ import torch
9
+
10
+ from nemo.collections.asr.models import ASRModel
11
+ from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED
12
+ from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED
13
+
14
+ SAMPLE_RATE = 16000 # Hz
15
+ MAX_AUDIO_MINUTES = 1 # wont try to transcribe if longer than this
16
 
 
17
  model = ASRModel.from_pretrained("nvidia/canary-1b")
18
  model.eval()
19
 
20
+ # make sure beam size always 1 for consistency
21
+ model.change_decoding_strategy(None)
22
+ decoding_cfg = model.cfg.decoding
23
+ decoding_cfg.beam.beam_size = 1
24
+ model.change_decoding_strategy(decoding_cfg)
25
+
26
+ # setup for buffered inference
27
+ model.cfg.preprocessor.dither = 0.0
28
+ model.cfg.preprocessor.pad_to = 0
29
+
30
+ feature_stride = model.cfg.preprocessor['window_stride']
31
+ model_stride_in_secs = feature_stride * 8 # 8 = model stride, which is 8 for FastConformer
32
+
33
+ frame_asr = FrameBatchMultiTaskAED(
34
+ asr_model=model,
35
+ frame_len=40.0,
36
+ total_buffer=40.0,
37
+ batch_size=16,
 
 
 
 
 
 
 
 
 
 
38
  )
39
 
40
+ amp_dtype = torch.float16
41
+
42
+ def convert_audio(audio_filepath, tmpdir, utt_id):
43
+ """
44
+ Convert all files to monochannel 16 kHz wav files.
45
+ Do not convert and raise error if audio too long.
46
+ Returns output filename and duration.
47
+ """
48
+
49
+ data, sr = librosa.load(audio_filepath, sr=None, mono=True)
50
+
51
+ duration = librosa.get_duration(y=data, sr=sr)
52
+
53
+ if duration / 60.0 > MAX_AUDIO_MINUTES:
54
+ raise gr.Error(
55
+ f"This demo can transcribe up to {MAX_AUDIO_MINUTES} minutes of audio. "
56
+ "If you wish, you may trim the audio using the Audio viewer in Step 1 "
57
+ "(click on the scissors icon to start trimming audio)."
58
+ )
59
+
60
+ if sr != SAMPLE_RATE:
61
+ data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
62
+
63
+ out_filename = os.path.join(tmpdir, utt_id + '.wav')
64
+
65
+ # save output audio
66
+ sf.write(out_filename, data, SAMPLE_RATE)
67
+
68
+ return out_filename, duration
69
+
70
+
71
+ def transcribe(audio_filepath, src_lang, tgt_lang, pnc):
72
+
73
+ if audio_filepath is None:
74
+ raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
75
+
76
+ utt_id = uuid.uuid4()
77
+ with tempfile.TemporaryDirectory() as tmpdir:
78
+ converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id))
79
+
80
+ # map src_lang and tgt_lang from long versions to short
81
+ LANG_LONG_TO_LANG_SHORT = {
82
+ "English": "en",
83
+ "Spanish": "es",
84
+ "French": "fr",
85
+ "German": "de",
86
+ }
87
+ if src_lang not in LANG_LONG_TO_LANG_SHORT.keys():
88
+ raise ValueError(f"src_lang must be one of {LANG_LONG_TO_LANG_SHORT.keys()}")
89
+ else:
90
+ src_lang = LANG_LONG_TO_LANG_SHORT[src_lang]
91
+
92
+ if tgt_lang not in LANG_LONG_TO_LANG_SHORT.keys():
93
+ raise ValueError(f"tgt_lang must be one of {LANG_LONG_TO_LANG_SHORT.keys()}")
94
+ else:
95
+ tgt_lang = LANG_LONG_TO_LANG_SHORT[tgt_lang]
96
+
97
+
98
+ # infer taskname from src_lang and tgt_lang
99
+ if src_lang == tgt_lang:
100
+ taskname = "asr"
101
+ else:
102
+ taskname = "s2t_translation"
103
+
104
+ # update pnc variable to be "yes" or "no"
105
+ pnc = "yes" if pnc else "no"
106
+
107
+ # make manifest file and save
108
+ manifest_data = {
109
+ "audio_filepath": converted_audio_filepath,
110
+ "source_lang": src_lang,
111
+ "target_lang": tgt_lang,
112
+ "taskname": taskname,
113
+ "pnc": pnc,
114
+ "answer": "predict",
115
+ "duration": str(duration),
116
+ }
117
+
118
+ manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json')
119
+
120
+ with open(manifest_filepath, 'w') as fout:
121
+ line = json.dumps(manifest_data)
122
+ fout.write(line + '\n')
123
+
124
+ # call transcribe, passing in manifest filepath
125
+ if duration < 40:
126
+ output_text = model.transcribe(manifest_filepath)[0]
127
+ else: # do buffered inference
128
+ with torch.cuda.amp.autocast(dtype=amp_dtype): # TODO: make it work if no cuda
129
+ with torch.no_grad():
130
+ hyps = get_buffered_pred_feat_multitaskAED(
131
+ frame_asr,
132
+ model.cfg.preprocessor,
133
+ model_stride_in_secs,
134
+ model.device,
135
+ manifest=manifest_filepath,
136
+ filepaths=None,
137
+ )
138
+
139
+ output_text = hyps[0].text
140
+
141
+ return output_text
142
+
143
+ # add logic to make sure dropdown menus only suggest valid combos
144
+ def on_src_or_tgt_lang_change(src_lang_value, tgt_lang_value, pnc_value):
145
+ """Callback function for when src_lang or tgt_lang dropdown menus are changed.
146
+ Args:
147
+ src_lang_value(string), tgt_lang_value (string), pnc_value(bool) - the current
148
+ chosen "values" of each Gradio component
149
+ Returns:
150
+ src_lang, tgt_lang, pnc - these are the new Gradio components that will be displayed
151
+
152
+ Note: I found the required logic is easier to understand if you think about the possible src & tgt langs as
153
+ a matrix, e.g. with English, Spanish, French, German as the langs, and only transcription in the same language,
154
+ and X -> English and English -> X translation being allowed, the matrix looks like the diagram below ("Y" means it is
155
+ allowed to go into that state).
156
+ It is easier to understand the code if you think about which state you are in, given the current src_lang_value and
157
+ tgt_lang_value, and then which states you can go to from there.
158
+ tgt lang
159
+ - |EN |ES |FR |DE
160
+ ------------------
161
+ EN| Y | Y | Y | Y
162
+ ------------------
163
+ src ES| Y | Y | |
164
+ lang ------------------
165
+ FR| Y | | Y |
166
+ ------------------
167
+ DE| Y | | | Y
168
+ """
169
+
170
+ if src_lang_value == "English" and tgt_lang_value == "English":
171
+ # src_lang and tgt_lang can go anywhere
172
+ src_lang = gr.Dropdown(
173
+ choices=["English", "Spanish", "French", "German"],
174
+ value=src_lang_value,
175
+ label="Input audio is spoken in:"
176
+ )
177
+ tgt_lang = gr.Dropdown(
178
+ choices=["English", "Spanish", "French", "German"],
179
+ value=tgt_lang_value,
180
+ label="Transcribe in language:"
181
+ )
182
+ elif src_lang_value == "English":
183
+ # src is English & tgt is non-English
184
+ # => src can only be English or current tgt_lang_values
185
+ # & tgt can be anything
186
+ src_lang = gr.Dropdown(
187
+ choices=["English", tgt_lang_value],
188
+ value=src_lang_value,
189
+ label="Input audio is spoken in:"
190
+ )
191
+ tgt_lang = gr.Dropdown(
192
+ choices=["English", "Spanish", "French", "German"],
193
+ value=tgt_lang_value,
194
+ label="Transcribe in language:"
195
+ )
196
+ elif tgt_lang_value == "English":
197
+ # src is non-English & tgt is English
198
+ # => src can be anything
199
+ # & tgt can only be English or current src_lang_value
200
+ src_lang = gr.Dropdown(
201
+ choices=["English", "Spanish", "French", "German"],
202
+ value=src_lang_value,
203
+ label="Input audio is spoken in:"
204
+ )
205
+ tgt_lang = gr.Dropdown(
206
+ choices=["English", src_lang_value],
207
+ value=tgt_lang_value,
208
+ label="Transcribe in language:"
209
+ )
210
+ else:
211
+ # both src and tgt are non-English
212
+ # => both src and tgt can only be switch to English or themselves
213
+ src_lang = gr.Dropdown(
214
+ choices=["English", src_lang_value],
215
+ value=src_lang_value,
216
+ label="Input audio is spoken in:"
217
+ )
218
+ tgt_lang = gr.Dropdown(
219
+ choices=["English", tgt_lang_value],
220
+ value=tgt_lang_value,
221
+ label="Transcribe in language:"
222
+ )
223
+ # let pnc be anything if src_lang_value == tgt_lang_value, else fix to True
224
+ if src_lang_value == tgt_lang_value:
225
+ pnc = gr.Checkbox(
226
+ value=pnc_value,
227
+ label="Punctuation & Capitalization in transcript?",
228
+ interactive=True
229
+ )
230
+ else:
231
+ pnc = gr.Checkbox(
232
+ value=True,
233
+ label="Punctuation & Capitalization in transcript?",
234
+ interactive=False
235
+ )
236
+ return src_lang, tgt_lang, pnc
237
+
238
+
239
+ with gr.Blocks(
240
+ title="NeMo Canary Model",
241
+ css="""
242
+ textarea { font-size: 18px;}
243
+ #model_output_text_box span {
244
+ font-size: 18px;
245
+ font-weight: bold;
246
+ }
247
+ """,
248
+ theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) # make text slightly bigger (default is text_md )
249
+ ) as demo:
250
+
251
+ gr.HTML("<h1 style='text-align: center'>NeMo Canary model: Transcribe & Translate audio</h1>")
252
+
253
+ with gr.Row():
254
+ with gr.Column():
255
+ gr.HTML(
256
+ "<p><b>Step 1:</b> Upload an audio file or record with your microphone.</p>"
257
+
258
+ "<p style='color: #A0A0A0;'>This demo supports audio files up to 10 mins long. "
259
+ "You can transcribe longer files locally with this NeMo "
260
+ "<a href='https://github.com/NVIDIA/NeMo/blob/main/examples/asr/speech_multitask/speech_to_text_aed_chunked_infer.py'>script</a>.</p>"
261
+ )
262
+
263
+ audio_file = gr.Audio(sources=["microphone", "upload"], type="filepath")
264
+
265
+ gr.HTML("<p><b>Step 2:</b> Choose the input and output language.</p>")
266
+
267
+ src_lang = gr.Dropdown(
268
+ choices=["English", "Spanish", "French", "German"],
269
+ value="English",
270
+ label="Input audio is spoken in:"
271
+ )
272
+
273
+ with gr.Column():
274
+ tgt_lang = gr.Dropdown(
275
+ choices=["English", "Spanish", "French", "German"],
276
+ value="English",
277
+ label="Transcribe in language:"
278
+ )
279
+ pnc = gr.Checkbox(
280
+ value=True,
281
+ label="Punctuation & Capitalization in transcript?",
282
+ )
283
+
284
+ with gr.Column():
285
+
286
+ gr.HTML("<p><b>Step 3:</b> Run the model.</p>")
287
+
288
+ go_button = gr.Button(
289
+ value="Run model",
290
+ variant="primary", # make "primary" so it stands out (default is "secondary")
291
+ )
292
+
293
+ model_output_text_box = gr.Textbox(
294
+ label="Model Output",
295
+ elem_id="model_output_text_box",
296
+ )
297
+
298
+ with gr.Row():
299
+
300
+ gr.HTML(
301
+ "<p style='text-align: center'>"
302
+ "🐀 <a href='https://huggingface.co/nvidia/canary-1b' target='_blank'>Canary model</a> | "
303
+ "πŸ§‘β€πŸ’» <a href='https://github.com/NVIDIA/NeMo' target='_blank'>NeMo Repository</a>"
304
+ "</p>"
305
+ )
306
+
307
+ go_button.click(
308
+ fn=transcribe,
309
+ inputs = [audio_file, src_lang, tgt_lang, pnc],
310
+ outputs = [model_output_text_box]
311
+ )
312
+
313
+ # call on_src_or_tgt_lang_change whenever src_lang or tgt_lang dropdown menus are changed
314
+ src_lang.change(
315
+ fn=on_src_or_tgt_lang_change,
316
+ inputs=[src_lang, tgt_lang, pnc],
317
+ outputs=[src_lang, tgt_lang, pnc],
318
+ )
319
+ tgt_lang.change(
320
+ fn=on_src_or_tgt_lang_change,
321
+ inputs=[src_lang, tgt_lang, pnc],
322
+ outputs=[src_lang, tgt_lang, pnc],
323
+ )
324
+
325
+
326
+ demo.queue()
327
+ demo.launch()