File size: 3,305 Bytes
514b8b1
4df31f3
bdea63d
514b8b1
4df31f3
 
bbfef86
36ab993
 
52ea34b
36ab993
52ea34b
 
9303fde
52ea34b
36ab993
9303fde
bbfef86
4df31f3
a41b014
4df31f3
36ab993
4df31f3
 
bdea63d
4df31f3
 
cb84f56
 
4df31f3
d237a07
36ab993
 
 
d237a07
9303fde
36ab993
 
9303fde
36ab993
9303fde
36ab993
d237a07
cb84f56
36ab993
9303fde
36ab993
 
 
 
d237a07
36ab993
d237a07
36ab993
d237a07
 
36ab993
d237a07
36ab993
9303fde
 
d237a07
cb84f56
 
36ab993
cb84f56
 
 
52ea34b
d237a07
36ab993
cb84f56
514b8b1
4df31f3
514b8b1
 
a41b014
cb84f56
 
 
 
a16e363
52ea34b
514b8b1
 
4df31f3
514b8b1
cb84f56
514b8b1
 
 
 
 
 
 
 
 
 
 
 
 
bbfef86
4df31f3
a41b014
a16e363
610954a
d237a07
52ea34b
bdea63d
cb84f56
9303fde
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
from transformers import CLIPModel, CLIPProcessor
from PIL import Image

# Step 1: Load Fine-Tuned Model from Hugging Face Model Hub
model_name = "quadranttechnologies/retail-content-safety-clip-finetuned"

print("Initializing the application...")

try:
    print("Loading the model from Hugging Face Model Hub...")
    model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
    processor = CLIPProcessor.from_pretrained(model_name)
    print("Model and processor loaded successfully.")
except Exception as e:
    print(f"Error loading the model or processor: {e}")
    raise RuntimeError(f"Failed to load model: {e}")

# Step 2: Define the Inference Function
def classify_image(image):
    """
    Classify an image as 'safe' or 'unsafe' and return probabilities.

    Args:
        image (PIL.Image.Image): Uploaded image.
    
    Returns:
        str: Predicted category ("safe" or "unsafe").
        dict: Probabilities for "safe" and "unsafe".
    """
    try:
        print("Starting image classification...")

        # Check if the image is valid
        if image is None:
            raise ValueError("No image provided. Please upload a valid image.")
        if not hasattr(image, "convert"):
            raise ValueError("Uploaded file is not a valid image format.")

        # Define main categories
        categories = ["safe", "unsafe"]
        print(f"Categories: {categories}")

        # Process the image
        print("Processing the image with the processor...")
        inputs = processor(text=categories, images=image, return_tensors="pt", padding=True)
        print(f"Processed inputs: {inputs}")

        # Perform inference
        print("Running model inference...")
        outputs = model(**inputs)
        print(f"Model outputs: {outputs}")

        # Calculate probabilities
        logits_per_image = outputs.logits_per_image  # Image-text similarity scores
        probs = logits_per_image.softmax(dim=1)  # Convert logits to probabilities
        print(f"Probabilities: {probs}")

        # Extract probabilities for each category
        safe_prob = probs[0][0].item() * 100  # Safe percentage
        unsafe_prob = probs[0][1].item() * 100  # Unsafe percentage

        # Determine the predicted category
        predicted_category = "safe" if safe_prob > unsafe_prob else "unsafe"
        print(f"Predicted category: {predicted_category}")

        # Return the predicted category and probabilities
        return predicted_category, {"safe": f"{safe_prob:.2f}%", "unsafe": f"{unsafe_prob:.2f}%"}

    except Exception as e:
        print(f"Error during classification: {e}")
        return f"Error: {str(e)}", {}

# Step 3: Set Up Gradio Interface
iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Textbox(label="Predicted Category"),  # Display the predicted category prominently
        gr.Label(label="Probabilities"),        # Display probabilities with a progress bar
    ],
    title="Content Safety Classification",
    description="Upload an image to classify it as 'safe' or 'unsafe' with corresponding probabilities.",
)

# Step 4: Launch Gradio Interface
if __name__ == "__main__":
    print("Launching Gradio interface...")
    iface.launch()