import gradio as gr from transformers import CLIPModel, CLIPProcessor from PIL import Image # Step 1: Load Fine-Tuned Model from Hugging Face Model Hub model_name = "quadranttechnologies/retail-content-safety-clip-finetuned" print("Initializing the application...") try: print("Loading the model from Hugging Face Model Hub...") model = CLIPModel.from_pretrained(model_name, trust_remote_code=True) processor = CLIPProcessor.from_pretrained(model_name) print("Model and processor loaded successfully.") except Exception as e: print(f"Error loading the model or processor: {e}") raise RuntimeError(f"Failed to load model: {e}") # Step 2: Define the Inference Function def classify_image(image): """ Classify an image as 'safe' or 'unsafe' and return probabilities. """ try: if image is None: raise ValueError("No image provided. Please upload a valid image.") # Define categories unsafe_categories = ["hate", "sexual", "violent", "self-harm"] safe_categories = ["safe", "retail product"] categories = safe_categories + unsafe_categories # Process the image inputs = processor(text=categories, images=image, return_tensors="pt", padding=True) # Run inference outputs = model(**inputs) # Extract logits and apply softmax logits_per_image = outputs.logits_per_image # Shape: [1, 2] probs = logits_per_image.softmax(dim=1).detach().numpy() # Convert logits to probabilities # Extract probabilities for each category safe_prob = sum(value if categories[i] in safe_categories else 0.0 for i, value in enumerate(probs[0])) unsafe_prob = sum(value if categories[i] in unsafe_categories else 0.0 for i, value in enumerate(probs[0])) #debug for i, value in enumerate(probs[0]): print(categories[i], value) # Return raw probabilities return { "safe": safe_prob, # Leave as a fraction (e.g., 0.92) "unsafe": unsafe_prob # Leave as a fraction (e.g., 0.08) } except Exception as e: return {"Error": str(e)} # Step 3: Set Up Gradio Interface iface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=2), # Use gr.Label to display probabilities with a bar-style visualization title="Content Safety Classification", description="Upload an image to classify it as 'safe' or 'unsafe' with corresponding probabilities.", ) # Step 4: Launch Gradio Interface if __name__ == "__main__": print("Launching the Gradio interface...") iface.launch() # Save the fine-tuned model model.save_pretrained("fine-tuned-model") processor.save_pretrained("fine-tuned-model") print("Model and processor saved locally in the 'fine-tuned-model' directory.")