File size: 2,504 Bytes
514b8b1 4df31f3 bdea63d 514b8b1 e99084c 4df31f3 f3de939 36ab993 52ea34b 9303fde 52ea34b 36ab993 f3de939 4d41f6e e99084c a41b014 d237a07 36ab993 e99084c d237a07 9303fde 36ab993 9303fde e99084c 9303fde 36ab993 d237a07 cb84f56 9303fde 36ab993 d237a07 36ab993 d237a07 36ab993 4d41f6e 36ab993 d237a07 e99084c 4d41f6e e99084c d237a07 cb84f56 52ea34b d237a07 36ab993 cb84f56 514b8b1 e99084c 514b8b1 a41b014 cb84f56 e99084c cb84f56 a16e363 52ea34b 514b8b1 cb84f56 514b8b1 bbfef86 4df31f3 a41b014 a16e363 610954a d237a07 52ea34b bdea63d 4d41f6e e99084c f3de939 cb84f56 9303fde |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import gradio as gr
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
# Load the model and processor
model_name = "quadranttechnologies/retail-content-safety-clip-finetuned"
print("Loading the model and processor...")
try:
model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
processor = CLIPProcessor.from_pretrained(model_name)
print("Model and processor loaded successfully.")
except Exception as e:
print(f"Error loading the model or processor: {e}")
raise RuntimeError(f"Failed to load model: {e}")
# Define the inference function
def classify_image(image):
try:
print("Starting image classification...")
# Validate image input
if image is None:
raise ValueError("No image provided. Please upload a valid image.")
if not hasattr(image, "convert"):
raise ValueError("Uploaded file is not a valid image format.")
# Define categories
categories = ["safe", "unsafe"]
print(f"Categories: {categories}")
# Process the image
inputs = processor(text=categories, images=image, return_tensors="pt", padding=True)
print(f"Processed inputs: {inputs}")
# Perform inference
outputs = model(**inputs)
print(f"Model outputs: {outputs}")
# Calculate probabilities
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
print(f"Probabilities: {probs}")
# Extract probabilities
safe_prob = probs[0][0].item() * 100
unsafe_prob = probs[0][1].item() * 100
print(f"Safe: {safe_prob:.2f}%, Unsafe: {unsafe_prob:.2f}%")
# Determine the predicted category
predicted_category = "safe" if safe_prob > unsafe_prob else "unsafe"
return predicted_category, {"safe": f"{safe_prob:.2f}%", "unsafe": f"{unsafe_prob:.2f}%"}
except Exception as e:
print(f"Error during classification: {e}")
return f"Error: {str(e)}", {}
# Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"),
outputs=[
gr.Textbox(label="Predicted Category"),
gr.Label(label="Probabilities"),
],
title="Content Safety Classification",
description="Upload an image to classify it as 'safe' or 'unsafe' with corresponding probabilities.",
)
if __name__ == "__main__":
print("Launching Gradio interface...")
iface.launch()
|