Spaces:
Sleeping
Sleeping
import streamlit as st | |
import re | |
import numpy as np | |
import pandas as pd | |
import sklearn | |
import xgboost | |
import shap | |
import plotly.tools as tls | |
import dash_core_components as dcc | |
import matplotlib | |
import plotly.graph_objs as go | |
try: | |
import matplotlib.pyplot as pl | |
from matplotlib.colors import LinearSegmentedColormap | |
from matplotlib.ticker import MaxNLocator | |
except ImportError: | |
pass | |
st.set_option('deprecation.showPyplotGlobalUse', False) | |
seed=42 | |
annotations = pd.read_csv("annotations_dataset.csv") | |
annotations = annotations.set_index("Gene") | |
training_data = pd.read_csv("./selected_features_training_data.csv", header=0) | |
training_data.columns = [ | |
regex.sub("_", col) if any(x in str(col) for x in set(("[", "]", "<"))) else col | |
for col in training_data.columns.values | |
] | |
training_data["BPlabel_encoded"] = training_data["BPlabel"].map( | |
{"most likely": 1, "probable": 0.75, "least likely": 0.1} | |
) | |
Y = training_data["BPlabel_encoded"] | |
X = training_data.drop(columns=["BPlabel_encoded","BPlabel"]) | |
xgb = xgboost.XGBRegressor( | |
n_estimators=40, | |
learning_rate=0.2, | |
max_depth=4, | |
reg_alpha=1, | |
reg_lambda=1, | |
random_state=seed, | |
objective="reg:squarederror") | |
xgb.fit(X, Y) | |
prediction_list = list(xgb.predict(annotations)) | |
predictions = [round(prediction, 2) for prediction in prediction_list] | |
output = pd.Series(data=predictions, index=annotations.index, name="XGB_Score") | |
df_total = pd.concat([annotations, output], axis=1) | |
df_total = df_total[['XGB_Score', 'mousescore_Exomiser', | |
'SDI', 'Liver_GTExTPM', 'pLI_ExAC', | |
'HIPred', | |
'Cells - EBV-transformed lymphocytes_GTExTPM', | |
'Pituitary_GTExTPM', | |
'IPA_BP_annotation']] | |
st.title('Blood Pressure Gene Prioritisation Post-GWAS') | |
st.markdown(""" | |
A machine learning pipeline for predicting disease-causing genes post-genome-wide association study in blood pressure. | |
""") | |
collect_genes = lambda x : [str(i) for i in re.split(",|,\s+|\s+", x) if i != ""] | |
input_gene_list = st.text_input("Input list of HGNC genes (enter comma separated):") | |
gene_list = collect_genes(input_gene_list) | |
explainer = shap.TreeExplainer(xgb) | |
def convert_df(df): | |
return df.to_csv(index=False).encode('utf-8') | |
cdict1 = { | |
'red': ((0.0, 0.11764705882352941, 0.11764705882352941), | |
(1.0, 0.9607843137254902, 0.9607843137254902)), | |
'green': ((0.0, 0.5333333333333333, 0.5333333333333333), | |
(1.0, 0.15294117647058825, 0.15294117647058825)), | |
'blue': ((0.0, 0.8980392156862745, 0.8980392156862745), | |
(1.0, 0.3411764705882353, 0.3411764705882353)), | |
'alpha': ((0.0, 1, 1), | |
(0.5, 1, 1), | |
(1.0, 1, 1)) | |
} # #1E88E5 -> #ff0052 | |
red_blue = LinearSegmentedColormap('RedBlue', cdict1) | |
def matplotlib_to_plotly(cmap, pl_entries): | |
h = 1.0/(pl_entries-1) | |
pl_colorscale = [] | |
for k in range(pl_entries): | |
C = list(map(np.uint8, np.array(cmap(k*h)[:3])*255)) | |
pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))]) | |
return pl_colorscale | |
red_blue = matplotlib_to_plotly(red_blue, 255) | |
if len(gene_list) > 1: | |
df = df_total[df_total.index.isin(gene_list)] | |
df['Gene'] = df.index | |
df.reset_index(drop=True, inplace=True) | |
df = df[['Gene','XGB_Score', 'mousescore_Exomiser', | |
'SDI', 'Liver_GTExTPM', 'pLI_ExAC', | |
'HIPred', | |
'Cells - EBV-transformed lymphocytes_GTExTPM', | |
'Pituitary_GTExTPM', | |
'IPA_BP_annotation']] | |
st.dataframe(df) | |
output = df[['Gene', 'XGB_Score']] | |
csv = convert_df(output) | |
st.download_button( | |
"Download Gene Prioritisation", | |
csv, | |
"bp_gene_prioritisation.csv", | |
"text/csv", | |
key='download-csv' | |
) | |
df_shap = df_total[df_total.index.isin(gene_list)] | |
df_shap.drop(columns='XGB_Score', inplace=True) | |
shap_values = explainer.shap_values(df_shap) | |
summary_plot = shap.summary_plot(shap_values, df_shap, show=False) | |
st.caption("SHAP Summary Plot of All Input Genes") | |
st.pyplot(fig=summary_plot) | |
st.caption("Interactive SHAP Summary Plot of All Input Genes") | |
mpl_fig = shap_summary_plot(shap_values, df_shap, show=False) | |
plotly_fig = tls.mpl_to_plotly(mpl_fig) | |
plotly_fig['layout'] = {'xaxis': {'title': 'SHAP value (impact on model output)'}} | |
feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1]) | |
feature_order = feature_order[-min(max_display, len(feature_order)):] | |
text = gene_list | |
text = iter(text) | |
for i in range(1, len(plotly_fig['df_shap']), 2): | |
t = text.__next__() | |
plotly_fig['df_shap'][i]['name'] = '' | |
plotly_fig['df_shap'][i]['text'] = t | |
plotly_fig['df_shap'][i]['hoverinfo'] = 'text' | |
colorbar_trace = go.Scatter(x=[None], | |
y=[None], | |
mode='markers', | |
marker=dict( | |
colorscale=red_blue, | |
showscale=True, | |
cmin=-5, | |
cmax=5, | |
colorbar=dict(thickness=5, tickvals=[-5, 5], ticktext=['Low', 'High'], outlinewidth=0) | |
), | |
hoverinfo='none' | |
) | |
plotly_fig['layout']['showlegend'] = False | |
plotly_fig['layout']['hovermode'] = 'closest' | |
plotly_fig['layout']['height']=600 | |
plotly_fig['layout']['width']=500 | |
plotly_fig['layout']['xaxis'].update(zeroline=True, showline=True, ticklen=4, showgrid=False) | |
plotly_fig['layout']['yaxis'].update(dict(visible=False)) | |
plotly_fig.add_trace(colorbar_trace) | |
plotly_fig.layout.update( | |
annotations=[dict( | |
x=1.18, | |
align="right", | |
valign="top", | |
text='Feature value', | |
showarrow=False, | |
xref="paper", | |
yref="paper", | |
xanchor="right", | |
yanchor="middle", | |
textangle=-90, | |
font=dict(family='Calibri', size=14) | |
) | |
], | |
margin=dict(t=20) | |
) | |
st.plotly_chart(plotly_fig) | |
else: | |
pass | |
input_gene = st.text_input("Input individual HGNC gene:") | |
df2 = df_total[df_total.index == input_gene] | |
df2['Gene'] = df2.index | |
df2.reset_index(drop=True, inplace=True) | |
df2 = df2[['Gene','XGB_Score', 'mousescore_Exomiser', | |
'SDI', 'Liver_GTExTPM', 'pLI_ExAC', | |
'HIPred', | |
'Cells - EBV-transformed lymphocytes_GTExTPM', | |
'Pituitary_GTExTPM', | |
'IPA_BP_annotation']] | |
st.dataframe(df2) | |
if input_gene: | |
df2_shap = df_total[df_total.index == input_gene] | |
df2_shap.drop(columns='XGB_Score', inplace=True) | |
shap_values = explainer.shap_values(df2_shap) | |
shap.getjs() | |
force_plot = shap.force_plot( | |
explainer.expected_value, | |
shap_values, | |
df2_shap, | |
matplotlib = True,show=False) | |
st.pyplot(fig=force_plot) | |
else: | |
pass | |
st.markdown(""" | |
Total Gene Prioritisation Results: | |
""") | |
df_total_output = df_total | |
df_total_output['Gene'] = df_total_output.index | |
df_total_output.reset_index(drop=True, inplace=True) | |
df_total_output = df_total_output[['Gene','XGB_Score', 'mousescore_Exomiser', | |
'SDI', 'Liver_GTExTPM', 'pLI_ExAC', | |
'HIPred', | |
'Cells - EBV-transformed lymphocytes_GTExTPM', | |
'Pituitary_GTExTPM', | |
'IPA_BP_annotation']] | |
st.dataframe(df_total_output) | |