Content_safety / app.py
Dileep7729's picture
Update app.py
52ea34b verified
raw
history blame
2.84 kB
import gradio as gr
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
import torch
# Step 1: Load Fine-Tuned Model from Hugging Face Model Hub
model_name = "quadranttechnologies/retail-content-safety-clip-finetuned"
print("Loading the fine-tuned model from Hugging Face Model Hub...")
try:
model = CLIPModel.from_pretrained(model_name, trust_remote_code=True)
processor = CLIPProcessor.from_pretrained(model_name)
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model or processor: {str(e)}")
raise
# Step 2: Define the Inference Function
def classify_image(image):
"""
Classify an image as 'safe' or 'unsafe' with the corresponding percentage.
Args:
image (PIL.Image.Image): The input image.
Returns:
dict: A dictionary containing probabilities for 'safe' and 'unsafe' or an error message.
"""
try:
# Check if the image is valid
if image is None:
raise ValueError("No image provided. Please upload an image.")
if not hasattr(image, "convert"):
raise ValueError("Uploaded file is not a valid image. Please upload a valid image (JPEG, PNG).")
# Define main categories
main_categories = ["safe", "unsafe"]
# Process the image
print("Processing the image...")
inputs = processor(text=main_categories, images=image, return_tensors="pt", padding=True)
print("Inputs processed successfully.")
# Perform inference
outputs = model(**inputs)
print("Model inference completed.")
# Calculate probabilities
logits_per_image = outputs.logits_per_image # Image-text similarity scores
probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
# Extract probabilities for "safe" and "unsafe"
safe_probability = probs[0][0].item() * 100 # Convert to percentage
unsafe_probability = probs[0][1].item() * 100 # Convert to percentage
print(f"Safe: {safe_probability:.2f}%, Unsafe: {unsafe_probability:.2f}%")
# Return results
return {
"safe": f"{safe_probability:.2f}%",
"unsafe": f"{unsafe_probability:.2f}%"
}
except Exception as e:
print(f"Error during inference: {str(e)}")
return {"Error": str(e)}
# Step 3: Set Up Gradio Interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"),
outputs=gr.Label(label="Output"), # Use Gradio's Label component for user-friendly display
title="Content Safety Classification",
description="Upload an image to classify it as 'safe' or 'unsafe' with corresponding probabilities.",
)
# Step 4: Launch Gradio Interface
if __name__ == "__main__":
iface.launch()