File size: 4,420 Bytes
d22276b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303f06b
d22276b
 
 
 
 
303f06b
d22276b
303f06b
 
ca9fc66
d22276b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca9fc66
d22276b
 
 
 
 
 
303f06b
d22276b
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from shap_plots import shap_summary_plot, shap_dependence_plot
import plotly.tools as tls
import dash_core_components as dcc
import pandas as pd
import numpy as np
import xgboost
import shap
import matplotlib
import plotly.graph_objs as go
try:
    import matplotlib.pyplot as pl
    from matplotlib.colors import LinearSegmentedColormap
    from matplotlib.ticker import MaxNLocator
except ImportError:
    pass
from sklearn import preprocessing 

cdict1 = {
    'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
            (1.0, 0.9607843137254902, 0.9607843137254902)),

    'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
              (1.0, 0.15294117647058825, 0.15294117647058825)),

    'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
             (1.0, 0.3411764705882353, 0.3411764705882353)),

    'alpha': ((0.0, 1, 1),
              (0.5, 1, 1),
              (1.0, 1, 1))
}  # #1E88E5 -> #ff0052
red_blue = LinearSegmentedColormap('RedBlue', cdict1)

def matplotlib_to_plotly(cmap, pl_entries):
    h = 1.0/(pl_entries-1)
    pl_colorscale = []

    for k in range(pl_entries):
        C = list(map(np.uint8, np.array(cmap(k*h)[:3])*255))
        pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])

    return pl_colorscale

red_blue = matplotlib_to_plotly(red_blue, 255)

def summary_plot_plotly_fig(shap_values, df_shap, feature_names, max_display = 8):
    #data = pd.read_csv(dataset, encoding="ISO-8859-1")
    #X = data.drop(['target column'], axis=1)

    #y = data[target]
    #y = y/max(y)

    #X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)

    #X_train.fillna((-999), inplace=True) 
    #X_test.fillna((-999), inplace=True)

    #_, shap_values, feature_names = train_model_and_return_shap_values(X, y, target)

    mpl_fig = shap_summary_plot(shap_values, df_shap, feature_names=feature_names, max_display=20)

    plotly_fig = tls.mpl_to_plotly(mpl_fig)

    plotly_fig['layout'] = {'xaxis': {'title': 'SHAP value (impact on model output)'}}

    feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
    feature_order = feature_order[-min(max_display, len(feature_order)):]
    text = [df_shap.index[i] for i in df_shap.index]
    text = iter(text)

    for i in range(1, len(plotly_fig['data']), 2):
        t = text.__next__()
        plotly_fig['data'][i]['name'] = ''
        plotly_fig['data'][i]['text'] = t
        plotly_fig['data'][i]['hoverinfo'] = 'text'
        #plotly_fig['data'][i]['text'] = df_shap.index
        plotly_fig['data'][i]['y'] = feature_names[feature_order] 
        

    colorbar_trace  = go.Scatter(x=[None],
                                 y=[None],
                                 mode='markers',
                                 marker=dict(
                                     colorscale=red_blue, 
                                     showscale=True,
                                     cmin=-5,
                                     cmax=5,
                                     colorbar=dict(thickness=5, tickvals=[-5, 5], ticktext=['Low', 'High'], outlinewidth=0)
                                 ),
                                 hoverinfo='none'
                                )

    plotly_fig['layout']['showlegend'] = False
    plotly_fig['layout']['hovermode'] = 'closest'
    plotly_fig['layout']['height']=600
    plotly_fig['layout']['width']=500

    plotly_fig['layout']['xaxis'].update(zeroline=True, showline=True, ticklen=4, showgrid=False)
    plotly_fig['layout']['yaxis'].update(dict(visible=True))
    plotly_fig.add_trace(colorbar_trace)
    plotly_fig.layout.update(
                             annotations=[dict(
                                  x=1.18,
                                  align="right",
                                  valign="top",
                                  text='Gene',
                                  showarrow=False,
                                  xref="paper",
                                  yref="paper",
                                  xanchor="right",
                                  yanchor="middle",
                                  textangle=-90,
                                  font=dict(family='Calibri', size=14)
                                )
                             ],
                             margin=dict(t=20)
                            )
    return plotly_fig