#!/usr/bin/env python # -*- coding: utf-8 -*- from tempfile import NamedTemporaryFile import streamlit as st from conette import CoNeTTEModel, conette @st.cache_resource def load_conette(*args, **kwargs) -> CoNeTTEModel: return conette(*args, **kwargs) def main() -> None: st.header("CoNeTTE model test") model = load_conette(model_kwds=dict(device="cpu")) task = st.selectbox("Task embedding input", model.tasks, 0) beam_size: int = st.select_slider( # type: ignore "Beam size", list(range(1, 20)), model.config.beam_size, ) min_pred_size: int = st.select_slider( # type: ignore "Minimal number of words", list(range(1, 31)), model.config.min_pred_size, ) max_pred_size: int = st.select_slider( # type: ignore "Maximal number of words", list(range(1, 31)), model.config.max_pred_size, ) audios = st.file_uploader( "Upload an audio file", type=["wav", "flac", "mp3", "ogg", "avi"], accept_multiple_files=True, ) if audios is not None and len(audios) > 0: for audio in audios: with NamedTemporaryFile() as temp: temp.write(audio.getvalue()) fpath = temp.name outputs = model( fpath, task=task, beam_size=beam_size, min_pred_size=min_pred_size, max_pred_size=max_pred_size, ) cand = outputs["cands"][0] st.write(f"Output for {audio.name}:") st.write(" - ", cand) if __name__ == "__main__": main()