Spaces:
Build error
Build error
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import os | |
import time | |
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper | |
from typing import Any, Optional, Union | |
import streamlit as st | |
import torchaudio | |
from conette import CoNeTTEModel, conette, __version__ | |
from conette.utils.collections import dict_list_to_list_dict | |
from st_audiorec import st_audiorec | |
from streamlit.runtime.uploaded_file_manager import UploadedFile | |
from torch import Tensor | |
ALLOW_REP_MODES = ("stopwords", "all", "none") | |
DEFAULT_TASK = "audiocaps" | |
MAX_BEAM_SIZE = 20 | |
MAX_PRED_SIZE = 30 | |
MAX_BATCH_SIZE = 16 | |
RECORD_AUDIO_FNAME = "microphone_conette_record.wav" | |
DEFAULT_THRESHOLD = 0.3 | |
THRESHOLD_PRECISION = 100 | |
MIN_AUDIO_DURATION_SEC = 0.3 | |
MAX_AUDIO_DURATION_SEC = 60 | |
HASH_PREFIX = "hash_" | |
TMP_FILE_PREFIX = "audio_tmp_file_" | |
SECOND_BEFORE_CLEAR_CACHE = 10 * 60 | |
def load_conette(*args, **kwargs) -> CoNeTTEModel: | |
return conette(*args, **kwargs) | |
def format_candidate(candidate: str) -> str: | |
if len(candidate) == 0: | |
return "" | |
else: | |
return f"{candidate[0].title()}{candidate[1:]}." | |
def format_tags(tags: Optional[list[str]]) -> str: | |
if tags is None or len(tags) == 0: | |
return "None." | |
else: | |
return ", ".join(tags) | |
def get_result_hash(audio_fname: str, generate_kwds: dict[str, Any]) -> str: | |
return f"{HASH_PREFIX}{audio_fname}-{generate_kwds}" | |
def get_results( | |
model: CoNeTTEModel, | |
audio_files: dict[str, bytes], | |
generate_kwds: dict[str, Any], | |
) -> dict[str, Union[dict[str, Any], str]]: | |
# Get audio to be processed | |
audio_to_predict: dict[str, tuple[str, bytes]] = {} | |
for audio_fname, audio in audio_files.items(): | |
result_hash = get_result_hash(audio_fname, generate_kwds) | |
if result_hash not in st.session_state or audio_fname == RECORD_AUDIO_FNAME: | |
audio_to_predict[result_hash] = (audio_fname, audio) | |
# Save audio to be processed | |
tmp_files: dict[str, _TemporaryFileWrapper] = {} | |
for result_hash, (audio_fname, audio) in audio_to_predict.items(): | |
tmp_file = NamedTemporaryFile(delete=False, prefix=TMP_FILE_PREFIX) | |
tmp_file.write(audio) | |
tmp_file.close() | |
metadata = torchaudio.info(tmp_file.name) # type: ignore | |
duration = metadata.num_frames / metadata.sample_rate | |
if MIN_AUDIO_DURATION_SEC > duration: | |
error_msg = f""" | |
##### Result for "{audio_fname}" | |
Audio file is too short. (found {duration:.2f}s but the model expect audio in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}]) | |
""" | |
st.session_state[result_hash] = error_msg | |
elif duration > MAX_AUDIO_DURATION_SEC: | |
error_msg = f""" | |
##### Result for "{audio_fname}" | |
Audio file is too long. (found {duration:.2f}s but the model expect audio in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}]) | |
""" | |
st.session_state[result_hash] = error_msg | |
else: | |
tmp_files[result_hash] = tmp_file | |
# Generate predictions and store them in session state | |
for start in range(0, len(tmp_files), MAX_BATCH_SIZE): | |
end = min(start + MAX_BATCH_SIZE, len(tmp_files)) | |
result_hashes_j = list(tmp_files.keys())[start:end] | |
tmp_files_j = list(tmp_files.values())[start:end] | |
tmp_paths_j = [tmp_file.name for tmp_file in tmp_files_j] | |
outputs_j = model( | |
tmp_paths_j, | |
**generate_kwds, | |
) | |
outputs_lst = dict_list_to_list_dict(outputs_j) # type: ignore | |
for result_hash, output_i in zip(result_hashes_j, outputs_lst): | |
st.session_state[result_hash] = output_i | |
# Get outputs | |
outputs = {} | |
for audio_fname in audio_files.keys(): | |
result_hash = get_result_hash(audio_fname, generate_kwds) | |
output_i = st.session_state[result_hash] | |
outputs[audio_fname] = output_i | |
for tmp_file in tmp_files.values(): | |
os.remove(tmp_file.name) | |
return outputs | |
def show_results(outputs: dict[str, Union[dict[str, Any], str]]) -> None: | |
keys = list(outputs.keys())[::-1] | |
outputs = {key: outputs[key] for key in keys} | |
st.divider() | |
for audio_fname, output in outputs.items(): | |
if isinstance(output, str): | |
st.error(output) | |
st.divider() | |
continue | |
cand: str = output["cands"] | |
lprobs: Tensor = output["lprobs"] | |
tags_lst = output.get("tags") | |
mult_cands: list[str] = output["mult_cands"] | |
mult_lprobs: Tensor = output["mult_lprobs"] | |
cand = format_candidate(cand) | |
prob = lprobs.exp().tolist() | |
tags = format_tags(tags_lst) | |
mult_cands = [format_candidate(cand_i) for cand_i in mult_cands] | |
mult_probs = mult_lprobs.exp() | |
indexes = mult_probs.argsort(descending=True)[1:] | |
mult_probs = mult_probs[indexes].tolist() | |
mult_cands = [mult_cands[idx] for idx in indexes] | |
if audio_fname == RECORD_AUDIO_FNAME: | |
header = "##### Result for microphone input:" | |
else: | |
header = f'##### Result for "{audio_fname}"' | |
lines = [ | |
header, | |
f'<center><p class="space"><p class="big-font">"{cand}"</p></p></center>', | |
] | |
st.markdown( | |
""" | |
<style> | |
.big-font { | |
font-size:22px !important; | |
background-color: rgba(0, 255, 0, 0.1); | |
padding: 10px; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
content = "<br>".join(lines) | |
st.markdown(content, unsafe_allow_html=True) | |
lines = [ | |
f"- **Probability**: {prob*100:.1f}%", | |
] | |
if len(mult_cands) > 0: | |
msg = f"- **Other descriptions:**" | |
lines.append(msg) | |
for cand_i, prob_i in zip(mult_cands, mult_probs): | |
msg = f' - "{cand_i}" ({prob_i*100:.1f}%)' | |
lines.append(msg) | |
msg = f"- **Tags:** {tags}" | |
lines.append(msg) | |
content = "\n".join(lines) | |
st.markdown(content, unsafe_allow_html=False) | |
st.divider() | |
def main() -> None: | |
model = load_conette(model_kwds=dict(device="cpu")) | |
st.header("Describe audio content with CoNeTTE") | |
st.markdown( | |
"This interface allows you to generate a short description of the sound events of any recording using an Audio Captioning system. You can try it from your microphone or upload a file below." | |
) | |
st.markdown( | |
"Use '**Start Recording**' and '**Stop**' to record an audio from your microphone." | |
) | |
record_data = st_audiorec() | |
with st.expander("Or upload audio files here:"): | |
audio_files: Optional[list[UploadedFile]] = st.file_uploader( | |
f"Audio files are automatically resampled to 32 kHz.\nTheir duration must be in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}] seconds.", | |
type=["wav", "flac", "mp3", "ogg", "avi"], | |
accept_multiple_files=True, | |
help="Supports wav, flac, mp3, ogg and avi files.", | |
) | |
with st.expander("Model options"): | |
if DEFAULT_TASK in model.tasks: | |
default_task_idx = list(model.tasks).index(DEFAULT_TASK) | |
else: | |
default_task_idx = 0 | |
task = st.selectbox("Task embedding input", model.tasks, default_task_idx) | |
allow_rep_mode = st.selectbox("Allow repetition of words", ALLOW_REP_MODES, 0) | |
beam_size: int = st.select_slider( # type: ignore | |
"Beam size", | |
list(range(1, MAX_BEAM_SIZE + 1)), | |
model.config.beam_size, | |
) | |
min_pred_size, max_pred_size = st.slider( | |
"Minimal and maximal number of words", | |
1, | |
MAX_PRED_SIZE, | |
(model.config.min_pred_size, model.config.max_pred_size), | |
) | |
threshold = st.select_slider( | |
"Tags threshold", | |
[(i / THRESHOLD_PRECISION) for i in range(THRESHOLD_PRECISION + 1)], | |
DEFAULT_THRESHOLD, | |
) | |
if allow_rep_mode == "all": | |
forbid_rep_mode = "none" | |
elif allow_rep_mode == "none": | |
forbid_rep_mode = "all" | |
elif allow_rep_mode == "stopwords": | |
forbid_rep_mode = "content_words" | |
else: | |
msg = ( | |
f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})" | |
) | |
raise ValueError(msg) | |
del allow_rep_mode | |
generate_kwds: dict[str, Any] = dict( | |
task=task, | |
beam_size=beam_size, | |
min_pred_size=min_pred_size, | |
max_pred_size=max_pred_size, | |
forbid_rep_mode=forbid_rep_mode, | |
threshold=threshold, | |
) | |
audios: dict[str, bytes] = {} | |
if audio_files is not None: | |
audios |= {audio.name: audio.getvalue() for audio in audio_files} | |
if record_data is not None: | |
audios |= {RECORD_AUDIO_FNAME: record_data} | |
if len(audios) > 0: | |
with st.spinner("Generating descriptions..."): | |
outputs = get_results(model, audios, generate_kwds) | |
st.header("Results:") | |
show_results(outputs) | |
current = time.perf_counter() | |
last_generation = st.session_state.get("last_generation", current) | |
if current > last_generation + SECOND_BEFORE_CLEAR_CACHE: | |
print(f"Removing result cache...") | |
for key in st.session_state.keys(): | |
if isinstance(key, str) and key.startswith(HASH_PREFIX): | |
del st.session_state[key] | |
st.session_state["last_generation"] = current | |
content = f"""CoNeTTE version {__version__}. <a href="https://github.com/Labbeti/conette-audio-captioning/">Source code on GitHub</a>. <a href="https://ieeexplore.ieee.org/document/10603439">Academic Paper</a>.""" | |
st.divider() | |
st.markdown(content, unsafe_allow_html=True) | |
if __name__ == "__main__": | |
main() | |