sguna commited on
Commit
9c3df37
·
verified ·
1 Parent(s): 32cecd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -4
app.py CHANGED
@@ -41,10 +41,8 @@ def classify_image(image):
41
  probs = logits_per_image.softmax(dim=1).detach().numpy() # Convert logits to probabilities
42
 
43
  # Extract probabilities for each category
44
- safe_prob = probs[0][0] # Safe probability
45
- unsafe_prob = probs[0][1] # Unsafe probability
46
- #safe_prob = sum(value if categories[i] in safe_categories else 0.0 for i, value in enumerate(probs[0]))
47
- #unsafe_prob = sum(value if categories[i] in unsafe_categories else 0.0 for i, value in enumerate(probs[0]))
48
 
49
  #debug
50
  for i, value in enumerate(probs[0]):
 
41
  probs = logits_per_image.softmax(dim=1).detach().numpy() # Convert logits to probabilities
42
 
43
  # Extract probabilities for each category
44
+ safe_prob = sum(value if categories[i] in safe_categories else 0.0 for i, value in enumerate(probs[0]))
45
+ unsafe_prob = sum(value if categories[i] in unsafe_categories else 0.0 for i, value in enumerate(probs[0]))
 
 
46
 
47
  #debug
48
  for i, value in enumerate(probs[0]):