roman commited on
Commit
9c6f6ce
·
1 Parent(s): 6775bac

reset state of output

Browse files
Files changed (1) hide show
  1. app.py +28 -2
app.py CHANGED
@@ -65,7 +65,7 @@ def decode_predictions(logits_list):
65
  return current_hypotheses[0]
66
 
67
 
68
- def transcribe(audio, state):
69
  if state is None:
70
  state = [np.array([], dtype=np.float32), []]
71
 
@@ -88,6 +88,31 @@ def transcribe(audio, state):
88
  text = decode_predictions(state[1])
89
  return text, state
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  gr.Interface(
93
  fn=transcribe,
@@ -95,7 +120,8 @@ gr.Interface(
95
  # gr.Audio(source="upload", type="filepath", streaming=True),
96
  gr.Audio(source="upload", type="filepath"),
97
  # "state"
98
- gr.State(None)
 
99
  ],
100
  outputs=[
101
  "textbox",
 
65
  return current_hypotheses[0]
66
 
67
 
68
+ def transcribe_(audio, state):
69
  if state is None:
70
  state = [np.array([], dtype=np.float32), []]
71
 
 
88
  text = decode_predictions(state[1])
89
  return text, state
90
 
91
+ def transcribe(audio, state, reset_state):
92
+ if reset_state:
93
+ state = [np.array([], dtype=np.float32), []]
94
+
95
+ if state is None:
96
+ state = [np.array([], dtype=np.float32), []]
97
+
98
+ audio_16k = resample(audio)
99
+
100
+ # join to audio sequence
101
+ state[0] = np.concatenate([state[0], audio_16k])
102
+
103
+ while (len(state[0]) > total_buffer):
104
+ buffer = state[0][:total_buffer]
105
+ state[0] = state[0][total_buffer - overhead_len:]
106
+ # run model
107
+ logits = model(buffer)
108
+ # add logits
109
+ state[1].append(logits)
110
+
111
+ if len(state[1]) == 0:
112
+ text = ""
113
+ else:
114
+ text = decode_predictions(state[1])
115
+ return text, state
116
 
117
  gr.Interface(
118
  fn=transcribe,
 
120
  # gr.Audio(source="upload", type="filepath", streaming=True),
121
  gr.Audio(source="upload", type="filepath"),
122
  # "state"
123
+ gr.State(None),
124
+ gr.Button(text="Reset State", label="Reset State")
125
  ],
126
  outputs=[
127
  "textbox",