Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -34,20 +34,27 @@ def classify_image(image):
|
|
34 |
# Run inference
|
35 |
outputs = model(**inputs)
|
36 |
|
37 |
-
# Extract logits
|
38 |
-
logits_per_image = outputs.logits_per_image #
|
39 |
-
|
|
|
|
|
|
|
|
|
40 |
|
41 |
# Extract probabilities for each category
|
42 |
-
safe_prob = probs[0][0] #
|
43 |
-
unsafe_prob = probs[0][1] #
|
|
|
44 |
|
45 |
# Normalize probabilities to ensure they sum to 100%
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
49 |
|
50 |
-
#
|
|
|
51 |
return {
|
52 |
"safe": round(safe_percentage, 2), # Rounded to 2 decimal places
|
53 |
"unsafe": round(unsafe_percentage, 2)
|
@@ -58,6 +65,7 @@ def classify_image(image):
|
|
58 |
|
59 |
|
60 |
|
|
|
61 |
# Step 3: Set Up Gradio Interface
|
62 |
iface = gr.Interface(
|
63 |
fn=classify_image,
|
|
|
34 |
# Run inference
|
35 |
outputs = model(**inputs)
|
36 |
|
37 |
+
# Extract logits
|
38 |
+
logits_per_image = outputs.logits_per_image # Shape: [1, 2]
|
39 |
+
print(f"Logits: {logits_per_image}")
|
40 |
+
|
41 |
+
# Apply softmax to logits to get probabilities
|
42 |
+
probs = logits_per_image.softmax(dim=1) # Shape: [1, 2]
|
43 |
+
print(f"Softmax probabilities: {probs}")
|
44 |
|
45 |
# Extract probabilities for each category
|
46 |
+
safe_prob = probs[0][0].item() # Extract 'safe' probability
|
47 |
+
unsafe_prob = probs[0][1].item() # Extract 'unsafe' probability
|
48 |
+
print(f"Safe probability: {safe_prob}, Unsafe probability: {unsafe_prob}")
|
49 |
|
50 |
# Normalize probabilities to ensure they sum to 100%
|
51 |
+
total_prob = safe_prob + unsafe_prob
|
52 |
+
print(f"Total probability before normalization: {total_prob}")
|
53 |
+
safe_percentage = (safe_prob / total_prob) * 100
|
54 |
+
unsafe_percentage = (unsafe_prob / total_prob) * 100
|
55 |
|
56 |
+
# Ensure the sum is exactly 100%
|
57 |
+
print(f"Normalized percentages: Safe={safe_percentage}%, Unsafe={unsafe_percentage}%")
|
58 |
return {
|
59 |
"safe": round(safe_percentage, 2), # Rounded to 2 decimal places
|
60 |
"unsafe": round(unsafe_percentage, 2)
|
|
|
65 |
|
66 |
|
67 |
|
68 |
+
|
69 |
# Step 3: Set Up Gradio Interface
|
70 |
iface = gr.Interface(
|
71 |
fn=classify_image,
|