thanhtvt commited on
Commit
e9812a3
1 Parent(s): 5e958ff

add new decoders

Browse files
Files changed (4) hide show
  1. app.py +190 -12
  2. decode.py +44 -0
  3. model.py +28 -8
  4. requirements.txt +1 -2
app.py CHANGED
@@ -1,9 +1,12 @@
 
1
  import gradio as gr
2
  import librosa
3
  import logging
4
  import os
5
  import soundfile as sf
6
- import tensorflow as tf
 
 
7
 
8
  from datetime import datetime
9
  from time import time
@@ -13,7 +16,7 @@ from model import UETASRModel
13
 
14
 
15
  def get_duration(filename: str) -> float:
16
- return librosa.get_duration(filename=filename)
17
 
18
 
19
  def convert_to_wav(in_filename: str) -> str:
@@ -24,6 +27,39 @@ def convert_to_wav(in_filename: str) -> str:
24
  return out_filename
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def build_html_output(s: str, style: str = "result_item_success"):
28
  return f"""
29
  <div class='result'>
@@ -34,7 +70,34 @@ def build_html_output(s: str, style: str = "result_item_success"):
34
  """
35
 
36
 
37
- def process_uploaded_file(in_filename: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  if in_filename is None or in_filename == "":
39
  return "", build_html_output(
40
  "Please first upload a file and then click "
@@ -44,13 +107,23 @@ def process_uploaded_file(in_filename: str):
44
 
45
  logging.info(f"Processing uploaded file: {in_filename}")
46
  try:
47
- return process(in_filename=in_filename)
 
 
 
 
48
  except Exception as e:
49
- logging.error(str(e))
50
  return "", build_html_output(str(e), "result_item_error")
51
 
52
 
53
- def process_microphone(in_filename: str):
 
 
 
 
 
 
54
  if in_filename is None or in_filename == "":
55
  return "", build_html_output(
56
  "Please first upload a file and then click "
@@ -60,13 +133,23 @@ def process_microphone(in_filename: str):
60
 
61
  logging.info(f"Processing microphone: {in_filename}")
62
  try:
63
- return process(in_filename=in_filename)
 
 
 
 
64
  except Exception as e:
65
- logging.error(str(e))
66
  return "", build_html_output(str(e), "result_item_error")
67
 
68
 
69
- def process(in_filename: str):
 
 
 
 
 
 
70
  logging.info(f"in_filename: {in_filename}")
71
 
72
  filename = convert_to_wav(in_filename)
@@ -79,7 +162,11 @@ def process(in_filename: str):
79
 
80
  start = time()
81
 
82
- recognizer = UETASRModel(repo_id)
 
 
 
 
83
  text = recognizer.predict(filename)
84
 
85
  date_time = now.strftime("%d/%m/%Y, %H:%M:%S.%f")
@@ -130,6 +217,61 @@ demo = gr.Blocks(css=css)
130
  with demo:
131
  gr.Markdown(title)
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  with gr.Tabs():
134
  with gr.TabItem("Upload from disk"):
135
  uploaded_file = gr.Audio(
@@ -166,17 +308,53 @@ with demo:
166
  fn=process_microphone,
167
  )
168
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  upload_button.click(
170
  process_uploaded_file,
171
- inputs=uploaded_file,
 
 
 
 
 
 
172
  outputs=[uploaded_output, uploaded_html_info],
173
  )
174
 
175
  record_button.click(
176
  process_microphone,
177
- inputs=microphone,
 
 
 
 
 
 
178
  outputs=[recorded_output, recorded_html_info],
179
  )
 
 
 
 
 
 
 
 
 
 
 
 
180
  gr.Markdown(description)
181
 
182
 
 
1
+ import base64
2
  import gradio as gr
3
  import librosa
4
  import logging
5
  import os
6
  import soundfile as sf
7
+ import subprocess
8
+ import tempfile
9
+ import urllib.request
10
 
11
  from datetime import datetime
12
  from time import time
 
16
 
17
 
18
  def get_duration(filename: str) -> float:
19
+ return librosa.get_duration(path=filename)
20
 
21
 
22
  def convert_to_wav(in_filename: str) -> str:
 
27
  return out_filename
28
 
29
 
30
+ def convert_to_wav1(in_filename: str) -> str:
31
+ """Convert the input audio file to a wave file"""
32
+ out_filename = in_filename + ".wav"
33
+ logging.info(f"Converting '{in_filename}' to '{out_filename}'")
34
+ _ = os.system(f"ffmpeg -hide_banner -i '{in_filename}' -ar 16000 '{out_filename}'")
35
+ _ = os.system(
36
+ f"ffmpeg -hide_banner -loglevel error -i '{in_filename}' -ar 16000 '{out_filename}.flac'"
37
+ )
38
+
39
+ with open(out_filename + ".flac", "rb") as f:
40
+ s = "\n" + out_filename + "\n"
41
+ s += base64.b64encode(f.read()).decode()
42
+ logging.info(s)
43
+
44
+ return out_filename
45
+
46
+
47
+ def convert_to_wav2(in_filename: str) -> str:
48
+ """Convert the input audio file to a wave file"""
49
+ out_filename = in_filename + ".wav"
50
+ logging.info(f"Converting '{in_filename}' to '{out_filename}'")
51
+
52
+ sp_args = ["ffmpeg", "-hide_banner", "-i", in_filename, "-ar", "16000", out_filename]
53
+ sp_args.insert(2, "-y") if os.path.exists(out_filename) else None
54
+ # Create a subprocess to run the ffmpeg command.
55
+ _ = subprocess.Popen(
56
+ sp_args,
57
+ stdin=subprocess.PIPE,
58
+ )
59
+
60
+ return out_filename
61
+
62
+
63
  def build_html_output(s: str, style: str = "result_item_success"):
64
  return f"""
65
  <div class='result'>
 
70
  """
71
 
72
 
73
+ def process_url(
74
+ url: str,
75
+ decoding_method: str,
76
+ beam_size: int,
77
+ max_symbols_per_step: int,
78
+ max_out_seq_len_ratio: float,
79
+ ):
80
+ logging.info(f"Processing URL: {url}")
81
+ with tempfile.NamedTemporaryFile() as f:
82
+ try:
83
+ urllib.request.urlretrieve(url, f.name)
84
+ return process(in_filename=f.name,
85
+ decoding_method=decoding_method,
86
+ beam_size=beam_size,
87
+ max_symbols_per_step=max_symbols_per_step,
88
+ max_out_seq_len_ratio=max_out_seq_len_ratio)
89
+ except Exception as e:
90
+ logging.info(str(e))
91
+ return "", build_html_output(str(e), "result_item_error")
92
+
93
+
94
+ def process_uploaded_file(
95
+ in_filename: str,
96
+ decoding_method: str,
97
+ beam_size: int,
98
+ max_symbols_per_step: int,
99
+ max_out_seq_len_ratio: float,
100
+ ):
101
  if in_filename is None or in_filename == "":
102
  return "", build_html_output(
103
  "Please first upload a file and then click "
 
107
 
108
  logging.info(f"Processing uploaded file: {in_filename}")
109
  try:
110
+ return process(in_filename=in_filename,
111
+ decoding_method=decoding_method,
112
+ beam_size=beam_size,
113
+ max_symbols_per_step=max_symbols_per_step,
114
+ max_out_seq_len_ratio=max_out_seq_len_ratio)
115
  except Exception as e:
116
+ logging.info(str(e))
117
  return "", build_html_output(str(e), "result_item_error")
118
 
119
 
120
+ def process_microphone(
121
+ in_filename: str,
122
+ decoding_method: str,
123
+ beam_size: int,
124
+ max_symbols_per_step: int,
125
+ max_out_seq_len_ratio: float,
126
+ ):
127
  if in_filename is None or in_filename == "":
128
  return "", build_html_output(
129
  "Please first upload a file and then click "
 
133
 
134
  logging.info(f"Processing microphone: {in_filename}")
135
  try:
136
+ return process(in_filename=in_filename,
137
+ decoding_method=decoding_method,
138
+ beam_size=beam_size,
139
+ max_symbols_per_step=max_symbols_per_step,
140
+ max_out_seq_len_ratio=max_out_seq_len_ratio)
141
  except Exception as e:
142
+ logging.info(str(e))
143
  return "", build_html_output(str(e), "result_item_error")
144
 
145
 
146
+ def process(
147
+ in_filename: str,
148
+ decoding_method: str,
149
+ beam_size: int,
150
+ max_symbols_per_step: int,
151
+ max_out_seq_len_ratio: float,
152
+ ):
153
  logging.info(f"in_filename: {in_filename}")
154
 
155
  filename = convert_to_wav(in_filename)
 
162
 
163
  start = time()
164
 
165
+ recognizer = UETASRModel(repo_id,
166
+ decoding_method,
167
+ beam_size,
168
+ max_symbols_per_step,
169
+ max_out_seq_len_ratio)
170
  text = recognizer.predict(filename)
171
 
172
  date_time = now.strftime("%d/%m/%Y, %H:%M:%S.%f")
 
217
  with demo:
218
  gr.Markdown(title)
219
 
220
+ decode_method_radio = gr.Radio(
221
+ label="Decoding method",
222
+ choices=["greedy_search", "beam_search", "alsd_search"],
223
+ value="greedy_search",
224
+ interactive=True,
225
+ )
226
+
227
+ with gr.Column(visible=False) as beam_col:
228
+ beam_size = gr.Slider(
229
+ label="Beam size",
230
+ minimum=1,
231
+ maximum=10,
232
+ step=1,
233
+ value=5,
234
+ interactive=True,
235
+ )
236
+
237
+ def enable_beam_col(decoding_method):
238
+ if decoding_method != "greedy_search":
239
+ return gr.update(visible=True)
240
+ else:
241
+ return gr.update(visible=False)
242
+
243
+ decode_method_radio.change(enable_beam_col, decode_method_radio, beam_col)
244
+
245
+ max_symbols_per_step_slider = gr.Slider(
246
+ label="Maximum symbols per step",
247
+ minimum=1,
248
+ maximum=15,
249
+ step=1,
250
+ value=5,
251
+ interactive=True,
252
+ visible=True,
253
+ )
254
+
255
+ max_out_seq_len_slider = gr.Slider(
256
+ label="Maximum output sequence length ratio",
257
+ minimum=0,
258
+ maximum=1,
259
+ step=0.01,
260
+ value=0.6,
261
+ interactive=True,
262
+ visible=False,
263
+ )
264
+
265
+ def switch_slider(decoding_method):
266
+ if decoding_method == "alsd_search":
267
+ return gr.update(visible=False), gr.update(visible=True)
268
+ else:
269
+ return gr.update(visible=True), gr.update(visible=False)
270
+
271
+ decode_method_radio.change(switch_slider,
272
+ decode_method_radio,
273
+ [max_symbols_per_step_slider, max_out_seq_len_slider])
274
+
275
  with gr.Tabs():
276
  with gr.TabItem("Upload from disk"):
277
  uploaded_file = gr.Audio(
 
308
  fn=process_microphone,
309
  )
310
 
311
+ with gr.TabItem("From URL"):
312
+ url_textbox = gr.Textbox(
313
+ max_lines=1,
314
+ placeholder="URL to an audio file",
315
+ label="URL",
316
+ interactive=True,
317
+ )
318
+
319
+ url_button = gr.Button("Submit for recognition")
320
+ url_output = gr.Textbox(label="Recognized speech from URL")
321
+ url_html_info = gr.HTML(label="Info")
322
+
323
  upload_button.click(
324
  process_uploaded_file,
325
+ inputs=[
326
+ uploaded_file,
327
+ decode_method_radio,
328
+ beam_size,
329
+ max_symbols_per_step_slider,
330
+ max_out_seq_len_slider,
331
+ ],
332
  outputs=[uploaded_output, uploaded_html_info],
333
  )
334
 
335
  record_button.click(
336
  process_microphone,
337
+ inputs=[
338
+ microphone,
339
+ decode_method_radio,
340
+ beam_size,
341
+ max_symbols_per_step_slider,
342
+ max_out_seq_len_slider,
343
+ ],
344
  outputs=[recorded_output, recorded_html_info],
345
  )
346
+
347
+ url_button.click(
348
+ process_url,
349
+ inputs=[
350
+ url_textbox,
351
+ decode_method_radio,
352
+ beam_size,
353
+ max_symbols_per_step_slider,
354
+ max_out_seq_len_slider,
355
+ ],
356
+ outputs=[url_output, url_html_info],
357
+ )
358
  gr.Markdown(description)
359
 
360
 
decode.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import tensorflow as tf
3
+ from functools import lru_cache
4
+ from uetasr.searchers import GreedyRNNT, BeamRNNT, ALSDBeamRNNT
5
+
6
+
7
+ @lru_cache(maxsize=5)
8
+ def get_searcher(
9
+ searcher_type: str,
10
+ decoder: tf.keras.Model,
11
+ jointer: tf.keras.Model,
12
+ text_decoder: tf.keras.layers.experimental.preprocessing.PreprocessingLayer,
13
+ beam_size: int,
14
+ max_symbols_per_step: int,
15
+ max_output_seq_length_ratio: float,
16
+ ):
17
+ common_kwargs = {
18
+ "decoder": decoder,
19
+ "jointer": jointer,
20
+ "text_decoder": text_decoder,
21
+ "return_scores": False,
22
+ }
23
+ if searcher_type == "greedy_search":
24
+ searcher = GreedyRNNT(
25
+ max_symbols_per_step=max_symbols_per_step,
26
+ **common_kwargs,
27
+ )
28
+ elif searcher_type == "beam_search":
29
+ searcher = BeamRNNT(
30
+ max_symbols_per_step=max_symbols_per_step,
31
+ beam=beam_size,
32
+ alpha=0.0,
33
+ **common_kwargs,
34
+ )
35
+ elif searcher_type == "alsd_search":
36
+ searcher = ALSDBeamRNNT(
37
+ fraction=max_output_seq_length_ratio,
38
+ beam_size=beam_size,
39
+ **common_kwargs,
40
+ )
41
+ else:
42
+ logging.info(f"Unknown searcher type: {searcher_type}")
43
+
44
+ return searcher
model.py CHANGED
@@ -5,6 +5,8 @@ from huggingface_hub import hf_hub_download
5
  from hyperpyyaml import load_hyperpyyaml
6
  from typing import Union
7
 
 
 
8
  os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
9
 
10
 
@@ -69,17 +71,20 @@ def _get_conformer_pre_trained_model(repo_id: str, checkpoint_dir: str = "checkp
69
  local_dir=os.path.dirname(__file__), # noqa
70
  local_dir_use_symlinks=True,
71
  )
72
- print(config_path)
73
  with open(config_path, "r") as f:
74
  config = load_hyperpyyaml(f)
75
 
76
  encoder_model = config["encoder_model"]
77
- searcher = config["decoder"]
 
 
 
78
  model = config["model"]
79
  audio_encoder = config["audio_encoder"]
80
  model.load_weights(os.path.join(checkpoint_dir, "avg_top5_27-32.ckpt")).expect_partial()
81
 
82
- return audio_encoder, encoder_model, searcher, model
83
 
84
 
85
  def read_audio(in_filename: str):
@@ -90,16 +95,30 @@ def read_audio(in_filename: str):
90
 
91
 
92
  class UETASRModel:
93
- def __init__(self, repo_id: str):
94
- self.featurizer, self.encoder_model, self.searcher, self.model = _get_conformer_pre_trained_model(repo_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  def predict(self, in_filename: str):
97
  inputs = read_audio(in_filename)
98
  features = self.featurizer(inputs)
99
  features = self.model.cmvn(features) if self.model.use_cmvn else features
100
 
101
- batch_size = tf.shape(features)[0]
102
- dim = tf.shape(features)[-1]
103
  mask = tf.sequence_mask([tf.shape(features)[1]], maxlen=tf.shape(features)[1])
104
  mask = tf.expand_dims(mask, axis=1)
105
  encoder_outputs, encoder_masks = self.encoder_model(
@@ -111,7 +130,8 @@ class UETASRModel:
111
  axis=1
112
  )
113
 
114
- outputs = self.searcher(encoder_outputs, features_length)
 
115
  outputs = tf.compat.as_str_any(outputs.numpy())
116
 
117
  return outputs
 
5
  from hyperpyyaml import load_hyperpyyaml
6
  from typing import Union
7
 
8
+ from decode import get_searcher
9
+
10
  os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
11
 
12
 
 
71
  local_dir=os.path.dirname(__file__), # noqa
72
  local_dir_use_symlinks=True,
73
  )
74
+
75
  with open(config_path, "r") as f:
76
  config = load_hyperpyyaml(f)
77
 
78
  encoder_model = config["encoder_model"]
79
+ text_encoder = config["text_encoder"]
80
+ jointer = config["jointer_model"]
81
+ decoder = config["decoder_model"]
82
+ # searcher = config["decoder"]
83
  model = config["model"]
84
  audio_encoder = config["audio_encoder"]
85
  model.load_weights(os.path.join(checkpoint_dir, "avg_top5_27-32.ckpt")).expect_partial()
86
 
87
+ return audio_encoder, encoder_model, jointer, decoder, text_encoder, model
88
 
89
 
90
  def read_audio(in_filename: str):
 
95
 
96
 
97
  class UETASRModel:
98
+ def __init__(
99
+ self,
100
+ repo_id: str,
101
+ decoding_method: str,
102
+ beam_size: int,
103
+ max_symbols_per_step: int,
104
+ max_output_seq_length_ratio: float,
105
+ ):
106
+ self.featurizer, self.encoder_model, jointer, decoder, text_encoder, self.model = _get_conformer_pre_trained_model(repo_id)
107
+ self.searcher = get_searcher(
108
+ decoding_method,
109
+ decoder,
110
+ jointer,
111
+ text_encoder,
112
+ beam_size,
113
+ max_symbols_per_step,
114
+ max_output_seq_length_ratio,
115
+ )
116
 
117
  def predict(self, in_filename: str):
118
  inputs = read_audio(in_filename)
119
  features = self.featurizer(inputs)
120
  features = self.model.cmvn(features) if self.model.use_cmvn else features
121
 
 
 
122
  mask = tf.sequence_mask([tf.shape(features)[1]], maxlen=tf.shape(features)[1])
123
  mask = tf.expand_dims(mask, axis=1)
124
  encoder_outputs, encoder_masks = self.encoder_model(
 
130
  axis=1
131
  )
132
 
133
+ outputs = self.searcher.infer(encoder_outputs, features_length)
134
+ outputs = tf.squeeze(outputs)
135
  outputs = tf.compat.as_str_any(outputs.numpy())
136
 
137
  return outputs
requirements.txt CHANGED
@@ -1,3 +1,2 @@
1
- uetasr @ git+https://github.com/thanhtvt/uetasr@v0.1.0-beta
2
- librosa
3
  requests==2.28.2
 
1
+ uetasr @ git+https://github.com/thanhtvt/uetasr@v0.2.0
 
2
  requests==2.28.2