|
""" |
|
General streamlit diarization application |
|
""" |
|
import os |
|
import shutil |
|
from io import BytesIO |
|
from typing import Dict, Union |
|
from pathlib import Path |
|
|
|
import librosa |
|
import librosa.display |
|
import matplotlib.figure |
|
import numpy as np |
|
import streamlit as st |
|
import streamlit.uploaded_file_manager |
|
from PIL import Image |
|
from pydub import AudioSegment |
|
from matplotlib import pyplot as plt |
|
|
|
import configs |
|
from utils import audio_utils, text_utils, general_utils, streamlit_utils |
|
from diarizers import pyannote_diarizer, nemo_diarizer |
|
|
|
plt.rcParams["figure.figsize"] = (10, 5) |
|
|
|
|
|
def plot_audio_diarization(diarization_figure: Union[plt.gcf, np.array], diarization_name: str, |
|
audio_data: np.array, |
|
sampling_frequency: int): |
|
""" |
|
Function that plots the audio along with the different applied diarization techniques |
|
Args: |
|
diarization_figure (plt.gcf): the diarization figure to plot |
|
diarization_name (str): the name of the diarization technique |
|
audio_data (np.array): the audio numpy array |
|
sampling_frequency (int): the audio sampling frequency |
|
""" |
|
col1, col2 = st.columns([3, 5]) |
|
with col1: |
|
st.markdown( |
|
f"<h4 style='text-align: center; color: black;'>Original</h5>", |
|
unsafe_allow_html=True, |
|
) |
|
st.markdown("<br></br>", unsafe_allow_html=True) |
|
|
|
st.audio(audio_utils.create_st_audio_virtualfile(audio_data, sampling_frequency)) |
|
with col2: |
|
st.markdown( |
|
f"<h4 style='text-align: center; color: black;'>{diarization_name}</h5>", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
if type(diarization_figure) == matplotlib.figure.Figure: |
|
buf = BytesIO() |
|
diarization_figure.savefig(buf, format="png") |
|
st.image(buf) |
|
else: |
|
st.image(diarization_figure) |
|
st.markdown("---") |
|
|
|
|
|
def execute_diarization(file_uploader: st.uploaded_file_manager.UploadedFile, selected_option: any, |
|
sample_option_dict: Dict[str, str], |
|
diarization_checkbox_dict: Dict[str, bool], |
|
session_id: str): |
|
""" |
|
Function that exectutes the diarization based on the specified files and pipelines |
|
Args: |
|
file_uploader (st.uploaded_file_manager.UploadedFile): the uploaded streamlit audio file |
|
selected_option (any): the selected option of samples |
|
Dict[str, str]: a dictionary where the name is the file name (without extension to be listed |
|
as an option for the user) and the value is the original file name |
|
diarization_checkbox_dict (Dict[str, bool]): dictionary where the key is the Diarization |
|
technique name and the value is a boolean indicating whether to apply that technique |
|
session_id (str): unique id of the user session |
|
""" |
|
user_folder = os.path.join(configs.UPLOADED_AUDIO_SAMPLES_DIR, session_id) |
|
Path(user_folder).mkdir(parents=True, exist_ok=True) |
|
|
|
if file_uploader is not None: |
|
file_name = file_uploader.name |
|
file_path = os.path.join(user_folder, file_name) |
|
audio = AudioSegment.from_wav(file_uploader).set_channels(1) |
|
|
|
audio = audio[0:1000 * 30] |
|
audio.export(file_path, format='wav') |
|
else: |
|
file_name = sample_option_dict[selected_option] |
|
file_path = os.path.join(configs.AUDIO_SAMPLES_DIR, file_name) |
|
|
|
audio_data, sampling_frequency = librosa.load(file_path) |
|
|
|
nb_pipelines_to_run = sum(pipeline_bool for pipeline_bool in diarization_checkbox_dict.values()) |
|
pipeline_count = 0 |
|
for diarization_idx, (diarization_name, diarization_bool) in \ |
|
enumerate(diarization_checkbox_dict.items()): |
|
|
|
if diarization_bool: |
|
pipeline_count += 1 |
|
if diarization_name == 'pyannote': |
|
diarizer = pyannote_diarizer.PyannoteDiarizer(file_path) |
|
elif diarization_name == 'NeMo': |
|
diarizer = nemo_diarizer.NemoDiarizer(file_path, user_folder) |
|
else: |
|
raise NotImplementedError('Framework not recognized') |
|
|
|
if file_uploader is not None: |
|
with st.spinner( |
|
f"Executing {pipeline_count}/{nb_pipelines_to_run} diarization pipelines " |
|
f"({diarization_name}). This might take 1-2 minutes..."): |
|
diarizer_figure = diarizer.get_diarization_figure() |
|
else: |
|
diarizer_figure = Image.open(f"{configs.PRECOMPUTED_DIARIZATION_FIGURE}/" |
|
f"{file_name.rsplit('.')[0]}_{diarization_name}.png") |
|
|
|
plot_audio_diarization(diarizer_figure, diarization_name, audio_data, |
|
sampling_frequency) |
|
|
|
shutil.rmtree(user_folder) |
|
|
|
|
|
def main(): |
|
|
|
text_utils.intro_container() |
|
diarization_container = st.container() |
|
|
|
|
|
diarization_container.markdown("---") |
|
|
|
|
|
text_utils.demo_container(diarization_container) |
|
diarization_container.markdown("Choose the Diarization method here:") |
|
|
|
diarization_checkbox_dict = {} |
|
for diarization_method in configs.DIARIZATION_METHODS: |
|
diarization_checkbox_dict[diarization_method] = diarization_container.checkbox( |
|
diarization_method) |
|
|
|
|
|
diarization_container.markdown("(Optional) Upload an audio file here:") |
|
file_uploader = diarization_container.file_uploader( |
|
label="", type=[".wav", ".wave"] |
|
) |
|
|
|
sample_option_dict = general_utils.get_dict_of_audio_samples(configs.AUDIO_SAMPLES_DIR) |
|
diarization_container.markdown("Or select a sample file here:") |
|
selected_option = diarization_container.selectbox( |
|
label="", options=list(sample_option_dict.keys()) |
|
) |
|
diarization_container.markdown("---") |
|
|
|
|
|
if diarization_container.button("Apply"): |
|
session_id = streamlit_utils.get_session() |
|
execute_diarization( |
|
file_uploader=file_uploader, |
|
selected_option=selected_option, |
|
sample_option_dict=sample_option_dict, |
|
diarization_checkbox_dict=diarization_checkbox_dict, |
|
session_id=session_id |
|
) |
|
text_utils.conlusion_container() |
|
|
|
|
|
if __name__ == "__main__": |
|
st.set_page_config(layout="wide", page_title="Audio diarization visualization") |
|
main() |
|
|