Dileep7729 commited on
Commit
f77a486
·
verified ·
1 Parent(s): d5de525

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -9
app.py CHANGED
@@ -34,20 +34,27 @@ def classify_image(image):
34
  # Run inference
35
  outputs = model(**inputs)
36
 
37
- # Extract logits and apply softmax
38
- logits_per_image = outputs.logits_per_image # Image-text similarity scores
39
- probs = logits_per_image.softmax(dim=1).detach().numpy() # Convert logits to probabilities
 
 
 
 
40
 
41
  # Extract probabilities for each category
42
- safe_prob = probs[0][0] # Safe probability
43
- unsafe_prob = probs[0][1] # Unsafe probability
 
44
 
45
  # Normalize probabilities to ensure they sum to 100%
46
- total = safe_prob + unsafe_prob
47
- safe_percentage = (safe_prob / total) * 100
48
- unsafe_percentage = (unsafe_prob / total) * 100
 
49
 
50
- # Return results as percentages
 
51
  return {
52
  "safe": round(safe_percentage, 2), # Rounded to 2 decimal places
53
  "unsafe": round(unsafe_percentage, 2)
@@ -58,6 +65,7 @@ def classify_image(image):
58
 
59
 
60
 
 
61
  # Step 3: Set Up Gradio Interface
62
  iface = gr.Interface(
63
  fn=classify_image,
 
34
  # Run inference
35
  outputs = model(**inputs)
36
 
37
+ # Extract logits
38
+ logits_per_image = outputs.logits_per_image # Shape: [1, 2]
39
+ print(f"Logits: {logits_per_image}")
40
+
41
+ # Apply softmax to logits to get probabilities
42
+ probs = logits_per_image.softmax(dim=1) # Shape: [1, 2]
43
+ print(f"Softmax probabilities: {probs}")
44
 
45
  # Extract probabilities for each category
46
+ safe_prob = probs[0][0].item() # Extract 'safe' probability
47
+ unsafe_prob = probs[0][1].item() # Extract 'unsafe' probability
48
+ print(f"Safe probability: {safe_prob}, Unsafe probability: {unsafe_prob}")
49
 
50
  # Normalize probabilities to ensure they sum to 100%
51
+ total_prob = safe_prob + unsafe_prob
52
+ print(f"Total probability before normalization: {total_prob}")
53
+ safe_percentage = (safe_prob / total_prob) * 100
54
+ unsafe_percentage = (unsafe_prob / total_prob) * 100
55
 
56
+ # Ensure the sum is exactly 100%
57
+ print(f"Normalized percentages: Safe={safe_percentage}%, Unsafe={unsafe_percentage}%")
58
  return {
59
  "safe": round(safe_percentage, 2), # Rounded to 2 decimal places
60
  "unsafe": round(unsafe_percentage, 2)
 
65
 
66
 
67
 
68
+
69
  # Step 3: Set Up Gradio Interface
70
  iface = gr.Interface(
71
  fn=classify_image,