Spaces:
Sleeping
Sleeping
import streamlit as st | |
import re | |
import numpy as np | |
import pandas as pd | |
import sklearn | |
import xgboost | |
import shap | |
st.set_option('deprecation.showPyplotGlobalUse', False) | |
seed=42 | |
data = pd.read_csv("annotations_dataset.csv") | |
data = data.set_index("Gene") | |
training_data = pd.read_csv("./selected_features_training_data.csv", header=0) | |
training_data.columns = [ | |
regex.sub("_", col) if any(x in str(col) for x in set(("[", "]", "<"))) else col | |
for col in training_data.columns.values | |
] | |
training_data["BPlabel_encoded"] = training_data["BPlabel"].map( | |
{"most likely": 1, "probable": 0.75, "least likely": 0.1} | |
) | |
Y = training_data["BPlabel_encoded"] | |
X = training_data.drop(columns=["BPlabel_encoded","BPlabel"]) | |
xgb = xgboost.XGBRegressor( | |
n_estimators=40, | |
learning_rate=0.2, | |
max_depth=4, | |
reg_alpha=1, | |
reg_lambda=1, | |
random_state=seed, | |
objective="reg:squarederror", | |
) | |
xgb.fit(X, Y) | |
predictions = list(xgb.predict(data)) | |
predictions = [round(item, 2) for item in predictions] | |
output = pd.Series(data=predictions, index=data.index, name="XGB_Score") | |
df_total = pd.concat([data, output], axis=1) | |
df_total.rename_axis('Gene').reset_index() | |
df_total = df_total[['XGB_Score', 'mousescore_Exomiser', | |
'SDI', 'Liver_GTExTPM', 'pLI_ExAC', | |
'HIPred', | |
'Cells - EBV-transformed lymphocytes_GTExTPM', | |
'Pituitary_GTExTPM', | |
'IPA_BP_annotation']] | |
st.title('Blood Pressure Gene Prioritisation Post-GWAS') | |
st.markdown(""" | |
A machine learning pipeline for predicting disease-causing genes post-genome-wide association study in blood pressure. | |
""") | |
collect_genes = lambda x : [str(i) for i in re.split(",|, ", x) if i != ""] | |
input_gene_list = st.text_input("Input list of HGNC genes (enter comma separated):") | |
gene_list = collect_genes(input_gene_list) | |
explainer = shap.TreeExplainer(xgb) | |
if len(gene_list) > 1: | |
df = df_total[df_total.index.isin(gene_list)] | |
st.dataframe(df) | |
df.drop(columns='XGB_Score', inplace=True) | |
shap_values = explainer.shap_values(df) | |
summary_plot = shap.summary_plot(shap_values, df, show=False) | |
st.caption("SHAP Summary Plot of All Input Genes") | |
st.pyplot(fig=summary_plot) | |
else: | |
pass | |
input_gene = st.text_input("Input individual HGNC gene:") | |
df2 = df_total[df_total.index == input_gene] | |
st.dataframe(df2) | |
df2.drop(columns='XGB_Score', inplace=True) | |
if input_gene: | |
shap_values = explainer.shap_values(df2) | |
shap.initjs() | |
force_plot = shap.force_plot( | |
explainer.expected_value, | |
shap_values.values, | |
df2, | |
show=False) | |
st.pyplot(fig=force_plot) | |
else: | |
pass | |
st.markdown(""" | |
Total Gene Prioritisation Results: | |
""") | |
st.dataframe(df_total) | |