hiyata commited on
Commit
910c6c2
·
verified ·
1 Parent(s): 03f2bb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -31
app.py CHANGED
@@ -5,6 +5,7 @@ import numpy as np
5
  from itertools import product
6
  import torch.nn as nn
7
  import matplotlib.pyplot as plt
 
8
  import io
9
  from PIL import Image
10
 
@@ -144,20 +145,17 @@ def find_extreme_subregion(shap_means, window_size=500, mode="max"):
144
  """
145
  n = len(shap_means)
146
  if n == 0:
147
- # Edge case: empty array
148
  return (0, 0, 0.0)
149
  if window_size >= n:
150
- # If the window is bigger than the entire sequence, return entire seq
151
  avg_val = float(np.mean(shap_means))
152
  return (0, n, avg_val)
153
 
154
- # We'll build csum as length n+1 so csum[i] = sum of shap_means[:i]
155
- # That means sum in [start, start+window_size) = csum[start+window_size] - csum[start].
156
  csum = np.zeros(n + 1, dtype=np.float32)
157
  csum[1:] = np.cumsum(shap_means)
158
 
159
  best_start = 0
160
- # Initialize with the first window: [0, window_size)
161
  best_sum = csum[window_size] - csum[0]
162
  best_avg = best_sum / window_size
163
 
@@ -188,29 +186,65 @@ def fig_to_image(fig):
188
  plt.close(fig)
189
  return img
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
192
  """
193
- Plots a 1D heatmap of per-base SHAP contributions.
194
- Negative = push toward Non-Human, Positive = push toward Human.
195
- Optionally can show only a subrange (start:end).
196
- Adjust layout so the colorbar is well below the x-axis:
197
- - orientation='horizontal', pad=0.35
198
- - plt.subplots_adjust(bottom=0.4)
 
199
  """
200
  if start is not None and end is not None:
201
- shap_means = shap_means[start:end]
202
  subtitle = f" (positions {start}-{end})"
203
  else:
 
204
  subtitle = ""
205
 
206
- heatmap_data = shap_means.reshape(1, -1) # shape (1, region_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  fig, ax = plt.subplots(figsize=(12, 2))
209
- cax = ax.imshow(heatmap_data, aspect='auto', cmap='RdBu_r')
 
 
 
 
 
 
210
 
211
- # Place colorbar below and add extra margin
212
  cbar = plt.colorbar(cax, orientation='horizontal', pad=0.35)
213
- cbar.set_label('SHAP Contribution')
214
 
215
  ax.set_yticks([])
216
  ax.set_xlabel('Position in Sequence')
@@ -231,7 +265,8 @@ def create_importance_bar_plot(shap_values, kmers, top_k=10):
231
  values = shap_values[indices]
232
  features = [kmers[i] for i in indices]
233
 
234
- colors = ['#ff9999' if v > 0 else '#99ccff' for v in values]
 
235
 
236
  plt.barh(range(len(values)), values, color=colors)
237
  plt.yticks(range(len(values)), features)
@@ -244,7 +279,6 @@ def create_importance_bar_plot(shap_values, kmers, top_k=10):
244
  def plot_shap_histogram(shap_array, title="SHAP Distribution in Region"):
245
  """
246
  Simple histogram of SHAP values in the subregion.
247
- Helps see how many positions push human vs non-human.
248
  """
249
  fig, ax = plt.subplots(figsize=(6, 4))
250
  ax.hist(shap_array, bins=30, color='gray', edgecolor='black')
@@ -294,12 +328,11 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
294
  # Load model and scaler
295
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
296
  try:
297
- # Use weights_only=True to address the FutureWarning about untrusted pickle data
298
  state_dict = torch.load('model.pt', map_location=device, weights_only=True)
299
  model = VirusClassifier(256).to(device)
300
  model.load_state_dict(state_dict)
301
-
302
- # Load scaler (warning if version mismatch)
303
  scaler = joblib.load('scaler.pkl')
304
  except Exception as e:
305
  return (f"Error loading model/scaler: {str(e)}", None, None, None, None)
@@ -353,7 +386,6 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
353
  "shap_means": shap_means
354
  }
355
 
356
- # Return exactly 5 items
357
  return (results_text, bar_img, heatmap_img, state_dict_out, header)
358
 
359
  ###############################################################################
@@ -438,9 +470,11 @@ css = """
438
 
439
  with gr.Blocks(css=css) as iface:
440
  gr.Markdown("""
441
- # Virus Host Classifier (with Interactive Region Viewer)
442
  **Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme regions.
443
  **Step 2**: Explore subregions to see local SHAP signals, distribution, GC content, etc.
 
 
444
  """)
445
 
446
  with gr.Tab("1) Full-Sequence Analysis"):
@@ -477,12 +511,12 @@ with gr.Blocks(css=css) as iface:
477
  label="Classification Results", lines=12, interactive=False
478
  )
479
  kmer_img = gr.Image(label="Top k-mer SHAP")
480
- genome_img = gr.Image(label="Genome-wide SHAP Heatmap")
481
 
482
  seq_state = gr.State()
483
  header_state = gr.State()
484
 
485
- # analyze_sequence(...) returns 5 items.
486
  analyze_btn.click(
487
  analyze_sequence,
488
  inputs=[file_input, top_k, text_input, win_size],
@@ -492,7 +526,8 @@ with gr.Blocks(css=css) as iface:
492
  with gr.Tab("2) Subregion Exploration"):
493
  gr.Markdown("""
494
  **Subregion Analysis**
495
- Select start/end positions to view local SHAP signals, distribution, and GC content.
 
496
  """)
497
  with gr.Row():
498
  region_start = gr.Number(label="Region Start", value=0)
@@ -505,7 +540,7 @@ with gr.Blocks(css=css) as iface:
505
  interactive=False
506
  )
507
  with gr.Row():
508
- subregion_img = gr.Image(label="Subregion SHAP Heatmap")
509
  subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
510
 
511
  region_btn.click(
@@ -517,10 +552,15 @@ with gr.Blocks(css=css) as iface:
517
  gr.Markdown("""
518
  ### Interface Features
519
  - **Overall Classification** (human vs non-human) using k-mer frequencies.
520
- - **Top k-mer SHAP**: which k-mers push the classifier output.
521
- - **Genome-Wide SHAP Heatmap**: each base's average SHAP across overlapping k-mers.
522
- - **Identify Subregions** (sliding window) with the strongest push for human or non-human.
523
- - **Subregion Exploration**: local SHAP heatmap & histogram, GC content, fraction of positions pushing human vs. non-human.
 
 
 
 
 
524
  """)
525
 
526
  if __name__ == "__main__":
 
5
  from itertools import product
6
  import torch.nn as nn
7
  import matplotlib.pyplot as plt
8
+ import matplotlib.colors as mcolors
9
  import io
10
  from PIL import Image
11
 
 
145
  """
146
  n = len(shap_means)
147
  if n == 0:
 
148
  return (0, 0, 0.0)
149
  if window_size >= n:
150
+ # entire sequence
151
  avg_val = float(np.mean(shap_means))
152
  return (0, n, avg_val)
153
 
154
+ # We'll build csum of length n+1
 
155
  csum = np.zeros(n + 1, dtype=np.float32)
156
  csum[1:] = np.cumsum(shap_means)
157
 
158
  best_start = 0
 
159
  best_sum = csum[window_size] - csum[0]
160
  best_avg = best_sum / window_size
161
 
 
186
  plt.close(fig)
187
  return img
188
 
189
+ def get_zero_centered_cmap():
190
+ """
191
+ Creates a custom diverging colormap that is:
192
+ - Blue for negative
193
+ - White for zero
194
+ - Red for positive
195
+ """
196
+ colors = [
197
+ (0.0, 'blue'), # negative
198
+ (0.5, 'white'), # zero
199
+ (1.0, 'red') # positive
200
+ ]
201
+ cmap = mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
202
+ return cmap
203
+
204
  def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
205
  """
206
+ Plots a 1D heatmap of per-base SHAP contributions with a custom colormap:
207
+ - Negative = blue
208
+ - 0 = white
209
+ - Positive = red
210
+ We'll force the range to be symmetrical around 0 by using:
211
+ vmin=-extent, vmax=+extent
212
+ so 0 is in the middle.
213
  """
214
  if start is not None and end is not None:
215
+ local_shap = shap_means[start:end]
216
  subtitle = f" (positions {start}-{end})"
217
  else:
218
+ local_shap = shap_means
219
  subtitle = ""
220
 
221
+ if len(local_shap) == 0:
222
+ # Edge case: no data to plot
223
+ local_shap = np.array([0.0])
224
+
225
+ # Build 2D array for imshow
226
+ heatmap_data = local_shap.reshape(1, -1)
227
+
228
+ # Force symmetrical range
229
+ min_val = np.min(local_shap)
230
+ max_val = np.max(local_shap)
231
+ extent = max(abs(min_val), abs(max_val))
232
+
233
+ # Create custom colormap
234
+ custom_cmap = get_zero_centered_cmap()
235
 
236
  fig, ax = plt.subplots(figsize=(12, 2))
237
+ cax = ax.imshow(
238
+ heatmap_data,
239
+ aspect='auto',
240
+ cmap=custom_cmap,
241
+ vmin=-extent,
242
+ vmax=+extent
243
+ )
244
 
245
+ # Place colorbar below with plenty of margin
246
  cbar = plt.colorbar(cax, orientation='horizontal', pad=0.35)
247
+ cbar.set_label('SHAP Contribution (negative=blue, zero=white, positive=red)')
248
 
249
  ax.set_yticks([])
250
  ax.set_xlabel('Position in Sequence')
 
265
  values = shap_values[indices]
266
  features = [kmers[i] for i in indices]
267
 
268
+ # negative -> blue, positive -> red
269
+ colors = ['#99ccff' if v < 0 else '#ff9999' for v in values]
270
 
271
  plt.barh(range(len(values)), values, color=colors)
272
  plt.yticks(range(len(values)), features)
 
279
  def plot_shap_histogram(shap_array, title="SHAP Distribution in Region"):
280
  """
281
  Simple histogram of SHAP values in the subregion.
 
282
  """
283
  fig, ax = plt.subplots(figsize=(6, 4))
284
  ax.hist(shap_array, bins=30, color='gray', edgecolor='black')
 
328
  # Load model and scaler
329
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
330
  try:
331
+ # Use weights_only=True for safer loading
332
  state_dict = torch.load('model.pt', map_location=device, weights_only=True)
333
  model = VirusClassifier(256).to(device)
334
  model.load_state_dict(state_dict)
335
+
 
336
  scaler = joblib.load('scaler.pkl')
337
  except Exception as e:
338
  return (f"Error loading model/scaler: {str(e)}", None, None, None, None)
 
386
  "shap_means": shap_means
387
  }
388
 
 
389
  return (results_text, bar_img, heatmap_img, state_dict_out, header)
390
 
391
  ###############################################################################
 
470
 
471
  with gr.Blocks(css=css) as iface:
472
  gr.Markdown("""
473
+ # Virus Host Classifier with White-Centered Gradient
474
  **Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme regions.
475
  **Step 2**: Explore subregions to see local SHAP signals, distribution, GC content, etc.
476
+
477
+ **Color Scale**: Negative SHAP = Blue, Zero = White, Positive = Red.
478
  """)
479
 
480
  with gr.Tab("1) Full-Sequence Analysis"):
 
511
  label="Classification Results", lines=12, interactive=False
512
  )
513
  kmer_img = gr.Image(label="Top k-mer SHAP")
514
+ genome_img = gr.Image(label="Genome-wide SHAP Heatmap (Blue=neg, White=0, Red=pos)")
515
 
516
  seq_state = gr.State()
517
  header_state = gr.State()
518
 
519
+ # analyze_sequence(...) returns 5 items
520
  analyze_btn.click(
521
  analyze_sequence,
522
  inputs=[file_input, top_k, text_input, win_size],
 
526
  with gr.Tab("2) Subregion Exploration"):
527
  gr.Markdown("""
528
  **Subregion Analysis**
529
+ Select start/end positions to view local SHAP signals, distribution, and GC content.
530
+ The heatmap also uses the same Blue-White-Red scale.
531
  """)
532
  with gr.Row():
533
  region_start = gr.Number(label="Region Start", value=0)
 
540
  interactive=False
541
  )
542
  with gr.Row():
543
+ subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)")
544
  subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
545
 
546
  region_btn.click(
 
552
  gr.Markdown("""
553
  ### Interface Features
554
  - **Overall Classification** (human vs non-human) using k-mer frequencies.
555
+ - **SHAP Analysis** to see which k-mers push classification toward or away from human.
556
+ - **White-Centered SHAP Gradient**:
557
+ - Negative (blue), 0 (white), Positive (red), with symmetrical color range around 0.
558
+ - **Identify Subregions** with the strongest push for human or non-human.
559
+ - **Subregion Exploration**:
560
+ - Local SHAP heatmap & histogram
561
+ - GC content
562
+ - Fraction of positions pushing human vs. non-human
563
+ - Simple logic-based classification
564
  """)
565
 
566
  if __name__ == "__main__":