Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -90,7 +90,7 @@ def calculate_shap_values(model, x_tensor):
|
|
90 |
# Baseline
|
91 |
baseline_output = model(x_tensor)
|
92 |
baseline_probs = torch.softmax(baseline_output, dim=1)
|
93 |
-
baseline_prob = baseline_probs[0, 1].item() # Probability of 'human'
|
94 |
|
95 |
# Zeroing each feature to measure impact
|
96 |
shap_values = []
|
@@ -100,7 +100,7 @@ def calculate_shap_values(model, x_tensor):
|
|
100 |
x_zeroed[0, i] = 0.0
|
101 |
output = model(x_zeroed)
|
102 |
probs = torch.softmax(output, dim=1)
|
103 |
-
prob = probs[0, 1].item()
|
104 |
impact = baseline_prob - prob
|
105 |
shap_values.append(impact)
|
106 |
x_zeroed[0, i] = original_val # restore
|
@@ -354,6 +354,7 @@ def analyze_sequence(file_path, top_k=10, fasta_text="", window_size=500):
|
|
354 |
with torch.no_grad():
|
355 |
output = model(x_tensor)
|
356 |
probs = torch.softmax(output, dim=1)
|
|
|
357 |
pred_human = probs[0, 1].item()
|
358 |
|
359 |
# Calculate SHAP values
|
@@ -814,4 +815,4 @@ if __name__ == "__main__":
|
|
814 |
server_port=7860, # Default Gradio port
|
815 |
show_api=False, # Hide API docs
|
816 |
debug=False # Set to True for debugging
|
817 |
-
)
|
|
|
90 |
# Baseline
|
91 |
baseline_output = model(x_tensor)
|
92 |
baseline_probs = torch.softmax(baseline_output, dim=1)
|
93 |
+
baseline_prob = baseline_probs[0, 1].item() # Probability of 'human'
|
94 |
|
95 |
# Zeroing each feature to measure impact
|
96 |
shap_values = []
|
|
|
100 |
x_zeroed[0, i] = 0.0
|
101 |
output = model(x_zeroed)
|
102 |
probs = torch.softmax(output, dim=1)
|
103 |
+
prob = probs[0, 1].item() # Probability of 'human'
|
104 |
impact = baseline_prob - prob
|
105 |
shap_values.append(impact)
|
106 |
x_zeroed[0, i] = original_val # restore
|
|
|
354 |
with torch.no_grad():
|
355 |
output = model(x_tensor)
|
356 |
probs = torch.softmax(output, dim=1)
|
357 |
+
# Using index 1 for probability of human
|
358 |
pred_human = probs[0, 1].item()
|
359 |
|
360 |
# Calculate SHAP values
|
|
|
815 |
server_port=7860, # Default Gradio port
|
816 |
show_api=False, # Hide API docs
|
817 |
debug=False # Set to True for debugging
|
818 |
+
)
|