hiyata commited on
Commit
d76e76a
·
verified ·
1 Parent(s): 962ae70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -125
app.py CHANGED
@@ -96,20 +96,19 @@ def calculate_shap_values(model, x_tensor):
96
  shap_values = []
97
  x_zeroed = x_tensor.clone()
98
  for i in range(x_tensor.shape[1]):
99
- orig_value = x_zeroed[0, i].item()
100
  x_zeroed[0, i] = 0.0
101
  output = model(x_zeroed)
102
  probs = torch.softmax(output, dim=1)
103
  prob = probs[0, 1].item()
104
- impact = baseline_prob - prob # how much removing the feature changed the prediction
105
  shap_values.append(impact)
106
- x_zeroed[0, i] = orig_value # restore the original value
107
-
108
  return np.array(shap_values), baseline_prob
109
 
110
 
111
  ###############################################################################
112
- # 4. PER-BASE SHAP AGGREGATION (LINEAR HEATMAP)
113
  ###############################################################################
114
 
115
  def compute_positionwise_scores(sequence, shap_values, k=4):
@@ -122,60 +121,98 @@ def compute_positionwise_scores(sequence, shap_values, k=4):
122
  kmer_dict = {km: i for i, km in enumerate(kmers)}
123
 
124
  seq_len = len(sequence)
125
-
126
- # Arrays to accumulate sums (SHAP) and coverage counts
127
  shap_sums = np.zeros(seq_len, dtype=np.float32)
128
  coverage = np.zeros(seq_len, dtype=np.float32)
129
 
130
- # Slide over the sequence, summing SHAP values for overlapping positions
131
  for i in range(seq_len - k + 1):
132
  kmer = sequence[i:i+k]
133
  if kmer in kmer_dict:
134
- # Get the SHAP value for this k-mer
135
- value = shap_values[kmer_dict[kmer]]
136
- # Accumulate it for each base in the k-mer
137
- shap_sums[i : i + k] += value
138
  coverage[i : i + k] += 1
139
 
140
- # Compute the average SHAP per base (avoid divide-by-zero)
141
  with np.errstate(divide='ignore', invalid='ignore'):
142
  shap_means = np.where(coverage > 0, shap_sums / coverage, 0.0)
143
 
144
  return shap_means
145
 
146
- def plot_linear_heatmap(shap_means):
 
 
 
 
 
147
  """
148
  Plots a 1D heatmap of per-base SHAP contributions.
149
  Negative = push toward Non-Human, Positive = push toward Human.
150
  """
151
- # Reshape into (1, -1) so that imshow displays it as a single row
152
- heatmap_data = shap_means.reshape(1, -1)
153
-
154
  fig, ax = plt.subplots(figsize=(12, 2))
155
 
156
- # We'll use a diverging color map (red/blue)
157
  cax = ax.imshow(heatmap_data, aspect='auto', cmap='RdBu_r')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- # Add colorbar
 
160
  cbar = plt.colorbar(cax, orientation='horizontal', pad=0.2)
161
  cbar.set_label('SHAP Contribution')
162
 
163
- ax.set_yticks([]) # single row, so hide the y-axis
164
- ax.set_xlabel('Position in Sequence')
165
- ax.set_title('Per-base SHAP Heatmap')
166
 
167
  plt.tight_layout()
168
  return fig
169
 
170
 
171
  ###############################################################################
172
- # 5. OTHER PLOTS: BAR PLOT OF TOP-K AND SEQUENCE IMPACT VISUALIZATION
173
  ###############################################################################
174
 
175
  def create_importance_bar_plot(shap_values, kmers, top_k=10):
176
  """Create a bar plot of the most important k-mers."""
177
  plt.rcParams.update({'font.size': 10})
178
- plt.figure(figsize=(10, 6))
179
 
180
  # Sort by absolute importance
181
  indices = np.argsort(np.abs(shap_values))[-top_k:]
@@ -188,83 +225,16 @@ def create_importance_bar_plot(shap_values, kmers, top_k=10):
188
  plt.yticks(range(len(values)), features)
189
  plt.xlabel('SHAP value (impact on model output)')
190
  plt.title(f'Top {top_k} Most Influential k-mers')
191
- plt.gca().invert_yaxis() # most important at top
192
-
193
- return plt.gcf()
194
-
195
- def visualize_sequence_impacts(sequence, kmers, shap_values, base_prob):
196
- """
197
- Create a SHAP-style visualization of sequence impacts.
198
- Shows each k-mer's contribution in context.
199
- """
200
- k = 4 # k-mer size
201
- kmer_dict = {km: i for i, km in enumerate(kmers)}
202
-
203
- # Find all k-mers and their impacts
204
- kmer_impacts = []
205
- for i in range(len(sequence) - k + 1):
206
- kmer = sequence[i:i+k]
207
- if kmer in kmer_dict:
208
- impact = shap_values[kmer_dict[kmer]]
209
- kmer_impacts.append((i, kmer, impact))
210
-
211
- # Sort by absolute impact
212
- kmer_impacts.sort(key=lambda x: abs(x[2]), reverse=True)
213
-
214
- # Limit display to top 30 k-mers
215
- display_kmers = kmer_impacts[:30]
216
-
217
- # Calculate figure height based on number of k-mers
218
- fig_height = min(20, max(8, len(display_kmers) * 0.4))
219
-
220
- # Create figure with controlled size
221
- fig = plt.figure(figsize=(12, fig_height))
222
- ax = plt.gca()
223
-
224
- # Add title and base value
225
- plt.text(0.01, 1.02, f"base value = {base_prob:.3f}", transform=ax.transAxes, fontsize=10)
226
-
227
- # Plot k-mers with controlled spacing
228
- y_spacing = 0.9 / max(len(display_kmers), 1)
229
- y_position = 0.95
230
-
231
- for pos, kmer, impact in display_kmers:
232
- pre_sequence = sequence[max(0, pos-20):pos]
233
- post_sequence = sequence[pos+len(kmer):min(pos+len(kmer)+20, len(sequence))]
234
-
235
- # Add ellipsis if truncated
236
- pre_ellipsis = "..." if pos > 20 else ""
237
- post_ellipsis = "..." if pos+len(kmer)+20 < len(sequence) else ""
238
-
239
- # Choose color based on impact
240
- color = '#ffcccb' if impact > 0 else '#cce0ff'
241
- arrow = '↑' if impact > 0 else '↓'
242
-
243
- # Draw text elements
244
- plt.text(0.01, y_position, f"{pre_ellipsis}{pre_sequence}", fontsize=9)
245
- plt.text(0.01 + len(f"{pre_ellipsis}{pre_sequence}")/50, y_position,
246
- kmer, fontsize=9, bbox=dict(facecolor=color, alpha=0.3, pad=1))
247
- plt.text(0.01 + (len(f"{pre_ellipsis}{pre_sequence}") + len(kmer))/50,
248
- y_position, f"{post_sequence}{post_ellipsis}", fontsize=9)
249
-
250
- # Add impact value
251
- plt.text(0.8, y_position, f"{arrow} {impact:+.3f}", fontsize=9)
252
-
253
- y_position -= y_spacing
254
-
255
- plt.axis('off')
256
-
257
- # Adjust layout
258
- plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05)
259
  return fig
260
 
261
-
262
  ###############################################################################
263
- # 6. HELPER FUNCTION: FIG TO IMAGE
264
  ###############################################################################
265
 
266
  def fig_to_image(fig):
267
  """Convert a Matplotlib figure to a PIL Image."""
 
268
  buf = io.BytesIO()
269
  fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
270
  buf.seek(0)
@@ -272,12 +242,11 @@ def fig_to_image(fig):
272
  plt.close(fig)
273
  return img
274
 
275
-
276
  ###############################################################################
277
- # 7. MAIN PREDICTION FUNCTION
278
  ###############################################################################
279
 
280
- def predict(file_obj, top_kmers=10, fasta_text=""):
281
  """Main prediction function for Gradio interface."""
282
  # Handle input
283
  if fasta_text.strip():
@@ -302,7 +271,6 @@ def predict(file_obj, top_kmers=10, fasta_text=""):
302
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
303
  try:
304
  model = VirusClassifier(256).to(device)
305
- # Remove 'weights_only=True' if it causes errors; it's not a standard argument.
306
  model.load_state_dict(torch.load('model.pt', map_location=device))
307
  scaler = joblib.load('scaler.pkl')
308
  except Exception as e:
@@ -321,31 +289,34 @@ def predict(file_obj, top_kmers=10, fasta_text=""):
321
  f"Sequence: {header}",
322
  f"Prediction: {'Human' if prob_human > 0.5 else 'Non-human'} Origin",
323
  f"Confidence: {max(prob_human, 1 - prob_human):.3f}",
324
- f"Human Probability: {prob_human:.3f}",
325
- "\nTop Contributing k-mers:"
326
  ]
327
 
328
- # Create k-mer lists for visualization
329
  kmers = [''.join(p) for p in product("ACGT", repeat=4)]
330
 
331
- # 1) K-mer importance bar plot
332
- importance_plot = create_importance_bar_plot(shap_values, kmers, top_kmers)
333
- importance_img = fig_to_image(importance_plot)
334
-
335
- # 2) SHAP-style textual sequence impact
336
- sequence_plot = visualize_sequence_impacts(seq, kmers, shap_values, prob_human)
337
- sequence_img = fig_to_image(sequence_plot)
338
 
339
- # 3) Linear heatmap across full genome
340
  shap_means = compute_positionwise_scores(seq, shap_values, k=4)
341
- heatmap_fig = plot_linear_heatmap(shap_means)
342
  heatmap_img = fig_to_image(heatmap_fig)
343
 
344
- return "\n".join(results), importance_img, sequence_img, heatmap_img
 
 
 
 
 
 
 
 
345
 
346
 
347
  ###############################################################################
348
- # 8. BUILD GRADIO INTERFACE
349
  ###############################################################################
350
 
351
  css = """
@@ -379,31 +350,34 @@ with gr.Blocks(css=css) as iface:
379
  step=1,
380
  label="Number of top k-mers to display"
381
  )
 
 
 
 
 
 
 
382
  submit_btn = gr.Button("Analyze Sequence", variant="primary")
383
 
384
  with gr.Column(scale=2):
385
- results = gr.Textbox(label="Analysis Results", lines=10)
386
- kmer_plot = gr.Image(label="K-mer Importance Plot")
387
- shap_plot = gr.Image(label="Sequence Impact Visualization (SHAP-style)")
388
- heatmap_plot = gr.Image(label="Genome Heatmap")
389
 
390
  submit_btn.click(
391
  predict,
392
- inputs=[file_input, top_k, text_input],
393
- outputs=[results, kmer_plot, shap_plot, heatmap_plot]
394
  )
395
 
396
  gr.Markdown("""
397
  ### Visualization Guide
398
- - **K-mer Importance Plot**: Shows the most influential k-mers and their SHAP values
399
- - **Sequence Impact Visualization**: Shows the sequence with highlighted k-mers:
400
- - Red highlights = pushing toward human origin
401
- - Blue highlights = pushing toward non-human origin
402
- - Arrows (↑/↓) show impact direction
403
- - Values show impact magnitude
404
- - **Genome Heatmap**: Per-base SHAP values across the entire sequence
405
  - Red = push toward human
406
  - Blue = push toward non-human
 
407
  """)
408
 
409
  if __name__ == "__main__":
 
96
  shap_values = []
97
  x_zeroed = x_tensor.clone()
98
  for i in range(x_tensor.shape[1]):
99
+ original_value = x_zeroed[0, i].item()
100
  x_zeroed[0, i] = 0.0
101
  output = model(x_zeroed)
102
  probs = torch.softmax(output, dim=1)
103
  prob = probs[0, 1].item()
104
+ impact = baseline_prob - prob
105
  shap_values.append(impact)
106
+ x_zeroed[0, i] = original_value # restore
 
107
  return np.array(shap_values), baseline_prob
108
 
109
 
110
  ###############################################################################
111
+ # 4. PER-BASE SHAP AGGREGATION
112
  ###############################################################################
113
 
114
  def compute_positionwise_scores(sequence, shap_values, k=4):
 
121
  kmer_dict = {km: i for i, km in enumerate(kmers)}
122
 
123
  seq_len = len(sequence)
 
 
124
  shap_sums = np.zeros(seq_len, dtype=np.float32)
125
  coverage = np.zeros(seq_len, dtype=np.float32)
126
 
 
127
  for i in range(seq_len - k + 1):
128
  kmer = sequence[i:i+k]
129
  if kmer in kmer_dict:
130
+ val = shap_values[kmer_dict[kmer]]
131
+ shap_sums[i : i + k] += val
 
 
132
  coverage[i : i + k] += 1
133
 
 
134
  with np.errstate(divide='ignore', invalid='ignore'):
135
  shap_means = np.where(coverage > 0, shap_sums / coverage, 0.0)
136
 
137
  return shap_means
138
 
139
+
140
+ ###############################################################################
141
+ # 5. HEATMAP PLOTS
142
+ ###############################################################################
143
+
144
+ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap"):
145
  """
146
  Plots a 1D heatmap of per-base SHAP contributions.
147
  Negative = push toward Non-Human, Positive = push toward Human.
148
  """
149
+ heatmap_data = shap_means.reshape(1, -1) # shape (1, seq_len)
 
 
150
  fig, ax = plt.subplots(figsize=(12, 2))
151
 
 
152
  cax = ax.imshow(heatmap_data, aspect='auto', cmap='RdBu_r')
153
+ cbar = plt.colorbar(cax, orientation='horizontal', pad=0.2)
154
+ cbar.set_label('SHAP Contribution')
155
+
156
+ ax.set_yticks([])
157
+ ax.set_xlabel('Position in Sequence')
158
+ ax.set_title(title)
159
+ plt.tight_layout()
160
+ return fig
161
+
162
+ def get_top_signal_region(shap_means, window_size=500):
163
+ """
164
+ Find the window of length `window_size` that has the highest
165
+ sum of absolute SHAP values. Returns (start_index, end_index).
166
+ """
167
+ seq_len = len(shap_means)
168
+ if window_size >= seq_len:
169
+ return 0, seq_len # entire sequence if window too large
170
+
171
+ abs_values = np.abs(shap_means)
172
+ max_sum = -1
173
+ max_start = 0
174
+
175
+ # Slide a window over shap_means
176
+ current_sum = np.sum(abs_values[:window_size])
177
+ max_sum = current_sum
178
+ for start in range(1, seq_len - window_size + 1):
179
+ # Remove the leftmost base, add the new rightmost base
180
+ current_sum = current_sum - abs_values[start-1] + abs_values[start + window_size - 1]
181
+ if current_sum > max_sum:
182
+ max_sum = current_sum
183
+ max_start = start
184
+
185
+ return max_start, max_start + window_size
186
+
187
+ def plot_zoomed_heatmap(shap_means, window_size=500, title="Zoomed SHAP Region"):
188
+ """
189
+ Finds the region with the largest absolute SHAP sum in a fixed window,
190
+ then plots a 1D heatmap of just that sub-region.
191
+ """
192
+ start, end = get_top_signal_region(shap_means, window_size)
193
+ sub_means = shap_means[start:end].reshape(1, -1)
194
 
195
+ fig, ax = plt.subplots(figsize=(12, 2))
196
+ cax = ax.imshow(sub_means, aspect='auto', cmap='RdBu_r')
197
  cbar = plt.colorbar(cax, orientation='horizontal', pad=0.2)
198
  cbar.set_label('SHAP Contribution')
199
 
200
+ ax.set_yticks([])
201
+ ax.set_xlabel(f'Position in Sequence (zoomed in {start} - {end})')
202
+ ax.set_title(title)
203
 
204
  plt.tight_layout()
205
  return fig
206
 
207
 
208
  ###############################################################################
209
+ # 6. OTHER PLOT: TOP-K K-MER BAR PLOT
210
  ###############################################################################
211
 
212
  def create_importance_bar_plot(shap_values, kmers, top_k=10):
213
  """Create a bar plot of the most important k-mers."""
214
  plt.rcParams.update({'font.size': 10})
215
+ fig = plt.figure(figsize=(10, 5))
216
 
217
  # Sort by absolute importance
218
  indices = np.argsort(np.abs(shap_values))[-top_k:]
 
225
  plt.yticks(range(len(values)), features)
226
  plt.xlabel('SHAP value (impact on model output)')
227
  plt.title(f'Top {top_k} Most Influential k-mers')
228
+ plt.gca().invert_yaxis()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  return fig
230
 
 
231
  ###############################################################################
232
+ # 7. HELPER FUNCTION: FIG TO IMAGE
233
  ###############################################################################
234
 
235
  def fig_to_image(fig):
236
  """Convert a Matplotlib figure to a PIL Image."""
237
+ import io
238
  buf = io.BytesIO()
239
  fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
240
  buf.seek(0)
 
242
  plt.close(fig)
243
  return img
244
 
 
245
  ###############################################################################
246
+ # 8. MAIN PREDICTION FUNCTION
247
  ###############################################################################
248
 
249
+ def predict(file_obj, top_kmers=10, fasta_text="", zoom_window=500):
250
  """Main prediction function for Gradio interface."""
251
  # Handle input
252
  if fasta_text.strip():
 
271
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
272
  try:
273
  model = VirusClassifier(256).to(device)
 
274
  model.load_state_dict(torch.load('model.pt', map_location=device))
275
  scaler = joblib.load('scaler.pkl')
276
  except Exception as e:
 
289
  f"Sequence: {header}",
290
  f"Prediction: {'Human' if prob_human > 0.5 else 'Non-human'} Origin",
291
  f"Confidence: {max(prob_human, 1 - prob_human):.3f}",
292
+ f"Human Probability: {prob_human:.3f}"
 
293
  ]
294
 
295
+ # Create k-mer list (4-mers in lexicographic order)
296
  kmers = [''.join(p) for p in product("ACGT", repeat=4)]
297
 
298
+ # 1) Top-k k-mer bar plot
299
+ importance_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
300
+ importance_img = fig_to_image(importance_fig)
 
 
 
 
301
 
302
+ # 2) Full-genome per-base SHAP heatmap
303
  shap_means = compute_positionwise_scores(seq, shap_values, k=4)
304
+ heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide Per-base SHAP")
305
  heatmap_img = fig_to_image(heatmap_fig)
306
 
307
+ # 3) Zoomed region (optional, using the largest absolute SHAP region)
308
+ if zoom_window > 0:
309
+ zoom_fig = plot_zoomed_heatmap(shap_means, window_size=zoom_window,
310
+ title=f"Top SHAP Region (window={zoom_window})")
311
+ zoom_img = fig_to_image(zoom_fig)
312
+ else:
313
+ zoom_img = None
314
+
315
+ return "\n".join(results), importance_img, heatmap_img, zoom_img
316
 
317
 
318
  ###############################################################################
319
+ # 9. BUILD GRADIO INTERFACE
320
  ###############################################################################
321
 
322
  css = """
 
350
  step=1,
351
  label="Number of top k-mers to display"
352
  )
353
+ zoom_window = gr.Slider(
354
+ minimum=0,
355
+ maximum=5000,
356
+ value=500,
357
+ step=100,
358
+ label="Zoom Window Size (0 to disable zoom plot)"
359
+ )
360
  submit_btn = gr.Button("Analyze Sequence", variant="primary")
361
 
362
  with gr.Column(scale=2):
363
+ results_box = gr.Textbox(label="Analysis Results", lines=5)
364
+ kmer_plot = gr.Image(label="Top k-mer SHAP")
365
+ full_heatmap = gr.Image(label="Genome-wide SHAP Heatmap")
366
+ zoomed_heatmap = gr.Image(label="Zoomed SHAP Region (largest signal)")
367
 
368
  submit_btn.click(
369
  predict,
370
+ inputs=[file_input, top_k, text_input, zoom_window],
371
+ outputs=[results_box, kmer_plot, full_heatmap, zoomed_heatmap]
372
  )
373
 
374
  gr.Markdown("""
375
  ### Visualization Guide
376
+ - **Top k-mer SHAP**: Shows the most influential k-mers and their SHAP values.
377
+ - **Genome-wide SHAP Heatmap**: Per-base SHAP values across the entire sequence.
 
 
 
 
 
378
  - Red = push toward human
379
  - Blue = push toward non-human
380
+ - **Zoomed SHAP Region**: Shows the subregion of length 'Zoom Window Size' that has the highest absolute SHAP sum.
381
  """)
382
 
383
  if __name__ == "__main__":