conette / app.py
Labbeti's picture
Add: Allow repetition mode option.
afed00c
raw
history blame
2.91 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from tempfile import NamedTemporaryFile
from typing import Any
import streamlit as st
from conette import CoNeTTEModel, conette
@st.cache_resource
def load_conette(*args, **kwargs) -> CoNeTTEModel:
return conette(*args, **kwargs)
def format_cand(cand: str) -> str:
return f"{cand[0].title()}{cand[1:]}."
def main() -> None:
st.header("Describe audio content with CoNeTTE")
model = load_conette(model_kwds=dict(device="cpu"))
task = st.selectbox("Task embedding input", model.tasks, 0)
allow_rep_mode = st.selectbox("Allow repetition of words", ["stopwords", "all", "none"], 0)
beam_size: int = st.select_slider( # type: ignore
"Beam size",
list(range(1, 21)),
model.config.beam_size,
)
min_pred_size: int = st.select_slider( # type: ignore
"Minimal number of words",
list(range(1, 31)),
model.config.min_pred_size,
)
max_pred_size: int = st.select_slider( # type: ignore
"Maximal number of words",
list(range(1, 31)),
model.config.max_pred_size,
)
st.markdown("Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz**.")
audios = st.file_uploader(
"Upload an audio file",
type=["wav", "flac", "mp3", "ogg", "avi"],
accept_multiple_files=True,
)
if audios is not None and len(audios) > 0:
for audio in audios:
with NamedTemporaryFile() as temp:
temp.write(audio.getvalue())
fpath = temp.name
if allow_rep_mode == "all":
forbid_rep_mode = "none"
elif allow_rep_mode == "none":
forbid_rep_mode = "all"
elif allow_rep_mode == "stopwords":
forbid_rep_mode = "content_words"
else:
ALLOW_REP_MODES = ("all", "none", "stopwords")
raise ValueError(f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})")
kwargs: dict[str, Any] = dict(
task=task,
beam_size=beam_size,
min_pred_size=min_pred_size,
max_pred_size=max_pred_size,
forbid_rep_mode=forbid_rep_mode,
)
cand_key = f"{audio.name}-{kwargs}"
if cand_key in st.session_state:
cand = st.session_state[cand_key]
else:
outputs = model(
fpath,
**kwargs,
)
cand = outputs["cands"][0]
st.session_state[cand_key] = cand
st.markdown(f"Output for {audio.name}:")
st.markdown(f" - red[{format_cand(cand)}]")
if __name__ == "__main__":
main()