Labbeti commited on
Commit
c37029b
·
1 Parent(s): b8c79c9

Mod: Refactor model forward.

Browse files
Files changed (1) hide show
  1. app.py +53 -37
app.py CHANGED
@@ -1,6 +1,8 @@
1
  #!/usr/bin/env python
2
  # -*- coding: utf-8 -*-
3
 
 
 
4
  from tempfile import NamedTemporaryFile
5
  from typing import Any
6
 
@@ -20,12 +22,56 @@ def format_cand(cand: str) -> str:
20
  return f"{cand[0].title()}{cand[1:]}."
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def main() -> None:
24
  st.header("Describe audio content with CoNeTTE")
25
 
26
  model = load_conette(model_kwds=dict(device="cpu"))
27
 
28
- st.warning("Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz** minimum.")
 
 
29
  audios = st.file_uploader(
30
  "**Upload audio files here:**",
31
  type=["wav", "flac", "mp3", "ogg", "avi"],
@@ -78,7 +124,7 @@ def main() -> None:
78
  )
79
  del allow_rep_mode
80
 
81
- kwargs: dict[str, Any] = dict(
82
  task=task,
83
  beam_size=beam_size,
84
  min_pred_size=min_pred_size,
@@ -87,39 +133,7 @@ def main() -> None:
87
  )
88
 
89
  if audios is not None and len(audios) > 0:
90
- audio_to_predict = []
91
- cands = [""] * len(audios)
92
- tmp_files = []
93
- tmp_fpaths = []
94
- audio_fnames = []
95
-
96
- for i, audio in enumerate(audios):
97
- audio_fname = audio.name
98
- audio_fnames.append(audio_fname)
99
- cand_key = f"{audio_fname}-{kwargs}"
100
-
101
- if cand_key in st.session_state:
102
- cand = st.session_state[cand_key]
103
- cands[i] = cand
104
- else:
105
- tmp_file = NamedTemporaryFile()
106
- tmp_file.write(audio.getvalue())
107
- tmp_files.append(tmp_file)
108
- audio_to_predict.append((i, cand_key, tmp_file))
109
-
110
- tmp_fpath = tmp_file.name
111
- tmp_fpaths.append(tmp_fpath)
112
-
113
- if len(tmp_fpaths) > 0:
114
- outputs = model(
115
- tmp_fpaths,
116
- **kwargs,
117
- )
118
- for i, (j, cand_key, tmp_file) in enumerate(audio_to_predict):
119
- cand = outputs["cands"][i]
120
- cands[j] = cand
121
- st.session_state[cand_key] = cand
122
- tmp_file.close()
123
 
124
  for audio_fname, cand in zip(audio_fnames, cands):
125
  st.success(f"**Output for {audio_fname}:**\n- {format_cand(cand)}")
@@ -127,10 +141,12 @@ def main() -> None:
127
  if len(record) > 0:
128
  outputs = model(
129
  record_fpath,
130
- **kwargs,
131
  )
132
  cand = outputs["cands"][0]
133
- st.success(f"**Output for {'test'}:**\n- {format_cand(cand)}")
 
 
134
 
135
 
136
  if __name__ == "__main__":
 
1
  #!/usr/bin/env python
2
  # -*- coding: utf-8 -*-
3
 
4
+ import os.path as osp
5
+
6
  from tempfile import NamedTemporaryFile
7
  from typing import Any
8
 
 
22
  return f"{cand[0].title()}{cand[1:]}."
23
 
24
 
25
+ def get_results(
26
+ model: CoNeTTEModel,
27
+ audios: list,
28
+ generate_kwds: dict[str, Any],
29
+ ) -> tuple[list[str], list[str]]:
30
+ audio_to_predict = []
31
+ cands = [""] * len(audios)
32
+ tmp_files = []
33
+ tmp_fpaths = []
34
+ audio_fnames = []
35
+
36
+ for i, audio in enumerate(audios):
37
+ audio_fname = audio.name
38
+ audio_fnames.append(audio_fname)
39
+ cand_key = f"{audio_fname}-{generate_kwds}"
40
+
41
+ if cand_key in st.session_state:
42
+ cand = st.session_state[cand_key]
43
+ cands[i] = cand
44
+ else:
45
+ tmp_file = NamedTemporaryFile()
46
+ tmp_file.write(audio.getvalue())
47
+ tmp_files.append(tmp_file)
48
+ audio_to_predict.append((i, cand_key, tmp_file))
49
+
50
+ tmp_fpath = tmp_file.name
51
+ tmp_fpaths.append(tmp_fpath)
52
+
53
+ if len(tmp_fpaths) > 0:
54
+ outputs = model(
55
+ tmp_fpaths,
56
+ **generate_kwds,
57
+ )
58
+ for i, (j, cand_key, tmp_file) in enumerate(audio_to_predict):
59
+ cand = outputs["cands"][i]
60
+ cands[j] = cand
61
+ st.session_state[cand_key] = cand
62
+ tmp_file.close()
63
+
64
+ return audio_fnames, cands
65
+
66
+
67
  def main() -> None:
68
  st.header("Describe audio content with CoNeTTE")
69
 
70
  model = load_conette(model_kwds=dict(device="cpu"))
71
 
72
+ st.warning(
73
+ "Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz** minimum."
74
+ )
75
  audios = st.file_uploader(
76
  "**Upload audio files here:**",
77
  type=["wav", "flac", "mp3", "ogg", "avi"],
 
124
  )
125
  del allow_rep_mode
126
 
127
+ generate_kwds: dict[str, Any] = dict(
128
  task=task,
129
  beam_size=beam_size,
130
  min_pred_size=min_pred_size,
 
133
  )
134
 
135
  if audios is not None and len(audios) > 0:
136
+ audio_fnames, cands = get_results(model, audios, generate_kwds)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  for audio_fname, cand in zip(audio_fnames, cands):
139
  st.success(f"**Output for {audio_fname}:**\n- {format_cand(cand)}")
 
141
  if len(record) > 0:
142
  outputs = model(
143
  record_fpath,
144
+ **generate_kwds,
145
  )
146
  cand = outputs["cands"][0]
147
+ st.success(
148
+ f"**Output for {osp.basename(record_fpath)}:**\n- {format_cand(cand)}"
149
+ )
150
 
151
 
152
  if __name__ == "__main__":