|
|
|
|
|
|
|
import os |
|
import time |
|
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper |
|
from typing import Any, Optional, Union |
|
|
|
import streamlit as st |
|
import torchaudio |
|
|
|
from st_audiorec import st_audiorec |
|
from streamlit.runtime.uploaded_file_manager import UploadedFile |
|
from torch import Tensor |
|
|
|
from conette import CoNeTTEModel, conette |
|
from conette.utils.collections import dict_list_to_list_dict |
|
|
|
|
|
ALLOW_REP_MODES = ("stopwords", "all", "none") |
|
DEFAULT_TASK = "audiocaps" |
|
MAX_BEAM_SIZE = 20 |
|
MAX_PRED_SIZE = 30 |
|
MAX_BATCH_SIZE = 16 |
|
RECORD_AUDIO_FNAME = "microphone_conette_record.wav" |
|
DEFAULT_THRESHOLD = 0.3 |
|
THRESHOLD_PRECISION = 100 |
|
MIN_AUDIO_DURATION_SEC = 0.3 |
|
MAX_AUDIO_DURATION_SEC = 60 |
|
HASH_PREFIX = "hash_" |
|
TMP_FILE_PREFIX = "audio_tmp_file_" |
|
SECOND_BEFORE_CLEAR_CACHE = 10 * 60 |
|
|
|
|
|
@st.cache_resource |
|
def load_conette(*args, **kwargs) -> CoNeTTEModel: |
|
return conette(*args, **kwargs) |
|
|
|
|
|
def format_candidate(candidate: str) -> str: |
|
if len(candidate) == 0: |
|
return "" |
|
else: |
|
return f"{candidate[0].title()}{candidate[1:]}." |
|
|
|
|
|
def format_tags(tags: Optional[list[str]]) -> str: |
|
if tags is None or len(tags) == 0: |
|
return "None." |
|
else: |
|
return ", ".join(tags) |
|
|
|
|
|
def get_result_hash(audio_fname: str, generate_kwds: dict[str, Any]) -> str: |
|
return f"{HASH_PREFIX}{audio_fname}-{generate_kwds}" |
|
|
|
|
|
def get_results( |
|
model: CoNeTTEModel, |
|
audio_files: dict[str, bytes], |
|
generate_kwds: dict[str, Any], |
|
) -> dict[str, Union[dict[str, Any], str]]: |
|
|
|
audio_to_predict: dict[str, tuple[str, bytes]] = {} |
|
for audio_fname, audio in audio_files.items(): |
|
result_hash = get_result_hash(audio_fname, generate_kwds) |
|
if result_hash not in st.session_state or audio_fname == RECORD_AUDIO_FNAME: |
|
audio_to_predict[result_hash] = (audio_fname, audio) |
|
|
|
|
|
tmp_files: dict[str, _TemporaryFileWrapper] = {} |
|
for result_hash, (audio_fname, audio) in audio_to_predict.items(): |
|
tmp_file = NamedTemporaryFile(delete=False, prefix=TMP_FILE_PREFIX) |
|
tmp_file.write(audio) |
|
tmp_file.close() |
|
|
|
metadata = torchaudio.info(tmp_file.name) |
|
duration = metadata.num_frames / metadata.sample_rate |
|
|
|
if MIN_AUDIO_DURATION_SEC > duration: |
|
error_msg = f""" |
|
##### Result for "{audio_fname}" |
|
Audio file is too short. (found {duration:.2f}s but the model expect audio in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}]) |
|
""" |
|
st.session_state[result_hash] = error_msg |
|
|
|
elif duration > MAX_AUDIO_DURATION_SEC: |
|
error_msg = f""" |
|
##### Result for "{audio_fname}" |
|
Audio file is too long. (found {duration:.2f}s but the model expect audio in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}]) |
|
""" |
|
st.session_state[result_hash] = error_msg |
|
|
|
else: |
|
tmp_files[result_hash] = tmp_file |
|
|
|
|
|
for start in range(0, len(tmp_files), MAX_BATCH_SIZE): |
|
end = min(start + MAX_BATCH_SIZE, len(tmp_files)) |
|
result_hashes_j = list(tmp_files.keys())[start:end] |
|
tmp_files_j = list(tmp_files.values())[start:end] |
|
tmp_paths_j = [tmp_file.name for tmp_file in tmp_files_j] |
|
outputs_j = model( |
|
tmp_paths_j, |
|
**generate_kwds, |
|
) |
|
outputs_lst = dict_list_to_list_dict(outputs_j) |
|
for result_hash, output_i in zip(result_hashes_j, outputs_lst): |
|
st.session_state[result_hash] = output_i |
|
|
|
|
|
outputs = {} |
|
for audio_fname in audio_files.keys(): |
|
result_hash = get_result_hash(audio_fname, generate_kwds) |
|
output_i = st.session_state[result_hash] |
|
outputs[audio_fname] = output_i |
|
|
|
for tmp_file in tmp_files.values(): |
|
os.remove(tmp_file.name) |
|
|
|
return outputs |
|
|
|
|
|
def show_results(outputs: dict[str, Union[dict[str, Any], str]]) -> None: |
|
keys = list(outputs.keys())[::-1] |
|
outputs = {key: outputs[key] for key in keys} |
|
|
|
st.divider() |
|
|
|
for audio_fname, output in outputs.items(): |
|
if isinstance(output, str): |
|
st.error(output) |
|
st.divider() |
|
continue |
|
|
|
cand: str = output["cands"] |
|
lprobs: Tensor = output["lprobs"] |
|
tags_lst = output.get("tags") |
|
mult_cands: list[str] = output["mult_cands"] |
|
mult_lprobs: Tensor = output["mult_lprobs"] |
|
|
|
cand = format_candidate(cand) |
|
prob = lprobs.exp().tolist() |
|
tags = format_tags(tags_lst) |
|
mult_cands = [format_candidate(cand_i) for cand_i in mult_cands] |
|
mult_probs = mult_lprobs.exp() |
|
|
|
indexes = mult_probs.argsort(descending=True)[1:] |
|
mult_probs = mult_probs[indexes].tolist() |
|
mult_cands = [mult_cands[idx] for idx in indexes] |
|
|
|
if audio_fname == RECORD_AUDIO_FNAME: |
|
header = "##### Result for microphone input:" |
|
else: |
|
header = f'##### Result for "{audio_fname}"' |
|
|
|
lines = [ |
|
header, |
|
f'<center><p class="space"><p class="big-font">"{cand}"</p></p></center>', |
|
] |
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
.big-font { |
|
font-size:22px !important; |
|
background-color: rgba(0, 255, 0, 0.1); |
|
padding: 10px; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
content = "<br>".join(lines) |
|
st.markdown(content, unsafe_allow_html=True) |
|
|
|
lines = [ |
|
f"- **Probability**: {prob*100:.1f}%", |
|
] |
|
if len(mult_cands) > 0: |
|
msg = f"- **Other descriptions:**" |
|
lines.append(msg) |
|
|
|
for cand_i, prob_i in zip(mult_cands, mult_probs): |
|
msg = f' - "{cand_i}" ({prob_i*100:.1f}%)' |
|
lines.append(msg) |
|
|
|
msg = f"- **Tags:** {tags}" |
|
lines.append(msg) |
|
|
|
content = "\n".join(lines) |
|
st.markdown(content, unsafe_allow_html=False) |
|
st.divider() |
|
|
|
|
|
def main() -> None: |
|
model = load_conette(model_kwds=dict(device="cpu")) |
|
|
|
st.header("Describe audio content with CoNeTTE") |
|
st.markdown( |
|
"This interface allows you to generate a short description of the sound events of any recording using an Audio Captioning system. You can try it from your microphone or upload a file below." |
|
) |
|
st.markdown( |
|
"Use '**Start Recording**' and '**Stop**' to record an audio from your microphone." |
|
) |
|
record_data = st_audiorec() |
|
|
|
with st.expander("Or upload audio files here:"): |
|
audio_files: Optional[list[UploadedFile]] = st.file_uploader( |
|
f"Audio files are automatically resampled to 32 kHz.\nTheir duration must be in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}] seconds.", |
|
type=["wav", "flac", "mp3", "ogg", "avi"], |
|
accept_multiple_files=True, |
|
help="Supports wav, flac, mp3, ogg and avi files.", |
|
) |
|
|
|
with st.expander("Model options"): |
|
if DEFAULT_TASK in model.tasks: |
|
default_task_idx = list(model.tasks).index(DEFAULT_TASK) |
|
else: |
|
default_task_idx = 0 |
|
|
|
task = st.selectbox("Task embedding input", model.tasks, default_task_idx) |
|
allow_rep_mode = st.selectbox("Allow repetition of words", ALLOW_REP_MODES, 0) |
|
beam_size: int = st.select_slider( |
|
"Beam size", |
|
list(range(1, MAX_BEAM_SIZE + 1)), |
|
model.config.beam_size, |
|
) |
|
min_pred_size, max_pred_size = st.slider( |
|
"Minimal and maximal number of words", |
|
1, |
|
MAX_PRED_SIZE, |
|
(model.config.min_pred_size, model.config.max_pred_size), |
|
) |
|
threshold = st.select_slider( |
|
"Tags threshold", |
|
[(i / THRESHOLD_PRECISION) for i in range(THRESHOLD_PRECISION + 1)], |
|
DEFAULT_THRESHOLD, |
|
) |
|
|
|
if allow_rep_mode == "all": |
|
forbid_rep_mode = "none" |
|
elif allow_rep_mode == "none": |
|
forbid_rep_mode = "all" |
|
elif allow_rep_mode == "stopwords": |
|
forbid_rep_mode = "content_words" |
|
else: |
|
msg = ( |
|
f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})" |
|
) |
|
raise ValueError(msg) |
|
|
|
del allow_rep_mode |
|
|
|
generate_kwds: dict[str, Any] = dict( |
|
task=task, |
|
beam_size=beam_size, |
|
min_pred_size=min_pred_size, |
|
max_pred_size=max_pred_size, |
|
forbid_rep_mode=forbid_rep_mode, |
|
threshold=threshold, |
|
) |
|
|
|
audios: dict[str, bytes] = {} |
|
if audio_files is not None: |
|
audios |= {audio.name: audio.getvalue() for audio in audio_files} |
|
if record_data is not None: |
|
audios |= {RECORD_AUDIO_FNAME: record_data} |
|
|
|
if len(audios) > 0: |
|
with st.spinner("Generating descriptions..."): |
|
outputs = get_results(model, audios, generate_kwds) |
|
st.header("Results:") |
|
show_results(outputs) |
|
|
|
current = time.perf_counter() |
|
last_generation = st.session_state.get("last_generation", current) |
|
if current > last_generation + SECOND_BEFORE_CLEAR_CACHE: |
|
print(f"Removing result cache...") |
|
for key in st.session_state.keys(): |
|
if isinstance(key, str) and key.startswith(HASH_PREFIX): |
|
del st.session_state[key] |
|
st.session_state["last_generation"] = current |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|