Labbeti commited on
Commit
ae94a43
1 Parent(s): 0e27ede

Add/Mod: Add new audio recorder, threshold hparam, min and max beam size single slider.

Browse files
Files changed (2) hide show
  1. app.py +113 -82
  2. requirements.txt +2 -2
app.py CHANGED
@@ -1,17 +1,25 @@
1
  #!/usr/bin/env python
2
  # -*- coding: utf-8 -*-
3
 
4
- import os.path as osp
5
-
6
- from tempfile import NamedTemporaryFile
7
- from typing import Any
8
 
9
  import streamlit as st
10
 
11
- from audiorecorder import audiorecorder
12
  from streamlit.runtime.uploaded_file_manager import UploadedFile
13
 
14
  from conette import CoNeTTEModel, conette
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  @st.cache_resource
@@ -26,46 +34,86 @@ def format_candidate(candidate: str) -> str:
26
  return f"{candidate[0].title()}{candidate[1:]}."
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
29
  def get_results(
30
  model: CoNeTTEModel,
31
- audios: list[UploadedFile],
32
  generate_kwds: dict[str, Any],
33
- ) -> tuple[list[str], list[str]]:
34
- audio_to_predict = []
35
- cands = [""] * len(audios)
36
- tmp_files = []
37
- tmp_fpaths = []
38
- audio_fnames = []
39
-
40
- for i, audio in enumerate(audios):
41
- audio_fname = audio.name
42
- audio_fnames.append(audio_fname)
43
- cand_key = f"{audio_fname}-{generate_kwds}"
44
-
45
- if cand_key in st.session_state:
46
- cand = st.session_state[cand_key]
47
- cands[i] = cand
48
- else:
49
- tmp_file = NamedTemporaryFile()
50
- tmp_file.write(audio.getvalue())
51
- tmp_files.append(tmp_file)
52
- audio_to_predict.append((i, cand_key, tmp_file))
53
-
54
- tmp_fpath = tmp_file.name
55
- tmp_fpaths.append(tmp_fpath)
56
-
57
- if len(tmp_fpaths) > 0:
58
- outputs = model(
59
- tmp_fpaths,
60
  **generate_kwds,
61
  )
62
- for i, (j, cand_key, tmp_file) in enumerate(audio_to_predict):
63
- cand = outputs["cands"][i]
64
- cands[j] = cand
65
- st.session_state[cand_key] = cand
66
  tmp_file.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- return audio_fnames, cands
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
  def main() -> None:
@@ -73,46 +121,35 @@ def main() -> None:
73
 
74
  model = load_conette(model_kwds=dict(device="cpu"))
75
 
76
- st.warning(
77
- "Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz** minimum."
78
- )
79
- audios = st.file_uploader(
80
- "**Upload audio files here:**",
 
 
81
  type=["wav", "flac", "mp3", "ogg", "avi"],
82
  accept_multiple_files=True,
83
  )
84
- st.write("**OR**")
85
- record = audiorecorder(
86
- start_prompt="Start recording", stop_prompt="Stop recording", pause_prompt=""
87
- )
88
-
89
- record_fpath = "record.wav"
90
- if len(record) > 0:
91
- record.export(record_fpath, format="wav")
92
- st.write(
93
- f"Record frame rate: {record.frame_rate}Hz, record duration: {record.duration_seconds:.2f}s"
94
- )
95
- st.audio(record.export().read()) # type: ignore
96
 
97
  with st.expander("Model hyperparameters"):
98
  task = st.selectbox("Task embedding input", model.tasks, 0)
99
- allow_rep_mode = st.selectbox(
100
- "Allow repetition of words", ["stopwords", "all", "none"], 0
101
- )
102
  beam_size: int = st.select_slider( # type: ignore
103
  "Beam size",
104
- list(range(1, 21)),
105
  model.config.beam_size,
106
  )
107
- min_pred_size: int = st.select_slider( # type: ignore
108
- "Minimal number of words",
109
- list(range(1, 31)),
110
- model.config.min_pred_size,
 
111
  )
112
- max_pred_size: int = st.select_slider( # type: ignore
113
- "Maximal number of words",
114
- list(range(1, 31)),
115
- model.config.max_pred_size,
116
  )
117
 
118
  if allow_rep_mode == "all":
@@ -122,7 +159,6 @@ def main() -> None:
122
  elif allow_rep_mode == "stopwords":
123
  forbid_rep_mode = "content_words"
124
  else:
125
- ALLOW_REP_MODES = ("all", "none", "stopwords")
126
  raise ValueError(
127
  f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})"
128
  )
@@ -134,23 +170,18 @@ def main() -> None:
134
  min_pred_size=min_pred_size,
135
  max_pred_size=max_pred_size,
136
  forbid_rep_mode=forbid_rep_mode,
 
137
  )
138
 
139
- if audios is not None and len(audios) > 0:
140
- audio_fnames, cands = get_results(model, audios, generate_kwds)
141
-
142
- for audio_fname, cand in zip(audio_fnames, cands):
143
- st.success(f"**Output for {audio_fname}:**\n- {format_candidate(cand)}")
144
 
145
- if len(record) > 0:
146
- outputs = model(
147
- record_fpath,
148
- **generate_kwds,
149
- )
150
- cand = outputs["cands"][0]
151
- st.success(
152
- f"**Output for {osp.basename(record_fpath)}:**\n- {format_candidate(cand)}"
153
- )
154
 
155
 
156
  if __name__ == "__main__":
 
1
  #!/usr/bin/env python
2
  # -*- coding: utf-8 -*-
3
 
4
+ from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
5
+ from typing import Any, Optional
 
 
6
 
7
  import streamlit as st
8
 
9
+ from st_audiorec import st_audiorec
10
  from streamlit.runtime.uploaded_file_manager import UploadedFile
11
 
12
  from conette import CoNeTTEModel, conette
13
+ from conette.utils.collections import dict_list_to_list_dict
14
+
15
+
16
+ ALLOW_REP_MODES = ("stopwords", "all", "none")
17
+ MAX_BEAM_SIZE = 20
18
+ MAX_PRED_SIZE = 30
19
+ MAX_BATCH_SIZE = 32
20
+ RECORD_AUDIO_FNAME = "record.wav"
21
+ DEFAULT_THRESHOLD = 0.3
22
+ THRESHOLD_PRECISION = 100
23
 
24
 
25
  @st.cache_resource
 
34
  return f"{candidate[0].title()}{candidate[1:]}."
35
 
36
 
37
+ def format_tags(tags: Optional[list[str]]) -> str:
38
+ if tags is None or len(tags) == 0:
39
+ return "None."
40
+ else:
41
+ return ", ".join(tags)
42
+
43
+
44
+ def get_result_hash(audio_fname: str, generate_kwds: dict[str, Any]) -> str:
45
+ return f"{audio_fname}-{generate_kwds}"
46
+
47
+
48
  def get_results(
49
  model: CoNeTTEModel,
50
+ audio_files: dict[str, bytes],
51
  generate_kwds: dict[str, Any],
52
+ ) -> dict[str, dict[str, Any]]:
53
+ # Get audio to be processed
54
+ audio_to_predict: dict[str, bytes] = {}
55
+ for audio_fname, audio in audio_files.items():
56
+ result_hash = get_result_hash(audio_fname, generate_kwds)
57
+ if result_hash not in st.session_state or audio_fname == RECORD_AUDIO_FNAME:
58
+ audio_to_predict[result_hash] = audio
59
+
60
+ # Save audio to be processed
61
+ tmp_files: dict[str, _TemporaryFileWrapper] = {}
62
+ for result_hash, audio in audio_to_predict.items():
63
+ tmp_file = NamedTemporaryFile()
64
+ tmp_file.write(audio)
65
+ tmp_files[result_hash] = tmp_file
66
+
67
+ # Generate predictions and store them in session state
68
+ for start in range(0, len(tmp_files), MAX_BATCH_SIZE):
69
+ end = min(start + MAX_BATCH_SIZE, len(tmp_files))
70
+ result_hashes_j = list(tmp_files.keys())[start:end]
71
+ tmp_files_j = list(tmp_files.values())[start:end]
72
+ tmp_paths_j = [tmp_file.name for tmp_file in tmp_files_j]
73
+ outputs_j = model(
74
+ tmp_paths_j,
 
 
 
 
75
  **generate_kwds,
76
  )
77
+ for tmp_file in tmp_files_j:
 
 
 
78
  tmp_file.close()
79
+ outputs_lst = dict_list_to_list_dict(outputs_j) # type: ignore
80
+ for result_hash, output_i in zip(result_hashes_j, outputs_lst):
81
+ st.session_state[result_hash] = output_i
82
+
83
+ # Get outputs
84
+ outputs = {}
85
+ for audio_fname in audio_files.keys():
86
+ result_hash = get_result_hash(audio_fname, generate_kwds)
87
+ output_i = st.session_state[result_hash]
88
+ outputs[audio_fname] = output_i
89
+
90
+ return outputs
91
+
92
+
93
+ def show_results(outputs: dict[str, dict[str, Any]]) -> None:
94
+ st.divider()
95
 
96
+ for audio_fname, output in outputs.items():
97
+ cand = output["cands"]
98
+ lprobs = output["lprobs"]
99
+ tags = output.get("tags")
100
+
101
+ cand = format_candidate(cand)
102
+ tags = format_tags(tags)
103
+ prob = lprobs.exp().tolist()
104
+
105
+ if audio_fname == RECORD_AUDIO_FNAME:
106
+ header = "##### Result for microphone input:"
107
+ else:
108
+ header = f'##### Result for "{audio_fname}"'
109
+
110
+ content = f"""
111
+ {header}
112
+ - **Description:** "{cand}"
113
+ - **Mean confidence:** {prob*100:.0f}%
114
+ - **Tags:** {tags}"""
115
+ st.markdown(content)
116
+ st.divider()
117
 
118
 
119
  def main() -> None:
 
121
 
122
  model = load_conette(model_kwds=dict(device="cpu"))
123
 
124
+ # st.warning(
125
+ # "Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz** minimum."
126
+ # )
127
+
128
+ record_data = st_audiorec()
129
+ audio_files: Optional[list[UploadedFile]] = st.file_uploader(
130
+ "**Or upload audio files here:**",
131
  type=["wav", "flac", "mp3", "ogg", "avi"],
132
  accept_multiple_files=True,
133
  )
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  with st.expander("Model hyperparameters"):
136
  task = st.selectbox("Task embedding input", model.tasks, 0)
137
+ allow_rep_mode = st.selectbox("Allow repetition of words", ALLOW_REP_MODES, 0)
 
 
138
  beam_size: int = st.select_slider( # type: ignore
139
  "Beam size",
140
+ list(range(1, MAX_BEAM_SIZE + 1)),
141
  model.config.beam_size,
142
  )
143
+ min_pred_size, max_pred_size = st.slider(
144
+ "Minimal and maximal number of words",
145
+ 1,
146
+ MAX_PRED_SIZE,
147
+ (model.config.min_pred_size, model.config.max_pred_size),
148
  )
149
+ threshold = st.select_slider(
150
+ "Tags threshold",
151
+ [(i / THRESHOLD_PRECISION) for i in range(THRESHOLD_PRECISION + 1)],
152
+ DEFAULT_THRESHOLD,
153
  )
154
 
155
  if allow_rep_mode == "all":
 
159
  elif allow_rep_mode == "stopwords":
160
  forbid_rep_mode = "content_words"
161
  else:
 
162
  raise ValueError(
163
  f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})"
164
  )
 
170
  min_pred_size=min_pred_size,
171
  max_pred_size=max_pred_size,
172
  forbid_rep_mode=forbid_rep_mode,
173
+ threshold=threshold,
174
  )
175
 
176
+ audios: dict[str, bytes] = {}
177
+ if audio_files is not None:
178
+ audios |= {audio.name: audio.getvalue() for audio in audio_files}
179
+ if record_data is not None:
180
+ audios |= {RECORD_AUDIO_FNAME: record_data}
181
 
182
+ if len(audios) > 0:
183
+ outputs = get_results(model, audios, generate_kwds)
184
+ show_results(outputs)
 
 
 
 
 
 
185
 
186
 
187
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- conette~=0.2.0
2
- streamlit-audiorecorder
 
1
+ conette~=0.2.2
2
+ streamlit-audiorec~=0.1.3