hiyata commited on
Commit
6c88c65
·
verified ·
1 Parent(s): b0fba50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -153
app.py CHANGED
@@ -4,10 +4,8 @@ import joblib
4
  import numpy as np
5
  from itertools import product
6
  import torch.nn as nn
7
- import shap
8
  import matplotlib.pyplot as plt
9
  import io
10
- import json
11
  from PIL import Image
12
 
13
  class VirusClassifier(nn.Module):
@@ -31,16 +29,16 @@ class VirusClassifier(nn.Module):
31
  return self.network(x)
32
 
33
  def get_feature_importance(self, x):
34
- """Calculate feature importance using gradient-based method for the human class (index 1)"""
35
  x.requires_grad_(True)
36
  output = self.network(x)
37
  probs = torch.softmax(output, dim=1)
38
 
39
- # We focus on the human class (index 1) probability
40
  human_prob = probs[..., 1]
 
 
41
  human_prob.backward()
42
-
43
- # The gradient shows how each feature affects the human probability
44
  importance = x.grad
45
 
46
  return importance, float(human_prob)
@@ -82,6 +80,94 @@ def parse_fasta(text):
82
  sequences.append((current_header, ''.join(current_sequence)))
83
  return sequences
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def predict(file_obj):
86
  if file_obj is None:
87
  return "Please upload a FASTA file", None
@@ -119,172 +205,64 @@ def predict(file_obj):
119
  kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
120
  X_tensor = torch.FloatTensor(kmer_vector).to(device)
121
 
122
- # Calculate final probabilities first
123
  with torch.no_grad():
124
  output = model(X_tensor)
125
  probs = torch.softmax(output, dim=1)
126
- human_prob = float(probs[0][1])
127
-
128
- # Get feature importance using integrated gradients
129
- baseline = torch.zeros_like(X_tensor) # baseline of zeros
130
- steps = 50
131
-
132
- all_importance = []
133
- for i in range(steps + 1):
134
- alpha = i / steps
135
- interpolated = baseline + alpha * (X_tensor - baseline)
136
- interpolated.requires_grad_(True)
137
-
138
- output = model(interpolated)
139
- probs = torch.softmax(output, dim=1)
140
- human_class = probs[..., 1]
141
-
142
- if interpolated.grad is not None:
143
- interpolated.grad.zero_()
144
- human_class.backward()
145
- all_importance.append(interpolated.grad.cpu().numpy())
146
 
147
- # Average the gradients
148
- kmer_importance = np.mean(all_importance, axis=0)[0]
149
- # Scale to match probability difference
150
- target_diff = human_prob - 0.5 # difference from neutral prediction
151
- current_sum = np.sum(kmer_importance)
152
- if current_sum != 0: # avoid division by zero
153
- kmer_importance = kmer_importance * (target_diff / current_sum)
154
 
155
- # Get top k-mers by absolute importance
156
  top_k = 10
157
  top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
158
- important_kmers = [
159
- {
160
- 'kmer': list(kmer_dict.keys())[list(kmer_dict.values()).index(i)],
161
- 'importance': float(kmer_importance[i]),
162
- 'frequency': float(raw_freq_vector[i]),
163
- 'scaled': float(kmer_vector[0][i])
164
- }
165
- for i in top_indices
166
- ]
167
-
168
- # Prepare data for SHAP waterfall plot
169
- top_features = [item['kmer'] for item in important_kmers]
170
- top_values = [item['importance'] for item in important_kmers]
171
-
172
- # Calculate the impact of remaining features
173
- others_mask = np.ones_like(kmer_importance, dtype=bool)
174
- others_mask[top_indices] = False
175
- others_sum = np.sum(kmer_importance[others_mask])
176
 
177
- top_features.append("Others")
178
- top_values.append(others_sum)
179
-
180
- # Calculate final probabilities first
181
- with torch.no_grad():
182
- output = model(X_tensor)
183
- probs = torch.softmax(output, dim=1)
184
- human_prob = float(probs[0][1])
185
-
186
- # Create SHAP explanation
187
- # We'll use the actual probabilities for alignment
188
- explanation = shap.Explanation(
189
- values=np.array(top_values),
190
- base_values=0.5, # Start from neutral prediction
191
- data=np.array([
192
- raw_freq_vector[kmer_dict[feat]] if feat != "Others"
193
- else np.sum(raw_freq_vector[others_mask])
194
- for feat in top_features
195
- ]),
196
- feature_names=top_features
197
- )
198
- explanation.expected_value = 0.5 # Start from neutral prediction
199
-
200
- # Calculate step-by-step probabilities
201
- current_prob = 0.5 # Start at neutral
202
- steps = [('Start', current_prob, 0)]
203
-
204
- # Process each k-mer contribution
205
- for kmer in important_kmers:
206
- change = kmer['importance']
207
- current_prob += change
208
- steps.append((kmer['kmer'], current_prob, change))
209
-
210
- # Add final "Others" contribution
211
- current_prob += others_sum
212
- steps.append(('Others', current_prob, others_sum))
213
-
214
- # Create step plot
215
- plt.figure(figsize=(12, 6))
216
- x = range(len(steps))
217
- y = [step[1] for step in steps]
218
-
219
- # Plot steps
220
- plt.step(x, y, 'b-', where='post', label='Probability', linewidth=2)
221
- plt.plot(x, y, 'b.', markersize=10)
222
-
223
- # Add reference line
224
- plt.axhline(y=0.5, color='r', linestyle='--', label='Neutral (0.5)')
225
-
226
- # Customize plot
227
- plt.grid(True, linestyle='--', alpha=0.7)
228
- plt.ylim(0, 1)
229
- plt.ylabel('Human Probability')
230
- plt.title(f'K-mer Contributions to Prediction (final prob: {human_prob:.3f})')
231
-
232
- # Add labels for each point
233
- for i, (kmer, prob, change) in enumerate(steps):
234
- # Add k-mer label
235
- plt.annotate(kmer,
236
- (i, prob),
237
- xytext=(0, 10 if i % 2 == 0 else -20), # Alternate up/down
238
- textcoords='offset points',
239
- ha='center',
240
- rotation=45 if len(kmer) > 5 else 0)
241
 
242
- # Add change value
243
- if i > 0: # Skip first point (Start)
244
- change_text = f'{change:+.3f}'
245
- color = 'green' if change > 0 else 'red'
246
- plt.annotate(change_text,
247
- (i, prob),
248
- xytext=(0, -20 if i % 2 == 0 else 10),
249
- textcoords='offset points',
250
- ha='center',
251
- color=color)
252
-
253
- plt.legend()
254
- plt.tight_layout()
255
 
256
- # Save plot
257
- buf = io.BytesIO()
258
- plt.savefig(buf, format='png', bbox_inches='tight', dpi=300)
259
- buf.seek(0)
260
- plot_image = Image.open(buf)
261
- plt.close()
262
-
263
- # Calculate final probabilities
264
- with torch.no_grad():
265
- output = model(X_tensor)
266
- probs = torch.softmax(output, dim=1)
267
-
268
  pred_class = 1 if probs[0][1] > probs[0][0] else 0
269
  pred_label = 'human' if pred_class == 1 else 'non-human'
 
270
 
271
- # Generate results text
272
- results_text += f"""Sequence: {header}
273
  Prediction: {pred_label}
274
  Confidence: {float(max(probs[0])):0.4f}
275
- Human probability: {float(probs[0][1]):0.4f}
276
  Non-human probability: {float(probs[0][0]):0.4f}
277
  Most influential k-mers (ranked by importance):"""
278
 
279
  for kmer in important_kmers:
280
- direction = "human" if kmer['importance'] > 0 else "non-human"
281
  results_text += f"\n {kmer['kmer']}: "
282
- results_text += f"pushes toward {direction} (impact={abs(kmer['importance']):.4f}), "
283
- results_text += f"occurrence={kmer['frequency']*100:.2f}% of sequence "
284
- if kmer['scaled'] > 0:
285
- results_text += f"(appears {abs(kmer['scaled']):.2f}σ more than average)"
286
- else:
287
- results_text += f"(appears {abs(kmer['scaled']):.2f}σ less than average)"
 
 
 
 
 
 
 
 
 
288
 
289
  except Exception as e:
290
  return f"Error processing sequences: {str(e)}", None
@@ -294,7 +272,10 @@ Most influential k-mers (ranked by importance):"""
294
  iface = gr.Interface(
295
  fn=predict,
296
  inputs=gr.File(label="Upload FASTA file", type="binary"),
297
- outputs=[gr.Textbox(label="Results"), gr.Image(label="SHAP Waterfall Plot")],
 
 
 
298
  title="Virus Host Classifier"
299
  )
300
 
 
4
  import numpy as np
5
  from itertools import product
6
  import torch.nn as nn
 
7
  import matplotlib.pyplot as plt
8
  import io
 
9
  from PIL import Image
10
 
11
  class VirusClassifier(nn.Module):
 
29
  return self.network(x)
30
 
31
  def get_feature_importance(self, x):
32
+ """Calculate feature importance using gradient-based method"""
33
  x.requires_grad_(True)
34
  output = self.network(x)
35
  probs = torch.softmax(output, dim=1)
36
 
37
+ # Get importance for human class (index 1)
38
  human_prob = probs[..., 1]
39
+ if x.grad is not None:
40
+ x.grad.zero_()
41
  human_prob.backward()
 
 
42
  importance = x.grad
43
 
44
  return importance, float(human_prob)
 
80
  sequences.append((current_header, ''.join(current_sequence)))
81
  return sequences
82
 
83
+ def create_visualization(important_kmers, human_prob, title):
84
+ """Create a comprehensive visualization of k-mer impacts"""
85
+ fig = plt.figure(figsize=(15, 10))
86
+
87
+ # Create grid for subplots
88
+ gs = plt.GridSpec(2, 1, height_ratios=[1.5, 1], hspace=0.3)
89
+
90
+ # 1. Probability Step Plot
91
+ ax1 = plt.subplot(gs[0])
92
+ current_prob = 0.5
93
+ steps = [('Start', current_prob, 0)]
94
+
95
+ for kmer in important_kmers:
96
+ change = kmer['impact'] * (-1 if kmer['direction'] == 'non-human' else 1)
97
+ current_prob += change
98
+ steps.append((kmer['kmer'], current_prob, change))
99
+
100
+ x = range(len(steps))
101
+ y = [step[1] for step in steps]
102
+
103
+ # Plot steps
104
+ ax1.step(x, y, 'b-', where='post', label='Probability', linewidth=2)
105
+ ax1.plot(x, y, 'b.', markersize=10)
106
+
107
+ # Add reference line
108
+ ax1.axhline(y=0.5, color='r', linestyle='--', label='Neutral (0.5)')
109
+
110
+ # Customize plot
111
+ ax1.grid(True, linestyle='--', alpha=0.7)
112
+ ax1.set_ylim(0, 1)
113
+ ax1.set_ylabel('Human Probability')
114
+ ax1.set_title(f'K-mer Contributions to Prediction (final prob: {human_prob:.3f})')
115
+
116
+ # Add labels for each point
117
+ for i, (kmer, prob, change) in enumerate(steps):
118
+ # Add k-mer label
119
+ ax1.annotate(kmer,
120
+ (i, prob),
121
+ xytext=(0, 10 if i % 2 == 0 else -20),
122
+ textcoords='offset points',
123
+ ha='center',
124
+ rotation=45)
125
+
126
+ # Add change value
127
+ if i > 0:
128
+ change_text = f'{change:+.3f}'
129
+ color = 'green' if change > 0 else 'red'
130
+ ax1.annotate(change_text,
131
+ (i, prob),
132
+ xytext=(0, -20 if i % 2 == 0 else 10),
133
+ textcoords='offset points',
134
+ ha='center',
135
+ color=color)
136
+
137
+ ax1.legend()
138
+
139
+ # 2. K-mer Frequency and Sigma Plot
140
+ ax2 = plt.subplot(gs[1])
141
+
142
+ # Prepare data
143
+ kmers = [k['kmer'] for k in important_kmers]
144
+ frequencies = [k['occurrence'] for k in important_kmers]
145
+ sigmas = [k['sigma'] for k in important_kmers]
146
+ colors = ['g' if k['direction'] == 'human' else 'r' for k in important_kmers]
147
+
148
+ # Create bar plot for frequencies
149
+ x = np.arange(len(kmers))
150
+ width = 0.35
151
+
152
+ ax2.bar(x - width/2, frequencies, width, label='Frequency (%)', color=colors, alpha=0.6)
153
+ ax2_twin = ax2.twinx()
154
+ ax2_twin.bar(x + width/2, sigmas, width, label='σ from mean', color=[c if s > 0 else 'gray' for c, s in zip(colors, sigmas)], alpha=0.3)
155
+
156
+ # Customize plot
157
+ ax2.set_xticks(x)
158
+ ax2.set_xticklabels(kmers, rotation=45)
159
+ ax2.set_ylabel('Frequency (%)')
160
+ ax2_twin.set_ylabel('Standard Deviations (σ) from Mean')
161
+ ax2.set_title('K-mer Frequencies and Statistical Significance')
162
+
163
+ # Add legends
164
+ lines1, labels1 = ax2.get_legend_handles_labels()
165
+ lines2, labels2 = ax2_twin.get_legend_handles_labels()
166
+ ax2.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
167
+
168
+ plt.tight_layout()
169
+ return fig
170
+
171
  def predict(file_obj):
172
  if file_obj is None:
173
  return "Please upload a FASTA file", None
 
205
  kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
206
  X_tensor = torch.FloatTensor(kmer_vector).to(device)
207
 
208
+ # Get model predictions
209
  with torch.no_grad():
210
  output = model(X_tensor)
211
  probs = torch.softmax(output, dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
+ # Get feature importance
214
+ importance, _ = model.get_feature_importance(X_tensor)
215
+ kmer_importance = importance[0].cpu().numpy()
 
 
 
 
216
 
217
+ # Get top k-mers
218
  top_k = 10
219
  top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
+ important_kmers = []
222
+ for idx in top_indices:
223
+ kmer = list(kmer_dict.keys())[list(kmer_dict.values()).index(idx)]
224
+ imp = float(abs(kmer_importance[idx]))
225
+ direction = 'human' if kmer_importance[idx] > 0 else 'non-human'
226
+ freq = float(raw_freq_vector[idx] * 100) # Convert to percentage
227
+ sigma = float(kmer_vector[0][idx])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
+ important_kmers.append({
230
+ 'kmer': kmer,
231
+ 'impact': imp,
232
+ 'direction': direction,
233
+ 'occurrence': freq,
234
+ 'sigma': sigma
235
+ })
 
 
 
 
 
 
236
 
237
+ # Generate text results
 
 
 
 
 
 
 
 
 
 
 
238
  pred_class = 1 if probs[0][1] > probs[0][0] else 0
239
  pred_label = 'human' if pred_class == 1 else 'non-human'
240
+ human_prob = float(probs[0][1])
241
 
242
+ results_text = f"""Sequence: {header}
 
243
  Prediction: {pred_label}
244
  Confidence: {float(max(probs[0])):0.4f}
245
+ Human probability: {human_prob:0.4f}
246
  Non-human probability: {float(probs[0][0]):0.4f}
247
  Most influential k-mers (ranked by importance):"""
248
 
249
  for kmer in important_kmers:
 
250
  results_text += f"\n {kmer['kmer']}: "
251
+ results_text += f"pushes toward {kmer['direction']} (impact={kmer['impact']:.4f}), "
252
+ results_text += f"occurrence={kmer['occurrence']:.2f}% of sequence "
253
+ results_text += f"(appears {abs(kmer['sigma']):.2f}σ "
254
+ results_text += "more" if kmer['sigma'] > 0 else "less"
255
+ results_text += " than average)"
256
+
257
+ # Create visualization
258
+ fig = create_visualization(important_kmers, human_prob, header)
259
+
260
+ # Save plot
261
+ buf = io.BytesIO()
262
+ fig.savefig(buf, format='png', bbox_inches='tight', dpi=300)
263
+ buf.seek(0)
264
+ plot_image = Image.open(buf)
265
+ plt.close(fig)
266
 
267
  except Exception as e:
268
  return f"Error processing sequences: {str(e)}", None
 
272
  iface = gr.Interface(
273
  fn=predict,
274
  inputs=gr.File(label="Upload FASTA file", type="binary"),
275
+ outputs=[
276
+ gr.Textbox(label="Results"),
277
+ gr.Image(label="K-mer Analysis Visualization")
278
+ ],
279
  title="Virus Host Classifier"
280
  )
281