import gradio as gr from transformers import CLIPModel, CLIPProcessor from PIL import Image import torch # Step 1: Load Fine-Tuned Model from Hugging Face Model Hub model_name = "quadranttechnologies/retail-content-safety-clip-finetuned" print("Loading the fine-tuned model from Hugging Face Model Hub...") try: model = CLIPModel.from_pretrained(model_name, trust_remote_code=True) processor = CLIPProcessor.from_pretrained(model_name) print("Model loaded successfully.") except Exception as e: print(f"Error loading model or processor: {str(e)}") raise # Step 2: Define the Inference Function def classify_image(image): """ Classify an image as 'safe' or 'unsafe' with the corresponding percentage. Args: image (PIL.Image.Image): The input image. Returns: dict: A dictionary containing probabilities for 'safe' and 'unsafe' or an error message. """ try: # Check if the image is valid if image is None: raise ValueError("No image provided. Please upload an image.") if not hasattr(image, "convert"): raise ValueError("Uploaded file is not a valid image. Please upload a valid image (JPEG, PNG).") # Define main categories main_categories = ["safe", "unsafe"] # Process the image print("Processing the image...") inputs = processor(text=main_categories, images=image, return_tensors="pt", padding=True) print("Inputs processed successfully.") # Perform inference outputs = model(**inputs) print("Model inference completed.") # Calculate probabilities logits_per_image = outputs.logits_per_image # Image-text similarity scores probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities # Extract probabilities for "safe" and "unsafe" safe_probability = probs[0][0].item() * 100 # Convert to percentage unsafe_probability = probs[0][1].item() * 100 # Convert to percentage print(f"Safe: {safe_probability:.2f}%, Unsafe: {unsafe_probability:.2f}%") # Return results return { "safe": f"{safe_probability:.2f}%", "unsafe": f"{unsafe_probability:.2f}%" } except Exception as e: print(f"Error during inference: {str(e)}") return {"Error": str(e)} # Step 3: Set Up Gradio Interface iface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs=gr.Label(label="Output"), # Use Gradio's Label component for user-friendly display title="Content Safety Classification", description="Upload an image to classify it as 'safe' or 'unsafe' with corresponding probabilities.", ) # Step 4: Launch Gradio Interface if __name__ == "__main__": iface.launch()