Spaces:
Running
Running
import streamlit as st | |
import re | |
import numpy as np | |
import pandas as pd | |
import pickle | |
import sklearn | |
import catboost | |
import shap | |
import plotly.tools as tls | |
from dash import dcc | |
from sklearn.cluster import KMeans | |
from sklearn.decomposition import PCA | |
from sklearn.metrics import silhouette_score | |
import plotly.express as px | |
import matplotlib.pyplot as plt | |
import plotly.graph_objs as go | |
import plotly.graph_objects as go | |
try: | |
import matplotlib.pyplot as pl | |
from matplotlib.colors import LinearSegmentedColormap | |
from matplotlib.ticker import MaxNLocator | |
except ImportError: | |
pass | |
st.set_option('deprecation.showPyplotGlobalUse', False) | |
seed = 0 | |
annotations = pd.read_csv("all_genes_imputed_features.csv") | |
annotations.fillna(0, inplace=True) | |
annotations = annotations.set_index("Gene") | |
model_path = "best_model_fitted.pkl" | |
with open(model_path, 'rb') as file: | |
catboost_model = pickle.load(file) | |
probabilities = catboost_model.predict_proba(annotations) | |
prob_df = pd.DataFrame(probabilities, index=annotations.index, columns=['Probability_Most_Likely', 'Probability_Probable', 'Probability_Least_Likely']) | |
df_total = pd.concat([prob_df, annotations], axis=1) | |
# Create tabs for navigation | |
with st.sidebar: | |
st.sidebar.title("Navigation") | |
tab = st.sidebar.radio("Go to", ("Gene Prioritisation", "Interactive SHAP Plot", "Supervised SHAP Clustering")) | |
st.title('Blood Pressure Gene Prioritisation Post-GWAS') | |
st.markdown("""A machine learning pipeline for predicting disease-causing genes post-genome-wide association study in blood pressure.""") | |
# Define a function to collect genes from input | |
collect_genes = lambda x: [str(i) for i in re.split(",|,\s+|\s+", x) if i != ""] | |
input_gene_list = st.text_input("Input a list of multiple HGNC genes (enter comma separated):") | |
gene_list = collect_genes(input_gene_list) | |
explainer = shap.TreeExplainer(catboost_model) | |
def convert_df(df): | |
return df.to_csv(index=False).encode('utf-8') | |
probability_columns = ['Probability_Most_Likely', 'Probability_Probable', 'Probability_Least_Likely'] | |
features_list = [column for column in df_total.columns if column not in probability_columns] | |
features = df_total[features_list] | |
# Page 1: Gene Prioritisation | |
if tab == "Gene Prioritisation": | |
if len(gene_list) > 1: | |
df = df_total[df_total.index.isin(gene_list)] | |
df['Gene'] = df.index | |
df.reset_index(drop=True, inplace=True) | |
required_columns = ['Gene'] + probability_columns + [column for column in df.columns if column not in probability_columns and column != 'Gene'] | |
df = df[required_columns] | |
st.dataframe(df) | |
output = df[['Gene'] + probability_columns] | |
csv = convert_df(output) | |
st.download_button("Download Gene Prioritisation", csv, "bp_gene_prioritisation.csv", "text/csv", key='download-csv') | |
df_shap = df.drop(columns=probability_columns + ['Gene']) | |
shap_values = explainer.shap_values(df_shap) | |
col1, col2 = st.columns(2) | |
class_names = ["Most likely", "Probable", "Least likely"] | |
with col1: | |
st.subheader("Global SHAP Summary Plot") | |
shap.summary_plot(shap_values, df_shap, plot_type="bar", class_names=class_names) | |
st.pyplot(bbox_inches='tight', clear_figure=True) | |
with col2: | |
st.subheader(f"{class_names[0]} Gene Prediction") | |
shap.summary_plot(shap_values[0], df_shap) | |
st.pyplot(bbox_inches='tight', clear_figure=True) | |
col3, col4 = st.columns(2) | |
with col3: | |
st.subheader(f"{class_names[1]} Gene Prediction") | |
shap.summary_plot(shap_values[1], df_shap) | |
st.pyplot(bbox_inches='tight', clear_figure=True) | |
with col4: | |
st.subheader(f"{class_names[2]} Gene Prediction") | |
shap.summary_plot(shap_values[2], df_shap) | |
st.pyplot(bbox_inches='tight', clear_figure=True) | |
else: | |
pass | |
input_gene = st.text_input("Input an individual HGNC gene:") | |
if input_gene: | |
df2 = df_total[df_total.index == input_gene] | |
class_names = ["Most likely", "Probable", "Least likely"] | |
if not df2.empty: | |
df2['Gene'] = df2.index | |
df2.reset_index(drop=True, inplace=True) | |
required_columns = ['Gene'] + probability_columns + [col for col in df2.columns if col not in probability_columns and col != 'Gene'] | |
df2 = df2[required_columns] | |
st.dataframe(df2) | |
if ' ' in input_gene or ',' in input_gene: | |
st.write('Input Error: Please input only a single HGNC gene name with no white spaces or commas.') | |
else: | |
df2_shap = df_total.loc[[input_gene], [col for col in df_total.columns if col not in probability_columns + ['Gene']]] | |
print(df2_shap.columns) | |
shap_values = explainer.shap_values(df2_shap) | |
shap.getjs() | |
for i in range(3): | |
st.subheader(f"Force Plot for {class_names[i]} Prediction") | |
force_plot = shap.force_plot( | |
explainer.expected_value[i], | |
shap_values[i], | |
df2_shap, | |
matplotlib=True, | |
show=False | |
) | |
st.pyplot(fig=force_plot) | |
else: | |
st.write("Gene not found in the dataset.") | |
else: | |
pass | |
url = f"https://astrazeneca-cgr-publications.github.io/DrugnomeAI/geneview.html?gene={input_gene}" | |
markdown_link = f"[{input_gene} druggability in DrugnomeAI]({url})" | |
st.markdown(markdown_link, unsafe_allow_html=True) | |
st.markdown(""" | |
### Total Gene Prioritisation Results for All Genes: | |
""") | |
df_total_output = df_total | |
df_total_output['Gene'] = df_total_output.index | |
#df_total_output.reset_index(drop=True, inplace=True) | |
st.dataframe(df_total_output) | |
csv = convert_df(df_total_output) | |
st.download_button("Download Gene Prioritisation", csv, "all_genes_bp_prioritisation.csv", "text/csv", key='download-all-csv') | |
elif tab == "Interactive SHAP Plot": | |
st.title("Interactive SHAP Plot") | |
if len(gene_list) > 1: | |
df = df_total[df_total.index.isin(gene_list)] | |
df['Gene'] = df.index | |
df.reset_index(drop=True, inplace=True) | |
required_columns = ['Gene'] + probability_columns + [column for column in df.columns if column not in probability_columns and column != 'Gene'] | |
df = df[required_columns] | |
st.dataframe(df) | |
output = df[['Gene'] + probability_columns] | |
csv = convert_df(output) | |
st.download_button("Download Gene Prioritisation", csv, "bp_gene_prioritisation.csv", "text/csv", key='download-csv') | |
df_shap = df.drop(columns=probability_columns + ['Gene']) | |
shap_values = explainer.shap_values(df_shap) | |
shap_values_first_class = shap_values[0] | |
feature_importance = np.abs(shap_values_first_class).mean(axis=0) | |
top_features_indices = np.argsort(feature_importance)[-20:] | |
features_top = df_shap.columns[top_features_indices][::-1] | |
shap_values_top = shap_values_first_class[:, top_features_indices][..., ::-1] | |
# Prepare data for a single trace | |
x_values = [] | |
y_values = [] | |
hover_texts = [] | |
for i, feature_name in enumerate(features_top): | |
for gene, value in zip(df['Gene'], shap_values_top[:, i]): | |
x_values.append(value) | |
y_values.append(feature_name) | |
hover_texts.append(f'{gene}: {value:.3f}') | |
# Create a single trace for the plot | |
fig = go.Figure(data=go.Scatter( | |
x=x_values, | |
y=y_values, | |
mode='markers', | |
marker=dict( | |
color=x_values, # Set color to SHAP values | |
colorbar=dict(title="SHAP Value"), | |
colorscale=[(0, "blue"), (1, "red")], # Blue to Red color scale | |
), | |
text=hover_texts, # Set hover text | |
hoverinfo="text+x" # Display hover text and x-value (SHAP value) | |
)) | |
fig.update_layout( | |
title="SHAP Summary Plot - Top 20 Features", | |
xaxis_title="SHAP Value", | |
yaxis=dict(autorange="reversed", title="Feature"), | |
showlegend=False, | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
st.caption("SHAP Summary Plot of All Input Genes - Top 20 Features") | |
elif tab == "Supervised SHAP Clustering": | |
st.title("Supervised SHAP Clustering") | |
training_genes = pd.read_csv("training_cleaned.csv") | |
training_genes = training_genes[training_genes['BPlabel_encoded'] == 0] | |
training_genes.set_index('Gene', inplace=True) | |
# Calculate SHAP values for the full dataset | |
shap_values_full = explainer.shap_values(annotations) | |
shap_values_full_array = np.array(shap_values_full[0]) | |
# Apply PCA to reduce dimensionality for visualization | |
pca = PCA(n_components=2) | |
shap_values_pca = pca.fit_transform(shap_values_full_array) | |
# Apply clustering on the PCA-reduced SHAP values | |
kmeans = KMeans(n_clusters=3, random_state=0).fit(shap_values_pca) | |
# Get cluster labels for each point in the dataset | |
labels = kmeans.labels_ | |
# Prepare a DataFrame for visualization | |
df_for_plot = pd.DataFrame({ | |
'PCA_1': shap_values_pca[:, 0], | |
'PCA_2': shap_values_pca[:, 1], | |
'Cluster': labels.astype(str), | |
'Gene': annotations.index, | |
'Type': 'Clustered Gene' | |
}) | |
# Add a new column for marking the special groups | |
df_for_plot['SpecialGroup'] = 'None' | |
df_for_plot.loc[df_for_plot['Gene'].isin(training_genes.index), 'SpecialGroup'] = 'Most Likely Training Gene' | |
if gene_list: | |
df_for_plot.loc[df_for_plot['Gene'].isin(gene_list), 'SpecialGroup'] = 'User Input Gene' | |
# Initialize an empty figure | |
fig = go.Figure() | |
# Plot clustered genes based on PCA components | |
for cluster in df_for_plot['Cluster'].unique(): | |
filtered_df = df_for_plot[(df_for_plot['Cluster'] == cluster) & (df_for_plot['SpecialGroup'] == 'None')] | |
fig.add_trace(go.Scatter( | |
x=filtered_df['PCA_1'], y=filtered_df['PCA_2'], | |
mode='markers', | |
name=f'Cluster {cluster}', | |
text=filtered_df['Gene'], | |
hoverinfo="text+x+y", | |
)) | |
# Overlay "Most Likely Training Gene" | |
filtered_df = df_for_plot[df_for_plot['SpecialGroup'] == 'Most Likely Training Gene'] | |
fig.add_trace(go.Scatter( | |
x=filtered_df['PCA_1'], y=filtered_df['PCA_2'], | |
mode='markers', | |
name='Most Likely Training Gene', | |
text=filtered_df['Gene'], | |
marker=dict(color='rgba(255, 0, 0, .9)'), | |
hoverinfo="text+x+y", | |
)) | |
# Overlay "User Input Gene" | |
filtered_df = df_for_plot[df_for_plot['SpecialGroup'] == 'User Input Gene'] | |
fig.add_trace(go.Scatter( | |
x=filtered_df['PCA_1'], y=filtered_df['PCA_2'], | |
mode='markers', | |
name='User Input Gene', | |
text=filtered_df['Gene'], | |
marker=dict(color='rgba(0, 255, 0, .9)'), | |
hoverinfo="text+x+y", | |
)) | |
# Customize layout | |
fig.update_layout( | |
title='Supervised SHAP Clustering with PCA', | |
xaxis_title='First Principal Component', | |
yaxis_title='Second Principal Component', | |
showlegend=True, | |
legend_title_text='Gene Category', | |
) | |
st.plotly_chart(fig, use_container_width=True) | |