Spaces:
Sleeping
Sleeping
File size: 2,913 Bytes
590f41d a480cb3 590f41d 927ba9d 590f41d afed00c 590f41d a480cb3 f96230b afed00c f96230b afed00c f96230b 83af184 f96230b 83af184 f96230b afed00c 590f41d a480cb3 afed00c a480cb3 f96230b afed00c f96230b a480cb3 590f41d afed00c 590f41d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from tempfile import NamedTemporaryFile
from typing import Any
import streamlit as st
from conette import CoNeTTEModel, conette
@st.cache_resource
def load_conette(*args, **kwargs) -> CoNeTTEModel:
return conette(*args, **kwargs)
def format_cand(cand: str) -> str:
return f"{cand[0].title()}{cand[1:]}."
def main() -> None:
st.header("Describe audio content with CoNeTTE")
model = load_conette(model_kwds=dict(device="cpu"))
task = st.selectbox("Task embedding input", model.tasks, 0)
allow_rep_mode = st.selectbox("Allow repetition of words", ["stopwords", "all", "none"], 0)
beam_size: int = st.select_slider( # type: ignore
"Beam size",
list(range(1, 21)),
model.config.beam_size,
)
min_pred_size: int = st.select_slider( # type: ignore
"Minimal number of words",
list(range(1, 31)),
model.config.min_pred_size,
)
max_pred_size: int = st.select_slider( # type: ignore
"Maximal number of words",
list(range(1, 31)),
model.config.max_pred_size,
)
st.markdown("Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz**.")
audios = st.file_uploader(
"Upload an audio file",
type=["wav", "flac", "mp3", "ogg", "avi"],
accept_multiple_files=True,
)
if audios is not None and len(audios) > 0:
for audio in audios:
with NamedTemporaryFile() as temp:
temp.write(audio.getvalue())
fpath = temp.name
if allow_rep_mode == "all":
forbid_rep_mode = "none"
elif allow_rep_mode == "none":
forbid_rep_mode = "all"
elif allow_rep_mode == "stopwords":
forbid_rep_mode = "content_words"
else:
ALLOW_REP_MODES = ("all", "none", "stopwords")
raise ValueError(f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})")
kwargs: dict[str, Any] = dict(
task=task,
beam_size=beam_size,
min_pred_size=min_pred_size,
max_pred_size=max_pred_size,
forbid_rep_mode=forbid_rep_mode,
)
cand_key = f"{audio.name}-{kwargs}"
if cand_key in st.session_state:
cand = st.session_state[cand_key]
else:
outputs = model(
fpath,
**kwargs,
)
cand = outputs["cands"][0]
st.session_state[cand_key] = cand
st.markdown(f"Output for {audio.name}:")
st.markdown(f" - red[{format_cand(cand)}]")
if __name__ == "__main__":
main()
|