File size: 2,313 Bytes
93f5629
17c8406
93f5629
e88a32d
12cea06
e88a32d
 
 
93f5629
a9d7990
433b282
 
e88a32d
 
93f5629
a9d7990
f43015f
93f5629
e88a32d
7820a52
 
 
 
 
 
 
 
 
 
 
12cea06
a9d7990
 
12cea06
a9d7990
 
 
 
 
 
 
 
 
e88a32d
 
bf857a6
e88a32d
bf857a6
e88a32d
bf857a6
7820a52
a9d7990
12cea06
e88a32d
93f5629
e88a32d
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
from transformers import pipeline, AutoImageProcessor, Swinv2ForImageClassification
from torchvision import transforms
import torch
from PIL import Image

# Ensure using GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the model and processor
image_processor = AutoImageProcessor.from_pretrained("haywoodsloan/ai-image-detector-deploy")
model = Swinv2ForImageClassification.from_pretrained("haywoodsloan/ai-image-detector-deploy")
model = model.to(device)
clf = pipeline(model=model, task="image-classification", image_processor=image_processor, device=device)

# Define class names
class_names = ['artificial', 'real']

def predict_image(img, confidence_threshold):
    print(f"Type of img: {type(img)}")  # Debugging statement
    if not isinstance(img, Image.Image):
        raise ValueError(f"Expected a PIL Image, but got {type(img)}")
    
    # Convert the image to RGB if not already
    if img.mode != 'RGB':
        img_pil = img.convert('RGB')
    else:
        img_pil = img
    
    # Resize the image
    img_pil = transforms.Resize((256, 256))(img_pil)
    
    # Get the prediction
    prediction = clf(img_pil)
    
    # Process the prediction to match the class names
    result = {pred['label']: pred['score'] for pred in prediction}
    
    # Ensure the result dictionary contains both class names
    for class_name in class_names:
        if class_name not in result:
            result[class_name] = 0.0
    
    # Check if either class meets the confidence threshold
    if result['artificial'] >= confidence_threshold:
        return f"⚠️ AI Generated Image, Confidence: {result['artificial']:.4f}"
    elif result['real'] >= confidence_threshold:
        return f"✅ Real Photo, Confidence: {result['real']:.4f}"
    else:
        return "🤷‍♂️ Uncertain, not confident enough to call."
        
# Define the Gradio interface
image = gr.Image(label="Image to Analyze", sources=['upload'], type='pil')  # Ensure the image type is PIL
confidence_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Confidence Threshold")
label = gr.Label(num_top_classes=2)

gr.Interface(
    fn=predict_image,
    inputs=[image, confidence_slider],
    outputs=label,
    title="AI Generated Classification"
).launch()