hlnicholls commited on
Commit
d22276b
·
1 Parent(s): 1ededbc

Upload dynamic_shap_plot.py

Browse files
Files changed (1) hide show
  1. dynamic_shap_plot.py +115 -0
dynamic_shap_plot.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from shap_plots import shap_summary_plot, shap_dependence_plot
2
+ import plotly.tools as tls
3
+ import dash_core_components as dcc
4
+ import pandas as pd
5
+ import numpy as np
6
+ import xgboost
7
+ import shap
8
+ import matplotlib
9
+ import plotly.graph_objs as go
10
+ try:
11
+ import matplotlib.pyplot as pl
12
+ from matplotlib.colors import LinearSegmentedColormap
13
+ from matplotlib.ticker import MaxNLocator
14
+ except ImportError:
15
+ pass
16
+ from sklearn import preprocessing
17
+
18
+ cdict1 = {
19
+ 'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
20
+ (1.0, 0.9607843137254902, 0.9607843137254902)),
21
+
22
+ 'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
23
+ (1.0, 0.15294117647058825, 0.15294117647058825)),
24
+
25
+ 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
26
+ (1.0, 0.3411764705882353, 0.3411764705882353)),
27
+
28
+ 'alpha': ((0.0, 1, 1),
29
+ (0.5, 1, 1),
30
+ (1.0, 1, 1))
31
+ } # #1E88E5 -> #ff0052
32
+ red_blue = LinearSegmentedColormap('RedBlue', cdict1)
33
+
34
+ def matplotlib_to_plotly(cmap, pl_entries):
35
+ h = 1.0/(pl_entries-1)
36
+ pl_colorscale = []
37
+
38
+ for k in range(pl_entries):
39
+ C = list(map(np.uint8, np.array(cmap(k*h)[:3])*255))
40
+ pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])
41
+
42
+ return pl_colorscale
43
+
44
+ red_blue = matplotlib_to_plotly(red_blue, 255)
45
+
46
+ def summary_plot_plotly_fig(shap_values, df_shap, feature_names, max_display = 8):
47
+ #data = pd.read_csv(dataset, encoding="ISO-8859-1")
48
+ #X = data.drop(['target column'], axis=1)
49
+
50
+ #y = data[target]
51
+ #y = y/max(y)
52
+
53
+ #X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
54
+
55
+ #X_train.fillna((-999), inplace=True)
56
+ #X_test.fillna((-999), inplace=True)
57
+
58
+ #_, shap_values, feature_names = train_model_and_return_shap_values(X, y, target)
59
+
60
+ mpl_fig = shap_summary_plot(shap_values, df_shap, feature_names=feature_names, max_display=20)
61
+
62
+ plotly_fig = tls.mpl_to_plotly(mpl_fig)
63
+
64
+ plotly_fig['layout'] = {'xaxis': {'title': 'SHAP value (impact on model output)'}}
65
+
66
+ feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
67
+ feature_order = feature_order[-min(max_display, len(feature_order)):]
68
+ text = [feature_names[i] for i in feature_order]
69
+ text = iter(text)
70
+
71
+ for i in range(1, len(plotly_fig['data']), 2):
72
+ t = text.__next__()
73
+ plotly_fig['data'][i]['name'] = ''
74
+ plotly_fig['data'][i]['text'] = t
75
+ plotly_fig['data'][i]['hoverinfo'] = 'text'
76
+
77
+ colorbar_trace = go.Scatter(x=[None],
78
+ y=[None],
79
+ mode='markers',
80
+ marker=dict(
81
+ colorscale=red_blue,
82
+ showscale=True,
83
+ cmin=-5,
84
+ cmax=5,
85
+ colorbar=dict(thickness=5, tickvals=[-5, 5], ticktext=['Low', 'High'], outlinewidth=0)
86
+ ),
87
+ hoverinfo='none'
88
+ )
89
+
90
+ plotly_fig['layout']['showlegend'] = False
91
+ plotly_fig['layout']['hovermode'] = 'closest'
92
+ plotly_fig['layout']['height']=600
93
+ plotly_fig['layout']['width']=500
94
+
95
+ plotly_fig['layout']['xaxis'].update(zeroline=True, showline=True, ticklen=4, showgrid=False)
96
+ plotly_fig['layout']['yaxis'].update(dict(visible=False))
97
+ plotly_fig.add_trace(colorbar_trace)
98
+ plotly_fig.layout.update(
99
+ annotations=[dict(
100
+ x=1.18,
101
+ align="right",
102
+ valign="top",
103
+ text='Feature value',
104
+ showarrow=False,
105
+ xref="paper",
106
+ yref="paper",
107
+ xanchor="right",
108
+ yanchor="middle",
109
+ textangle=-90,
110
+ font=dict(family='Calibri', size=14)
111
+ )
112
+ ],
113
+ margin=dict(t=20)
114
+ )
115
+ return plotly_fig