conette / app.py
Labbeti's picture
Mod: Refactor model forward.
c37029b
raw
history blame
4.49 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os.path as osp
from tempfile import NamedTemporaryFile
from typing import Any
import streamlit as st
from audiorecorder import audiorecorder
from conette import CoNeTTEModel, conette
@st.cache_resource
def load_conette(*args, **kwargs) -> CoNeTTEModel:
return conette(*args, **kwargs)
def format_cand(cand: str) -> str:
return f"{cand[0].title()}{cand[1:]}."
def get_results(
model: CoNeTTEModel,
audios: list,
generate_kwds: dict[str, Any],
) -> tuple[list[str], list[str]]:
audio_to_predict = []
cands = [""] * len(audios)
tmp_files = []
tmp_fpaths = []
audio_fnames = []
for i, audio in enumerate(audios):
audio_fname = audio.name
audio_fnames.append(audio_fname)
cand_key = f"{audio_fname}-{generate_kwds}"
if cand_key in st.session_state:
cand = st.session_state[cand_key]
cands[i] = cand
else:
tmp_file = NamedTemporaryFile()
tmp_file.write(audio.getvalue())
tmp_files.append(tmp_file)
audio_to_predict.append((i, cand_key, tmp_file))
tmp_fpath = tmp_file.name
tmp_fpaths.append(tmp_fpath)
if len(tmp_fpaths) > 0:
outputs = model(
tmp_fpaths,
**generate_kwds,
)
for i, (j, cand_key, tmp_file) in enumerate(audio_to_predict):
cand = outputs["cands"][i]
cands[j] = cand
st.session_state[cand_key] = cand
tmp_file.close()
return audio_fnames, cands
def main() -> None:
st.header("Describe audio content with CoNeTTE")
model = load_conette(model_kwds=dict(device="cpu"))
st.warning(
"Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz** minimum."
)
audios = st.file_uploader(
"**Upload audio files here:**",
type=["wav", "flac", "mp3", "ogg", "avi"],
accept_multiple_files=True,
)
st.write("**OR**")
record = audiorecorder(
start_prompt="Start recording", stop_prompt="Stop recording", pause_prompt=""
)
record_fpath = "record.wav"
if len(record) > 0:
record.export(record_fpath, format="wav")
st.write(
f"Record frame rate: {record.frame_rate}Hz, record duration: {record.duration_seconds:.2f}s"
)
st.audio(record.export().read()) # type: ignore
with st.expander("Model hyperparameters"):
task = st.selectbox("Task embedding input", model.tasks, 0)
allow_rep_mode = st.selectbox(
"Allow repetition of words", ["stopwords", "all", "none"], 0
)
beam_size: int = st.select_slider( # type: ignore
"Beam size",
list(range(1, 21)),
model.config.beam_size,
)
min_pred_size: int = st.select_slider( # type: ignore
"Minimal number of words",
list(range(1, 31)),
model.config.min_pred_size,
)
max_pred_size: int = st.select_slider( # type: ignore
"Maximal number of words",
list(range(1, 31)),
model.config.max_pred_size,
)
if allow_rep_mode == "all":
forbid_rep_mode = "none"
elif allow_rep_mode == "none":
forbid_rep_mode = "all"
elif allow_rep_mode == "stopwords":
forbid_rep_mode = "content_words"
else:
ALLOW_REP_MODES = ("all", "none", "stopwords")
raise ValueError(
f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})"
)
del allow_rep_mode
generate_kwds: dict[str, Any] = dict(
task=task,
beam_size=beam_size,
min_pred_size=min_pred_size,
max_pred_size=max_pred_size,
forbid_rep_mode=forbid_rep_mode,
)
if audios is not None and len(audios) > 0:
audio_fnames, cands = get_results(model, audios, generate_kwds)
for audio_fname, cand in zip(audio_fnames, cands):
st.success(f"**Output for {audio_fname}:**\n- {format_cand(cand)}")
if len(record) > 0:
outputs = model(
record_fpath,
**generate_kwds,
)
cand = outputs["cands"][0]
st.success(
f"**Output for {osp.basename(record_fpath)}:**\n- {format_cand(cand)}"
)
if __name__ == "__main__":
main()