Labbeti commited on
Commit
d290679
·
1 Parent(s): afed00c

Mod: Update forward to compute all audio files per batch and improve UI for hyperparameters.

Browse files
Files changed (1) hide show
  1. app.py +80 -56
app.py CHANGED
@@ -23,68 +23,92 @@ def main() -> None:
23
 
24
  model = load_conette(model_kwds=dict(device="cpu"))
25
 
26
- task = st.selectbox("Task embedding input", model.tasks, 0)
27
- allow_rep_mode = st.selectbox("Allow repetition of words", ["stopwords", "all", "none"], 0)
28
- beam_size: int = st.select_slider( # type: ignore
29
- "Beam size",
30
- list(range(1, 21)),
31
- model.config.beam_size,
32
- )
33
- min_pred_size: int = st.select_slider( # type: ignore
34
- "Minimal number of words",
35
- list(range(1, 31)),
36
- model.config.min_pred_size,
37
- )
38
- max_pred_size: int = st.select_slider( # type: ignore
39
- "Maximal number of words",
40
- list(range(1, 31)),
41
- model.config.max_pred_size,
42
- )
43
-
44
- st.markdown("Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz**.")
45
  audios = st.file_uploader(
46
- "Upload an audio file",
47
  type=["wav", "flac", "mp3", "ogg", "avi"],
48
  accept_multiple_files=True,
49
  )
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  if audios is not None and len(audios) > 0:
52
- for audio in audios:
53
- with NamedTemporaryFile() as temp:
54
- temp.write(audio.getvalue())
55
- fpath = temp.name
56
-
57
- if allow_rep_mode == "all":
58
- forbid_rep_mode = "none"
59
- elif allow_rep_mode == "none":
60
- forbid_rep_mode = "all"
61
- elif allow_rep_mode == "stopwords":
62
- forbid_rep_mode = "content_words"
63
- else:
64
- ALLOW_REP_MODES = ("all", "none", "stopwords")
65
- raise ValueError(f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})")
66
-
67
- kwargs: dict[str, Any] = dict(
68
- task=task,
69
- beam_size=beam_size,
70
- min_pred_size=min_pred_size,
71
- max_pred_size=max_pred_size,
72
- forbid_rep_mode=forbid_rep_mode,
73
- )
74
- cand_key = f"{audio.name}-{kwargs}"
75
-
76
- if cand_key in st.session_state:
77
- cand = st.session_state[cand_key]
78
- else:
79
- outputs = model(
80
- fpath,
81
- **kwargs,
82
- )
83
- cand = outputs["cands"][0]
84
- st.session_state[cand_key] = cand
85
-
86
- st.markdown(f"Output for {audio.name}:")
87
- st.markdown(f" - red[{format_cand(cand)}]")
88
 
89
 
90
  if __name__ == "__main__":
 
23
 
24
  model = load_conette(model_kwds=dict(device="cpu"))
25
 
26
+ st.warning("Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz**.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  audios = st.file_uploader(
28
+ "Upload audio files here:",
29
  type=["wav", "flac", "mp3", "ogg", "avi"],
30
  accept_multiple_files=True,
31
  )
32
 
33
+ with st.expander("Model hyperparameters"):
34
+ task = st.selectbox("Task embedding input", model.tasks, 0)
35
+ allow_rep_mode = st.selectbox(
36
+ "Allow repetition of words", ["stopwords", "all", "none"], 0
37
+ )
38
+ beam_size: int = st.select_slider( # type: ignore
39
+ "Beam size",
40
+ list(range(1, 21)),
41
+ model.config.beam_size,
42
+ )
43
+ min_pred_size: int = st.select_slider( # type: ignore
44
+ "Minimal number of words",
45
+ list(range(1, 31)),
46
+ model.config.min_pred_size,
47
+ )
48
+ max_pred_size: int = st.select_slider( # type: ignore
49
+ "Maximal number of words",
50
+ list(range(1, 31)),
51
+ model.config.max_pred_size,
52
+ )
53
+
54
+ if allow_rep_mode == "all":
55
+ forbid_rep_mode = "none"
56
+ elif allow_rep_mode == "none":
57
+ forbid_rep_mode = "all"
58
+ elif allow_rep_mode == "stopwords":
59
+ forbid_rep_mode = "content_words"
60
+ else:
61
+ ALLOW_REP_MODES = ("all", "none", "stopwords")
62
+ raise ValueError(
63
+ f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})"
64
+ )
65
+ del allow_rep_mode
66
+
67
+ kwargs: dict[str, Any] = dict(
68
+ task=task,
69
+ beam_size=beam_size,
70
+ min_pred_size=min_pred_size,
71
+ max_pred_size=max_pred_size,
72
+ forbid_rep_mode=forbid_rep_mode,
73
+ )
74
+
75
  if audios is not None and len(audios) > 0:
76
+ audio_to_predict = []
77
+ cands = [""] * len(audios)
78
+ tmp_files = []
79
+ tmp_fpaths = []
80
+ audio_fnames = []
81
+
82
+ for i, audio in enumerate(audios):
83
+ audio_fname = audio.name
84
+ audio_fnames.append(audio_fname)
85
+ cand_key = f"{audio_fname}-{kwargs}"
86
+
87
+ if cand_key in st.session_state:
88
+ cand = st.session_state[cand_key]
89
+ cands[i] = cand
90
+ else:
91
+ tmp_file = NamedTemporaryFile()
92
+ tmp_file.write(audio.getvalue())
93
+ tmp_files.append(tmp_file)
94
+ audio_to_predict.append((i, cand_key, tmp_file))
95
+
96
+ tmp_fpath = tmp_file.name
97
+ tmp_fpaths.append(tmp_fpath)
98
+
99
+ if len(tmp_fpaths) > 0:
100
+ outputs = model(
101
+ tmp_fpaths,
102
+ **kwargs,
103
+ )
104
+ for i, (j, cand_key, tmp_file) in enumerate(audio_to_predict):
105
+ cand = outputs["cands"][i]
106
+ cands[j] = cand
107
+ st.session_state[cand_key] = cand
108
+ tmp_file.close()
109
+
110
+ for audio_fname, cand in zip(audio_fnames, cands):
111
+ st.success(f"**Output for {audio_fname}:**\n- {format_cand(cand)}")
112
 
113
 
114
  if __name__ == "__main__":