Spaces:
Sleeping
Sleeping
Added decoding at the end
Browse files
app.py
CHANGED
@@ -29,24 +29,30 @@ def resample(sr, audio_data):
|
|
29 |
return audio_16k
|
30 |
|
31 |
|
32 |
-
def model(audio_16k
|
33 |
logits, logits_len, greedy_predictions = asr_model.forward(
|
34 |
input_signal=torch.tensor([audio_16k]),
|
35 |
input_signal_length=torch.tensor([len(audio_16k)])
|
36 |
)
|
37 |
-
|
38 |
-
# cut overhead
|
39 |
-
buffer_len = len(audio_16k)
|
40 |
-
logits_overhead = (logits.shape[1] - 1) * overhead_len // buffer_len
|
41 |
-
logits_overhead //= 2
|
42 |
-
delay = (logits.shape[1] - 1) - (2 * logits_overhead)
|
43 |
-
start_cut = 0 if is_start else logits_overhead
|
44 |
-
delay += 0 if not is_start else logits_overhead
|
45 |
-
logits = logits[:, start_cut:start_cut+delay]
|
46 |
return logits
|
47 |
|
48 |
|
49 |
-
def decode_predictions(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
logits_len = torch.tensor([logits.shape[1]])
|
51 |
current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
|
52 |
logits, decoder_lengths=logits_len, return_hypotheses=False,
|
@@ -57,8 +63,7 @@ def decode_predictions(logits):
|
|
57 |
|
58 |
def transcribe(audio, state):
|
59 |
if state is None:
|
60 |
-
state = [np.array([], dtype=np.float32),
|
61 |
-
is_start = state[1] is None
|
62 |
|
63 |
sr, audio_data = audio
|
64 |
audio_16k = resample(sr, audio_data)
|
@@ -70,15 +75,11 @@ def transcribe(audio, state):
|
|
70 |
buffer = state[0][:total_buffer]
|
71 |
state[0] = state[0][total_buffer - overhead_len:]
|
72 |
# run model
|
73 |
-
|
74 |
-
logits = model(buffer, is_start)
|
75 |
# add logits
|
76 |
-
|
77 |
-
state[1] = logits
|
78 |
-
else:
|
79 |
-
state[1] = torch.cat([state[1],logits], axis=1)
|
80 |
|
81 |
-
if
|
82 |
text = ""
|
83 |
else:
|
84 |
text = decode_predictions(state[1])
|
|
|
29 |
return audio_16k
|
30 |
|
31 |
|
32 |
+
def model(audio_16k):
|
33 |
logits, logits_len, greedy_predictions = asr_model.forward(
|
34 |
input_signal=torch.tensor([audio_16k]),
|
35 |
input_signal_length=torch.tensor([len(audio_16k)])
|
36 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
return logits
|
38 |
|
39 |
|
40 |
+
def decode_predictions(logits_list):
|
41 |
+
# calc overhead
|
42 |
+
logits_overhead = logits_list[0].shape[1] * overhead_len // total_buffer
|
43 |
+
logits_overhead //= 2
|
44 |
+
#delay = (logits.shape[1] - 1) - (2 * logits_overhead)
|
45 |
+
|
46 |
+
# cut overhead
|
47 |
+
cutted_logits = []
|
48 |
+
for idx in range(len(logits_list)):
|
49 |
+
start_cut = 0 if (idx==0) else logits_overhead
|
50 |
+
end_cut = 1 if (idx==len(logits_list)-1) else logits_overhead
|
51 |
+
logits = logits_list[idx][:, start_cut:-end_cut]
|
52 |
+
cutted_logits.append(logits)
|
53 |
+
|
54 |
+
# join
|
55 |
+
logits = torch.cat(cutted_logits, axis=1)
|
56 |
logits_len = torch.tensor([logits.shape[1]])
|
57 |
current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
|
58 |
logits, decoder_lengths=logits_len, return_hypotheses=False,
|
|
|
63 |
|
64 |
def transcribe(audio, state):
|
65 |
if state is None:
|
66 |
+
state = [np.array([], dtype=np.float32), []]
|
|
|
67 |
|
68 |
sr, audio_data = audio
|
69 |
audio_16k = resample(sr, audio_data)
|
|
|
75 |
buffer = state[0][:total_buffer]
|
76 |
state[0] = state[0][total_buffer - overhead_len:]
|
77 |
# run model
|
78 |
+
logits = model(buffer)
|
|
|
79 |
# add logits
|
80 |
+
state[1].append(logits)
|
|
|
|
|
|
|
81 |
|
82 |
+
if len(state[1]) == 0:
|
83 |
text = ""
|
84 |
else:
|
85 |
text = decode_predictions(state[1])
|