Spaces:
Running
Running
hlnicholls
commited on
Commit
·
8a6cf88
1
Parent(s):
15625ec
feat: updated interace
Browse files- __pycache__/dynamic_shap_plot.cpython-38.pyc +0 -0
- __pycache__/dynamic_shap_plots.cpython-38.pyc +0 -0
- __pycache__/shap_plots.cpython-38.pyc +0 -0
- app.py +132 -101
- dynamic_shap_plot.py +0 -118
- dynamic_shap_plots.py +346 -0
- requirements.txt +1 -0
- shap_plots.py +729 -729
__pycache__/dynamic_shap_plot.cpython-38.pyc
ADDED
Binary file (3.08 kB). View file
|
|
__pycache__/dynamic_shap_plots.cpython-38.pyc
ADDED
Binary file (8.14 kB). View file
|
|
__pycache__/shap_plots.cpython-38.pyc
ADDED
Binary file (16.9 kB). View file
|
|
app.py
CHANGED
@@ -7,11 +7,12 @@ import sklearn
|
|
7 |
import catboost
|
8 |
import shap
|
9 |
from shap_plots import shap_summary_plot
|
10 |
-
from
|
11 |
import plotly.tools as tls
|
12 |
-
|
13 |
-
import matplotlib
|
14 |
import plotly.graph_objs as go
|
|
|
15 |
try:
|
16 |
import matplotlib.pyplot as pl
|
17 |
from matplotlib.colors import LinearSegmentedColormap
|
@@ -21,133 +22,163 @@ except ImportError:
|
|
21 |
|
22 |
st.set_option('deprecation.showPyplotGlobalUse', False)
|
23 |
|
24 |
-
seed=
|
25 |
|
26 |
annotations = pd.read_csv("all_genes_merged_ml_data.csv")
|
27 |
-
# TODO remove this placeholder when imputation is finished:
|
28 |
annotations.fillna(0, inplace=True)
|
29 |
annotations = annotations.set_index("Gene")
|
30 |
|
31 |
-
|
32 |
-
model_path = "best_model_fitted.pkl" # Update this path if your model is stored elsewhere
|
33 |
with open(model_path, 'rb') as file:
|
34 |
catboost_model = pickle.load(file)
|
35 |
|
36 |
-
# For a multi-class classification model, obtaining probabilities per class
|
37 |
probabilities = catboost_model.predict_proba(annotations)
|
38 |
-
|
39 |
-
# Creating a DataFrame for these probabilities
|
40 |
-
# Assuming classes are ordered as 'most likely', 'probable', and 'least likely' in the model
|
41 |
-
prob_df = pd.DataFrame(probabilities,
|
42 |
-
index=annotations.index,
|
43 |
-
columns=['Probability_Most_Likely', 'Probability_Probable', 'Probability_Least_Likely'])
|
44 |
-
|
45 |
-
# Dynamically including all original features from annotations plus the new probability columns
|
46 |
df_total = pd.concat([prob_df, annotations], axis=1)
|
47 |
|
|
|
|
|
|
|
|
|
48 |
|
49 |
st.title('Blood Pressure Gene Prioritisation Post-GWAS')
|
50 |
-
st.markdown("""
|
51 |
-
A machine learning pipeline for predicting disease-causing genes post-genome-wide association study in blood pressure.
|
52 |
-
|
53 |
-
|
54 |
-
""")
|
55 |
-
|
56 |
-
collect_genes = lambda x : [str(i) for i in re.split(",|,\s+|\s+", x) if i != ""]
|
57 |
|
|
|
|
|
58 |
input_gene_list = st.text_input("Input a list of multiple HGNC genes (enter comma separated):")
|
59 |
gene_list = collect_genes(input_gene_list)
|
60 |
explainer = shap.TreeExplainer(catboost_model)
|
61 |
|
62 |
@st.cache_data
|
63 |
def convert_df(df):
|
64 |
-
|
65 |
|
66 |
probability_columns = ['Probability_Most_Likely', 'Probability_Probable', 'Probability_Least_Likely']
|
67 |
features_list = [column for column in df_total.columns if column not in probability_columns]
|
68 |
features = df_total[features_list]
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
"text/csv",
|
88 |
-
key='download-csv'
|
89 |
-
)
|
90 |
-
|
91 |
-
# For SHAP values, assuming explainer is already fitted to your model
|
92 |
-
df_shap = df.drop(columns=probability_columns + ['Gene']) # Exclude non-feature columns
|
93 |
-
shap_values = explainer.shap_values(df_shap)
|
94 |
-
|
95 |
-
# Handle multiclass scenario: SHAP values will be a list of matrices, one per class
|
96 |
-
# Plotting the summary plot for the first class as an example
|
97 |
-
# You may loop through each class or handle it differently based on your needs
|
98 |
-
class_index = 0 # Example: plotting for the first class
|
99 |
-
shap.summary_plot(shap_values[class_index], df_shap, show=False)
|
100 |
-
st.pyplot(bbox_inches='tight')
|
101 |
-
st.caption("SHAP Summary Plot of All Input Genes")
|
102 |
-
|
103 |
-
else:
|
104 |
-
pass
|
105 |
|
|
|
|
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
-
if input_gene:
|
120 |
-
if ' ' in input_gene or ',' in input_gene:
|
121 |
-
st.write('Input Error: Please input only a single HGNC gene name with no white spaces or commas.')
|
122 |
else:
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
-
|
141 |
-
|
142 |
-
""
|
143 |
-
|
144 |
-
|
145 |
-
df_total_output['Gene'] = df_total_output.index
|
146 |
-
df_total_output.reset_index(drop=True, inplace=True)
|
147 |
-
#df_total_output = df_total_output[['Gene','XGB_Score', 'mousescore_Exomiser',
|
148 |
-
# 'SDI', 'Liver_GTExTPM', 'pLI_ExAC',
|
149 |
-
# 'HIPred',
|
150 |
-
# 'Cells - EBV-transformed lymphocytes_GTExTPM',
|
151 |
-
# 'Pituitary_GTExTPM',
|
152 |
-
# 'IPA_BP_annotation']]
|
153 |
-
st.dataframe(df_total_output)
|
|
|
7 |
import catboost
|
8 |
import shap
|
9 |
from shap_plots import shap_summary_plot
|
10 |
+
from dynamic_shap_plots import matplotlib_to_plotly, summary_plot_plotly_fig
|
11 |
import plotly.tools as tls
|
12 |
+
from dash import dcc
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
import plotly.graph_objs as go
|
15 |
+
|
16 |
try:
|
17 |
import matplotlib.pyplot as pl
|
18 |
from matplotlib.colors import LinearSegmentedColormap
|
|
|
22 |
|
23 |
st.set_option('deprecation.showPyplotGlobalUse', False)
|
24 |
|
25 |
+
seed = 0
|
26 |
|
27 |
annotations = pd.read_csv("all_genes_merged_ml_data.csv")
|
|
|
28 |
annotations.fillna(0, inplace=True)
|
29 |
annotations = annotations.set_index("Gene")
|
30 |
|
31 |
+
model_path = "best_model_fitted.pkl"
|
|
|
32 |
with open(model_path, 'rb') as file:
|
33 |
catboost_model = pickle.load(file)
|
34 |
|
|
|
35 |
probabilities = catboost_model.predict_proba(annotations)
|
36 |
+
prob_df = pd.DataFrame(probabilities, index=annotations.index, columns=['Probability_Most_Likely', 'Probability_Probable', 'Probability_Least_Likely'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
df_total = pd.concat([prob_df, annotations], axis=1)
|
38 |
|
39 |
+
# Create tabs for navigation
|
40 |
+
with st.sidebar:
|
41 |
+
st.sidebar.title("Navigation")
|
42 |
+
tab = st.sidebar.radio("Go to", ("Gene Prioritisation", "Interactive SHAP Plot", "Supervised SHAP Clustering"))
|
43 |
|
44 |
st.title('Blood Pressure Gene Prioritisation Post-GWAS')
|
45 |
+
st.markdown("""A machine learning pipeline for predicting disease-causing genes post-genome-wide association study in blood pressure.""")
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
+
# Define a function to collect genes from input
|
48 |
+
collect_genes = lambda x: [str(i) for i in re.split(",|,\s+|\s+", x) if i != ""]
|
49 |
input_gene_list = st.text_input("Input a list of multiple HGNC genes (enter comma separated):")
|
50 |
gene_list = collect_genes(input_gene_list)
|
51 |
explainer = shap.TreeExplainer(catboost_model)
|
52 |
|
53 |
@st.cache_data
|
54 |
def convert_df(df):
|
55 |
+
return df.to_csv(index=False).encode('utf-8')
|
56 |
|
57 |
probability_columns = ['Probability_Most_Likely', 'Probability_Probable', 'Probability_Least_Likely']
|
58 |
features_list = [column for column in df_total.columns if column not in probability_columns]
|
59 |
features = df_total[features_list]
|
60 |
|
61 |
+
# Page 1: Gene Prioritisation
|
62 |
+
if tab == "Gene Prioritisation":
|
63 |
+
if len(gene_list) > 1:
|
64 |
+
df = df_total[df_total.index.isin(gene_list)]
|
65 |
+
df['Gene'] = df.index
|
66 |
+
df.reset_index(drop=True, inplace=True)
|
67 |
+
|
68 |
+
required_columns = ['Gene'] + probability_columns + [column for column in df.columns if column not in probability_columns and column != 'Gene']
|
69 |
+
df = df[required_columns]
|
70 |
+
st.dataframe(df)
|
71 |
+
|
72 |
+
output = df[['Gene'] + probability_columns]
|
73 |
+
csv = convert_df(output)
|
74 |
+
st.download_button("Download Gene Prioritisation", csv, "bp_gene_prioritisation.csv", "text/csv", key='download-csv')
|
75 |
+
|
76 |
+
df_shap = df.drop(columns=probability_columns + ['Gene'])
|
77 |
+
shap_values = explainer.shap_values(df_shap)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
+
col1, col2 = st.columns(2)
|
80 |
+
class_names = ["Most likely", "Probable", "Least likely"]
|
81 |
|
82 |
+
with col1:
|
83 |
+
st.subheader("Global SHAP Summary Plot")
|
84 |
+
shap.summary_plot(shap_values, df_shap, plot_type="bar", class_names=class_names)
|
85 |
+
st.pyplot(bbox_inches='tight', clear_figure=True)
|
86 |
|
87 |
+
with col2:
|
88 |
+
st.subheader(f"{class_names[0]} Gene Prediction")
|
89 |
+
shap.summary_plot(shap_values[0], df_shap)
|
90 |
+
st.pyplot(bbox_inches='tight', clear_figure=True)
|
91 |
+
|
92 |
+
col3, col4 = st.columns(2)
|
93 |
+
|
94 |
+
with col3:
|
95 |
+
st.subheader(f"{class_names[1]} Gene Prediction")
|
96 |
+
shap.summary_plot(shap_values[1], df_shap)
|
97 |
+
st.pyplot(bbox_inches='tight', clear_figure=True)
|
98 |
+
|
99 |
+
with col4:
|
100 |
+
st.subheader(f"{class_names[2]} Gene Prediction")
|
101 |
+
shap.summary_plot(shap_values[2], df_shap)
|
102 |
+
st.pyplot(bbox_inches='tight', clear_figure=True)
|
103 |
|
|
|
|
|
|
|
104 |
else:
|
105 |
+
pass
|
106 |
+
|
107 |
+
input_gene = st.text_input("Input an individual HGNC gene:")
|
108 |
+
if input_gene:
|
109 |
+
df2 = df_total[df_total.index == input_gene]
|
110 |
+
class_names = ["Most likely", "Probable", "Least likely"]
|
111 |
+
if not df2.empty:
|
112 |
+
df2['Gene'] = df2.index
|
113 |
+
df2.reset_index(drop=True, inplace=True)
|
114 |
|
115 |
+
required_columns = ['Gene'] + probability_columns + [col for col in df2.columns if col not in probability_columns and col != 'Gene']
|
116 |
+
df2 = df2[required_columns]
|
117 |
+
st.dataframe(df2)
|
118 |
+
|
119 |
+
if ' ' in input_gene or ',' in input_gene:
|
120 |
+
st.write('Input Error: Please input only a single HGNC gene name with no white spaces or commas.')
|
121 |
+
else:
|
122 |
+
df2_shap = df_total.loc[[input_gene], [col for col in df_total.columns if col not in probability_columns + ['Gene']]]
|
123 |
+
print(df2_shap.columns)
|
124 |
+
shap_values = explainer.shap_values(df2_shap)
|
125 |
+
shap.getjs()
|
126 |
+
|
127 |
+
for i in range(3):
|
128 |
+
st.subheader(f"Force Plot for {class_names[i]} Prediction")
|
129 |
+
force_plot = shap.force_plot(
|
130 |
+
explainer.expected_value[i],
|
131 |
+
shap_values[i],
|
132 |
+
df2_shap,
|
133 |
+
matplotlib=True,
|
134 |
+
show=False
|
135 |
+
)
|
136 |
+
st.pyplot(fig=force_plot)
|
137 |
+
else:
|
138 |
+
st.write("Gene not found in the dataset.")
|
139 |
+
else:
|
140 |
+
pass
|
141 |
+
|
142 |
+
st.markdown("""
|
143 |
+
### Total Gene Prioritisation Results for All Genes:
|
144 |
+
""")
|
145 |
+
|
146 |
+
df_total_output = df_total
|
147 |
+
df_total_output['Gene'] = df_total_output.index
|
148 |
+
#df_total_output.reset_index(drop=True, inplace=True)
|
149 |
+
st.dataframe(df_total_output)
|
150 |
+
csv = convert_df(df_total_output)
|
151 |
+
st.download_button("Download Gene Prioritisation", csv, "all_genes_bp_prioritisation.csv", "text/csv", key='download-all-csv')
|
152 |
+
|
153 |
+
# Page 2: Interactive SHAP Plot
|
154 |
+
|
155 |
+
elif tab == "Interactive SHAP Plot":
|
156 |
+
st.title("Interactive SHAP Plot")
|
157 |
+
if len(gene_list) > 1:
|
158 |
+
df = df_total[df_total.index.isin(gene_list)]
|
159 |
+
df['Gene'] = df.index
|
160 |
+
df.reset_index(drop=True, inplace=True)
|
161 |
+
|
162 |
+
required_columns = ['Gene'] + probability_columns + [column for column in df.columns if column not in probability_columns and column != 'Gene']
|
163 |
+
df = df[required_columns]
|
164 |
+
st.dataframe(df)
|
165 |
+
|
166 |
+
output = df[['Gene'] + probability_columns]
|
167 |
+
csv = convert_df(output)
|
168 |
+
st.download_button("Download Gene Prioritisation", csv, "bp_gene_prioritisation.csv", "text/csv", key='download-csv')
|
169 |
+
|
170 |
+
df_shap = df.drop(columns=probability_columns + ['Gene'])
|
171 |
+
shap_values = explainer.shap_values(df_shap)
|
172 |
+
|
173 |
+
# Use shap's summary_plot function for interactivity
|
174 |
+
# summary_plot = shap.summary_plot(shap_values[0], df_shap, plot_type='interactive', max_display=10)
|
175 |
+
summary_plot = summary_plot_plotly_fig(df_shap, shap_values[0], max_display=10)
|
176 |
+
st.pyplot(summary_plot)
|
177 |
+
st.caption("SHAP Summary Plot of All Input Genes")
|
178 |
+
|
179 |
|
180 |
+
# Page 3: Supervised SHAP Clustering
|
181 |
+
elif tab == "Supervised SHAP Clustering":
|
182 |
+
st.title("Supervised SHAP Clustering")
|
183 |
+
# Add your code here to implement supervised SHAP clustering
|
184 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dynamic_shap_plot.py
DELETED
@@ -1,118 +0,0 @@
|
|
1 |
-
from shap_plots import shap_summary_plot, shap_dependence_plot
|
2 |
-
import plotly.tools as tls
|
3 |
-
import dash_core_components as dcc
|
4 |
-
import pandas as pd
|
5 |
-
import numpy as np
|
6 |
-
import xgboost
|
7 |
-
import shap
|
8 |
-
import matplotlib
|
9 |
-
import plotly.graph_objs as go
|
10 |
-
try:
|
11 |
-
import matplotlib.pyplot as pl
|
12 |
-
from matplotlib.colors import LinearSegmentedColormap
|
13 |
-
from matplotlib.ticker import MaxNLocator
|
14 |
-
except ImportError:
|
15 |
-
pass
|
16 |
-
from sklearn import preprocessing
|
17 |
-
|
18 |
-
cdict1 = {
|
19 |
-
'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
|
20 |
-
(1.0, 0.9607843137254902, 0.9607843137254902)),
|
21 |
-
|
22 |
-
'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
|
23 |
-
(1.0, 0.15294117647058825, 0.15294117647058825)),
|
24 |
-
|
25 |
-
'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
|
26 |
-
(1.0, 0.3411764705882353, 0.3411764705882353)),
|
27 |
-
|
28 |
-
'alpha': ((0.0, 1, 1),
|
29 |
-
(0.5, 1, 1),
|
30 |
-
(1.0, 1, 1))
|
31 |
-
} # #1E88E5 -> #ff0052
|
32 |
-
red_blue = LinearSegmentedColormap('RedBlue', cdict1)
|
33 |
-
|
34 |
-
def matplotlib_to_plotly(cmap, pl_entries):
|
35 |
-
h = 1.0/(pl_entries-1)
|
36 |
-
pl_colorscale = []
|
37 |
-
|
38 |
-
for k in range(pl_entries):
|
39 |
-
C = list(map(np.uint8, np.array(cmap(k*h)[:3])*255))
|
40 |
-
pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])
|
41 |
-
|
42 |
-
return pl_colorscale
|
43 |
-
|
44 |
-
red_blue = matplotlib_to_plotly(red_blue, 255)
|
45 |
-
|
46 |
-
def summary_plot_plotly_fig(shap_values, df_shap, feature_names, max_display = 8):
|
47 |
-
#data = pd.read_csv(dataset, encoding="ISO-8859-1")
|
48 |
-
#X = data.drop(['target column'], axis=1)
|
49 |
-
|
50 |
-
#y = data[target]
|
51 |
-
#y = y/max(y)
|
52 |
-
|
53 |
-
#X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
|
54 |
-
|
55 |
-
#X_train.fillna((-999), inplace=True)
|
56 |
-
#X_test.fillna((-999), inplace=True)
|
57 |
-
|
58 |
-
#_, shap_values, feature_names = train_model_and_return_shap_values(X, y, target)
|
59 |
-
|
60 |
-
mpl_fig = shap_summary_plot(shap_values, df_shap, feature_names=feature_names, max_display=20)
|
61 |
-
|
62 |
-
plotly_fig = tls.mpl_to_plotly(mpl_fig)
|
63 |
-
|
64 |
-
plotly_fig['layout'] = {'xaxis': {'title': 'SHAP value (impact on model output)'}}
|
65 |
-
|
66 |
-
feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
|
67 |
-
feature_order = feature_order[-min(max_display, len(feature_order)):]
|
68 |
-
text = [df_shap.index[i] for i in df_shap.index]
|
69 |
-
text = iter(text)
|
70 |
-
|
71 |
-
for i in range(1, len(plotly_fig['data']), 2):
|
72 |
-
t = text.__next__()
|
73 |
-
plotly_fig['data'][i]['name'] = ''
|
74 |
-
plotly_fig['data'][i]['text'] = t
|
75 |
-
plotly_fig['data'][i]['hoverinfo'] = 'text'
|
76 |
-
#plotly_fig['data'][i]['text'] = df_shap.index
|
77 |
-
plotly_fig['data'][i]['y'] = feature_names[feature_order]
|
78 |
-
|
79 |
-
|
80 |
-
colorbar_trace = go.Scatter(x=[None],
|
81 |
-
y=[None],
|
82 |
-
mode='markers',
|
83 |
-
marker=dict(
|
84 |
-
colorscale=red_blue,
|
85 |
-
showscale=True,
|
86 |
-
cmin=-5,
|
87 |
-
cmax=5,
|
88 |
-
colorbar=dict(thickness=5, tickvals=[-5, 5], ticktext=['Low', 'High'], outlinewidth=0)
|
89 |
-
),
|
90 |
-
hoverinfo='none'
|
91 |
-
)
|
92 |
-
|
93 |
-
plotly_fig['layout']['showlegend'] = False
|
94 |
-
plotly_fig['layout']['hovermode'] = 'closest'
|
95 |
-
plotly_fig['layout']['height']=600
|
96 |
-
plotly_fig['layout']['width']=500
|
97 |
-
|
98 |
-
plotly_fig['layout']['xaxis'].update(zeroline=True, showline=True, ticklen=4, showgrid=False)
|
99 |
-
plotly_fig['layout']['yaxis'].update(dict(visible=True))
|
100 |
-
plotly_fig.add_trace(colorbar_trace)
|
101 |
-
plotly_fig.layout.update(
|
102 |
-
annotations=[dict(
|
103 |
-
x=1.18,
|
104 |
-
align="right",
|
105 |
-
valign="top",
|
106 |
-
text='Gene',
|
107 |
-
showarrow=False,
|
108 |
-
xref="paper",
|
109 |
-
yref="paper",
|
110 |
-
xanchor="right",
|
111 |
-
yanchor="middle",
|
112 |
-
textangle=-90,
|
113 |
-
font=dict(family='Calibri', size=14)
|
114 |
-
)
|
115 |
-
],
|
116 |
-
margin=dict(t=20)
|
117 |
-
)
|
118 |
-
return plotly_fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dynamic_shap_plots.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from shap_plots import shap_summary_plot, shap_dependence_plot
|
2 |
+
import plotly.tools as tls
|
3 |
+
import dash_core_components as dcc
|
4 |
+
import pandas as pd
|
5 |
+
from sklearn.model_selection import train_test_split
|
6 |
+
import numpy as np
|
7 |
+
import xgboost
|
8 |
+
import shap
|
9 |
+
import matplotlib
|
10 |
+
import plotly.graph_objs as go
|
11 |
+
try:
|
12 |
+
import matplotlib.pyplot as pl
|
13 |
+
from matplotlib.colors import LinearSegmentedColormap
|
14 |
+
from matplotlib.ticker import MaxNLocator
|
15 |
+
except ImportError:
|
16 |
+
pass
|
17 |
+
from sklearn import preprocessing
|
18 |
+
|
19 |
+
cdict1 = {
|
20 |
+
'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
|
21 |
+
(1.0, 0.9607843137254902, 0.9607843137254902)),
|
22 |
+
|
23 |
+
'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
|
24 |
+
(1.0, 0.15294117647058825, 0.15294117647058825)),
|
25 |
+
|
26 |
+
'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
|
27 |
+
(1.0, 0.3411764705882353, 0.3411764705882353)),
|
28 |
+
|
29 |
+
'alpha': ((0.0, 1, 1),
|
30 |
+
(0.5, 1, 1),
|
31 |
+
(1.0, 1, 1))
|
32 |
+
} # #1E88E5 -> #ff0052
|
33 |
+
red_blue = LinearSegmentedColormap('RedBlue', cdict1)
|
34 |
+
|
35 |
+
def matplotlib_to_plotly(cmap, pl_entries):
|
36 |
+
h = 1.0/(pl_entries-1)
|
37 |
+
pl_colorscale = []
|
38 |
+
|
39 |
+
for k in range(pl_entries):
|
40 |
+
C = list(map(np.uint8, np.array(cmap(k*h)[:3])*255))
|
41 |
+
pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])
|
42 |
+
|
43 |
+
return pl_colorscale
|
44 |
+
|
45 |
+
red_blue = matplotlib_to_plotly(red_blue, 255)
|
46 |
+
|
47 |
+
def summary_plot_plotly_fig(dataset, shap_values, target='target column', max_display = 20):
|
48 |
+
feature_names=dataset.columns
|
49 |
+
mpl_fig = shap_summary_plot(shap_values, dataset, feature_names=feature_names, max_display=20)
|
50 |
+
|
51 |
+
plotly_fig = tls.mpl_to_plotly(mpl_fig)
|
52 |
+
|
53 |
+
plotly_fig['layout'] = {'xaxis': {'title': 'SHAP value (impact on model output)'}}
|
54 |
+
|
55 |
+
feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
|
56 |
+
feature_order = feature_order[-min(max_display, len(feature_order)):]
|
57 |
+
text = [feature_names[i] for i in feature_order]
|
58 |
+
text = iter(text)
|
59 |
+
|
60 |
+
for i in range(1, len(plotly_fig['data']), 2):
|
61 |
+
t = text.__next__()
|
62 |
+
plotly_fig['data'][i]['name'] = ''
|
63 |
+
plotly_fig['data'][i]['text'] = t
|
64 |
+
plotly_fig['data'][i]['hoverinfo'] = 'text'
|
65 |
+
|
66 |
+
colorbar_trace = go.Scatter(x=[None],
|
67 |
+
y=[None],
|
68 |
+
mode='markers',
|
69 |
+
marker=dict(
|
70 |
+
colorscale=red_blue,
|
71 |
+
showscale=True,
|
72 |
+
cmin=-5,
|
73 |
+
cmax=5,
|
74 |
+
colorbar=dict(thickness=5, tickvals=[-5, 5], ticktext=['Low', 'High'], outlinewidth=0)
|
75 |
+
),
|
76 |
+
hoverinfo='none'
|
77 |
+
)
|
78 |
+
|
79 |
+
plotly_fig['layout']['showlegend'] = False
|
80 |
+
plotly_fig['layout']['hovermode'] = 'closest'
|
81 |
+
plotly_fig['layout']['height']=600
|
82 |
+
plotly_fig['layout']['width']=500
|
83 |
+
|
84 |
+
plotly_fig['layout']['xaxis'].update(zeroline=True, showline=True, ticklen=4, showgrid=False)
|
85 |
+
plotly_fig['layout']['yaxis'].update(dict(visible=False))
|
86 |
+
plotly_fig.add_trace(colorbar_trace)
|
87 |
+
plotly_fig.layout.update(
|
88 |
+
annotations=[dict(
|
89 |
+
x=1.18,
|
90 |
+
align="right",
|
91 |
+
valign="top",
|
92 |
+
text='Feature value',
|
93 |
+
showarrow=False,
|
94 |
+
xref="paper",
|
95 |
+
yref="paper",
|
96 |
+
xanchor="right",
|
97 |
+
yanchor="middle",
|
98 |
+
textangle=-90,
|
99 |
+
font=dict(family='Calibri', size=14)
|
100 |
+
)
|
101 |
+
],
|
102 |
+
margin=dict(t=20)
|
103 |
+
)
|
104 |
+
return plotly_fig
|
105 |
+
|
106 |
+
def train_model_and_return_shap_values(X, y, target):
|
107 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
|
108 |
+
|
109 |
+
X_train.fillna((-999), inplace=True)
|
110 |
+
X_test.fillna((-999), inplace=True)
|
111 |
+
|
112 |
+
# Some of values are float or integer and some object. This is why we need to cast them:
|
113 |
+
for f in X_train.columns:
|
114 |
+
if X_train[f].dtype=='object':
|
115 |
+
lbl = preprocessing.LabelEncoder()
|
116 |
+
lbl.fit(list(X_train[f].values))
|
117 |
+
X_train[f] = lbl.transform(list(X_train[f].values))
|
118 |
+
|
119 |
+
for f in X_test.columns:
|
120 |
+
if X_test[f].dtype=='object':
|
121 |
+
lbl = preprocessing.LabelEncoder()
|
122 |
+
lbl.fit(list(X_test[f].values))
|
123 |
+
X_test[f] = lbl.transform(list(X_test[f].values))
|
124 |
+
|
125 |
+
X_train=np.array(X_train)
|
126 |
+
X_test=np.array(X_test)
|
127 |
+
X_train = X_train.astype(float)
|
128 |
+
X_test = X_test.astype(float)
|
129 |
+
|
130 |
+
d_train = xgboost.DMatrix(X_train, label=y_train, feature_names=list(X))
|
131 |
+
d_test = xgboost.DMatrix(X_test, label=y_test, feature_names=list(X))
|
132 |
+
|
133 |
+
# train the model
|
134 |
+
params = {
|
135 |
+
"eta": 0.01,
|
136 |
+
"subsample": 0.5,
|
137 |
+
"base_score": np.mean(y_train),
|
138 |
+
"silent": 1
|
139 |
+
}
|
140 |
+
|
141 |
+
model = xgboost.train(params, d_train, 5000, evals = [(d_test, "test")], verbose_eval=None, early_stopping_rounds=50)
|
142 |
+
feature_names = model.feature_names
|
143 |
+
shap_values = shap.TreeExplainer(model).shap_values(pd.DataFrame(X_train, columns=X.columns))
|
144 |
+
return model, shap_values, feature_names
|
145 |
+
|
146 |
+
def dependence_plot_to_plotly_fig(dataset, target='target column', max_display=10):
|
147 |
+
data = pd.read_csv(dataset, encoding="ISO-8859-1")
|
148 |
+
X = data.drop(['target column'], axis=1)
|
149 |
+
y = data[target]
|
150 |
+
y = y/max(y)
|
151 |
+
|
152 |
+
xgb_full = xgboost.DMatrix(X, label=y)
|
153 |
+
|
154 |
+
# create a train/test split
|
155 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
|
156 |
+
xgb_train = xgboost.DMatrix(X_train, label=y_train)
|
157 |
+
xgb_test = xgboost.DMatrix(X_test, label=y_test)
|
158 |
+
|
159 |
+
# use validation set to choose # of trees
|
160 |
+
params = {
|
161 |
+
# "eta": 0.002,
|
162 |
+
# "max_depth": 3,
|
163 |
+
# "subsample": 0.5,
|
164 |
+
"silent": 1
|
165 |
+
}
|
166 |
+
model_train = xgboost.train(params, xgb_train, 3000, evals = [(xgb_test, "test")], verbose_eval=None)
|
167 |
+
|
168 |
+
# train final model on the full data set
|
169 |
+
params = {
|
170 |
+
# "eta": 0.002,
|
171 |
+
# "max_depth": 3,
|
172 |
+
# "subsample": 0.5,
|
173 |
+
"silent": 1
|
174 |
+
}
|
175 |
+
model = xgboost.train(params, xgb_full, 1500, evals = [(xgb_full, "test")], verbose_eval=None)
|
176 |
+
features = model.feature_names
|
177 |
+
shap_values = shap.TreeExplainer(model).shap_values(X)
|
178 |
+
|
179 |
+
feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
|
180 |
+
feature_order = feature_order[-min(max_display, len(feature_order)):]
|
181 |
+
features = [features[i] for i in feature_order[::-1]]
|
182 |
+
|
183 |
+
lis = []
|
184 |
+
for i in features:
|
185 |
+
mpl_fig, interaction_index = shap_dependence_plot(i, shap_values, X)
|
186 |
+
plotly_fig = tls.mpl_to_plotly(mpl_fig)
|
187 |
+
|
188 |
+
# The x-tick labels start by default from 0, which is not necessarily the min value of the feature.
|
189 |
+
# So, we need to increment the x-tick labels by 1. But while doing so, the y-axis gets shifted.
|
190 |
+
# To prevent that, we need to manually control the x-axis range from r_min to r_max
|
191 |
+
new_x = []
|
192 |
+
for j in plotly_fig['data'][0]['x']:
|
193 |
+
new_x.append(j)
|
194 |
+
|
195 |
+
r_min = min(plotly_fig['data'][0]['x'])
|
196 |
+
r_max = max(plotly_fig['data'][0]['x'])
|
197 |
+
|
198 |
+
plotly_fig['layout']['xaxis'].update(range=[r_min-1, r_max+1])
|
199 |
+
plotly_fig['data'][0]['x'] = tuple(new_x)
|
200 |
+
|
201 |
+
# Define the colorbar
|
202 |
+
colorbar_trace = go.Scatter(x=[None],
|
203 |
+
y=[None],
|
204 |
+
mode='markers',
|
205 |
+
marker=dict(
|
206 |
+
colorscale=red_blue,
|
207 |
+
showscale=True,
|
208 |
+
colorbar=dict(thickness=5, outlinewidth=0),
|
209 |
+
color=[min(X[X.columns[interaction_index]]), max(X[X.columns[interaction_index]])],
|
210 |
+
),
|
211 |
+
hoverinfo='none'
|
212 |
+
)
|
213 |
+
|
214 |
+
plotly_fig['layout']['showlegend'] = False
|
215 |
+
plotly_fig['layout']['hovermode'] = 'closest'
|
216 |
+
plotly_fig['layout']['height']=380
|
217 |
+
plotly_fig['layout']['width']=450
|
218 |
+
plotly_fig['layout']['xaxis'].update(zeroline=True,
|
219 |
+
showline=True,
|
220 |
+
ticklen=4,
|
221 |
+
showgrid=False,
|
222 |
+
tickmode='linear')
|
223 |
+
title = plotly_fig['layout']['yaxis']['title']
|
224 |
+
plotly_fig['layout']['yaxis'].update(title=title.split(' -')[0])
|
225 |
+
|
226 |
+
plotly_fig.add_trace(colorbar_trace)
|
227 |
+
plotly_fig.layout.update(
|
228 |
+
annotations=[dict(
|
229 |
+
x=1.23,
|
230 |
+
align="right",
|
231 |
+
valign="top",
|
232 |
+
text=X.columns[interaction_index],
|
233 |
+
showarrow=False,
|
234 |
+
xref="paper",
|
235 |
+
yref="paper",
|
236 |
+
xanchor="right",
|
237 |
+
yanchor="middle",
|
238 |
+
textangle=-90,
|
239 |
+
font=dict(family='Calibri', size=14)
|
240 |
+
)
|
241 |
+
],
|
242 |
+
margin=dict(t=50, b=50, l=50, r=80)
|
243 |
+
)
|
244 |
+
lis.append(plotly_fig)
|
245 |
+
return lis, features
|
246 |
+
|
247 |
+
def interaction_plot_to_plotly_fig(dataset, target_col='target column', max_display=10):
|
248 |
+
data = pd.read_csv(dataset, encoding="ISO-8859-1")
|
249 |
+
X = data.drop(['target column'], axis=1)
|
250 |
+
y = data[target_col]
|
251 |
+
y = y/max(y)
|
252 |
+
|
253 |
+
xgb_full = xgboost.DMatrix(X, label=y)
|
254 |
+
|
255 |
+
# create a train/test split
|
256 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
|
257 |
+
xgb_train = xgboost.DMatrix(X_train, label=y_train)
|
258 |
+
xgb_test = xgboost.DMatrix(X_test, label=y_test)
|
259 |
+
|
260 |
+
# use validation set to choose # of trees
|
261 |
+
params = {
|
262 |
+
# "eta": 0.002,
|
263 |
+
# "max_depth": 3,
|
264 |
+
# "subsample": 0.5,
|
265 |
+
"silent": 1
|
266 |
+
}
|
267 |
+
model_train = xgboost.train(params, xgb_train, 3000, evals = [(xgb_test, "test")], verbose_eval=None)
|
268 |
+
|
269 |
+
# train final model on the full data set
|
270 |
+
params = {
|
271 |
+
# "eta": 0.002,
|
272 |
+
# "max_depth": 3,
|
273 |
+
# "subsample": 0.5,
|
274 |
+
"silent": 1
|
275 |
+
}
|
276 |
+
model = xgboost.train(params, xgb_full, 1500, evals = [(xgb_full, "test")], verbose_eval=None)
|
277 |
+
features = model.feature_names
|
278 |
+
shap_values = shap.TreeExplainer(model).shap_values(X)
|
279 |
+
|
280 |
+
feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
|
281 |
+
feature_order = feature_order[-min(max_display, len(feature_order)):]
|
282 |
+
features = [features[i] for i in feature_order[::-1]]
|
283 |
+
|
284 |
+
shap_interaction_values = shap.TreeExplainer(model).shap_interaction_values(X)
|
285 |
+
|
286 |
+
lis = []
|
287 |
+
for i in features:
|
288 |
+
for j in features:
|
289 |
+
mpl_fig = pl.figure()
|
290 |
+
ax = mpl_fig.add_subplot(111)
|
291 |
+
_, interaction_index = shap_dependence_plot ( (i, j), shap_interaction_values, X.iloc[:2000,:] )
|
292 |
+
plotly_fig = tls.mpl_to_plotly(mpl_fig)
|
293 |
+
|
294 |
+
r_min = min(plotly_fig['data'][0]['x'])
|
295 |
+
r_max = max(plotly_fig['data'][0]['x'])
|
296 |
+
|
297 |
+
plotly_fig['layout']['xaxis'].update(range=[r_min-1, r_max+1])
|
298 |
+
plotly_fig['layout']['showlegend'] = False
|
299 |
+
plotly_fig['layout']['hovermode'] = 'closest'
|
300 |
+
plotly_fig['layout']['height']=380
|
301 |
+
plotly_fig['layout']['width']=450
|
302 |
+
plotly_fig['layout']['xaxis'].update(zeroline=True,
|
303 |
+
showline=True,
|
304 |
+
ticklen=4,
|
305 |
+
showgrid=False,
|
306 |
+
tickmode='linear')
|
307 |
+
plotly_fig['layout']['yaxis'].update(showline=True)
|
308 |
+
|
309 |
+
if i!=j:
|
310 |
+
# plotly_fig['layout']['height']=380
|
311 |
+
plotly_fig['layout']['width']=480
|
312 |
+
plotly_fig['layout']['yaxis']['title'] = "SHAP interaction value for {} and {}".format(i.split('-')[0], j.split('-')[0])
|
313 |
+
# Define the colorbar
|
314 |
+
colorbar_trace = go.Scatter(x=[None],
|
315 |
+
y=[None],
|
316 |
+
mode='markers',
|
317 |
+
marker=dict(
|
318 |
+
colorscale=red_blue,
|
319 |
+
showscale=True,
|
320 |
+
colorbar=dict(thickness=5, outlinewidth=0),
|
321 |
+
color=[min(X[X.columns[interaction_index]]), max(X[X.columns[interaction_index]])],
|
322 |
+
),
|
323 |
+
hoverinfo='none'
|
324 |
+
)
|
325 |
+
plotly_fig.add_trace(colorbar_trace)
|
326 |
+
plotly_fig.layout.update(
|
327 |
+
annotations=[dict(
|
328 |
+
x=1.23,
|
329 |
+
align="right",
|
330 |
+
valign="top",
|
331 |
+
text=X.columns[interaction_index],
|
332 |
+
showarrow=False,
|
333 |
+
xref="paper",
|
334 |
+
yref="paper",
|
335 |
+
xanchor="right",
|
336 |
+
yanchor="middle",
|
337 |
+
textangle=-90,
|
338 |
+
font=dict(family='Calibri', size=14)
|
339 |
+
)
|
340 |
+
],
|
341 |
+
margin=dict(t=30, b=30, l=60, r=80)
|
342 |
+
)
|
343 |
+
else:
|
344 |
+
plotly_fig['layout']['yaxis']['title'] = "SHAP main effect value for {}".format(i.split('-')[0])
|
345 |
+
lis.append(plotly_fig)
|
346 |
+
return lis, features
|
requirements.txt
CHANGED
@@ -3,6 +3,7 @@ numpy==1.23.4
|
|
3 |
altair==5.1.2
|
4 |
scikit-learn==1.1.3
|
5 |
pandas
|
|
|
6 |
xgboost==1.3.3
|
7 |
shap==0.41.0
|
8 |
plotly
|
|
|
3 |
altair==5.1.2
|
4 |
scikit-learn==1.1.3
|
5 |
pandas
|
6 |
+
catboost
|
7 |
xgboost==1.3.3
|
8 |
shap==0.41.0
|
9 |
plotly
|
shap_plots.py
CHANGED
@@ -1,730 +1,730 @@
|
|
1 |
-
import warnings
|
2 |
-
import iml
|
3 |
-
import numpy as np
|
4 |
-
from iml import Instance, Model
|
5 |
-
from iml.datatypes import DenseData
|
6 |
-
from iml.explanations import AdditiveExplanation
|
7 |
-
from iml.links import IdentityLink
|
8 |
-
from scipy.stats import gaussian_kde
|
9 |
-
import matplotlib
|
10 |
-
try:
|
11 |
-
import matplotlib.pyplot as pl
|
12 |
-
from matplotlib.colors import LinearSegmentedColormap
|
13 |
-
from matplotlib.ticker import MaxNLocator
|
14 |
-
|
15 |
-
cdict1 = {
|
16 |
-
'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
|
17 |
-
(1.0, 0.9607843137254902, 0.9607843137254902)),
|
18 |
-
|
19 |
-
'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
|
20 |
-
(1.0, 0.15294117647058825, 0.15294117647058825)),
|
21 |
-
|
22 |
-
'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
|
23 |
-
(1.0, 0.3411764705882353, 0.3411764705882353)),
|
24 |
-
|
25 |
-
'alpha': ((0.0, 1, 1),
|
26 |
-
(0.5, 0.3, 0.3),
|
27 |
-
(1.0, 1, 1))
|
28 |
-
} # #1E88E5 -> #ff0052
|
29 |
-
red_blue = LinearSegmentedColormap('RedBlue', cdict1)
|
30 |
-
|
31 |
-
cdict1 = {
|
32 |
-
'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
|
33 |
-
(1.0, 0.9607843137254902, 0.9607843137254902)),
|
34 |
-
|
35 |
-
'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
|
36 |
-
(1.0, 0.15294117647058825, 0.15294117647058825)),
|
37 |
-
|
38 |
-
'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
|
39 |
-
(1.0, 0.3411764705882353, 0.3411764705882353)),
|
40 |
-
|
41 |
-
'alpha': ((0.0, 1, 1),
|
42 |
-
(0.5, 1, 1),
|
43 |
-
(1.0, 1, 1))
|
44 |
-
} # #1E88E5 -> #ff0052
|
45 |
-
red_blue_solid = LinearSegmentedColormap('RedBlue', cdict1)
|
46 |
-
except ImportError:
|
47 |
-
pass
|
48 |
-
|
49 |
-
labels = {
|
50 |
-
'MAIN_EFFECT': "SHAP main effect value for\n%s",
|
51 |
-
'INTERACTION_VALUE': "SHAP interaction value",
|
52 |
-
'INTERACTION_EFFECT': "SHAP interaction value for\n%s and %s",
|
53 |
-
'VALUE': "SHAP value (impact on model output)",
|
54 |
-
'VALUE_FOR': "SHAP value for\n%s",
|
55 |
-
'PLOT_FOR': "SHAP plot for %s",
|
56 |
-
'FEATURE': "Feature %s",
|
57 |
-
'FEATURE_VALUE': "Feature value",
|
58 |
-
'FEATURE_VALUE_LOW': "Low",
|
59 |
-
'FEATURE_VALUE_HIGH': "High",
|
60 |
-
'JOINT_VALUE': "Joint SHAP value"
|
61 |
-
}
|
62 |
-
|
63 |
-
def shap_summary_plot(shap_values, features=None, feature_names=None, max_display=None, plot_type="dot",
|
64 |
-
color=None, axis_color="#333333", title=None, alpha=1, show=True, sort=True,
|
65 |
-
color_bar=True, auto_size_plot=True, layered_violin_max_num_bins=20):
|
66 |
-
"""Create a SHAP summary plot, colored by feature values when they are provided.
|
67 |
-
|
68 |
-
Parameters
|
69 |
-
----------
|
70 |
-
shap_values : numpy.array
|
71 |
-
Matrix of SHAP values (# samples x # features)
|
72 |
-
|
73 |
-
features : numpy.array or pandas.DataFrame or list
|
74 |
-
Matrix of feature values (# samples x # features) or a feature_names list as shorthand
|
75 |
-
|
76 |
-
feature_names : list
|
77 |
-
Names of the features (length # features)
|
78 |
-
|
79 |
-
max_display : int
|
80 |
-
How many top features to include in the plot (default is 20, or 7 for interaction plots)
|
81 |
-
|
82 |
-
plot_type : "dot" (default) or "violin"
|
83 |
-
What type of summary plot to produce
|
84 |
-
"""
|
85 |
-
|
86 |
-
assert len(shap_values.shape) != 1, "Summary plots need a matrix of shap_values, not a vector."
|
87 |
-
|
88 |
-
# default color:
|
89 |
-
if color is None:
|
90 |
-
color = "coolwarm" if plot_type == 'layered_violin' else "#ff0052"
|
91 |
-
|
92 |
-
# convert from a DataFrame or other types
|
93 |
-
if str(type(features)) == "<class 'pandas.core.frame.DataFrame'>":
|
94 |
-
if feature_names is None:
|
95 |
-
feature_names = features.columns
|
96 |
-
features = features.values
|
97 |
-
elif str(type(features)) == "<class 'list'>":
|
98 |
-
if feature_names is None:
|
99 |
-
feature_names = features
|
100 |
-
features = None
|
101 |
-
elif (features is not None) and len(features.shape) == 1 and feature_names is None:
|
102 |
-
feature_names = features
|
103 |
-
features = None
|
104 |
-
|
105 |
-
if feature_names is None:
|
106 |
-
feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)]
|
107 |
-
|
108 |
-
mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1))
|
109 |
-
|
110 |
-
# plotting SHAP interaction values
|
111 |
-
if len(shap_values.shape) == 3:
|
112 |
-
if max_display is None:
|
113 |
-
max_display = 7
|
114 |
-
else:
|
115 |
-
max_display = min(len(feature_names), max_display)
|
116 |
-
|
117 |
-
sort_inds = np.argsort(-np.abs(shap_values[:, :-1, :-1].sum(1)).sum(0))
|
118 |
-
|
119 |
-
# get plotting limits
|
120 |
-
delta = 1.0 / (shap_values.shape[1] ** 2)
|
121 |
-
slow = np.nanpercentile(shap_values, delta)
|
122 |
-
shigh = np.nanpercentile(shap_values, 100 - delta)
|
123 |
-
v = max(abs(slow), abs(shigh))
|
124 |
-
slow = -0.2
|
125 |
-
shigh = 0.2
|
126 |
-
|
127 |
-
# mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1))
|
128 |
-
ax = mpl_fig.subplot(1, max_display, 1)
|
129 |
-
proj_shap_values = shap_values[:, sort_inds[0], np.hstack((sort_inds, len(sort_inds)))]
|
130 |
-
proj_shap_values[:, 1:] *= 2 # because off diag effects are split in half
|
131 |
-
shap_summary_plot(
|
132 |
-
proj_shap_values, features[:, sort_inds],
|
133 |
-
feature_names=feature_names[sort_inds],
|
134 |
-
sort=False, show=False, color_bar=False,
|
135 |
-
auto_size_plot=False,
|
136 |
-
max_display=max_display
|
137 |
-
)
|
138 |
-
pl.xlim((slow, shigh))
|
139 |
-
pl.xlabel("")
|
140 |
-
title_length_limit = 11
|
141 |
-
pl.title(shorten_text(feature_names[sort_inds[0]], title_length_limit))
|
142 |
-
for i in range(1, max_display):
|
143 |
-
ind = sort_inds[i]
|
144 |
-
pl.subplot(1, max_display, i + 1)
|
145 |
-
proj_shap_values = shap_values[:, ind, np.hstack((sort_inds, len(sort_inds)))]
|
146 |
-
proj_shap_values *= 2
|
147 |
-
proj_shap_values[:, i] /= 2 # because only off diag effects are split in half
|
148 |
-
shap_summary_plot(
|
149 |
-
proj_shap_values, features[:, sort_inds],
|
150 |
-
sort=False,
|
151 |
-
feature_names=
|
152 |
-
show=False,
|
153 |
-
color_bar=False,
|
154 |
-
auto_size_plot=False,
|
155 |
-
max_display=max_display
|
156 |
-
)
|
157 |
-
pl.xlim((slow, shigh))
|
158 |
-
pl.xlabel("")
|
159 |
-
if i == max_display // 2:
|
160 |
-
pl.xlabel(labels['INTERACTION_VALUE'])
|
161 |
-
pl.title(shorten_text(feature_names[ind], title_length_limit))
|
162 |
-
pl.tight_layout(pad=0, w_pad=0, h_pad=0.0)
|
163 |
-
pl.subplots_adjust(hspace=0, wspace=0.1)
|
164 |
-
# if show:
|
165 |
-
# # pl.show()
|
166 |
-
return mpl_fig
|
167 |
-
|
168 |
-
if max_display is None:
|
169 |
-
max_display = 20
|
170 |
-
|
171 |
-
if sort:
|
172 |
-
# order features by the sum of their effect magnitudes
|
173 |
-
feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
|
174 |
-
feature_order = feature_order[-min(max_display, len(feature_order)):]
|
175 |
-
else:
|
176 |
-
feature_order = np.flip(np.arange(min(max_display, shap_values.shape[1] - 1)), 0)
|
177 |
-
|
178 |
-
row_height = 0.4
|
179 |
-
if auto_size_plot:
|
180 |
-
pl.gcf().set_size_inches(8, len(feature_order) * row_height + 1.5)
|
181 |
-
pl.axvline(x=0, color="#999999", zorder=-1)
|
182 |
-
|
183 |
-
if plot_type == "dot":
|
184 |
-
for pos, i in enumerate(feature_order):
|
185 |
-
pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
|
186 |
-
shaps = shap_values[:, i]
|
187 |
-
values = None if features is None else features[:, i]
|
188 |
-
inds = np.arange(len(shaps))
|
189 |
-
np.random.shuffle(inds)
|
190 |
-
if values is not None:
|
191 |
-
values = values[inds]
|
192 |
-
shaps = shaps[inds]
|
193 |
-
colored_feature = True
|
194 |
-
try:
|
195 |
-
values = np.array(values, dtype=np.float64) # make sure this can be numeric
|
196 |
-
except:
|
197 |
-
colored_feature = False
|
198 |
-
N = len(shaps)
|
199 |
-
# hspacing = (np.max(shaps) - np.min(shaps)) / 200
|
200 |
-
# curr_bin = []
|
201 |
-
nbins = 100
|
202 |
-
quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))
|
203 |
-
inds = np.argsort(quant + np.random.randn(N) * 1e-6)
|
204 |
-
layer = 0
|
205 |
-
last_bin = -1
|
206 |
-
ys = np.zeros(N)
|
207 |
-
for ind in inds:
|
208 |
-
if quant[ind] != last_bin:
|
209 |
-
layer = 0
|
210 |
-
ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1)
|
211 |
-
layer += 1
|
212 |
-
last_bin = quant[ind]
|
213 |
-
ys *= 0.9 * (row_height / np.max(ys + 1))
|
214 |
-
|
215 |
-
if features is not None and colored_feature:
|
216 |
-
# trim the color range, but prevent the color range from collapsing
|
217 |
-
vmin = np.nanpercentile(values, 5)
|
218 |
-
vmax = np.nanpercentile(values, 95)
|
219 |
-
if vmin == vmax:
|
220 |
-
vmin = np.nanpercentile(values, 1)
|
221 |
-
vmax = np.nanpercentile(values, 99)
|
222 |
-
if vmin == vmax:
|
223 |
-
vmin = np.min(values)
|
224 |
-
vmax = np.max(values)
|
225 |
-
|
226 |
-
assert features.shape[0] == len(shaps), "Feature and SHAP matrices must have the same number of rows!"
|
227 |
-
nan_mask = np.isnan(values)
|
228 |
-
pl.scatter(shaps[nan_mask], pos + ys[nan_mask], color="#777777", vmin=vmin,
|
229 |
-
vmax=vmax, s=16, alpha=alpha, linewidth=0,
|
230 |
-
zorder=3, rasterized=len(shaps) > 500)
|
231 |
-
pl.scatter(shaps[np.invert(nan_mask)], pos + ys[np.invert(nan_mask)],
|
232 |
-
cmap=red_blue, vmin=vmin, vmax=vmax, s=16,
|
233 |
-
c=values[np.invert(nan_mask)], alpha=alpha, linewidth=0,
|
234 |
-
zorder=3, rasterized=len(shaps) > 500)
|
235 |
-
else:
|
236 |
-
|
237 |
-
pl.scatter(shaps, pos + ys, s=16, alpha=alpha, linewidth=0, zorder=3,
|
238 |
-
color=color if colored_feature else "#777777", rasterized=len(shaps) > 500)
|
239 |
-
|
240 |
-
elif plot_type == "violin":
|
241 |
-
for pos, i in enumerate(feature_order):
|
242 |
-
pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
|
243 |
-
|
244 |
-
if features is not None:
|
245 |
-
global_low = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 1)
|
246 |
-
global_high = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 99)
|
247 |
-
for pos, i in enumerate(feature_order):
|
248 |
-
shaps = shap_values[:, i]
|
249 |
-
shap_min, shap_max = np.min(shaps), np.max(shaps)
|
250 |
-
rng = shap_max - shap_min
|
251 |
-
xs = np.linspace(np.min(shaps) - rng * 0.2, np.max(shaps) + rng * 0.2, 100)
|
252 |
-
if np.std(shaps) < (global_high - global_low) / 100:
|
253 |
-
ds = gaussian_kde(shaps + np.random.randn(len(shaps)) * (global_high - global_low) / 100)(xs)
|
254 |
-
else:
|
255 |
-
ds = gaussian_kde(shaps)(xs)
|
256 |
-
ds /= np.max(ds) * 3
|
257 |
-
|
258 |
-
values = features[:, i]
|
259 |
-
window_size = max(10, len(values) // 20)
|
260 |
-
smooth_values = np.zeros(len(xs) - 1)
|
261 |
-
sort_inds = np.argsort(shaps)
|
262 |
-
trailing_pos = 0
|
263 |
-
leading_pos = 0
|
264 |
-
running_sum = 0
|
265 |
-
back_fill = 0
|
266 |
-
for j in range(len(xs) - 1):
|
267 |
-
|
268 |
-
while leading_pos < len(shaps) and xs[j] >= shaps[sort_inds[leading_pos]]:
|
269 |
-
running_sum += values[sort_inds[leading_pos]]
|
270 |
-
leading_pos += 1
|
271 |
-
if leading_pos - trailing_pos > 20:
|
272 |
-
running_sum -= values[sort_inds[trailing_pos]]
|
273 |
-
trailing_pos += 1
|
274 |
-
if leading_pos - trailing_pos > 0:
|
275 |
-
smooth_values[j] = running_sum / (leading_pos - trailing_pos)
|
276 |
-
for k in range(back_fill):
|
277 |
-
smooth_values[j - k - 1] = smooth_values[j]
|
278 |
-
else:
|
279 |
-
back_fill += 1
|
280 |
-
|
281 |
-
vmin = np.nanpercentile(values, 5)
|
282 |
-
vmax = np.nanpercentile(values, 95)
|
283 |
-
if vmin == vmax:
|
284 |
-
vmin = np.nanpercentile(values, 1)
|
285 |
-
vmax = np.nanpercentile(values, 99)
|
286 |
-
if vmin == vmax:
|
287 |
-
vmin = np.min(values)
|
288 |
-
vmax = np.max(values)
|
289 |
-
pl.scatter(shaps, np.ones(shap_values.shape[0]) * pos, s=9, cmap=red_blue_solid, vmin=vmin, vmax=vmax,
|
290 |
-
c=values, alpha=alpha, linewidth=0, zorder=1)
|
291 |
-
# smooth_values -= nxp.nanpercentile(smooth_values, 5)
|
292 |
-
# smooth_values /= np.nanpercentile(smooth_values, 95)
|
293 |
-
smooth_values -= vmin
|
294 |
-
if vmax - vmin > 0:
|
295 |
-
smooth_values /= vmax - vmin
|
296 |
-
for i in range(len(xs) - 1):
|
297 |
-
if ds[i] > 0.05 or ds[i + 1] > 0.05:
|
298 |
-
pl.fill_between([xs[i], xs[i + 1]], [pos + ds[i], pos + ds[i + 1]],
|
299 |
-
[pos - ds[i], pos - ds[i + 1]], color=red_blue_solid(smooth_values[i]),
|
300 |
-
zorder=2)
|
301 |
-
|
302 |
-
else:
|
303 |
-
parts = pl.violinplot(shap_values[:, feature_order], range(len(feature_order)), points=200, vert=False,
|
304 |
-
widths=0.7,
|
305 |
-
showmeans=False, showextrema=False, showmedians=False)
|
306 |
-
|
307 |
-
for pc in parts['bodies']:
|
308 |
-
pc.set_facecolor(color)
|
309 |
-
pc.set_edgecolor('none')
|
310 |
-
pc.set_alpha(alpha)
|
311 |
-
|
312 |
-
elif plot_type == "layered_violin": # courtesy of @kodonnell
|
313 |
-
num_x_points = 200
|
314 |
-
bins = np.linspace(0, features.shape[0], layered_violin_max_num_bins + 1).round(0).astype(
|
315 |
-
'int') # the indices of the feature data corresponding to each bin
|
316 |
-
shap_min, shap_max = np.min(shap_values[:, :-1]), np.max(shap_values[:, :-1])
|
317 |
-
x_points = np.linspace(shap_min, shap_max, num_x_points)
|
318 |
-
|
319 |
-
# loop through each feature and plot:
|
320 |
-
for pos, ind in enumerate(feature_order):
|
321 |
-
# decide how to handle: if #unique < layered_violin_max_num_bins then split by unique value, otherwise use bins/percentiles.
|
322 |
-
# to keep simpler code, in the case of uniques, we just adjust the bins to align with the unique counts.
|
323 |
-
feature = features[:, ind]
|
324 |
-
unique, counts = np.unique(feature, return_counts=True)
|
325 |
-
if unique.shape[0] <= layered_violin_max_num_bins:
|
326 |
-
order = np.argsort(unique)
|
327 |
-
thesebins = np.cumsum(counts[order])
|
328 |
-
thesebins = np.insert(thesebins, 0, 0)
|
329 |
-
else:
|
330 |
-
thesebins = bins
|
331 |
-
nbins = thesebins.shape[0] - 1
|
332 |
-
# order the feature data so we can apply percentiling
|
333 |
-
order = np.argsort(feature)
|
334 |
-
# x axis is located at y0 = pos, with pos being there for offset
|
335 |
-
y0 = np.ones(num_x_points) * pos
|
336 |
-
# calculate kdes:
|
337 |
-
ys = np.zeros((nbins, num_x_points))
|
338 |
-
for i in range(nbins):
|
339 |
-
# get shap values in this bin:
|
340 |
-
shaps = shap_values[order[thesebins[i]:thesebins[i + 1]], ind]
|
341 |
-
# if there's only one element, then we can't
|
342 |
-
if shaps.shape[0] == 1:
|
343 |
-
warnings.warn(
|
344 |
-
"not enough data in bin #%d for feature %s, so it'll be ignored. Try increasing the number of records to plot."
|
345 |
-
% (i, feature_names[ind]))
|
346 |
-
# to ignore it, just set it to the previous y-values (so the area between them will be zero). Not ys is already 0, so there's
|
347 |
-
# nothing to do if i == 0
|
348 |
-
if i > 0:
|
349 |
-
ys[i, :] = ys[i - 1, :]
|
350 |
-
continue
|
351 |
-
# save kde of them: note that we add a tiny bit of gaussian noise to avoid singular matrix errors
|
352 |
-
ys[i, :] = gaussian_kde(shaps + np.random.normal(loc=0, scale=0.001, size=shaps.shape[0]))(x_points)
|
353 |
-
# scale it up so that the 'size' of each y represents the size of the bin. For continuous data this will
|
354 |
-
# do nothing, but when we've gone with the unqique option, this will matter - e.g. if 99% are male and 1%
|
355 |
-
# female, we want the 1% to appear a lot smaller.
|
356 |
-
size = thesebins[i + 1] - thesebins[i]
|
357 |
-
bin_size_if_even = features.shape[0] / nbins
|
358 |
-
relative_bin_size = size / bin_size_if_even
|
359 |
-
ys[i, :] *= relative_bin_size
|
360 |
-
# now plot 'em. We don't plot the individual strips, as this can leave whitespace between them.
|
361 |
-
# instead, we plot the full kde, then remove outer strip and plot over it, etc., to ensure no
|
362 |
-
# whitespace
|
363 |
-
ys = np.cumsum(ys, axis=0)
|
364 |
-
width = 0.8
|
365 |
-
scale = ys.max() * 2 / width # 2 is here as we plot both sides of x axis
|
366 |
-
for i in range(nbins - 1, -1, -1):
|
367 |
-
y = ys[i, :] / scale
|
368 |
-
c = pl.get_cmap(color)(i / (
|
369 |
-
nbins - 1)) if color in pl.cm.datad else color # if color is a cmap, use it, otherwise use a color
|
370 |
-
pl.fill_between(x_points, pos - y, pos + y, facecolor=c)
|
371 |
-
pl.xlim(shap_min, shap_max)
|
372 |
-
|
373 |
-
# draw the color bar
|
374 |
-
if color_bar and features is not None and (plot_type != "layered_violin" or color in pl.cm.datad):
|
375 |
-
import matplotlib.cm as cm
|
376 |
-
m = cm.ScalarMappable(cmap=red_blue_solid if plot_type != "layered_violin" else pl.get_cmap(color))
|
377 |
-
m.set_array([0, 1])
|
378 |
-
cb = pl.colorbar(m, ticks=[0, 1], aspect=1000)
|
379 |
-
cb.set_ticklabels([labels['FEATURE_VALUE_LOW'], labels['FEATURE_VALUE_HIGH']])
|
380 |
-
cb.set_label(labels['FEATURE_VALUE'], size=12, labelpad=0)
|
381 |
-
cb.ax.tick_params(labelsize=11, length=0)
|
382 |
-
cb.set_alpha(1)
|
383 |
-
cb.outline.set_visible(False)
|
384 |
-
bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
|
385 |
-
cb.ax.set_aspect((bbox.height - 0.9) * 20)
|
386 |
-
# cb.draw_all()
|
387 |
-
|
388 |
-
pl.gca().xaxis.set_ticks_position('bottom')
|
389 |
-
pl.gca().yaxis.set_ticks_position('none')
|
390 |
-
pl.gca().spines['right'].set_visible(False)
|
391 |
-
pl.gca().spines['top'].set_visible(False)
|
392 |
-
pl.gca().spines['left'].set_visible(False)
|
393 |
-
pl.gca().tick_params(color=axis_color, labelcolor=axis_color)
|
394 |
-
pl.yticks(range(len(feature_order)), [feature_names[i] for i in feature_order], fontsize=13)
|
395 |
-
pl.gca().tick_params('y', length=20, width=0.5, which='major')
|
396 |
-
pl.gca().tick_params('x', labelsize=11)
|
397 |
-
pl.ylim(-1, len(feature_order))
|
398 |
-
pl.xlabel(labels['VALUE'], fontsize=13)
|
399 |
-
pl.tight_layout()
|
400 |
-
# if show:
|
401 |
-
# pl.show()
|
402 |
-
return mpl_fig
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
def approx_interactions(index, shap_values, X):
|
410 |
-
""" Order other features by how much interaction they seem to have with the feature at the given index.
|
411 |
-
|
412 |
-
This just bins the SHAP values for a feature along that feature's value. For true Shapley interaction
|
413 |
-
index values for SHAP see the interaction_contribs option implemented in XGBoost.
|
414 |
-
"""
|
415 |
-
|
416 |
-
if X.shape[0] > 10000:
|
417 |
-
a = np.arange(X.shape[0])
|
418 |
-
np.random.shuffle(a)
|
419 |
-
inds = a[:10000]
|
420 |
-
else:
|
421 |
-
inds = np.arange(X.shape[0])
|
422 |
-
|
423 |
-
x = X[inds, index]
|
424 |
-
srt = np.argsort(x)
|
425 |
-
shap_ref = shap_values[inds, index]
|
426 |
-
shap_ref = shap_ref[srt]
|
427 |
-
inc = max(min(int(len(x) / 10.0), 50), 1)
|
428 |
-
interactions = []
|
429 |
-
for i in range(X.shape[1]):
|
430 |
-
val_other = X[inds, i][srt].astype(np.float)
|
431 |
-
v = 0.0
|
432 |
-
if not (i == index or np.sum(np.abs(val_other)) < 1e-8):
|
433 |
-
for j in range(0, len(x), inc):
|
434 |
-
if np.std(val_other[j:j + inc]) > 0 and np.std(shap_ref[j:j + inc]) > 0:
|
435 |
-
v += abs(np.corrcoef(shap_ref[j:j + inc], val_other[j:j + inc])[0, 1])
|
436 |
-
interactions.append(v)
|
437 |
-
|
438 |
-
return np.argsort(-np.abs(interactions))
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
def shap_dependence_plot(ind, shap_values, features, feature_names=None, display_features=None,
|
447 |
-
interaction_index="auto", color="#1E88E5", axis_color="#333333",
|
448 |
-
dot_size=16, alpha=1, title=None, show=True):
|
449 |
-
"""
|
450 |
-
Create a SHAP dependence plot, colored by an interaction feature.
|
451 |
-
|
452 |
-
Parameters
|
453 |
-
----------
|
454 |
-
ind : int
|
455 |
-
Index of the feature to plot.
|
456 |
-
|
457 |
-
shap_values : numpy.array
|
458 |
-
Matrix of SHAP values (# samples x # features)
|
459 |
-
|
460 |
-
features : numpy.array or pandas.DataFrame
|
461 |
-
Matrix of feature values (# samples x # features)
|
462 |
-
|
463 |
-
feature_names : list
|
464 |
-
Names of the features (length # features)
|
465 |
-
|
466 |
-
display_features : numpy.array or pandas.DataFrame
|
467 |
-
Matrix of feature values for visual display (such as strings instead of coded values)
|
468 |
-
|
469 |
-
interaction_index : "auto", None, or int
|
470 |
-
The index of the feature used to color the plot.
|
471 |
-
"""
|
472 |
-
|
473 |
-
# convert from DataFrames if we got any
|
474 |
-
if str(type(features)).endswith("'pandas.core.frame.DataFrame'>"):
|
475 |
-
if feature_names is None:
|
476 |
-
feature_names = features.columns
|
477 |
-
features = features.values
|
478 |
-
if str(type(display_features)).endswith("'pandas.core.frame.DataFrame'>"):
|
479 |
-
if feature_names is None:
|
480 |
-
feature_names = display_features.columns
|
481 |
-
display_features = display_features.values
|
482 |
-
elif display_features is None:
|
483 |
-
display_features = features
|
484 |
-
|
485 |
-
if feature_names is None:
|
486 |
-
feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)]
|
487 |
-
|
488 |
-
# allow vectors to be passed
|
489 |
-
if len(shap_values.shape) == 1:
|
490 |
-
shap_values = np.reshape(shap_values, len(shap_values), 1)
|
491 |
-
if len(features.shape) == 1:
|
492 |
-
features = np.reshape(features, len(features), 1)
|
493 |
-
|
494 |
-
def convert_name(ind):
|
495 |
-
if type(ind) == str:
|
496 |
-
nzinds = np.where(feature_names == ind)[0]
|
497 |
-
if len(nzinds) == 0:
|
498 |
-
print("Could not find feature named: " + ind)
|
499 |
-
return None
|
500 |
-
else:
|
501 |
-
return nzinds[0]
|
502 |
-
else:
|
503 |
-
return ind
|
504 |
-
|
505 |
-
ind = convert_name(ind)
|
506 |
-
|
507 |
-
mpl_fig = pl.gcf()
|
508 |
-
ax = mpl_fig.gca()
|
509 |
-
|
510 |
-
# plotting SHAP interaction values
|
511 |
-
if len(shap_values.shape) == 3 and len(ind) == 2:
|
512 |
-
ind1 = convert_name(ind[0])
|
513 |
-
ind2 = convert_name(ind[1])
|
514 |
-
if ind1 == ind2:
|
515 |
-
proj_shap_values = shap_values[:, ind2, :]
|
516 |
-
else:
|
517 |
-
proj_shap_values = shap_values[:, ind2, :] * 2 # off-diag values are split in half
|
518 |
-
|
519 |
-
# TODO: remove recursion; generally the functions should be shorter for more maintainable code
|
520 |
-
return shap_dependence_plot(
|
521 |
-
ind1, proj_shap_values, features, feature_names=feature_names,
|
522 |
-
interaction_index=ind2, display_features=display_features, show=False
|
523 |
-
)
|
524 |
-
|
525 |
-
assert shap_values.shape[0] == features.shape[0], \
|
526 |
-
"'shap_values' and 'features' values must have the same number of rows!"
|
527 |
-
assert shap_values.shape[1] == features.shape[1], \
|
528 |
-
"'shap_values' must have the same number of columns as 'features'!"
|
529 |
-
|
530 |
-
# get both the raw and display feature values
|
531 |
-
xv = features[:, ind]
|
532 |
-
xd = display_features[:, ind]
|
533 |
-
s = shap_values[:, ind]
|
534 |
-
if type(xd[0]) == str:
|
535 |
-
name_map = {}
|
536 |
-
for i in range(len(xv)):
|
537 |
-
name_map[xd[i]] = xv[i]
|
538 |
-
xnames = list(name_map.keys())
|
539 |
-
|
540 |
-
# allow a single feature name to be passed alone
|
541 |
-
if type(feature_names) == str:
|
542 |
-
feature_names = [feature_names]
|
543 |
-
name = feature_names[ind]
|
544 |
-
|
545 |
-
# guess what other feature as the stongest interaction with the plotted feature
|
546 |
-
if interaction_index == "auto":
|
547 |
-
interaction_index = approx_interactions(ind, shap_values, features)[0]
|
548 |
-
interaction_index = convert_name(interaction_index)
|
549 |
-
categorical_interaction = False
|
550 |
-
|
551 |
-
# get both the raw and display color values
|
552 |
-
if interaction_index is not None:
|
553 |
-
cv = features[:, interaction_index]
|
554 |
-
cd = display_features[:, interaction_index]
|
555 |
-
clow = np.nanpercentile(features[:, interaction_index].astype(np.float), 5)
|
556 |
-
chigh = np.nanpercentile(features[:, interaction_index].astype(np.float), 95)
|
557 |
-
if type(cd[0]) == str:
|
558 |
-
cname_map = {}
|
559 |
-
for i in range(len(cv)):
|
560 |
-
cname_map[cd[i]] = cv[i]
|
561 |
-
cnames = list(cname_map.keys())
|
562 |
-
categorical_interaction = True
|
563 |
-
elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:, interaction_index])) < 50:
|
564 |
-
categorical_interaction = True
|
565 |
-
|
566 |
-
# discritize colors for categorical features
|
567 |
-
color_norm = None
|
568 |
-
if categorical_interaction and clow != chigh:
|
569 |
-
bounds = np.linspace(clow, chigh, chigh - clow + 2)
|
570 |
-
color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N)
|
571 |
-
|
572 |
-
# the actual scatter plot, TODO: adapt the dot_size to the number of data points?
|
573 |
-
if interaction_index is not None:
|
574 |
-
pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:, interaction_index], cmap=red_blue,
|
575 |
-
alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500)
|
576 |
-
else:
|
577 |
-
pl.scatter(xv, s, s=dot_size, linewidth=0, color="#1E88E5",
|
578 |
-
alpha=alpha, rasterized=len(xv) > 500)
|
579 |
-
|
580 |
-
if interaction_index != ind and interaction_index is not None:
|
581 |
-
# draw the color bar
|
582 |
-
if type(cd[0]) == str:
|
583 |
-
tick_positions = [cname_map[n] for n in cnames]
|
584 |
-
if len(tick_positions) == 2:
|
585 |
-
tick_positions[0] -= 0.25
|
586 |
-
tick_positions[1] += 0.25
|
587 |
-
cb = pl.colorbar(ticks=tick_positions)
|
588 |
-
cb.set_ticklabels(cnames)
|
589 |
-
else:
|
590 |
-
cb = pl.colorbar()
|
591 |
-
|
592 |
-
cb.set_label(feature_names[interaction_index], size=13)
|
593 |
-
cb.ax.tick_params(labelsize=11)
|
594 |
-
if categorical_interaction:
|
595 |
-
cb.ax.tick_params(length=0)
|
596 |
-
cb.set_alpha(1)
|
597 |
-
cb.outline.set_visible(False)
|
598 |
-
bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
|
599 |
-
cb.ax.set_aspect((bbox.height - 0.7) * 20)
|
600 |
-
|
601 |
-
# make the plot more readable
|
602 |
-
if interaction_index != ind:
|
603 |
-
pl.gcf().set_size_inches(7.5, 5)
|
604 |
-
else:
|
605 |
-
pl.gcf().set_size_inches(6, 5)
|
606 |
-
# pl.xlabel(name, color=axis_color, fontsize=13)
|
607 |
-
# pl.ylabel(labels['VALUE_FOR'] % name, color=axis_color, fontsize=13)
|
608 |
-
if title is not None:
|
609 |
-
pl.title(title, color=axis_color, fontsize=13)
|
610 |
-
pl.gca().xaxis.set_ticks_position('bottom')
|
611 |
-
pl.gca().yaxis.set_ticks_position('left')
|
612 |
-
pl.gca().spines['right'].set_visible(False)
|
613 |
-
pl.gca().spines['top'].set_visible(False)
|
614 |
-
pl.gca().tick_params(color=axis_color, labelcolor=axis_color, labelsize=11)
|
615 |
-
for spine in pl.gca().spines.values():
|
616 |
-
spine.set_edgecolor(axis_color)
|
617 |
-
if type(xd[0]) == str:
|
618 |
-
pl.xticks([name_map[n] for n in xnames], xnames, rotation='vertical', fontsize=11)
|
619 |
-
# if show:
|
620 |
-
# pl.show()
|
621 |
-
|
622 |
-
|
623 |
-
if ind1 == ind2:
|
624 |
-
pl.ylabel(labels['MAIN_EFFECT'] % feature_names[ind1])
|
625 |
-
else:
|
626 |
-
pl.ylabel(labels['INTERACTION_EFFECT'] % (feature_names[ind1], feature_names[ind2]))
|
627 |
-
|
628 |
-
return mpl_fig, interaction_index
|
629 |
-
|
630 |
-
|
631 |
-
# # if show:
|
632 |
-
# # pl.show()
|
633 |
-
# return
|
634 |
-
# return mpl_fig
|
635 |
-
|
636 |
-
# assert shap_values.shape[0] == features.shape[0], "'shap_values' and 'features' values must have the same number of rows!"
|
637 |
-
# assert shap_values.shape[1] == features.shape[1] + 1, "'shap_values' must have one more column than 'features'!"
|
638 |
-
|
639 |
-
# get both the raw and display feature values
|
640 |
-
xv = features[:, ind]
|
641 |
-
xd = display_features[:, ind]
|
642 |
-
s = shap_values[:, ind]
|
643 |
-
if type(xd[0]) == str:
|
644 |
-
name_map = {}
|
645 |
-
for i in range(len(xv)):
|
646 |
-
name_map[xd[i]] = xv[i]
|
647 |
-
xnames = list(name_map.keys())
|
648 |
-
|
649 |
-
# allow a single feature name to be passed alone
|
650 |
-
if type(feature_names) == str:
|
651 |
-
feature_names = [feature_names]
|
652 |
-
name = feature_names[ind]
|
653 |
-
|
654 |
-
# guess what other feature as the stongest interaction with the plotted feature
|
655 |
-
if interaction_index == "auto":
|
656 |
-
interaction_index = approx_interactions(ind, shap_values, features)[0]
|
657 |
-
interaction_index = convert_name(interaction_index)
|
658 |
-
categorical_interaction = False
|
659 |
-
|
660 |
-
# get both the raw and display color values
|
661 |
-
if interaction_index is not None:
|
662 |
-
cv = features[:, interaction_index]
|
663 |
-
cd = display_features[:, interaction_index]
|
664 |
-
clow = np.nanpercentile(features[:, interaction_index].astype(np.float), 5)
|
665 |
-
chigh = np.nanpercentile(features[:, interaction_index].astype(np.float), 95)
|
666 |
-
if type(cd[0]) == str:
|
667 |
-
cname_map = {}
|
668 |
-
for i in range(len(cv)):
|
669 |
-
cname_map[cd[i]] = cv[i]
|
670 |
-
cnames = list(cname_map.keys())
|
671 |
-
categorical_interaction = True
|
672 |
-
elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:, interaction_index])) < 50:
|
673 |
-
categorical_interaction = True
|
674 |
-
|
675 |
-
# discritize colors for categorical features
|
676 |
-
color_norm = None
|
677 |
-
if categorical_interaction and clow != chigh:
|
678 |
-
bounds = np.linspace(clow, chigh, chigh - clow + 2)
|
679 |
-
color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N)
|
680 |
-
|
681 |
-
# the actual scatter plot, TODO: adapt the dot_size to the number of data points?
|
682 |
-
if interaction_index is not None:
|
683 |
-
pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:, interaction_index], cmap=red_blue,
|
684 |
-
alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500)
|
685 |
-
else:
|
686 |
-
pl.scatter(xv, s, s=dot_size, linewidth=0, color="#1E88E5",
|
687 |
-
alpha=alpha, rasterized=len(xv) > 500)
|
688 |
-
|
689 |
-
if interaction_index != ind and interaction_index is not None:
|
690 |
-
# draw the color bar
|
691 |
-
if type(cd[0]) == str:
|
692 |
-
tick_positions = [cname_map[n] for n in cnames]
|
693 |
-
if len(tick_positions) == 2:
|
694 |
-
tick_positions[0] -= 0.25
|
695 |
-
tick_positions[1] += 0.25
|
696 |
-
cb = pl.colorbar(ticks=tick_positions)
|
697 |
-
cb.set_ticklabels(cnames)
|
698 |
-
else:
|
699 |
-
cb = pl.colorbar()
|
700 |
-
|
701 |
-
cb.set_label(feature_names[interaction_index], size=13)
|
702 |
-
cb.ax.tick_params(labelsize=11)
|
703 |
-
if categorical_interaction:
|
704 |
-
cb.ax.tick_params(length=0)
|
705 |
-
cb.set_alpha(1)
|
706 |
-
cb.outline.set_visible(False)
|
707 |
-
bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
|
708 |
-
cb.ax.set_aspect((bbox.height - 0.7) * 20)
|
709 |
-
|
710 |
-
# make the plot more readable
|
711 |
-
if interaction_index != ind:
|
712 |
-
pl.gcf().set_size_inches(7.5, 5)
|
713 |
-
else:
|
714 |
-
pl.gcf().set_size_inches(6, 5)
|
715 |
-
pl.xlabel(name, color=axis_color, fontsize=13)
|
716 |
-
pl.ylabel(labels['VALUE_FOR'] % name, color=axis_color, fontsize=13)
|
717 |
-
if title is not None:
|
718 |
-
pl.title(title, color=axis_color, fontsize=13)
|
719 |
-
pl.gca().xaxis.set_ticks_position('bottom')
|
720 |
-
pl.gca().yaxis.set_ticks_position('left')
|
721 |
-
pl.gca().spines['right'].set_visible(False)
|
722 |
-
pl.gca().spines['top'].set_visible(False)
|
723 |
-
pl.gca().tick_params(color=axis_color, labelcolor=axis_color, labelsize=11)
|
724 |
-
for spine in pl.gca().spines.values():
|
725 |
-
spine.set_edgecolor(axis_color)
|
726 |
-
if type(xd[0]) == str:
|
727 |
-
pl.xticks([name_map[n] for n in xnames], xnames, rotation='vertical', fontsize=11)
|
728 |
-
# if show:
|
729 |
-
# pl.show()
|
730 |
return mpl_fig, interaction_index
|
|
|
1 |
+
import warnings
|
2 |
+
import iml
|
3 |
+
import numpy as np
|
4 |
+
from iml import Instance, Model
|
5 |
+
from iml.datatypes import DenseData
|
6 |
+
from iml.explanations import AdditiveExplanation
|
7 |
+
from iml.links import IdentityLink
|
8 |
+
from scipy.stats import gaussian_kde
|
9 |
+
import matplotlib
|
10 |
+
try:
|
11 |
+
import matplotlib.pyplot as pl
|
12 |
+
from matplotlib.colors import LinearSegmentedColormap
|
13 |
+
from matplotlib.ticker import MaxNLocator
|
14 |
+
|
15 |
+
cdict1 = {
|
16 |
+
'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
|
17 |
+
(1.0, 0.9607843137254902, 0.9607843137254902)),
|
18 |
+
|
19 |
+
'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
|
20 |
+
(1.0, 0.15294117647058825, 0.15294117647058825)),
|
21 |
+
|
22 |
+
'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
|
23 |
+
(1.0, 0.3411764705882353, 0.3411764705882353)),
|
24 |
+
|
25 |
+
'alpha': ((0.0, 1, 1),
|
26 |
+
(0.5, 0.3, 0.3),
|
27 |
+
(1.0, 1, 1))
|
28 |
+
} # #1E88E5 -> #ff0052
|
29 |
+
red_blue = LinearSegmentedColormap('RedBlue', cdict1)
|
30 |
+
|
31 |
+
cdict1 = {
|
32 |
+
'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
|
33 |
+
(1.0, 0.9607843137254902, 0.9607843137254902)),
|
34 |
+
|
35 |
+
'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
|
36 |
+
(1.0, 0.15294117647058825, 0.15294117647058825)),
|
37 |
+
|
38 |
+
'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
|
39 |
+
(1.0, 0.3411764705882353, 0.3411764705882353)),
|
40 |
+
|
41 |
+
'alpha': ((0.0, 1, 1),
|
42 |
+
(0.5, 1, 1),
|
43 |
+
(1.0, 1, 1))
|
44 |
+
} # #1E88E5 -> #ff0052
|
45 |
+
red_blue_solid = LinearSegmentedColormap('RedBlue', cdict1)
|
46 |
+
except ImportError:
|
47 |
+
pass
|
48 |
+
|
49 |
+
labels = {
|
50 |
+
'MAIN_EFFECT': "SHAP main effect value for\n%s",
|
51 |
+
'INTERACTION_VALUE': "SHAP interaction value",
|
52 |
+
'INTERACTION_EFFECT': "SHAP interaction value for\n%s and %s",
|
53 |
+
'VALUE': "SHAP value (impact on model output)",
|
54 |
+
'VALUE_FOR': "SHAP value for\n%s",
|
55 |
+
'PLOT_FOR': "SHAP plot for %s",
|
56 |
+
'FEATURE': "Feature %s",
|
57 |
+
'FEATURE_VALUE': "Feature value",
|
58 |
+
'FEATURE_VALUE_LOW': "Low",
|
59 |
+
'FEATURE_VALUE_HIGH': "High",
|
60 |
+
'JOINT_VALUE': "Joint SHAP value"
|
61 |
+
}
|
62 |
+
|
63 |
+
def shap_summary_plot(shap_values, features=None, feature_names=None, max_display=None, plot_type="dot",
|
64 |
+
color=None, axis_color="#333333", title=None, alpha=1, show=True, sort=True,
|
65 |
+
color_bar=True, auto_size_plot=True, layered_violin_max_num_bins=20):
|
66 |
+
"""Create a SHAP summary plot, colored by feature values when they are provided.
|
67 |
+
|
68 |
+
Parameters
|
69 |
+
----------
|
70 |
+
shap_values : numpy.array
|
71 |
+
Matrix of SHAP values (# samples x # features)
|
72 |
+
|
73 |
+
features : numpy.array or pandas.DataFrame or list
|
74 |
+
Matrix of feature values (# samples x # features) or a feature_names list as shorthand
|
75 |
+
|
76 |
+
feature_names : list
|
77 |
+
Names of the features (length # features)
|
78 |
+
|
79 |
+
max_display : int
|
80 |
+
How many top features to include in the plot (default is 20, or 7 for interaction plots)
|
81 |
+
|
82 |
+
plot_type : "dot" (default) or "violin"
|
83 |
+
What type of summary plot to produce
|
84 |
+
"""
|
85 |
+
|
86 |
+
assert len(shap_values.shape) != 1, "Summary plots need a matrix of shap_values, not a vector."
|
87 |
+
|
88 |
+
# default color:
|
89 |
+
if color is None:
|
90 |
+
color = "coolwarm" if plot_type == 'layered_violin' else "#ff0052"
|
91 |
+
|
92 |
+
# convert from a DataFrame or other types
|
93 |
+
if str(type(features)) == "<class 'pandas.core.frame.DataFrame'>":
|
94 |
+
if feature_names is None:
|
95 |
+
feature_names = features.columns
|
96 |
+
features = features.values
|
97 |
+
elif str(type(features)) == "<class 'list'>":
|
98 |
+
if feature_names is None:
|
99 |
+
feature_names = features
|
100 |
+
features = None
|
101 |
+
elif (features is not None) and len(features.shape) == 1 and feature_names is None:
|
102 |
+
feature_names = features
|
103 |
+
features = None
|
104 |
+
|
105 |
+
if feature_names is None:
|
106 |
+
feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)]
|
107 |
+
|
108 |
+
mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1))
|
109 |
+
|
110 |
+
# plotting SHAP interaction values
|
111 |
+
if len(shap_values.shape) == 3:
|
112 |
+
if max_display is None:
|
113 |
+
max_display = 7
|
114 |
+
else:
|
115 |
+
max_display = min(len(feature_names), max_display)
|
116 |
+
|
117 |
+
sort_inds = np.argsort(-np.abs(shap_values[:, :-1, :-1].sum(1)).sum(0))
|
118 |
+
|
119 |
+
# get plotting limits
|
120 |
+
delta = 1.0 / (shap_values.shape[1] ** 2)
|
121 |
+
slow = np.nanpercentile(shap_values, delta)
|
122 |
+
shigh = np.nanpercentile(shap_values, 100 - delta)
|
123 |
+
v = max(abs(slow), abs(shigh))
|
124 |
+
slow = -0.2
|
125 |
+
shigh = 0.2
|
126 |
+
|
127 |
+
# mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1))
|
128 |
+
ax = mpl_fig.subplot(1, max_display, 1)
|
129 |
+
proj_shap_values = shap_values[:, sort_inds[0], np.hstack((sort_inds, len(sort_inds)))]
|
130 |
+
proj_shap_values[:, 1:] *= 2 # because off diag effects are split in half
|
131 |
+
shap_summary_plot(
|
132 |
+
proj_shap_values, features[:, sort_inds],
|
133 |
+
feature_names=feature_names[sort_inds],
|
134 |
+
sort=False, show=False, color_bar=False,
|
135 |
+
auto_size_plot=False,
|
136 |
+
max_display=max_display
|
137 |
+
)
|
138 |
+
pl.xlim((slow, shigh))
|
139 |
+
pl.xlabel("")
|
140 |
+
title_length_limit = 11
|
141 |
+
pl.title(shorten_text(feature_names[sort_inds[0]], title_length_limit))
|
142 |
+
for i in range(1, max_display):
|
143 |
+
ind = sort_inds[i]
|
144 |
+
pl.subplot(1, max_display, i + 1)
|
145 |
+
proj_shap_values = shap_values[:, ind, np.hstack((sort_inds, len(sort_inds)))]
|
146 |
+
proj_shap_values *= 2
|
147 |
+
proj_shap_values[:, i] /= 2 # because only off diag effects are split in half
|
148 |
+
shap_summary_plot(
|
149 |
+
proj_shap_values, features[:, sort_inds],
|
150 |
+
sort=False,
|
151 |
+
feature_names=["" for i in range(features.shape[1])],
|
152 |
+
show=False,
|
153 |
+
color_bar=False,
|
154 |
+
auto_size_plot=False,
|
155 |
+
max_display=max_display
|
156 |
+
)
|
157 |
+
pl.xlim((slow, shigh))
|
158 |
+
pl.xlabel("")
|
159 |
+
if i == max_display // 2:
|
160 |
+
pl.xlabel(labels['INTERACTION_VALUE'])
|
161 |
+
pl.title(shorten_text(feature_names[ind], title_length_limit))
|
162 |
+
pl.tight_layout(pad=0, w_pad=0, h_pad=0.0)
|
163 |
+
pl.subplots_adjust(hspace=0, wspace=0.1)
|
164 |
+
# if show:
|
165 |
+
# # pl.show()
|
166 |
+
return mpl_fig
|
167 |
+
|
168 |
+
if max_display is None:
|
169 |
+
max_display = 20
|
170 |
+
|
171 |
+
if sort:
|
172 |
+
# order features by the sum of their effect magnitudes
|
173 |
+
feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
|
174 |
+
feature_order = feature_order[-min(max_display, len(feature_order)):]
|
175 |
+
else:
|
176 |
+
feature_order = np.flip(np.arange(min(max_display, shap_values.shape[1] - 1)), 0)
|
177 |
+
|
178 |
+
row_height = 0.4
|
179 |
+
if auto_size_plot:
|
180 |
+
pl.gcf().set_size_inches(8, len(feature_order) * row_height + 1.5)
|
181 |
+
pl.axvline(x=0, color="#999999", zorder=-1)
|
182 |
+
|
183 |
+
if plot_type == "dot":
|
184 |
+
for pos, i in enumerate(feature_order):
|
185 |
+
pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
|
186 |
+
shaps = shap_values[:, i]
|
187 |
+
values = None if features is None else features[:, i]
|
188 |
+
inds = np.arange(len(shaps))
|
189 |
+
np.random.shuffle(inds)
|
190 |
+
if values is not None:
|
191 |
+
values = values[inds]
|
192 |
+
shaps = shaps[inds]
|
193 |
+
colored_feature = True
|
194 |
+
try:
|
195 |
+
values = np.array(values, dtype=np.float64) # make sure this can be numeric
|
196 |
+
except:
|
197 |
+
colored_feature = False
|
198 |
+
N = len(shaps)
|
199 |
+
# hspacing = (np.max(shaps) - np.min(shaps)) / 200
|
200 |
+
# curr_bin = []
|
201 |
+
nbins = 100
|
202 |
+
quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))
|
203 |
+
inds = np.argsort(quant + np.random.randn(N) * 1e-6)
|
204 |
+
layer = 0
|
205 |
+
last_bin = -1
|
206 |
+
ys = np.zeros(N)
|
207 |
+
for ind in inds:
|
208 |
+
if quant[ind] != last_bin:
|
209 |
+
layer = 0
|
210 |
+
ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1)
|
211 |
+
layer += 1
|
212 |
+
last_bin = quant[ind]
|
213 |
+
ys *= 0.9 * (row_height / np.max(ys + 1))
|
214 |
+
|
215 |
+
if features is not None and colored_feature:
|
216 |
+
# trim the color range, but prevent the color range from collapsing
|
217 |
+
vmin = np.nanpercentile(values, 5)
|
218 |
+
vmax = np.nanpercentile(values, 95)
|
219 |
+
if vmin == vmax:
|
220 |
+
vmin = np.nanpercentile(values, 1)
|
221 |
+
vmax = np.nanpercentile(values, 99)
|
222 |
+
if vmin == vmax:
|
223 |
+
vmin = np.min(values)
|
224 |
+
vmax = np.max(values)
|
225 |
+
|
226 |
+
assert features.shape[0] == len(shaps), "Feature and SHAP matrices must have the same number of rows!"
|
227 |
+
nan_mask = np.isnan(values)
|
228 |
+
pl.scatter(shaps[nan_mask], pos + ys[nan_mask], color="#777777", vmin=vmin,
|
229 |
+
vmax=vmax, s=16, alpha=alpha, linewidth=0,
|
230 |
+
zorder=3, rasterized=len(shaps) > 500)
|
231 |
+
pl.scatter(shaps[np.invert(nan_mask)], pos + ys[np.invert(nan_mask)],
|
232 |
+
cmap=red_blue, vmin=vmin, vmax=vmax, s=16,
|
233 |
+
c=values[np.invert(nan_mask)], alpha=alpha, linewidth=0,
|
234 |
+
zorder=3, rasterized=len(shaps) > 500)
|
235 |
+
else:
|
236 |
+
|
237 |
+
pl.scatter(shaps, pos + ys, s=16, alpha=alpha, linewidth=0, zorder=3,
|
238 |
+
color=color if colored_feature else "#777777", rasterized=len(shaps) > 500)
|
239 |
+
|
240 |
+
elif plot_type == "violin":
|
241 |
+
for pos, i in enumerate(feature_order):
|
242 |
+
pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
|
243 |
+
|
244 |
+
if features is not None:
|
245 |
+
global_low = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 1)
|
246 |
+
global_high = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 99)
|
247 |
+
for pos, i in enumerate(feature_order):
|
248 |
+
shaps = shap_values[:, i]
|
249 |
+
shap_min, shap_max = np.min(shaps), np.max(shaps)
|
250 |
+
rng = shap_max - shap_min
|
251 |
+
xs = np.linspace(np.min(shaps) - rng * 0.2, np.max(shaps) + rng * 0.2, 100)
|
252 |
+
if np.std(shaps) < (global_high - global_low) / 100:
|
253 |
+
ds = gaussian_kde(shaps + np.random.randn(len(shaps)) * (global_high - global_low) / 100)(xs)
|
254 |
+
else:
|
255 |
+
ds = gaussian_kde(shaps)(xs)
|
256 |
+
ds /= np.max(ds) * 3
|
257 |
+
|
258 |
+
values = features[:, i]
|
259 |
+
window_size = max(10, len(values) // 20)
|
260 |
+
smooth_values = np.zeros(len(xs) - 1)
|
261 |
+
sort_inds = np.argsort(shaps)
|
262 |
+
trailing_pos = 0
|
263 |
+
leading_pos = 0
|
264 |
+
running_sum = 0
|
265 |
+
back_fill = 0
|
266 |
+
for j in range(len(xs) - 1):
|
267 |
+
|
268 |
+
while leading_pos < len(shaps) and xs[j] >= shaps[sort_inds[leading_pos]]:
|
269 |
+
running_sum += values[sort_inds[leading_pos]]
|
270 |
+
leading_pos += 1
|
271 |
+
if leading_pos - trailing_pos > 20:
|
272 |
+
running_sum -= values[sort_inds[trailing_pos]]
|
273 |
+
trailing_pos += 1
|
274 |
+
if leading_pos - trailing_pos > 0:
|
275 |
+
smooth_values[j] = running_sum / (leading_pos - trailing_pos)
|
276 |
+
for k in range(back_fill):
|
277 |
+
smooth_values[j - k - 1] = smooth_values[j]
|
278 |
+
else:
|
279 |
+
back_fill += 1
|
280 |
+
|
281 |
+
vmin = np.nanpercentile(values, 5)
|
282 |
+
vmax = np.nanpercentile(values, 95)
|
283 |
+
if vmin == vmax:
|
284 |
+
vmin = np.nanpercentile(values, 1)
|
285 |
+
vmax = np.nanpercentile(values, 99)
|
286 |
+
if vmin == vmax:
|
287 |
+
vmin = np.min(values)
|
288 |
+
vmax = np.max(values)
|
289 |
+
pl.scatter(shaps, np.ones(shap_values.shape[0]) * pos, s=9, cmap=red_blue_solid, vmin=vmin, vmax=vmax,
|
290 |
+
c=values, alpha=alpha, linewidth=0, zorder=1)
|
291 |
+
# smooth_values -= nxp.nanpercentile(smooth_values, 5)
|
292 |
+
# smooth_values /= np.nanpercentile(smooth_values, 95)
|
293 |
+
smooth_values -= vmin
|
294 |
+
if vmax - vmin > 0:
|
295 |
+
smooth_values /= vmax - vmin
|
296 |
+
for i in range(len(xs) - 1):
|
297 |
+
if ds[i] > 0.05 or ds[i + 1] > 0.05:
|
298 |
+
pl.fill_between([xs[i], xs[i + 1]], [pos + ds[i], pos + ds[i + 1]],
|
299 |
+
[pos - ds[i], pos - ds[i + 1]], color=red_blue_solid(smooth_values[i]),
|
300 |
+
zorder=2)
|
301 |
+
|
302 |
+
else:
|
303 |
+
parts = pl.violinplot(shap_values[:, feature_order], range(len(feature_order)), points=200, vert=False,
|
304 |
+
widths=0.7,
|
305 |
+
showmeans=False, showextrema=False, showmedians=False)
|
306 |
+
|
307 |
+
for pc in parts['bodies']:
|
308 |
+
pc.set_facecolor(color)
|
309 |
+
pc.set_edgecolor('none')
|
310 |
+
pc.set_alpha(alpha)
|
311 |
+
|
312 |
+
elif plot_type == "layered_violin": # courtesy of @kodonnell
|
313 |
+
num_x_points = 200
|
314 |
+
bins = np.linspace(0, features.shape[0], layered_violin_max_num_bins + 1).round(0).astype(
|
315 |
+
'int') # the indices of the feature data corresponding to each bin
|
316 |
+
shap_min, shap_max = np.min(shap_values[:, :-1]), np.max(shap_values[:, :-1])
|
317 |
+
x_points = np.linspace(shap_min, shap_max, num_x_points)
|
318 |
+
|
319 |
+
# loop through each feature and plot:
|
320 |
+
for pos, ind in enumerate(feature_order):
|
321 |
+
# decide how to handle: if #unique < layered_violin_max_num_bins then split by unique value, otherwise use bins/percentiles.
|
322 |
+
# to keep simpler code, in the case of uniques, we just adjust the bins to align with the unique counts.
|
323 |
+
feature = features[:, ind]
|
324 |
+
unique, counts = np.unique(feature, return_counts=True)
|
325 |
+
if unique.shape[0] <= layered_violin_max_num_bins:
|
326 |
+
order = np.argsort(unique)
|
327 |
+
thesebins = np.cumsum(counts[order])
|
328 |
+
thesebins = np.insert(thesebins, 0, 0)
|
329 |
+
else:
|
330 |
+
thesebins = bins
|
331 |
+
nbins = thesebins.shape[0] - 1
|
332 |
+
# order the feature data so we can apply percentiling
|
333 |
+
order = np.argsort(feature)
|
334 |
+
# x axis is located at y0 = pos, with pos being there for offset
|
335 |
+
y0 = np.ones(num_x_points) * pos
|
336 |
+
# calculate kdes:
|
337 |
+
ys = np.zeros((nbins, num_x_points))
|
338 |
+
for i in range(nbins):
|
339 |
+
# get shap values in this bin:
|
340 |
+
shaps = shap_values[order[thesebins[i]:thesebins[i + 1]], ind]
|
341 |
+
# if there's only one element, then we can't
|
342 |
+
if shaps.shape[0] == 1:
|
343 |
+
warnings.warn(
|
344 |
+
"not enough data in bin #%d for feature %s, so it'll be ignored. Try increasing the number of records to plot."
|
345 |
+
% (i, feature_names[ind]))
|
346 |
+
# to ignore it, just set it to the previous y-values (so the area between them will be zero). Not ys is already 0, so there's
|
347 |
+
# nothing to do if i == 0
|
348 |
+
if i > 0:
|
349 |
+
ys[i, :] = ys[i - 1, :]
|
350 |
+
continue
|
351 |
+
# save kde of them: note that we add a tiny bit of gaussian noise to avoid singular matrix errors
|
352 |
+
ys[i, :] = gaussian_kde(shaps + np.random.normal(loc=0, scale=0.001, size=shaps.shape[0]))(x_points)
|
353 |
+
# scale it up so that the 'size' of each y represents the size of the bin. For continuous data this will
|
354 |
+
# do nothing, but when we've gone with the unqique option, this will matter - e.g. if 99% are male and 1%
|
355 |
+
# female, we want the 1% to appear a lot smaller.
|
356 |
+
size = thesebins[i + 1] - thesebins[i]
|
357 |
+
bin_size_if_even = features.shape[0] / nbins
|
358 |
+
relative_bin_size = size / bin_size_if_even
|
359 |
+
ys[i, :] *= relative_bin_size
|
360 |
+
# now plot 'em. We don't plot the individual strips, as this can leave whitespace between them.
|
361 |
+
# instead, we plot the full kde, then remove outer strip and plot over it, etc., to ensure no
|
362 |
+
# whitespace
|
363 |
+
ys = np.cumsum(ys, axis=0)
|
364 |
+
width = 0.8
|
365 |
+
scale = ys.max() * 2 / width # 2 is here as we plot both sides of x axis
|
366 |
+
for i in range(nbins - 1, -1, -1):
|
367 |
+
y = ys[i, :] / scale
|
368 |
+
c = pl.get_cmap(color)(i / (
|
369 |
+
nbins - 1)) if color in pl.cm.datad else color # if color is a cmap, use it, otherwise use a color
|
370 |
+
pl.fill_between(x_points, pos - y, pos + y, facecolor=c)
|
371 |
+
pl.xlim(shap_min, shap_max)
|
372 |
+
|
373 |
+
# draw the color bar
|
374 |
+
if color_bar and features is not None and (plot_type != "layered_violin" or color in pl.cm.datad):
|
375 |
+
import matplotlib.cm as cm
|
376 |
+
m = cm.ScalarMappable(cmap=red_blue_solid if plot_type != "layered_violin" else pl.get_cmap(color))
|
377 |
+
m.set_array([0, 1])
|
378 |
+
cb = pl.colorbar(m, ticks=[0, 1], aspect=1000)
|
379 |
+
cb.set_ticklabels([labels['FEATURE_VALUE_LOW'], labels['FEATURE_VALUE_HIGH']])
|
380 |
+
cb.set_label(labels['FEATURE_VALUE'], size=12, labelpad=0)
|
381 |
+
cb.ax.tick_params(labelsize=11, length=0)
|
382 |
+
cb.set_alpha(1)
|
383 |
+
cb.outline.set_visible(False)
|
384 |
+
bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
|
385 |
+
cb.ax.set_aspect((bbox.height - 0.9) * 20)
|
386 |
+
# cb.draw_all()
|
387 |
+
|
388 |
+
pl.gca().xaxis.set_ticks_position('bottom')
|
389 |
+
pl.gca().yaxis.set_ticks_position('none')
|
390 |
+
pl.gca().spines['right'].set_visible(False)
|
391 |
+
pl.gca().spines['top'].set_visible(False)
|
392 |
+
pl.gca().spines['left'].set_visible(False)
|
393 |
+
pl.gca().tick_params(color=axis_color, labelcolor=axis_color)
|
394 |
+
pl.yticks(range(len(feature_order)), [feature_names[i] for i in feature_order], fontsize=13)
|
395 |
+
pl.gca().tick_params('y', length=20, width=0.5, which='major')
|
396 |
+
pl.gca().tick_params('x', labelsize=11)
|
397 |
+
pl.ylim(-1, len(feature_order))
|
398 |
+
pl.xlabel(labels['VALUE'], fontsize=13)
|
399 |
+
pl.tight_layout()
|
400 |
+
# if show:
|
401 |
+
# pl.show()
|
402 |
+
return mpl_fig
|
403 |
+
|
404 |
+
|
405 |
+
|
406 |
+
|
407 |
+
|
408 |
+
|
409 |
+
def approx_interactions(index, shap_values, X):
|
410 |
+
""" Order other features by how much interaction they seem to have with the feature at the given index.
|
411 |
+
|
412 |
+
This just bins the SHAP values for a feature along that feature's value. For true Shapley interaction
|
413 |
+
index values for SHAP see the interaction_contribs option implemented in XGBoost.
|
414 |
+
"""
|
415 |
+
|
416 |
+
if X.shape[0] > 10000:
|
417 |
+
a = np.arange(X.shape[0])
|
418 |
+
np.random.shuffle(a)
|
419 |
+
inds = a[:10000]
|
420 |
+
else:
|
421 |
+
inds = np.arange(X.shape[0])
|
422 |
+
|
423 |
+
x = X[inds, index]
|
424 |
+
srt = np.argsort(x)
|
425 |
+
shap_ref = shap_values[inds, index]
|
426 |
+
shap_ref = shap_ref[srt]
|
427 |
+
inc = max(min(int(len(x) / 10.0), 50), 1)
|
428 |
+
interactions = []
|
429 |
+
for i in range(X.shape[1]):
|
430 |
+
val_other = X[inds, i][srt].astype(np.float)
|
431 |
+
v = 0.0
|
432 |
+
if not (i == index or np.sum(np.abs(val_other)) < 1e-8):
|
433 |
+
for j in range(0, len(x), inc):
|
434 |
+
if np.std(val_other[j:j + inc]) > 0 and np.std(shap_ref[j:j + inc]) > 0:
|
435 |
+
v += abs(np.corrcoef(shap_ref[j:j + inc], val_other[j:j + inc])[0, 1])
|
436 |
+
interactions.append(v)
|
437 |
+
|
438 |
+
return np.argsort(-np.abs(interactions))
|
439 |
+
|
440 |
+
|
441 |
+
|
442 |
+
|
443 |
+
|
444 |
+
|
445 |
+
|
446 |
+
def shap_dependence_plot(ind, shap_values, features, feature_names=None, display_features=None,
|
447 |
+
interaction_index="auto", color="#1E88E5", axis_color="#333333",
|
448 |
+
dot_size=16, alpha=1, title=None, show=True):
|
449 |
+
"""
|
450 |
+
Create a SHAP dependence plot, colored by an interaction feature.
|
451 |
+
|
452 |
+
Parameters
|
453 |
+
----------
|
454 |
+
ind : int
|
455 |
+
Index of the feature to plot.
|
456 |
+
|
457 |
+
shap_values : numpy.array
|
458 |
+
Matrix of SHAP values (# samples x # features)
|
459 |
+
|
460 |
+
features : numpy.array or pandas.DataFrame
|
461 |
+
Matrix of feature values (# samples x # features)
|
462 |
+
|
463 |
+
feature_names : list
|
464 |
+
Names of the features (length # features)
|
465 |
+
|
466 |
+
display_features : numpy.array or pandas.DataFrame
|
467 |
+
Matrix of feature values for visual display (such as strings instead of coded values)
|
468 |
+
|
469 |
+
interaction_index : "auto", None, or int
|
470 |
+
The index of the feature used to color the plot.
|
471 |
+
"""
|
472 |
+
|
473 |
+
# convert from DataFrames if we got any
|
474 |
+
if str(type(features)).endswith("'pandas.core.frame.DataFrame'>"):
|
475 |
+
if feature_names is None:
|
476 |
+
feature_names = features.columns
|
477 |
+
features = features.values
|
478 |
+
if str(type(display_features)).endswith("'pandas.core.frame.DataFrame'>"):
|
479 |
+
if feature_names is None:
|
480 |
+
feature_names = display_features.columns
|
481 |
+
display_features = display_features.values
|
482 |
+
elif display_features is None:
|
483 |
+
display_features = features
|
484 |
+
|
485 |
+
if feature_names is None:
|
486 |
+
feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)]
|
487 |
+
|
488 |
+
# allow vectors to be passed
|
489 |
+
if len(shap_values.shape) == 1:
|
490 |
+
shap_values = np.reshape(shap_values, len(shap_values), 1)
|
491 |
+
if len(features.shape) == 1:
|
492 |
+
features = np.reshape(features, len(features), 1)
|
493 |
+
|
494 |
+
def convert_name(ind):
|
495 |
+
if type(ind) == str:
|
496 |
+
nzinds = np.where(feature_names == ind)[0]
|
497 |
+
if len(nzinds) == 0:
|
498 |
+
print("Could not find feature named: " + ind)
|
499 |
+
return None
|
500 |
+
else:
|
501 |
+
return nzinds[0]
|
502 |
+
else:
|
503 |
+
return ind
|
504 |
+
|
505 |
+
ind = convert_name(ind)
|
506 |
+
|
507 |
+
mpl_fig = pl.gcf()
|
508 |
+
ax = mpl_fig.gca()
|
509 |
+
|
510 |
+
# plotting SHAP interaction values
|
511 |
+
if len(shap_values.shape) == 3 and len(ind) == 2:
|
512 |
+
ind1 = convert_name(ind[0])
|
513 |
+
ind2 = convert_name(ind[1])
|
514 |
+
if ind1 == ind2:
|
515 |
+
proj_shap_values = shap_values[:, ind2, :]
|
516 |
+
else:
|
517 |
+
proj_shap_values = shap_values[:, ind2, :] * 2 # off-diag values are split in half
|
518 |
+
|
519 |
+
# TODO: remove recursion; generally the functions should be shorter for more maintainable code
|
520 |
+
return shap_dependence_plot(
|
521 |
+
ind1, proj_shap_values, features, feature_names=feature_names,
|
522 |
+
interaction_index=ind2, display_features=display_features, show=False
|
523 |
+
)
|
524 |
+
|
525 |
+
assert shap_values.shape[0] == features.shape[0], \
|
526 |
+
"'shap_values' and 'features' values must have the same number of rows!"
|
527 |
+
assert shap_values.shape[1] == features.shape[1], \
|
528 |
+
"'shap_values' must have the same number of columns as 'features'!"
|
529 |
+
|
530 |
+
# get both the raw and display feature values
|
531 |
+
xv = features[:, ind]
|
532 |
+
xd = display_features[:, ind]
|
533 |
+
s = shap_values[:, ind]
|
534 |
+
if type(xd[0]) == str:
|
535 |
+
name_map = {}
|
536 |
+
for i in range(len(xv)):
|
537 |
+
name_map[xd[i]] = xv[i]
|
538 |
+
xnames = list(name_map.keys())
|
539 |
+
|
540 |
+
# allow a single feature name to be passed alone
|
541 |
+
if type(feature_names) == str:
|
542 |
+
feature_names = [feature_names]
|
543 |
+
name = feature_names[ind]
|
544 |
+
|
545 |
+
# guess what other feature as the stongest interaction with the plotted feature
|
546 |
+
if interaction_index == "auto":
|
547 |
+
interaction_index = approx_interactions(ind, shap_values, features)[0]
|
548 |
+
interaction_index = convert_name(interaction_index)
|
549 |
+
categorical_interaction = False
|
550 |
+
|
551 |
+
# get both the raw and display color values
|
552 |
+
if interaction_index is not None:
|
553 |
+
cv = features[:, interaction_index]
|
554 |
+
cd = display_features[:, interaction_index]
|
555 |
+
clow = np.nanpercentile(features[:, interaction_index].astype(np.float), 5)
|
556 |
+
chigh = np.nanpercentile(features[:, interaction_index].astype(np.float), 95)
|
557 |
+
if type(cd[0]) == str:
|
558 |
+
cname_map = {}
|
559 |
+
for i in range(len(cv)):
|
560 |
+
cname_map[cd[i]] = cv[i]
|
561 |
+
cnames = list(cname_map.keys())
|
562 |
+
categorical_interaction = True
|
563 |
+
elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:, interaction_index])) < 50:
|
564 |
+
categorical_interaction = True
|
565 |
+
|
566 |
+
# discritize colors for categorical features
|
567 |
+
color_norm = None
|
568 |
+
if categorical_interaction and clow != chigh:
|
569 |
+
bounds = np.linspace(clow, chigh, chigh - clow + 2)
|
570 |
+
color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N)
|
571 |
+
|
572 |
+
# the actual scatter plot, TODO: adapt the dot_size to the number of data points?
|
573 |
+
if interaction_index is not None:
|
574 |
+
pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:, interaction_index], cmap=red_blue,
|
575 |
+
alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500)
|
576 |
+
else:
|
577 |
+
pl.scatter(xv, s, s=dot_size, linewidth=0, color="#1E88E5",
|
578 |
+
alpha=alpha, rasterized=len(xv) > 500)
|
579 |
+
|
580 |
+
if interaction_index != ind and interaction_index is not None:
|
581 |
+
# draw the color bar
|
582 |
+
if type(cd[0]) == str:
|
583 |
+
tick_positions = [cname_map[n] for n in cnames]
|
584 |
+
if len(tick_positions) == 2:
|
585 |
+
tick_positions[0] -= 0.25
|
586 |
+
tick_positions[1] += 0.25
|
587 |
+
cb = pl.colorbar(ticks=tick_positions)
|
588 |
+
cb.set_ticklabels(cnames)
|
589 |
+
else:
|
590 |
+
cb = pl.colorbar()
|
591 |
+
|
592 |
+
cb.set_label(feature_names[interaction_index], size=13)
|
593 |
+
cb.ax.tick_params(labelsize=11)
|
594 |
+
if categorical_interaction:
|
595 |
+
cb.ax.tick_params(length=0)
|
596 |
+
cb.set_alpha(1)
|
597 |
+
cb.outline.set_visible(False)
|
598 |
+
bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
|
599 |
+
cb.ax.set_aspect((bbox.height - 0.7) * 20)
|
600 |
+
|
601 |
+
# make the plot more readable
|
602 |
+
if interaction_index != ind:
|
603 |
+
pl.gcf().set_size_inches(7.5, 5)
|
604 |
+
else:
|
605 |
+
pl.gcf().set_size_inches(6, 5)
|
606 |
+
# pl.xlabel(name, color=axis_color, fontsize=13)
|
607 |
+
# pl.ylabel(labels['VALUE_FOR'] % name, color=axis_color, fontsize=13)
|
608 |
+
if title is not None:
|
609 |
+
pl.title(title, color=axis_color, fontsize=13)
|
610 |
+
pl.gca().xaxis.set_ticks_position('bottom')
|
611 |
+
pl.gca().yaxis.set_ticks_position('left')
|
612 |
+
pl.gca().spines['right'].set_visible(False)
|
613 |
+
pl.gca().spines['top'].set_visible(False)
|
614 |
+
pl.gca().tick_params(color=axis_color, labelcolor=axis_color, labelsize=11)
|
615 |
+
for spine in pl.gca().spines.values():
|
616 |
+
spine.set_edgecolor(axis_color)
|
617 |
+
if type(xd[0]) == str:
|
618 |
+
pl.xticks([name_map[n] for n in xnames], xnames, rotation='vertical', fontsize=11)
|
619 |
+
# if show:
|
620 |
+
# pl.show()
|
621 |
+
|
622 |
+
|
623 |
+
if ind1 == ind2:
|
624 |
+
pl.ylabel(labels['MAIN_EFFECT'] % feature_names[ind1])
|
625 |
+
else:
|
626 |
+
pl.ylabel(labels['INTERACTION_EFFECT'] % (feature_names[ind1], feature_names[ind2]))
|
627 |
+
|
628 |
+
return mpl_fig, interaction_index
|
629 |
+
|
630 |
+
|
631 |
+
# # if show:
|
632 |
+
# # pl.show()
|
633 |
+
# return
|
634 |
+
# return mpl_fig
|
635 |
+
|
636 |
+
# assert shap_values.shape[0] == features.shape[0], "'shap_values' and 'features' values must have the same number of rows!"
|
637 |
+
# assert shap_values.shape[1] == features.shape[1] + 1, "'shap_values' must have one more column than 'features'!"
|
638 |
+
|
639 |
+
# get both the raw and display feature values
|
640 |
+
xv = features[:, ind]
|
641 |
+
xd = display_features[:, ind]
|
642 |
+
s = shap_values[:, ind]
|
643 |
+
if type(xd[0]) == str:
|
644 |
+
name_map = {}
|
645 |
+
for i in range(len(xv)):
|
646 |
+
name_map[xd[i]] = xv[i]
|
647 |
+
xnames = list(name_map.keys())
|
648 |
+
|
649 |
+
# allow a single feature name to be passed alone
|
650 |
+
if type(feature_names) == str:
|
651 |
+
feature_names = [feature_names]
|
652 |
+
name = feature_names[ind]
|
653 |
+
|
654 |
+
# guess what other feature as the stongest interaction with the plotted feature
|
655 |
+
if interaction_index == "auto":
|
656 |
+
interaction_index = approx_interactions(ind, shap_values, features)[0]
|
657 |
+
interaction_index = convert_name(interaction_index)
|
658 |
+
categorical_interaction = False
|
659 |
+
|
660 |
+
# get both the raw and display color values
|
661 |
+
if interaction_index is not None:
|
662 |
+
cv = features[:, interaction_index]
|
663 |
+
cd = display_features[:, interaction_index]
|
664 |
+
clow = np.nanpercentile(features[:, interaction_index].astype(np.float), 5)
|
665 |
+
chigh = np.nanpercentile(features[:, interaction_index].astype(np.float), 95)
|
666 |
+
if type(cd[0]) == str:
|
667 |
+
cname_map = {}
|
668 |
+
for i in range(len(cv)):
|
669 |
+
cname_map[cd[i]] = cv[i]
|
670 |
+
cnames = list(cname_map.keys())
|
671 |
+
categorical_interaction = True
|
672 |
+
elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:, interaction_index])) < 50:
|
673 |
+
categorical_interaction = True
|
674 |
+
|
675 |
+
# discritize colors for categorical features
|
676 |
+
color_norm = None
|
677 |
+
if categorical_interaction and clow != chigh:
|
678 |
+
bounds = np.linspace(clow, chigh, chigh - clow + 2)
|
679 |
+
color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N)
|
680 |
+
|
681 |
+
# the actual scatter plot, TODO: adapt the dot_size to the number of data points?
|
682 |
+
if interaction_index is not None:
|
683 |
+
pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:, interaction_index], cmap=red_blue,
|
684 |
+
alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500)
|
685 |
+
else:
|
686 |
+
pl.scatter(xv, s, s=dot_size, linewidth=0, color="#1E88E5",
|
687 |
+
alpha=alpha, rasterized=len(xv) > 500)
|
688 |
+
|
689 |
+
if interaction_index != ind and interaction_index is not None:
|
690 |
+
# draw the color bar
|
691 |
+
if type(cd[0]) == str:
|
692 |
+
tick_positions = [cname_map[n] for n in cnames]
|
693 |
+
if len(tick_positions) == 2:
|
694 |
+
tick_positions[0] -= 0.25
|
695 |
+
tick_positions[1] += 0.25
|
696 |
+
cb = pl.colorbar(ticks=tick_positions)
|
697 |
+
cb.set_ticklabels(cnames)
|
698 |
+
else:
|
699 |
+
cb = pl.colorbar()
|
700 |
+
|
701 |
+
cb.set_label(feature_names[interaction_index], size=13)
|
702 |
+
cb.ax.tick_params(labelsize=11)
|
703 |
+
if categorical_interaction:
|
704 |
+
cb.ax.tick_params(length=0)
|
705 |
+
cb.set_alpha(1)
|
706 |
+
cb.outline.set_visible(False)
|
707 |
+
bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
|
708 |
+
cb.ax.set_aspect((bbox.height - 0.7) * 20)
|
709 |
+
|
710 |
+
# make the plot more readable
|
711 |
+
if interaction_index != ind:
|
712 |
+
pl.gcf().set_size_inches(7.5, 5)
|
713 |
+
else:
|
714 |
+
pl.gcf().set_size_inches(6, 5)
|
715 |
+
pl.xlabel(name, color=axis_color, fontsize=13)
|
716 |
+
pl.ylabel(labels['VALUE_FOR'] % name, color=axis_color, fontsize=13)
|
717 |
+
if title is not None:
|
718 |
+
pl.title(title, color=axis_color, fontsize=13)
|
719 |
+
pl.gca().xaxis.set_ticks_position('bottom')
|
720 |
+
pl.gca().yaxis.set_ticks_position('left')
|
721 |
+
pl.gca().spines['right'].set_visible(False)
|
722 |
+
pl.gca().spines['top'].set_visible(False)
|
723 |
+
pl.gca().tick_params(color=axis_color, labelcolor=axis_color, labelsize=11)
|
724 |
+
for spine in pl.gca().spines.values():
|
725 |
+
spine.set_edgecolor(axis_color)
|
726 |
+
if type(xd[0]) == str:
|
727 |
+
pl.xticks([name_map[n] for n in xnames], xnames, rotation='vertical', fontsize=11)
|
728 |
+
# if show:
|
729 |
+
# pl.show()
|
730 |
return mpl_fig, interaction_index
|