File size: 2,913 Bytes
590f41d
 
 
 
a480cb3
590f41d
927ba9d
 
590f41d
 
 
 
 
 
 
 
afed00c
 
 
 
590f41d
a480cb3
 
f96230b
 
 
afed00c
f96230b
 
afed00c
f96230b
 
 
 
83af184
f96230b
 
 
 
83af184
f96230b
 
 
afed00c
590f41d
 
 
 
 
 
 
 
 
 
 
a480cb3
afed00c
 
 
 
 
 
 
 
 
 
a480cb3
f96230b
 
 
 
afed00c
f96230b
a480cb3
 
 
 
 
 
 
 
 
 
 
590f41d
afed00c
 
590f41d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from tempfile import NamedTemporaryFile
from typing import Any

import streamlit as st

from conette import CoNeTTEModel, conette


@st.cache_resource
def load_conette(*args, **kwargs) -> CoNeTTEModel:
    return conette(*args, **kwargs)


def format_cand(cand: str) -> str:
    return f"{cand[0].title()}{cand[1:]}."


def main() -> None:
    st.header("Describe audio content with CoNeTTE")

    model = load_conette(model_kwds=dict(device="cpu"))

    task = st.selectbox("Task embedding input", model.tasks, 0)
    allow_rep_mode = st.selectbox("Allow repetition of words", ["stopwords", "all", "none"], 0)
    beam_size: int = st.select_slider(  # type: ignore
        "Beam size",
        list(range(1, 21)),
        model.config.beam_size,
    )
    min_pred_size: int = st.select_slider(  # type: ignore
        "Minimal number of words",
        list(range(1, 31)),
        model.config.min_pred_size,
    )
    max_pred_size: int = st.select_slider(  # type: ignore
        "Maximal number of words",
        list(range(1, 31)),
        model.config.max_pred_size,
    )

    st.markdown("Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz**.")
    audios = st.file_uploader(
        "Upload an audio file",
        type=["wav", "flac", "mp3", "ogg", "avi"],
        accept_multiple_files=True,
    )

    if audios is not None and len(audios) > 0:
        for audio in audios:
            with NamedTemporaryFile() as temp:
                temp.write(audio.getvalue())
                fpath = temp.name

                if allow_rep_mode == "all":
                    forbid_rep_mode = "none"
                elif allow_rep_mode == "none":
                    forbid_rep_mode = "all"
                elif allow_rep_mode == "stopwords":
                    forbid_rep_mode = "content_words"
                else:
                    ALLOW_REP_MODES = ("all", "none", "stopwords")
                    raise ValueError(f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})")

                kwargs: dict[str, Any] = dict(
                    task=task,
                    beam_size=beam_size,
                    min_pred_size=min_pred_size,
                    max_pred_size=max_pred_size,
                    forbid_rep_mode=forbid_rep_mode,
                )
                cand_key = f"{audio.name}-{kwargs}"

                if cand_key in st.session_state:
                    cand = st.session_state[cand_key]
                else:
                    outputs = model(
                        fpath,
                        **kwargs,
                    )
                    cand = outputs["cands"][0]
                    st.session_state[cand_key] = cand

                st.markdown(f"Output for {audio.name}:")
                st.markdown(f" - red[{format_cand(cand)}]")


if __name__ == "__main__":
    main()