Spaces:
Sleeping
Sleeping
Added global decoding
Browse files
app.py
CHANGED
@@ -39,8 +39,11 @@ def model(audio_16k):
|
|
39 |
logits_overhead = logits.shape[1] * overhead_len // total_buffer
|
40 |
extra = 1 if (logits.shape[1] * overhead_len % total_buffer) else 0
|
41 |
logits = logits[:,logits_overhead:-logits_overhead-extra]
|
42 |
-
|
43 |
|
|
|
|
|
|
|
44 |
current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
|
45 |
logits, decoder_lengths=logits_len, return_hypotheses=False,
|
46 |
)
|
@@ -50,7 +53,7 @@ def model(audio_16k):
|
|
50 |
|
51 |
def transcribe(audio, state):
|
52 |
if state is None:
|
53 |
-
state = [np.array([], dtype=np.float32),
|
54 |
|
55 |
sr, audio_data = audio
|
56 |
audio_16k = resample(sr, audio_data)
|
@@ -64,13 +67,15 @@ def transcribe(audio, state):
|
|
64 |
buffer = state[0][:buffer_len]
|
65 |
state[0] = state[0][buffer_len - overhead_len:]
|
66 |
# run model
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
74 |
|
75 |
|
76 |
gr.Interface(
|
|
|
39 |
logits_overhead = logits.shape[1] * overhead_len // total_buffer
|
40 |
extra = 1 if (logits.shape[1] * overhead_len % total_buffer) else 0
|
41 |
logits = logits[:,logits_overhead:-logits_overhead-extra]
|
42 |
+
return logits
|
43 |
|
44 |
+
|
45 |
+
def decode_predictions(logits):
|
46 |
+
logits_len = torch.tensor([logits.shape[1]])
|
47 |
current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
|
48 |
logits, decoder_lengths=logits_len, return_hypotheses=False,
|
49 |
)
|
|
|
53 |
|
54 |
def transcribe(audio, state):
|
55 |
if state is None:
|
56 |
+
state = [np.array([], dtype=np.float32), None]
|
57 |
|
58 |
sr, audio_data = audio
|
59 |
audio_16k = resample(sr, audio_data)
|
|
|
67 |
buffer = state[0][:buffer_len]
|
68 |
state[0] = state[0][buffer_len - overhead_len:]
|
69 |
# run model
|
70 |
+
logits = model(buffer)
|
71 |
+
# add logits
|
72 |
+
if state[1] is None:
|
73 |
+
state[1] = logits
|
74 |
+
else:
|
75 |
+
state[1] = torch.cat([state[1],logits], axis=1)
|
76 |
+
|
77 |
+
text = decode_predictions(state[1])
|
78 |
+
return text, state
|
79 |
|
80 |
|
81 |
gr.Interface(
|