File size: 2,839 Bytes
514b8b1
4df31f3
52ea34b
 
514b8b1
4df31f3
 
bbfef86
4df31f3
52ea34b
 
 
 
 
 
 
bbfef86
4df31f3
a41b014
4df31f3
d237a07
4df31f3
 
 
 
 
52ea34b
4df31f3
d237a07
52ea34b
d237a07
 
52ea34b
 
d237a07
52ea34b
d237a07
 
52ea34b
d237a07
 
52ea34b
d237a07
52ea34b
d237a07
52ea34b
d237a07
52ea34b
d237a07
 
 
52ea34b
d237a07
 
 
 
 
52ea34b
d237a07
 
 
 
52ea34b
d237a07
 
 
514b8b1
4df31f3
514b8b1
 
a41b014
52ea34b
a16e363
52ea34b
514b8b1
 
4df31f3
514b8b1
 
 
 
 
 
 
 
 
 
 
 
 
 
bbfef86
4df31f3
a41b014
a16e363
610954a
d237a07
52ea34b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
import torch

# Step 1: Load Fine-Tuned Model from Hugging Face Model Hub
model_name = "quadranttechnologies/retail-content-safety-clip-finetuned"

print("Loading the fine-tuned model from Hugging Face Model Hub...")
try:
    model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
    processor = CLIPProcessor.from_pretrained(model_name)
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model or processor: {str(e)}")
    raise

# Step 2: Define the Inference Function
def classify_image(image):
    """
    Classify an image as 'safe' or 'unsafe' with the corresponding percentage.

    Args:
        image (PIL.Image.Image): The input image.
    
    Returns:
        dict: A dictionary containing probabilities for 'safe' and 'unsafe' or an error message.
    """
    try:
        # Check if the image is valid
        if image is None:
            raise ValueError("No image provided. Please upload an image.")
        if not hasattr(image, "convert"):
            raise ValueError("Uploaded file is not a valid image. Please upload a valid image (JPEG, PNG).")

        # Define main categories
        main_categories = ["safe", "unsafe"]

        # Process the image
        print("Processing the image...")
        inputs = processor(text=main_categories, images=image, return_tensors="pt", padding=True)
        print("Inputs processed successfully.")

        # Perform inference
        outputs = model(**inputs)
        print("Model inference completed.")

        # Calculate probabilities
        logits_per_image = outputs.logits_per_image  # Image-text similarity scores
        probs = logits_per_image.softmax(dim=1)  # Convert logits to probabilities

        # Extract probabilities for "safe" and "unsafe"
        safe_probability = probs[0][0].item() * 100  # Convert to percentage
        unsafe_probability = probs[0][1].item() * 100  # Convert to percentage

        print(f"Safe: {safe_probability:.2f}%, Unsafe: {unsafe_probability:.2f}%")

        # Return results
        return {
            "safe": f"{safe_probability:.2f}%",
            "unsafe": f"{unsafe_probability:.2f}%"
        }

    except Exception as e:
        print(f"Error during inference: {str(e)}")
        return {"Error": str(e)}

# Step 3: Set Up Gradio Interface
iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(label="Output"),  # Use Gradio's Label component for user-friendly display
    title="Content Safety Classification",
    description="Upload an image to classify it as 'safe' or 'unsafe' with corresponding probabilities.",
)

# Step 4: Launch Gradio Interface
if __name__ == "__main__":
    iface.launch()