Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,9 @@ import joblib
|
|
4 |
import numpy as np
|
5 |
from itertools import product
|
6 |
import torch.nn as nn
|
|
|
|
|
|
|
7 |
|
8 |
class VirusClassifier(nn.Module):
|
9 |
def __init__(self, input_shape: int):
|
@@ -83,7 +86,7 @@ def parse_fasta(text):
|
|
83 |
|
84 |
def predict(file_obj):
|
85 |
if file_obj is None:
|
86 |
-
return "Please upload a FASTA file"
|
87 |
|
88 |
# Read the file content
|
89 |
try:
|
@@ -92,7 +95,7 @@ def predict(file_obj):
|
|
92 |
else:
|
93 |
text = file_obj.decode('utf-8')
|
94 |
except Exception as e:
|
95 |
-
return f"Error reading file: {str(e)}"
|
96 |
|
97 |
# Generate k-mer dictionary
|
98 |
k = 4 # k-mer size
|
@@ -114,78 +117,108 @@ def predict(file_obj):
|
|
114 |
# Set model to evaluation mode
|
115 |
model.eval()
|
116 |
except Exception as e:
|
117 |
-
return f"Error loading model: {str(e)}
|
118 |
|
119 |
-
#
|
120 |
-
|
|
|
|
|
121 |
try:
|
122 |
sequences = parse_fasta(text)
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
|
|
139 |
kmer_importance = kmer_importance / np.max(np.abs(kmer_importance)) * 0.002
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
Prediction: {pred_label}
|
160 |
Confidence: {float(max(probs[0])):0.4f}
|
161 |
Human probability: {float(probs[0][1]):0.4f}
|
162 |
Non-human probability: {float(probs[0][0]):0.4f}
|
163 |
-
|
164 |
Most influential k-mers (ranked by importance):"""
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
results.append(result)
|
176 |
except Exception as e:
|
177 |
-
return f"Error processing sequences: {str(e)}"
|
178 |
|
179 |
-
return
|
180 |
|
181 |
-
# Create the interface
|
182 |
iface = gr.Interface(
|
183 |
fn=predict,
|
184 |
inputs=gr.File(label="Upload FASTA file", type="binary"),
|
185 |
-
outputs=gr.Textbox(label="Results"),
|
186 |
title="Virus Host Classifier"
|
187 |
)
|
188 |
|
189 |
# Launch the interface
|
190 |
if __name__ == "__main__":
|
191 |
-
iface.launch()
|
|
|
4 |
import numpy as np
|
5 |
from itertools import product
|
6 |
import torch.nn as nn
|
7 |
+
import shap
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import io
|
10 |
|
11 |
class VirusClassifier(nn.Module):
|
12 |
def __init__(self, input_shape: int):
|
|
|
86 |
|
87 |
def predict(file_obj):
|
88 |
if file_obj is None:
|
89 |
+
return "Please upload a FASTA file", None
|
90 |
|
91 |
# Read the file content
|
92 |
try:
|
|
|
95 |
else:
|
96 |
text = file_obj.decode('utf-8')
|
97 |
except Exception as e:
|
98 |
+
return f"Error reading file: {str(e)}", None
|
99 |
|
100 |
# Generate k-mer dictionary
|
101 |
k = 4 # k-mer size
|
|
|
117 |
# Set model to evaluation mode
|
118 |
model.eval()
|
119 |
except Exception as e:
|
120 |
+
return f"Error loading model: {str(e)}", None
|
121 |
|
122 |
+
# Initialize variables to store results and plot
|
123 |
+
results_text = ""
|
124 |
+
plot_image = None
|
125 |
+
|
126 |
try:
|
127 |
sequences = parse_fasta(text)
|
128 |
+
# For simplicity, process only the first sequence for plotting
|
129 |
+
header, seq = sequences[0]
|
130 |
+
|
131 |
+
# Get raw frequency vector and scaled vector
|
132 |
+
raw_freq_vector = sequence_to_kmer_vector(seq)
|
133 |
+
kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
|
134 |
+
X_tensor = torch.FloatTensor(kmer_vector).to(device)
|
135 |
+
|
136 |
+
# Get predictions and feature importance
|
137 |
+
with torch.no_grad():
|
138 |
+
output = model(X_tensor)
|
139 |
+
probs = torch.softmax(output, dim=1)
|
140 |
+
|
141 |
+
importance = model.get_feature_importance(X_tensor)
|
142 |
+
kmer_importance = importance[0].cpu().numpy()
|
143 |
+
|
144 |
+
# Normalize importance scores to original scale
|
145 |
+
if np.max(np.abs(kmer_importance)) != 0:
|
146 |
kmer_importance = kmer_importance / np.max(np.abs(kmer_importance)) * 0.002
|
147 |
+
|
148 |
+
# Get top 10 k-mers based on absolute importance
|
149 |
+
top_k = 10
|
150 |
+
top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
|
151 |
+
important_kmers = [
|
152 |
+
{
|
153 |
+
'kmer': list(kmer_dict.keys())[list(kmer_dict.values()).index(i)],
|
154 |
+
'importance': float(kmer_importance[i]),
|
155 |
+
'frequency': float(raw_freq_vector[i]),
|
156 |
+
'scaled': float(kmer_vector[0][i])
|
157 |
+
}
|
158 |
+
for i in top_indices
|
159 |
+
]
|
160 |
+
|
161 |
+
# Prepare SHAP-like values for waterfall plot
|
162 |
+
top_features = [item['kmer'] for item in important_kmers]
|
163 |
+
top_values = [item['importance'] for item in important_kmers]
|
164 |
+
|
165 |
+
# Combine the rest of the features into an "Others" category
|
166 |
+
others_mask = np.ones_like(kmer_importance, dtype=bool)
|
167 |
+
others_mask[top_indices] = False
|
168 |
+
others_sum = np.sum(kmer_importance[others_mask])
|
169 |
+
|
170 |
+
top_features.append("Others")
|
171 |
+
top_values.append(others_sum)
|
172 |
+
|
173 |
+
explanation = shap.Explanation(
|
174 |
+
values=np.array(top_values),
|
175 |
+
base_values=0,
|
176 |
+
data=np.array([raw_freq_vector[kmer_dict[feat]] if feat != "Others" else np.sum(raw_freq_vector[others_mask]) for feat in top_features]),
|
177 |
+
feature_names=top_features
|
178 |
+
)
|
179 |
+
|
180 |
+
# Generate waterfall plot using SHAP's legacy function
|
181 |
+
fig = shap.plots._waterfall.waterfall_legacy(explanation, show=False)
|
182 |
+
|
183 |
+
# Save plot to a bytes buffer
|
184 |
+
buf = io.BytesIO()
|
185 |
+
fig.savefig(buf, format='png')
|
186 |
+
buf.seek(0)
|
187 |
+
plot_image = buf
|
188 |
+
|
189 |
+
# Format textual results for the first sequence
|
190 |
+
pred_class = 1 if probs[0][1] > probs[0][0] else 0
|
191 |
+
pred_label = 'human' if pred_class == 1 else 'non-human'
|
192 |
+
|
193 |
+
results_text += f"""Sequence: {header}
|
194 |
Prediction: {pred_label}
|
195 |
Confidence: {float(max(probs[0])):0.4f}
|
196 |
Human probability: {float(probs[0][1]):0.4f}
|
197 |
Non-human probability: {float(probs[0][0]):0.4f}
|
|
|
198 |
Most influential k-mers (ranked by importance):"""
|
199 |
+
|
200 |
+
for kmer in important_kmers:
|
201 |
+
results_text += f"\n {kmer['kmer']}: "
|
202 |
+
results_text += f"impact={kmer['importance']:.4f}, "
|
203 |
+
results_text += f"occurrence={kmer['frequency']*100:.2f}% of sequence "
|
204 |
+
if kmer['scaled'] > 0:
|
205 |
+
results_text += f"(appears {abs(kmer['scaled']):.2f}σ more than average)"
|
206 |
+
else:
|
207 |
+
results_text += f"(appears {abs(kmer['scaled']):.2f}σ less than average)"
|
208 |
+
|
|
|
209 |
except Exception as e:
|
210 |
+
return f"Error processing sequences: {str(e)}", None
|
211 |
|
212 |
+
return results_text, plot_image
|
213 |
|
214 |
+
# Create the interface with two outputs: Textbox and Image
|
215 |
iface = gr.Interface(
|
216 |
fn=predict,
|
217 |
inputs=gr.File(label="Upload FASTA file", type="binary"),
|
218 |
+
outputs=[gr.Textbox(label="Results"), gr.Image(label="SHAP Waterfall Plot")],
|
219 |
title="Virus Host Classifier"
|
220 |
)
|
221 |
|
222 |
# Launch the interface
|
223 |
if __name__ == "__main__":
|
224 |
+
iface.launch()
|