Spaces:
Sleeping
Sleeping
import streamlit as st | |
import re | |
import numpy as np | |
import pandas as pd | |
import pickle | |
import sklearn | |
import catboost | |
import shap | |
from shap_plots import shap_summary_plot | |
from dynamic_shap_plot import matplotlib_to_plotly, summary_plot_plotly_fig | |
import plotly.tools as tls | |
import dash_core_components as dcc | |
import matplotlib | |
import plotly.graph_objs as go | |
try: | |
import matplotlib.pyplot as pl | |
from matplotlib.colors import LinearSegmentedColormap | |
from matplotlib.ticker import MaxNLocator | |
except ImportError: | |
pass | |
st.set_option('deprecation.showPyplotGlobalUse', False) | |
seed=42 | |
annotations = pd.read_csv("all_genes_merged_ml_data.csv") | |
# TODO remove this placeholder when imputation is finished: | |
annotations.fillna(0, inplace=True) | |
annotations = annotations.set_index("Gene") | |
# Read in best_model_fitted.pkl as catboost_model | |
model_path = "best_model_fitted.pkl" # Update this path if your model is stored elsewhere | |
with open(model_path, 'rb') as file: | |
catboost_model = pickle.load(file) | |
# For a multi-class classification model, obtaining probabilities per class | |
probabilities = catboost_model.predict_proba(annotations) | |
# Creating a DataFrame for these probabilities | |
# Assuming classes are ordered as 'most likely', 'probable', and 'least likely' in the model | |
prob_df = pd.DataFrame(probabilities, | |
index=annotations.index, | |
columns=['Probability_Most_Likely', 'Probability_Probable', 'Probability_Least_Likely']) | |
# Dynamically including all original features from annotations plus the new probability columns | |
df_total = pd.concat([prob_df, annotations], axis=1) | |
st.title('Blood Pressure Gene Prioritisation Post-GWAS') | |
st.markdown(""" | |
A machine learning pipeline for predicting disease-causing genes post-genome-wide association study in blood pressure. | |
""") | |
collect_genes = lambda x : [str(i) for i in re.split(",|,\s+|\s+", x) if i != ""] | |
input_gene_list = st.text_input("Input a list of multiple HGNC genes (enter comma separated):") | |
gene_list = collect_genes(input_gene_list) | |
explainer = shap.TreeExplainer(catboost_model) | |
def convert_df(df): | |
return df.to_csv(index=False).encode('utf-8') | |
probability_columns = ['Probability_Most_Likely', 'Probability_Probable', 'Probability_Least_Likely'] | |
features_list = [column for column in df_total.columns if column not in probability_columns] | |
features = df_total[features_list] | |
if len(gene_list) > 1: | |
df = df_total[df_total.index.isin(gene_list)] | |
df['Gene'] = df.index # Ensure 'Gene' is a column if it's not already | |
df.reset_index(drop=True, inplace=True) | |
# Including Gene, probability columns, and all other features | |
required_columns = ['Gene'] + probability_columns + [col for col in df.columns if col not in probability_columns and col != 'Gene'] | |
df = df[required_columns] | |
st.dataframe(df) | |
# Assuming you want to download the genes with their probabilities | |
output = df[['Gene'] + probability_columns] | |
csv = convert_df(output) | |
st.download_button( | |
"Download Gene Prioritisation", | |
csv, | |
"bp_gene_prioritisation.csv", | |
"text/csv", | |
key='download-csv' | |
) | |
# For SHAP values, assuming explainer is already fitted to your model | |
df_shap = df.drop(columns=probability_columns + ['Gene']) # Exclude non-feature columns | |
shap_values = explainer.shap_values(df_shap) | |
# Handle multiclass scenario: SHAP values will be a list of matrices, one per class | |
# Plotting the summary plot for the first class as an example | |
# You may loop through each class or handle it differently based on your needs | |
class_index = 0 # Example: plotting for the first class | |
shap.summary_plot(shap_values[class_index], df_shap, show=False) | |
st.pyplot(bbox_inches='tight') | |
st.caption("SHAP Summary Plot of All Input Genes") | |
else: | |
pass | |
input_gene = st.text_input("Input an individual HGNC gene:") | |
df2 = df_total[df_total.index == input_gene] | |
df2['Gene'] = df2.index | |
df2.reset_index(drop=True, inplace=True) | |
# Ensure the DataFrame includes the CatBoost model's probability columns | |
# And assuming all features are desired in the output | |
probability_columns = ['Probability_Most_Likely', 'Probability_Probable', 'Probability_Least_Likely'] | |
required_columns = ['Gene'] + probability_columns + [col for col in df2.columns if col not in probability_columns and col != 'Gene'] | |
df2 = df2[required_columns] | |
st.dataframe(df2) | |
if input_gene: | |
if ' ' in input_gene or ',' in input_gene: | |
st.write('Input Error: Please input only a single HGNC gene name with no white spaces or commas.') | |
else: | |
df2_shap = df_total.loc[[input_gene], [col for col in df_total.columns if col not in probability_columns + ['Gene']]] | |
if df2_shap.shape[0] > 0: # Check if the gene exists in the DataFrame | |
shap_values = explainer.shap_values(df2_shap) | |
# Adjust for multiclass: Select SHAP values for the predicted class (or a specific class) | |
predicted_class_index = catboost_model.predict(df2_shap).item() # Assuming predict returns the class index | |
class_shap_values = shap_values[predicted_class_index] | |
class_expected_value = explainer.expected_value[predicted_class_index] | |
# Since force_plot doesn't directly support multiclass, consider using waterfall_plot or decision_plot | |
# Here's an example using waterfall_plot for the first feature set's prediction | |
shap.plots.waterfall(shap_values=class_shap_values[0], max_display=10, show=False) | |
st.pyplot(bbox_inches='tight') | |
else: | |
pass | |
st.markdown(""" | |
### Total Gene Prioritisation Results: | |
""") | |
df_total_output = df_total | |
df_total_output['Gene'] = df_total_output.index | |
df_total_output.reset_index(drop=True, inplace=True) | |
#df_total_output = df_total_output[['Gene','XGB_Score', 'mousescore_Exomiser', | |
# 'SDI', 'Liver_GTExTPM', 'pLI_ExAC', | |
# 'HIPred', | |
# 'Cells - EBV-transformed lymphocytes_GTExTPM', | |
# 'Pituitary_GTExTPM', | |
# 'IPA_BP_annotation']] | |
st.dataframe(df_total_output) | |