theodotus commited on
Commit
2a5f9c9
·
1 Parent(s): 2f05f3a

Nearly fixed streaming bug

Browse files
Files changed (1) hide show
  1. app.py +20 -13
app.py CHANGED
@@ -17,7 +17,7 @@ asr_model.decoder.freeze()
17
 
18
 
19
  total_buffer = asr_model.cfg["sample_rate"]
20
- overhead_len = asr_model.cfg["sample_rate"] // 4
21
  model_stride = 4
22
 
23
 
@@ -29,16 +29,20 @@ def resample(sr, audio_data):
29
  return audio_16k
30
 
31
 
32
- def model(audio_16k):
33
  logits, logits_len, greedy_predictions = asr_model.forward(
34
  input_signal=torch.tensor([audio_16k]),
35
  input_signal_length=torch.tensor([len(audio_16k)])
36
  )
37
 
38
  # cut overhead
39
- logits_overhead = logits.shape[1] * overhead_len // total_buffer
40
- extra = 1 if (logits.shape[1] * overhead_len % total_buffer) else 0
41
- logits = logits[:,logits_overhead:-logits_overhead-extra]
 
 
 
 
42
  return logits
43
 
44
 
@@ -54,6 +58,7 @@ def decode_predictions(logits):
54
  def transcribe(audio, state):
55
  if state is None:
56
  state = [np.array([], dtype=np.float32), None]
 
57
 
58
  sr, audio_data = audio
59
  audio_16k = resample(sr, audio_data)
@@ -61,20 +66,22 @@ def transcribe(audio, state):
61
  # join to audio sequence
62
  state[0] = np.concatenate([state[0], audio_16k])
63
 
64
- buffer_len = len(state[0])
65
- if (buffer_len > total_buffer):
66
- buffer_len = buffer_len - buffer_len % total_buffer
67
- buffer = state[0][:buffer_len]
68
- state[0] = state[0][buffer_len - overhead_len:]
69
  # run model
70
- logits = model(buffer)
 
71
  # add logits
72
- if state[1] is None:
73
  state[1] = logits
74
  else:
75
  state[1] = torch.cat([state[1],logits], axis=1)
76
 
77
- text = decode_predictions(state[1])
 
 
 
78
  return text, state
79
 
80
 
 
17
 
18
 
19
  total_buffer = asr_model.cfg["sample_rate"]
20
+ overhead_len = asr_model.cfg["sample_rate"] // 2
21
  model_stride = 4
22
 
23
 
 
29
  return audio_16k
30
 
31
 
32
+ def model(audio_16k, is_start):
33
  logits, logits_len, greedy_predictions = asr_model.forward(
34
  input_signal=torch.tensor([audio_16k]),
35
  input_signal_length=torch.tensor([len(audio_16k)])
36
  )
37
 
38
  # cut overhead
39
+ buffer_len = len(audio_16k)
40
+ logits_overhead = (logits.shape[1] - 1) * overhead_len // buffer_len
41
+ logits_overhead //= 2
42
+ delay = (logits.shape[1] - 1) - (2 * logits_overhead)
43
+ start_cut = 0 if is_start else logits_overhead
44
+ delay += 0 if not is_start else logits_overhead
45
+ logits = logits[:, start_cut:start_cut+delay]
46
  return logits
47
 
48
 
 
58
  def transcribe(audio, state):
59
  if state is None:
60
  state = [np.array([], dtype=np.float32), None]
61
+ is_start = state[1] is None
62
 
63
  sr, audio_data = audio
64
  audio_16k = resample(sr, audio_data)
 
66
  # join to audio sequence
67
  state[0] = np.concatenate([state[0], audio_16k])
68
 
69
+ while (len(state[0]) > total_buffer):
70
+ buffer = state[0][:total_buffer]
71
+ state[0] = state[0][total_buffer - overhead_len:]
 
 
72
  # run model
73
+ is_start = state[1] is None
74
+ logits = model(buffer, is_start)
75
  # add logits
76
+ if is_start:
77
  state[1] = logits
78
  else:
79
  state[1] = torch.cat([state[1],logits], axis=1)
80
 
81
+ if is_start:
82
+ text = ""
83
+ else:
84
+ text = decode_predictions(state[1])
85
  return text, state
86
 
87