hiyata commited on
Commit
0d6258f
·
verified ·
1 Parent(s): 82425ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -90,7 +90,7 @@ def calculate_shap_values(model, x_tensor):
90
  # Baseline
91
  baseline_output = model(x_tensor)
92
  baseline_probs = torch.softmax(baseline_output, dim=1)
93
- baseline_prob = baseline_probs[0, 1].item() # Probability of 'human' class
94
 
95
  # Zeroing each feature to measure impact
96
  shap_values = []
@@ -100,7 +100,7 @@ def calculate_shap_values(model, x_tensor):
100
  x_zeroed[0, i] = 0.0
101
  output = model(x_zeroed)
102
  probs = torch.softmax(output, dim=1)
103
- prob = probs[0, 1].item()
104
  impact = baseline_prob - prob
105
  shap_values.append(impact)
106
  x_zeroed[0, i] = original_val # restore
@@ -354,6 +354,7 @@ def analyze_sequence(file_path, top_k=10, fasta_text="", window_size=500):
354
  with torch.no_grad():
355
  output = model(x_tensor)
356
  probs = torch.softmax(output, dim=1)
 
357
  pred_human = probs[0, 1].item()
358
 
359
  # Calculate SHAP values
@@ -814,4 +815,4 @@ if __name__ == "__main__":
814
  server_port=7860, # Default Gradio port
815
  show_api=False, # Hide API docs
816
  debug=False # Set to True for debugging
817
- )
 
90
  # Baseline
91
  baseline_output = model(x_tensor)
92
  baseline_probs = torch.softmax(baseline_output, dim=1)
93
+ baseline_prob = baseline_probs[0, 1].item() # Probability of 'human'
94
 
95
  # Zeroing each feature to measure impact
96
  shap_values = []
 
100
  x_zeroed[0, i] = 0.0
101
  output = model(x_zeroed)
102
  probs = torch.softmax(output, dim=1)
103
+ prob = probs[0, 1].item() # Probability of 'human'
104
  impact = baseline_prob - prob
105
  shap_values.append(impact)
106
  x_zeroed[0, i] = original_val # restore
 
354
  with torch.no_grad():
355
  output = model(x_tensor)
356
  probs = torch.softmax(output, dim=1)
357
+ # Using index 1 for probability of human
358
  pred_human = probs[0, 1].item()
359
 
360
  # Calculate SHAP values
 
815
  server_port=7860, # Default Gradio port
816
  show_api=False, # Hide API docs
817
  debug=False # Set to True for debugging
818
+ )