hlnicholls commited on
Commit
5464374
·
1 Parent(s): 7b49260

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -16
app.py CHANGED
@@ -106,24 +106,16 @@ def matplotlib_to_plotly(cmap, pl_entries):
106
 
107
  red_blue = matplotlib_to_plotly(red_blue, 255)
108
 
109
- df_shap = df_total[df_total.index.isin(gene_list)]
110
- df_shap.drop(columns='XGB_Score', inplace=True)
111
- shap_values = explainer.shap_values(df_shap)
112
- summary_plot = shap.summary_plot(shap_values, df_shap)
113
- feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
114
- feature_order = feature_order[-min(8, len(feature_order)):]
115
- col_order = [df_shap.columns[i] for i in feature_order]
116
-
117
  if len(gene_list) > 1:
118
  df = df_total[df_total.index.isin(gene_list)]
119
  df['Gene'] = df.index
120
  df.reset_index(drop=True, inplace=True)
121
  df = df[['Gene','XGB_Score', 'mousescore_Exomiser',
122
- 'SDI', 'Liver_GTExTPM', 'pLI_ExAC',
123
- 'HIPred',
124
- 'Cells - EBV-transformed lymphocytes_GTExTPM',
125
- 'Pituitary_GTExTPM',
126
- 'IPA_BP_annotation']]
127
  st.dataframe(df)
128
  output = df[['Gene', 'XGB_Score']]
129
  csv = convert_df(output)
@@ -138,13 +130,13 @@ if len(gene_list) > 1:
138
  df_shap.drop(columns='XGB_Score', inplace=True)
139
  shap_values = explainer.shap_values(df_shap)
140
  summary_plot = shap.summary_plot(shap_values, df_shap)
 
 
141
  feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
142
  feature_order = feature_order[-min(8, len(feature_order)):]
143
  col_order = [df.columns[i] for i in feature_order]
144
- st.caption("SHAP Summary Plot of All Input Genes")
145
- st.pyplot(fig=summary_plot)
146
  st.caption("Interactive SHAP Summary Plot of All Input Genes")
147
- mpl_fig = shap_summary_plot(shap_values, df_shap, max_display=8, show=False, feature_names=df_shap.columns)
148
  plotly_fig = tls.mpl_to_plotly(mpl_fig)
149
  plotly_fig['layout'] = {'xaxis': {'title': 'SHAP value (impact on model output)'}}
150
  max_display=8
 
106
 
107
  red_blue = matplotlib_to_plotly(red_blue, 255)
108
 
 
 
 
 
 
 
 
 
109
  if len(gene_list) > 1:
110
  df = df_total[df_total.index.isin(gene_list)]
111
  df['Gene'] = df.index
112
  df.reset_index(drop=True, inplace=True)
113
  df = df[['Gene','XGB_Score', 'mousescore_Exomiser',
114
+ 'SDI', 'Liver_GTExTPM', 'pLI_ExAC',
115
+ 'HIPred',
116
+ 'Cells - EBV-transformed lymphocytes_GTExTPM',
117
+ 'Pituitary_GTExTPM',
118
+ 'IPA_BP_annotation']]
119
  st.dataframe(df)
120
  output = df[['Gene', 'XGB_Score']]
121
  csv = convert_df(output)
 
130
  df_shap.drop(columns='XGB_Score', inplace=True)
131
  shap_values = explainer.shap_values(df_shap)
132
  summary_plot = shap.summary_plot(shap_values, df_shap)
133
+ st.pyplot(fig=summary_plot)
134
+ st.caption("SHAP Summary Plot of All Input Genes")
135
  feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
136
  feature_order = feature_order[-min(8, len(feature_order)):]
137
  col_order = [df.columns[i] for i in feature_order]
 
 
138
  st.caption("Interactive SHAP Summary Plot of All Input Genes")
139
+ mpl_fig = shap_summary_plot(shap_values, df_shap, max_display=8, show=False, feature_names=col_order)
140
  plotly_fig = tls.mpl_to_plotly(mpl_fig)
141
  plotly_fig['layout'] = {'xaxis': {'title': 'SHAP value (impact on model output)'}}
142
  max_display=8