thanhtvt commited on
Commit
4ac7ffc
1 Parent(s): bcd8e2f

remove alsd

Browse files
Files changed (4) hide show
  1. app.py +26 -72
  2. decode.py +1 -8
  3. model.py +0 -2
  4. requirements.txt +1 -1
app.py CHANGED
@@ -19,7 +19,7 @@ def get_duration(filename: str) -> float:
19
  return librosa.get_duration(path=filename)
20
 
21
 
22
- def convert_to_wav1(in_filename: str) -> str:
23
  out_filename = os.path.splitext(in_filename)[0] + ".wav"
24
  logging.info(f"Converting {in_filename} to {out_filename}")
25
  y, sr = librosa.load(in_filename, sr=16000)
@@ -27,22 +27,6 @@ def convert_to_wav1(in_filename: str) -> str:
27
  return out_filename
28
 
29
 
30
- def convert_to_wav(in_filename: str) -> str:
31
- """Convert the input audio file to a wave file"""
32
- out_filename = in_filename + ".wav"
33
- logging.info(f"Converting '{in_filename}' to '{out_filename}'")
34
-
35
- sp_args = ["ffmpeg", "-hide_banner", "-i", in_filename, "-ar", "16000", out_filename]
36
- sp_args.insert(2, "-y") if os.path.exists(out_filename) else None
37
- # Create a subprocess to run the ffmpeg command.
38
- _ = subprocess.Popen(
39
- sp_args,
40
- stdin=subprocess.PIPE,
41
- )
42
-
43
- return out_filename
44
-
45
-
46
  def build_html_output(s: str, style: str = "result_item_success"):
47
  return f"""
48
  <div class='result'>
@@ -58,7 +42,6 @@ def process_url(
58
  decoding_method: str,
59
  beam_size: int,
60
  max_symbols_per_step: int,
61
- max_out_seq_len_ratio: float,
62
  ):
63
  logging.info(f"Processing URL: {url}")
64
  with tempfile.NamedTemporaryFile() as f:
@@ -67,8 +50,7 @@ def process_url(
67
  return process(in_filename=f.name,
68
  decoding_method=decoding_method,
69
  beam_size=beam_size,
70
- max_symbols_per_step=max_symbols_per_step,
71
- max_out_seq_len_ratio=max_out_seq_len_ratio)
72
  except Exception as e:
73
  logging.info(str(e))
74
  return "", build_html_output(str(e), "result_item_error")
@@ -79,7 +61,6 @@ def process_uploaded_file(
79
  decoding_method: str,
80
  beam_size: int,
81
  max_symbols_per_step: int,
82
- max_out_seq_len_ratio: float,
83
  ):
84
  if in_filename is None or in_filename == "":
85
  return "", build_html_output(
@@ -93,8 +74,7 @@ def process_uploaded_file(
93
  return process(in_filename=in_filename,
94
  decoding_method=decoding_method,
95
  beam_size=beam_size,
96
- max_symbols_per_step=max_symbols_per_step,
97
- max_out_seq_len_ratio=max_out_seq_len_ratio)
98
  except Exception as e:
99
  logging.info(str(e))
100
  return "", build_html_output(str(e), "result_item_error")
@@ -105,7 +85,6 @@ def process_microphone(
105
  decoding_method: str,
106
  beam_size: int,
107
  max_symbols_per_step: int,
108
- max_out_seq_len_ratio: float,
109
  ):
110
  if in_filename is None or in_filename == "":
111
  return "", build_html_output(
@@ -119,8 +98,7 @@ def process_microphone(
119
  return process(in_filename=in_filename,
120
  decoding_method=decoding_method,
121
  beam_size=beam_size,
122
- max_symbols_per_step=max_symbols_per_step,
123
- max_out_seq_len_ratio=max_out_seq_len_ratio)
124
  except Exception as e:
125
  logging.info(str(e))
126
  return "", build_html_output(str(e), "result_item_error")
@@ -131,7 +109,6 @@ def process(
131
  decoding_method: str,
132
  beam_size: int,
133
  max_symbols_per_step: int,
134
- max_out_seq_len_ratio: float,
135
  ):
136
  logging.info(f"in_filename: {in_filename}")
137
 
@@ -148,8 +125,7 @@ def process(
148
  recognizer = UETASRModel(repo_id,
149
  decoding_method,
150
  beam_size,
151
- max_symbols_per_step,
152
- max_out_seq_len_ratio)
153
  text = recognizer.predict(filename)
154
 
155
  date_time = now.strftime("%d/%m/%Y, %H:%M:%S.%f")
@@ -167,7 +143,7 @@ def process(
167
  """
168
  if rtf > 1:
169
  info += (
170
- "<br/>We are loading the model for the first run. "
171
  "Please run again to measure the real RTF.<br/>"
172
  )
173
 
@@ -202,59 +178,40 @@ with demo:
202
 
203
  decode_method_radio = gr.Radio(
204
  label="Decoding method",
205
- choices=["greedy_search", "beam_search", "alsd_search"],
206
  value="greedy_search",
207
  interactive=True,
208
  )
209
 
210
- with gr.Column(visible=False) as beam_col:
211
- beam_size = gr.Slider(
212
- label="Beam size",
213
- minimum=1,
214
- maximum=10,
215
- step=1,
216
- value=5,
217
- interactive=True,
218
- )
219
 
220
- def enable_beam_col(decoding_method):
221
- if decoding_method != "greedy_search":
222
- return gr.update(visible=True)
223
  else:
224
- return gr.update(visible=False)
225
 
226
- decode_method_radio.change(enable_beam_col, decode_method_radio, beam_col)
 
 
227
 
228
  max_symbols_per_step_slider = gr.Slider(
229
  label="Maximum symbols per step",
230
  minimum=1,
231
- maximum=15,
232
  step=1,
233
  value=5,
234
  interactive=True,
235
  visible=True,
236
  )
237
 
238
- max_out_seq_len_slider = gr.Slider(
239
- label="Maximum output sequence length ratio",
240
- minimum=0,
241
- maximum=1,
242
- step=0.01,
243
- value=0.6,
244
- interactive=True,
245
- visible=False,
246
- )
247
-
248
- def switch_slider(decoding_method):
249
- if decoding_method == "alsd_search":
250
- return gr.update(visible=False), gr.update(visible=True)
251
- else:
252
- return gr.update(visible=True), gr.update(visible=False)
253
-
254
- decode_method_radio.change(switch_slider,
255
- decode_method_radio,
256
- [max_symbols_per_step_slider, max_out_seq_len_slider])
257
-
258
  with gr.Tabs():
259
  with gr.TabItem("Upload from disk"):
260
  uploaded_file = gr.Audio(
@@ -308,9 +265,8 @@ with demo:
308
  inputs=[
309
  uploaded_file,
310
  decode_method_radio,
311
- beam_size,
312
  max_symbols_per_step_slider,
313
- max_out_seq_len_slider,
314
  ],
315
  outputs=[uploaded_output, uploaded_html_info],
316
  )
@@ -320,9 +276,8 @@ with demo:
320
  inputs=[
321
  microphone,
322
  decode_method_radio,
323
- beam_size,
324
  max_symbols_per_step_slider,
325
- max_out_seq_len_slider,
326
  ],
327
  outputs=[recorded_output, recorded_html_info],
328
  )
@@ -332,9 +287,8 @@ with demo:
332
  inputs=[
333
  url_textbox,
334
  decode_method_radio,
335
- beam_size,
336
  max_symbols_per_step_slider,
337
- max_out_seq_len_slider,
338
  ],
339
  outputs=[url_output, url_html_info],
340
  )
 
19
  return librosa.get_duration(path=filename)
20
 
21
 
22
+ def convert_to_wav(in_filename: str) -> str:
23
  out_filename = os.path.splitext(in_filename)[0] + ".wav"
24
  logging.info(f"Converting {in_filename} to {out_filename}")
25
  y, sr = librosa.load(in_filename, sr=16000)
 
27
  return out_filename
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def build_html_output(s: str, style: str = "result_item_success"):
31
  return f"""
32
  <div class='result'>
 
42
  decoding_method: str,
43
  beam_size: int,
44
  max_symbols_per_step: int,
 
45
  ):
46
  logging.info(f"Processing URL: {url}")
47
  with tempfile.NamedTemporaryFile() as f:
 
50
  return process(in_filename=f.name,
51
  decoding_method=decoding_method,
52
  beam_size=beam_size,
53
+ max_symbols_per_step=max_symbols_per_step)
 
54
  except Exception as e:
55
  logging.info(str(e))
56
  return "", build_html_output(str(e), "result_item_error")
 
61
  decoding_method: str,
62
  beam_size: int,
63
  max_symbols_per_step: int,
 
64
  ):
65
  if in_filename is None or in_filename == "":
66
  return "", build_html_output(
 
74
  return process(in_filename=in_filename,
75
  decoding_method=decoding_method,
76
  beam_size=beam_size,
77
+ max_symbols_per_step=max_symbols_per_step)
 
78
  except Exception as e:
79
  logging.info(str(e))
80
  return "", build_html_output(str(e), "result_item_error")
 
85
  decoding_method: str,
86
  beam_size: int,
87
  max_symbols_per_step: int,
 
88
  ):
89
  if in_filename is None or in_filename == "":
90
  return "", build_html_output(
 
98
  return process(in_filename=in_filename,
99
  decoding_method=decoding_method,
100
  beam_size=beam_size,
101
+ max_symbols_per_step=max_symbols_per_step)
 
102
  except Exception as e:
103
  logging.info(str(e))
104
  return "", build_html_output(str(e), "result_item_error")
 
109
  decoding_method: str,
110
  beam_size: int,
111
  max_symbols_per_step: int,
 
112
  ):
113
  logging.info(f"in_filename: {in_filename}")
114
 
 
125
  recognizer = UETASRModel(repo_id,
126
  decoding_method,
127
  beam_size,
128
+ max_symbols_per_step)
 
129
  text = recognizer.predict(filename)
130
 
131
  date_time = now.strftime("%d/%m/%Y, %H:%M:%S.%f")
 
143
  """
144
  if rtf > 1:
145
  info += (
146
+ "<br/>We are loading required resources for the first run. "
147
  "Please run again to measure the real RTF.<br/>"
148
  )
149
 
 
178
 
179
  decode_method_radio = gr.Radio(
180
  label="Decoding method",
181
+ choices=["greedy_search", "beam_search"],
182
  value="greedy_search",
183
  interactive=True,
184
  )
185
 
186
+ beam_size_slider = gr.Slider(
187
+ label="Beam size",
188
+ minimum=1,
189
+ maximum=20,
190
+ step=1,
191
+ value=1,
192
+ interactive=False,
193
+ )
 
194
 
195
+ def interact_beam_slider(decoding_method):
196
+ if decoding_method == "greedy_search":
197
+ return gr.update(value=1, interactive=False)
198
  else:
199
+ return gr.update(interactive=True)
200
 
201
+ decode_method_radio.change(interact_beam_slider,
202
+ decode_method_radio,
203
+ beam_size_slider)
204
 
205
  max_symbols_per_step_slider = gr.Slider(
206
  label="Maximum symbols per step",
207
  minimum=1,
208
+ maximum=20,
209
  step=1,
210
  value=5,
211
  interactive=True,
212
  visible=True,
213
  )
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  with gr.Tabs():
216
  with gr.TabItem("Upload from disk"):
217
  uploaded_file = gr.Audio(
 
265
  inputs=[
266
  uploaded_file,
267
  decode_method_radio,
268
+ beam_size_slider,
269
  max_symbols_per_step_slider,
 
270
  ],
271
  outputs=[uploaded_output, uploaded_html_info],
272
  )
 
276
  inputs=[
277
  microphone,
278
  decode_method_radio,
279
+ beam_size_slider,
280
  max_symbols_per_step_slider,
 
281
  ],
282
  outputs=[recorded_output, recorded_html_info],
283
  )
 
287
  inputs=[
288
  url_textbox,
289
  decode_method_radio,
290
+ beam_size_slider,
291
  max_symbols_per_step_slider,
 
292
  ],
293
  outputs=[url_output, url_html_info],
294
  )
decode.py CHANGED
@@ -1,7 +1,7 @@
1
  import logging
2
  import tensorflow as tf
3
  from functools import lru_cache
4
- from uetasr.searchers import GreedyRNNT, BeamRNNT, ALSDBeamRNNT
5
 
6
 
7
  @lru_cache(maxsize=5)
@@ -12,7 +12,6 @@ def get_searcher(
12
  text_decoder: tf.keras.layers.experimental.preprocessing.PreprocessingLayer,
13
  beam_size: int,
14
  max_symbols_per_step: int,
15
- max_output_seq_length_ratio: float,
16
  ):
17
  common_kwargs = {
18
  "decoder": decoder,
@@ -32,12 +31,6 @@ def get_searcher(
32
  alpha=0.0,
33
  **common_kwargs,
34
  )
35
- elif searcher_type == "alsd_search":
36
- searcher = ALSDBeamRNNT(
37
- fraction=max_output_seq_length_ratio,
38
- beam_size=beam_size,
39
- **common_kwargs,
40
- )
41
  else:
42
  logging.info(f"Unknown searcher type: {searcher_type}")
43
 
 
1
  import logging
2
  import tensorflow as tf
3
  from functools import lru_cache
4
+ from uetasr.searchers import GreedyRNNT, BeamRNNT
5
 
6
 
7
  @lru_cache(maxsize=5)
 
12
  text_decoder: tf.keras.layers.experimental.preprocessing.PreprocessingLayer,
13
  beam_size: int,
14
  max_symbols_per_step: int,
 
15
  ):
16
  common_kwargs = {
17
  "decoder": decoder,
 
31
  alpha=0.0,
32
  **common_kwargs,
33
  )
 
 
 
 
 
 
34
  else:
35
  logging.info(f"Unknown searcher type: {searcher_type}")
36
 
model.py CHANGED
@@ -101,7 +101,6 @@ class UETASRModel:
101
  decoding_method: str,
102
  beam_size: int,
103
  max_symbols_per_step: int,
104
- max_output_seq_length_ratio: float,
105
  ):
106
  self.featurizer, self.encoder_model, jointer, decoder, text_encoder, self.model = _get_conformer_pre_trained_model(repo_id)
107
  self.searcher = get_searcher(
@@ -111,7 +110,6 @@ class UETASRModel:
111
  text_encoder,
112
  beam_size,
113
  max_symbols_per_step,
114
- max_output_seq_length_ratio,
115
  )
116
 
117
  def predict(self, in_filename: str):
 
101
  decoding_method: str,
102
  beam_size: int,
103
  max_symbols_per_step: int,
 
104
  ):
105
  self.featurizer, self.encoder_model, jointer, decoder, text_encoder, self.model = _get_conformer_pre_trained_model(repo_id)
106
  self.searcher = get_searcher(
 
110
  text_encoder,
111
  beam_size,
112
  max_symbols_per_step,
 
113
  )
114
 
115
  def predict(self, in_filename: str):
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- uetasr @ git+https://github.com/thanhtvt/uetasr
2
  requests==2.28.2
 
1
+ uetasr @ git+https://github.com/thanhtvt/uetasr@v0.2.1
2
  requests==2.28.2