Spaces:
Build error
Build error
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import os.path as osp | |
from tempfile import NamedTemporaryFile | |
from typing import Any | |
import streamlit as st | |
from audiorecorder import audiorecorder | |
from conette import CoNeTTEModel, conette | |
def load_conette(*args, **kwargs) -> CoNeTTEModel: | |
return conette(*args, **kwargs) | |
def format_cand(cand: str) -> str: | |
return f"{cand[0].title()}{cand[1:]}." | |
def get_results( | |
model: CoNeTTEModel, | |
audios: list, | |
generate_kwds: dict[str, Any], | |
) -> tuple[list[str], list[str]]: | |
audio_to_predict = [] | |
cands = [""] * len(audios) | |
tmp_files = [] | |
tmp_fpaths = [] | |
audio_fnames = [] | |
for i, audio in enumerate(audios): | |
audio_fname = audio.name | |
audio_fnames.append(audio_fname) | |
cand_key = f"{audio_fname}-{generate_kwds}" | |
if cand_key in st.session_state: | |
cand = st.session_state[cand_key] | |
cands[i] = cand | |
else: | |
tmp_file = NamedTemporaryFile() | |
tmp_file.write(audio.getvalue()) | |
tmp_files.append(tmp_file) | |
audio_to_predict.append((i, cand_key, tmp_file)) | |
tmp_fpath = tmp_file.name | |
tmp_fpaths.append(tmp_fpath) | |
if len(tmp_fpaths) > 0: | |
outputs = model( | |
tmp_fpaths, | |
**generate_kwds, | |
) | |
for i, (j, cand_key, tmp_file) in enumerate(audio_to_predict): | |
cand = outputs["cands"][i] | |
cands[j] = cand | |
st.session_state[cand_key] = cand | |
tmp_file.close() | |
return audio_fnames, cands | |
def main() -> None: | |
st.header("Describe audio content with CoNeTTE") | |
model = load_conette(model_kwds=dict(device="cpu")) | |
st.warning( | |
"Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz** minimum." | |
) | |
audios = st.file_uploader( | |
"**Upload audio files here:**", | |
type=["wav", "flac", "mp3", "ogg", "avi"], | |
accept_multiple_files=True, | |
) | |
st.write("**OR**") | |
record = audiorecorder( | |
start_prompt="Start recording", stop_prompt="Stop recording", pause_prompt="" | |
) | |
record_fpath = "record.wav" | |
if len(record) > 0: | |
record.export(record_fpath, format="wav") | |
st.write( | |
f"Record frame rate: {record.frame_rate}Hz, record duration: {record.duration_seconds:.2f}s" | |
) | |
st.audio(record.export().read()) # type: ignore | |
with st.expander("Model hyperparameters"): | |
task = st.selectbox("Task embedding input", model.tasks, 0) | |
allow_rep_mode = st.selectbox( | |
"Allow repetition of words", ["stopwords", "all", "none"], 0 | |
) | |
beam_size: int = st.select_slider( # type: ignore | |
"Beam size", | |
list(range(1, 21)), | |
model.config.beam_size, | |
) | |
min_pred_size: int = st.select_slider( # type: ignore | |
"Minimal number of words", | |
list(range(1, 31)), | |
model.config.min_pred_size, | |
) | |
max_pred_size: int = st.select_slider( # type: ignore | |
"Maximal number of words", | |
list(range(1, 31)), | |
model.config.max_pred_size, | |
) | |
if allow_rep_mode == "all": | |
forbid_rep_mode = "none" | |
elif allow_rep_mode == "none": | |
forbid_rep_mode = "all" | |
elif allow_rep_mode == "stopwords": | |
forbid_rep_mode = "content_words" | |
else: | |
ALLOW_REP_MODES = ("all", "none", "stopwords") | |
raise ValueError( | |
f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})" | |
) | |
del allow_rep_mode | |
generate_kwds: dict[str, Any] = dict( | |
task=task, | |
beam_size=beam_size, | |
min_pred_size=min_pred_size, | |
max_pred_size=max_pred_size, | |
forbid_rep_mode=forbid_rep_mode, | |
) | |
if audios is not None and len(audios) > 0: | |
audio_fnames, cands = get_results(model, audios, generate_kwds) | |
for audio_fname, cand in zip(audio_fnames, cands): | |
st.success(f"**Output for {audio_fname}:**\n- {format_cand(cand)}") | |
if len(record) > 0: | |
outputs = model( | |
record_fpath, | |
**generate_kwds, | |
) | |
cand = outputs["cands"][0] | |
st.success( | |
f"**Output for {osp.basename(record_fpath)}:**\n- {format_cand(cand)}" | |
) | |
if __name__ == "__main__": | |
main() | |