hiyata commited on
Commit
7e92f7c
Β·
verified Β·
1 Parent(s): 555d484

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -129
app.py CHANGED
@@ -29,9 +29,7 @@ class VirusClassifier(nn.Module):
29
  return self.network(x)
30
 
31
  def parse_fasta(text):
32
- """
33
- Parses FASTA formatted text into a list of (header, sequence).
34
- """
35
  sequences = []
36
  current_header = None
37
  current_sequence = []
@@ -52,9 +50,7 @@ def parse_fasta(text):
52
  return sequences
53
 
54
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
55
- """
56
- Convert a sequence to a k-mer frequency vector.
57
- """
58
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
59
  kmer_dict = {km: i for i, km in enumerate(kmers)}
60
  vec = np.zeros(len(kmers), dtype=np.float32)
@@ -72,130 +68,130 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
72
 
73
  def calculate_shap_values(model, x_tensor):
74
  """
75
- Calculate SHAP-like values using a simple ablation approach.
 
76
  """
77
  model.eval()
78
  with torch.no_grad():
 
79
  baseline_output = model(x_tensor)
80
- baseline_prob = torch.softmax(baseline_output, dim=1)[0, 1].item()
 
81
 
 
82
  shap_values = []
 
83
  for i in range(x_tensor.shape[1]):
84
- perturbed_input = x_tensor.clone()
85
- perturbed_input[0, i] = 0 # Ablate feature
86
- output = model(perturbed_input)
87
- prob = torch.softmax(output, dim=1)[0, 1].item()
88
- shap_values.append(baseline_prob - prob)
 
 
89
 
90
  return np.array(shap_values), baseline_prob
91
 
92
- def create_importance_plot(shap_values, kmers, top_k=10):
93
- """
94
- Create horizontal bar plot of feature importance.
95
- """
96
- # Set style directly instead of using seaborn
97
- plt.rcParams['figure.facecolor'] = '#ffffff'
98
- plt.rcParams['axes.facecolor'] = '#ffffff'
99
- plt.rcParams['axes.grid'] = True
100
- plt.rcParams['grid.alpha'] = 0.3
101
- fig = plt.figure(figsize=(10, 8))
102
 
103
  # Sort by absolute importance
104
  indices = np.argsort(np.abs(shap_values))[-top_k:]
105
  values = shap_values[indices]
106
  features = [kmers[i] for i in indices]
107
 
108
- colors = ['#2ecc71' if v > 0 else '#e74c3c' for v in values]
109
 
110
  plt.barh(range(len(values)), values, color=colors)
111
  plt.yticks(range(len(values)), features)
112
- plt.xlabel('Impact on Prediction (SHAP value)')
113
  plt.title(f'Top {top_k} Most Influential k-mers')
114
- plt.gca().invert_yaxis()
115
 
116
- return fig
117
 
118
- def create_contribution_plot(important_kmers, final_prob):
119
  """
120
- Create waterfall plot showing cumulative feature contributions.
 
121
  """
122
- # Set style parameters
123
- plt.rcParams['figure.facecolor'] = '#ffffff'
124
- plt.rcParams['axes.facecolor'] = '#ffffff'
125
- plt.rcParams['axes.grid'] = True
126
- plt.rcParams['grid.alpha'] = 0.3
127
-
128
- fig, ax = plt.subplots(figsize=(12, 6))
129
 
130
- base_prob = 0.5
131
- cumulative = [base_prob]
132
- labels = ['Base']
 
 
 
 
133
 
134
- for kmer_info in important_kmers:
135
- cumulative.append(cumulative[-1] + kmer_info['impact'])
136
- labels.append(kmer_info['kmer'])
137
 
138
- # Plot cumulative line with markers
139
- line = ax.plot(range(len(cumulative)), cumulative, '-o',
140
- color='#3498db', linewidth=2,
141
- marker='o', markersize=8,
142
- markerfacecolor='white',
143
- markeredgecolor='#3498db',
144
- markeredgewidth=2)
145
 
146
- # Add reference line at 0.5
147
- ax.axhline(y=0.5, color='#95a5a6', linestyle='--', alpha=0.5)
148
 
149
- # Customize plot
150
- ax.set_xticks(range(len(labels)))
151
- ax.set_xticklabels(labels, rotation=45, ha='right')
152
- ax.set_ylim(0, 1)
153
- ax.grid(True, axis='y', linestyle='--', alpha=0.3)
154
- ax.set_title('Cumulative Feature Contributions')
155
- ax.set_ylabel('Probability of Human Origin')
156
 
157
- # Add value labels
158
- for i, prob in enumerate(cumulative):
159
- ax.annotate(f'{prob:.3f}',
160
- (i, prob),
161
- xytext=(0, 10),
162
- textcoords='offset points',
163
- ha='center',
164
- va='bottom')
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- # Adjust layout to prevent label cutoff
167
  plt.tight_layout()
168
  return fig
169
 
170
  def predict(file_obj, top_kmers=10, fasta_text=""):
171
- """
172
- Main prediction function for the Gradio interface.
173
- """
174
  # Handle input
175
  if fasta_text.strip():
176
  text = fasta_text.strip()
177
  elif file_obj is not None:
178
  try:
179
- # File input will be a filepath since we specified type="filepath"
180
  with open(file_obj, 'r') as f:
181
  text = f.read()
182
  except Exception as e:
183
- return f"Error reading file: {str(e)}\nPlease ensure you're uploading a valid FASTA text file.", None, None
184
  else:
185
- return "Please provide a FASTA sequence either by file upload or text input.", None, None
186
 
187
  # Parse FASTA
188
  sequences = parse_fasta(text)
189
  if not sequences:
190
- return "No valid FASTA sequences found in input.", None, None
191
 
192
  header, seq = sequences[0]
193
 
194
- # Process sequence
195
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
196
  try:
197
  model = VirusClassifier(256).to(device)
198
- # Load model weights safely
199
  model.load_state_dict(torch.load('model.pt', map_location=device, weights_only=True))
200
  scaler = joblib.load('scaler.pkl')
201
  except Exception as e:
@@ -206,42 +202,24 @@ def predict(file_obj, top_kmers=10, fasta_text=""):
206
  scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
207
  x_tensor = torch.FloatTensor(scaled_vector).to(device)
208
 
209
- # Calculate SHAP values and predictions
210
- shap_values, human_prob = calculate_shap_values(model, x_tensor)
211
 
212
- # Generate k-mer information
213
- kmers = [''.join(p) for p in product("ACGT", repeat=4)]
214
- important_indices = np.argsort(np.abs(shap_values))[-top_kmers:]
215
-
216
- important_kmers = []
217
- for idx in important_indices:
218
- important_kmers.append({
219
- 'kmer': kmers[idx],
220
- 'impact': shap_values[idx],
221
- 'frequency': freq_vector[idx] * 100,
222
- 'significance': scaled_vector[0][idx]
223
- })
224
-
225
- # Format results text
226
  results = [
227
  f"Sequence: {header}",
228
- f"Prediction: {'Human' if human_prob > 0.5 else 'Non-human'} Origin",
229
- f"Confidence: {max(human_prob, 1-human_prob):.3f}",
230
- f"Human Probability: {human_prob:.3f}",
231
- "\nTop Contributing k-mers:",
232
  ]
233
-
234
- for kmer in important_kmers:
235
- direction = "β†’ Human" if kmer['impact'] > 0 else "β†’ Non-human"
236
- results.append(
237
- f"β€’ {kmer['kmer']}: {direction} "
238
- f"(impact: {kmer['impact']:.3f}, "
239
- f"freq: {kmer['frequency']:.2f}%)"
240
- )
241
 
242
- # Generate plots
243
- shap_plot = create_importance_plot(shap_values, kmers, top_kmers)
244
- contribution_plot = create_contribution_plot(important_kmers, human_prob)
 
 
 
245
 
246
  # Convert plots to images
247
  def fig_to_image(fig):
@@ -252,30 +230,19 @@ def predict(file_obj, top_kmers=10, fasta_text=""):
252
  plt.close(fig)
253
  return img
254
 
255
- return "\n".join(results), fig_to_image(shap_plot), fig_to_image(contribution_plot)
256
 
257
  # Create Gradio interface
258
  css = """
259
  .gradio-container {
260
  font-family: 'IBM Plex Sans', sans-serif;
261
  }
262
- .interpretation-container {
263
- margin-top: 20px;
264
- padding: 15px;
265
- border-radius: 8px;
266
- background-color: #f8f9fa;
267
- }
268
  """
269
 
270
  with gr.Blocks(css=css) as iface:
271
  gr.Markdown("""
272
  # Virus Host Classifier
273
- This tool predicts whether a viral sequence is likely of human or non-human origin using k-mer frequency analysis.
274
-
275
- ### Instructions
276
- 1. Upload a FASTA file or paste your sequence in FASTA format
277
- 2. Adjust the number of top k-mers to display (default: 10)
278
- 3. View the prediction results and feature importance visualizations
279
  """)
280
 
281
  with gr.Row():
@@ -283,7 +250,7 @@ with gr.Blocks(css=css) as iface:
283
  file_input = gr.File(
284
  label="Upload FASTA file",
285
  file_types=[".fasta", ".fa", ".txt"],
286
- type="filepath" # Changed to filepath which is one of the valid options
287
  )
288
  text_input = gr.Textbox(
289
  label="Or paste FASTA sequence",
@@ -292,7 +259,7 @@ with gr.Blocks(css=css) as iface:
292
  )
293
  top_k = gr.Slider(
294
  minimum=5,
295
- maximum=20,
296
  value=10,
297
  step=1,
298
  label="Number of top k-mers to display"
@@ -301,20 +268,23 @@ with gr.Blocks(css=css) as iface:
301
 
302
  with gr.Column(scale=2):
303
  results = gr.Textbox(label="Analysis Results", lines=10)
304
- shap_plot = gr.Image(label="Feature Importance Plot")
305
- contribution_plot = gr.Image(label="Cumulative Contribution Plot")
306
 
307
  submit_btn.click(
308
  predict,
309
  inputs=[file_input, top_k, text_input],
310
- outputs=[results, shap_plot, contribution_plot]
311
  )
312
 
313
  gr.Markdown("""
314
- ### About
315
- - Uses 4-mer frequencies as sequence features
316
- - Employs SHAP-like values for feature importance interpretation
317
- - Visualizes cumulative feature contributions to the final prediction
 
 
 
318
  """)
319
 
320
  if __name__ == "__main__":
 
29
  return self.network(x)
30
 
31
  def parse_fasta(text):
32
+ """Parse FASTA formatted text into a list of (header, sequence)."""
 
 
33
  sequences = []
34
  current_header = None
35
  current_sequence = []
 
50
  return sequences
51
 
52
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
53
+ """Convert a sequence to a k-mer frequency vector."""
 
 
54
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
55
  kmer_dict = {km: i for i, km in enumerate(kmers)}
56
  vec = np.zeros(len(kmers), dtype=np.float32)
 
68
 
69
  def calculate_shap_values(model, x_tensor):
70
  """
71
+ Calculate SHAP values using a simple ablation approach.
72
+ Returns shap values and model prediction.
73
  """
74
  model.eval()
75
  with torch.no_grad():
76
+ # Get baseline prediction
77
  baseline_output = model(x_tensor)
78
+ baseline_probs = torch.softmax(baseline_output, dim=1)
79
+ baseline_prob = baseline_probs[0, 1].item() # Probability of human class
80
 
81
+ # Calculate impact of zeroing each feature
82
  shap_values = []
83
+ x_zeroed = x_tensor.clone()
84
  for i in range(x_tensor.shape[1]):
85
+ x_zeroed[0, i] = 0
86
+ output = model(x_zeroed)
87
+ probs = torch.softmax(output, dim=1)
88
+ prob = probs[0, 1].item()
89
+ impact = baseline_prob - prob # How much removing the feature changed the prediction
90
+ shap_values.append(impact)
91
+ x_zeroed[0, i] = x_tensor[0, i] # Restore the original value
92
 
93
  return np.array(shap_values), baseline_prob
94
 
95
+ def create_importance_bar_plot(shap_values, kmers, top_k=10):
96
+ """Create a bar plot of the most important k-mers."""
97
+ plt.rcParams.update({'font.size': 10})
98
+ plt.figure(figsize=(10, 6))
 
 
 
 
 
 
99
 
100
  # Sort by absolute importance
101
  indices = np.argsort(np.abs(shap_values))[-top_k:]
102
  values = shap_values[indices]
103
  features = [kmers[i] for i in indices]
104
 
105
+ colors = ['#ff9999' if v > 0 else '#99ccff' for v in values]
106
 
107
  plt.barh(range(len(values)), values, color=colors)
108
  plt.yticks(range(len(values)), features)
109
+ plt.xlabel('SHAP value (impact on model output)')
110
  plt.title(f'Top {top_k} Most Influential k-mers')
111
+ plt.gca().invert_yaxis() # Most important at top
112
 
113
+ return plt.gcf()
114
 
115
+ def visualize_sequence_impacts(sequence, kmers, shap_values, base_prob):
116
  """
117
+ Create a SHAP-style visualization of sequence impacts.
118
+ Shows each k-mer's contribution in context.
119
  """
120
+ k = 4 # k-mer size
121
+ kmer_dict = {km: i for i, km in enumerate(kmers)}
 
 
 
 
 
122
 
123
+ # Find all k-mers and their impacts
124
+ kmer_impacts = []
125
+ for i in range(len(sequence) - k + 1):
126
+ kmer = sequence[i:i+k]
127
+ if kmer in kmer_dict:
128
+ impact = shap_values[kmer_dict[kmer]]
129
+ kmer_impacts.append((i, kmer, impact))
130
 
131
+ # Sort by absolute impact
132
+ kmer_impacts.sort(key=lambda x: abs(x[2]), reverse=True)
 
133
 
134
+ # Create the plot
135
+ fig = plt.figure(figsize=(20, max(10, len(kmer_impacts[:30])*0.3)))
136
+ ax = plt.gca()
 
 
 
 
137
 
138
+ # Add title and base value
139
+ plt.text(0.01, 1.02, f"base value = {base_prob:.3f}", transform=ax.transAxes, fontsize=12)
140
 
141
+ # Plot k-mers
142
+ y_position = 1
143
+ sequence_length = len(sequence)
 
 
 
 
144
 
145
+ for pos, kmer, impact in kmer_impacts[:30]: # Show top 30 most impactful k-mers
146
+ # Show sequence with highlighted k-mer
147
+ pre_sequence = sequence[:pos]
148
+ post_sequence = sequence[pos+k:]
149
+
150
+ # Choose color based on impact
151
+ color = '#ffcccb' if impact > 0 else '#cce0ff' # Light red or light blue
152
+ arrow = '↑' if impact > 0 else '↓'
153
+
154
+ # Calculate text positions
155
+ plt.text(0.01, y_position, pre_sequence, fontsize=10)
156
+ plt.text(0.01 + len(pre_sequence)/(sequence_length*1.5), y_position,
157
+ kmer, fontsize=10, bbox=dict(facecolor=color, alpha=0.3, pad=2))
158
+ plt.text(0.01 + (len(pre_sequence) + len(kmer))/(sequence_length*1.5),
159
+ y_position, post_sequence, fontsize=10)
160
+
161
+ # Add impact value
162
+ plt.text(0.8, y_position, f"{arrow} {impact:+.3f}", fontsize=10)
163
+
164
+ y_position -= 0.03
165
 
166
+ plt.axis('off')
167
  plt.tight_layout()
168
  return fig
169
 
170
  def predict(file_obj, top_kmers=10, fasta_text=""):
171
+ """Main prediction function for Gradio interface."""
 
 
172
  # Handle input
173
  if fasta_text.strip():
174
  text = fasta_text.strip()
175
  elif file_obj is not None:
176
  try:
 
177
  with open(file_obj, 'r') as f:
178
  text = f.read()
179
  except Exception as e:
180
+ return f"Error reading file: {str(e)}", None, None
181
  else:
182
+ return "Please provide a FASTA sequence.", None, None
183
 
184
  # Parse FASTA
185
  sequences = parse_fasta(text)
186
  if not sequences:
187
+ return "No valid FASTA sequences found.", None, None
188
 
189
  header, seq = sequences[0]
190
 
191
+ # Load model and process sequence
192
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
193
  try:
194
  model = VirusClassifier(256).to(device)
 
195
  model.load_state_dict(torch.load('model.pt', map_location=device, weights_only=True))
196
  scaler = joblib.load('scaler.pkl')
197
  except Exception as e:
 
202
  scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
203
  x_tensor = torch.FloatTensor(scaled_vector).to(device)
204
 
205
+ # Calculate SHAP values and get prediction
206
+ shap_values, prob_human = calculate_shap_values(model, x_tensor)
207
 
208
+ # Generate result text
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  results = [
210
  f"Sequence: {header}",
211
+ f"Prediction: {'Human' if prob_human > 0.5 else 'Non-human'} Origin",
212
+ f"Confidence: {max(prob_human, 1-prob_human):.3f}",
213
+ f"Human Probability: {prob_human:.3f}",
214
+ "\nTop Contributing k-mers:"
215
  ]
 
 
 
 
 
 
 
 
216
 
217
+ # Get k-mers for visualization
218
+ kmers = [''.join(p) for p in product("ACGT", repeat=4)]
219
+
220
+ # Create visualizations
221
+ importance_plot = create_importance_bar_plot(shap_values, kmers, top_kmers)
222
+ sequence_plot = visualize_sequence_impacts(seq, kmers, shap_values, prob_human)
223
 
224
  # Convert plots to images
225
  def fig_to_image(fig):
 
230
  plt.close(fig)
231
  return img
232
 
233
+ return "\n".join(results), fig_to_image(importance_plot), fig_to_image(sequence_plot)
234
 
235
  # Create Gradio interface
236
  css = """
237
  .gradio-container {
238
  font-family: 'IBM Plex Sans', sans-serif;
239
  }
 
 
 
 
 
 
240
  """
241
 
242
  with gr.Blocks(css=css) as iface:
243
  gr.Markdown("""
244
  # Virus Host Classifier
245
+ Predicts whether a viral sequence is of human or non-human origin using k-mer analysis.
 
 
 
 
 
246
  """)
247
 
248
  with gr.Row():
 
250
  file_input = gr.File(
251
  label="Upload FASTA file",
252
  file_types=[".fasta", ".fa", ".txt"],
253
+ type="filepath"
254
  )
255
  text_input = gr.Textbox(
256
  label="Or paste FASTA sequence",
 
259
  )
260
  top_k = gr.Slider(
261
  minimum=5,
262
+ maximum=30,
263
  value=10,
264
  step=1,
265
  label="Number of top k-mers to display"
 
268
 
269
  with gr.Column(scale=2):
270
  results = gr.Textbox(label="Analysis Results", lines=10)
271
+ kmer_plot = gr.Image(label="K-mer Importance Plot")
272
+ shap_plot = gr.Image(label="Sequence Impact Visualization (SHAP-style)")
273
 
274
  submit_btn.click(
275
  predict,
276
  inputs=[file_input, top_k, text_input],
277
+ outputs=[results, kmer_plot, shap_plot]
278
  )
279
 
280
  gr.Markdown("""
281
+ ### Visualization Guide
282
+ - **K-mer Importance Plot**: Shows the most influential k-mers and their SHAP values
283
+ - **Sequence Impact Visualization**: Shows the sequence with highlighted k-mers:
284
+ - Red highlights = pushing toward human origin
285
+ - Blue highlights = pushing toward non-human origin
286
+ - Arrows (↑/↓) show impact direction
287
+ - Values show impact magnitude
288
  """)
289
 
290
  if __name__ == "__main__":