sguna commited on
Commit
a35a781
·
verified ·
1 Parent(s): c46febd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -26,7 +26,9 @@ def classify_image(image):
26
  raise ValueError("No image provided. Please upload a valid image.")
27
 
28
  # Define categories
29
- categories = ["safe", "unsafe"]
 
 
30
 
31
  # Process the image
32
  inputs = processor(text=categories, images=image, return_tensors="pt", padding=True)
@@ -39,12 +41,17 @@ def classify_image(image):
39
  probs = logits_per_image.softmax(dim=1).detach().numpy() # Convert logits to probabilities
40
 
41
  # Extract probabilities for each category
42
- safe_prob = probs[0][0] # Safe probability
43
  unsafe_prob = probs[0][1] # Unsafe probability
 
 
 
 
 
 
44
 
45
  # Return raw probabilities
46
  return {
47
- "unknown" : 0.2,
48
  "safe": safe_prob, # Leave as a fraction (e.g., 0.92)
49
  "unsafe": unsafe_prob # Leave as a fraction (e.g., 0.08)
50
  }
 
26
  raise ValueError("No image provided. Please upload a valid image.")
27
 
28
  # Define categories
29
+ unsafe_categories = ["hate", "sexual", "violent", "self-harm"]
30
+ safe_categories = ["safe", "retail product"]
31
+ categories = safe_categories + unsafe_categories
32
 
33
  # Process the image
34
  inputs = processor(text=categories, images=image, return_tensors="pt", padding=True)
 
41
  probs = logits_per_image.softmax(dim=1).detach().numpy() # Convert logits to probabilities
42
 
43
  # Extract probabilities for each category
44
+ #safe_prob = probs[0][0] # Safe probability
45
  unsafe_prob = probs[0][1] # Unsafe probability
46
+ safe_prob = sum(value if categories[i] in safe_categories else 0.0 for i, value in enumerate(probs[0]))
47
+ unsafe_prob = sum(value if categories[i] in unsafe_categoriessafe_categories else 0.0 for i, value in enumerate(probs[0]))
48
+
49
+ #debug
50
+ for i, value in enumerate(probs[0]):
51
+ print(categories[i], value)
52
 
53
  # Return raw probabilities
54
  return {
 
55
  "safe": safe_prob, # Leave as a fraction (e.g., 0.92)
56
  "unsafe": unsafe_prob # Leave as a fraction (e.g., 0.08)
57
  }