theodotus commited on
Commit
2f05f3a
·
1 Parent(s): a20f918

Added global decoding

Browse files
Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -39,8 +39,11 @@ def model(audio_16k):
39
  logits_overhead = logits.shape[1] * overhead_len // total_buffer
40
  extra = 1 if (logits.shape[1] * overhead_len % total_buffer) else 0
41
  logits = logits[:,logits_overhead:-logits_overhead-extra]
42
- logits_len -= 2 * logits_overhead + extra
43
 
 
 
 
44
  current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
45
  logits, decoder_lengths=logits_len, return_hypotheses=False,
46
  )
@@ -50,7 +53,7 @@ def model(audio_16k):
50
 
51
  def transcribe(audio, state):
52
  if state is None:
53
- state = [np.array([], dtype=np.float32), ""]
54
 
55
  sr, audio_data = audio
56
  audio_16k = resample(sr, audio_data)
@@ -64,13 +67,15 @@ def transcribe(audio, state):
64
  buffer = state[0][:buffer_len]
65
  state[0] = state[0][buffer_len - overhead_len:]
66
  # run model
67
- text = model(buffer)
68
- else:
69
- text = ""
70
-
71
- if (len(text) != 0):
72
- state[1] += text + " "
73
- return state[1], state
 
 
74
 
75
 
76
  gr.Interface(
 
39
  logits_overhead = logits.shape[1] * overhead_len // total_buffer
40
  extra = 1 if (logits.shape[1] * overhead_len % total_buffer) else 0
41
  logits = logits[:,logits_overhead:-logits_overhead-extra]
42
+ return logits
43
 
44
+
45
+ def decode_predictions(logits):
46
+ logits_len = torch.tensor([logits.shape[1]])
47
  current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
48
  logits, decoder_lengths=logits_len, return_hypotheses=False,
49
  )
 
53
 
54
  def transcribe(audio, state):
55
  if state is None:
56
+ state = [np.array([], dtype=np.float32), None]
57
 
58
  sr, audio_data = audio
59
  audio_16k = resample(sr, audio_data)
 
67
  buffer = state[0][:buffer_len]
68
  state[0] = state[0][buffer_len - overhead_len:]
69
  # run model
70
+ logits = model(buffer)
71
+ # add logits
72
+ if state[1] is None:
73
+ state[1] = logits
74
+ else:
75
+ state[1] = torch.cat([state[1],logits], axis=1)
76
+
77
+ text = decode_predictions(state[1])
78
+ return text, state
79
 
80
 
81
  gr.Interface(