jfaustin commited on
Commit
354bfc2
·
verified ·
1 Parent(s): f601557

Improve aesthetics of correlation tab (#7)

Browse files

- git-ignore draft notebooks and np output (1521e428e0cd285be1e9d19853270de06eb08e0a)
- improve correlation legend (fe7692a5de1e886c3d22d6b39a5b9d6839b8af71)
- show corr plot on button click (30947eff1f1b60460a326f0ffe27110a55ad6bb2)
- add option to show x-axis as log (640e97433b3ecd0cdd3d7c9e689ade72c060e196)
- add text to explain the correlation tab purpose (924cdcdcd7d3a52a25f994bb52d163ef371b2b8f)
- use blue/purple theme (ca02709b5b7f00c3e9cefa66bf7f3e17c94fa59a)

.gitignore CHANGED
@@ -3,3 +3,6 @@
3
  output/
4
  sequences/
5
  boltz_results/
 
 
 
 
3
  output/
4
  sequences/
5
  boltz_results/
6
+
7
+ *.ipynb
8
+ *.npz
folding_studio_demo/app.py CHANGED
@@ -10,9 +10,10 @@ from gradio_molecule3d import Molecule3D
10
  from folding_studio_demo.correlate import (
11
  SCORE_COLUMNS,
12
  fake_predict_and_correlate,
13
- select_correlation_plot,
14
  )
15
  from folding_studio_demo.predict import predict, predict_comparison
 
16
 
17
  logger = logging.getLogger(__name__)
18
 
@@ -102,7 +103,13 @@ def simple_prediction(api_key: str) -> None:
102
  with gr.Column():
103
  sequence = sequence_input()
104
 
105
- predict_btn = gr.Button("Predict")
 
 
 
 
 
 
106
 
107
  with gr.Row():
108
  mol_output = Molecule3D(label="Protein Structure", reps=MOLECULE_REPS)
@@ -134,7 +141,13 @@ def model_comparison(api_key: str) -> None:
134
  with gr.Column():
135
  sequence = sequence_input()
136
 
137
- predict_btn = gr.Button("Compare Models")
 
 
 
 
 
 
138
 
139
  with gr.Row():
140
  mol_outputs = Molecule3D(
@@ -154,6 +167,20 @@ def model_comparison(api_key: str) -> None:
154
 
155
  def create_correlation_tab():
156
  gr.Markdown("# Correlation with experimental binding affinity data")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  spr_data_with_scores = pd.read_csv("spr_af_scores_mapped.csv")
158
  prettified_columns = {
159
  "antibody_name": "Antibody Name",
@@ -179,36 +206,60 @@ def create_correlation_tab():
179
 
180
  gr.Markdown("# Prediction and correlation")
181
  with gr.Row():
182
- fake_predict_btn = gr.Button("Predict structures of all complexes")
 
 
 
 
 
183
  with gr.Row():
184
  prediction_dataframe = gr.Dataframe(label="Predicted Structures Data")
185
  with gr.Row():
186
  correlation_ranking_plot = gr.Plot(label="Correlation ranking")
187
  with gr.Row():
188
- # User can select the columns to display in the correlation plot
189
- correlation_column = gr.Dropdown(
190
- label="Score data to display", choices=SCORE_COLUMNS, multiselect=False
191
- )
192
- correlation_plot = gr.Plot(label="Correlation with binding affinity")
 
 
 
 
 
 
193
 
194
  fake_predict_btn.click(
195
  fn=lambda x: fake_predict_and_correlate(
196
  spr_data_with_scores, SCORE_COLUMNS, ["Antibody Name", "KD (nM)"]
197
  ),
198
  inputs=None,
199
- outputs=[prediction_dataframe, correlation_ranking_plot],
200
  )
201
-
202
- # Call function to update the correlation plot when the user selects the columns
 
 
203
  correlation_column.change(
204
- fn=lambda score: select_correlation_plot(spr_data_with_scores, score),
205
- inputs=correlation_column,
 
 
 
 
 
 
206
  outputs=correlation_plot,
207
  )
208
 
209
 
210
  def __main__():
211
- with gr.Blocks(title="Folding Studio Demo") as demo:
 
 
 
 
 
212
  gr.Markdown(
213
  """
214
  # Folding Studio: Harness the Power of Protein Folding 🧬
 
10
  from folding_studio_demo.correlate import (
11
  SCORE_COLUMNS,
12
  fake_predict_and_correlate,
13
+ make_correlation_plot,
14
  )
15
  from folding_studio_demo.predict import predict, predict_comparison
16
+ from folding_studio_demo.config import BLUE, PURPLE
17
 
18
  logger = logging.getLogger(__name__)
19
 
 
103
  with gr.Column():
104
  sequence = sequence_input()
105
 
106
+ predict_btn = gr.Button(
107
+ "Predict",
108
+ elem_classes="gradient-button",
109
+ elem_id="predict-btn",
110
+ variant="primary",
111
+ # css=f".gradio-container #predict-btn {{background: linear-gradient(90deg, {BLUE}, {PURPLE});}}",
112
+ )
113
 
114
  with gr.Row():
115
  mol_output = Molecule3D(label="Protein Structure", reps=MOLECULE_REPS)
 
141
  with gr.Column():
142
  sequence = sequence_input()
143
 
144
+ predict_btn = gr.Button(
145
+ "Compare Models",
146
+ elem_classes=["gradient-button"],
147
+ elem_id="compare-models-btn",
148
+ variant="primary",
149
+ # css=f".gradio-container #compare-models-btn {{background: linear-gradient(90deg, {BLUE}, {PURPLE});}}"
150
+ )
151
 
152
  with gr.Row():
153
  mol_outputs = Molecule3D(
 
167
 
168
  def create_correlation_tab():
169
  gr.Markdown("# Correlation with experimental binding affinity data")
170
+ gr.Markdown("""
171
+ This analysis explores the relationship between protein folding model confidence scores and experimental binding affinity data.
172
+
173
+ The experimental dataset contains binding affinity measurements (KD in nM) between antibody-antigen pairs.
174
+ Each data point includes:
175
+ - The antibody's light and heavy chain sequences
176
+ - The antigen sequence
177
+ - The experimental KD value
178
+
179
+ The analysis involves submitting these sequences to protein folding models for 3D structure prediction.
180
+ The models generate various confidence scores for each prediction. These scores are then correlated
181
+ with the experimental binding affinity measurements to evaluate their effectiveness as predictors
182
+ of binding strength.
183
+ """)
184
  spr_data_with_scores = pd.read_csv("spr_af_scores_mapped.csv")
185
  prettified_columns = {
186
  "antibody_name": "Antibody Name",
 
206
 
207
  gr.Markdown("# Prediction and correlation")
208
  with gr.Row():
209
+ fake_predict_btn = gr.Button(
210
+ "Predict structures of all complexes",
211
+ elem_classes="gradient-button",
212
+ variant="primary",
213
+ # css=f".gradio-container #fake-predict-btn {{background: linear-gradient(90deg, {BLUE}, {PURPLE});}}",
214
+ )
215
  with gr.Row():
216
  prediction_dataframe = gr.Dataframe(label="Predicted Structures Data")
217
  with gr.Row():
218
  correlation_ranking_plot = gr.Plot(label="Correlation ranking")
219
  with gr.Row():
220
+ with gr.Column():
221
+ with gr.Row():
222
+ # User can select the columns to display in the correlation plot
223
+ correlation_column = gr.Dropdown(
224
+ label="Score data to display", choices=SCORE_COLUMNS, multiselect=False, value=SCORE_COLUMNS[0]
225
+ )
226
+ # Add checkbox for log scale and update plot when either input changes
227
+ with gr.Row():
228
+ log_scale = gr.Checkbox(label="Display x-axis on logarithmic scale", value=False)
229
+ with gr.Column():
230
+ correlation_plot = gr.Plot(label="Correlation with binding affinity")
231
 
232
  fake_predict_btn.click(
233
  fn=lambda x: fake_predict_and_correlate(
234
  spr_data_with_scores, SCORE_COLUMNS, ["Antibody Name", "KD (nM)"]
235
  ),
236
  inputs=None,
237
+ outputs=[prediction_dataframe, correlation_ranking_plot, correlation_plot],
238
  )
239
+
240
+ def update_plot(score, use_log):
241
+ return make_correlation_plot(spr_data_with_scores, score, use_log)
242
+
243
  correlation_column.change(
244
+ fn=update_plot,
245
+ inputs=[correlation_column, log_scale],
246
+ outputs=correlation_plot,
247
+ )
248
+
249
+ log_scale.change(
250
+ fn=update_plot,
251
+ inputs=[correlation_column, log_scale],
252
  outputs=correlation_plot,
253
  )
254
 
255
 
256
  def __main__():
257
+
258
+ theme = gr.themes.Ocean(
259
+ primary_hue="blue",
260
+ secondary_hue="purple",
261
+ )
262
+ with gr.Blocks(theme=theme, title="Folding Studio Demo") as demo:
263
  gr.Markdown(
264
  """
265
  # Folding Studio: Harness the Power of Protein Folding 🧬
folding_studio_demo/correlate.py CHANGED
@@ -68,15 +68,17 @@ def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: l
68
  cols_to_show = main_cols[:]
69
  cols_to_show.extend(score_cols)
70
 
71
- return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot
72
 
73
- def select_correlation_plot(spr_data_with_scores: pd.DataFrame, score: str) -> go.Figure:
 
 
74
  """Select the correlation plot to display."""
75
  # corr_plot is a scatter plot of the correlation between the binding affinity and each of the scores
76
  scatter = go.Scatter(
77
  x=spr_data_with_scores["KD (nM)"],
78
  y=spr_data_with_scores[score],
79
- name=f"KD (nM) vs {score}",
80
  mode='markers', # Only show markers/dots, no lines
81
  hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>",
82
  marker=dict(color='#1f77b4') # Set color to match default first color
@@ -91,9 +93,9 @@ def select_correlation_plot(spr_data_with_scores: pd.DataFrame, score: str) -> g
91
  yanchor="bottom",
92
  y=1.02,
93
  xanchor="right",
94
- x=1
95
- )
96
- # xaxis_type="log" # Set x-axis to logarithmic scale
97
  )
98
  # compute the correlation line
99
  corr_line = np.polyfit(spr_data_with_scores["KD (nM)"], spr_data_with_scores[score], 1)
@@ -104,7 +106,7 @@ def select_correlation_plot(spr_data_with_scores: pd.DataFrame, score: str) -> g
104
  x=corr_line_x,
105
  y=corr_line_y,
106
  mode='lines',
107
- name=f"Correlation",
108
  line=dict(color='#1f77b4') # Set same color as scatter points
109
  ))
110
  return corr_plot
 
68
  cols_to_show = main_cols[:]
69
  cols_to_show.extend(score_cols)
70
 
71
+ corr_plot = make_correlation_plot(spr_data_with_scores, score_cols[0], use_log=False)
72
 
73
+ return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot, corr_plot
74
+
75
+ def make_correlation_plot(spr_data_with_scores: pd.DataFrame, score: str, use_log: bool) -> go.Figure:
76
  """Select the correlation plot to display."""
77
  # corr_plot is a scatter plot of the correlation between the binding affinity and each of the scores
78
  scatter = go.Scatter(
79
  x=spr_data_with_scores["KD (nM)"],
80
  y=spr_data_with_scores[score],
81
+ name=f"Samples",
82
  mode='markers', # Only show markers/dots, no lines
83
  hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>",
84
  marker=dict(color='#1f77b4') # Set color to match default first color
 
93
  yanchor="bottom",
94
  y=1.02,
95
  xanchor="right",
96
+ x=1,
97
+ ),
98
+ xaxis_type="log" if use_log else "linear" # Set x-axis to logarithmic scale
99
  )
100
  # compute the correlation line
101
  corr_line = np.polyfit(spr_data_with_scores["KD (nM)"], spr_data_with_scores[score], 1)
 
106
  x=corr_line_x,
107
  y=corr_line_y,
108
  mode='lines',
109
+ name=f"Regression line",
110
  line=dict(color='#1f77b4') # Set same color as scatter points
111
  ))
112
  return corr_plot