Spaces:
Runtime error
Runtime error
chnk58hoang
commited on
Commit
Β·
a882d4c
1
Parent(s):
2eacc27
upload source file
Browse files- README.md +5 -5
- __pycache__/app.cpython-310.pyc +0 -0
- __pycache__/decode.cpython-310.pyc +0 -0
- __pycache__/examples.cpython-310.pyc +0 -0
- __pycache__/model.cpython-310.pyc +0 -0
- app.py +299 -0
- checkpoints/avg_top5_27-32.ckpt.data-00000-of-00001 +1 -0
- checkpoints/avg_top5_27-32.ckpt.index +1 -0
- config.yaml +1 -0
- decode.py +37 -0
- examples.py +5 -0
- model.py +135 -0
- requirements.txt +3 -0
- test_wavs/2022_1004_00001300_00002239.wav +0 -0
- test_wavs/2022_1004_00087158_00087929.wav +0 -0
- test_wavs/2022_1008_00110083_00110571.wav +0 -0
- vocabs/subword_vietnamese_500.model +1 -0
- vocabs/subword_vietnamese_500.vocab +1 -0
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
1 |
---
|
2 |
+
title: Uetasr
|
3 |
+
emoji: π
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: blue
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.24.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
__pycache__/app.cpython-310.pyc
ADDED
Binary file (5.98 kB). View file
|
|
__pycache__/decode.cpython-310.pyc
ADDED
Binary file (998 Bytes). View file
|
|
__pycache__/examples.cpython-310.pyc
ADDED
Binary file (297 Bytes). View file
|
|
__pycache__/model.cpython-310.pyc
ADDED
Binary file (3.46 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import gradio as gr
|
3 |
+
import librosa
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import soundfile as sf
|
7 |
+
import subprocess
|
8 |
+
import tempfile
|
9 |
+
import urllib.request
|
10 |
+
|
11 |
+
from datetime import datetime
|
12 |
+
from time import time
|
13 |
+
|
14 |
+
from examples import examples
|
15 |
+
from model import UETASRModel
|
16 |
+
|
17 |
+
|
18 |
+
def get_duration(filename: str) -> float:
|
19 |
+
return librosa.get_duration(path=filename)
|
20 |
+
|
21 |
+
|
22 |
+
def convert_to_wav(in_filename: str) -> str:
|
23 |
+
out_filename = os.path.splitext(in_filename)[0] + ".wav"
|
24 |
+
logging.info(f"Converting {in_filename} to {out_filename}")
|
25 |
+
y, sr = librosa.load(in_filename, sr=16000)
|
26 |
+
sf.write(out_filename, y, sr)
|
27 |
+
return out_filename
|
28 |
+
|
29 |
+
|
30 |
+
def build_html_output(s: str, style: str = "result_item_success"):
|
31 |
+
return f"""
|
32 |
+
<div class='result'>
|
33 |
+
<div class='result_item {style}'>
|
34 |
+
{s}
|
35 |
+
</div>
|
36 |
+
</div>
|
37 |
+
"""
|
38 |
+
|
39 |
+
|
40 |
+
def process_url(
|
41 |
+
url: str,
|
42 |
+
decoding_method: str,
|
43 |
+
beam_size: int,
|
44 |
+
max_symbols_per_step: int,
|
45 |
+
):
|
46 |
+
logging.info(f"Processing URL: {url}")
|
47 |
+
with tempfile.NamedTemporaryFile() as f:
|
48 |
+
try:
|
49 |
+
urllib.request.urlretrieve(url, f.name)
|
50 |
+
return process(in_filename=f.name,
|
51 |
+
decoding_method=decoding_method,
|
52 |
+
beam_size=beam_size,
|
53 |
+
max_symbols_per_step=max_symbols_per_step)
|
54 |
+
except Exception as e:
|
55 |
+
logging.info(str(e))
|
56 |
+
return "", build_html_output(str(e), "result_item_error")
|
57 |
+
|
58 |
+
|
59 |
+
def process_uploaded_file(
|
60 |
+
in_filename: str,
|
61 |
+
decoding_method: str,
|
62 |
+
beam_size: int,
|
63 |
+
max_symbols_per_step: int,
|
64 |
+
):
|
65 |
+
if in_filename is None or in_filename == "":
|
66 |
+
return "", build_html_output(
|
67 |
+
"Please first upload a file and then click "
|
68 |
+
'the button "submit for recognition"',
|
69 |
+
"result_item_error",
|
70 |
+
)
|
71 |
+
|
72 |
+
logging.info(f"Processing uploaded file: {in_filename}")
|
73 |
+
try:
|
74 |
+
return process(in_filename=in_filename,
|
75 |
+
decoding_method=decoding_method,
|
76 |
+
beam_size=beam_size,
|
77 |
+
max_symbols_per_step=max_symbols_per_step)
|
78 |
+
except Exception as e:
|
79 |
+
logging.info(str(e))
|
80 |
+
return "", build_html_output(str(e), "result_item_error")
|
81 |
+
|
82 |
+
|
83 |
+
def process_microphone(
|
84 |
+
in_filename: str,
|
85 |
+
decoding_method: str,
|
86 |
+
beam_size: int,
|
87 |
+
max_symbols_per_step: int,
|
88 |
+
):
|
89 |
+
if in_filename is None or in_filename == "":
|
90 |
+
return "", build_html_output(
|
91 |
+
"Please first upload a file and then click "
|
92 |
+
'the button "submit for recognition"',
|
93 |
+
"result_item_error",
|
94 |
+
)
|
95 |
+
|
96 |
+
logging.info(f"Processing microphone: {in_filename}")
|
97 |
+
try:
|
98 |
+
return process(in_filename=in_filename,
|
99 |
+
decoding_method=decoding_method,
|
100 |
+
beam_size=beam_size,
|
101 |
+
max_symbols_per_step=max_symbols_per_step)
|
102 |
+
except Exception as e:
|
103 |
+
logging.info(str(e))
|
104 |
+
return "", build_html_output(str(e), "result_item_error")
|
105 |
+
|
106 |
+
|
107 |
+
def process(
|
108 |
+
in_filename: str,
|
109 |
+
decoding_method: str,
|
110 |
+
beam_size: int,
|
111 |
+
max_symbols_per_step: int,
|
112 |
+
):
|
113 |
+
logging.info(f"in_filename: {in_filename}")
|
114 |
+
|
115 |
+
filename = convert_to_wav(in_filename)
|
116 |
+
|
117 |
+
now = datetime.now()
|
118 |
+
date_time = now.strftime("%d/%m/%Y, %H:%M:%S.%f")
|
119 |
+
logging.info(f"Started at {date_time}")
|
120 |
+
|
121 |
+
repo_id = "thanhtvt/uetasr-conformer_30.3m"
|
122 |
+
|
123 |
+
start = time()
|
124 |
+
|
125 |
+
recognizer = UETASRModel(repo_id,
|
126 |
+
decoding_method,
|
127 |
+
beam_size,
|
128 |
+
max_symbols_per_step)
|
129 |
+
text = recognizer.predict(filename)
|
130 |
+
|
131 |
+
date_time = now.strftime("%d/%m/%Y, %H:%M:%S.%f")
|
132 |
+
end = time()
|
133 |
+
|
134 |
+
duration = get_duration(filename)
|
135 |
+
rtf = (end - start) / duration
|
136 |
+
|
137 |
+
logging.info(f"Finished at {date_time} s. Elapsed: {end - start: .3f} s")
|
138 |
+
|
139 |
+
info = f"""
|
140 |
+
Wave duration : {duration: .3f} s <br/>
|
141 |
+
Processing time: {end - start: .3f} s <br/>
|
142 |
+
RTF: {end - start: .3f}/{duration: .3f} = {rtf:.3f} <br/>
|
143 |
+
"""
|
144 |
+
if rtf > 1:
|
145 |
+
info += (
|
146 |
+
"<br/>We are loading required resources for the first run. "
|
147 |
+
"Please run again to measure the real RTF.<br/>"
|
148 |
+
)
|
149 |
+
|
150 |
+
logging.info(info)
|
151 |
+
|
152 |
+
return text, build_html_output(info)
|
153 |
+
|
154 |
+
|
155 |
+
title = "Educa ASR"
|
156 |
+
description = """
|
157 |
+
A space demo for Automatic Speech Recognition.
|
158 |
+
"""
|
159 |
+
|
160 |
+
# css style is copied from
|
161 |
+
# https://huggingface.co/spaces/alphacep/asr/blob/main/app.py#L113
|
162 |
+
css = """
|
163 |
+
.result {display:flex;flex-direction:column}
|
164 |
+
.result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%}
|
165 |
+
.result_item_success {background-color:mediumaquamarine;color:white;align-self:start}
|
166 |
+
.result_item_error {background-color:#ff7070;color:white;align-self:start}
|
167 |
+
"""
|
168 |
+
|
169 |
+
demo = gr.Blocks(css=css)
|
170 |
+
|
171 |
+
|
172 |
+
with demo:
|
173 |
+
gr.Markdown(title)
|
174 |
+
|
175 |
+
decode_method_radio = gr.Radio(
|
176 |
+
label="Decoding method",
|
177 |
+
choices=["greedy_search", "beam_search"],
|
178 |
+
value="greedy_search",
|
179 |
+
interactive=True,
|
180 |
+
)
|
181 |
+
|
182 |
+
beam_size_slider = gr.Slider(
|
183 |
+
label="Beam size",
|
184 |
+
minimum=1,
|
185 |
+
maximum=20,
|
186 |
+
step=1,
|
187 |
+
value=1,
|
188 |
+
interactive=False,
|
189 |
+
)
|
190 |
+
|
191 |
+
def interact_beam_slider(decoding_method):
|
192 |
+
if decoding_method == "greedy_search":
|
193 |
+
return gr.update(value=1, interactive=False)
|
194 |
+
else:
|
195 |
+
return gr.update(interactive=True)
|
196 |
+
|
197 |
+
decode_method_radio.change(interact_beam_slider,
|
198 |
+
decode_method_radio,
|
199 |
+
beam_size_slider)
|
200 |
+
|
201 |
+
max_symbols_per_step_slider = gr.Slider(
|
202 |
+
label="Maximum symbols per step",
|
203 |
+
minimum=1,
|
204 |
+
maximum=20,
|
205 |
+
step=1,
|
206 |
+
value=5,
|
207 |
+
interactive=True,
|
208 |
+
visible=True,
|
209 |
+
)
|
210 |
+
|
211 |
+
with gr.Tabs():
|
212 |
+
with gr.TabItem("Upload from disk"):
|
213 |
+
uploaded_file = gr.Audio(
|
214 |
+
source="upload", # Choose between "microphone", "upload"
|
215 |
+
type="filepath",
|
216 |
+
label="Upload from disk",
|
217 |
+
)
|
218 |
+
upload_button = gr.Button("Submit for recognition")
|
219 |
+
uploaded_output = gr.Textbox(label="Recognized speech from uploaded file")
|
220 |
+
uploaded_html_info = gr.HTML(label="Info")
|
221 |
+
|
222 |
+
gr.Examples(
|
223 |
+
examples=examples,
|
224 |
+
inputs=uploaded_file,
|
225 |
+
outputs=[uploaded_output, uploaded_html_info],
|
226 |
+
fn=process_uploaded_file,
|
227 |
+
)
|
228 |
+
|
229 |
+
with gr.TabItem("Record from microphone"):
|
230 |
+
microphone = gr.Audio(
|
231 |
+
source="microphone",
|
232 |
+
type="filepath",
|
233 |
+
label="Record from microphone",
|
234 |
+
)
|
235 |
+
|
236 |
+
record_button = gr.Button("Submit for recognition")
|
237 |
+
recorded_output = gr.Textbox(label="Recognized speech from recordings")
|
238 |
+
recorded_html_info = gr.HTML(label="Info")
|
239 |
+
|
240 |
+
gr.Examples(
|
241 |
+
examples=examples,
|
242 |
+
inputs=microphone,
|
243 |
+
outputs=[uploaded_output, uploaded_html_info],
|
244 |
+
fn=process_microphone,
|
245 |
+
)
|
246 |
+
|
247 |
+
with gr.TabItem("From URL"):
|
248 |
+
url_textbox = gr.Textbox(
|
249 |
+
max_lines=1,
|
250 |
+
placeholder="URL to an audio file",
|
251 |
+
label="URL",
|
252 |
+
interactive=True,
|
253 |
+
)
|
254 |
+
|
255 |
+
url_button = gr.Button("Submit for recognition")
|
256 |
+
url_output = gr.Textbox(label="Recognized speech from URL")
|
257 |
+
url_html_info = gr.HTML(label="Info")
|
258 |
+
|
259 |
+
upload_button.click(
|
260 |
+
process_uploaded_file,
|
261 |
+
inputs=[
|
262 |
+
uploaded_file,
|
263 |
+
decode_method_radio,
|
264 |
+
beam_size_slider,
|
265 |
+
max_symbols_per_step_slider,
|
266 |
+
],
|
267 |
+
outputs=[uploaded_output, uploaded_html_info],
|
268 |
+
)
|
269 |
+
|
270 |
+
record_button.click(
|
271 |
+
process_microphone,
|
272 |
+
inputs=[
|
273 |
+
microphone,
|
274 |
+
decode_method_radio,
|
275 |
+
beam_size_slider,
|
276 |
+
max_symbols_per_step_slider,
|
277 |
+
],
|
278 |
+
outputs=[recorded_output, recorded_html_info],
|
279 |
+
)
|
280 |
+
|
281 |
+
url_button.click(
|
282 |
+
process_url,
|
283 |
+
inputs=[
|
284 |
+
url_textbox,
|
285 |
+
decode_method_radio,
|
286 |
+
beam_size_slider,
|
287 |
+
max_symbols_per_step_slider,
|
288 |
+
],
|
289 |
+
outputs=[url_output, url_html_info],
|
290 |
+
)
|
291 |
+
gr.Markdown(description)
|
292 |
+
|
293 |
+
|
294 |
+
if __name__ == "__main__":
|
295 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
296 |
+
|
297 |
+
logging.basicConfig(format=formatter, level=logging.INFO)
|
298 |
+
|
299 |
+
demo.launch()
|
checkpoints/avg_top5_27-32.ckpt.data-00000-of-00001
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
../../../.cache/huggingface/hub/models--thanhtvt--uetasr-conformer_30.3m/blobs/5806d3e06fac6c2fde5a70f5c5d29bf31dcf81557ef744fd67339804ef736050
|
checkpoints/avg_top5_27-32.ckpt.index
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
../../../.cache/huggingface/hub/models--thanhtvt--uetasr-conformer_30.3m/blobs/620561277420e8326531c0036fe013b9f57aa6e577fb7390236679ae15c70ca7
|
config.yaml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
../../.cache/huggingface/hub/models--thanhtvt--uetasr-conformer_30.3m/blobs/665f8dfb04042619a6b114dca8622bd0cfe1ad3d
|
decode.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import tensorflow as tf
|
3 |
+
from functools import lru_cache
|
4 |
+
from uetasr.searchers import GreedyRNNT, BeamRNNT
|
5 |
+
|
6 |
+
|
7 |
+
@lru_cache(maxsize=5)
|
8 |
+
def get_searcher(
|
9 |
+
searcher_type: str,
|
10 |
+
decoder: tf.keras.Model,
|
11 |
+
jointer: tf.keras.Model,
|
12 |
+
text_decoder: tf.keras.layers.experimental.preprocessing.PreprocessingLayer,
|
13 |
+
beam_size: int,
|
14 |
+
max_symbols_per_step: int,
|
15 |
+
):
|
16 |
+
common_kwargs = {
|
17 |
+
"decoder": decoder,
|
18 |
+
"jointer": jointer,
|
19 |
+
"text_decoder": text_decoder,
|
20 |
+
"return_scores": False,
|
21 |
+
}
|
22 |
+
if searcher_type == "greedy_search":
|
23 |
+
searcher = GreedyRNNT(
|
24 |
+
max_symbols_per_step=max_symbols_per_step,
|
25 |
+
**common_kwargs,
|
26 |
+
)
|
27 |
+
elif searcher_type == "beam_search":
|
28 |
+
searcher = BeamRNNT(
|
29 |
+
max_symbols_per_step=max_symbols_per_step,
|
30 |
+
beam=beam_size,
|
31 |
+
alpha=0.0,
|
32 |
+
**common_kwargs,
|
33 |
+
)
|
34 |
+
else:
|
35 |
+
logging.info(f"Unknown searcher type: {searcher_type}")
|
36 |
+
|
37 |
+
return searcher
|
examples.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
examples = [
|
2 |
+
"./test_wavs/2022_1004_00001300_00002239.wav",
|
3 |
+
"./test_wavs/2022_1004_00087158_00087929.wav",
|
4 |
+
"./test_wavs/2022_1008_00110083_00110571.wav",
|
5 |
+
]
|
model.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tensorflow as tf
|
3 |
+
from functools import lru_cache
|
4 |
+
from huggingface_hub import hf_hub_download
|
5 |
+
from hyperpyyaml import load_hyperpyyaml
|
6 |
+
from typing import Union
|
7 |
+
|
8 |
+
from decode import get_searcher
|
9 |
+
|
10 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
11 |
+
|
12 |
+
|
13 |
+
def _get_checkpoint_filename(
|
14 |
+
repo_id: str,
|
15 |
+
filename: str,
|
16 |
+
local_dir: str = None,
|
17 |
+
local_dir_use_symlinks: Union[bool, str] = "auto",
|
18 |
+
subfolder: str = "checkpoints"
|
19 |
+
) -> str:
|
20 |
+
model_filename = hf_hub_download(
|
21 |
+
repo_id=repo_id,
|
22 |
+
filename=filename,
|
23 |
+
subfolder=subfolder,
|
24 |
+
local_dir=local_dir,
|
25 |
+
local_dir_use_symlinks=local_dir_use_symlinks,
|
26 |
+
)
|
27 |
+
return model_filename
|
28 |
+
|
29 |
+
|
30 |
+
def _get_bpe_model_filename(
|
31 |
+
repo_id: str,
|
32 |
+
filename: str,
|
33 |
+
local_dir: str = None,
|
34 |
+
local_dir_use_symlinks: Union[bool, str] = "auto",
|
35 |
+
subfolder: str = "vocabs"
|
36 |
+
) -> str:
|
37 |
+
bpe_model_filename = hf_hub_download(
|
38 |
+
repo_id=repo_id,
|
39 |
+
filename=filename,
|
40 |
+
subfolder=subfolder,
|
41 |
+
local_dir=local_dir,
|
42 |
+
local_dir_use_symlinks=local_dir_use_symlinks,
|
43 |
+
)
|
44 |
+
return bpe_model_filename
|
45 |
+
|
46 |
+
|
47 |
+
@lru_cache(maxsize=1)
|
48 |
+
def _get_conformer_pre_trained_model(repo_id: str, checkpoint_dir: str = "checkpoints"):
|
49 |
+
for postfix in ["index", "data-00000-of-00001"]:
|
50 |
+
tmp = _get_checkpoint_filename(
|
51 |
+
repo_id=repo_id,
|
52 |
+
filename="avg_top5_27-32.ckpt.{}".format(postfix),
|
53 |
+
subfolder=checkpoint_dir,
|
54 |
+
local_dir=os.path.dirname(__file__), # noqa
|
55 |
+
local_dir_use_symlinks=True,
|
56 |
+
)
|
57 |
+
print(tmp)
|
58 |
+
|
59 |
+
for postfix in ["model", "vocab"]:
|
60 |
+
tmp = _get_bpe_model_filename(
|
61 |
+
repo_id=repo_id,
|
62 |
+
filename="subword_vietnamese_500.{}".format(postfix),
|
63 |
+
local_dir=os.path.dirname(__file__), # noqa
|
64 |
+
local_dir_use_symlinks=True,
|
65 |
+
)
|
66 |
+
print(tmp)
|
67 |
+
|
68 |
+
config_path = hf_hub_download(
|
69 |
+
repo_id=repo_id,
|
70 |
+
filename="config.yaml",
|
71 |
+
local_dir=os.path.dirname(__file__), # noqa
|
72 |
+
local_dir_use_symlinks=True,
|
73 |
+
)
|
74 |
+
|
75 |
+
with open(config_path, "r") as f:
|
76 |
+
config = load_hyperpyyaml(f)
|
77 |
+
|
78 |
+
encoder_model = config["encoder_model"]
|
79 |
+
text_encoder = config["text_encoder"]
|
80 |
+
jointer = config["jointer_model"]
|
81 |
+
decoder = config["decoder_model"]
|
82 |
+
# searcher = config["decoder"]
|
83 |
+
model = config["model"]
|
84 |
+
audio_encoder = config["audio_encoder"]
|
85 |
+
model.load_weights(os.path.join(checkpoint_dir, "avg_top5_27-32.ckpt")).expect_partial()
|
86 |
+
|
87 |
+
return audio_encoder, encoder_model, jointer, decoder, text_encoder, model
|
88 |
+
|
89 |
+
|
90 |
+
def read_audio(in_filename: str):
|
91 |
+
audio = tf.io.read_file(in_filename)
|
92 |
+
audio = tf.audio.decode_wav(audio)[0]
|
93 |
+
audio = tf.expand_dims(tf.squeeze(audio, axis=-1), axis=0)
|
94 |
+
return audio
|
95 |
+
|
96 |
+
|
97 |
+
class UETASRModel:
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
repo_id: str,
|
101 |
+
decoding_method: str,
|
102 |
+
beam_size: int,
|
103 |
+
max_symbols_per_step: int,
|
104 |
+
):
|
105 |
+
self.featurizer, self.encoder_model, jointer, decoder, text_encoder, self.model = _get_conformer_pre_trained_model(repo_id)
|
106 |
+
self.searcher = get_searcher(
|
107 |
+
decoding_method,
|
108 |
+
decoder,
|
109 |
+
jointer,
|
110 |
+
text_encoder,
|
111 |
+
beam_size,
|
112 |
+
max_symbols_per_step,
|
113 |
+
)
|
114 |
+
|
115 |
+
def predict(self, in_filename: str):
|
116 |
+
inputs = read_audio(in_filename)
|
117 |
+
features = self.featurizer(inputs)
|
118 |
+
features = self.model.cmvn(features) if self.model.use_cmvn else features
|
119 |
+
|
120 |
+
mask = tf.sequence_mask([tf.shape(features)[1]], maxlen=tf.shape(features)[1])
|
121 |
+
mask = tf.expand_dims(mask, axis=1)
|
122 |
+
encoder_outputs, encoder_masks = self.encoder_model(
|
123 |
+
features, mask, training=False)
|
124 |
+
|
125 |
+
encoder_mask = tf.squeeze(encoder_masks, axis=1)
|
126 |
+
features_length = tf.math.reduce_sum(
|
127 |
+
tf.cast(encoder_mask, tf.int32),
|
128 |
+
axis=1
|
129 |
+
)
|
130 |
+
|
131 |
+
outputs = self.searcher.infer(encoder_outputs, features_length)
|
132 |
+
outputs = tf.squeeze(outputs)
|
133 |
+
outputs = tf.compat.as_str_any(outputs.numpy())
|
134 |
+
|
135 |
+
return outputs
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
uetasr @ git+https://github.com/thanhtvt/[email protected]
|
2 |
+
requests==2.28.2
|
3 |
+
gradio==3.26.0
|
test_wavs/2022_1004_00001300_00002239.wav
ADDED
Binary file (301 kB). View file
|
|
test_wavs/2022_1004_00087158_00087929.wav
ADDED
Binary file (247 kB). View file
|
|
test_wavs/2022_1008_00110083_00110571.wav
ADDED
Binary file (156 kB). View file
|
|
vocabs/subword_vietnamese_500.model
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
../../../.cache/huggingface/hub/models--thanhtvt--uetasr-conformer_30.3m/blobs/cd533c4c7981f9d1caeb89be0d0e444177af9e00acf27919d66888625b939b96
|
vocabs/subword_vietnamese_500.vocab
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
../../../.cache/huggingface/hub/models--thanhtvt--uetasr-conformer_30.3m/blobs/d6daeedf9cfdad6ec7c7267e23583d2c231d4925
|