File size: 4,129 Bytes
514b8b1
4df31f3
bdea63d
4d41f6e
514b8b1
4df31f3
 
bbfef86
36ab993
 
52ea34b
36ab993
52ea34b
 
9303fde
52ea34b
36ab993
9303fde
bbfef86
4d41f6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a41b014
4df31f3
36ab993
4df31f3
 
bdea63d
4df31f3
 
4d41f6e
cb84f56
4df31f3
d237a07
36ab993
 
 
d237a07
9303fde
36ab993
 
9303fde
36ab993
9303fde
36ab993
d237a07
cb84f56
36ab993
9303fde
36ab993
 
 
 
d237a07
36ab993
d237a07
36ab993
4d41f6e
 
36ab993
d237a07
36ab993
4d41f6e
 
d237a07
cb84f56
 
36ab993
cb84f56
 
 
52ea34b
d237a07
36ab993
cb84f56
514b8b1
4d41f6e
514b8b1
 
a41b014
cb84f56
 
 
 
a16e363
52ea34b
514b8b1
 
4d41f6e
514b8b1
cb84f56
514b8b1
 
 
 
 
 
 
 
 
 
 
 
 
bbfef86
4df31f3
a41b014
a16e363
610954a
d237a07
52ea34b
bdea63d
4d41f6e
cb84f56
9303fde
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
import requests

# Step 1: Load Fine-Tuned Model from Hugging Face Model Hub
model_name = "quadranttechnologies/retail-content-safety-clip-finetuned"

print("Initializing the application...")

try:
    print("Loading the model from Hugging Face Model Hub...")
    model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
    processor = CLIPProcessor.from_pretrained(model_name)
    print("Model and processor loaded successfully.")
except Exception as e:
    print(f"Error loading the model or processor: {e}")
    raise RuntimeError(f"Failed to load model: {e}")

# Step 2: Minimal Test Case to Verify Model and Processor
try:
    print("Running a minimal test case with the model...")

    # Test Image URL
    url = "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png"
    image = Image.open(requests.get(url, stream=True).raw)

    # Define test categories
    test_categories = ["safe", "unsafe"]

    # Process the image
    test_inputs = processor(text=test_categories, images=image, return_tensors="pt", padding=True)
    print(f"Test inputs processed: {test_inputs}")

    # Perform inference
    test_outputs = model(**test_inputs)
    print(f"Test outputs: {test_outputs}")

    # Check probabilities
    test_logits = test_outputs.logits_per_image
    test_probs = test_logits.softmax(dim=1)
    print(f"Test probabilities: {test_probs}")

except Exception as e:
    print(f"Error during the minimal test case: {e}")
    raise RuntimeError(f"Test case failed: {e}")

# Step 3: Define the Inference Function
def classify_image(image):
    """
    Classify an image as 'safe' or 'unsafe' and return probabilities.

    Args:
        image (PIL.Image.Image): Uploaded image.
    
    Returns:
        str: Predicted category.
        dict: Probabilities for "safe" and "unsafe".
    """
    try:
        print("Starting image classification...")

        # Check if the image is valid
        if image is None:
            raise ValueError("No image provided. Please upload a valid image.")
        if not hasattr(image, "convert"):
            raise ValueError("Uploaded file is not a valid image format.")

        # Define main categories
        categories = ["safe", "unsafe"]
        print(f"Categories: {categories}")

        # Process the image
        print("Processing the image with the processor...")
        inputs = processor(text=categories, images=image, return_tensors="pt", padding=True)
        print(f"Processed inputs: {inputs}")

        # Perform inference
        print("Running model inference...")
        outputs = model(**inputs)
        print(f"Model outputs: {outputs}")

        # Calculate probabilities
        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=1)
        print(f"Probabilities: {probs}")

        # Extract probabilities for each category
        safe_prob = probs[0][0].item() * 100
        unsafe_prob = probs[0][1].item() * 100

        # Determine the predicted category
        predicted_category = "safe" if safe_prob > unsafe_prob else "unsafe"
        print(f"Predicted category: {predicted_category}")

        # Return the predicted category and probabilities
        return predicted_category, {"safe": f"{safe_prob:.2f}%", "unsafe": f"{unsafe_prob:.2f}%"}

    except Exception as e:
        print(f"Error during classification: {e}")
        return f"Error: {str(e)}", {}

# Step 4: Set Up Gradio Interface
iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Textbox(label="Predicted Category"),  # Display the predicted category prominently
        gr.Label(label="Probabilities"),        # Display probabilities with a progress bar
    ],
    title="Content Safety Classification",
    description="Upload an image to classify it as 'safe' or 'unsafe' with corresponding probabilities.",
)

# Step 5: Launch Gradio Interface
if __name__ == "__main__":
    print("Launching Gradio interface...")
    iface.launch()