Spaces:
Running
Running
Commit
·
a828d88
1
Parent(s):
35437a1
Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,17 @@ import pandas as pd
|
|
5 |
import sklearn
|
6 |
import xgboost
|
7 |
import shap
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
st.set_option('deprecation.showPyplotGlobalUse', False)
|
9 |
|
10 |
seed=42
|
@@ -65,6 +76,36 @@ explainer = shap.TreeExplainer(xgb)
|
|
65 |
def convert_df(df):
|
66 |
return df.to_csv(index=False).encode('utf-8')
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
if len(gene_list) > 1:
|
69 |
df = df_total[df_total.index.isin(gene_list)]
|
70 |
df['Gene'] = df.index
|
@@ -91,6 +132,57 @@ if len(gene_list) > 1:
|
|
91 |
summary_plot = shap.summary_plot(shap_values, df_shap, show=False)
|
92 |
st.caption("SHAP Summary Plot of All Input Genes")
|
93 |
st.pyplot(fig=summary_plot)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
else:
|
96 |
pass
|
|
|
5 |
import sklearn
|
6 |
import xgboost
|
7 |
import shap
|
8 |
+
import plotly.tools as tls
|
9 |
+
import dash_core_components as dcc
|
10 |
+
import matplotlib
|
11 |
+
import plotly.graph_objs as go
|
12 |
+
try:
|
13 |
+
import matplotlib.pyplot as pl
|
14 |
+
from matplotlib.colors import LinearSegmentedColormap
|
15 |
+
from matplotlib.ticker import MaxNLocator
|
16 |
+
except ImportError:
|
17 |
+
pass
|
18 |
+
|
19 |
st.set_option('deprecation.showPyplotGlobalUse', False)
|
20 |
|
21 |
seed=42
|
|
|
76 |
def convert_df(df):
|
77 |
return df.to_csv(index=False).encode('utf-8')
|
78 |
|
79 |
+
|
80 |
+
cdict1 = {
|
81 |
+
'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
|
82 |
+
(1.0, 0.9607843137254902, 0.9607843137254902)),
|
83 |
+
|
84 |
+
'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
|
85 |
+
(1.0, 0.15294117647058825, 0.15294117647058825)),
|
86 |
+
|
87 |
+
'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
|
88 |
+
(1.0, 0.3411764705882353, 0.3411764705882353)),
|
89 |
+
|
90 |
+
'alpha': ((0.0, 1, 1),
|
91 |
+
(0.5, 1, 1),
|
92 |
+
(1.0, 1, 1))
|
93 |
+
} # #1E88E5 -> #ff0052
|
94 |
+
red_blue = LinearSegmentedColormap('RedBlue', cdict1)
|
95 |
+
|
96 |
+
def matplotlib_to_plotly(cmap, pl_entries):
|
97 |
+
h = 1.0/(pl_entries-1)
|
98 |
+
pl_colorscale = []
|
99 |
+
|
100 |
+
for k in range(pl_entries):
|
101 |
+
C = list(map(np.uint8, np.array(cmap(k*h)[:3])*255))
|
102 |
+
pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])
|
103 |
+
|
104 |
+
return pl_colorscale
|
105 |
+
|
106 |
+
red_blue = matplotlib_to_plotly(red_blue, 255)
|
107 |
+
|
108 |
+
|
109 |
if len(gene_list) > 1:
|
110 |
df = df_total[df_total.index.isin(gene_list)]
|
111 |
df['Gene'] = df.index
|
|
|
132 |
summary_plot = shap.summary_plot(shap_values, df_shap, show=False)
|
133 |
st.caption("SHAP Summary Plot of All Input Genes")
|
134 |
st.pyplot(fig=summary_plot)
|
135 |
+
st.caption("Interactive SHAP Summary Plot of All Input Genes")
|
136 |
+
mpl_fig = shap_summary_plot(shap_values, df_shap, show=False)
|
137 |
+
plotly_fig = tls.mpl_to_plotly(mpl_fig)
|
138 |
+
plotly_fig['layout'] = {'xaxis': {'title': 'SHAP value (impact on model output)'}}
|
139 |
+
|
140 |
+
feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
|
141 |
+
feature_order = feature_order[-min(max_display, len(feature_order)):]
|
142 |
+
text = gene_list
|
143 |
+
text = iter(text)
|
144 |
+
for i in range(1, len(plotly_fig['df_shap']), 2):
|
145 |
+
t = text.__next__()
|
146 |
+
plotly_fig['df_shap'][i]['name'] = ''
|
147 |
+
plotly_fig['df_shap'][i]['text'] = t
|
148 |
+
plotly_fig['df_shap'][i]['hoverinfo'] = 'text'
|
149 |
+
colorbar_trace = go.Scatter(x=[None],
|
150 |
+
y=[None],
|
151 |
+
mode='markers',
|
152 |
+
marker=dict(
|
153 |
+
colorscale=red_blue,
|
154 |
+
showscale=True,
|
155 |
+
cmin=-5,
|
156 |
+
cmax=5,
|
157 |
+
colorbar=dict(thickness=5, tickvals=[-5, 5], ticktext=['Low', 'High'], outlinewidth=0)
|
158 |
+
),
|
159 |
+
hoverinfo='none'
|
160 |
+
)
|
161 |
+
plotly_fig['layout']['showlegend'] = False
|
162 |
+
plotly_fig['layout']['hovermode'] = 'closest'
|
163 |
+
plotly_fig['layout']['height']=600
|
164 |
+
plotly_fig['layout']['width']=500
|
165 |
+
plotly_fig['layout']['xaxis'].update(zeroline=True, showline=True, ticklen=4, showgrid=False)
|
166 |
+
plotly_fig['layout']['yaxis'].update(dict(visible=False))
|
167 |
+
plotly_fig.add_trace(colorbar_trace)
|
168 |
+
plotly_fig.layout.update(
|
169 |
+
annotations=[dict(
|
170 |
+
x=1.18,
|
171 |
+
align="right",
|
172 |
+
valign="top",
|
173 |
+
text='Feature value',
|
174 |
+
showarrow=False,
|
175 |
+
xref="paper",
|
176 |
+
yref="paper",
|
177 |
+
xanchor="right",
|
178 |
+
yanchor="middle",
|
179 |
+
textangle=-90,
|
180 |
+
font=dict(family='Calibri', size=14)
|
181 |
+
)
|
182 |
+
],
|
183 |
+
margin=dict(t=20)
|
184 |
+
)
|
185 |
+
st.plotly_chart(plotly_fig)
|
186 |
|
187 |
else:
|
188 |
pass
|