hiyata commited on
Commit
962ae70
·
verified ·
1 Parent(s): 7aea9ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -35
app.py CHANGED
@@ -8,6 +8,10 @@ import matplotlib.pyplot as plt
8
  import io
9
  from PIL import Image
10
 
 
 
 
 
11
  class VirusClassifier(nn.Module):
12
  def __init__(self, input_shape: int):
13
  super(VirusClassifier, self).__init__()
@@ -28,6 +32,11 @@ class VirusClassifier(nn.Module):
28
  def forward(self, x):
29
  return self.network(x)
30
 
 
 
 
 
 
31
  def parse_fasta(text):
32
  """Parse FASTA formatted text into a list of (header, sequence)."""
33
  sequences = []
@@ -66,6 +75,11 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
66
 
67
  return vec
68
 
 
 
 
 
 
69
  def calculate_shap_values(model, x_tensor):
70
  """
71
  Calculate SHAP values using a simple ablation approach.
@@ -76,22 +90,88 @@ def calculate_shap_values(model, x_tensor):
76
  # Get baseline prediction
77
  baseline_output = model(x_tensor)
78
  baseline_probs = torch.softmax(baseline_output, dim=1)
79
- baseline_prob = baseline_probs[0, 1].item() # Probability of human class
80
 
81
  # Calculate impact of zeroing each feature
82
  shap_values = []
83
  x_zeroed = x_tensor.clone()
84
  for i in range(x_tensor.shape[1]):
85
- x_zeroed[0, i] = 0
 
86
  output = model(x_zeroed)
87
  probs = torch.softmax(output, dim=1)
88
  prob = probs[0, 1].item()
89
- impact = baseline_prob - prob # How much removing the feature changed the prediction
90
  shap_values.append(impact)
91
- x_zeroed[0, i] = x_tensor[0, i] # Restore the original value
92
 
93
  return np.array(shap_values), baseline_prob
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def create_importance_bar_plot(shap_values, kmers, top_k=10):
96
  """Create a bar plot of the most important k-mers."""
97
  plt.rcParams.update({'font.size': 10})
@@ -108,7 +188,7 @@ def create_importance_bar_plot(shap_values, kmers, top_k=10):
108
  plt.yticks(range(len(values)), features)
109
  plt.xlabel('SHAP value (impact on model output)')
110
  plt.title(f'Top {top_k} Most Influential k-mers')
111
- plt.gca().invert_yaxis() # Most important at top
112
 
113
  return plt.gcf()
114
 
@@ -147,16 +227,14 @@ def visualize_sequence_impacts(sequence, kmers, shap_values, base_prob):
147
  # Plot k-mers with controlled spacing
148
  y_spacing = 0.9 / max(len(display_kmers), 1)
149
  y_position = 0.95
150
- max_seq_display = 100 # Maximum sequence length to show
151
 
152
  for pos, kmer, impact in display_kmers:
153
- # Truncate sequence display if too long
154
  pre_sequence = sequence[max(0, pos-20):pos]
155
- post_sequence = sequence[pos+k:min(pos+k+20, len(sequence))]
156
 
157
  # Add ellipsis if truncated
158
  pre_ellipsis = "..." if pos > 20 else ""
159
- post_ellipsis = "..." if pos+k+20 < len(sequence) else ""
160
 
161
  # Choose color based on impact
162
  color = '#ffcccb' if impact > 0 else '#cce0ff'
@@ -165,9 +243,9 @@ def visualize_sequence_impacts(sequence, kmers, shap_values, base_prob):
165
  # Draw text elements
166
  plt.text(0.01, y_position, f"{pre_ellipsis}{pre_sequence}", fontsize=9)
167
  plt.text(0.01 + len(f"{pre_ellipsis}{pre_sequence}")/50, y_position,
168
- kmer, fontsize=9, bbox=dict(facecolor=color, alpha=0.3, pad=1))
169
  plt.text(0.01 + (len(f"{pre_ellipsis}{pre_sequence}") + len(kmer))/50,
170
- y_position, f"{post_sequence}{post_ellipsis}", fontsize=9)
171
 
172
  # Add impact value
173
  plt.text(0.8, y_position, f"{arrow} {impact:+.3f}", fontsize=9)
@@ -176,10 +254,29 @@ def visualize_sequence_impacts(sequence, kmers, shap_values, base_prob):
176
 
177
  plt.axis('off')
178
 
179
- # Adjust layout with specific margins
180
  plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05)
181
  return fig
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  def predict(file_obj, top_kmers=10, fasta_text=""):
184
  """Main prediction function for Gradio interface."""
185
  # Handle input
@@ -190,25 +287,26 @@ def predict(file_obj, top_kmers=10, fasta_text=""):
190
  with open(file_obj, 'r') as f:
191
  text = f.read()
192
  except Exception as e:
193
- return f"Error reading file: {str(e)}", None, None
194
  else:
195
- return "Please provide a FASTA sequence.", None, None
196
 
197
  # Parse FASTA
198
  sequences = parse_fasta(text)
199
  if not sequences:
200
- return "No valid FASTA sequences found.", None, None
201
 
202
  header, seq = sequences[0]
203
 
204
- # Load model and process sequence
205
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
206
  try:
207
  model = VirusClassifier(256).to(device)
208
- model.load_state_dict(torch.load('model.pt', map_location=device, weights_only=True))
 
209
  scaler = joblib.load('scaler.pkl')
210
  except Exception as e:
211
- return f"Error loading model: {str(e)}", None, None
212
 
213
  # Generate features
214
  freq_vector = sequence_to_kmer_vector(seq)
@@ -218,34 +316,38 @@ def predict(file_obj, top_kmers=10, fasta_text=""):
218
  # Calculate SHAP values and get prediction
219
  shap_values, prob_human = calculate_shap_values(model, x_tensor)
220
 
221
- # Generate result text
222
  results = [
223
  f"Sequence: {header}",
224
  f"Prediction: {'Human' if prob_human > 0.5 else 'Non-human'} Origin",
225
- f"Confidence: {max(prob_human, 1-prob_human):.3f}",
226
  f"Human Probability: {prob_human:.3f}",
227
  "\nTop Contributing k-mers:"
228
  ]
229
-
230
- # Get k-mers for visualization
231
  kmers = [''.join(p) for p in product("ACGT", repeat=4)]
232
 
233
- # Create visualizations
234
  importance_plot = create_importance_bar_plot(shap_values, kmers, top_kmers)
 
 
 
235
  sequence_plot = visualize_sequence_impacts(seq, kmers, shap_values, prob_human)
 
236
 
237
- # Convert plots to images
238
- def fig_to_image(fig):
239
- buf = io.BytesIO()
240
- fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
241
- buf.seek(0)
242
- img = Image.open(buf)
243
- plt.close(fig)
244
- return img
245
 
246
- return "\n".join(results), fig_to_image(importance_plot), fig_to_image(sequence_plot)
 
 
247
 
248
- # Create Gradio interface
249
  css = """
250
  .gradio-container {
251
  font-family: 'IBM Plex Sans', sans-serif;
@@ -283,11 +385,12 @@ with gr.Blocks(css=css) as iface:
283
  results = gr.Textbox(label="Analysis Results", lines=10)
284
  kmer_plot = gr.Image(label="K-mer Importance Plot")
285
  shap_plot = gr.Image(label="Sequence Impact Visualization (SHAP-style)")
 
286
 
287
  submit_btn.click(
288
  predict,
289
  inputs=[file_input, top_k, text_input],
290
- outputs=[results, kmer_plot, shap_plot]
291
  )
292
 
293
  gr.Markdown("""
@@ -298,7 +401,10 @@ with gr.Blocks(css=css) as iface:
298
  - Blue highlights = pushing toward non-human origin
299
  - Arrows (↑/↓) show impact direction
300
  - Values show impact magnitude
 
 
 
301
  """)
302
 
303
  if __name__ == "__main__":
304
- iface.launch()
 
8
  import io
9
  from PIL import Image
10
 
11
+ ###############################################################################
12
+ # 1. MODEL DEFINITION
13
+ ###############################################################################
14
+
15
  class VirusClassifier(nn.Module):
16
  def __init__(self, input_shape: int):
17
  super(VirusClassifier, self).__init__()
 
32
  def forward(self, x):
33
  return self.network(x)
34
 
35
+
36
+ ###############################################################################
37
+ # 2. FASTA PARSING & K-MER FEATURE ENGINEERING
38
+ ###############################################################################
39
+
40
  def parse_fasta(text):
41
  """Parse FASTA formatted text into a list of (header, sequence)."""
42
  sequences = []
 
75
 
76
  return vec
77
 
78
+
79
+ ###############################################################################
80
+ # 3. SHAP-VALUE (ABLATION) CALCULATION
81
+ ###############################################################################
82
+
83
  def calculate_shap_values(model, x_tensor):
84
  """
85
  Calculate SHAP values using a simple ablation approach.
 
90
  # Get baseline prediction
91
  baseline_output = model(x_tensor)
92
  baseline_probs = torch.softmax(baseline_output, dim=1)
93
+ baseline_prob = baseline_probs[0, 1].item() # Probability of 'human' class
94
 
95
  # Calculate impact of zeroing each feature
96
  shap_values = []
97
  x_zeroed = x_tensor.clone()
98
  for i in range(x_tensor.shape[1]):
99
+ orig_value = x_zeroed[0, i].item()
100
+ x_zeroed[0, i] = 0.0
101
  output = model(x_zeroed)
102
  probs = torch.softmax(output, dim=1)
103
  prob = probs[0, 1].item()
104
+ impact = baseline_prob - prob # how much removing the feature changed the prediction
105
  shap_values.append(impact)
106
+ x_zeroed[0, i] = orig_value # restore the original value
107
 
108
  return np.array(shap_values), baseline_prob
109
 
110
+
111
+ ###############################################################################
112
+ # 4. PER-BASE SHAP AGGREGATION (LINEAR HEATMAP)
113
+ ###############################################################################
114
+
115
+ def compute_positionwise_scores(sequence, shap_values, k=4):
116
+ """
117
+ Returns an array of per-base SHAP contributions by averaging
118
+ the k-mer SHAP values of all k-mers covering that base.
119
+ """
120
+ # Create the list of k-mers (in lexicographic order)
121
+ kmers = [''.join(p) for p in product("ACGT", repeat=k)]
122
+ kmer_dict = {km: i for i, km in enumerate(kmers)}
123
+
124
+ seq_len = len(sequence)
125
+
126
+ # Arrays to accumulate sums (SHAP) and coverage counts
127
+ shap_sums = np.zeros(seq_len, dtype=np.float32)
128
+ coverage = np.zeros(seq_len, dtype=np.float32)
129
+
130
+ # Slide over the sequence, summing SHAP values for overlapping positions
131
+ for i in range(seq_len - k + 1):
132
+ kmer = sequence[i:i+k]
133
+ if kmer in kmer_dict:
134
+ # Get the SHAP value for this k-mer
135
+ value = shap_values[kmer_dict[kmer]]
136
+ # Accumulate it for each base in the k-mer
137
+ shap_sums[i : i + k] += value
138
+ coverage[i : i + k] += 1
139
+
140
+ # Compute the average SHAP per base (avoid divide-by-zero)
141
+ with np.errstate(divide='ignore', invalid='ignore'):
142
+ shap_means = np.where(coverage > 0, shap_sums / coverage, 0.0)
143
+
144
+ return shap_means
145
+
146
+ def plot_linear_heatmap(shap_means):
147
+ """
148
+ Plots a 1D heatmap of per-base SHAP contributions.
149
+ Negative = push toward Non-Human, Positive = push toward Human.
150
+ """
151
+ # Reshape into (1, -1) so that imshow displays it as a single row
152
+ heatmap_data = shap_means.reshape(1, -1)
153
+
154
+ fig, ax = plt.subplots(figsize=(12, 2))
155
+
156
+ # We'll use a diverging color map (red/blue)
157
+ cax = ax.imshow(heatmap_data, aspect='auto', cmap='RdBu_r')
158
+
159
+ # Add colorbar
160
+ cbar = plt.colorbar(cax, orientation='horizontal', pad=0.2)
161
+ cbar.set_label('SHAP Contribution')
162
+
163
+ ax.set_yticks([]) # single row, so hide the y-axis
164
+ ax.set_xlabel('Position in Sequence')
165
+ ax.set_title('Per-base SHAP Heatmap')
166
+
167
+ plt.tight_layout()
168
+ return fig
169
+
170
+
171
+ ###############################################################################
172
+ # 5. OTHER PLOTS: BAR PLOT OF TOP-K AND SEQUENCE IMPACT VISUALIZATION
173
+ ###############################################################################
174
+
175
  def create_importance_bar_plot(shap_values, kmers, top_k=10):
176
  """Create a bar plot of the most important k-mers."""
177
  plt.rcParams.update({'font.size': 10})
 
188
  plt.yticks(range(len(values)), features)
189
  plt.xlabel('SHAP value (impact on model output)')
190
  plt.title(f'Top {top_k} Most Influential k-mers')
191
+ plt.gca().invert_yaxis() # most important at top
192
 
193
  return plt.gcf()
194
 
 
227
  # Plot k-mers with controlled spacing
228
  y_spacing = 0.9 / max(len(display_kmers), 1)
229
  y_position = 0.95
 
230
 
231
  for pos, kmer, impact in display_kmers:
 
232
  pre_sequence = sequence[max(0, pos-20):pos]
233
+ post_sequence = sequence[pos+len(kmer):min(pos+len(kmer)+20, len(sequence))]
234
 
235
  # Add ellipsis if truncated
236
  pre_ellipsis = "..." if pos > 20 else ""
237
+ post_ellipsis = "..." if pos+len(kmer)+20 < len(sequence) else ""
238
 
239
  # Choose color based on impact
240
  color = '#ffcccb' if impact > 0 else '#cce0ff'
 
243
  # Draw text elements
244
  plt.text(0.01, y_position, f"{pre_ellipsis}{pre_sequence}", fontsize=9)
245
  plt.text(0.01 + len(f"{pre_ellipsis}{pre_sequence}")/50, y_position,
246
+ kmer, fontsize=9, bbox=dict(facecolor=color, alpha=0.3, pad=1))
247
  plt.text(0.01 + (len(f"{pre_ellipsis}{pre_sequence}") + len(kmer))/50,
248
+ y_position, f"{post_sequence}{post_ellipsis}", fontsize=9)
249
 
250
  # Add impact value
251
  plt.text(0.8, y_position, f"{arrow} {impact:+.3f}", fontsize=9)
 
254
 
255
  plt.axis('off')
256
 
257
+ # Adjust layout
258
  plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05)
259
  return fig
260
 
261
+
262
+ ###############################################################################
263
+ # 6. HELPER FUNCTION: FIG TO IMAGE
264
+ ###############################################################################
265
+
266
+ def fig_to_image(fig):
267
+ """Convert a Matplotlib figure to a PIL Image."""
268
+ buf = io.BytesIO()
269
+ fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
270
+ buf.seek(0)
271
+ img = Image.open(buf)
272
+ plt.close(fig)
273
+ return img
274
+
275
+
276
+ ###############################################################################
277
+ # 7. MAIN PREDICTION FUNCTION
278
+ ###############################################################################
279
+
280
  def predict(file_obj, top_kmers=10, fasta_text=""):
281
  """Main prediction function for Gradio interface."""
282
  # Handle input
 
287
  with open(file_obj, 'r') as f:
288
  text = f.read()
289
  except Exception as e:
290
+ return f"Error reading file: {str(e)}", None, None, None
291
  else:
292
+ return "Please provide a FASTA sequence.", None, None, None
293
 
294
  # Parse FASTA
295
  sequences = parse_fasta(text)
296
  if not sequences:
297
+ return "No valid FASTA sequences found.", None, None, None
298
 
299
  header, seq = sequences[0]
300
 
301
+ # Load model and scaler
302
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
303
  try:
304
  model = VirusClassifier(256).to(device)
305
+ # Remove 'weights_only=True' if it causes errors; it's not a standard argument.
306
+ model.load_state_dict(torch.load('model.pt', map_location=device))
307
  scaler = joblib.load('scaler.pkl')
308
  except Exception as e:
309
+ return f"Error loading model: {str(e)}", None, None, None
310
 
311
  # Generate features
312
  freq_vector = sequence_to_kmer_vector(seq)
 
316
  # Calculate SHAP values and get prediction
317
  shap_values, prob_human = calculate_shap_values(model, x_tensor)
318
 
319
+ # Prediction text
320
  results = [
321
  f"Sequence: {header}",
322
  f"Prediction: {'Human' if prob_human > 0.5 else 'Non-human'} Origin",
323
+ f"Confidence: {max(prob_human, 1 - prob_human):.3f}",
324
  f"Human Probability: {prob_human:.3f}",
325
  "\nTop Contributing k-mers:"
326
  ]
327
+
328
+ # Create k-mer lists for visualization
329
  kmers = [''.join(p) for p in product("ACGT", repeat=4)]
330
 
331
+ # 1) K-mer importance bar plot
332
  importance_plot = create_importance_bar_plot(shap_values, kmers, top_kmers)
333
+ importance_img = fig_to_image(importance_plot)
334
+
335
+ # 2) SHAP-style textual sequence impact
336
  sequence_plot = visualize_sequence_impacts(seq, kmers, shap_values, prob_human)
337
+ sequence_img = fig_to_image(sequence_plot)
338
 
339
+ # 3) Linear heatmap across full genome
340
+ shap_means = compute_positionwise_scores(seq, shap_values, k=4)
341
+ heatmap_fig = plot_linear_heatmap(shap_means)
342
+ heatmap_img = fig_to_image(heatmap_fig)
343
+
344
+ return "\n".join(results), importance_img, sequence_img, heatmap_img
345
+
 
346
 
347
+ ###############################################################################
348
+ # 8. BUILD GRADIO INTERFACE
349
+ ###############################################################################
350
 
 
351
  css = """
352
  .gradio-container {
353
  font-family: 'IBM Plex Sans', sans-serif;
 
385
  results = gr.Textbox(label="Analysis Results", lines=10)
386
  kmer_plot = gr.Image(label="K-mer Importance Plot")
387
  shap_plot = gr.Image(label="Sequence Impact Visualization (SHAP-style)")
388
+ heatmap_plot = gr.Image(label="Genome Heatmap")
389
 
390
  submit_btn.click(
391
  predict,
392
  inputs=[file_input, top_k, text_input],
393
+ outputs=[results, kmer_plot, shap_plot, heatmap_plot]
394
  )
395
 
396
  gr.Markdown("""
 
401
  - Blue highlights = pushing toward non-human origin
402
  - Arrows (↑/↓) show impact direction
403
  - Values show impact magnitude
404
+ - **Genome Heatmap**: Per-base SHAP values across the entire sequence
405
+ - Red = push toward human
406
+ - Blue = push toward non-human
407
  """)
408
 
409
  if __name__ == "__main__":
410
+ iface.launch()