File size: 10,033 Bytes
590f41d
 
 
4ff8b3b
 
ae94a43
5f47c66
590f41d
927ba9d
5f47c66
927ba9d
88b33c2
 
ae94a43
a7d5a37
5f47c66
d7b4867
ae94a43
 
4ff8b3b
ae94a43
 
4ff8b3b
5f47c66
ae94a43
 
5f47c66
 
4ff8b3b
 
 
590f41d
 
 
 
 
 
 
bbf3398
 
 
 
 
afed00c
 
ae94a43
 
 
 
 
 
 
 
4ff8b3b
ae94a43
 
c37029b
 
ae94a43
c37029b
5f47c66
ae94a43
5f47c66
ae94a43
 
 
5f47c66
ae94a43
 
 
5f47c66
4ff8b3b
ae94a43
5f47c66
 
 
 
 
 
28a7057
 
 
 
5f47c66
 
 
28a7057
 
 
 
5f47c66
 
 
 
ae94a43
 
 
 
 
 
 
 
 
c37029b
 
ae94a43
 
 
 
 
 
 
 
 
 
 
4ff8b3b
 
 
ae94a43
 
 
5f47c66
2ec8ef6
 
 
ae94a43
c37029b
ae94a43
5f47c66
 
 
 
 
 
 
 
 
 
ae94a43
 
 
5f47c66
 
 
 
 
 
 
ae94a43
 
 
 
 
 
4ff8b3b
5f47c66
4ff8b3b
 
 
db7515e
 
4ff8b3b
 
 
 
 
 
 
db7515e
 
 
4ff8b3b
 
 
 
 
5f47c66
 
 
4ff8b3b
5f47c66
 
 
4ff8b3b
 
 
 
5f47c66
4ff8b3b
 
ae94a43
c37029b
 
590f41d
f96230b
 
5f47c66
 
4ff8b3b
5f47c66
4ff8b3b
 
590f41d
4ff8b3b
 
 
 
 
 
 
db7515e
4ff8b3b
590f41d
4ff8b3b
 
 
 
 
 
 
ae94a43
d290679
 
ae94a43
d290679
 
ae94a43
 
 
 
 
d290679
ae94a43
 
 
 
d290679
 
 
 
 
 
 
 
 
db7515e
d290679
 
db7515e
 
d290679
 
c37029b
d290679
 
 
 
 
ae94a43
d290679
 
ae94a43
 
 
 
 
590f41d
ae94a43
2ec8ef6
 
 
ae94a43
d7b4867
4ff8b3b
 
 
 
 
 
 
 
88b33c2
 
 
 
 
4ff8b3b
590f41d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import time
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
from typing import Any, Optional, Union

import streamlit as st
import torchaudio

from conette import CoNeTTEModel, conette, __version__
from conette.utils.collections import dict_list_to_list_dict
from st_audiorec import st_audiorec
from streamlit.runtime.uploaded_file_manager import UploadedFile
from torch import Tensor


ALLOW_REP_MODES = ("stopwords", "all", "none")
DEFAULT_TASK = "audiocaps"
MAX_BEAM_SIZE = 20
MAX_PRED_SIZE = 30
MAX_BATCH_SIZE = 16
RECORD_AUDIO_FNAME = "microphone_conette_record.wav"
DEFAULT_THRESHOLD = 0.3
THRESHOLD_PRECISION = 100
MIN_AUDIO_DURATION_SEC = 0.3
MAX_AUDIO_DURATION_SEC = 60
HASH_PREFIX = "hash_"
TMP_FILE_PREFIX = "audio_tmp_file_"
SECOND_BEFORE_CLEAR_CACHE = 10 * 60


@st.cache_resource
def load_conette(*args, **kwargs) -> CoNeTTEModel:
    return conette(*args, **kwargs)


def format_candidate(candidate: str) -> str:
    if len(candidate) == 0:
        return ""
    else:
        return f"{candidate[0].title()}{candidate[1:]}."


def format_tags(tags: Optional[list[str]]) -> str:
    if tags is None or len(tags) == 0:
        return "None."
    else:
        return ", ".join(tags)


def get_result_hash(audio_fname: str, generate_kwds: dict[str, Any]) -> str:
    return f"{HASH_PREFIX}{audio_fname}-{generate_kwds}"


def get_results(
    model: CoNeTTEModel,
    audio_files: dict[str, bytes],
    generate_kwds: dict[str, Any],
) -> dict[str, Union[dict[str, Any], str]]:
    # Get audio to be processed
    audio_to_predict: dict[str, tuple[str, bytes]] = {}
    for audio_fname, audio in audio_files.items():
        result_hash = get_result_hash(audio_fname, generate_kwds)
        if result_hash not in st.session_state or audio_fname == RECORD_AUDIO_FNAME:
            audio_to_predict[result_hash] = (audio_fname, audio)

    # Save audio to be processed
    tmp_files: dict[str, _TemporaryFileWrapper] = {}
    for result_hash, (audio_fname, audio) in audio_to_predict.items():
        tmp_file = NamedTemporaryFile(delete=False, prefix=TMP_FILE_PREFIX)
        tmp_file.write(audio)
        tmp_file.close()

        metadata = torchaudio.info(tmp_file.name)  # type: ignore
        duration = metadata.num_frames / metadata.sample_rate

        if MIN_AUDIO_DURATION_SEC > duration:
            error_msg = f"""
            ##### Result for "{audio_fname}"
            Audio file is too short. (found {duration:.2f}s but the model expect audio in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}])
            """
            st.session_state[result_hash] = error_msg

        elif duration > MAX_AUDIO_DURATION_SEC:
            error_msg = f"""
            ##### Result for "{audio_fname}"
            Audio file is too long. (found {duration:.2f}s but the model expect audio in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}])
            """
            st.session_state[result_hash] = error_msg

        else:
            tmp_files[result_hash] = tmp_file

    # Generate predictions and store them in session state
    for start in range(0, len(tmp_files), MAX_BATCH_SIZE):
        end = min(start + MAX_BATCH_SIZE, len(tmp_files))
        result_hashes_j = list(tmp_files.keys())[start:end]
        tmp_files_j = list(tmp_files.values())[start:end]
        tmp_paths_j = [tmp_file.name for tmp_file in tmp_files_j]
        outputs_j = model(
            tmp_paths_j,
            **generate_kwds,
        )
        outputs_lst = dict_list_to_list_dict(outputs_j)  # type: ignore
        for result_hash, output_i in zip(result_hashes_j, outputs_lst):
            st.session_state[result_hash] = output_i

    # Get outputs
    outputs = {}
    for audio_fname in audio_files.keys():
        result_hash = get_result_hash(audio_fname, generate_kwds)
        output_i = st.session_state[result_hash]
        outputs[audio_fname] = output_i

    for tmp_file in tmp_files.values():
        os.remove(tmp_file.name)

    return outputs


def show_results(outputs: dict[str, Union[dict[str, Any], str]]) -> None:
    keys = list(outputs.keys())[::-1]
    outputs = {key: outputs[key] for key in keys}

    st.divider()

    for audio_fname, output in outputs.items():
        if isinstance(output, str):
            st.error(output)
            st.divider()
            continue

        cand: str = output["cands"]
        lprobs: Tensor = output["lprobs"]
        tags_lst = output.get("tags")
        mult_cands: list[str] = output["mult_cands"]
        mult_lprobs: Tensor = output["mult_lprobs"]

        cand = format_candidate(cand)
        prob = lprobs.exp().tolist()
        tags = format_tags(tags_lst)
        mult_cands = [format_candidate(cand_i) for cand_i in mult_cands]
        mult_probs = mult_lprobs.exp()

        indexes = mult_probs.argsort(descending=True)[1:]
        mult_probs = mult_probs[indexes].tolist()
        mult_cands = [mult_cands[idx] for idx in indexes]

        if audio_fname == RECORD_AUDIO_FNAME:
            header = "##### Result for microphone input:"
        else:
            header = f'##### Result for "{audio_fname}"'

        lines = [
            header,
            f'<center><p class="space"><p class="big-font">"{cand}"</p></p></center>',
        ]

        st.markdown(
            """
        <style>
        .big-font {
            font-size:22px !important;
            background-color: rgba(0, 255, 0, 0.1);
            padding: 10px;
        }
        </style>
        """,
            unsafe_allow_html=True,
        )
        content = "<br>".join(lines)
        st.markdown(content, unsafe_allow_html=True)

        lines = [
            f"- **Probability**: {prob*100:.1f}%",
        ]
        if len(mult_cands) > 0:
            msg = f"- **Other descriptions:**"
            lines.append(msg)

        for cand_i, prob_i in zip(mult_cands, mult_probs):
            msg = f'  - "{cand_i}" ({prob_i*100:.1f}%)'
            lines.append(msg)

        msg = f"- **Tags:** {tags}"
        lines.append(msg)

        content = "\n".join(lines)
        st.markdown(content, unsafe_allow_html=False)
        st.divider()


def main() -> None:
    model = load_conette(model_kwds=dict(device="cpu"))

    st.header("Describe audio content with CoNeTTE")
    st.markdown(
        "This interface allows you to generate a short description of the sound events of any recording using an Audio Captioning system. You can try it from your microphone or upload a file below."
    )
    st.markdown(
        "Use '**Start Recording**' and '**Stop**' to record an audio from your microphone."
    )
    record_data = st_audiorec()

    with st.expander("Or upload audio files here:"):
        audio_files: Optional[list[UploadedFile]] = st.file_uploader(
            f"Audio files are automatically resampled to 32 kHz.\nTheir duration must be in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}] seconds.",
            type=["wav", "flac", "mp3", "ogg", "avi"],
            accept_multiple_files=True,
            help="Supports wav, flac, mp3, ogg and avi files.",
        )

    with st.expander("Model options"):
        if DEFAULT_TASK in model.tasks:
            default_task_idx = list(model.tasks).index(DEFAULT_TASK)
        else:
            default_task_idx = 0

        task = st.selectbox("Task embedding input", model.tasks, default_task_idx)
        allow_rep_mode = st.selectbox("Allow repetition of words", ALLOW_REP_MODES, 0)
        beam_size: int = st.select_slider(  # type: ignore
            "Beam size",
            list(range(1, MAX_BEAM_SIZE + 1)),
            model.config.beam_size,
        )
        min_pred_size, max_pred_size = st.slider(
            "Minimal and maximal number of words",
            1,
            MAX_PRED_SIZE,
            (model.config.min_pred_size, model.config.max_pred_size),
        )
        threshold = st.select_slider(
            "Tags threshold",
            [(i / THRESHOLD_PRECISION) for i in range(THRESHOLD_PRECISION + 1)],
            DEFAULT_THRESHOLD,
        )

        if allow_rep_mode == "all":
            forbid_rep_mode = "none"
        elif allow_rep_mode == "none":
            forbid_rep_mode = "all"
        elif allow_rep_mode == "stopwords":
            forbid_rep_mode = "content_words"
        else:
            msg = (
                f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})"
            )
            raise ValueError(msg)

        del allow_rep_mode

        generate_kwds: dict[str, Any] = dict(
            task=task,
            beam_size=beam_size,
            min_pred_size=min_pred_size,
            max_pred_size=max_pred_size,
            forbid_rep_mode=forbid_rep_mode,
            threshold=threshold,
        )

    audios: dict[str, bytes] = {}
    if audio_files is not None:
        audios |= {audio.name: audio.getvalue() for audio in audio_files}
    if record_data is not None:
        audios |= {RECORD_AUDIO_FNAME: record_data}

    if len(audios) > 0:
        with st.spinner("Generating descriptions..."):
            outputs = get_results(model, audios, generate_kwds)
        st.header("Results:")
        show_results(outputs)

        current = time.perf_counter()
        last_generation = st.session_state.get("last_generation", current)
        if current > last_generation + SECOND_BEFORE_CLEAR_CACHE:
            print(f"Removing result cache...")
            for key in st.session_state.keys():
                if isinstance(key, str) and key.startswith(HASH_PREFIX):
                    del st.session_state[key]
        st.session_state["last_generation"] = current
    
    content = f"""CoNeTTE version {__version__}. <a href="https://github.com/Labbeti/conette-audio-captioning/">Source code on GitHub</a>. <a href="https://ieeexplore.ieee.org/document/10603439">Academic Paper</a>."""
    st.divider()
    st.markdown(content, unsafe_allow_html=True)



if __name__ == "__main__":
    main()