Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ import torch.nn as nn
|
|
7 |
import shap
|
8 |
import matplotlib.pyplot as plt
|
9 |
import io
|
|
|
10 |
|
11 |
class VirusClassifier(nn.Module):
|
12 |
def __init__(self, input_shape: int):
|
@@ -44,20 +45,15 @@ class VirusClassifier(nn.Module):
|
|
44 |
|
45 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
46 |
"""Convert sequence to k-mer frequency vector"""
|
47 |
-
# Generate all possible k-mers
|
48 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
49 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
50 |
-
|
51 |
-
# Initialize vector
|
52 |
vec = np.zeros(len(kmers), dtype=np.float32)
|
53 |
|
54 |
-
# Count k-mers
|
55 |
for i in range(len(sequence) - k + 1):
|
56 |
kmer = sequence[i:i+k]
|
57 |
if kmer in kmer_dict:
|
58 |
vec[kmer_dict[kmer]] += 1
|
59 |
|
60 |
-
# Convert to frequencies
|
61 |
total_kmers = len(sequence) - k + 1
|
62 |
if total_kmers > 0:
|
63 |
vec = vec / total_kmers
|
@@ -88,7 +84,6 @@ def predict(file_obj):
|
|
88 |
if file_obj is None:
|
89 |
return "Please upload a FASTA file", None
|
90 |
|
91 |
-
# Read the file content
|
92 |
try:
|
93 |
if isinstance(file_obj, str):
|
94 |
text = file_obj
|
@@ -97,43 +92,31 @@ def predict(file_obj):
|
|
97 |
except Exception as e:
|
98 |
return f"Error reading file: {str(e)}", None
|
99 |
|
100 |
-
|
101 |
-
k = 4 # k-mer size
|
102 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
103 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
104 |
|
105 |
-
# Load model and scaler
|
106 |
try:
|
107 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
108 |
-
model = VirusClassifier(256).to(device)
|
109 |
-
|
110 |
-
# Load model with explicit map_location
|
111 |
state_dict = torch.load('model.pt', map_location=device)
|
112 |
model.load_state_dict(state_dict)
|
113 |
-
|
114 |
-
# Load scaler
|
115 |
scaler = joblib.load('scaler.pkl')
|
116 |
-
|
117 |
-
# Set model to evaluation mode
|
118 |
model.eval()
|
119 |
except Exception as e:
|
120 |
return f"Error loading model: {str(e)}", None
|
121 |
|
122 |
-
# Initialize variables to store results and plot
|
123 |
results_text = ""
|
124 |
plot_image = None
|
125 |
|
126 |
try:
|
127 |
sequences = parse_fasta(text)
|
128 |
-
# For simplicity, process only the first sequence for plotting
|
129 |
header, seq = sequences[0]
|
130 |
|
131 |
-
# Get raw frequency vector and scaled vector
|
132 |
raw_freq_vector = sequence_to_kmer_vector(seq)
|
133 |
kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
|
134 |
X_tensor = torch.FloatTensor(kmer_vector).to(device)
|
135 |
|
136 |
-
# Get predictions and feature importance
|
137 |
with torch.no_grad():
|
138 |
output = model(X_tensor)
|
139 |
probs = torch.softmax(output, dim=1)
|
@@ -141,11 +124,9 @@ def predict(file_obj):
|
|
141 |
importance = model.get_feature_importance(X_tensor)
|
142 |
kmer_importance = importance[0].cpu().numpy()
|
143 |
|
144 |
-
# Normalize importance scores to original scale
|
145 |
if np.max(np.abs(kmer_importance)) != 0:
|
146 |
kmer_importance = kmer_importance / np.max(np.abs(kmer_importance)) * 0.002
|
147 |
|
148 |
-
# Get top 10 k-mers based on absolute importance
|
149 |
top_k = 10
|
150 |
top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
|
151 |
important_kmers = [
|
@@ -158,11 +139,9 @@ def predict(file_obj):
|
|
158 |
for i in top_indices
|
159 |
]
|
160 |
|
161 |
-
# Prepare SHAP-like values for waterfall plot
|
162 |
top_features = [item['kmer'] for item in important_kmers]
|
163 |
top_values = [item['importance'] for item in important_kmers]
|
164 |
|
165 |
-
# Combine the rest of the features into an "Others" category
|
166 |
others_mask = np.ones_like(kmer_importance, dtype=bool)
|
167 |
others_mask[top_indices] = False
|
168 |
others_sum = np.sum(kmer_importance[others_mask])
|
@@ -176,19 +155,15 @@ def predict(file_obj):
|
|
176 |
data=np.array([raw_freq_vector[kmer_dict[feat]] if feat != "Others" else np.sum(raw_freq_vector[others_mask]) for feat in top_features]),
|
177 |
feature_names=top_features
|
178 |
)
|
179 |
-
# Manually set expected_value to satisfy waterfall_legacy requirements
|
180 |
explanation.expected_value = 0
|
181 |
|
182 |
-
# Generate waterfall plot using SHAP's legacy function
|
183 |
fig = shap.plots._waterfall.waterfall_legacy(explanation, show=False)
|
184 |
|
185 |
-
# Save plot to a bytes buffer
|
186 |
buf = io.BytesIO()
|
187 |
fig.savefig(buf, format='png')
|
188 |
buf.seek(0)
|
189 |
-
plot_image = buf
|
190 |
|
191 |
-
# Format textual results for the first sequence
|
192 |
pred_class = 1 if probs[0][1] > probs[0][0] else 0
|
193 |
pred_label = 'human' if pred_class == 1 else 'non-human'
|
194 |
|
@@ -213,7 +188,6 @@ Most influential k-mers (ranked by importance):"""
|
|
213 |
|
214 |
return results_text, plot_image
|
215 |
|
216 |
-
# Create the interface with two outputs: Textbox and Image
|
217 |
iface = gr.Interface(
|
218 |
fn=predict,
|
219 |
inputs=gr.File(label="Upload FASTA file", type="binary"),
|
@@ -221,6 +195,5 @@ iface = gr.Interface(
|
|
221 |
title="Virus Host Classifier"
|
222 |
)
|
223 |
|
224 |
-
# Launch the interface
|
225 |
if __name__ == "__main__":
|
226 |
-
iface.launch()
|
|
|
7 |
import shap
|
8 |
import matplotlib.pyplot as plt
|
9 |
import io
|
10 |
+
from PIL import Image # Import PIL for image handling
|
11 |
|
12 |
class VirusClassifier(nn.Module):
|
13 |
def __init__(self, input_shape: int):
|
|
|
45 |
|
46 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
47 |
"""Convert sequence to k-mer frequency vector"""
|
|
|
48 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
49 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
|
|
|
|
50 |
vec = np.zeros(len(kmers), dtype=np.float32)
|
51 |
|
|
|
52 |
for i in range(len(sequence) - k + 1):
|
53 |
kmer = sequence[i:i+k]
|
54 |
if kmer in kmer_dict:
|
55 |
vec[kmer_dict[kmer]] += 1
|
56 |
|
|
|
57 |
total_kmers = len(sequence) - k + 1
|
58 |
if total_kmers > 0:
|
59 |
vec = vec / total_kmers
|
|
|
84 |
if file_obj is None:
|
85 |
return "Please upload a FASTA file", None
|
86 |
|
|
|
87 |
try:
|
88 |
if isinstance(file_obj, str):
|
89 |
text = file_obj
|
|
|
92 |
except Exception as e:
|
93 |
return f"Error reading file: {str(e)}", None
|
94 |
|
95 |
+
k = 4
|
|
|
96 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
97 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
98 |
|
|
|
99 |
try:
|
100 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
101 |
+
model = VirusClassifier(256).to(device)
|
|
|
|
|
102 |
state_dict = torch.load('model.pt', map_location=device)
|
103 |
model.load_state_dict(state_dict)
|
|
|
|
|
104 |
scaler = joblib.load('scaler.pkl')
|
|
|
|
|
105 |
model.eval()
|
106 |
except Exception as e:
|
107 |
return f"Error loading model: {str(e)}", None
|
108 |
|
|
|
109 |
results_text = ""
|
110 |
plot_image = None
|
111 |
|
112 |
try:
|
113 |
sequences = parse_fasta(text)
|
|
|
114 |
header, seq = sequences[0]
|
115 |
|
|
|
116 |
raw_freq_vector = sequence_to_kmer_vector(seq)
|
117 |
kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
|
118 |
X_tensor = torch.FloatTensor(kmer_vector).to(device)
|
119 |
|
|
|
120 |
with torch.no_grad():
|
121 |
output = model(X_tensor)
|
122 |
probs = torch.softmax(output, dim=1)
|
|
|
124 |
importance = model.get_feature_importance(X_tensor)
|
125 |
kmer_importance = importance[0].cpu().numpy()
|
126 |
|
|
|
127 |
if np.max(np.abs(kmer_importance)) != 0:
|
128 |
kmer_importance = kmer_importance / np.max(np.abs(kmer_importance)) * 0.002
|
129 |
|
|
|
130 |
top_k = 10
|
131 |
top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
|
132 |
important_kmers = [
|
|
|
139 |
for i in top_indices
|
140 |
]
|
141 |
|
|
|
142 |
top_features = [item['kmer'] for item in important_kmers]
|
143 |
top_values = [item['importance'] for item in important_kmers]
|
144 |
|
|
|
145 |
others_mask = np.ones_like(kmer_importance, dtype=bool)
|
146 |
others_mask[top_indices] = False
|
147 |
others_sum = np.sum(kmer_importance[others_mask])
|
|
|
155 |
data=np.array([raw_freq_vector[kmer_dict[feat]] if feat != "Others" else np.sum(raw_freq_vector[others_mask]) for feat in top_features]),
|
156 |
feature_names=top_features
|
157 |
)
|
|
|
158 |
explanation.expected_value = 0
|
159 |
|
|
|
160 |
fig = shap.plots._waterfall.waterfall_legacy(explanation, show=False)
|
161 |
|
|
|
162 |
buf = io.BytesIO()
|
163 |
fig.savefig(buf, format='png')
|
164 |
buf.seek(0)
|
165 |
+
plot_image = Image.open(buf) # Convert BytesIO to PIL Image
|
166 |
|
|
|
167 |
pred_class = 1 if probs[0][1] > probs[0][0] else 0
|
168 |
pred_label = 'human' if pred_class == 1 else 'non-human'
|
169 |
|
|
|
188 |
|
189 |
return results_text, plot_image
|
190 |
|
|
|
191 |
iface = gr.Interface(
|
192 |
fn=predict,
|
193 |
inputs=gr.File(label="Upload FASTA file", type="binary"),
|
|
|
195 |
title="Virus Host Classifier"
|
196 |
)
|
197 |
|
|
|
198 |
if __name__ == "__main__":
|
199 |
+
iface.launch(share=True) # Set share=True to create a public link
|