File size: 2,313 Bytes
93f5629 17c8406 93f5629 e88a32d 12cea06 e88a32d 93f5629 a9d7990 433b282 e88a32d 93f5629 a9d7990 f43015f 93f5629 e88a32d 7820a52 12cea06 a9d7990 12cea06 a9d7990 e88a32d bf857a6 e88a32d bf857a6 e88a32d bf857a6 7820a52 a9d7990 12cea06 e88a32d 93f5629 e88a32d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import gradio as gr
from transformers import pipeline, AutoImageProcessor, Swinv2ForImageClassification
from torchvision import transforms
import torch
from PIL import Image
# Ensure using GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the model and processor
image_processor = AutoImageProcessor.from_pretrained("haywoodsloan/ai-image-detector-deploy")
model = Swinv2ForImageClassification.from_pretrained("haywoodsloan/ai-image-detector-deploy")
model = model.to(device)
clf = pipeline(model=model, task="image-classification", image_processor=image_processor, device=device)
# Define class names
class_names = ['artificial', 'real']
def predict_image(img, confidence_threshold):
print(f"Type of img: {type(img)}") # Debugging statement
if not isinstance(img, Image.Image):
raise ValueError(f"Expected a PIL Image, but got {type(img)}")
# Convert the image to RGB if not already
if img.mode != 'RGB':
img_pil = img.convert('RGB')
else:
img_pil = img
# Resize the image
img_pil = transforms.Resize((256, 256))(img_pil)
# Get the prediction
prediction = clf(img_pil)
# Process the prediction to match the class names
result = {pred['label']: pred['score'] for pred in prediction}
# Ensure the result dictionary contains both class names
for class_name in class_names:
if class_name not in result:
result[class_name] = 0.0
# Check if either class meets the confidence threshold
if result['artificial'] >= confidence_threshold:
return f"⚠️ AI Generated Image, Confidence: {result['artificial']:.4f}"
elif result['real'] >= confidence_threshold:
return f"✅ Real Photo, Confidence: {result['real']:.4f}"
else:
return "🤷♂️ Uncertain, not confident enough to call."
# Define the Gradio interface
image = gr.Image(label="Image to Analyze", sources=['upload'], type='pil') # Ensure the image type is PIL
confidence_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Confidence Threshold")
label = gr.Label(num_top_classes=2)
gr.Interface(
fn=predict_image,
inputs=[image, confidence_slider],
outputs=label,
title="AI Generated Classification"
).launch() |