hlnicholls commited on
Commit
20aea5e
·
1 Parent(s): 9cea6b7

Upload shap_plots.py

Browse files
Files changed (1) hide show
  1. shap_plots.py +730 -0
shap_plots.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import iml
3
+ import numpy as np
4
+ from iml import Instance, Model
5
+ from iml.datatypes import DenseData
6
+ from iml.explanations import AdditiveExplanation
7
+ from iml.links import IdentityLink
8
+ from scipy.stats import gaussian_kde
9
+ import matplotlib
10
+ try:
11
+ import matplotlib.pyplot as pl
12
+ from matplotlib.colors import LinearSegmentedColormap
13
+ from matplotlib.ticker import MaxNLocator
14
+
15
+ cdict1 = {
16
+ 'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
17
+ (1.0, 0.9607843137254902, 0.9607843137254902)),
18
+
19
+ 'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
20
+ (1.0, 0.15294117647058825, 0.15294117647058825)),
21
+
22
+ 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
23
+ (1.0, 0.3411764705882353, 0.3411764705882353)),
24
+
25
+ 'alpha': ((0.0, 1, 1),
26
+ (0.5, 0.3, 0.3),
27
+ (1.0, 1, 1))
28
+ } # #1E88E5 -> #ff0052
29
+ red_blue = LinearSegmentedColormap('RedBlue', cdict1)
30
+
31
+ cdict1 = {
32
+ 'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
33
+ (1.0, 0.9607843137254902, 0.9607843137254902)),
34
+
35
+ 'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
36
+ (1.0, 0.15294117647058825, 0.15294117647058825)),
37
+
38
+ 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
39
+ (1.0, 0.3411764705882353, 0.3411764705882353)),
40
+
41
+ 'alpha': ((0.0, 1, 1),
42
+ (0.5, 1, 1),
43
+ (1.0, 1, 1))
44
+ } # #1E88E5 -> #ff0052
45
+ red_blue_solid = LinearSegmentedColormap('RedBlue', cdict1)
46
+ except ImportError:
47
+ pass
48
+
49
+ labels = {
50
+ 'MAIN_EFFECT': "SHAP main effect value for\n%s",
51
+ 'INTERACTION_VALUE': "SHAP interaction value",
52
+ 'INTERACTION_EFFECT': "SHAP interaction value for\n%s and %s",
53
+ 'VALUE': "SHAP value (impact on model output)",
54
+ 'VALUE_FOR': "SHAP value for\n%s",
55
+ 'PLOT_FOR': "SHAP plot for %s",
56
+ 'FEATURE': "Feature %s",
57
+ 'FEATURE_VALUE': "Feature value",
58
+ 'FEATURE_VALUE_LOW': "Low",
59
+ 'FEATURE_VALUE_HIGH': "High",
60
+ 'JOINT_VALUE': "Joint SHAP value"
61
+ }
62
+
63
+ def shap_summary_plot(shap_values, features=None, feature_names=None, max_display=None, plot_type="dot",
64
+ color=None, axis_color="#333333", title=None, alpha=1, show=True, sort=True,
65
+ color_bar=True, auto_size_plot=True, layered_violin_max_num_bins=20):
66
+ """Create a SHAP summary plot, colored by feature values when they are provided.
67
+
68
+ Parameters
69
+ ----------
70
+ shap_values : numpy.array
71
+ Matrix of SHAP values (# samples x # features)
72
+
73
+ features : numpy.array or pandas.DataFrame or list
74
+ Matrix of feature values (# samples x # features) or a feature_names list as shorthand
75
+
76
+ feature_names : list
77
+ Names of the features (length # features)
78
+
79
+ max_display : int
80
+ How many top features to include in the plot (default is 20, or 7 for interaction plots)
81
+
82
+ plot_type : "dot" (default) or "violin"
83
+ What type of summary plot to produce
84
+ """
85
+
86
+ assert len(shap_values.shape) != 1, "Summary plots need a matrix of shap_values, not a vector."
87
+
88
+ # default color:
89
+ if color is None:
90
+ color = "coolwarm" if plot_type == 'layered_violin' else "#ff0052"
91
+
92
+ # convert from a DataFrame or other types
93
+ if str(type(features)) == "<class 'pandas.core.frame.DataFrame'>":
94
+ if feature_names is None:
95
+ feature_names = features.columns
96
+ features = features.values
97
+ elif str(type(features)) == "<class 'list'>":
98
+ if feature_names is None:
99
+ feature_names = features
100
+ features = None
101
+ elif (features is not None) and len(features.shape) == 1 and feature_names is None:
102
+ feature_names = features
103
+ features = None
104
+
105
+ if feature_names is None:
106
+ feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)]
107
+
108
+ mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1))
109
+
110
+ # plotting SHAP interaction values
111
+ if len(shap_values.shape) == 3:
112
+ if max_display is None:
113
+ max_display = 7
114
+ else:
115
+ max_display = min(len(feature_names), max_display)
116
+
117
+ sort_inds = np.argsort(-np.abs(shap_values[:, :-1, :-1].sum(1)).sum(0))
118
+
119
+ # get plotting limits
120
+ delta = 1.0 / (shap_values.shape[1] ** 2)
121
+ slow = np.nanpercentile(shap_values, delta)
122
+ shigh = np.nanpercentile(shap_values, 100 - delta)
123
+ v = max(abs(slow), abs(shigh))
124
+ slow = -0.2
125
+ shigh = 0.2
126
+
127
+ # mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1))
128
+ ax = mpl_fig.subplot(1, max_display, 1)
129
+ proj_shap_values = shap_values[:, sort_inds[0], np.hstack((sort_inds, len(sort_inds)))]
130
+ proj_shap_values[:, 1:] *= 2 # because off diag effects are split in half
131
+ shap_summary_plot(
132
+ proj_shap_values, features[:, sort_inds],
133
+ feature_names=feature_names[sort_inds],
134
+ sort=False, show=False, color_bar=False,
135
+ auto_size_plot=False,
136
+ max_display=max_display
137
+ )
138
+ pl.xlim((slow, shigh))
139
+ pl.xlabel("")
140
+ title_length_limit = 11
141
+ pl.title(shorten_text(feature_names[sort_inds[0]], title_length_limit))
142
+ for i in range(1, max_display):
143
+ ind = sort_inds[i]
144
+ pl.subplot(1, max_display, i + 1)
145
+ proj_shap_values = shap_values[:, ind, np.hstack((sort_inds, len(sort_inds)))]
146
+ proj_shap_values *= 2
147
+ proj_shap_values[:, i] /= 2 # because only off diag effects are split in half
148
+ shap_summary_plot(
149
+ proj_shap_values, features[:, sort_inds],
150
+ sort=False,
151
+ feature_names=["" for i in range(features.shape[1])],
152
+ show=False,
153
+ color_bar=False,
154
+ auto_size_plot=False,
155
+ max_display=max_display
156
+ )
157
+ pl.xlim((slow, shigh))
158
+ pl.xlabel("")
159
+ if i == max_display // 2:
160
+ pl.xlabel(labels['INTERACTION_VALUE'])
161
+ pl.title(shorten_text(feature_names[ind], title_length_limit))
162
+ pl.tight_layout(pad=0, w_pad=0, h_pad=0.0)
163
+ pl.subplots_adjust(hspace=0, wspace=0.1)
164
+ # if show:
165
+ # # pl.show()
166
+ return mpl_fig
167
+
168
+ if max_display is None:
169
+ max_display = 20
170
+
171
+ if sort:
172
+ # order features by the sum of their effect magnitudes
173
+ feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
174
+ feature_order = feature_order[-min(max_display, len(feature_order)):]
175
+ else:
176
+ feature_order = np.flip(np.arange(min(max_display, shap_values.shape[1] - 1)), 0)
177
+
178
+ row_height = 0.4
179
+ if auto_size_plot:
180
+ pl.gcf().set_size_inches(8, len(feature_order) * row_height + 1.5)
181
+ pl.axvline(x=0, color="#999999", zorder=-1)
182
+
183
+ if plot_type == "dot":
184
+ for pos, i in enumerate(feature_order):
185
+ pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
186
+ shaps = shap_values[:, i]
187
+ values = None if features is None else features[:, i]
188
+ inds = np.arange(len(shaps))
189
+ np.random.shuffle(inds)
190
+ if values is not None:
191
+ values = values[inds]
192
+ shaps = shaps[inds]
193
+ colored_feature = True
194
+ try:
195
+ values = np.array(values, dtype=np.float64) # make sure this can be numeric
196
+ except:
197
+ colored_feature = False
198
+ N = len(shaps)
199
+ # hspacing = (np.max(shaps) - np.min(shaps)) / 200
200
+ # curr_bin = []
201
+ nbins = 100
202
+ quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))
203
+ inds = np.argsort(quant + np.random.randn(N) * 1e-6)
204
+ layer = 0
205
+ last_bin = -1
206
+ ys = np.zeros(N)
207
+ for ind in inds:
208
+ if quant[ind] != last_bin:
209
+ layer = 0
210
+ ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1)
211
+ layer += 1
212
+ last_bin = quant[ind]
213
+ ys *= 0.9 * (row_height / np.max(ys + 1))
214
+
215
+ if features is not None and colored_feature:
216
+ # trim the color range, but prevent the color range from collapsing
217
+ vmin = np.nanpercentile(values, 5)
218
+ vmax = np.nanpercentile(values, 95)
219
+ if vmin == vmax:
220
+ vmin = np.nanpercentile(values, 1)
221
+ vmax = np.nanpercentile(values, 99)
222
+ if vmin == vmax:
223
+ vmin = np.min(values)
224
+ vmax = np.max(values)
225
+
226
+ assert features.shape[0] == len(shaps), "Feature and SHAP matrices must have the same number of rows!"
227
+ nan_mask = np.isnan(values)
228
+ pl.scatter(shaps[nan_mask], pos + ys[nan_mask], color="#777777", vmin=vmin,
229
+ vmax=vmax, s=16, alpha=alpha, linewidth=0,
230
+ zorder=3, rasterized=len(shaps) > 500)
231
+ pl.scatter(shaps[np.invert(nan_mask)], pos + ys[np.invert(nan_mask)],
232
+ cmap=red_blue, vmin=vmin, vmax=vmax, s=16,
233
+ c=values[np.invert(nan_mask)], alpha=alpha, linewidth=0,
234
+ zorder=3, rasterized=len(shaps) > 500)
235
+ else:
236
+
237
+ pl.scatter(shaps, pos + ys, s=16, alpha=alpha, linewidth=0, zorder=3,
238
+ color=color if colored_feature else "#777777", rasterized=len(shaps) > 500)
239
+
240
+ elif plot_type == "violin":
241
+ for pos, i in enumerate(feature_order):
242
+ pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
243
+
244
+ if features is not None:
245
+ global_low = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 1)
246
+ global_high = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 99)
247
+ for pos, i in enumerate(feature_order):
248
+ shaps = shap_values[:, i]
249
+ shap_min, shap_max = np.min(shaps), np.max(shaps)
250
+ rng = shap_max - shap_min
251
+ xs = np.linspace(np.min(shaps) - rng * 0.2, np.max(shaps) + rng * 0.2, 100)
252
+ if np.std(shaps) < (global_high - global_low) / 100:
253
+ ds = gaussian_kde(shaps + np.random.randn(len(shaps)) * (global_high - global_low) / 100)(xs)
254
+ else:
255
+ ds = gaussian_kde(shaps)(xs)
256
+ ds /= np.max(ds) * 3
257
+
258
+ values = features[:, i]
259
+ window_size = max(10, len(values) // 20)
260
+ smooth_values = np.zeros(len(xs) - 1)
261
+ sort_inds = np.argsort(shaps)
262
+ trailing_pos = 0
263
+ leading_pos = 0
264
+ running_sum = 0
265
+ back_fill = 0
266
+ for j in range(len(xs) - 1):
267
+
268
+ while leading_pos < len(shaps) and xs[j] >= shaps[sort_inds[leading_pos]]:
269
+ running_sum += values[sort_inds[leading_pos]]
270
+ leading_pos += 1
271
+ if leading_pos - trailing_pos > 20:
272
+ running_sum -= values[sort_inds[trailing_pos]]
273
+ trailing_pos += 1
274
+ if leading_pos - trailing_pos > 0:
275
+ smooth_values[j] = running_sum / (leading_pos - trailing_pos)
276
+ for k in range(back_fill):
277
+ smooth_values[j - k - 1] = smooth_values[j]
278
+ else:
279
+ back_fill += 1
280
+
281
+ vmin = np.nanpercentile(values, 5)
282
+ vmax = np.nanpercentile(values, 95)
283
+ if vmin == vmax:
284
+ vmin = np.nanpercentile(values, 1)
285
+ vmax = np.nanpercentile(values, 99)
286
+ if vmin == vmax:
287
+ vmin = np.min(values)
288
+ vmax = np.max(values)
289
+ pl.scatter(shaps, np.ones(shap_values.shape[0]) * pos, s=9, cmap=red_blue_solid, vmin=vmin, vmax=vmax,
290
+ c=values, alpha=alpha, linewidth=0, zorder=1)
291
+ # smooth_values -= nxp.nanpercentile(smooth_values, 5)
292
+ # smooth_values /= np.nanpercentile(smooth_values, 95)
293
+ smooth_values -= vmin
294
+ if vmax - vmin > 0:
295
+ smooth_values /= vmax - vmin
296
+ for i in range(len(xs) - 1):
297
+ if ds[i] > 0.05 or ds[i + 1] > 0.05:
298
+ pl.fill_between([xs[i], xs[i + 1]], [pos + ds[i], pos + ds[i + 1]],
299
+ [pos - ds[i], pos - ds[i + 1]], color=red_blue_solid(smooth_values[i]),
300
+ zorder=2)
301
+
302
+ else:
303
+ parts = pl.violinplot(shap_values[:, feature_order], range(len(feature_order)), points=200, vert=False,
304
+ widths=0.7,
305
+ showmeans=False, showextrema=False, showmedians=False)
306
+
307
+ for pc in parts['bodies']:
308
+ pc.set_facecolor(color)
309
+ pc.set_edgecolor('none')
310
+ pc.set_alpha(alpha)
311
+
312
+ elif plot_type == "layered_violin": # courtesy of @kodonnell
313
+ num_x_points = 200
314
+ bins = np.linspace(0, features.shape[0], layered_violin_max_num_bins + 1).round(0).astype(
315
+ 'int') # the indices of the feature data corresponding to each bin
316
+ shap_min, shap_max = np.min(shap_values[:, :-1]), np.max(shap_values[:, :-1])
317
+ x_points = np.linspace(shap_min, shap_max, num_x_points)
318
+
319
+ # loop through each feature and plot:
320
+ for pos, ind in enumerate(feature_order):
321
+ # decide how to handle: if #unique < layered_violin_max_num_bins then split by unique value, otherwise use bins/percentiles.
322
+ # to keep simpler code, in the case of uniques, we just adjust the bins to align with the unique counts.
323
+ feature = features[:, ind]
324
+ unique, counts = np.unique(feature, return_counts=True)
325
+ if unique.shape[0] <= layered_violin_max_num_bins:
326
+ order = np.argsort(unique)
327
+ thesebins = np.cumsum(counts[order])
328
+ thesebins = np.insert(thesebins, 0, 0)
329
+ else:
330
+ thesebins = bins
331
+ nbins = thesebins.shape[0] - 1
332
+ # order the feature data so we can apply percentiling
333
+ order = np.argsort(feature)
334
+ # x axis is located at y0 = pos, with pos being there for offset
335
+ y0 = np.ones(num_x_points) * pos
336
+ # calculate kdes:
337
+ ys = np.zeros((nbins, num_x_points))
338
+ for i in range(nbins):
339
+ # get shap values in this bin:
340
+ shaps = shap_values[order[thesebins[i]:thesebins[i + 1]], ind]
341
+ # if there's only one element, then we can't
342
+ if shaps.shape[0] == 1:
343
+ warnings.warn(
344
+ "not enough data in bin #%d for feature %s, so it'll be ignored. Try increasing the number of records to plot."
345
+ % (i, feature_names[ind]))
346
+ # to ignore it, just set it to the previous y-values (so the area between them will be zero). Not ys is already 0, so there's
347
+ # nothing to do if i == 0
348
+ if i > 0:
349
+ ys[i, :] = ys[i - 1, :]
350
+ continue
351
+ # save kde of them: note that we add a tiny bit of gaussian noise to avoid singular matrix errors
352
+ ys[i, :] = gaussian_kde(shaps + np.random.normal(loc=0, scale=0.001, size=shaps.shape[0]))(x_points)
353
+ # scale it up so that the 'size' of each y represents the size of the bin. For continuous data this will
354
+ # do nothing, but when we've gone with the unqique option, this will matter - e.g. if 99% are male and 1%
355
+ # female, we want the 1% to appear a lot smaller.
356
+ size = thesebins[i + 1] - thesebins[i]
357
+ bin_size_if_even = features.shape[0] / nbins
358
+ relative_bin_size = size / bin_size_if_even
359
+ ys[i, :] *= relative_bin_size
360
+ # now plot 'em. We don't plot the individual strips, as this can leave whitespace between them.
361
+ # instead, we plot the full kde, then remove outer strip and plot over it, etc., to ensure no
362
+ # whitespace
363
+ ys = np.cumsum(ys, axis=0)
364
+ width = 0.8
365
+ scale = ys.max() * 2 / width # 2 is here as we plot both sides of x axis
366
+ for i in range(nbins - 1, -1, -1):
367
+ y = ys[i, :] / scale
368
+ c = pl.get_cmap(color)(i / (
369
+ nbins - 1)) if color in pl.cm.datad else color # if color is a cmap, use it, otherwise use a color
370
+ pl.fill_between(x_points, pos - y, pos + y, facecolor=c)
371
+ pl.xlim(shap_min, shap_max)
372
+
373
+ # draw the color bar
374
+ if color_bar and features is not None and (plot_type != "layered_violin" or color in pl.cm.datad):
375
+ import matplotlib.cm as cm
376
+ m = cm.ScalarMappable(cmap=red_blue_solid if plot_type != "layered_violin" else pl.get_cmap(color))
377
+ m.set_array([0, 1])
378
+ cb = pl.colorbar(m, ticks=[0, 1], aspect=1000)
379
+ cb.set_ticklabels([labels['FEATURE_VALUE_LOW'], labels['FEATURE_VALUE_HIGH']])
380
+ cb.set_label(labels['FEATURE_VALUE'], size=12, labelpad=0)
381
+ cb.ax.tick_params(labelsize=11, length=0)
382
+ cb.set_alpha(1)
383
+ cb.outline.set_visible(False)
384
+ bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
385
+ cb.ax.set_aspect((bbox.height - 0.9) * 20)
386
+ # cb.draw_all()
387
+
388
+ pl.gca().xaxis.set_ticks_position('bottom')
389
+ pl.gca().yaxis.set_ticks_position('none')
390
+ pl.gca().spines['right'].set_visible(False)
391
+ pl.gca().spines['top'].set_visible(False)
392
+ pl.gca().spines['left'].set_visible(False)
393
+ pl.gca().tick_params(color=axis_color, labelcolor=axis_color)
394
+ pl.yticks(range(len(feature_order)), [feature_names[i] for i in feature_order], fontsize=13)
395
+ pl.gca().tick_params('y', length=20, width=0.5, which='major')
396
+ pl.gca().tick_params('x', labelsize=11)
397
+ pl.ylim(-1, len(feature_order))
398
+ pl.xlabel(labels['VALUE'], fontsize=13)
399
+ pl.tight_layout()
400
+ # if show:
401
+ # pl.show()
402
+ return mpl_fig
403
+
404
+
405
+
406
+
407
+
408
+
409
+ def approx_interactions(index, shap_values, X):
410
+ """ Order other features by how much interaction they seem to have with the feature at the given index.
411
+
412
+ This just bins the SHAP values for a feature along that feature's value. For true Shapley interaction
413
+ index values for SHAP see the interaction_contribs option implemented in XGBoost.
414
+ """
415
+
416
+ if X.shape[0] > 10000:
417
+ a = np.arange(X.shape[0])
418
+ np.random.shuffle(a)
419
+ inds = a[:10000]
420
+ else:
421
+ inds = np.arange(X.shape[0])
422
+
423
+ x = X[inds, index]
424
+ srt = np.argsort(x)
425
+ shap_ref = shap_values[inds, index]
426
+ shap_ref = shap_ref[srt]
427
+ inc = max(min(int(len(x) / 10.0), 50), 1)
428
+ interactions = []
429
+ for i in range(X.shape[1]):
430
+ val_other = X[inds, i][srt].astype(np.float)
431
+ v = 0.0
432
+ if not (i == index or np.sum(np.abs(val_other)) < 1e-8):
433
+ for j in range(0, len(x), inc):
434
+ if np.std(val_other[j:j + inc]) > 0 and np.std(shap_ref[j:j + inc]) > 0:
435
+ v += abs(np.corrcoef(shap_ref[j:j + inc], val_other[j:j + inc])[0, 1])
436
+ interactions.append(v)
437
+
438
+ return np.argsort(-np.abs(interactions))
439
+
440
+
441
+
442
+
443
+
444
+
445
+
446
+ def shap_dependence_plot(ind, shap_values, features, feature_names=None, display_features=None,
447
+ interaction_index="auto", color="#1E88E5", axis_color="#333333",
448
+ dot_size=16, alpha=1, title=None, show=True):
449
+ """
450
+ Create a SHAP dependence plot, colored by an interaction feature.
451
+
452
+ Parameters
453
+ ----------
454
+ ind : int
455
+ Index of the feature to plot.
456
+
457
+ shap_values : numpy.array
458
+ Matrix of SHAP values (# samples x # features)
459
+
460
+ features : numpy.array or pandas.DataFrame
461
+ Matrix of feature values (# samples x # features)
462
+
463
+ feature_names : list
464
+ Names of the features (length # features)
465
+
466
+ display_features : numpy.array or pandas.DataFrame
467
+ Matrix of feature values for visual display (such as strings instead of coded values)
468
+
469
+ interaction_index : "auto", None, or int
470
+ The index of the feature used to color the plot.
471
+ """
472
+
473
+ # convert from DataFrames if we got any
474
+ if str(type(features)).endswith("'pandas.core.frame.DataFrame'>"):
475
+ if feature_names is None:
476
+ feature_names = features.columns
477
+ features = features.values
478
+ if str(type(display_features)).endswith("'pandas.core.frame.DataFrame'>"):
479
+ if feature_names is None:
480
+ feature_names = display_features.columns
481
+ display_features = display_features.values
482
+ elif display_features is None:
483
+ display_features = features
484
+
485
+ if feature_names is None:
486
+ feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)]
487
+
488
+ # allow vectors to be passed
489
+ if len(shap_values.shape) == 1:
490
+ shap_values = np.reshape(shap_values, len(shap_values), 1)
491
+ if len(features.shape) == 1:
492
+ features = np.reshape(features, len(features), 1)
493
+
494
+ def convert_name(ind):
495
+ if type(ind) == str:
496
+ nzinds = np.where(feature_names == ind)[0]
497
+ if len(nzinds) == 0:
498
+ print("Could not find feature named: " + ind)
499
+ return None
500
+ else:
501
+ return nzinds[0]
502
+ else:
503
+ return ind
504
+
505
+ ind = convert_name(ind)
506
+
507
+ mpl_fig = pl.gcf()
508
+ ax = mpl_fig.gca()
509
+
510
+ # plotting SHAP interaction values
511
+ if len(shap_values.shape) == 3 and len(ind) == 2:
512
+ ind1 = convert_name(ind[0])
513
+ ind2 = convert_name(ind[1])
514
+ if ind1 == ind2:
515
+ proj_shap_values = shap_values[:, ind2, :]
516
+ else:
517
+ proj_shap_values = shap_values[:, ind2, :] * 2 # off-diag values are split in half
518
+
519
+ # TODO: remove recursion; generally the functions should be shorter for more maintainable code
520
+ return shap_dependence_plot(
521
+ ind1, proj_shap_values, features, feature_names=feature_names,
522
+ interaction_index=ind2, display_features=display_features, show=False
523
+ )
524
+
525
+ assert shap_values.shape[0] == features.shape[0], \
526
+ "'shap_values' and 'features' values must have the same number of rows!"
527
+ assert shap_values.shape[1] == features.shape[1], \
528
+ "'shap_values' must have the same number of columns as 'features'!"
529
+
530
+ # get both the raw and display feature values
531
+ xv = features[:, ind]
532
+ xd = display_features[:, ind]
533
+ s = shap_values[:, ind]
534
+ if type(xd[0]) == str:
535
+ name_map = {}
536
+ for i in range(len(xv)):
537
+ name_map[xd[i]] = xv[i]
538
+ xnames = list(name_map.keys())
539
+
540
+ # allow a single feature name to be passed alone
541
+ if type(feature_names) == str:
542
+ feature_names = [feature_names]
543
+ name = feature_names[ind]
544
+
545
+ # guess what other feature as the stongest interaction with the plotted feature
546
+ if interaction_index == "auto":
547
+ interaction_index = approx_interactions(ind, shap_values, features)[0]
548
+ interaction_index = convert_name(interaction_index)
549
+ categorical_interaction = False
550
+
551
+ # get both the raw and display color values
552
+ if interaction_index is not None:
553
+ cv = features[:, interaction_index]
554
+ cd = display_features[:, interaction_index]
555
+ clow = np.nanpercentile(features[:, interaction_index].astype(np.float), 5)
556
+ chigh = np.nanpercentile(features[:, interaction_index].astype(np.float), 95)
557
+ if type(cd[0]) == str:
558
+ cname_map = {}
559
+ for i in range(len(cv)):
560
+ cname_map[cd[i]] = cv[i]
561
+ cnames = list(cname_map.keys())
562
+ categorical_interaction = True
563
+ elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:, interaction_index])) < 50:
564
+ categorical_interaction = True
565
+
566
+ # discritize colors for categorical features
567
+ color_norm = None
568
+ if categorical_interaction and clow != chigh:
569
+ bounds = np.linspace(clow, chigh, chigh - clow + 2)
570
+ color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N)
571
+
572
+ # the actual scatter plot, TODO: adapt the dot_size to the number of data points?
573
+ if interaction_index is not None:
574
+ pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:, interaction_index], cmap=red_blue,
575
+ alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500)
576
+ else:
577
+ pl.scatter(xv, s, s=dot_size, linewidth=0, color="#1E88E5",
578
+ alpha=alpha, rasterized=len(xv) > 500)
579
+
580
+ if interaction_index != ind and interaction_index is not None:
581
+ # draw the color bar
582
+ if type(cd[0]) == str:
583
+ tick_positions = [cname_map[n] for n in cnames]
584
+ if len(tick_positions) == 2:
585
+ tick_positions[0] -= 0.25
586
+ tick_positions[1] += 0.25
587
+ cb = pl.colorbar(ticks=tick_positions)
588
+ cb.set_ticklabels(cnames)
589
+ else:
590
+ cb = pl.colorbar()
591
+
592
+ cb.set_label(feature_names[interaction_index], size=13)
593
+ cb.ax.tick_params(labelsize=11)
594
+ if categorical_interaction:
595
+ cb.ax.tick_params(length=0)
596
+ cb.set_alpha(1)
597
+ cb.outline.set_visible(False)
598
+ bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
599
+ cb.ax.set_aspect((bbox.height - 0.7) * 20)
600
+
601
+ # make the plot more readable
602
+ if interaction_index != ind:
603
+ pl.gcf().set_size_inches(7.5, 5)
604
+ else:
605
+ pl.gcf().set_size_inches(6, 5)
606
+ # pl.xlabel(name, color=axis_color, fontsize=13)
607
+ # pl.ylabel(labels['VALUE_FOR'] % name, color=axis_color, fontsize=13)
608
+ if title is not None:
609
+ pl.title(title, color=axis_color, fontsize=13)
610
+ pl.gca().xaxis.set_ticks_position('bottom')
611
+ pl.gca().yaxis.set_ticks_position('left')
612
+ pl.gca().spines['right'].set_visible(False)
613
+ pl.gca().spines['top'].set_visible(False)
614
+ pl.gca().tick_params(color=axis_color, labelcolor=axis_color, labelsize=11)
615
+ for spine in pl.gca().spines.values():
616
+ spine.set_edgecolor(axis_color)
617
+ if type(xd[0]) == str:
618
+ pl.xticks([name_map[n] for n in xnames], xnames, rotation='vertical', fontsize=11)
619
+ # if show:
620
+ # pl.show()
621
+
622
+
623
+ if ind1 == ind2:
624
+ pl.ylabel(labels['MAIN_EFFECT'] % feature_names[ind1])
625
+ else:
626
+ pl.ylabel(labels['INTERACTION_EFFECT'] % (feature_names[ind1], feature_names[ind2]))
627
+
628
+ return mpl_fig, interaction_index
629
+
630
+
631
+ # # if show:
632
+ # # pl.show()
633
+ # return
634
+ # return mpl_fig
635
+
636
+ # assert shap_values.shape[0] == features.shape[0], "'shap_values' and 'features' values must have the same number of rows!"
637
+ # assert shap_values.shape[1] == features.shape[1] + 1, "'shap_values' must have one more column than 'features'!"
638
+
639
+ # get both the raw and display feature values
640
+ xv = features[:, ind]
641
+ xd = display_features[:, ind]
642
+ s = shap_values[:, ind]
643
+ if type(xd[0]) == str:
644
+ name_map = {}
645
+ for i in range(len(xv)):
646
+ name_map[xd[i]] = xv[i]
647
+ xnames = list(name_map.keys())
648
+
649
+ # allow a single feature name to be passed alone
650
+ if type(feature_names) == str:
651
+ feature_names = [feature_names]
652
+ name = feature_names[ind]
653
+
654
+ # guess what other feature as the stongest interaction with the plotted feature
655
+ if interaction_index == "auto":
656
+ interaction_index = approx_interactions(ind, shap_values, features)[0]
657
+ interaction_index = convert_name(interaction_index)
658
+ categorical_interaction = False
659
+
660
+ # get both the raw and display color values
661
+ if interaction_index is not None:
662
+ cv = features[:, interaction_index]
663
+ cd = display_features[:, interaction_index]
664
+ clow = np.nanpercentile(features[:, interaction_index].astype(np.float), 5)
665
+ chigh = np.nanpercentile(features[:, interaction_index].astype(np.float), 95)
666
+ if type(cd[0]) == str:
667
+ cname_map = {}
668
+ for i in range(len(cv)):
669
+ cname_map[cd[i]] = cv[i]
670
+ cnames = list(cname_map.keys())
671
+ categorical_interaction = True
672
+ elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:, interaction_index])) < 50:
673
+ categorical_interaction = True
674
+
675
+ # discritize colors for categorical features
676
+ color_norm = None
677
+ if categorical_interaction and clow != chigh:
678
+ bounds = np.linspace(clow, chigh, chigh - clow + 2)
679
+ color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N)
680
+
681
+ # the actual scatter plot, TODO: adapt the dot_size to the number of data points?
682
+ if interaction_index is not None:
683
+ pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:, interaction_index], cmap=red_blue,
684
+ alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500)
685
+ else:
686
+ pl.scatter(xv, s, s=dot_size, linewidth=0, color="#1E88E5",
687
+ alpha=alpha, rasterized=len(xv) > 500)
688
+
689
+ if interaction_index != ind and interaction_index is not None:
690
+ # draw the color bar
691
+ if type(cd[0]) == str:
692
+ tick_positions = [cname_map[n] for n in cnames]
693
+ if len(tick_positions) == 2:
694
+ tick_positions[0] -= 0.25
695
+ tick_positions[1] += 0.25
696
+ cb = pl.colorbar(ticks=tick_positions)
697
+ cb.set_ticklabels(cnames)
698
+ else:
699
+ cb = pl.colorbar()
700
+
701
+ cb.set_label(feature_names[interaction_index], size=13)
702
+ cb.ax.tick_params(labelsize=11)
703
+ if categorical_interaction:
704
+ cb.ax.tick_params(length=0)
705
+ cb.set_alpha(1)
706
+ cb.outline.set_visible(False)
707
+ bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
708
+ cb.ax.set_aspect((bbox.height - 0.7) * 20)
709
+
710
+ # make the plot more readable
711
+ if interaction_index != ind:
712
+ pl.gcf().set_size_inches(7.5, 5)
713
+ else:
714
+ pl.gcf().set_size_inches(6, 5)
715
+ pl.xlabel(name, color=axis_color, fontsize=13)
716
+ pl.ylabel(labels['VALUE_FOR'] % name, color=axis_color, fontsize=13)
717
+ if title is not None:
718
+ pl.title(title, color=axis_color, fontsize=13)
719
+ pl.gca().xaxis.set_ticks_position('bottom')
720
+ pl.gca().yaxis.set_ticks_position('left')
721
+ pl.gca().spines['right'].set_visible(False)
722
+ pl.gca().spines['top'].set_visible(False)
723
+ pl.gca().tick_params(color=axis_color, labelcolor=axis_color, labelsize=11)
724
+ for spine in pl.gca().spines.values():
725
+ spine.set_edgecolor(axis_color)
726
+ if type(xd[0]) == str:
727
+ pl.xticks([name_map[n] for n in xnames], xnames, rotation='vertical', fontsize=11)
728
+ # if show:
729
+ # pl.show()
730
+ return mpl_fig, interaction_index