File size: 3,562 Bytes
514b8b1
4df31f3
790f088
 
514b8b1
dcff825
4df31f3
dcff825
790f088
 
 
 
 
 
 
 
 
 
988ceee
dcff825
a41b014
988ceee
790f088
988ceee
 
790f088
dcff825
988ceee
790f088
988ceee
790f088
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514b8b1
dcff825
514b8b1
 
dcff825
790f088
a16e363
dcff825
514b8b1
 
790f088
514b8b1
790f088
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514b8b1
 
 
 
 
 
 
 
 
 
 
 
 
bbfef86
4df31f3
a41b014
a16e363
610954a
d237a07
52ea34b
bdea63d
4d41f6e
e99084c
f3de939
dcff825
 
988ceee
e305da6
ca0f653
790f088
cb84f56
9303fde
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
import requests

# Step 1: Load Fine-Tuned Model from Hugging Face Model Hub
model_name = "quadranttechnologies/retail-content-safety-clip-finetuned"

print("Initializing the application...")

try:
    print("Loading the model from Hugging Face Model Hub...")
    model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
    processor = CLIPProcessor.from_pretrained(model_name)
    print("Model and processor loaded successfully.")
except Exception as e:
    print(f"Error loading the model or processor: {e}")
    raise RuntimeError(f"Failed to load model: {e}")

# Step 2: Define the Inference Function
def classify_image(image):
    """
    Classify an image as 'safe' or 'unsafe' and return probabilities.

    Args:
        image (PIL.Image.Image): Uploaded image.
    
    Returns:
        dict: Classification results or an error message.
    """
    try:
        print("Starting image classification...")

        # Validate input
        if image is None:
            raise ValueError("No image provided. Please upload a valid image.")

        # Validate image format
        if not hasattr(image, "convert"):
            raise ValueError("Invalid image format. Please upload a valid image (JPEG, PNG, etc.).")

        # Define categories
        categories = ["safe", "unsafe"]

        # Process the image with the processor
        print("Processing the image...")
        inputs = processor(text=categories, images=image, return_tensors="pt", padding=True)
        print(f"Processed inputs: {inputs}")

        # Run inference with the model
        print("Running model inference...")
        outputs = model(**inputs)
        print(f"Model outputs: {outputs}")

        # Extract logits and probabilities
        logits_per_image = outputs.logits_per_image  # Image-text similarity scores
        probs = logits_per_image.softmax(dim=1)  # Convert logits to probabilities
        print(f"Calculated probabilities: {probs}")

        # Extract probabilities for each category
        safe_prob = probs[0][0].item() * 100  # Safe percentage
        unsafe_prob = probs[0][1].item() * 100  # Unsafe percentage

        # Return results
        return {
            "safe": f"{safe_prob:.2f}%",
            "unsafe": f"{unsafe_prob:.2f}%"
        }

    except Exception as e:
        # Log and return detailed error messages
        print(f"Error during classification: {e}")
        return {"Error": str(e)}

# Step 3: Set Up Gradio Interface
iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(label="Output"),  # Display probabilities as progress bars
    title="Content Safety Classification",
    description="Upload an image to classify it as 'safe' or 'unsafe' with corresponding probabilities.",
)

# Step 4: Test Before Launch
if __name__ == "__main__":
    print("Testing model locally with a sample image...")
    try:
        # Test with a sample image
        url = "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png"
        test_image = Image.open(requests.get(url, stream=True).raw)

        # Run the classification function
        print("Running local test...")
        result = classify_image(test_image)
        print(f"Local Test Result: {result}")
    except Exception as e:
        print(f"Error during local test: {e}")

    # Launch Gradio Interface
    print("Launching the Gradio interface...")
    iface.launch()