|
import gradio as gr |
|
from transformers import CLIPModel, CLIPProcessor |
|
from PIL import Image |
|
|
|
|
|
model_name = "quadranttechnologies/retail-content-safety-clip-finetuned" |
|
print("Loading the model and processor...") |
|
|
|
try: |
|
model = CLIPModel.from_pretrained(model_name, trust_remote_code=True) |
|
processor = CLIPProcessor.from_pretrained(model_name) |
|
print("Model and processor loaded successfully.") |
|
except Exception as e: |
|
print(f"Error loading the model or processor: {e}") |
|
raise RuntimeError(f"Failed to load model: {e}") |
|
|
|
|
|
def classify_image(image): |
|
try: |
|
print("Starting image classification...") |
|
|
|
|
|
if image is None: |
|
raise ValueError("No image provided. Please upload a valid image.") |
|
if not hasattr(image, "convert"): |
|
raise ValueError("Uploaded file is not a valid image format.") |
|
|
|
|
|
categories = ["safe", "unsafe"] |
|
print(f"Categories: {categories}") |
|
|
|
|
|
inputs = processor(text=categories, images=image, return_tensors="pt", padding=True) |
|
print(f"Processed inputs: {inputs}") |
|
|
|
|
|
outputs = model(**inputs) |
|
print(f"Model outputs: {outputs}") |
|
|
|
|
|
logits_per_image = outputs.logits_per_image |
|
probs = logits_per_image.softmax(dim=1) |
|
print(f"Probabilities: {probs}") |
|
|
|
|
|
safe_prob = probs[0][0].item() * 100 |
|
unsafe_prob = probs[0][1].item() * 100 |
|
print(f"Safe: {safe_prob:.2f}%, Unsafe: {unsafe_prob:.2f}%") |
|
|
|
|
|
predicted_category = "safe" if safe_prob > unsafe_prob else "unsafe" |
|
return predicted_category, {"safe": f"{safe_prob:.2f}%", "unsafe": f"{unsafe_prob:.2f}%"} |
|
|
|
except Exception as e: |
|
print(f"Error during classification: {e}") |
|
return f"Error: {str(e)}", {} |
|
|
|
|
|
iface = gr.Interface( |
|
fn=classify_image, |
|
inputs=gr.Image(type="pil"), |
|
outputs=[ |
|
gr.Textbox(label="Predicted Category"), |
|
gr.Label(label="Probabilities"), |
|
], |
|
title="Content Safety Classification", |
|
description="Upload an image to classify it as 'safe' or 'unsafe' with corresponding probabilities.", |
|
) |
|
|
|
if __name__ == "__main__": |
|
print("Launching Gradio interface...") |
|
iface.launch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|