Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -93,7 +93,11 @@ def create_importance_plot(shap_values, kmers, top_k=10):
|
|
93 |
"""
|
94 |
Create horizontal bar plot of feature importance.
|
95 |
"""
|
96 |
-
|
|
|
|
|
|
|
|
|
97 |
fig = plt.figure(figsize=(10, 8))
|
98 |
|
99 |
# Sort by absolute importance
|
@@ -115,8 +119,13 @@ def create_contribution_plot(important_kmers, final_prob):
|
|
115 |
"""
|
116 |
Create waterfall plot showing cumulative feature contributions.
|
117 |
"""
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
base_prob = 0.5
|
122 |
cumulative = [base_prob]
|
@@ -126,15 +135,36 @@ def create_contribution_plot(important_kmers, final_prob):
|
|
126 |
cumulative.append(cumulative[-1] + kmer_info['impact'])
|
127 |
labels.append(kmer_info['kmer'])
|
128 |
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
137 |
|
|
|
|
|
138 |
return fig
|
139 |
|
140 |
def predict(file_obj, top_kmers=10, fasta_text=""):
|
@@ -165,7 +195,8 @@ def predict(file_obj, top_kmers=10, fasta_text=""):
|
|
165 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
166 |
try:
|
167 |
model = VirusClassifier(256).to(device)
|
168 |
-
model
|
|
|
169 |
scaler = joblib.load('scaler.pkl')
|
170 |
except Exception as e:
|
171 |
return f"Error loading model: {str(e)}", None, None
|
|
|
93 |
"""
|
94 |
Create horizontal bar plot of feature importance.
|
95 |
"""
|
96 |
+
# Set style directly instead of using seaborn
|
97 |
+
plt.rcParams['figure.facecolor'] = '#ffffff'
|
98 |
+
plt.rcParams['axes.facecolor'] = '#ffffff'
|
99 |
+
plt.rcParams['axes.grid'] = True
|
100 |
+
plt.rcParams['grid.alpha'] = 0.3
|
101 |
fig = plt.figure(figsize=(10, 8))
|
102 |
|
103 |
# Sort by absolute importance
|
|
|
119 |
"""
|
120 |
Create waterfall plot showing cumulative feature contributions.
|
121 |
"""
|
122 |
+
# Set style parameters
|
123 |
+
plt.rcParams['figure.facecolor'] = '#ffffff'
|
124 |
+
plt.rcParams['axes.facecolor'] = '#ffffff'
|
125 |
+
plt.rcParams['axes.grid'] = True
|
126 |
+
plt.rcParams['grid.alpha'] = 0.3
|
127 |
+
|
128 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
129 |
|
130 |
base_prob = 0.5
|
131 |
cumulative = [base_prob]
|
|
|
135 |
cumulative.append(cumulative[-1] + kmer_info['impact'])
|
136 |
labels.append(kmer_info['kmer'])
|
137 |
|
138 |
+
# Plot cumulative line with markers
|
139 |
+
line = ax.plot(range(len(cumulative)), cumulative, '-o',
|
140 |
+
color='#3498db', linewidth=2,
|
141 |
+
marker='o', markersize=8,
|
142 |
+
markerfacecolor='white',
|
143 |
+
markeredgecolor='#3498db',
|
144 |
+
markeredgewidth=2)
|
145 |
+
|
146 |
+
# Add reference line at 0.5
|
147 |
+
ax.axhline(y=0.5, color='#95a5a6', linestyle='--', alpha=0.5)
|
148 |
+
|
149 |
+
# Customize plot
|
150 |
+
ax.set_xticks(range(len(labels)))
|
151 |
+
ax.set_xticklabels(labels, rotation=45, ha='right')
|
152 |
+
ax.set_ylim(0, 1)
|
153 |
+
ax.grid(True, axis='y', linestyle='--', alpha=0.3)
|
154 |
+
ax.set_title('Cumulative Feature Contributions')
|
155 |
+
ax.set_ylabel('Probability of Human Origin')
|
156 |
|
157 |
+
# Add value labels
|
158 |
+
for i, prob in enumerate(cumulative):
|
159 |
+
ax.annotate(f'{prob:.3f}',
|
160 |
+
(i, prob),
|
161 |
+
xytext=(0, 10),
|
162 |
+
textcoords='offset points',
|
163 |
+
ha='center',
|
164 |
+
va='bottom')
|
165 |
|
166 |
+
# Adjust layout to prevent label cutoff
|
167 |
+
plt.tight_layout()
|
168 |
return fig
|
169 |
|
170 |
def predict(file_obj, top_kmers=10, fasta_text=""):
|
|
|
195 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
196 |
try:
|
197 |
model = VirusClassifier(256).to(device)
|
198 |
+
# Load model weights safely
|
199 |
+
model.load_state_dict(torch.load('model.pt', map_location=device, weights_only=True))
|
200 |
scaler = joblib.load('scaler.pkl')
|
201 |
except Exception as e:
|
202 |
return f"Error loading model: {str(e)}", None, None
|