File size: 4,663 Bytes
590f41d
 
 
c37029b
 
590f41d
a480cb3
590f41d
927ba9d
 
d7b4867
a7d5a37
d7b4867
590f41d
 
 
 
 
 
 
 
bbf3398
 
 
 
 
afed00c
 
c37029b
 
a7d5a37
c37029b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590f41d
a480cb3
 
f96230b
 
c37029b
 
 
590f41d
d7b4867
590f41d
 
 
d7b4867
 
 
 
 
 
 
 
 
 
 
 
590f41d
d290679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c37029b
d290679
 
 
 
 
 
 
590f41d
c37029b
d290679
 
bbf3398
590f41d
d7b4867
 
 
c37029b
d7b4867
 
c37029b
bbf3398
c37029b
d7b4867
590f41d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os.path as osp

from tempfile import NamedTemporaryFile
from typing import Any

import streamlit as st

from audiorecorder import audiorecorder
from streamlit.runtime.uploaded_file_manager import UploadedFile

from conette import CoNeTTEModel, conette


@st.cache_resource
def load_conette(*args, **kwargs) -> CoNeTTEModel:
    return conette(*args, **kwargs)


def format_candidate(candidate: str) -> str:
    if len(candidate) == 0:
        return ""
    else:
        return f"{candidate[0].title()}{candidate[1:]}."


def get_results(
    model: CoNeTTEModel,
    audios: list[UploadedFile],
    generate_kwds: dict[str, Any],
) -> tuple[list[str], list[str]]:
    audio_to_predict = []
    cands = [""] * len(audios)
    tmp_files = []
    tmp_fpaths = []
    audio_fnames = []

    for i, audio in enumerate(audios):
        audio_fname = audio.name
        audio_fnames.append(audio_fname)
        cand_key = f"{audio_fname}-{generate_kwds}"

        if cand_key in st.session_state:
            cand = st.session_state[cand_key]
            cands[i] = cand
        else:
            tmp_file = NamedTemporaryFile()
            tmp_file.write(audio.getvalue())
            tmp_files.append(tmp_file)
            audio_to_predict.append((i, cand_key, tmp_file))

            tmp_fpath = tmp_file.name
            tmp_fpaths.append(tmp_fpath)

    if len(tmp_fpaths) > 0:
        outputs = model(
            tmp_fpaths,
            **generate_kwds,
        )
        for i, (j, cand_key, tmp_file) in enumerate(audio_to_predict):
            cand = outputs["cands"][i]
            cands[j] = cand
            st.session_state[cand_key] = cand
            tmp_file.close()

    return audio_fnames, cands


def main() -> None:
    st.header("Describe audio content with CoNeTTE")

    model = load_conette(model_kwds=dict(device="cpu"))

    st.warning(
        "Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz** minimum."
    )
    audios = st.file_uploader(
        "**Upload audio files here:**",
        type=["wav", "flac", "mp3", "ogg", "avi"],
        accept_multiple_files=True,
    )
    st.write("**OR**")
    record = audiorecorder(
        start_prompt="Start recording", stop_prompt="Stop recording", pause_prompt=""
    )

    record_fpath = "record.wav"
    if len(record) > 0:
        record.export(record_fpath, format="wav")
        st.write(
            f"Record frame rate: {record.frame_rate}Hz, record duration: {record.duration_seconds:.2f}s"
        )
        st.audio(record.export().read())  # type: ignore

    with st.expander("Model hyperparameters"):
        task = st.selectbox("Task embedding input", model.tasks, 0)
        allow_rep_mode = st.selectbox(
            "Allow repetition of words", ["stopwords", "all", "none"], 0
        )
        beam_size: int = st.select_slider(  # type: ignore
            "Beam size",
            list(range(1, 21)),
            model.config.beam_size,
        )
        min_pred_size: int = st.select_slider(  # type: ignore
            "Minimal number of words",
            list(range(1, 31)),
            model.config.min_pred_size,
        )
        max_pred_size: int = st.select_slider(  # type: ignore
            "Maximal number of words",
            list(range(1, 31)),
            model.config.max_pred_size,
        )

        if allow_rep_mode == "all":
            forbid_rep_mode = "none"
        elif allow_rep_mode == "none":
            forbid_rep_mode = "all"
        elif allow_rep_mode == "stopwords":
            forbid_rep_mode = "content_words"
        else:
            ALLOW_REP_MODES = ("all", "none", "stopwords")
            raise ValueError(
                f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})"
            )
        del allow_rep_mode

        generate_kwds: dict[str, Any] = dict(
            task=task,
            beam_size=beam_size,
            min_pred_size=min_pred_size,
            max_pred_size=max_pred_size,
            forbid_rep_mode=forbid_rep_mode,
        )

    if audios is not None and len(audios) > 0:
        audio_fnames, cands = get_results(model, audios, generate_kwds)

        for audio_fname, cand in zip(audio_fnames, cands):
            st.success(f"**Output for {audio_fname}:**\n- {format_candidate(cand)}")

    if len(record) > 0:
        outputs = model(
            record_fpath,
            **generate_kwds,
        )
        cand = outputs["cands"][0]
        st.success(
            f"**Output for {osp.basename(record_fpath)}:**\n- {format_candidate(cand)}"
        )


if __name__ == "__main__":
    main()