File size: 4,965 Bytes
f6b56a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT
from gyraudio.audio_separation.parser import shared_parser
from gyraudio.audio_separation.infer import launch_infer, RECORD_KEYS, SNR_OUT, SNR_IN, NBATCH, SAVE_IDX
from gyraudio.audio_separation.properties import TEST, NAME, SHORT_NAME, CURRENT_EPOCH, SNR_FILTER 
import sys
import os
from dash import Dash, html, dcc, callback, Output, Input, dash_table
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
from typing import List
import torch
from pathlib import Path
DIFF_SNR = 'SNR out - SNR in'



def get_app(record_row_dfs : pd.DataFrame, eval_dfs : List[pd.DataFrame]) :
    app = Dash(__name__)
    # names_options = [{'label' : f"{record[SHORT_NAME]} - {record[NAME]} epoch {record[CURRENT_EPOCH]:04d}", 'value' : record[NAME] } for idx, record in record_row_dfs.iterrows()]
    app.layout = html.Div([
        html.H1(children='Inference results', style={'textAlign':'center'}),
        # dcc.Dropdown(names_options, names_options[0]['value'], id='exp-selection'),
        # dcc.RadioItems(['scatter', 'box'], 'box', inline=True, id='radio-plot-type'),
        dcc.RadioItems([SNR_OUT, DIFF_SNR], DIFF_SNR, inline=True, id='radio-plot-out'),
        dcc.Graph(id='graph-content')
    ])

    @callback(
        Output('graph-content', 'figure'),
        # Input('exp-selection', 'value'),
        # Input('radio-plot-type', 'value'),
        Input('radio-plot-out', 'value'),
    )
    def update_graph(radio_plot_out) :
        fig = make_subplots(rows = 2, cols = 1)
        colors = px.colors.qualitative.Plotly
        for id, record in record_row_dfs.iterrows() :
            color = colors[id % len(colors)]
            eval_df = eval_dfs[id].sort_values(by=SNR_IN)
            eval_df[DIFF_SNR] = eval_df[SNR_OUT] - eval_df[SNR_IN]
            legend = f'{record[SHORT_NAME]}_{record[NAME]}'
            fig.add_trace(
                go.Scatter(
                    x=eval_df[SNR_IN], 
                    y=eval_df[radio_plot_out], 
                    mode="markers", marker={'color' : color}, 
                    name=legend,
                    hovertemplate = 'File : %{text}'+
                        '<br>%{y}<br>',
                    text = [f"{eval[SAVE_IDX]:.0f}" for idx, eval in eval_df.iterrows()]
                    ),
                row = 1, col = 1
            )
            eval_df_bins = eval_df
            eval_df_bins[SNR_IN] = eval_df_bins[SNR_IN].apply(lambda snr : round(snr))
            fig.add_trace(
                go.Box(
                    x=eval_df[SNR_IN], 
                    y=eval_df[radio_plot_out], 
                    fillcolor = color, 
                    marker={'color' : color},
                    name = legend
                    ),
                row = 2, col = 1
            )

        title = f"SNR performances"
        fig.update_layout(
            title=title,
            xaxis2_title = SNR_IN,
            yaxis_title = radio_plot_out,
            hovermode='x unified'
            )
        return fig

        
    
    return app

    
def main(argv):
    default_device = "cuda" if torch.cuda.is_available() else "cpu"
    parser_def = shared_parser(help="Launch training \nCheck results at: https://wandb.ai/balthazarneveu/audio-sep"
                               + ("\n<<<Cuda available>>>" if default_device == "cuda" else ""))
    parser_def.add_argument("-i", "--input-dir", type=str, default=EXPERIMENT_STORAGE_ROOT)
    parser_def.add_argument("-o", "--output-dir", type=str, default=EXPERIMENT_STORAGE_ROOT)
    parser_def.add_argument("-d", "--device", type=str, default=default_device,
                            help="Training device", choices=["cpu", "cuda"])
    parser_def.add_argument("-b", "--nb-batch", type=int, default=None,
                    help="Number of batches to process")
    parser_def.add_argument("-s",  "--snr-filter", type=float, nargs="+", default=None,
                    help="SNR filters on the inference dataloader")
    args = parser_def.parse_args(argv)
    record_row_dfs = pd.DataFrame(columns = RECORD_KEYS)
    eval_dfs = []
    for exp in args.experiments:
        record_row_df, evaluation_path = launch_infer(
                exp,
                model_dir=Path(args.input_dir),
                output_dir=Path(args.output_dir),
                device=args.device,
                max_batches=args.nb_batch,
                snr_filter=args.snr_filter
            )
        eval_df = pd.read_csv(evaluation_path)
        # Careful, list order for concat is important for index matching eval_dfs list
        record_row_dfs = pd.concat([record_row_dfs.loc[:], record_row_df], ignore_index=True)
        eval_dfs.append(eval_df)
    app = get_app(record_row_dfs, eval_dfs)
    app.run(debug=True)


if __name__ == '__main__':
    os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
    main(sys.argv[1:])