theodotus commited on
Commit
0992503
·
1 Parent(s): ba10ffc

Copied theodotus/streaming-asr-uk

Browse files
Files changed (4) hide show
  1. packages.txt +2 -0
  2. .gitignore +3 -0
  3. app.py +102 -0
  4. requirements.txt +1 -0
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ libsndfile1
2
+ ffmpeg
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .env
2
+ .vscode
3
+ flagged
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import resampy
4
+ import torch
5
+
6
+ from math import floor,ceil
7
+ import nemo.collections.asr as nemo_asr
8
+
9
+
10
+ asr_model = nemo_asr.models.EncDecCTCModelBPE. \
11
+ from_pretrained("NeonBohdan/stt_uk_citrinet_512_gamma_0_25",map_location="cpu")
12
+
13
+ asr_model.preprocessor.featurizer.dither = 0.0
14
+ asr_model.preprocessor.featurizer.pad_to = 0
15
+ asr_model.eval()
16
+ asr_model.encoder.freeze()
17
+ asr_model.decoder.freeze()
18
+
19
+
20
+ total_buffer = asr_model.cfg["sample_rate"] * 19 // 10
21
+ overhead_len = total_buffer // 2
22
+ model_stride = 4
23
+
24
+
25
+
26
+ def resample(sr, audio_data):
27
+ audio_fp32 = np.divide(audio_data, np.iinfo(audio_data.dtype).max, dtype=np.float32)
28
+ audio_16k = resampy.resample(audio_fp32, sr, asr_model.cfg["sample_rate"])
29
+
30
+ return audio_16k
31
+
32
+
33
+ def model(audio_16k):
34
+ logits, logits_len, greedy_predictions = asr_model.forward(
35
+ input_signal=torch.tensor([audio_16k]),
36
+ input_signal_length=torch.tensor([len(audio_16k)])
37
+ )
38
+ return logits
39
+
40
+
41
+ def decode_predictions(logits_list):
42
+ # calc overhead
43
+ logits_overhead = logits_list[0].shape[1] * overhead_len / total_buffer / 2
44
+ if (logits_overhead * 2 != int(logits_overhead * 2)):
45
+ raise ValueError("Wrong total_buffer")
46
+
47
+ # cut overhead
48
+ cutted_logits = []
49
+ for idx in range(len(logits_list)):
50
+ start_cut = 0 if (idx==0) else floor(logits_overhead)
51
+ end_cut = 1 if (idx==len(logits_list)-1) else ceil(logits_overhead)
52
+ if (logits_overhead == int(logits_overhead)) and (end_cut != 1):
53
+ end_cut +=1
54
+ logits = logits_list[idx][:, start_cut:-end_cut]
55
+ cutted_logits.append(logits)
56
+
57
+ # join
58
+ logits = torch.cat(cutted_logits, axis=1)
59
+ logits_len = torch.tensor([logits.shape[1]])
60
+ current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
61
+ logits, decoder_lengths=logits_len, return_hypotheses=False,
62
+ )
63
+
64
+ return current_hypotheses[0]
65
+
66
+
67
+ def transcribe(audio, state):
68
+ if state is None:
69
+ state = [np.array([], dtype=np.float32), []]
70
+
71
+ sr, audio_data = audio
72
+ audio_16k = resample(sr, audio_data)
73
+
74
+ # join to audio sequence
75
+ state[0] = np.concatenate([state[0], audio_16k])
76
+
77
+ while (len(state[0]) > total_buffer):
78
+ buffer = state[0][:total_buffer]
79
+ state[0] = state[0][total_buffer - overhead_len:]
80
+ # run model
81
+ logits = model(buffer)
82
+ # add logits
83
+ state[1].append(logits)
84
+
85
+ if len(state[1]) == 0:
86
+ text = ""
87
+ else:
88
+ text = decode_predictions(state[1])
89
+ return text, state
90
+
91
+
92
+ gr.Interface(
93
+ fn=transcribe,
94
+ inputs=[
95
+ gr.Audio(source="microphone", type="numpy", streaming=True),
96
+ gr.State(None)
97
+ ],
98
+ outputs=[
99
+ "textbox",
100
+ "state"
101
+ ],
102
+ live=True).launch()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ nemo_toolkit[asr]