File size: 2,953 Bytes
514b8b1
4df31f3
790f088
514b8b1
dcff825
4df31f3
dcff825
790f088
 
 
 
 
 
 
 
 
 
988ceee
dcff825
a41b014
988ceee
790f088
988ceee
790f088
 
 
 
 
 
 
d5de525
790f088
 
d5de525
790f088
ab3b271
f77a486
 
 
 
 
 
 
790f088
ab3b271
f77a486
 
 
790f088
ab3b271
f77a486
 
 
 
a2a2ce4
f77a486
 
ca68fdd
d5de525
 
790f088
 
 
934fd4a
 
ab3b271
d5de525
f77a486
934fd4a
 
 
 
45892bc
934fd4a
 
 
 
1b54fc6
934fd4a
 
 
 
7c06143
514b8b1
 
 
 
 
 
 
 
 
 
 
 
bbfef86
4df31f3
a41b014
a16e363
610954a
d237a07
52ea34b
bdea63d
4d41f6e
e99084c
f3de939
dcff825
 
988ceee
e305da6
ca0f653
790f088
1aaff03
1b54fc6
ca68fdd
45892bc
cb84f56
9303fde
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
from transformers import CLIPModel, CLIPProcessor
from PIL import Image

# Step 1: Load Fine-Tuned Model from Hugging Face Model Hub
model_name = "quadranttechnologies/retail-content-safety-clip-finetuned"

print("Initializing the application...")

try:
    print("Loading the model from Hugging Face Model Hub...")
    model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
    processor = CLIPProcessor.from_pretrained(model_name)
    print("Model and processor loaded successfully.")
except Exception as e:
    print(f"Error loading the model or processor: {e}")
    raise RuntimeError(f"Failed to load model: {e}")

# Step 2: Define the Inference Function
def classify_image(image):
    """
    Classify an image as 'safe' or 'unsafe' and return probabilities.
    """
    try:
        if image is None:
            raise ValueError("No image provided. Please upload a valid image.")

        # Define categories
        categories = ["safe", "unsafe"]

        # Process the image
        inputs = processor(text=categories, images=image, return_tensors="pt", padding=True)

        # Run inference
        outputs = model(**inputs)

        # Extract logits
        logits_per_image = outputs.logits_per_image  # Shape: [1, 2]
        print(f"Logits: {logits_per_image}")

        # Apply softmax to logits to get probabilities
        probs = logits_per_image.softmax(dim=1)  # Shape: [1, 2]
        print(f"Softmax probabilities: {probs}")

        # Extract probabilities for each category
        safe_prob = probs[0][0].item()  # Extract 'safe' probability
        unsafe_prob = probs[0][1].item()  # Extract 'unsafe' probability
        print(f"Safe probability: {safe_prob}, Unsafe probability: {unsafe_prob}")

        # Normalize probabilities to ensure they sum to 100%
        total_prob = safe_prob + unsafe_prob
        print(f"Total probability before normalization: {total_prob}")
        safe_percentage = (safe_prob / total_prob) * 100
        unsafe_percentage = (unsafe_prob / total_prob) * 100

        # Ensure the sum is exactly 100%
        print(f"Normalized percentages: Safe={safe_percentage}%, Unsafe={unsafe_percentage}%")
        return {
            "safe": round(safe_percentage, 2),  # Rounded to 2 decimal places
            "unsafe": round(unsafe_percentage, 2)
        }

    except Exception as e:
        return {"Error": str(e)}




# Step 3: Set Up Gradio Interface
iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(num_top_classes=2),  # Use gr.Label to display probabilities with a bar-style visualization
    title="Content Safety Classification",
    description="Upload an image to classify it as 'safe' or 'unsafe' with corresponding probabilities.",
)

# Step 4: Launch Gradio Interface
if __name__ == "__main__":
    print("Launching the Gradio interface...")
    iface.launch()