from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT
from gyraudio.audio_separation.parser import shared_parser
from gyraudio.audio_separation.infer import launch_infer, RECORD_KEYS, SNR_OUT, SNR_IN, NBATCH, SAVE_IDX
from gyraudio.audio_separation.properties import TEST, NAME, SHORT_NAME, CURRENT_EPOCH, SNR_FILTER
import sys
import os
from dash import Dash, html, dcc, callback, Output, Input, dash_table
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
from typing import List
import torch
from pathlib import Path
DIFF_SNR = 'SNR out - SNR in'
def get_app(record_row_dfs : pd.DataFrame, eval_dfs : List[pd.DataFrame]) :
app = Dash(__name__)
# names_options = [{'label' : f"{record[SHORT_NAME]} - {record[NAME]} epoch {record[CURRENT_EPOCH]:04d}", 'value' : record[NAME] } for idx, record in record_row_dfs.iterrows()]
app.layout = html.Div([
html.H1(children='Inference results', style={'textAlign':'center'}),
# dcc.Dropdown(names_options, names_options[0]['value'], id='exp-selection'),
# dcc.RadioItems(['scatter', 'box'], 'box', inline=True, id='radio-plot-type'),
dcc.RadioItems([SNR_OUT, DIFF_SNR], DIFF_SNR, inline=True, id='radio-plot-out'),
dcc.Graph(id='graph-content')
])
@callback(
Output('graph-content', 'figure'),
# Input('exp-selection', 'value'),
# Input('radio-plot-type', 'value'),
Input('radio-plot-out', 'value'),
)
def update_graph(radio_plot_out) :
fig = make_subplots(rows = 2, cols = 1)
colors = px.colors.qualitative.Plotly
for id, record in record_row_dfs.iterrows() :
color = colors[id % len(colors)]
eval_df = eval_dfs[id].sort_values(by=SNR_IN)
eval_df[DIFF_SNR] = eval_df[SNR_OUT] - eval_df[SNR_IN]
legend = f'{record[SHORT_NAME]}_{record[NAME]}'
fig.add_trace(
go.Scatter(
x=eval_df[SNR_IN],
y=eval_df[radio_plot_out],
mode="markers", marker={'color' : color},
name=legend,
hovertemplate = 'File : %{text}'+
'
%{y}
',
text = [f"{eval[SAVE_IDX]:.0f}" for idx, eval in eval_df.iterrows()]
),
row = 1, col = 1
)
eval_df_bins = eval_df
eval_df_bins[SNR_IN] = eval_df_bins[SNR_IN].apply(lambda snr : round(snr))
fig.add_trace(
go.Box(
x=eval_df[SNR_IN],
y=eval_df[radio_plot_out],
fillcolor = color,
marker={'color' : color},
name = legend
),
row = 2, col = 1
)
title = f"SNR performances"
fig.update_layout(
title=title,
xaxis2_title = SNR_IN,
yaxis_title = radio_plot_out,
hovermode='x unified'
)
return fig
return app
def main(argv):
default_device = "cuda" if torch.cuda.is_available() else "cpu"
parser_def = shared_parser(help="Launch training \nCheck results at: https://wandb.ai/balthazarneveu/audio-sep"
+ ("\n<<>>" if default_device == "cuda" else ""))
parser_def.add_argument("-i", "--input-dir", type=str, default=EXPERIMENT_STORAGE_ROOT)
parser_def.add_argument("-o", "--output-dir", type=str, default=EXPERIMENT_STORAGE_ROOT)
parser_def.add_argument("-d", "--device", type=str, default=default_device,
help="Training device", choices=["cpu", "cuda"])
parser_def.add_argument("-b", "--nb-batch", type=int, default=None,
help="Number of batches to process")
parser_def.add_argument("-s", "--snr-filter", type=float, nargs="+", default=None,
help="SNR filters on the inference dataloader")
args = parser_def.parse_args(argv)
record_row_dfs = pd.DataFrame(columns = RECORD_KEYS)
eval_dfs = []
for exp in args.experiments:
record_row_df, evaluation_path = launch_infer(
exp,
model_dir=Path(args.input_dir),
output_dir=Path(args.output_dir),
device=args.device,
max_batches=args.nb_batch,
snr_filter=args.snr_filter
)
eval_df = pd.read_csv(evaluation_path)
# Careful, list order for concat is important for index matching eval_dfs list
record_row_dfs = pd.concat([record_row_dfs.loc[:], record_row_df], ignore_index=True)
eval_dfs.append(eval_df)
app = get_app(record_row_dfs, eval_dfs)
app.run(debug=True)
if __name__ == '__main__':
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
main(sys.argv[1:])