Spaces:
Sleeping
Sleeping
Commit
·
20aea5e
1
Parent(s):
9cea6b7
Upload shap_plots.py
Browse files- shap_plots.py +730 -0
shap_plots.py
ADDED
@@ -0,0 +1,730 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
import iml
|
3 |
+
import numpy as np
|
4 |
+
from iml import Instance, Model
|
5 |
+
from iml.datatypes import DenseData
|
6 |
+
from iml.explanations import AdditiveExplanation
|
7 |
+
from iml.links import IdentityLink
|
8 |
+
from scipy.stats import gaussian_kde
|
9 |
+
import matplotlib
|
10 |
+
try:
|
11 |
+
import matplotlib.pyplot as pl
|
12 |
+
from matplotlib.colors import LinearSegmentedColormap
|
13 |
+
from matplotlib.ticker import MaxNLocator
|
14 |
+
|
15 |
+
cdict1 = {
|
16 |
+
'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
|
17 |
+
(1.0, 0.9607843137254902, 0.9607843137254902)),
|
18 |
+
|
19 |
+
'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
|
20 |
+
(1.0, 0.15294117647058825, 0.15294117647058825)),
|
21 |
+
|
22 |
+
'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
|
23 |
+
(1.0, 0.3411764705882353, 0.3411764705882353)),
|
24 |
+
|
25 |
+
'alpha': ((0.0, 1, 1),
|
26 |
+
(0.5, 0.3, 0.3),
|
27 |
+
(1.0, 1, 1))
|
28 |
+
} # #1E88E5 -> #ff0052
|
29 |
+
red_blue = LinearSegmentedColormap('RedBlue', cdict1)
|
30 |
+
|
31 |
+
cdict1 = {
|
32 |
+
'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
|
33 |
+
(1.0, 0.9607843137254902, 0.9607843137254902)),
|
34 |
+
|
35 |
+
'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
|
36 |
+
(1.0, 0.15294117647058825, 0.15294117647058825)),
|
37 |
+
|
38 |
+
'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
|
39 |
+
(1.0, 0.3411764705882353, 0.3411764705882353)),
|
40 |
+
|
41 |
+
'alpha': ((0.0, 1, 1),
|
42 |
+
(0.5, 1, 1),
|
43 |
+
(1.0, 1, 1))
|
44 |
+
} # #1E88E5 -> #ff0052
|
45 |
+
red_blue_solid = LinearSegmentedColormap('RedBlue', cdict1)
|
46 |
+
except ImportError:
|
47 |
+
pass
|
48 |
+
|
49 |
+
labels = {
|
50 |
+
'MAIN_EFFECT': "SHAP main effect value for\n%s",
|
51 |
+
'INTERACTION_VALUE': "SHAP interaction value",
|
52 |
+
'INTERACTION_EFFECT': "SHAP interaction value for\n%s and %s",
|
53 |
+
'VALUE': "SHAP value (impact on model output)",
|
54 |
+
'VALUE_FOR': "SHAP value for\n%s",
|
55 |
+
'PLOT_FOR': "SHAP plot for %s",
|
56 |
+
'FEATURE': "Feature %s",
|
57 |
+
'FEATURE_VALUE': "Feature value",
|
58 |
+
'FEATURE_VALUE_LOW': "Low",
|
59 |
+
'FEATURE_VALUE_HIGH': "High",
|
60 |
+
'JOINT_VALUE': "Joint SHAP value"
|
61 |
+
}
|
62 |
+
|
63 |
+
def shap_summary_plot(shap_values, features=None, feature_names=None, max_display=None, plot_type="dot",
|
64 |
+
color=None, axis_color="#333333", title=None, alpha=1, show=True, sort=True,
|
65 |
+
color_bar=True, auto_size_plot=True, layered_violin_max_num_bins=20):
|
66 |
+
"""Create a SHAP summary plot, colored by feature values when they are provided.
|
67 |
+
|
68 |
+
Parameters
|
69 |
+
----------
|
70 |
+
shap_values : numpy.array
|
71 |
+
Matrix of SHAP values (# samples x # features)
|
72 |
+
|
73 |
+
features : numpy.array or pandas.DataFrame or list
|
74 |
+
Matrix of feature values (# samples x # features) or a feature_names list as shorthand
|
75 |
+
|
76 |
+
feature_names : list
|
77 |
+
Names of the features (length # features)
|
78 |
+
|
79 |
+
max_display : int
|
80 |
+
How many top features to include in the plot (default is 20, or 7 for interaction plots)
|
81 |
+
|
82 |
+
plot_type : "dot" (default) or "violin"
|
83 |
+
What type of summary plot to produce
|
84 |
+
"""
|
85 |
+
|
86 |
+
assert len(shap_values.shape) != 1, "Summary plots need a matrix of shap_values, not a vector."
|
87 |
+
|
88 |
+
# default color:
|
89 |
+
if color is None:
|
90 |
+
color = "coolwarm" if plot_type == 'layered_violin' else "#ff0052"
|
91 |
+
|
92 |
+
# convert from a DataFrame or other types
|
93 |
+
if str(type(features)) == "<class 'pandas.core.frame.DataFrame'>":
|
94 |
+
if feature_names is None:
|
95 |
+
feature_names = features.columns
|
96 |
+
features = features.values
|
97 |
+
elif str(type(features)) == "<class 'list'>":
|
98 |
+
if feature_names is None:
|
99 |
+
feature_names = features
|
100 |
+
features = None
|
101 |
+
elif (features is not None) and len(features.shape) == 1 and feature_names is None:
|
102 |
+
feature_names = features
|
103 |
+
features = None
|
104 |
+
|
105 |
+
if feature_names is None:
|
106 |
+
feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)]
|
107 |
+
|
108 |
+
mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1))
|
109 |
+
|
110 |
+
# plotting SHAP interaction values
|
111 |
+
if len(shap_values.shape) == 3:
|
112 |
+
if max_display is None:
|
113 |
+
max_display = 7
|
114 |
+
else:
|
115 |
+
max_display = min(len(feature_names), max_display)
|
116 |
+
|
117 |
+
sort_inds = np.argsort(-np.abs(shap_values[:, :-1, :-1].sum(1)).sum(0))
|
118 |
+
|
119 |
+
# get plotting limits
|
120 |
+
delta = 1.0 / (shap_values.shape[1] ** 2)
|
121 |
+
slow = np.nanpercentile(shap_values, delta)
|
122 |
+
shigh = np.nanpercentile(shap_values, 100 - delta)
|
123 |
+
v = max(abs(slow), abs(shigh))
|
124 |
+
slow = -0.2
|
125 |
+
shigh = 0.2
|
126 |
+
|
127 |
+
# mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1))
|
128 |
+
ax = mpl_fig.subplot(1, max_display, 1)
|
129 |
+
proj_shap_values = shap_values[:, sort_inds[0], np.hstack((sort_inds, len(sort_inds)))]
|
130 |
+
proj_shap_values[:, 1:] *= 2 # because off diag effects are split in half
|
131 |
+
shap_summary_plot(
|
132 |
+
proj_shap_values, features[:, sort_inds],
|
133 |
+
feature_names=feature_names[sort_inds],
|
134 |
+
sort=False, show=False, color_bar=False,
|
135 |
+
auto_size_plot=False,
|
136 |
+
max_display=max_display
|
137 |
+
)
|
138 |
+
pl.xlim((slow, shigh))
|
139 |
+
pl.xlabel("")
|
140 |
+
title_length_limit = 11
|
141 |
+
pl.title(shorten_text(feature_names[sort_inds[0]], title_length_limit))
|
142 |
+
for i in range(1, max_display):
|
143 |
+
ind = sort_inds[i]
|
144 |
+
pl.subplot(1, max_display, i + 1)
|
145 |
+
proj_shap_values = shap_values[:, ind, np.hstack((sort_inds, len(sort_inds)))]
|
146 |
+
proj_shap_values *= 2
|
147 |
+
proj_shap_values[:, i] /= 2 # because only off diag effects are split in half
|
148 |
+
shap_summary_plot(
|
149 |
+
proj_shap_values, features[:, sort_inds],
|
150 |
+
sort=False,
|
151 |
+
feature_names=["" for i in range(features.shape[1])],
|
152 |
+
show=False,
|
153 |
+
color_bar=False,
|
154 |
+
auto_size_plot=False,
|
155 |
+
max_display=max_display
|
156 |
+
)
|
157 |
+
pl.xlim((slow, shigh))
|
158 |
+
pl.xlabel("")
|
159 |
+
if i == max_display // 2:
|
160 |
+
pl.xlabel(labels['INTERACTION_VALUE'])
|
161 |
+
pl.title(shorten_text(feature_names[ind], title_length_limit))
|
162 |
+
pl.tight_layout(pad=0, w_pad=0, h_pad=0.0)
|
163 |
+
pl.subplots_adjust(hspace=0, wspace=0.1)
|
164 |
+
# if show:
|
165 |
+
# # pl.show()
|
166 |
+
return mpl_fig
|
167 |
+
|
168 |
+
if max_display is None:
|
169 |
+
max_display = 20
|
170 |
+
|
171 |
+
if sort:
|
172 |
+
# order features by the sum of their effect magnitudes
|
173 |
+
feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
|
174 |
+
feature_order = feature_order[-min(max_display, len(feature_order)):]
|
175 |
+
else:
|
176 |
+
feature_order = np.flip(np.arange(min(max_display, shap_values.shape[1] - 1)), 0)
|
177 |
+
|
178 |
+
row_height = 0.4
|
179 |
+
if auto_size_plot:
|
180 |
+
pl.gcf().set_size_inches(8, len(feature_order) * row_height + 1.5)
|
181 |
+
pl.axvline(x=0, color="#999999", zorder=-1)
|
182 |
+
|
183 |
+
if plot_type == "dot":
|
184 |
+
for pos, i in enumerate(feature_order):
|
185 |
+
pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
|
186 |
+
shaps = shap_values[:, i]
|
187 |
+
values = None if features is None else features[:, i]
|
188 |
+
inds = np.arange(len(shaps))
|
189 |
+
np.random.shuffle(inds)
|
190 |
+
if values is not None:
|
191 |
+
values = values[inds]
|
192 |
+
shaps = shaps[inds]
|
193 |
+
colored_feature = True
|
194 |
+
try:
|
195 |
+
values = np.array(values, dtype=np.float64) # make sure this can be numeric
|
196 |
+
except:
|
197 |
+
colored_feature = False
|
198 |
+
N = len(shaps)
|
199 |
+
# hspacing = (np.max(shaps) - np.min(shaps)) / 200
|
200 |
+
# curr_bin = []
|
201 |
+
nbins = 100
|
202 |
+
quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))
|
203 |
+
inds = np.argsort(quant + np.random.randn(N) * 1e-6)
|
204 |
+
layer = 0
|
205 |
+
last_bin = -1
|
206 |
+
ys = np.zeros(N)
|
207 |
+
for ind in inds:
|
208 |
+
if quant[ind] != last_bin:
|
209 |
+
layer = 0
|
210 |
+
ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1)
|
211 |
+
layer += 1
|
212 |
+
last_bin = quant[ind]
|
213 |
+
ys *= 0.9 * (row_height / np.max(ys + 1))
|
214 |
+
|
215 |
+
if features is not None and colored_feature:
|
216 |
+
# trim the color range, but prevent the color range from collapsing
|
217 |
+
vmin = np.nanpercentile(values, 5)
|
218 |
+
vmax = np.nanpercentile(values, 95)
|
219 |
+
if vmin == vmax:
|
220 |
+
vmin = np.nanpercentile(values, 1)
|
221 |
+
vmax = np.nanpercentile(values, 99)
|
222 |
+
if vmin == vmax:
|
223 |
+
vmin = np.min(values)
|
224 |
+
vmax = np.max(values)
|
225 |
+
|
226 |
+
assert features.shape[0] == len(shaps), "Feature and SHAP matrices must have the same number of rows!"
|
227 |
+
nan_mask = np.isnan(values)
|
228 |
+
pl.scatter(shaps[nan_mask], pos + ys[nan_mask], color="#777777", vmin=vmin,
|
229 |
+
vmax=vmax, s=16, alpha=alpha, linewidth=0,
|
230 |
+
zorder=3, rasterized=len(shaps) > 500)
|
231 |
+
pl.scatter(shaps[np.invert(nan_mask)], pos + ys[np.invert(nan_mask)],
|
232 |
+
cmap=red_blue, vmin=vmin, vmax=vmax, s=16,
|
233 |
+
c=values[np.invert(nan_mask)], alpha=alpha, linewidth=0,
|
234 |
+
zorder=3, rasterized=len(shaps) > 500)
|
235 |
+
else:
|
236 |
+
|
237 |
+
pl.scatter(shaps, pos + ys, s=16, alpha=alpha, linewidth=0, zorder=3,
|
238 |
+
color=color if colored_feature else "#777777", rasterized=len(shaps) > 500)
|
239 |
+
|
240 |
+
elif plot_type == "violin":
|
241 |
+
for pos, i in enumerate(feature_order):
|
242 |
+
pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
|
243 |
+
|
244 |
+
if features is not None:
|
245 |
+
global_low = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 1)
|
246 |
+
global_high = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 99)
|
247 |
+
for pos, i in enumerate(feature_order):
|
248 |
+
shaps = shap_values[:, i]
|
249 |
+
shap_min, shap_max = np.min(shaps), np.max(shaps)
|
250 |
+
rng = shap_max - shap_min
|
251 |
+
xs = np.linspace(np.min(shaps) - rng * 0.2, np.max(shaps) + rng * 0.2, 100)
|
252 |
+
if np.std(shaps) < (global_high - global_low) / 100:
|
253 |
+
ds = gaussian_kde(shaps + np.random.randn(len(shaps)) * (global_high - global_low) / 100)(xs)
|
254 |
+
else:
|
255 |
+
ds = gaussian_kde(shaps)(xs)
|
256 |
+
ds /= np.max(ds) * 3
|
257 |
+
|
258 |
+
values = features[:, i]
|
259 |
+
window_size = max(10, len(values) // 20)
|
260 |
+
smooth_values = np.zeros(len(xs) - 1)
|
261 |
+
sort_inds = np.argsort(shaps)
|
262 |
+
trailing_pos = 0
|
263 |
+
leading_pos = 0
|
264 |
+
running_sum = 0
|
265 |
+
back_fill = 0
|
266 |
+
for j in range(len(xs) - 1):
|
267 |
+
|
268 |
+
while leading_pos < len(shaps) and xs[j] >= shaps[sort_inds[leading_pos]]:
|
269 |
+
running_sum += values[sort_inds[leading_pos]]
|
270 |
+
leading_pos += 1
|
271 |
+
if leading_pos - trailing_pos > 20:
|
272 |
+
running_sum -= values[sort_inds[trailing_pos]]
|
273 |
+
trailing_pos += 1
|
274 |
+
if leading_pos - trailing_pos > 0:
|
275 |
+
smooth_values[j] = running_sum / (leading_pos - trailing_pos)
|
276 |
+
for k in range(back_fill):
|
277 |
+
smooth_values[j - k - 1] = smooth_values[j]
|
278 |
+
else:
|
279 |
+
back_fill += 1
|
280 |
+
|
281 |
+
vmin = np.nanpercentile(values, 5)
|
282 |
+
vmax = np.nanpercentile(values, 95)
|
283 |
+
if vmin == vmax:
|
284 |
+
vmin = np.nanpercentile(values, 1)
|
285 |
+
vmax = np.nanpercentile(values, 99)
|
286 |
+
if vmin == vmax:
|
287 |
+
vmin = np.min(values)
|
288 |
+
vmax = np.max(values)
|
289 |
+
pl.scatter(shaps, np.ones(shap_values.shape[0]) * pos, s=9, cmap=red_blue_solid, vmin=vmin, vmax=vmax,
|
290 |
+
c=values, alpha=alpha, linewidth=0, zorder=1)
|
291 |
+
# smooth_values -= nxp.nanpercentile(smooth_values, 5)
|
292 |
+
# smooth_values /= np.nanpercentile(smooth_values, 95)
|
293 |
+
smooth_values -= vmin
|
294 |
+
if vmax - vmin > 0:
|
295 |
+
smooth_values /= vmax - vmin
|
296 |
+
for i in range(len(xs) - 1):
|
297 |
+
if ds[i] > 0.05 or ds[i + 1] > 0.05:
|
298 |
+
pl.fill_between([xs[i], xs[i + 1]], [pos + ds[i], pos + ds[i + 1]],
|
299 |
+
[pos - ds[i], pos - ds[i + 1]], color=red_blue_solid(smooth_values[i]),
|
300 |
+
zorder=2)
|
301 |
+
|
302 |
+
else:
|
303 |
+
parts = pl.violinplot(shap_values[:, feature_order], range(len(feature_order)), points=200, vert=False,
|
304 |
+
widths=0.7,
|
305 |
+
showmeans=False, showextrema=False, showmedians=False)
|
306 |
+
|
307 |
+
for pc in parts['bodies']:
|
308 |
+
pc.set_facecolor(color)
|
309 |
+
pc.set_edgecolor('none')
|
310 |
+
pc.set_alpha(alpha)
|
311 |
+
|
312 |
+
elif plot_type == "layered_violin": # courtesy of @kodonnell
|
313 |
+
num_x_points = 200
|
314 |
+
bins = np.linspace(0, features.shape[0], layered_violin_max_num_bins + 1).round(0).astype(
|
315 |
+
'int') # the indices of the feature data corresponding to each bin
|
316 |
+
shap_min, shap_max = np.min(shap_values[:, :-1]), np.max(shap_values[:, :-1])
|
317 |
+
x_points = np.linspace(shap_min, shap_max, num_x_points)
|
318 |
+
|
319 |
+
# loop through each feature and plot:
|
320 |
+
for pos, ind in enumerate(feature_order):
|
321 |
+
# decide how to handle: if #unique < layered_violin_max_num_bins then split by unique value, otherwise use bins/percentiles.
|
322 |
+
# to keep simpler code, in the case of uniques, we just adjust the bins to align with the unique counts.
|
323 |
+
feature = features[:, ind]
|
324 |
+
unique, counts = np.unique(feature, return_counts=True)
|
325 |
+
if unique.shape[0] <= layered_violin_max_num_bins:
|
326 |
+
order = np.argsort(unique)
|
327 |
+
thesebins = np.cumsum(counts[order])
|
328 |
+
thesebins = np.insert(thesebins, 0, 0)
|
329 |
+
else:
|
330 |
+
thesebins = bins
|
331 |
+
nbins = thesebins.shape[0] - 1
|
332 |
+
# order the feature data so we can apply percentiling
|
333 |
+
order = np.argsort(feature)
|
334 |
+
# x axis is located at y0 = pos, with pos being there for offset
|
335 |
+
y0 = np.ones(num_x_points) * pos
|
336 |
+
# calculate kdes:
|
337 |
+
ys = np.zeros((nbins, num_x_points))
|
338 |
+
for i in range(nbins):
|
339 |
+
# get shap values in this bin:
|
340 |
+
shaps = shap_values[order[thesebins[i]:thesebins[i + 1]], ind]
|
341 |
+
# if there's only one element, then we can't
|
342 |
+
if shaps.shape[0] == 1:
|
343 |
+
warnings.warn(
|
344 |
+
"not enough data in bin #%d for feature %s, so it'll be ignored. Try increasing the number of records to plot."
|
345 |
+
% (i, feature_names[ind]))
|
346 |
+
# to ignore it, just set it to the previous y-values (so the area between them will be zero). Not ys is already 0, so there's
|
347 |
+
# nothing to do if i == 0
|
348 |
+
if i > 0:
|
349 |
+
ys[i, :] = ys[i - 1, :]
|
350 |
+
continue
|
351 |
+
# save kde of them: note that we add a tiny bit of gaussian noise to avoid singular matrix errors
|
352 |
+
ys[i, :] = gaussian_kde(shaps + np.random.normal(loc=0, scale=0.001, size=shaps.shape[0]))(x_points)
|
353 |
+
# scale it up so that the 'size' of each y represents the size of the bin. For continuous data this will
|
354 |
+
# do nothing, but when we've gone with the unqique option, this will matter - e.g. if 99% are male and 1%
|
355 |
+
# female, we want the 1% to appear a lot smaller.
|
356 |
+
size = thesebins[i + 1] - thesebins[i]
|
357 |
+
bin_size_if_even = features.shape[0] / nbins
|
358 |
+
relative_bin_size = size / bin_size_if_even
|
359 |
+
ys[i, :] *= relative_bin_size
|
360 |
+
# now plot 'em. We don't plot the individual strips, as this can leave whitespace between them.
|
361 |
+
# instead, we plot the full kde, then remove outer strip and plot over it, etc., to ensure no
|
362 |
+
# whitespace
|
363 |
+
ys = np.cumsum(ys, axis=0)
|
364 |
+
width = 0.8
|
365 |
+
scale = ys.max() * 2 / width # 2 is here as we plot both sides of x axis
|
366 |
+
for i in range(nbins - 1, -1, -1):
|
367 |
+
y = ys[i, :] / scale
|
368 |
+
c = pl.get_cmap(color)(i / (
|
369 |
+
nbins - 1)) if color in pl.cm.datad else color # if color is a cmap, use it, otherwise use a color
|
370 |
+
pl.fill_between(x_points, pos - y, pos + y, facecolor=c)
|
371 |
+
pl.xlim(shap_min, shap_max)
|
372 |
+
|
373 |
+
# draw the color bar
|
374 |
+
if color_bar and features is not None and (plot_type != "layered_violin" or color in pl.cm.datad):
|
375 |
+
import matplotlib.cm as cm
|
376 |
+
m = cm.ScalarMappable(cmap=red_blue_solid if plot_type != "layered_violin" else pl.get_cmap(color))
|
377 |
+
m.set_array([0, 1])
|
378 |
+
cb = pl.colorbar(m, ticks=[0, 1], aspect=1000)
|
379 |
+
cb.set_ticklabels([labels['FEATURE_VALUE_LOW'], labels['FEATURE_VALUE_HIGH']])
|
380 |
+
cb.set_label(labels['FEATURE_VALUE'], size=12, labelpad=0)
|
381 |
+
cb.ax.tick_params(labelsize=11, length=0)
|
382 |
+
cb.set_alpha(1)
|
383 |
+
cb.outline.set_visible(False)
|
384 |
+
bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
|
385 |
+
cb.ax.set_aspect((bbox.height - 0.9) * 20)
|
386 |
+
# cb.draw_all()
|
387 |
+
|
388 |
+
pl.gca().xaxis.set_ticks_position('bottom')
|
389 |
+
pl.gca().yaxis.set_ticks_position('none')
|
390 |
+
pl.gca().spines['right'].set_visible(False)
|
391 |
+
pl.gca().spines['top'].set_visible(False)
|
392 |
+
pl.gca().spines['left'].set_visible(False)
|
393 |
+
pl.gca().tick_params(color=axis_color, labelcolor=axis_color)
|
394 |
+
pl.yticks(range(len(feature_order)), [feature_names[i] for i in feature_order], fontsize=13)
|
395 |
+
pl.gca().tick_params('y', length=20, width=0.5, which='major')
|
396 |
+
pl.gca().tick_params('x', labelsize=11)
|
397 |
+
pl.ylim(-1, len(feature_order))
|
398 |
+
pl.xlabel(labels['VALUE'], fontsize=13)
|
399 |
+
pl.tight_layout()
|
400 |
+
# if show:
|
401 |
+
# pl.show()
|
402 |
+
return mpl_fig
|
403 |
+
|
404 |
+
|
405 |
+
|
406 |
+
|
407 |
+
|
408 |
+
|
409 |
+
def approx_interactions(index, shap_values, X):
|
410 |
+
""" Order other features by how much interaction they seem to have with the feature at the given index.
|
411 |
+
|
412 |
+
This just bins the SHAP values for a feature along that feature's value. For true Shapley interaction
|
413 |
+
index values for SHAP see the interaction_contribs option implemented in XGBoost.
|
414 |
+
"""
|
415 |
+
|
416 |
+
if X.shape[0] > 10000:
|
417 |
+
a = np.arange(X.shape[0])
|
418 |
+
np.random.shuffle(a)
|
419 |
+
inds = a[:10000]
|
420 |
+
else:
|
421 |
+
inds = np.arange(X.shape[0])
|
422 |
+
|
423 |
+
x = X[inds, index]
|
424 |
+
srt = np.argsort(x)
|
425 |
+
shap_ref = shap_values[inds, index]
|
426 |
+
shap_ref = shap_ref[srt]
|
427 |
+
inc = max(min(int(len(x) / 10.0), 50), 1)
|
428 |
+
interactions = []
|
429 |
+
for i in range(X.shape[1]):
|
430 |
+
val_other = X[inds, i][srt].astype(np.float)
|
431 |
+
v = 0.0
|
432 |
+
if not (i == index or np.sum(np.abs(val_other)) < 1e-8):
|
433 |
+
for j in range(0, len(x), inc):
|
434 |
+
if np.std(val_other[j:j + inc]) > 0 and np.std(shap_ref[j:j + inc]) > 0:
|
435 |
+
v += abs(np.corrcoef(shap_ref[j:j + inc], val_other[j:j + inc])[0, 1])
|
436 |
+
interactions.append(v)
|
437 |
+
|
438 |
+
return np.argsort(-np.abs(interactions))
|
439 |
+
|
440 |
+
|
441 |
+
|
442 |
+
|
443 |
+
|
444 |
+
|
445 |
+
|
446 |
+
def shap_dependence_plot(ind, shap_values, features, feature_names=None, display_features=None,
|
447 |
+
interaction_index="auto", color="#1E88E5", axis_color="#333333",
|
448 |
+
dot_size=16, alpha=1, title=None, show=True):
|
449 |
+
"""
|
450 |
+
Create a SHAP dependence plot, colored by an interaction feature.
|
451 |
+
|
452 |
+
Parameters
|
453 |
+
----------
|
454 |
+
ind : int
|
455 |
+
Index of the feature to plot.
|
456 |
+
|
457 |
+
shap_values : numpy.array
|
458 |
+
Matrix of SHAP values (# samples x # features)
|
459 |
+
|
460 |
+
features : numpy.array or pandas.DataFrame
|
461 |
+
Matrix of feature values (# samples x # features)
|
462 |
+
|
463 |
+
feature_names : list
|
464 |
+
Names of the features (length # features)
|
465 |
+
|
466 |
+
display_features : numpy.array or pandas.DataFrame
|
467 |
+
Matrix of feature values for visual display (such as strings instead of coded values)
|
468 |
+
|
469 |
+
interaction_index : "auto", None, or int
|
470 |
+
The index of the feature used to color the plot.
|
471 |
+
"""
|
472 |
+
|
473 |
+
# convert from DataFrames if we got any
|
474 |
+
if str(type(features)).endswith("'pandas.core.frame.DataFrame'>"):
|
475 |
+
if feature_names is None:
|
476 |
+
feature_names = features.columns
|
477 |
+
features = features.values
|
478 |
+
if str(type(display_features)).endswith("'pandas.core.frame.DataFrame'>"):
|
479 |
+
if feature_names is None:
|
480 |
+
feature_names = display_features.columns
|
481 |
+
display_features = display_features.values
|
482 |
+
elif display_features is None:
|
483 |
+
display_features = features
|
484 |
+
|
485 |
+
if feature_names is None:
|
486 |
+
feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)]
|
487 |
+
|
488 |
+
# allow vectors to be passed
|
489 |
+
if len(shap_values.shape) == 1:
|
490 |
+
shap_values = np.reshape(shap_values, len(shap_values), 1)
|
491 |
+
if len(features.shape) == 1:
|
492 |
+
features = np.reshape(features, len(features), 1)
|
493 |
+
|
494 |
+
def convert_name(ind):
|
495 |
+
if type(ind) == str:
|
496 |
+
nzinds = np.where(feature_names == ind)[0]
|
497 |
+
if len(nzinds) == 0:
|
498 |
+
print("Could not find feature named: " + ind)
|
499 |
+
return None
|
500 |
+
else:
|
501 |
+
return nzinds[0]
|
502 |
+
else:
|
503 |
+
return ind
|
504 |
+
|
505 |
+
ind = convert_name(ind)
|
506 |
+
|
507 |
+
mpl_fig = pl.gcf()
|
508 |
+
ax = mpl_fig.gca()
|
509 |
+
|
510 |
+
# plotting SHAP interaction values
|
511 |
+
if len(shap_values.shape) == 3 and len(ind) == 2:
|
512 |
+
ind1 = convert_name(ind[0])
|
513 |
+
ind2 = convert_name(ind[1])
|
514 |
+
if ind1 == ind2:
|
515 |
+
proj_shap_values = shap_values[:, ind2, :]
|
516 |
+
else:
|
517 |
+
proj_shap_values = shap_values[:, ind2, :] * 2 # off-diag values are split in half
|
518 |
+
|
519 |
+
# TODO: remove recursion; generally the functions should be shorter for more maintainable code
|
520 |
+
return shap_dependence_plot(
|
521 |
+
ind1, proj_shap_values, features, feature_names=feature_names,
|
522 |
+
interaction_index=ind2, display_features=display_features, show=False
|
523 |
+
)
|
524 |
+
|
525 |
+
assert shap_values.shape[0] == features.shape[0], \
|
526 |
+
"'shap_values' and 'features' values must have the same number of rows!"
|
527 |
+
assert shap_values.shape[1] == features.shape[1], \
|
528 |
+
"'shap_values' must have the same number of columns as 'features'!"
|
529 |
+
|
530 |
+
# get both the raw and display feature values
|
531 |
+
xv = features[:, ind]
|
532 |
+
xd = display_features[:, ind]
|
533 |
+
s = shap_values[:, ind]
|
534 |
+
if type(xd[0]) == str:
|
535 |
+
name_map = {}
|
536 |
+
for i in range(len(xv)):
|
537 |
+
name_map[xd[i]] = xv[i]
|
538 |
+
xnames = list(name_map.keys())
|
539 |
+
|
540 |
+
# allow a single feature name to be passed alone
|
541 |
+
if type(feature_names) == str:
|
542 |
+
feature_names = [feature_names]
|
543 |
+
name = feature_names[ind]
|
544 |
+
|
545 |
+
# guess what other feature as the stongest interaction with the plotted feature
|
546 |
+
if interaction_index == "auto":
|
547 |
+
interaction_index = approx_interactions(ind, shap_values, features)[0]
|
548 |
+
interaction_index = convert_name(interaction_index)
|
549 |
+
categorical_interaction = False
|
550 |
+
|
551 |
+
# get both the raw and display color values
|
552 |
+
if interaction_index is not None:
|
553 |
+
cv = features[:, interaction_index]
|
554 |
+
cd = display_features[:, interaction_index]
|
555 |
+
clow = np.nanpercentile(features[:, interaction_index].astype(np.float), 5)
|
556 |
+
chigh = np.nanpercentile(features[:, interaction_index].astype(np.float), 95)
|
557 |
+
if type(cd[0]) == str:
|
558 |
+
cname_map = {}
|
559 |
+
for i in range(len(cv)):
|
560 |
+
cname_map[cd[i]] = cv[i]
|
561 |
+
cnames = list(cname_map.keys())
|
562 |
+
categorical_interaction = True
|
563 |
+
elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:, interaction_index])) < 50:
|
564 |
+
categorical_interaction = True
|
565 |
+
|
566 |
+
# discritize colors for categorical features
|
567 |
+
color_norm = None
|
568 |
+
if categorical_interaction and clow != chigh:
|
569 |
+
bounds = np.linspace(clow, chigh, chigh - clow + 2)
|
570 |
+
color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N)
|
571 |
+
|
572 |
+
# the actual scatter plot, TODO: adapt the dot_size to the number of data points?
|
573 |
+
if interaction_index is not None:
|
574 |
+
pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:, interaction_index], cmap=red_blue,
|
575 |
+
alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500)
|
576 |
+
else:
|
577 |
+
pl.scatter(xv, s, s=dot_size, linewidth=0, color="#1E88E5",
|
578 |
+
alpha=alpha, rasterized=len(xv) > 500)
|
579 |
+
|
580 |
+
if interaction_index != ind and interaction_index is not None:
|
581 |
+
# draw the color bar
|
582 |
+
if type(cd[0]) == str:
|
583 |
+
tick_positions = [cname_map[n] for n in cnames]
|
584 |
+
if len(tick_positions) == 2:
|
585 |
+
tick_positions[0] -= 0.25
|
586 |
+
tick_positions[1] += 0.25
|
587 |
+
cb = pl.colorbar(ticks=tick_positions)
|
588 |
+
cb.set_ticklabels(cnames)
|
589 |
+
else:
|
590 |
+
cb = pl.colorbar()
|
591 |
+
|
592 |
+
cb.set_label(feature_names[interaction_index], size=13)
|
593 |
+
cb.ax.tick_params(labelsize=11)
|
594 |
+
if categorical_interaction:
|
595 |
+
cb.ax.tick_params(length=0)
|
596 |
+
cb.set_alpha(1)
|
597 |
+
cb.outline.set_visible(False)
|
598 |
+
bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
|
599 |
+
cb.ax.set_aspect((bbox.height - 0.7) * 20)
|
600 |
+
|
601 |
+
# make the plot more readable
|
602 |
+
if interaction_index != ind:
|
603 |
+
pl.gcf().set_size_inches(7.5, 5)
|
604 |
+
else:
|
605 |
+
pl.gcf().set_size_inches(6, 5)
|
606 |
+
# pl.xlabel(name, color=axis_color, fontsize=13)
|
607 |
+
# pl.ylabel(labels['VALUE_FOR'] % name, color=axis_color, fontsize=13)
|
608 |
+
if title is not None:
|
609 |
+
pl.title(title, color=axis_color, fontsize=13)
|
610 |
+
pl.gca().xaxis.set_ticks_position('bottom')
|
611 |
+
pl.gca().yaxis.set_ticks_position('left')
|
612 |
+
pl.gca().spines['right'].set_visible(False)
|
613 |
+
pl.gca().spines['top'].set_visible(False)
|
614 |
+
pl.gca().tick_params(color=axis_color, labelcolor=axis_color, labelsize=11)
|
615 |
+
for spine in pl.gca().spines.values():
|
616 |
+
spine.set_edgecolor(axis_color)
|
617 |
+
if type(xd[0]) == str:
|
618 |
+
pl.xticks([name_map[n] for n in xnames], xnames, rotation='vertical', fontsize=11)
|
619 |
+
# if show:
|
620 |
+
# pl.show()
|
621 |
+
|
622 |
+
|
623 |
+
if ind1 == ind2:
|
624 |
+
pl.ylabel(labels['MAIN_EFFECT'] % feature_names[ind1])
|
625 |
+
else:
|
626 |
+
pl.ylabel(labels['INTERACTION_EFFECT'] % (feature_names[ind1], feature_names[ind2]))
|
627 |
+
|
628 |
+
return mpl_fig, interaction_index
|
629 |
+
|
630 |
+
|
631 |
+
# # if show:
|
632 |
+
# # pl.show()
|
633 |
+
# return
|
634 |
+
# return mpl_fig
|
635 |
+
|
636 |
+
# assert shap_values.shape[0] == features.shape[0], "'shap_values' and 'features' values must have the same number of rows!"
|
637 |
+
# assert shap_values.shape[1] == features.shape[1] + 1, "'shap_values' must have one more column than 'features'!"
|
638 |
+
|
639 |
+
# get both the raw and display feature values
|
640 |
+
xv = features[:, ind]
|
641 |
+
xd = display_features[:, ind]
|
642 |
+
s = shap_values[:, ind]
|
643 |
+
if type(xd[0]) == str:
|
644 |
+
name_map = {}
|
645 |
+
for i in range(len(xv)):
|
646 |
+
name_map[xd[i]] = xv[i]
|
647 |
+
xnames = list(name_map.keys())
|
648 |
+
|
649 |
+
# allow a single feature name to be passed alone
|
650 |
+
if type(feature_names) == str:
|
651 |
+
feature_names = [feature_names]
|
652 |
+
name = feature_names[ind]
|
653 |
+
|
654 |
+
# guess what other feature as the stongest interaction with the plotted feature
|
655 |
+
if interaction_index == "auto":
|
656 |
+
interaction_index = approx_interactions(ind, shap_values, features)[0]
|
657 |
+
interaction_index = convert_name(interaction_index)
|
658 |
+
categorical_interaction = False
|
659 |
+
|
660 |
+
# get both the raw and display color values
|
661 |
+
if interaction_index is not None:
|
662 |
+
cv = features[:, interaction_index]
|
663 |
+
cd = display_features[:, interaction_index]
|
664 |
+
clow = np.nanpercentile(features[:, interaction_index].astype(np.float), 5)
|
665 |
+
chigh = np.nanpercentile(features[:, interaction_index].astype(np.float), 95)
|
666 |
+
if type(cd[0]) == str:
|
667 |
+
cname_map = {}
|
668 |
+
for i in range(len(cv)):
|
669 |
+
cname_map[cd[i]] = cv[i]
|
670 |
+
cnames = list(cname_map.keys())
|
671 |
+
categorical_interaction = True
|
672 |
+
elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:, interaction_index])) < 50:
|
673 |
+
categorical_interaction = True
|
674 |
+
|
675 |
+
# discritize colors for categorical features
|
676 |
+
color_norm = None
|
677 |
+
if categorical_interaction and clow != chigh:
|
678 |
+
bounds = np.linspace(clow, chigh, chigh - clow + 2)
|
679 |
+
color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N)
|
680 |
+
|
681 |
+
# the actual scatter plot, TODO: adapt the dot_size to the number of data points?
|
682 |
+
if interaction_index is not None:
|
683 |
+
pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:, interaction_index], cmap=red_blue,
|
684 |
+
alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500)
|
685 |
+
else:
|
686 |
+
pl.scatter(xv, s, s=dot_size, linewidth=0, color="#1E88E5",
|
687 |
+
alpha=alpha, rasterized=len(xv) > 500)
|
688 |
+
|
689 |
+
if interaction_index != ind and interaction_index is not None:
|
690 |
+
# draw the color bar
|
691 |
+
if type(cd[0]) == str:
|
692 |
+
tick_positions = [cname_map[n] for n in cnames]
|
693 |
+
if len(tick_positions) == 2:
|
694 |
+
tick_positions[0] -= 0.25
|
695 |
+
tick_positions[1] += 0.25
|
696 |
+
cb = pl.colorbar(ticks=tick_positions)
|
697 |
+
cb.set_ticklabels(cnames)
|
698 |
+
else:
|
699 |
+
cb = pl.colorbar()
|
700 |
+
|
701 |
+
cb.set_label(feature_names[interaction_index], size=13)
|
702 |
+
cb.ax.tick_params(labelsize=11)
|
703 |
+
if categorical_interaction:
|
704 |
+
cb.ax.tick_params(length=0)
|
705 |
+
cb.set_alpha(1)
|
706 |
+
cb.outline.set_visible(False)
|
707 |
+
bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
|
708 |
+
cb.ax.set_aspect((bbox.height - 0.7) * 20)
|
709 |
+
|
710 |
+
# make the plot more readable
|
711 |
+
if interaction_index != ind:
|
712 |
+
pl.gcf().set_size_inches(7.5, 5)
|
713 |
+
else:
|
714 |
+
pl.gcf().set_size_inches(6, 5)
|
715 |
+
pl.xlabel(name, color=axis_color, fontsize=13)
|
716 |
+
pl.ylabel(labels['VALUE_FOR'] % name, color=axis_color, fontsize=13)
|
717 |
+
if title is not None:
|
718 |
+
pl.title(title, color=axis_color, fontsize=13)
|
719 |
+
pl.gca().xaxis.set_ticks_position('bottom')
|
720 |
+
pl.gca().yaxis.set_ticks_position('left')
|
721 |
+
pl.gca().spines['right'].set_visible(False)
|
722 |
+
pl.gca().spines['top'].set_visible(False)
|
723 |
+
pl.gca().tick_params(color=axis_color, labelcolor=axis_color, labelsize=11)
|
724 |
+
for spine in pl.gca().spines.values():
|
725 |
+
spine.set_edgecolor(axis_color)
|
726 |
+
if type(xd[0]) == str:
|
727 |
+
pl.xticks([name_map[n] for n in xnames], xnames, rotation='vertical', fontsize=11)
|
728 |
+
# if show:
|
729 |
+
# pl.show()
|
730 |
+
return mpl_fig, interaction_index
|