#!/usr/bin/env python # -*- coding: utf-8 -*- import os.path as osp from tempfile import NamedTemporaryFile from typing import Any import streamlit as st from audiorecorder import audiorecorder from conette import CoNeTTEModel, conette @st.cache_resource def load_conette(*args, **kwargs) -> CoNeTTEModel: return conette(*args, **kwargs) def format_cand(cand: str) -> str: return f"{cand[0].title()}{cand[1:]}." def get_results( model: CoNeTTEModel, audios: list, generate_kwds: dict[str, Any], ) -> tuple[list[str], list[str]]: audio_to_predict = [] cands = [""] * len(audios) tmp_files = [] tmp_fpaths = [] audio_fnames = [] for i, audio in enumerate(audios): audio_fname = audio.name audio_fnames.append(audio_fname) cand_key = f"{audio_fname}-{generate_kwds}" if cand_key in st.session_state: cand = st.session_state[cand_key] cands[i] = cand else: tmp_file = NamedTemporaryFile() tmp_file.write(audio.getvalue()) tmp_files.append(tmp_file) audio_to_predict.append((i, cand_key, tmp_file)) tmp_fpath = tmp_file.name tmp_fpaths.append(tmp_fpath) if len(tmp_fpaths) > 0: outputs = model( tmp_fpaths, **generate_kwds, ) for i, (j, cand_key, tmp_file) in enumerate(audio_to_predict): cand = outputs["cands"][i] cands[j] = cand st.session_state[cand_key] = cand tmp_file.close() return audio_fnames, cands def main() -> None: st.header("Describe audio content with CoNeTTE") model = load_conette(model_kwds=dict(device="cpu")) st.warning( "Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz** minimum." ) audios = st.file_uploader( "**Upload audio files here:**", type=["wav", "flac", "mp3", "ogg", "avi"], accept_multiple_files=True, ) st.write("**OR**") record = audiorecorder( start_prompt="Start recording", stop_prompt="Stop recording", pause_prompt="" ) record_fpath = "record.wav" if len(record) > 0: record.export(record_fpath, format="wav") st.write( f"Record frame rate: {record.frame_rate}Hz, record duration: {record.duration_seconds:.2f}s" ) st.audio(record.export().read()) # type: ignore with st.expander("Model hyperparameters"): task = st.selectbox("Task embedding input", model.tasks, 0) allow_rep_mode = st.selectbox( "Allow repetition of words", ["stopwords", "all", "none"], 0 ) beam_size: int = st.select_slider( # type: ignore "Beam size", list(range(1, 21)), model.config.beam_size, ) min_pred_size: int = st.select_slider( # type: ignore "Minimal number of words", list(range(1, 31)), model.config.min_pred_size, ) max_pred_size: int = st.select_slider( # type: ignore "Maximal number of words", list(range(1, 31)), model.config.max_pred_size, ) if allow_rep_mode == "all": forbid_rep_mode = "none" elif allow_rep_mode == "none": forbid_rep_mode = "all" elif allow_rep_mode == "stopwords": forbid_rep_mode = "content_words" else: ALLOW_REP_MODES = ("all", "none", "stopwords") raise ValueError( f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})" ) del allow_rep_mode generate_kwds: dict[str, Any] = dict( task=task, beam_size=beam_size, min_pred_size=min_pred_size, max_pred_size=max_pred_size, forbid_rep_mode=forbid_rep_mode, ) if audios is not None and len(audios) > 0: audio_fnames, cands = get_results(model, audios, generate_kwds) for audio_fname, cand in zip(audio_fnames, cands): st.success(f"**Output for {audio_fname}:**\n- {format_cand(cand)}") if len(record) > 0: outputs = model( record_fpath, **generate_kwds, ) cand = outputs["cands"][0] st.success( f"**Output for {osp.basename(record_fpath)}:**\n- {format_cand(cand)}" ) if __name__ == "__main__": main()