Labbeti commited on
Commit
a480cb3
·
1 Parent(s): 83af184

Mod: Cache candidate results in memory to avoid re-computation.

Browse files
Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -2,6 +2,7 @@
2
  # -*- coding: utf-8 -*-
3
 
4
  from tempfile import NamedTemporaryFile
 
5
 
6
  import streamlit as st
7
 
@@ -14,7 +15,8 @@ def load_conette(*args, **kwargs) -> CoNeTTEModel:
14
 
15
 
16
  def main() -> None:
17
- st.header("CoNeTTE model test")
 
18
  model = load_conette(model_kwds=dict(device="cpu"))
19
 
20
  task = st.selectbox("Task embedding input", model.tasks, 0)
@@ -34,6 +36,7 @@ def main() -> None:
34
  model.config.max_pred_size,
35
  )
36
 
 
37
  audios = st.file_uploader(
38
  "Upload an audio file",
39
  type=["wav", "flac", "mp3", "ogg", "avi"],
@@ -45,14 +48,24 @@ def main() -> None:
45
  with NamedTemporaryFile() as temp:
46
  temp.write(audio.getvalue())
47
  fpath = temp.name
48
- outputs = model(
49
- fpath,
50
  task=task,
51
  beam_size=beam_size,
52
  min_pred_size=min_pred_size,
53
  max_pred_size=max_pred_size,
54
  )
55
- cand = outputs["cands"][0]
 
 
 
 
 
 
 
 
 
 
56
 
57
  st.write(f"Output for {audio.name}:")
58
  st.write(" - ", cand)
 
2
  # -*- coding: utf-8 -*-
3
 
4
  from tempfile import NamedTemporaryFile
5
+ from typing import Any
6
 
7
  import streamlit as st
8
 
 
15
 
16
 
17
  def main() -> None:
18
+ st.header("Describe audio content with CoNeTTE")
19
+
20
  model = load_conette(model_kwds=dict(device="cpu"))
21
 
22
  task = st.selectbox("Task embedding input", model.tasks, 0)
 
36
  model.config.max_pred_size,
37
  )
38
 
39
+ st.write("Recommanded audio: lasting from 1s to 30s, sampled at 32 kHz.")
40
  audios = st.file_uploader(
41
  "Upload an audio file",
42
  type=["wav", "flac", "mp3", "ogg", "avi"],
 
48
  with NamedTemporaryFile() as temp:
49
  temp.write(audio.getvalue())
50
  fpath = temp.name
51
+
52
+ kwargs: dict[str, Any] = dict(
53
  task=task,
54
  beam_size=beam_size,
55
  min_pred_size=min_pred_size,
56
  max_pred_size=max_pred_size,
57
  )
58
+ cand_key = f"{audio.name}-{kwargs}"
59
+
60
+ if cand_key in st.session_state:
61
+ cand = st.session_state[cand_key]
62
+ else:
63
+ outputs = model(
64
+ fpath,
65
+ **kwargs,
66
+ )
67
+ cand = outputs["cands"][0]
68
+ st.session_state[cand_key] = cand
69
 
70
  st.write(f"Output for {audio.name}:")
71
  st.write(" - ", cand)