hlnicholls commited on
Commit
ef3a227
·
1 Parent(s): 614dc1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -400
app.py CHANGED
@@ -5,6 +5,7 @@ import pandas as pd
5
  import sklearn
6
  import xgboost
7
  import shap
 
8
  import plotly.tools as tls
9
  import dash_core_components as dcc
10
  import matplotlib
@@ -76,406 +77,6 @@ explainer = shap.TreeExplainer(xgb)
76
  def convert_df(df):
77
  return df.to_csv(index=False).encode('utf-8')
78
 
79
- import warnings
80
- import iml
81
- import numpy as np
82
- from iml import Instance, Model
83
- from iml.datatypes import DenseData
84
- from iml.explanations import AdditiveExplanation
85
- from iml.links import IdentityLink
86
- from scipy.stats import gaussian_kde
87
- import matplotlib
88
- try:
89
- import matplotlib.pyplot as pl
90
- from matplotlib.colors import LinearSegmentedColormap
91
- from matplotlib.ticker import MaxNLocator
92
-
93
- cdict1 = {
94
- 'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
95
- (1.0, 0.9607843137254902, 0.9607843137254902)),
96
-
97
- 'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
98
- (1.0, 0.15294117647058825, 0.15294117647058825)),
99
-
100
- 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
101
- (1.0, 0.3411764705882353, 0.3411764705882353)),
102
-
103
- 'alpha': ((0.0, 1, 1),
104
- (0.5, 0.3, 0.3),
105
- (1.0, 1, 1))
106
- } # #1E88E5 -> #ff0052
107
- red_blue = LinearSegmentedColormap('RedBlue', cdict1)
108
-
109
- cdict1 = {
110
- 'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
111
- (1.0, 0.9607843137254902, 0.9607843137254902)),
112
-
113
- 'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
114
- (1.0, 0.15294117647058825, 0.15294117647058825)),
115
-
116
- 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
117
- (1.0, 0.3411764705882353, 0.3411764705882353)),
118
-
119
- 'alpha': ((0.0, 1, 1),
120
- (0.5, 1, 1),
121
- (1.0, 1, 1))
122
- } # #1E88E5 -> #ff0052
123
- red_blue_solid = LinearSegmentedColormap('RedBlue', cdict1)
124
- except ImportError:
125
- pass
126
-
127
- labels = {
128
- 'MAIN_EFFECT': "SHAP main effect value for\n%s",
129
- 'INTERACTION_VALUE': "SHAP interaction value",
130
- 'INTERACTION_EFFECT': "SHAP interaction value for\n%s and %s",
131
- 'VALUE': "SHAP value (impact on model output)",
132
- 'VALUE_FOR': "SHAP value for\n%s",
133
- 'PLOT_FOR': "SHAP plot for %s",
134
- 'FEATURE': "Feature %s",
135
- 'FEATURE_VALUE': "Feature value",
136
- 'FEATURE_VALUE_LOW': "Low",
137
- 'FEATURE_VALUE_HIGH': "High",
138
- 'JOINT_VALUE': "Joint SHAP value"
139
- }
140
-
141
- def shap_summary_plot(shap_values, features=None, feature_names=None, max_display=None, plot_type="dot",
142
- color=None, axis_color="#333333", title=None, alpha=1, show=True, sort=True,
143
- color_bar=True, auto_size_plot=True, layered_violin_max_num_bins=20):
144
- """Create a SHAP summary plot, colored by feature values when they are provided.
145
- Parameters
146
- ----------
147
- shap_values : numpy.array
148
- Matrix of SHAP values (# samples x # features)
149
- features : numpy.array or pandas.DataFrame or list
150
- Matrix of feature values (# samples x # features) or a feature_names list as shorthand
151
- feature_names : list
152
- Names of the features (length # features)
153
- max_display : int
154
- How many top features to include in the plot (default is 20, or 7 for interaction plots)
155
- plot_type : "dot" (default) or "violin"
156
- What type of summary plot to produce
157
- """
158
-
159
- assert len(shap_values.shape) != 1, "Summary plots need a matrix of shap_values, not a vector."
160
-
161
- # default color:
162
- if color is None:
163
- color = "coolwarm" if plot_type == 'layered_violin' else "#ff0052"
164
-
165
- # convert from a DataFrame or other types
166
- if str(type(features)) == "<class 'pandas.core.frame.DataFrame'>":
167
- if feature_names is None:
168
- feature_names = features.columns
169
- features = features.values
170
- elif str(type(features)) == "<class 'list'>":
171
- if feature_names is None:
172
- feature_names = features
173
- features = None
174
- elif (features is not None) and len(features.shape) == 1 and feature_names is None:
175
- feature_names = features
176
- features = None
177
-
178
- if feature_names is None:
179
- feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)]
180
-
181
- mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1))
182
-
183
- # plotting SHAP interaction values
184
- if len(shap_values.shape) == 3:
185
- if max_display is None:
186
- max_display = 7
187
- else:
188
- max_display = min(len(feature_names), max_display)
189
-
190
- sort_inds = np.argsort(-np.abs(shap_values[:, :-1, :-1].sum(1)).sum(0))
191
-
192
- # get plotting limits
193
- delta = 1.0 / (shap_values.shape[1] ** 2)
194
- slow = np.nanpercentile(shap_values, delta)
195
- shigh = np.nanpercentile(shap_values, 100 - delta)
196
- v = max(abs(slow), abs(shigh))
197
- slow = -0.2
198
- shigh = 0.2
199
-
200
- # mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1))
201
- ax = mpl_fig.subplot(1, max_display, 1)
202
- proj_shap_values = shap_values[:, sort_inds[0], np.hstack((sort_inds, len(sort_inds)))]
203
- proj_shap_values[:, 1:] *= 2 # because off diag effects are split in half
204
- shap_summary_plot(
205
- proj_shap_values, features[:, sort_inds],
206
- feature_names=feature_names[sort_inds],
207
- sort=False, show=False, color_bar=False,
208
- auto_size_plot=False,
209
- max_display=max_display
210
- )
211
- pl.xlim((slow, shigh))
212
- pl.xlabel("")
213
- title_length_limit = 11
214
- pl.title(shorten_text(feature_names[sort_inds[0]], title_length_limit))
215
- for i in range(1, max_display):
216
- ind = sort_inds[i]
217
- pl.subplot(1, max_display, i + 1)
218
- proj_shap_values = shap_values[:, ind, np.hstack((sort_inds, len(sort_inds)))]
219
- proj_shap_values *= 2
220
- proj_shap_values[:, i] /= 2 # because only off diag effects are split in half
221
- shap_summary_plot(
222
- proj_shap_values, features[:, sort_inds],
223
- sort=False,
224
- feature_names=df_shap.columns, #["" for i in range(features.shape[1])],
225
- show=False,
226
- color_bar=False,
227
- auto_size_plot=False,
228
- max_display=max_display
229
- )
230
- pl.xlim((slow, shigh))
231
- pl.xlabel("")
232
- if i == max_display // 2:
233
- pl.xlabel(labels['INTERACTION_VALUE'])
234
- pl.title(shorten_text(feature_names[ind], title_length_limit))
235
- pl.tight_layout(pad=0, w_pad=0, h_pad=0.0)
236
- pl.subplots_adjust(hspace=0, wspace=0.1)
237
- # if show:
238
- # # pl.show()
239
- return mpl_fig
240
-
241
- if max_display is None:
242
- max_display = 20
243
-
244
- if sort:
245
- # order features by the sum of their effect magnitudes
246
- feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
247
- feature_order = feature_order[-min(max_display, len(feature_order)):]
248
- else:
249
- feature_order = np.flip(np.arange(min(max_display, shap_values.shape[1] - 1)), 0)
250
-
251
- row_height = 0.4
252
- if auto_size_plot:
253
- pl.gcf().set_size_inches(8, len(feature_order) * row_height + 1.5)
254
- pl.axvline(x=0, color="#999999", zorder=-1)
255
-
256
- if plot_type == "dot":
257
- for pos, i in enumerate(feature_order):
258
- pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
259
- shaps = shap_values[:, i]
260
- values = None if features is None else features[:, i]
261
- inds = np.arange(len(shaps))
262
- np.random.shuffle(inds)
263
- if values is not None:
264
- values = values[inds]
265
- shaps = shaps[inds]
266
- colored_feature = True
267
- try:
268
- values = np.array(values, dtype=np.float64) # make sure this can be numeric
269
- except:
270
- colored_feature = False
271
- N = len(shaps)
272
- # hspacing = (np.max(shaps) - np.min(shaps)) / 200
273
- # curr_bin = []
274
- nbins = 100
275
- quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))
276
- inds = np.argsort(quant + np.random.randn(N) * 1e-6)
277
- layer = 0
278
- last_bin = -1
279
- ys = np.zeros(N)
280
- for ind in inds:
281
- if quant[ind] != last_bin:
282
- layer = 0
283
- ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1)
284
- layer += 1
285
- last_bin = quant[ind]
286
- ys *= 0.9 * (row_height / np.max(ys + 1))
287
-
288
- if features is not None and colored_feature:
289
- # trim the color range, but prevent the color range from collapsing
290
- vmin = np.nanpercentile(values, 5)
291
- vmax = np.nanpercentile(values, 95)
292
- if vmin == vmax:
293
- vmin = np.nanpercentile(values, 1)
294
- vmax = np.nanpercentile(values, 99)
295
- if vmin == vmax:
296
- vmin = np.min(values)
297
- vmax = np.max(values)
298
-
299
- assert features.shape[0] == len(shaps), "Feature and SHAP matrices must have the same number of rows!"
300
- nan_mask = np.isnan(values)
301
- pl.scatter(shaps[nan_mask], pos + ys[nan_mask], color="#777777", vmin=vmin,
302
- vmax=vmax, s=16, alpha=alpha, linewidth=0,
303
- zorder=3, rasterized=len(shaps) > 500)
304
- pl.scatter(shaps[np.invert(nan_mask)], pos + ys[np.invert(nan_mask)],
305
- cmap=red_blue, vmin=vmin, vmax=vmax, s=16,
306
- c=values[np.invert(nan_mask)], alpha=alpha, linewidth=0,
307
- zorder=3, rasterized=len(shaps) > 500)
308
- else:
309
-
310
- pl.scatter(shaps, pos + ys, s=16, alpha=alpha, linewidth=0, zorder=3,
311
- color=color if colored_feature else "#777777", rasterized=len(shaps) > 500)
312
-
313
- elif plot_type == "violin":
314
- for pos, i in enumerate(feature_order):
315
- pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
316
-
317
- if features is not None:
318
- global_low = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 1)
319
- global_high = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 99)
320
- for pos, i in enumerate(feature_order):
321
- shaps = shap_values[:, i]
322
- shap_min, shap_max = np.min(shaps), np.max(shaps)
323
- rng = shap_max - shap_min
324
- xs = np.linspace(np.min(shaps) - rng * 0.2, np.max(shaps) + rng * 0.2, 100)
325
- if np.std(shaps) < (global_high - global_low) / 100:
326
- ds = gaussian_kde(shaps + np.random.randn(len(shaps)) * (global_high - global_low) / 100)(xs)
327
- else:
328
- ds = gaussian_kde(shaps)(xs)
329
- ds /= np.max(ds) * 3
330
-
331
- values = features[:, i]
332
- window_size = max(10, len(values) // 20)
333
- smooth_values = np.zeros(len(xs) - 1)
334
- sort_inds = np.argsort(shaps)
335
- trailing_pos = 0
336
- leading_pos = 0
337
- running_sum = 0
338
- back_fill = 0
339
- for j in range(len(xs) - 1):
340
-
341
- while leading_pos < len(shaps) and xs[j] >= shaps[sort_inds[leading_pos]]:
342
- running_sum += values[sort_inds[leading_pos]]
343
- leading_pos += 1
344
- if leading_pos - trailing_pos > 20:
345
- running_sum -= values[sort_inds[trailing_pos]]
346
- trailing_pos += 1
347
- if leading_pos - trailing_pos > 0:
348
- smooth_values[j] = running_sum / (leading_pos - trailing_pos)
349
- for k in range(back_fill):
350
- smooth_values[j - k - 1] = smooth_values[j]
351
- else:
352
- back_fill += 1
353
-
354
- vmin = np.nanpercentile(values, 5)
355
- vmax = np.nanpercentile(values, 95)
356
- if vmin == vmax:
357
- vmin = np.nanpercentile(values, 1)
358
- vmax = np.nanpercentile(values, 99)
359
- if vmin == vmax:
360
- vmin = np.min(values)
361
- vmax = np.max(values)
362
- pl.scatter(shaps, np.ones(shap_values.shape[0]) * pos, s=9, cmap=red_blue_solid, vmin=vmin, vmax=vmax,
363
- c=values, alpha=alpha, linewidth=0, zorder=1)
364
- # smooth_values -= nxp.nanpercentile(smooth_values, 5)
365
- # smooth_values /= np.nanpercentile(smooth_values, 95)
366
- smooth_values -= vmin
367
- if vmax - vmin > 0:
368
- smooth_values /= vmax - vmin
369
- for i in range(len(xs) - 1):
370
- if ds[i] > 0.05 or ds[i + 1] > 0.05:
371
- pl.fill_between([xs[i], xs[i + 1]], [pos + ds[i], pos + ds[i + 1]],
372
- [pos - ds[i], pos - ds[i + 1]], color=red_blue_solid(smooth_values[i]),
373
- zorder=2)
374
-
375
- else:
376
- parts = pl.violinplot(shap_values[:, feature_order], range(len(feature_order)), points=200, vert=False,
377
- widths=0.7,
378
- showmeans=False, showextrema=False, showmedians=False)
379
-
380
- for pc in parts['bodies']:
381
- pc.set_facecolor(color)
382
- pc.set_edgecolor('none')
383
- pc.set_alpha(alpha)
384
-
385
- elif plot_type == "layered_violin": # courtesy of @kodonnell
386
- num_x_points = 200
387
- bins = np.linspace(0, features.shape[0], layered_violin_max_num_bins + 1).round(0).astype(
388
- 'int') # the indices of the feature data corresponding to each bin
389
- shap_min, shap_max = np.min(shap_values[:, :-1]), np.max(shap_values[:, :-1])
390
- x_points = np.linspace(shap_min, shap_max, num_x_points)
391
-
392
- # loop through each feature and plot:
393
- for pos, ind in enumerate(feature_order):
394
- # decide how to handle: if #unique < layered_violin_max_num_bins then split by unique value, otherwise use bins/percentiles.
395
- # to keep simpler code, in the case of uniques, we just adjust the bins to align with the unique counts.
396
- feature = features[:, ind]
397
- unique, counts = np.unique(feature, return_counts=True)
398
- if unique.shape[0] <= layered_violin_max_num_bins:
399
- order = np.argsort(unique)
400
- thesebins = np.cumsum(counts[order])
401
- thesebins = np.insert(thesebins, 0, 0)
402
- else:
403
- thesebins = bins
404
- nbins = thesebins.shape[0] - 1
405
- # order the feature data so we can apply percentiling
406
- order = np.argsort(feature)
407
- # x axis is located at y0 = pos, with pos being there for offset
408
- y0 = np.ones(num_x_points) * pos
409
- # calculate kdes:
410
- ys = np.zeros((nbins, num_x_points))
411
- for i in range(nbins):
412
- # get shap values in this bin:
413
- shaps = shap_values[order[thesebins[i]:thesebins[i + 1]], ind]
414
- # if there's only one element, then we can't
415
- if shaps.shape[0] == 1:
416
- warnings.warn(
417
- "not enough data in bin #%d for feature %s, so it'll be ignored. Try increasing the number of records to plot."
418
- % (i, feature_names[ind]))
419
- # to ignore it, just set it to the previous y-values (so the area between them will be zero). Not ys is already 0, so there's
420
- # nothing to do if i == 0
421
- if i > 0:
422
- ys[i, :] = ys[i - 1, :]
423
- continue
424
- # save kde of them: note that we add a tiny bit of gaussian noise to avoid singular matrix errors
425
- ys[i, :] = gaussian_kde(shaps + np.random.normal(loc=0, scale=0.001, size=shaps.shape[0]))(x_points)
426
- # scale it up so that the 'size' of each y represents the size of the bin. For continuous data this will
427
- # do nothing, but when we've gone with the unqique option, this will matter - e.g. if 99% are male and 1%
428
- # female, we want the 1% to appear a lot smaller.
429
- size = thesebins[i + 1] - thesebins[i]
430
- bin_size_if_even = features.shape[0] / nbins
431
- relative_bin_size = size / bin_size_if_even
432
- ys[i, :] *= relative_bin_size
433
- # now plot 'em. We don't plot the individual strips, as this can leave whitespace between them.
434
- # instead, we plot the full kde, then remove outer strip and plot over it, etc., to ensure no
435
- # whitespace
436
- ys = np.cumsum(ys, axis=0)
437
- width = 0.8
438
- scale = ys.max() * 2 / width # 2 is here as we plot both sides of x axis
439
- for i in range(nbins - 1, -1, -1):
440
- y = ys[i, :] / scale
441
- c = pl.get_cmap(color)(i / (
442
- nbins - 1)) if color in pl.cm.datad else color # if color is a cmap, use it, otherwise use a color
443
- pl.fill_between(x_points, pos - y, pos + y, facecolor=c)
444
- pl.xlim(shap_min, shap_max)
445
-
446
- # draw the color bar
447
- if color_bar and features is not None and (plot_type != "layered_violin" or color in pl.cm.datad):
448
- import matplotlib.cm as cm
449
- m = cm.ScalarMappable(cmap=red_blue_solid if plot_type != "layered_violin" else pl.get_cmap(color))
450
- m.set_array([0, 1])
451
- cb = pl.colorbar(m, ticks=[0, 1], aspect=1000)
452
- cb.set_ticklabels([labels['FEATURE_VALUE_LOW'], labels['FEATURE_VALUE_HIGH']])
453
- cb.set_label(labels['FEATURE_VALUE'], size=12, labelpad=0)
454
- cb.ax.tick_params(labelsize=11, length=0)
455
- cb.set_alpha(1)
456
- cb.outline.set_visible(False)
457
- bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
458
- cb.ax.set_aspect((bbox.height - 0.9) * 20)
459
- # cb.draw_all()
460
-
461
- pl.gca().xaxis.set_ticks_position('bottom')
462
- pl.gca().yaxis.set_ticks_position('none')
463
- pl.gca().spines['right'].set_visible(False)
464
- pl.gca().spines['top'].set_visible(False)
465
- pl.gca().spines['left'].set_visible(False)
466
- pl.gca().tick_params(color=axis_color, labelcolor=axis_color)
467
- col_order = [df.columns[i] for i in feature_order]
468
- pl.yticks(col_order, fontsize=13)
469
- #pl.yticks(range(len(feature_order)), [feature_names[i] for i in feature_order], fontsize=13)
470
- pl.gca().tick_params('y', length=20, width=0.5, which='major')
471
- pl.gca().tick_params('x', labelsize=11)
472
- pl.ylim(-1, len(feature_order))
473
- pl.xlabel(labels['VALUE'], fontsize=13)
474
- pl.tight_layout()
475
- # if show:
476
- # pl.show()
477
- return mpl_fig
478
-
479
 
480
  cdict1 = {
481
  'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
 
5
  import sklearn
6
  import xgboost
7
  import shap
8
+ from shap_plots import shap_summary_plot
9
  import plotly.tools as tls
10
  import dash_core_components as dcc
11
  import matplotlib
 
77
  def convert_df(df):
78
  return df.to_csv(index=False).encode('utf-8')
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  cdict1 = {
82
  'red': ((0.0, 0.11764705882352941, 0.11764705882352941),