Labbeti commited on
Commit
afed00c
·
1 Parent(s): a58a491

Add: Allow repetition mode option.

Browse files
Files changed (1) hide show
  1. app.py +20 -4
app.py CHANGED
@@ -14,15 +14,20 @@ def load_conette(*args, **kwargs) -> CoNeTTEModel:
14
  return conette(*args, **kwargs)
15
 
16
 
 
 
 
 
17
  def main() -> None:
18
  st.header("Describe audio content with CoNeTTE")
19
 
20
  model = load_conette(model_kwds=dict(device="cpu"))
21
 
22
  task = st.selectbox("Task embedding input", model.tasks, 0)
 
23
  beam_size: int = st.select_slider( # type: ignore
24
  "Beam size",
25
- list(range(1, 20)),
26
  model.config.beam_size,
27
  )
28
  min_pred_size: int = st.select_slider( # type: ignore
@@ -36,7 +41,7 @@ def main() -> None:
36
  model.config.max_pred_size,
37
  )
38
 
39
- st.write("Recommanded audio: lasting from 1s to 30s, sampled at 32 kHz.")
40
  audios = st.file_uploader(
41
  "Upload an audio file",
42
  type=["wav", "flac", "mp3", "ogg", "avi"],
@@ -49,11 +54,22 @@ def main() -> None:
49
  temp.write(audio.getvalue())
50
  fpath = temp.name
51
 
 
 
 
 
 
 
 
 
 
 
52
  kwargs: dict[str, Any] = dict(
53
  task=task,
54
  beam_size=beam_size,
55
  min_pred_size=min_pred_size,
56
  max_pred_size=max_pred_size,
 
57
  )
58
  cand_key = f"{audio.name}-{kwargs}"
59
 
@@ -67,8 +83,8 @@ def main() -> None:
67
  cand = outputs["cands"][0]
68
  st.session_state[cand_key] = cand
69
 
70
- st.write(f"Output for {audio.name}:")
71
- st.write(" - ", cand)
72
 
73
 
74
  if __name__ == "__main__":
 
14
  return conette(*args, **kwargs)
15
 
16
 
17
+ def format_cand(cand: str) -> str:
18
+ return f"{cand[0].title()}{cand[1:]}."
19
+
20
+
21
  def main() -> None:
22
  st.header("Describe audio content with CoNeTTE")
23
 
24
  model = load_conette(model_kwds=dict(device="cpu"))
25
 
26
  task = st.selectbox("Task embedding input", model.tasks, 0)
27
+ allow_rep_mode = st.selectbox("Allow repetition of words", ["stopwords", "all", "none"], 0)
28
  beam_size: int = st.select_slider( # type: ignore
29
  "Beam size",
30
+ list(range(1, 21)),
31
  model.config.beam_size,
32
  )
33
  min_pred_size: int = st.select_slider( # type: ignore
 
41
  model.config.max_pred_size,
42
  )
43
 
44
+ st.markdown("Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz**.")
45
  audios = st.file_uploader(
46
  "Upload an audio file",
47
  type=["wav", "flac", "mp3", "ogg", "avi"],
 
54
  temp.write(audio.getvalue())
55
  fpath = temp.name
56
 
57
+ if allow_rep_mode == "all":
58
+ forbid_rep_mode = "none"
59
+ elif allow_rep_mode == "none":
60
+ forbid_rep_mode = "all"
61
+ elif allow_rep_mode == "stopwords":
62
+ forbid_rep_mode = "content_words"
63
+ else:
64
+ ALLOW_REP_MODES = ("all", "none", "stopwords")
65
+ raise ValueError(f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})")
66
+
67
  kwargs: dict[str, Any] = dict(
68
  task=task,
69
  beam_size=beam_size,
70
  min_pred_size=min_pred_size,
71
  max_pred_size=max_pred_size,
72
+ forbid_rep_mode=forbid_rep_mode,
73
  )
74
  cand_key = f"{audio.name}-{kwargs}"
75
 
 
83
  cand = outputs["cands"][0]
84
  st.session_state[cand_key] = cand
85
 
86
+ st.markdown(f"Output for {audio.name}:")
87
+ st.markdown(f" - red[{format_cand(cand)}]")
88
 
89
 
90
  if __name__ == "__main__":