hiyata commited on
Commit
4a7c026
·
verified ·
1 Parent(s): 2897f12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -57
app.py CHANGED
@@ -4,6 +4,9 @@ import joblib
4
  import numpy as np
5
  from itertools import product
6
  import torch.nn as nn
 
 
 
7
 
8
  class VirusClassifier(nn.Module):
9
  def __init__(self, input_shape: int):
@@ -83,7 +86,7 @@ def parse_fasta(text):
83
 
84
  def predict(file_obj):
85
  if file_obj is None:
86
- return "Please upload a FASTA file"
87
 
88
  # Read the file content
89
  try:
@@ -92,7 +95,7 @@ def predict(file_obj):
92
  else:
93
  text = file_obj.decode('utf-8')
94
  except Exception as e:
95
- return f"Error reading file: {str(e)}"
96
 
97
  # Generate k-mer dictionary
98
  k = 4 # k-mer size
@@ -114,78 +117,108 @@ def predict(file_obj):
114
  # Set model to evaluation mode
115
  model.eval()
116
  except Exception as e:
117
- return f"Error loading model: {str(e)}\nFull traceback: {str(e.__traceback__)}"
118
 
119
- # Get predictions
120
- results = []
 
 
121
  try:
122
  sequences = parse_fasta(text)
123
- for header, seq in sequences:
124
- # Get raw frequency vector and scaled vector
125
- raw_freq_vector = sequence_to_kmer_vector(seq)
126
- kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
127
- X_tensor = torch.FloatTensor(kmer_vector).to(device)
128
-
129
- # Get predictions and feature importance
130
- with torch.no_grad():
131
- output = model(X_tensor)
132
- probs = torch.softmax(output, dim=1)
133
-
134
- # Calculate feature importance
135
- importance = model.get_feature_importance(X_tensor)
136
- kmer_importance = importance[0].cpu().numpy()
137
-
138
- # Normalize importance scores to original scale
 
 
139
  kmer_importance = kmer_importance / np.max(np.abs(kmer_importance)) * 0.002
140
-
141
- # Get top 10 k-mers based on absolute importance
142
- top_k = 10
143
- top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
144
- important_kmers = [
145
- {
146
- 'kmer': list(kmer_dict.keys())[list(kmer_dict.values()).index(i)],
147
- 'importance': float(kmer_importance[i]),
148
- 'frequency': float(raw_freq_vector[i]),
149
- 'scaled': float(kmer_vector[0][i])
150
- }
151
- for i in top_indices
152
- ]
153
-
154
- # Format results
155
- pred_class = 1 if probs[0][1] > probs[0][0] else 0
156
- pred_label = 'human' if pred_class == 1 else 'non-human'
157
-
158
- result = f"""Sequence: {header}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  Prediction: {pred_label}
160
  Confidence: {float(max(probs[0])):0.4f}
161
  Human probability: {float(probs[0][1]):0.4f}
162
  Non-human probability: {float(probs[0][0]):0.4f}
163
-
164
  Most influential k-mers (ranked by importance):"""
165
-
166
- for kmer in important_kmers:
167
- result += f"\n {kmer['kmer']}: "
168
- result += f"impact={kmer['importance']:.4f}, "
169
- result += f"occurrence={kmer['frequency']*100:.2f}% of sequence "
170
- if kmer['scaled'] > 0:
171
- result += f"(appears {abs(kmer['scaled']):.2f}σ more than average)"
172
- else:
173
- result += f"(appears {abs(kmer['scaled']):.2f}σ less than average)"
174
-
175
- results.append(result)
176
  except Exception as e:
177
- return f"Error processing sequences: {str(e)}"
178
 
179
- return "\n\n".join(results)
180
 
181
- # Create the interface
182
  iface = gr.Interface(
183
  fn=predict,
184
  inputs=gr.File(label="Upload FASTA file", type="binary"),
185
- outputs=gr.Textbox(label="Results"),
186
  title="Virus Host Classifier"
187
  )
188
 
189
  # Launch the interface
190
  if __name__ == "__main__":
191
- iface.launch()
 
4
  import numpy as np
5
  from itertools import product
6
  import torch.nn as nn
7
+ import shap
8
+ import matplotlib.pyplot as plt
9
+ import io
10
 
11
  class VirusClassifier(nn.Module):
12
  def __init__(self, input_shape: int):
 
86
 
87
  def predict(file_obj):
88
  if file_obj is None:
89
+ return "Please upload a FASTA file", None
90
 
91
  # Read the file content
92
  try:
 
95
  else:
96
  text = file_obj.decode('utf-8')
97
  except Exception as e:
98
+ return f"Error reading file: {str(e)}", None
99
 
100
  # Generate k-mer dictionary
101
  k = 4 # k-mer size
 
117
  # Set model to evaluation mode
118
  model.eval()
119
  except Exception as e:
120
+ return f"Error loading model: {str(e)}", None
121
 
122
+ # Initialize variables to store results and plot
123
+ results_text = ""
124
+ plot_image = None
125
+
126
  try:
127
  sequences = parse_fasta(text)
128
+ # For simplicity, process only the first sequence for plotting
129
+ header, seq = sequences[0]
130
+
131
+ # Get raw frequency vector and scaled vector
132
+ raw_freq_vector = sequence_to_kmer_vector(seq)
133
+ kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
134
+ X_tensor = torch.FloatTensor(kmer_vector).to(device)
135
+
136
+ # Get predictions and feature importance
137
+ with torch.no_grad():
138
+ output = model(X_tensor)
139
+ probs = torch.softmax(output, dim=1)
140
+
141
+ importance = model.get_feature_importance(X_tensor)
142
+ kmer_importance = importance[0].cpu().numpy()
143
+
144
+ # Normalize importance scores to original scale
145
+ if np.max(np.abs(kmer_importance)) != 0:
146
  kmer_importance = kmer_importance / np.max(np.abs(kmer_importance)) * 0.002
147
+
148
+ # Get top 10 k-mers based on absolute importance
149
+ top_k = 10
150
+ top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
151
+ important_kmers = [
152
+ {
153
+ 'kmer': list(kmer_dict.keys())[list(kmer_dict.values()).index(i)],
154
+ 'importance': float(kmer_importance[i]),
155
+ 'frequency': float(raw_freq_vector[i]),
156
+ 'scaled': float(kmer_vector[0][i])
157
+ }
158
+ for i in top_indices
159
+ ]
160
+
161
+ # Prepare SHAP-like values for waterfall plot
162
+ top_features = [item['kmer'] for item in important_kmers]
163
+ top_values = [item['importance'] for item in important_kmers]
164
+
165
+ # Combine the rest of the features into an "Others" category
166
+ others_mask = np.ones_like(kmer_importance, dtype=bool)
167
+ others_mask[top_indices] = False
168
+ others_sum = np.sum(kmer_importance[others_mask])
169
+
170
+ top_features.append("Others")
171
+ top_values.append(others_sum)
172
+
173
+ explanation = shap.Explanation(
174
+ values=np.array(top_values),
175
+ base_values=0,
176
+ data=np.array([raw_freq_vector[kmer_dict[feat]] if feat != "Others" else np.sum(raw_freq_vector[others_mask]) for feat in top_features]),
177
+ feature_names=top_features
178
+ )
179
+
180
+ # Generate waterfall plot using SHAP's legacy function
181
+ fig = shap.plots._waterfall.waterfall_legacy(explanation, show=False)
182
+
183
+ # Save plot to a bytes buffer
184
+ buf = io.BytesIO()
185
+ fig.savefig(buf, format='png')
186
+ buf.seek(0)
187
+ plot_image = buf
188
+
189
+ # Format textual results for the first sequence
190
+ pred_class = 1 if probs[0][1] > probs[0][0] else 0
191
+ pred_label = 'human' if pred_class == 1 else 'non-human'
192
+
193
+ results_text += f"""Sequence: {header}
194
  Prediction: {pred_label}
195
  Confidence: {float(max(probs[0])):0.4f}
196
  Human probability: {float(probs[0][1]):0.4f}
197
  Non-human probability: {float(probs[0][0]):0.4f}
 
198
  Most influential k-mers (ranked by importance):"""
199
+
200
+ for kmer in important_kmers:
201
+ results_text += f"\n {kmer['kmer']}: "
202
+ results_text += f"impact={kmer['importance']:.4f}, "
203
+ results_text += f"occurrence={kmer['frequency']*100:.2f}% of sequence "
204
+ if kmer['scaled'] > 0:
205
+ results_text += f"(appears {abs(kmer['scaled']):.2f}σ more than average)"
206
+ else:
207
+ results_text += f"(appears {abs(kmer['scaled']):.2f}σ less than average)"
208
+
 
209
  except Exception as e:
210
+ return f"Error processing sequences: {str(e)}", None
211
 
212
+ return results_text, plot_image
213
 
214
+ # Create the interface with two outputs: Textbox and Image
215
  iface = gr.Interface(
216
  fn=predict,
217
  inputs=gr.File(label="Upload FASTA file", type="binary"),
218
+ outputs=[gr.Textbox(label="Results"), gr.Image(label="SHAP Waterfall Plot")],
219
  title="Virus Host Classifier"
220
  )
221
 
222
  # Launch the interface
223
  if __name__ == "__main__":
224
+ iface.launch()