Spaces:
Sleeping
Sleeping
Copied theodotus/streaming-asr-uk
Browse files- packages.txt +2 -0
- .gitignore +3 -0
- app.py +102 -0
- requirements.txt +1 -0
packages.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
libsndfile1
|
2 |
+
ffmpeg
|
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
.vscode
|
3 |
+
flagged
|
app.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import resampy
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from math import floor,ceil
|
7 |
+
import nemo.collections.asr as nemo_asr
|
8 |
+
|
9 |
+
|
10 |
+
asr_model = nemo_asr.models.EncDecCTCModelBPE. \
|
11 |
+
from_pretrained("NeonBohdan/stt_uk_citrinet_512_gamma_0_25",map_location="cpu")
|
12 |
+
|
13 |
+
asr_model.preprocessor.featurizer.dither = 0.0
|
14 |
+
asr_model.preprocessor.featurizer.pad_to = 0
|
15 |
+
asr_model.eval()
|
16 |
+
asr_model.encoder.freeze()
|
17 |
+
asr_model.decoder.freeze()
|
18 |
+
|
19 |
+
|
20 |
+
total_buffer = asr_model.cfg["sample_rate"] * 19 // 10
|
21 |
+
overhead_len = total_buffer // 2
|
22 |
+
model_stride = 4
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
def resample(sr, audio_data):
|
27 |
+
audio_fp32 = np.divide(audio_data, np.iinfo(audio_data.dtype).max, dtype=np.float32)
|
28 |
+
audio_16k = resampy.resample(audio_fp32, sr, asr_model.cfg["sample_rate"])
|
29 |
+
|
30 |
+
return audio_16k
|
31 |
+
|
32 |
+
|
33 |
+
def model(audio_16k):
|
34 |
+
logits, logits_len, greedy_predictions = asr_model.forward(
|
35 |
+
input_signal=torch.tensor([audio_16k]),
|
36 |
+
input_signal_length=torch.tensor([len(audio_16k)])
|
37 |
+
)
|
38 |
+
return logits
|
39 |
+
|
40 |
+
|
41 |
+
def decode_predictions(logits_list):
|
42 |
+
# calc overhead
|
43 |
+
logits_overhead = logits_list[0].shape[1] * overhead_len / total_buffer / 2
|
44 |
+
if (logits_overhead * 2 != int(logits_overhead * 2)):
|
45 |
+
raise ValueError("Wrong total_buffer")
|
46 |
+
|
47 |
+
# cut overhead
|
48 |
+
cutted_logits = []
|
49 |
+
for idx in range(len(logits_list)):
|
50 |
+
start_cut = 0 if (idx==0) else floor(logits_overhead)
|
51 |
+
end_cut = 1 if (idx==len(logits_list)-1) else ceil(logits_overhead)
|
52 |
+
if (logits_overhead == int(logits_overhead)) and (end_cut != 1):
|
53 |
+
end_cut +=1
|
54 |
+
logits = logits_list[idx][:, start_cut:-end_cut]
|
55 |
+
cutted_logits.append(logits)
|
56 |
+
|
57 |
+
# join
|
58 |
+
logits = torch.cat(cutted_logits, axis=1)
|
59 |
+
logits_len = torch.tensor([logits.shape[1]])
|
60 |
+
current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
|
61 |
+
logits, decoder_lengths=logits_len, return_hypotheses=False,
|
62 |
+
)
|
63 |
+
|
64 |
+
return current_hypotheses[0]
|
65 |
+
|
66 |
+
|
67 |
+
def transcribe(audio, state):
|
68 |
+
if state is None:
|
69 |
+
state = [np.array([], dtype=np.float32), []]
|
70 |
+
|
71 |
+
sr, audio_data = audio
|
72 |
+
audio_16k = resample(sr, audio_data)
|
73 |
+
|
74 |
+
# join to audio sequence
|
75 |
+
state[0] = np.concatenate([state[0], audio_16k])
|
76 |
+
|
77 |
+
while (len(state[0]) > total_buffer):
|
78 |
+
buffer = state[0][:total_buffer]
|
79 |
+
state[0] = state[0][total_buffer - overhead_len:]
|
80 |
+
# run model
|
81 |
+
logits = model(buffer)
|
82 |
+
# add logits
|
83 |
+
state[1].append(logits)
|
84 |
+
|
85 |
+
if len(state[1]) == 0:
|
86 |
+
text = ""
|
87 |
+
else:
|
88 |
+
text = decode_predictions(state[1])
|
89 |
+
return text, state
|
90 |
+
|
91 |
+
|
92 |
+
gr.Interface(
|
93 |
+
fn=transcribe,
|
94 |
+
inputs=[
|
95 |
+
gr.Audio(source="microphone", type="numpy", streaming=True),
|
96 |
+
gr.State(None)
|
97 |
+
],
|
98 |
+
outputs=[
|
99 |
+
"textbox",
|
100 |
+
"state"
|
101 |
+
],
|
102 |
+
live=True).launch()
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
nemo_toolkit[asr]
|