Spaces:
Sleeping
Sleeping
import warnings | |
import iml | |
import numpy as np | |
from iml import Instance, Model | |
from iml.datatypes import DenseData | |
from iml.explanations import AdditiveExplanation | |
from iml.links import IdentityLink | |
from scipy.stats import gaussian_kde | |
import matplotlib | |
try: | |
import matplotlib.pyplot as pl | |
from matplotlib.colors import LinearSegmentedColormap | |
from matplotlib.ticker import MaxNLocator | |
cdict1 = { | |
'red': ((0.0, 0.11764705882352941, 0.11764705882352941), | |
(1.0, 0.9607843137254902, 0.9607843137254902)), | |
'green': ((0.0, 0.5333333333333333, 0.5333333333333333), | |
(1.0, 0.15294117647058825, 0.15294117647058825)), | |
'blue': ((0.0, 0.8980392156862745, 0.8980392156862745), | |
(1.0, 0.3411764705882353, 0.3411764705882353)), | |
'alpha': ((0.0, 1, 1), | |
(0.5, 0.3, 0.3), | |
(1.0, 1, 1)) | |
} # #1E88E5 -> #ff0052 | |
red_blue = LinearSegmentedColormap('RedBlue', cdict1) | |
cdict1 = { | |
'red': ((0.0, 0.11764705882352941, 0.11764705882352941), | |
(1.0, 0.9607843137254902, 0.9607843137254902)), | |
'green': ((0.0, 0.5333333333333333, 0.5333333333333333), | |
(1.0, 0.15294117647058825, 0.15294117647058825)), | |
'blue': ((0.0, 0.8980392156862745, 0.8980392156862745), | |
(1.0, 0.3411764705882353, 0.3411764705882353)), | |
'alpha': ((0.0, 1, 1), | |
(0.5, 1, 1), | |
(1.0, 1, 1)) | |
} # #1E88E5 -> #ff0052 | |
red_blue_solid = LinearSegmentedColormap('RedBlue', cdict1) | |
except ImportError: | |
pass | |
labels = { | |
'MAIN_EFFECT': "SHAP main effect value for\n%s", | |
'INTERACTION_VALUE': "SHAP interaction value", | |
'INTERACTION_EFFECT': "SHAP interaction value for\n%s and %s", | |
'VALUE': "SHAP value (impact on model output)", | |
'VALUE_FOR': "SHAP value for\n%s", | |
'PLOT_FOR': "SHAP plot for %s", | |
'FEATURE': "Feature %s", | |
'FEATURE_VALUE': "Feature value", | |
'FEATURE_VALUE_LOW': "Low", | |
'FEATURE_VALUE_HIGH': "High", | |
'JOINT_VALUE': "Joint SHAP value" | |
} | |
def shap_summary_plot(shap_values, features=None, feature_names=None, max_display=None, plot_type="dot", | |
color=None, axis_color="#333333", title=None, alpha=1, show=True, sort=True, | |
color_bar=True, auto_size_plot=True, layered_violin_max_num_bins=20): | |
"""Create a SHAP summary plot, colored by feature values when they are provided. | |
Parameters | |
---------- | |
shap_values : numpy.array | |
Matrix of SHAP values (# samples x # features) | |
features : numpy.array or pandas.DataFrame or list | |
Matrix of feature values (# samples x # features) or a feature_names list as shorthand | |
feature_names : list | |
Names of the features (length # features) | |
max_display : int | |
How many top features to include in the plot (default is 20, or 7 for interaction plots) | |
plot_type : "dot" (default) or "violin" | |
What type of summary plot to produce | |
""" | |
assert len(shap_values.shape) != 1, "Summary plots need a matrix of shap_values, not a vector." | |
# default color: | |
if color is None: | |
color = "coolwarm" if plot_type == 'layered_violin' else "#ff0052" | |
# convert from a DataFrame or other types | |
if str(type(features)) == "<class 'pandas.core.frame.DataFrame'>": | |
if feature_names is None: | |
feature_names = features.columns | |
features = features.values | |
elif str(type(features)) == "<class 'list'>": | |
if feature_names is None: | |
feature_names = features | |
features = None | |
elif (features is not None) and len(features.shape) == 1 and feature_names is None: | |
feature_names = features | |
features = None | |
if feature_names is None: | |
feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)] | |
mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1)) | |
# plotting SHAP interaction values | |
if len(shap_values.shape) == 3: | |
if max_display is None: | |
max_display = 7 | |
else: | |
max_display = min(len(feature_names), max_display) | |
sort_inds = np.argsort(-np.abs(shap_values[:, :-1, :-1].sum(1)).sum(0)) | |
# get plotting limits | |
delta = 1.0 / (shap_values.shape[1] ** 2) | |
slow = np.nanpercentile(shap_values, delta) | |
shigh = np.nanpercentile(shap_values, 100 - delta) | |
v = max(abs(slow), abs(shigh)) | |
slow = -0.2 | |
shigh = 0.2 | |
# mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1)) | |
ax = mpl_fig.subplot(1, max_display, 1) | |
proj_shap_values = shap_values[:, sort_inds[0], np.hstack((sort_inds, len(sort_inds)))] | |
proj_shap_values[:, 1:] *= 2 # because off diag effects are split in half | |
shap_summary_plot( | |
proj_shap_values, features[:, sort_inds], | |
feature_names=feature_names[sort_inds], | |
sort=False, show=False, color_bar=False, | |
auto_size_plot=False, | |
max_display=max_display | |
) | |
pl.xlim((slow, shigh)) | |
pl.xlabel("") | |
title_length_limit = 11 | |
pl.title(shorten_text(feature_names[sort_inds[0]], title_length_limit)) | |
for i in range(1, max_display): | |
ind = sort_inds[i] | |
pl.subplot(1, max_display, i + 1) | |
proj_shap_values = shap_values[:, ind, np.hstack((sort_inds, len(sort_inds)))] | |
proj_shap_values *= 2 | |
proj_shap_values[:, i] /= 2 # because only off diag effects are split in half | |
shap_summary_plot( | |
proj_shap_values, features[:, sort_inds], | |
sort=False, | |
feature_names=df_shap.columns, #["" for i in range(features.shape[1])], | |
show=False, | |
color_bar=False, | |
auto_size_plot=False, | |
max_display=max_display | |
) | |
pl.xlim((slow, shigh)) | |
pl.xlabel("") | |
if i == max_display // 2: | |
pl.xlabel(labels['INTERACTION_VALUE']) | |
pl.title(shorten_text(feature_names[ind], title_length_limit)) | |
pl.tight_layout(pad=0, w_pad=0, h_pad=0.0) | |
pl.subplots_adjust(hspace=0, wspace=0.1) | |
# if show: | |
# # pl.show() | |
return mpl_fig | |
if max_display is None: | |
max_display = 20 | |
if sort: | |
# order features by the sum of their effect magnitudes | |
feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1]) | |
feature_order = feature_order[-min(max_display, len(feature_order)):] | |
else: | |
feature_order = np.flip(np.arange(min(max_display, shap_values.shape[1] - 1)), 0) | |
row_height = 0.4 | |
if auto_size_plot: | |
pl.gcf().set_size_inches(8, len(feature_order) * row_height + 1.5) | |
pl.axvline(x=0, color="#999999", zorder=-1) | |
if plot_type == "dot": | |
for pos, i in enumerate(feature_order): | |
pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1) | |
shaps = shap_values[:, i] | |
values = None if features is None else features[:, i] | |
inds = np.arange(len(shaps)) | |
np.random.shuffle(inds) | |
if values is not None: | |
values = values[inds] | |
shaps = shaps[inds] | |
colored_feature = True | |
try: | |
values = np.array(values, dtype=np.float64) # make sure this can be numeric | |
except: | |
colored_feature = False | |
N = len(shaps) | |
# hspacing = (np.max(shaps) - np.min(shaps)) / 200 | |
# curr_bin = [] | |
nbins = 100 | |
quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8)) | |
inds = np.argsort(quant + np.random.randn(N) * 1e-6) | |
layer = 0 | |
last_bin = -1 | |
ys = np.zeros(N) | |
for ind in inds: | |
if quant[ind] != last_bin: | |
layer = 0 | |
ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1) | |
layer += 1 | |
last_bin = quant[ind] | |
ys *= 0.9 * (row_height / np.max(ys + 1)) | |
if features is not None and colored_feature: | |
# trim the color range, but prevent the color range from collapsing | |
vmin = np.nanpercentile(values, 5) | |
vmax = np.nanpercentile(values, 95) | |
if vmin == vmax: | |
vmin = np.nanpercentile(values, 1) | |
vmax = np.nanpercentile(values, 99) | |
if vmin == vmax: | |
vmin = np.min(values) | |
vmax = np.max(values) | |
assert features.shape[0] == len(shaps), "Feature and SHAP matrices must have the same number of rows!" | |
nan_mask = np.isnan(values) | |
pl.scatter(shaps[nan_mask], pos + ys[nan_mask], color="#777777", vmin=vmin, | |
vmax=vmax, s=16, alpha=alpha, linewidth=0, | |
zorder=3, rasterized=len(shaps) > 500) | |
pl.scatter(shaps[np.invert(nan_mask)], pos + ys[np.invert(nan_mask)], | |
cmap=red_blue, vmin=vmin, vmax=vmax, s=16, | |
c=values[np.invert(nan_mask)], alpha=alpha, linewidth=0, | |
zorder=3, rasterized=len(shaps) > 500) | |
else: | |
pl.scatter(shaps, pos + ys, s=16, alpha=alpha, linewidth=0, zorder=3, | |
color=color if colored_feature else "#777777", rasterized=len(shaps) > 500) | |
elif plot_type == "violin": | |
for pos, i in enumerate(feature_order): | |
pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1) | |
if features is not None: | |
global_low = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 1) | |
global_high = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 99) | |
for pos, i in enumerate(feature_order): | |
shaps = shap_values[:, i] | |
shap_min, shap_max = np.min(shaps), np.max(shaps) | |
rng = shap_max - shap_min | |
xs = np.linspace(np.min(shaps) - rng * 0.2, np.max(shaps) + rng * 0.2, 100) | |
if np.std(shaps) < (global_high - global_low) / 100: | |
ds = gaussian_kde(shaps + np.random.randn(len(shaps)) * (global_high - global_low) / 100)(xs) | |
else: | |
ds = gaussian_kde(shaps)(xs) | |
ds /= np.max(ds) * 3 | |
values = features[:, i] | |
window_size = max(10, len(values) // 20) | |
smooth_values = np.zeros(len(xs) - 1) | |
sort_inds = np.argsort(shaps) | |
trailing_pos = 0 | |
leading_pos = 0 | |
running_sum = 0 | |
back_fill = 0 | |
for j in range(len(xs) - 1): | |
while leading_pos < len(shaps) and xs[j] >= shaps[sort_inds[leading_pos]]: | |
running_sum += values[sort_inds[leading_pos]] | |
leading_pos += 1 | |
if leading_pos - trailing_pos > 20: | |
running_sum -= values[sort_inds[trailing_pos]] | |
trailing_pos += 1 | |
if leading_pos - trailing_pos > 0: | |
smooth_values[j] = running_sum / (leading_pos - trailing_pos) | |
for k in range(back_fill): | |
smooth_values[j - k - 1] = smooth_values[j] | |
else: | |
back_fill += 1 | |
vmin = np.nanpercentile(values, 5) | |
vmax = np.nanpercentile(values, 95) | |
if vmin == vmax: | |
vmin = np.nanpercentile(values, 1) | |
vmax = np.nanpercentile(values, 99) | |
if vmin == vmax: | |
vmin = np.min(values) | |
vmax = np.max(values) | |
pl.scatter(shaps, np.ones(shap_values.shape[0]) * pos, s=9, cmap=red_blue_solid, vmin=vmin, vmax=vmax, | |
c=values, alpha=alpha, linewidth=0, zorder=1) | |
# smooth_values -= nxp.nanpercentile(smooth_values, 5) | |
# smooth_values /= np.nanpercentile(smooth_values, 95) | |
smooth_values -= vmin | |
if vmax - vmin > 0: | |
smooth_values /= vmax - vmin | |
for i in range(len(xs) - 1): | |
if ds[i] > 0.05 or ds[i + 1] > 0.05: | |
pl.fill_between([xs[i], xs[i + 1]], [pos + ds[i], pos + ds[i + 1]], | |
[pos - ds[i], pos - ds[i + 1]], color=red_blue_solid(smooth_values[i]), | |
zorder=2) | |
else: | |
parts = pl.violinplot(shap_values[:, feature_order], range(len(feature_order)), points=200, vert=False, | |
widths=0.7, | |
showmeans=False, showextrema=False, showmedians=False) | |
for pc in parts['bodies']: | |
pc.set_facecolor(color) | |
pc.set_edgecolor('none') | |
pc.set_alpha(alpha) | |
elif plot_type == "layered_violin": # courtesy of @kodonnell | |
num_x_points = 200 | |
bins = np.linspace(0, features.shape[0], layered_violin_max_num_bins + 1).round(0).astype( | |
'int') # the indices of the feature data corresponding to each bin | |
shap_min, shap_max = np.min(shap_values[:, :-1]), np.max(shap_values[:, :-1]) | |
x_points = np.linspace(shap_min, shap_max, num_x_points) | |
# loop through each feature and plot: | |
for pos, ind in enumerate(feature_order): | |
# decide how to handle: if #unique < layered_violin_max_num_bins then split by unique value, otherwise use bins/percentiles. | |
# to keep simpler code, in the case of uniques, we just adjust the bins to align with the unique counts. | |
feature = features[:, ind] | |
unique, counts = np.unique(feature, return_counts=True) | |
if unique.shape[0] <= layered_violin_max_num_bins: | |
order = np.argsort(unique) | |
thesebins = np.cumsum(counts[order]) | |
thesebins = np.insert(thesebins, 0, 0) | |
else: | |
thesebins = bins | |
nbins = thesebins.shape[0] - 1 | |
# order the feature data so we can apply percentiling | |
order = np.argsort(feature) | |
# x axis is located at y0 = pos, with pos being there for offset | |
y0 = np.ones(num_x_points) * pos | |
# calculate kdes: | |
ys = np.zeros((nbins, num_x_points)) | |
for i in range(nbins): | |
# get shap values in this bin: | |
shaps = shap_values[order[thesebins[i]:thesebins[i + 1]], ind] | |
# if there's only one element, then we can't | |
if shaps.shape[0] == 1: | |
warnings.warn( | |
"not enough data in bin #%d for feature %s, so it'll be ignored. Try increasing the number of records to plot." | |
% (i, feature_names[ind])) | |
# to ignore it, just set it to the previous y-values (so the area between them will be zero). Not ys is already 0, so there's | |
# nothing to do if i == 0 | |
if i > 0: | |
ys[i, :] = ys[i - 1, :] | |
continue | |
# save kde of them: note that we add a tiny bit of gaussian noise to avoid singular matrix errors | |
ys[i, :] = gaussian_kde(shaps + np.random.normal(loc=0, scale=0.001, size=shaps.shape[0]))(x_points) | |
# scale it up so that the 'size' of each y represents the size of the bin. For continuous data this will | |
# do nothing, but when we've gone with the unqique option, this will matter - e.g. if 99% are male and 1% | |
# female, we want the 1% to appear a lot smaller. | |
size = thesebins[i + 1] - thesebins[i] | |
bin_size_if_even = features.shape[0] / nbins | |
relative_bin_size = size / bin_size_if_even | |
ys[i, :] *= relative_bin_size | |
# now plot 'em. We don't plot the individual strips, as this can leave whitespace between them. | |
# instead, we plot the full kde, then remove outer strip and plot over it, etc., to ensure no | |
# whitespace | |
ys = np.cumsum(ys, axis=0) | |
width = 0.8 | |
scale = ys.max() * 2 / width # 2 is here as we plot both sides of x axis | |
for i in range(nbins - 1, -1, -1): | |
y = ys[i, :] / scale | |
c = pl.get_cmap(color)(i / ( | |
nbins - 1)) if color in pl.cm.datad else color # if color is a cmap, use it, otherwise use a color | |
pl.fill_between(x_points, pos - y, pos + y, facecolor=c) | |
pl.xlim(shap_min, shap_max) | |
# draw the color bar | |
if color_bar and features is not None and (plot_type != "layered_violin" or color in pl.cm.datad): | |
import matplotlib.cm as cm | |
m = cm.ScalarMappable(cmap=red_blue_solid if plot_type != "layered_violin" else pl.get_cmap(color)) | |
m.set_array([0, 1]) | |
cb = pl.colorbar(m, ticks=[0, 1], aspect=1000) | |
cb.set_ticklabels([labels['FEATURE_VALUE_LOW'], labels['FEATURE_VALUE_HIGH']]) | |
cb.set_label(labels['FEATURE_VALUE'], size=12, labelpad=0) | |
cb.ax.tick_params(labelsize=11, length=0) | |
cb.set_alpha(1) | |
cb.outline.set_visible(False) | |
bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted()) | |
cb.ax.set_aspect((bbox.height - 0.9) * 20) | |
# cb.draw_all() | |
pl.gca().xaxis.set_ticks_position('bottom') | |
pl.gca().yaxis.set_ticks_position('none') | |
pl.gca().spines['right'].set_visible(False) | |
pl.gca().spines['top'].set_visible(False) | |
pl.gca().spines['left'].set_visible(False) | |
pl.gca().tick_params(color=axis_color, labelcolor=axis_color) | |
pl.yticks(range(len(feature_order)), [feature_names[i] for i in feature_order], fontsize=13) | |
pl.gca().tick_params('y', length=20, width=0.5, which='major') | |
pl.gca().tick_params('x', labelsize=11) | |
pl.ylim(-1, len(feature_order)) | |
pl.xlabel(labels['VALUE'], fontsize=13) | |
pl.tight_layout() | |
# if show: | |
# pl.show() | |
return mpl_fig | |
def approx_interactions(index, shap_values, X): | |
""" Order other features by how much interaction they seem to have with the feature at the given index. | |
This just bins the SHAP values for a feature along that feature's value. For true Shapley interaction | |
index values for SHAP see the interaction_contribs option implemented in XGBoost. | |
""" | |
if X.shape[0] > 10000: | |
a = np.arange(X.shape[0]) | |
np.random.shuffle(a) | |
inds = a[:10000] | |
else: | |
inds = np.arange(X.shape[0]) | |
x = X[inds, index] | |
srt = np.argsort(x) | |
shap_ref = shap_values[inds, index] | |
shap_ref = shap_ref[srt] | |
inc = max(min(int(len(x) / 10.0), 50), 1) | |
interactions = [] | |
for i in range(X.shape[1]): | |
val_other = X[inds, i][srt].astype(np.float) | |
v = 0.0 | |
if not (i == index or np.sum(np.abs(val_other)) < 1e-8): | |
for j in range(0, len(x), inc): | |
if np.std(val_other[j:j + inc]) > 0 and np.std(shap_ref[j:j + inc]) > 0: | |
v += abs(np.corrcoef(shap_ref[j:j + inc], val_other[j:j + inc])[0, 1]) | |
interactions.append(v) | |
return np.argsort(-np.abs(interactions)) | |
def shap_dependence_plot(ind, shap_values, features, feature_names=None, display_features=None, | |
interaction_index="auto", color="#1E88E5", axis_color="#333333", | |
dot_size=16, alpha=1, title=None, show=True): | |
""" | |
Create a SHAP dependence plot, colored by an interaction feature. | |
Parameters | |
---------- | |
ind : int | |
Index of the feature to plot. | |
shap_values : numpy.array | |
Matrix of SHAP values (# samples x # features) | |
features : numpy.array or pandas.DataFrame | |
Matrix of feature values (# samples x # features) | |
feature_names : list | |
Names of the features (length # features) | |
display_features : numpy.array or pandas.DataFrame | |
Matrix of feature values for visual display (such as strings instead of coded values) | |
interaction_index : "auto", None, or int | |
The index of the feature used to color the plot. | |
""" | |
# convert from DataFrames if we got any | |
if str(type(features)).endswith("'pandas.core.frame.DataFrame'>"): | |
if feature_names is None: | |
feature_names = features.columns | |
features = features.values | |
if str(type(display_features)).endswith("'pandas.core.frame.DataFrame'>"): | |
if feature_names is None: | |
feature_names = display_features.columns | |
display_features = display_features.values | |
elif display_features is None: | |
display_features = features | |
if feature_names is None: | |
feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)] | |
# allow vectors to be passed | |
if len(shap_values.shape) == 1: | |
shap_values = np.reshape(shap_values, len(shap_values), 1) | |
if len(features.shape) == 1: | |
features = np.reshape(features, len(features), 1) | |
def convert_name(ind): | |
if type(ind) == str: | |
nzinds = np.where(feature_names == ind)[0] | |
if len(nzinds) == 0: | |
print("Could not find feature named: " + ind) | |
return None | |
else: | |
return nzinds[0] | |
else: | |
return ind | |
ind = convert_name(ind) | |
mpl_fig = pl.gcf() | |
ax = mpl_fig.gca() | |
# plotting SHAP interaction values | |
if len(shap_values.shape) == 3 and len(ind) == 2: | |
ind1 = convert_name(ind[0]) | |
ind2 = convert_name(ind[1]) | |
if ind1 == ind2: | |
proj_shap_values = shap_values[:, ind2, :] | |
else: | |
proj_shap_values = shap_values[:, ind2, :] * 2 # off-diag values are split in half | |
# TODO: remove recursion; generally the functions should be shorter for more maintainable code | |
return shap_dependence_plot( | |
ind1, proj_shap_values, features, feature_names=feature_names, | |
interaction_index=ind2, display_features=display_features, show=False | |
) | |
assert shap_values.shape[0] == features.shape[0], \ | |
"'shap_values' and 'features' values must have the same number of rows!" | |
assert shap_values.shape[1] == features.shape[1], \ | |
"'shap_values' must have the same number of columns as 'features'!" | |
# get both the raw and display feature values | |
xv = features[:, ind] | |
xd = display_features[:, ind] | |
s = shap_values[:, ind] | |
if type(xd[0]) == str: | |
name_map = {} | |
for i in range(len(xv)): | |
name_map[xd[i]] = xv[i] | |
xnames = list(name_map.keys()) | |
# allow a single feature name to be passed alone | |
if type(feature_names) == str: | |
feature_names = [feature_names] | |
name = feature_names[ind] | |
# guess what other feature as the stongest interaction with the plotted feature | |
if interaction_index == "auto": | |
interaction_index = approx_interactions(ind, shap_values, features)[0] | |
interaction_index = convert_name(interaction_index) | |
categorical_interaction = False | |
# get both the raw and display color values | |
if interaction_index is not None: | |
cv = features[:, interaction_index] | |
cd = display_features[:, interaction_index] | |
clow = np.nanpercentile(features[:, interaction_index].astype(np.float), 5) | |
chigh = np.nanpercentile(features[:, interaction_index].astype(np.float), 95) | |
if type(cd[0]) == str: | |
cname_map = {} | |
for i in range(len(cv)): | |
cname_map[cd[i]] = cv[i] | |
cnames = list(cname_map.keys()) | |
categorical_interaction = True | |
elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:, interaction_index])) < 50: | |
categorical_interaction = True | |
# discritize colors for categorical features | |
color_norm = None | |
if categorical_interaction and clow != chigh: | |
bounds = np.linspace(clow, chigh, chigh - clow + 2) | |
color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N) | |
# the actual scatter plot, TODO: adapt the dot_size to the number of data points? | |
if interaction_index is not None: | |
pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:, interaction_index], cmap=red_blue, | |
alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500) | |
else: | |
pl.scatter(xv, s, s=dot_size, linewidth=0, color="#1E88E5", | |
alpha=alpha, rasterized=len(xv) > 500) | |
if interaction_index != ind and interaction_index is not None: | |
# draw the color bar | |
if type(cd[0]) == str: | |
tick_positions = [cname_map[n] for n in cnames] | |
if len(tick_positions) == 2: | |
tick_positions[0] -= 0.25 | |
tick_positions[1] += 0.25 | |
cb = pl.colorbar(ticks=tick_positions) | |
cb.set_ticklabels(cnames) | |
else: | |
cb = pl.colorbar() | |
cb.set_label(feature_names[interaction_index], size=13) | |
cb.ax.tick_params(labelsize=11) | |
if categorical_interaction: | |
cb.ax.tick_params(length=0) | |
cb.set_alpha(1) | |
cb.outline.set_visible(False) | |
bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted()) | |
cb.ax.set_aspect((bbox.height - 0.7) * 20) | |
# make the plot more readable | |
if interaction_index != ind: | |
pl.gcf().set_size_inches(7.5, 5) | |
else: | |
pl.gcf().set_size_inches(6, 5) | |
# pl.xlabel(name, color=axis_color, fontsize=13) | |
# pl.ylabel(labels['VALUE_FOR'] % name, color=axis_color, fontsize=13) | |
if title is not None: | |
pl.title(title, color=axis_color, fontsize=13) | |
pl.gca().xaxis.set_ticks_position('bottom') | |
pl.gca().yaxis.set_ticks_position('left') | |
pl.gca().spines['right'].set_visible(False) | |
pl.gca().spines['top'].set_visible(False) | |
pl.gca().tick_params(color=axis_color, labelcolor=axis_color, labelsize=11) | |
for spine in pl.gca().spines.values(): | |
spine.set_edgecolor(axis_color) | |
if type(xd[0]) == str: | |
pl.xticks([name_map[n] for n in xnames], xnames, rotation='vertical', fontsize=11) | |
# if show: | |
# pl.show() | |
if ind1 == ind2: | |
pl.ylabel(labels['MAIN_EFFECT'] % feature_names[ind1]) | |
else: | |
pl.ylabel(labels['INTERACTION_EFFECT'] % (feature_names[ind1], feature_names[ind2])) | |
return mpl_fig, interaction_index | |
# # if show: | |
# # pl.show() | |
# return | |
# return mpl_fig | |
# assert shap_values.shape[0] == features.shape[0], "'shap_values' and 'features' values must have the same number of rows!" | |
# assert shap_values.shape[1] == features.shape[1] + 1, "'shap_values' must have one more column than 'features'!" | |
# get both the raw and display feature values | |
xv = features[:, ind] | |
xd = display_features[:, ind] | |
s = shap_values[:, ind] | |
if type(xd[0]) == str: | |
name_map = {} | |
for i in range(len(xv)): | |
name_map[xd[i]] = xv[i] | |
xnames = list(name_map.keys()) | |
# allow a single feature name to be passed alone | |
if type(feature_names) == str: | |
feature_names = [feature_names] | |
name = feature_names[ind] | |
# guess what other feature as the stongest interaction with the plotted feature | |
if interaction_index == "auto": | |
interaction_index = approx_interactions(ind, shap_values, features)[0] | |
interaction_index = convert_name(interaction_index) | |
categorical_interaction = False | |
# get both the raw and display color values | |
if interaction_index is not None: | |
cv = features[:, interaction_index] | |
cd = display_features[:, interaction_index] | |
clow = np.nanpercentile(features[:, interaction_index].astype(np.float), 5) | |
chigh = np.nanpercentile(features[:, interaction_index].astype(np.float), 95) | |
if type(cd[0]) == str: | |
cname_map = {} | |
for i in range(len(cv)): | |
cname_map[cd[i]] = cv[i] | |
cnames = list(cname_map.keys()) | |
categorical_interaction = True | |
elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:, interaction_index])) < 50: | |
categorical_interaction = True | |
# discritize colors for categorical features | |
color_norm = None | |
if categorical_interaction and clow != chigh: | |
bounds = np.linspace(clow, chigh, chigh - clow + 2) | |
color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N) | |
# the actual scatter plot, TODO: adapt the dot_size to the number of data points? | |
if interaction_index is not None: | |
pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:, interaction_index], cmap=red_blue, | |
alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500) | |
else: | |
pl.scatter(xv, s, s=dot_size, linewidth=0, color="#1E88E5", | |
alpha=alpha, rasterized=len(xv) > 500) | |
if interaction_index != ind and interaction_index is not None: | |
# draw the color bar | |
if type(cd[0]) == str: | |
tick_positions = [cname_map[n] for n in cnames] | |
if len(tick_positions) == 2: | |
tick_positions[0] -= 0.25 | |
tick_positions[1] += 0.25 | |
cb = pl.colorbar(ticks=tick_positions) | |
cb.set_ticklabels(cnames) | |
else: | |
cb = pl.colorbar() | |
cb.set_label(feature_names[interaction_index], size=13) | |
cb.ax.tick_params(labelsize=11) | |
if categorical_interaction: | |
cb.ax.tick_params(length=0) | |
cb.set_alpha(1) | |
cb.outline.set_visible(False) | |
bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted()) | |
cb.ax.set_aspect((bbox.height - 0.7) * 20) | |
# make the plot more readable | |
if interaction_index != ind: | |
pl.gcf().set_size_inches(7.5, 5) | |
else: | |
pl.gcf().set_size_inches(6, 5) | |
pl.xlabel(name, color=axis_color, fontsize=13) | |
pl.ylabel(labels['VALUE_FOR'] % name, color=axis_color, fontsize=13) | |
if title is not None: | |
pl.title(title, color=axis_color, fontsize=13) | |
pl.gca().xaxis.set_ticks_position('bottom') | |
pl.gca().yaxis.set_ticks_position('left') | |
pl.gca().spines['right'].set_visible(False) | |
pl.gca().spines['top'].set_visible(False) | |
pl.gca().tick_params(color=axis_color, labelcolor=axis_color, labelsize=11) | |
for spine in pl.gca().spines.values(): | |
spine.set_edgecolor(axis_color) | |
if type(xd[0]) == str: | |
pl.xticks([name_map[n] for n in xnames], xnames, rotation='vertical', fontsize=11) | |
# if show: | |
# pl.show() | |
return mpl_fig, interaction_index |