Last commit not found
import numpy as np | |
import streamlit as st | |
import librosa | |
import soundfile as sf | |
import librosa.display | |
from config import CONFIG | |
import torch | |
from dataset import MaskGenerator | |
import onnxruntime, onnx | |
import matplotlib.pyplot as plt | |
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas | |
from pystoi import stoi | |
from pesq import pesq | |
import pandas as pd | |
import torchaudio | |
from torchmetrics.audio import ShortTimeObjectiveIntelligibility as STOI | |
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality as PESQ | |
from PLCMOS.plc_mos import PLCMOSEstimator | |
from speechmos import dnsmos | |
from speechmos import plcmos | |
def load_model(): | |
path = 'lightning_logs/version_0/checkpoints/frn.onnx' | |
onnx_model = onnx.load(path) | |
options = onnxruntime.SessionOptions() | |
options.intra_op_num_threads = 2 | |
options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL | |
session = onnxruntime.InferenceSession(path, options) | |
input_names = [x.name for x in session.get_inputs()] | |
output_names = [x.name for x in session.get_outputs()] | |
return session, onnx_model, input_names, output_names | |
def inference(re_im, session, onnx_model, input_names, output_names): | |
inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim], | |
dtype=np.float32) | |
for i, _input in enumerate(onnx_model.graph.input) | |
} | |
output_audio = [] | |
for t in range(re_im.shape[0]): | |
inputs[input_names[0]] = re_im[t] | |
out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs) | |
inputs[input_names[1]] = prev_mag | |
inputs[input_names[2]] = predictor_state | |
inputs[input_names[3]] = mlp_state | |
output_audio.append(out) | |
output_audio = torch.tensor(np.concatenate(output_audio, 0)) | |
output_audio = output_audio.permute(1, 0, 2).contiguous() | |
output_audio = torch.view_as_complex(output_audio) | |
output_audio = torch.istft(output_audio, window, stride, window=hann) | |
return output_audio.numpy() | |
def visualize(hr, lr, recon, sr): | |
sr = sr | |
window_size = 1024 | |
window = np.hanning(window_size) | |
stft_hr = librosa.core.spectrum.stft(hr, n_fft=window_size, hop_length=512, window=window) | |
stft_hr = 2 * np.abs(stft_hr) / np.sum(window) | |
stft_lr = librosa.core.spectrum.stft(lr, n_fft=window_size, hop_length=512, window=window) | |
stft_lr = 2 * np.abs(stft_lr) / np.sum(window) | |
stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window) | |
stft_recon = 2 * np.abs(stft_recon) / np.sum(window) | |
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 12)) | |
ax1.title.set_text('Оригинальный сигнал') | |
ax2.title.set_text('Сигнал с потерями') | |
ax3.title.set_text('Улучшенный сигнал') | |
canvas = FigureCanvas(fig) | |
p = librosa.display.specshow(librosa.amplitude_to_db(stft_hr), ax=ax1, y_axis='log', x_axis='time', sr=sr) | |
p = librosa.display.specshow(librosa.amplitude_to_db(stft_lr), ax=ax2, y_axis='log', x_axis='time', sr=sr) | |
p = librosa.display.specshow(librosa.amplitude_to_db(stft_recon), ax=ax3, y_axis='log', x_axis='time', sr=sr) | |
ax1.set_xlabel('Время, с') | |
ax1.set_ylabel('Частота, Гц') | |
ax2.set_xlabel('Время, с') | |
ax2.set_ylabel('Частота, Гц') | |
ax3.set_xlabel('Время, с') | |
ax3.set_ylabel('Частота, Гц') | |
return fig | |
packet_size = CONFIG.DATA.EVAL.packet_size | |
window = CONFIG.DATA.window_size | |
stride = CONFIG.DATA.stride | |
title = 'Сокрытие потерь пакетов' | |
st.set_page_config(page_title=title, page_icon=":sound:") | |
st.title(title) | |
st.subheader('1. Загрузка аудио') | |
uploaded_file = st.file_uploader("Загрузите аудио формата (.wav) 48 КГц") | |
is_file_uploaded = uploaded_file is not None | |
if not is_file_uploaded: | |
uploaded_file = 'sample.wav' | |
target, sr = librosa.load(uploaded_file) | |
target = target[:packet_size * (len(target) // packet_size)] | |
st.text('Ваше аудио') | |
st.audio(uploaded_file) | |
st.subheader('2. Выберите желаемый процент потерь') | |
slider = [st.slider("Ожидаемый процент потерь для генератора потерь цепи Маркова", 0, 100, step=1)] | |
loss_percent = float(slider[0])/100 | |
mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)]) | |
lossy_input = target.copy().reshape(-1, packet_size) | |
mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis] | |
lossy_input *= mask | |
lossy_input = lossy_input.reshape(-1) | |
hann = torch.sqrt(torch.hann_window(window)) | |
lossy_input_tensor = torch.tensor(lossy_input) | |
re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze( | |
1).numpy().astype(np.float32) | |
session, onnx_model, input_names, output_names = load_model() | |
if st.button('Сгенерировать потери'): | |
with st.spinner('Ожидайте...'): | |
output = inference(re_im, session, onnx_model, input_names, output_names) | |
st.subheader('3. Визуализация') | |
fig = visualize(target, lossy_input, output, sr) | |
st.pyplot(fig) | |
st.success('Сделано!') | |
sf.write('target.wav', target, sr) | |
sf.write('lossy.wav', lossy_input, sr) | |
sf.write('enhanced.wav', output, sr) | |
st.text('Оригинальное аудио') | |
st.audio('target.wav') | |
st.text('Аудио с потерями') | |
st.audio('lossy.wav') | |
st.text('Улучшенное аудио') | |
st.audio('enhanced.wav') | |
#data_clean, samplerate = torchaudio.load('target.wav') | |
#data_lossy, samplerate = torchaudio.load('lossy.wav') | |
#data_enhanced, samplerate = torchaudio.load('enhanced.wav') | |
#min_len = min(data_clean.shape[1], data_lossy.shape[1], data_enhanced.shape[1]) | |
#data_clean = data_clean[:, :min_len] | |
#data_lossy = data_lossy[:, :min_len] | |
#data_enhanced = data_enhanced[:, :min_len] | |
#stoi = STOI(samplerate) | |
#stoi_orig = round(float(stoi(data_clean, data_clean)),3) | |
#stoi_lossy = round(float(stoi(data_clean, data_lossy)),5) | |
#stoi_enhanced = round(float(stoi(data_clean, data_enhanced)),5) | |
#stoi_mass=[stoi_orig, stoi_lossy, stoi_enhanced] | |
#pesq = PESQ(8000, 'nb') | |
#data_clean = data_clean.cpu().numpy() | |
#data_lossy = data_lossy.cpu().numpy() | |
#data_enhanced = data_enhanced.cpu().numpy() | |
#if samplerate != 8000: | |
#data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=8000) | |
#data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=8000) | |
#data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=8000) | |
#pesq_orig = float(pesq(torch.tensor(data_clean), torch.tensor(data_clean))) | |
#pesq_lossy = float(pesq(torch.tensor(data_lossy), torch.tensor(data_clean))) | |
#pesq_enhanced = float(pesq(torch.tensor(data_enhanced), torch.tensor(data_clean))) | |
#psq_mas=[pesq_orig, pesq_lossy, pesq_enhanced] | |
#_____________________________________________ | |
data_clean, samplerate = sf.read('target.wav') | |
data_lossy, samplerate = sf.read('lossy.wav') | |
data_enhanced, samplerate = sf.read('enhanced.wav') | |
min_len = min(data_clean.shape[0], data_lossy.shape[0], data_enhanced.shape[0]) | |
data_clean = data_clean[:min_len] | |
data_lossy = data_lossy[:min_len] | |
data_enhanced = data_enhanced[:min_len] | |
stoi_orig = round(stoi(data_clean, data_clean, samplerate, extended=False),5) | |
stoi_lossy = round(stoi(data_clean, data_lossy , samplerate, extended=False),5) | |
stoi_enhanced = round(stoi(data_clean, data_enhanced, samplerate, extended=False),5) | |
stoi_mass=[stoi_orig, stoi_lossy, stoi_enhanced] | |
#def get_power(x, nfft): | |
# S = librosa.stft(x, n_fft=nfft) | |
# S = np.log(np.abs(S) ** 2 + 1e-8) | |
# return S | |
#def LSD(x_hr, x_pr): | |
# S1 = get_power(x_hr, nfft=2048) | |
# S2 = get_power(x_pr, nfft=2048) | |
# lsd = np.mean(np.sqrt(np.mean((S1 - S2) ** 2, axis=-1)), axis=0) | |
# return lsd | |
#lsd_orig = LSD(data_clean,data_clean) | |
#lsd_lossy = LSD(data_lossy,data_clean) | |
#lsd_enhanced = LSD(data_enhanced,data_clean) | |
#lsd_mass=[lsd_orig, lsd_lossy, lsd_enhanced] | |
if samplerate != 8000: | |
data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=8000) | |
data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=8000) | |
data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=8000) | |
pesq_orig = pesq(fs = 8000, ref = data_clean, deg = data_clean, mode='nb') | |
pesq_lossy = pesq(fs = 8000, ref = data_clean, deg = data_lossy, mode='nb') | |
pesq_enhanced = pesq(fs = 8000, ref = data_clean, deg = data_enhanced, mode='nb') | |
psq_mas=[pesq_orig, pesq_lossy, pesq_enhanced] | |
data_clean, fs = sf.read('target.wav') | |
data_lossy, fs = sf.read('lossy.wav') | |
data_enhanced, fs = sf.read('enhanced.wav') | |
if fs!= 16000: | |
data_lossy = librosa.resample(data_lossy, orig_sr=48000, target_sr=16000) | |
data_clean = librosa.resample(data_clean, orig_sr=48000, target_sr=16000) | |
data_enhanced = librosa.resample(data_enhanced, orig_sr=48000, target_sr=16000) | |
PLC_example=PLCMOSEstimator() | |
PLC_org = PLC_example.run(audio_degraded=data_clean, audio_clean=data_clean)[0] | |
PLC_lossy = PLC_example.run(audio_degraded=data_lossy, audio_clean=data_clean)[0] | |
PLC_enhanced = PLC_example.run(audio_degraded=data_enhanced, audio_clean=data_clean)[0] | |
PLC_massv1 = [PLC_org, PLC_lossy, PLC_enhanced] | |
df_1 = pd.DataFrame(columns=['Audio', 'PESQ', 'STOI', 'PLCMOSv1']) | |
df_1['Audio'] = ['Clean', 'Lossy', 'Enhanced'] | |
df_1['PESQ'] = psq_mas | |
df_1['STOI'] = stoi_mass | |
#df['LSD'] = lsd_mass | |
df_1['PLCMOSv1'] = PLC_massv1 | |
#new_columns = pd.MultiIndex.from_tuples([('', 'Audio'), ('Эталонные метрики', 'PESQ'), ('Эталонные метрики', 'STOI'), ('Эталонные метрики', 'PLCMOSv1')]) | |
# Присваиваем новый мультииндекс столбцам | |
#df_1.columns = new_columns | |
PLC_massv2 = [plcmos.run("target.wav", sr=16000)['plcmos'], plcmos.run("lossy.wav", sr=16000)['plcmos'], plcmos.run("enhanced.wav", sr=16000)['plcmos']] | |
DNS = [dnsmos.run("target.wav", sr=16000)['ovrl_mos'], dnsmos.run("lossy.wav", sr=16000)['ovrl_mos'], dnsmos.run("enhanced.wav", sr=16000)['ovrl_mos']] | |
df_1['PLCMOSv2'] = PLC_massv2 | |
df_1['DNSMOS'] = DNS | |
#df_2 = pd.DataFrame(columns=['DNSMOS', 'PLCMOSv2']) | |
#df_2['DNSMOS'] = DNS | |
#df_2['PLCMOSv2'] = PLC_massv2 | |
#new_columns = pd.MultiIndex.from_tuples([('Неэталонные метрики', 'DNSMOS'), ('Неэталонные метрики', 'PLCMOSv2')]) | |
# Присваиваем новый мультииндекс столбцам | |
#df_2.columns = new_columns | |
#df_merged = df_1.merge(df_2, left_index=True, right_index=True) | |
st.checkbox("Use container width", value=False, key="use_container_width") | |
st.dataframe(df_1, use_container_width=st.session_state.use_container_width, hide_index=True) | |