hlnicholls's picture
Update app.py
cbf0145
raw
history blame
4.63 kB
import streamlit as st
import re
import numpy as np
import pandas as pd
import sklearn
import xgboost
import shap
from shap_plots import shap_summary_plot
import plotly.tools as tls
import dash_core_components as dcc
import matplotlib
import plotly.graph_objs as go
try:
import matplotlib.pyplot as pl
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.ticker import MaxNLocator
except ImportError:
pass
st.set_option('deprecation.showPyplotGlobalUse', False)
seed=42
annotations = pd.read_csv("annotations_dataset.csv")
annotations = annotations.set_index("Gene")
training_data = pd.read_csv("./selected_features_training_data.csv", header=0)
training_data.columns = [
regex.sub("_", col) if any(x in str(col) for x in set(("[", "]", "<"))) else col
for col in training_data.columns.values
]
training_data["BPlabel_encoded"] = training_data["BPlabel"].map(
{"most likely": 1, "probable": 0.75, "least likely": 0.1}
)
Y = training_data["BPlabel_encoded"]
X = training_data.drop(columns=["BPlabel_encoded","BPlabel"])
xgb = xgboost.XGBRegressor(
n_estimators=40,
learning_rate=0.2,
max_depth=4,
reg_alpha=1,
reg_lambda=1,
random_state=seed,
objective="reg:squarederror")
xgb.fit(X, Y)
prediction_list = list(xgb.predict(annotations))
predictions = [round(prediction, 2) for prediction in prediction_list]
output = pd.Series(data=predictions, index=annotations.index, name="XGB_Score")
df_total = pd.concat([annotations, output], axis=1)
df_total = df_total[['XGB_Score', 'mousescore_Exomiser',
'SDI', 'Liver_GTExTPM', 'pLI_ExAC',
'HIPred',
'Cells - EBV-transformed lymphocytes_GTExTPM',
'Pituitary_GTExTPM',
'IPA_BP_annotation']]
st.title('Blood Pressure Gene Prioritisation Post-GWAS')
st.markdown("""
A machine learning pipeline for predicting disease-causing genes post-genome-wide association study in blood pressure.
https://doi.org/10.21203/rs.3.rs-2402775/v1
""")
collect_genes = lambda x : [str(i) for i in re.split(",|,\s+|\s+", x) if i != ""]
input_gene_list = st.text_input("Input a list of multiple HGNC genes (enter comma separated):")
gene_list = collect_genes(input_gene_list)
explainer = shap.TreeExplainer(xgb)
@st.cache_data
def convert_df(df):
return df.to_csv(index=False).encode('utf-8')
if len(gene_list) > 1:
df = df_total[df_total.index.isin(gene_list)]
df['Gene'] = df.index
df.reset_index(drop=True, inplace=True)
df = df[['Gene','XGB_Score', 'mousescore_Exomiser',
'SDI', 'Liver_GTExTPM', 'pLI_ExAC',
'HIPred',
'Cells - EBV-transformed lymphocytes_GTExTPM',
'Pituitary_GTExTPM',
'IPA_BP_annotation']]
st.dataframe(df)
output = df[['Gene', 'XGB_Score']]
csv = convert_df(output)
st.download_button(
"Download Gene Prioritisation",
csv,
"bp_gene_prioritisation.csv",
"text/csv",
key='download-csv'
)
df_shap = df_total[df_total.index.isin(gene_list)]
df_shap.drop(columns='XGB_Score', inplace=True)
shap_values = explainer.shap_values(df_shap)
summary_plot = shap.summary_plot(shap_values, df_shap)
st.pyplot(fig=summary_plot)
st.caption("SHAP Summary Plot of All Input Genes")
feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
feature_order = feature_order[-min(8, len(feature_order)):]
col_order = [df_shap.columns[i] for i in feature_order]
else:
pass
input_gene = st.text_input("Input an individual HGNC gene:")
df2 = df_total[df_total.index == input_gene]
df2['Gene'] = df2.index
df2.reset_index(drop=True, inplace=True)
df2 = df2[['Gene','XGB_Score', 'mousescore_Exomiser',
'SDI', 'Liver_GTExTPM', 'pLI_ExAC',
'HIPred',
'Cells - EBV-transformed lymphocytes_GTExTPM',
'Pituitary_GTExTPM',
'IPA_BP_annotation']]
st.dataframe(df2)
if input_gene:
df2_shap = df_total[df_total.index == input_gene]
df2_shap.drop(columns='XGB_Score', inplace=True)
shap_values = explainer.shap_values(df2_shap)
shap.getjs()
force_plot = shap.force_plot(
explainer.expected_value,
shap_values,
df2_shap,
matplotlib = True,show=False)
st.pyplot(fig=force_plot)
else:
pass
st.markdown("""
Total Gene Prioritisation Results:
""")
df_total_output = df_total
df_total_output['Gene'] = df_total_output.index
df_total_output.reset_index(drop=True, inplace=True)
df_total_output = df_total_output[['Gene','XGB_Score', 'mousescore_Exomiser',
'SDI', 'Liver_GTExTPM', 'pLI_ExAC',
'HIPred',
'Cells - EBV-transformed lymphocytes_GTExTPM',
'Pituitary_GTExTPM',
'IPA_BP_annotation']]
st.dataframe(df_total_output)