File size: 4,268 Bytes
590f41d
 
 
 
a480cb3
590f41d
927ba9d
 
d7b4867
 
590f41d
 
 
 
 
 
 
 
afed00c
 
 
 
590f41d
a480cb3
 
f96230b
 
b8c79c9
590f41d
d7b4867
590f41d
 
 
d7b4867
 
 
 
 
 
 
 
 
 
 
 
590f41d
d290679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590f41d
d290679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590f41d
d7b4867
 
 
 
 
 
 
 
590f41d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from tempfile import NamedTemporaryFile
from typing import Any

import streamlit as st

from audiorecorder import audiorecorder

from conette import CoNeTTEModel, conette


@st.cache_resource
def load_conette(*args, **kwargs) -> CoNeTTEModel:
    return conette(*args, **kwargs)


def format_cand(cand: str) -> str:
    return f"{cand[0].title()}{cand[1:]}."


def main() -> None:
    st.header("Describe audio content with CoNeTTE")

    model = load_conette(model_kwds=dict(device="cpu"))

    st.warning("Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz** minimum.")
    audios = st.file_uploader(
        "**Upload audio files here:**",
        type=["wav", "flac", "mp3", "ogg", "avi"],
        accept_multiple_files=True,
    )
    st.write("**OR**")
    record = audiorecorder(
        start_prompt="Start recording", stop_prompt="Stop recording", pause_prompt=""
    )

    record_fpath = "record.wav"
    if len(record) > 0:
        record.export(record_fpath, format="wav")
        st.write(
            f"Record frame rate: {record.frame_rate}Hz, record duration: {record.duration_seconds:.2f}s"
        )
        st.audio(record.export().read())  # type: ignore

    with st.expander("Model hyperparameters"):
        task = st.selectbox("Task embedding input", model.tasks, 0)
        allow_rep_mode = st.selectbox(
            "Allow repetition of words", ["stopwords", "all", "none"], 0
        )
        beam_size: int = st.select_slider(  # type: ignore
            "Beam size",
            list(range(1, 21)),
            model.config.beam_size,
        )
        min_pred_size: int = st.select_slider(  # type: ignore
            "Minimal number of words",
            list(range(1, 31)),
            model.config.min_pred_size,
        )
        max_pred_size: int = st.select_slider(  # type: ignore
            "Maximal number of words",
            list(range(1, 31)),
            model.config.max_pred_size,
        )

        if allow_rep_mode == "all":
            forbid_rep_mode = "none"
        elif allow_rep_mode == "none":
            forbid_rep_mode = "all"
        elif allow_rep_mode == "stopwords":
            forbid_rep_mode = "content_words"
        else:
            ALLOW_REP_MODES = ("all", "none", "stopwords")
            raise ValueError(
                f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})"
            )
        del allow_rep_mode

        kwargs: dict[str, Any] = dict(
            task=task,
            beam_size=beam_size,
            min_pred_size=min_pred_size,
            max_pred_size=max_pred_size,
            forbid_rep_mode=forbid_rep_mode,
        )

    if audios is not None and len(audios) > 0:
        audio_to_predict = []
        cands = [""] * len(audios)
        tmp_files = []
        tmp_fpaths = []
        audio_fnames = []

        for i, audio in enumerate(audios):
            audio_fname = audio.name
            audio_fnames.append(audio_fname)
            cand_key = f"{audio_fname}-{kwargs}"

            if cand_key in st.session_state:
                cand = st.session_state[cand_key]
                cands[i] = cand
            else:
                tmp_file = NamedTemporaryFile()
                tmp_file.write(audio.getvalue())
                tmp_files.append(tmp_file)
                audio_to_predict.append((i, cand_key, tmp_file))

                tmp_fpath = tmp_file.name
                tmp_fpaths.append(tmp_fpath)

        if len(tmp_fpaths) > 0:
            outputs = model(
                tmp_fpaths,
                **kwargs,
            )
            for i, (j, cand_key, tmp_file) in enumerate(audio_to_predict):
                cand = outputs["cands"][i]
                cands[j] = cand
                st.session_state[cand_key] = cand
                tmp_file.close()

        for audio_fname, cand in zip(audio_fnames, cands):
            st.success(f"**Output for {audio_fname}:**\n- {format_cand(cand)}")

    if len(record) > 0:
        outputs = model(
            record_fpath,
            **kwargs,
        )
        cand = outputs["cands"][0]
        st.success(f"**Output for {'test'}:**\n- {format_cand(cand)}")


if __name__ == "__main__":
    main()