File size: 2,504 Bytes
514b8b1
4df31f3
bdea63d
514b8b1
e99084c
4df31f3
f3de939
36ab993
52ea34b
 
 
9303fde
52ea34b
36ab993
f3de939
4d41f6e
e99084c
a41b014
d237a07
36ab993
 
e99084c
d237a07
9303fde
36ab993
 
9303fde
e99084c
9303fde
36ab993
d237a07
cb84f56
9303fde
36ab993
 
 
d237a07
36ab993
d237a07
36ab993
4d41f6e
 
36ab993
d237a07
e99084c
4d41f6e
 
e99084c
d237a07
cb84f56
 
 
52ea34b
d237a07
36ab993
cb84f56
514b8b1
e99084c
514b8b1
 
a41b014
cb84f56
e99084c
 
cb84f56
a16e363
52ea34b
514b8b1
 
 
cb84f56
514b8b1
 
 
 
 
 
 
 
 
 
 
 
 
bbfef86
4df31f3
a41b014
a16e363
610954a
d237a07
52ea34b
bdea63d
4d41f6e
e99084c
f3de939
cb84f56
9303fde
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
from transformers import CLIPModel, CLIPProcessor
from PIL import Image

# Load the model and processor
model_name = "quadranttechnologies/retail-content-safety-clip-finetuned"
print("Loading the model and processor...")

try:
    model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
    processor = CLIPProcessor.from_pretrained(model_name)
    print("Model and processor loaded successfully.")
except Exception as e:
    print(f"Error loading the model or processor: {e}")
    raise RuntimeError(f"Failed to load model: {e}")

# Define the inference function
def classify_image(image):
    try:
        print("Starting image classification...")

        # Validate image input
        if image is None:
            raise ValueError("No image provided. Please upload a valid image.")
        if not hasattr(image, "convert"):
            raise ValueError("Uploaded file is not a valid image format.")

        # Define categories
        categories = ["safe", "unsafe"]
        print(f"Categories: {categories}")

        # Process the image
        inputs = processor(text=categories, images=image, return_tensors="pt", padding=True)
        print(f"Processed inputs: {inputs}")

        # Perform inference
        outputs = model(**inputs)
        print(f"Model outputs: {outputs}")

        # Calculate probabilities
        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=1)
        print(f"Probabilities: {probs}")

        # Extract probabilities
        safe_prob = probs[0][0].item() * 100
        unsafe_prob = probs[0][1].item() * 100
        print(f"Safe: {safe_prob:.2f}%, Unsafe: {unsafe_prob:.2f}%")

        # Determine the predicted category
        predicted_category = "safe" if safe_prob > unsafe_prob else "unsafe"
        return predicted_category, {"safe": f"{safe_prob:.2f}%", "unsafe": f"{unsafe_prob:.2f}%"}

    except Exception as e:
        print(f"Error during classification: {e}")
        return f"Error: {str(e)}", {}

# Gradio interface
iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Textbox(label="Predicted Category"),
        gr.Label(label="Probabilities"),
    ],
    title="Content Safety Classification",
    description="Upload an image to classify it as 'safe' or 'unsafe' with corresponding probabilities.",
)

if __name__ == "__main__":
    print("Launching Gradio interface...")
    iface.launch()