File size: 7,705 Bytes
f60ce93
7d04c1c
0059ef7
 
bcf8eca
bf91270
a2e3d4b
a828d88
 
 
 
 
 
 
 
 
 
 
9ccc625
33c190e
7944a63
0059ef7
037b3d4
66ad10a
0059ef7
4d92e12
4b60c06
 
 
 
0059ef7
4b60c06
 
 
 
5c158f1
4b60c06
 
 
 
 
 
 
cfb90d5
4b60c06
 
 
d5fefc1
037b3d4
0059ef7
037b3d4
 
ac213c9
f60ce93
060dcc2
45b475d
 
 
 
 
 
 
659d788
f60ce93
 
 
 
a9b361d
c73e4be
a044018
a2e3d4b
 
 
d53ce83
3592cb3
 
 
 
a828d88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff8cf9b
330195f
1b01b6f
5ef2581
77897c3
 
 
 
 
 
 
1b01b6f
 
3592cb3
5e24005
3592cb3
1b01b6f
3592cb3
 
 
cfb90d5
5e24005
 
 
a2e3d4b
f83ec42
a828d88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2e3d4b
 
 
 
 
 
 
33c190e
8fb59d9
5ebff3c
8fb59d9
 
 
 
 
a2e3d4b
9ccc625
 
2ab2dc2
 
33c190e
e558219
a2e3d4b
 
1dab293
2ab2dc2
e558219
f83ec42
330195f
 
a044018
 
 
 
 
2ab2dc2
 
8fb59d9
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import streamlit as st
import re
import numpy as np
import pandas as pd
import sklearn
import xgboost
import shap
import plotly.tools as tls
import dash_core_components as dcc
import matplotlib
import plotly.graph_objs as go
try:
    import matplotlib.pyplot as pl
    from matplotlib.colors import LinearSegmentedColormap
    from matplotlib.ticker import MaxNLocator
except ImportError:
    pass

st.set_option('deprecation.showPyplotGlobalUse', False)

seed=42

annotations = pd.read_csv("annotations_dataset.csv")
annotations = annotations.set_index("Gene")

training_data = pd.read_csv("./selected_features_training_data.csv", header=0)
training_data.columns = [
    regex.sub("_", col) if any(x in str(col) for x in set(("[", "]", "<"))) else col
    for col in training_data.columns.values
]

training_data["BPlabel_encoded"] = training_data["BPlabel"].map(
    {"most likely": 1, "probable": 0.75, "least likely": 0.1}
)
Y = training_data["BPlabel_encoded"]
X = training_data.drop(columns=["BPlabel_encoded","BPlabel"])
xgb = xgboost.XGBRegressor(
    n_estimators=40,
    learning_rate=0.2,
    max_depth=4,
    reg_alpha=1,
    reg_lambda=1,
    random_state=seed,
    objective="reg:squarederror")


xgb.fit(X, Y)
prediction_list = list(xgb.predict(annotations))
predictions = [round(prediction, 2) for prediction in prediction_list]

output = pd.Series(data=predictions, index=annotations.index, name="XGB_Score")
df_total = pd.concat([annotations, output], axis=1)


df_total = df_total[['XGB_Score', 'mousescore_Exomiser',
 'SDI', 'Liver_GTExTPM',  'pLI_ExAC',
 'HIPred',
 'Cells - EBV-transformed lymphocytes_GTExTPM',
 'Pituitary_GTExTPM',
 'IPA_BP_annotation']]


st.title('Blood Pressure Gene Prioritisation Post-GWAS')
st.markdown("""
A machine learning pipeline for predicting disease-causing genes post-genome-wide association study in blood pressure.
""")


collect_genes = lambda x : [str(i) for i in re.split(",|,\s+|\s+", x) if i != ""]

input_gene_list = st.text_input("Input list of HGNC genes (enter comma separated):")
gene_list = collect_genes(input_gene_list)
explainer = shap.TreeExplainer(xgb)

@st.experimental_memo
def convert_df(df):
   return df.to_csv(index=False).encode('utf-8')


cdict1 = {
    'red': ((0.0, 0.11764705882352941, 0.11764705882352941),
            (1.0, 0.9607843137254902, 0.9607843137254902)),

    'green': ((0.0, 0.5333333333333333, 0.5333333333333333),
              (1.0, 0.15294117647058825, 0.15294117647058825)),

    'blue': ((0.0, 0.8980392156862745, 0.8980392156862745),
             (1.0, 0.3411764705882353, 0.3411764705882353)),

    'alpha': ((0.0, 1, 1),
              (0.5, 1, 1),
              (1.0, 1, 1))
}  # #1E88E5 -> #ff0052
red_blue = LinearSegmentedColormap('RedBlue', cdict1)

def matplotlib_to_plotly(cmap, pl_entries):
    h = 1.0/(pl_entries-1)
    pl_colorscale = []

    for k in range(pl_entries):
        C = list(map(np.uint8, np.array(cmap(k*h)[:3])*255))
        pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])

    return pl_colorscale

red_blue = matplotlib_to_plotly(red_blue, 255)


if len(gene_list) > 1:
    df = df_total[df_total.index.isin(gene_list)]
    df['Gene'] = df.index
    df.reset_index(drop=True, inplace=True)
    df = df[['Gene','XGB_Score', 'mousescore_Exomiser',
 'SDI', 'Liver_GTExTPM',  'pLI_ExAC',
 'HIPred',
 'Cells - EBV-transformed lymphocytes_GTExTPM',
 'Pituitary_GTExTPM',
 'IPA_BP_annotation']]
    st.dataframe(df)
    output = df[['Gene', 'XGB_Score']]
    csv = convert_df(output)
    st.download_button(
       "Download Gene Prioritisation",
       csv,
       "bp_gene_prioritisation.csv",
       "text/csv",
       key='download-csv'
    )
    df_shap = df_total[df_total.index.isin(gene_list)]
    df_shap.drop(columns='XGB_Score', inplace=True)
    shap_values = explainer.shap_values(df_shap)
    summary_plot = shap.summary_plot(shap_values, df_shap, show=False)
    st.caption("SHAP Summary Plot of All Input Genes")
    st.pyplot(fig=summary_plot)
    st.caption("Interactive SHAP Summary Plot of All Input Genes")
    mpl_fig = shap_summary_plot(shap_values, df_shap, show=False)
    plotly_fig = tls.mpl_to_plotly(mpl_fig)
    plotly_fig['layout'] = {'xaxis': {'title': 'SHAP value (impact on model output)'}}

    feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
    feature_order = feature_order[-min(max_display, len(feature_order)):]
    text = gene_list
    text = iter(text)
    for i in range(1, len(plotly_fig['df_shap']), 2):
        t = text.__next__()
        plotly_fig['df_shap'][i]['name'] = ''
        plotly_fig['df_shap'][i]['text'] = t
        plotly_fig['df_shap'][i]['hoverinfo'] = 'text'
    colorbar_trace  = go.Scatter(x=[None],
                                 y=[None],
                                 mode='markers',
                                 marker=dict(
                                     colorscale=red_blue, 
                                     showscale=True,
                                     cmin=-5,
                                     cmax=5,
                                     colorbar=dict(thickness=5, tickvals=[-5, 5], ticktext=['Low', 'High'], outlinewidth=0)
                                 ),
                                 hoverinfo='none'
                                )
    plotly_fig['layout']['showlegend'] = False
    plotly_fig['layout']['hovermode'] = 'closest'
    plotly_fig['layout']['height']=600
    plotly_fig['layout']['width']=500
    plotly_fig['layout']['xaxis'].update(zeroline=True, showline=True, ticklen=4, showgrid=False)
    plotly_fig['layout']['yaxis'].update(dict(visible=False))
    plotly_fig.add_trace(colorbar_trace)
    plotly_fig.layout.update(
                             annotations=[dict(
                                  x=1.18,
                                  align="right",
                                  valign="top",
                                  text='Feature value',
                                  showarrow=False,
                                  xref="paper",
                                  yref="paper",
                                  xanchor="right",
                                  yanchor="middle",
                                  textangle=-90,
                                  font=dict(family='Calibri', size=14)
                                )
                             ],
                             margin=dict(t=20)
                            )
    st.plotly_chart(plotly_fig)

else:
    pass


input_gene = st.text_input("Input individual HGNC gene:")
df2 = df_total[df_total.index == input_gene]
df2['Gene'] = df2.index
df2.reset_index(drop=True, inplace=True)
df2 = df2[['Gene','XGB_Score', 'mousescore_Exomiser',
 'SDI', 'Liver_GTExTPM',  'pLI_ExAC',
 'HIPred',
 'Cells - EBV-transformed lymphocytes_GTExTPM',
 'Pituitary_GTExTPM',
 'IPA_BP_annotation']]
st.dataframe(df2)

if input_gene:
    df2_shap = df_total[df_total.index == input_gene]
    df2_shap.drop(columns='XGB_Score', inplace=True)
    shap_values = explainer.shap_values(df2_shap)
    shap.getjs()
    force_plot = shap.force_plot(
    explainer.expected_value,
    shap_values,
    df2_shap, 
    matplotlib = True,show=False)
    st.pyplot(fig=force_plot)
else:
    pass

st.markdown("""
Total Gene Prioritisation Results:
""")

df_total_output = df_total
df_total_output['Gene'] = df_total_output.index
df_total_output.reset_index(drop=True, inplace=True)
df_total_output = df_total_output[['Gene','XGB_Score', 'mousescore_Exomiser',
 'SDI', 'Liver_GTExTPM',  'pLI_ExAC',
 'HIPred',
 'Cells - EBV-transformed lymphocytes_GTExTPM',
 'Pituitary_GTExTPM',
 'IPA_BP_annotation']]
st.dataframe(df_total_output)