import gradio as gr from transformers import CLIPModel, CLIPProcessor from PIL import Image # Load the model and processor model_name = "quadranttechnologies/retail-content-safety-clip-finetuned" print("Loading the model and processor...") try: model = CLIPModel.from_pretrained(model_name, trust_remote_code=True) processor = CLIPProcessor.from_pretrained(model_name) print("Model and processor loaded successfully.") except Exception as e: print(f"Error loading the model or processor: {e}") raise RuntimeError(f"Failed to load model: {e}") # Define the inference function def classify_image(image): try: print("Starting image classification...") # Validate image input if image is None: raise ValueError("No image provided. Please upload a valid image.") if not hasattr(image, "convert"): raise ValueError("Uploaded file is not a valid image format.") # Define categories categories = ["safe", "unsafe"] print(f"Categories: {categories}") # Process the image inputs = processor(text=categories, images=image, return_tensors="pt", padding=True) print(f"Processed inputs: {inputs}") # Perform inference outputs = model(**inputs) print(f"Model outputs: {outputs}") # Calculate probabilities logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1) print(f"Probabilities: {probs}") # Extract probabilities safe_prob = probs[0][0].item() * 100 unsafe_prob = probs[0][1].item() * 100 print(f"Safe: {safe_prob:.2f}%, Unsafe: {unsafe_prob:.2f}%") # Determine the predicted category predicted_category = "safe" if safe_prob > unsafe_prob else "unsafe" return predicted_category, {"safe": f"{safe_prob:.2f}%", "unsafe": f"{unsafe_prob:.2f}%"} except Exception as e: print(f"Error during classification: {e}") return f"Error: {str(e)}", {} # Gradio interface iface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs=[ gr.Textbox(label="Predicted Category"), gr.Label(label="Probabilities"), ], title="Content Safety Classification", description="Upload an image to classify it as 'safe' or 'unsafe' with corresponding probabilities.", ) if __name__ == "__main__": print("Launching Gradio interface...") iface.launch()