|
import gradio as gr |
|
from transformers import pipeline, AutoImageProcessor, Swinv2ForImageClassification |
|
from torchvision import transforms |
|
import torch |
|
from PIL import Image |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
image_processor = AutoImageProcessor.from_pretrained("haywoodsloan/ai-image-detector-deploy") |
|
model = Swinv2ForImageClassification.from_pretrained("haywoodsloan/ai-image-detector-deploy") |
|
model = model.to(device) |
|
clf = pipeline(model=model, task="image-classification", image_processor=image_processor, device=device) |
|
|
|
|
|
class_names = ['artificial', 'real'] |
|
|
|
def predict_image(img, confidence_threshold): |
|
print(f"Type of img: {type(img)}") |
|
if not isinstance(img, Image.Image): |
|
raise ValueError(f"Expected a PIL Image, but got {type(img)}") |
|
|
|
|
|
if img.mode != 'RGB': |
|
img_pil = img.convert('RGB') |
|
else: |
|
img_pil = img |
|
|
|
|
|
img_pil = transforms.Resize((256, 256))(img_pil) |
|
|
|
|
|
prediction = clf(img_pil) |
|
|
|
|
|
result = {pred['label']: pred['score'] for pred in prediction} |
|
|
|
|
|
for class_name in class_names: |
|
if class_name not in result: |
|
result[class_name] = 0.0 |
|
|
|
|
|
if result['artificial'] >= confidence_threshold: |
|
return f"Label: artificial, Confidence: {result['artificial']:.4f}" |
|
elif result['real'] >= confidence_threshold: |
|
return f"Label: real, Confidence: {result['real']:.4f}" |
|
else: |
|
return "Uncertain Classification" |
|
|
|
|
|
image = gr.Image(label="Image to Analyze", sources=['upload'], type='pil') |
|
confidence_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Confidence Threshold") |
|
label = gr.Label(num_top_classes=2) |
|
|
|
gr.Interface( |
|
fn=predict_image, |
|
inputs=[image, confidence_slider], |
|
outputs=label, |
|
title="AI Generated Classification" |
|
).launch() |