File size: 3,743 Bytes
21bd14f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228987b
 
 
 
 
 
 
 
 
 
 
ff5cbd0
228987b
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import os
import torch

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the main classifier (Detector_best_model.pth)
main_model = models.mobilenet_v3_large(weights=None)  # Updated: weights=None

#num_ftrs = main_model.fc.in_features
# main_model.fc = nn.Linear(num_ftrs, 2)  # 2 classes: AI-generated_Image, Real_Image

num_ftrs = main_model.classifier[3].in_features
main_model.classifier[3] = nn.Linear(num_ftrs, 2)

# main_model.fc = nn.Sequential(
#     nn.Dropout(p=0.5),  # Match the training architecture
#     nn.Linear(num_ftrs, 2)  # 2 classes: AI-generated Image, Real Image
# )

main_model.load_state_dict(torch.load('best_model3_mobilenetv3_large.pth', map_location=device, weights_only=True))  # Updated: weights_only=True
main_model = main_model.to(device)
main_model.eval()

# Define class names for the classifier based on the Folder structure
classes_name = ['AI-generated Image', 'Real Image']

def convert_to_rgb(image):
    """
    Converts 'P' mode images with transparency to 'RGBA', and then to 'RGB'.
    This is to avoid transparency issues during model training.
    """
    if image.mode in ('P', 'RGBA'):
        return image.convert('RGB')
    return image

# Define preprocessing transformations (same used during training)
preprocess = transforms.Compose([
    transforms.Lambda(convert_to_rgb),
    transforms.Resize((224, 224)),  # Resize here, no need for shape argument in gr.Image
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # ImageNet normalization
])

def classify_image(image):
    # Open the image using PIL
    image = Image.fromarray(image)
    
    # Preprocess the image
    input_image = preprocess(image).unsqueeze(0).to(device)
    
    # Perform inference with the main classifier
    with torch.no_grad():
        output = main_model(input_image)
        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        confidence, predicted_class = torch.max(probabilities, 0)
    
    # Main classifier result
    main_prediction = classes_name[predicted_class]
    main_confidence = confidence.item()
    
    return f"Image is : {main_prediction} (Confidence: {main_confidence:.4f})"

# Gradio interface (updated)
image_input = gr.Image(image_mode="RGB")  # Removed shape argument
output_text = gr.Textbox()

# gr.Interface(fn=classify_image, inputs=image_input, outputs=[output_text], 
#              title="Detect AI-generated Image ",
#              description="Upload an image to Detected AI-generated Image .",
#              theme="default").launch()

gr.Interface(
    fn=classify_image,
    inputs=image_input,
    outputs=[output_text],
    title="Detect AI-generated Image",
    description=(
        "Upload an art image From 6 websites, collecting data from this to detect if it's AI-generated or a real image.\n\n"
        "### Main Dataset Used:\n"
        "- [AI-generated Images vs Real Images (Kaggle)](https://www.kaggle.com/datasets/tristanzhang32/ai-generated-images-vs-real-images)\n\n"
        "**Fake Images Collected From:**\n"
        "- 10,000 from [Stable Diffusion (OpenArt AI)](https://www.openart.ai)\n"
        "- 10,000 from [MidJourney (Imagine.Art)](https://www.imagine.art)\n"
        "- 10,000 from [DALL·E (OpenAI)](https://openai.com/dall-e-2)\n\n"
        "**Real Images Collected From:**\n"
        "- 22,500 from [Pexels](https://www.pexels.com) and [Unsplash](https://unsplash.com)\n"
        "- 7,500 from [WikiArt](https://www.wikiart.org)\n"
    ),
    theme="default"
).launch()