|
|
|
|
|
|
|
import os.path as osp |
|
|
|
from tempfile import NamedTemporaryFile |
|
from typing import Any |
|
|
|
import streamlit as st |
|
|
|
from audiorecorder import audiorecorder |
|
|
|
from conette import CoNeTTEModel, conette |
|
|
|
|
|
@st.cache_resource |
|
def load_conette(*args, **kwargs) -> CoNeTTEModel: |
|
return conette(*args, **kwargs) |
|
|
|
|
|
def format_cand(cand: str) -> str: |
|
return f"{cand[0].title()}{cand[1:]}." |
|
|
|
|
|
def get_results( |
|
model: CoNeTTEModel, |
|
audios: list, |
|
generate_kwds: dict[str, Any], |
|
) -> tuple[list[str], list[str]]: |
|
audio_to_predict = [] |
|
cands = [""] * len(audios) |
|
tmp_files = [] |
|
tmp_fpaths = [] |
|
audio_fnames = [] |
|
|
|
for i, audio in enumerate(audios): |
|
audio_fname = audio.name |
|
audio_fnames.append(audio_fname) |
|
cand_key = f"{audio_fname}-{generate_kwds}" |
|
|
|
if cand_key in st.session_state: |
|
cand = st.session_state[cand_key] |
|
cands[i] = cand |
|
else: |
|
tmp_file = NamedTemporaryFile() |
|
tmp_file.write(audio.getvalue()) |
|
tmp_files.append(tmp_file) |
|
audio_to_predict.append((i, cand_key, tmp_file)) |
|
|
|
tmp_fpath = tmp_file.name |
|
tmp_fpaths.append(tmp_fpath) |
|
|
|
if len(tmp_fpaths) > 0: |
|
outputs = model( |
|
tmp_fpaths, |
|
**generate_kwds, |
|
) |
|
for i, (j, cand_key, tmp_file) in enumerate(audio_to_predict): |
|
cand = outputs["cands"][i] |
|
cands[j] = cand |
|
st.session_state[cand_key] = cand |
|
tmp_file.close() |
|
|
|
return audio_fnames, cands |
|
|
|
|
|
def main() -> None: |
|
st.header("Describe audio content with CoNeTTE") |
|
|
|
model = load_conette(model_kwds=dict(device="cpu")) |
|
|
|
st.warning( |
|
"Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz** minimum." |
|
) |
|
audios = st.file_uploader( |
|
"**Upload audio files here:**", |
|
type=["wav", "flac", "mp3", "ogg", "avi"], |
|
accept_multiple_files=True, |
|
) |
|
st.write("**OR**") |
|
record = audiorecorder( |
|
start_prompt="Start recording", stop_prompt="Stop recording", pause_prompt="" |
|
) |
|
|
|
record_fpath = "record.wav" |
|
if len(record) > 0: |
|
record.export(record_fpath, format="wav") |
|
st.write( |
|
f"Record frame rate: {record.frame_rate}Hz, record duration: {record.duration_seconds:.2f}s" |
|
) |
|
st.audio(record.export().read()) |
|
|
|
with st.expander("Model hyperparameters"): |
|
task = st.selectbox("Task embedding input", model.tasks, 0) |
|
allow_rep_mode = st.selectbox( |
|
"Allow repetition of words", ["stopwords", "all", "none"], 0 |
|
) |
|
beam_size: int = st.select_slider( |
|
"Beam size", |
|
list(range(1, 21)), |
|
model.config.beam_size, |
|
) |
|
min_pred_size: int = st.select_slider( |
|
"Minimal number of words", |
|
list(range(1, 31)), |
|
model.config.min_pred_size, |
|
) |
|
max_pred_size: int = st.select_slider( |
|
"Maximal number of words", |
|
list(range(1, 31)), |
|
model.config.max_pred_size, |
|
) |
|
|
|
if allow_rep_mode == "all": |
|
forbid_rep_mode = "none" |
|
elif allow_rep_mode == "none": |
|
forbid_rep_mode = "all" |
|
elif allow_rep_mode == "stopwords": |
|
forbid_rep_mode = "content_words" |
|
else: |
|
ALLOW_REP_MODES = ("all", "none", "stopwords") |
|
raise ValueError( |
|
f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})" |
|
) |
|
del allow_rep_mode |
|
|
|
generate_kwds: dict[str, Any] = dict( |
|
task=task, |
|
beam_size=beam_size, |
|
min_pred_size=min_pred_size, |
|
max_pred_size=max_pred_size, |
|
forbid_rep_mode=forbid_rep_mode, |
|
) |
|
|
|
if audios is not None and len(audios) > 0: |
|
audio_fnames, cands = get_results(model, audios, generate_kwds) |
|
|
|
for audio_fname, cand in zip(audio_fnames, cands): |
|
st.success(f"**Output for {audio_fname}:**\n- {format_cand(cand)}") |
|
|
|
if len(record) > 0: |
|
outputs = model( |
|
record_fpath, |
|
**generate_kwds, |
|
) |
|
cand = outputs["cands"][0] |
|
st.success( |
|
f"**Output for {osp.basename(record_fpath)}:**\n- {format_cand(cand)}" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|