hlnicholls commited on
Commit
a828d88
·
1 Parent(s): 35437a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -0
app.py CHANGED
@@ -5,6 +5,17 @@ import pandas as pd
5
  import sklearn
6
  import xgboost
7
  import shap
 
 
 
 
 
 
 
 
 
 
 
8
  st.set_option('deprecation.showPyplotGlobalUse', False)
9
 
10
  seed=42
@@ -65,6 +76,36 @@ explainer = shap.TreeExplainer(xgb)
65
  def convert_df(df):
66
  return df.to_csv(index=False).encode('utf-8')
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  if len(gene_list) > 1:
69
  df = df_total[df_total.index.isin(gene_list)]
70
  df['Gene'] = df.index
@@ -91,6 +132,57 @@ if len(gene_list) > 1:
91
  summary_plot = shap.summary_plot(shap_values, df_shap, show=False)
92
  st.caption("SHAP Summary Plot of All Input Genes")
93
  st.pyplot(fig=summary_plot)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  else:
96
  pass
 
5
  import sklearn
6
  import xgboost
7
  import shap
8
+ import plotly.tools as tls
9
+ import dash_core_components as dcc
10
+ import matplotlib
11
+ import plotly.graph_objs as go
12
+ try:
13
+ import matplotlib.pyplot as pl
14
+ from matplotlib.colors import LinearSegmentedColormap
15
+ from matplotlib.ticker import MaxNLocator
16
+ except ImportError:
17
+ pass
18
+
19
  st.set_option('deprecation.showPyplotGlobalUse', False)
20
 
21
  seed=42
 
76
  def convert_df(df):
77
  return df.to_csv(index=False).encode('utf-8')
78
 
79
+
80
+ cdict1 = {
81
+ 'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
82
+ (1.0, 0.9607843137254902, 0.9607843137254902)),
83
+
84
+ 'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
85
+ (1.0, 0.15294117647058825, 0.15294117647058825)),
86
+
87
+ 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
88
+ (1.0, 0.3411764705882353, 0.3411764705882353)),
89
+
90
+ 'alpha': ((0.0, 1, 1),
91
+ (0.5, 1, 1),
92
+ (1.0, 1, 1))
93
+ } # #1E88E5 -> #ff0052
94
+ red_blue = LinearSegmentedColormap('RedBlue', cdict1)
95
+
96
+ def matplotlib_to_plotly(cmap, pl_entries):
97
+ h = 1.0/(pl_entries-1)
98
+ pl_colorscale = []
99
+
100
+ for k in range(pl_entries):
101
+ C = list(map(np.uint8, np.array(cmap(k*h)[:3])*255))
102
+ pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])
103
+
104
+ return pl_colorscale
105
+
106
+ red_blue = matplotlib_to_plotly(red_blue, 255)
107
+
108
+
109
  if len(gene_list) > 1:
110
  df = df_total[df_total.index.isin(gene_list)]
111
  df['Gene'] = df.index
 
132
  summary_plot = shap.summary_plot(shap_values, df_shap, show=False)
133
  st.caption("SHAP Summary Plot of All Input Genes")
134
  st.pyplot(fig=summary_plot)
135
+ st.caption("Interactive SHAP Summary Plot of All Input Genes")
136
+ mpl_fig = shap_summary_plot(shap_values, df_shap, show=False)
137
+ plotly_fig = tls.mpl_to_plotly(mpl_fig)
138
+ plotly_fig['layout'] = {'xaxis': {'title': 'SHAP value (impact on model output)'}}
139
+
140
+ feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
141
+ feature_order = feature_order[-min(max_display, len(feature_order)):]
142
+ text = gene_list
143
+ text = iter(text)
144
+ for i in range(1, len(plotly_fig['df_shap']), 2):
145
+ t = text.__next__()
146
+ plotly_fig['df_shap'][i]['name'] = ''
147
+ plotly_fig['df_shap'][i]['text'] = t
148
+ plotly_fig['df_shap'][i]['hoverinfo'] = 'text'
149
+ colorbar_trace = go.Scatter(x=[None],
150
+ y=[None],
151
+ mode='markers',
152
+ marker=dict(
153
+ colorscale=red_blue,
154
+ showscale=True,
155
+ cmin=-5,
156
+ cmax=5,
157
+ colorbar=dict(thickness=5, tickvals=[-5, 5], ticktext=['Low', 'High'], outlinewidth=0)
158
+ ),
159
+ hoverinfo='none'
160
+ )
161
+ plotly_fig['layout']['showlegend'] = False
162
+ plotly_fig['layout']['hovermode'] = 'closest'
163
+ plotly_fig['layout']['height']=600
164
+ plotly_fig['layout']['width']=500
165
+ plotly_fig['layout']['xaxis'].update(zeroline=True, showline=True, ticklen=4, showgrid=False)
166
+ plotly_fig['layout']['yaxis'].update(dict(visible=False))
167
+ plotly_fig.add_trace(colorbar_trace)
168
+ plotly_fig.layout.update(
169
+ annotations=[dict(
170
+ x=1.18,
171
+ align="right",
172
+ valign="top",
173
+ text='Feature value',
174
+ showarrow=False,
175
+ xref="paper",
176
+ yref="paper",
177
+ xanchor="right",
178
+ yanchor="middle",
179
+ textangle=-90,
180
+ font=dict(family='Calibri', size=14)
181
+ )
182
+ ],
183
+ margin=dict(t=20)
184
+ )
185
+ st.plotly_chart(plotly_fig)
186
 
187
  else:
188
  pass