File size: 3,822 Bytes
7f4b1c2
 
 
 
 
 
 
a131fad
7f4b1c2
 
 
 
a131fad
076763e
 
 
a131fad
076763e
 
 
 
 
 
 
 
758ef7d
5b8ca90
a131fad
 
7f4b1c2
a131fad
758ef7d
 
a131fad
 
 
 
 
 
 
 
 
 
7f4b1c2
a131fad
 
7f4b1c2
a131fad
7f4b1c2
 
a131fad
 
68f66aa
a131fad
 
7f4b1c2
 
a131fad
7f4b1c2
a131fad
7f4b1c2
 
 
a131fad
 
 
 
 
7f4b1c2
a131fad
 
7f4b1c2
 
da117ae
 
 
 
 
 
 
 
 
 
 
c7c3a50
da117ae
 
 
 
 
 
 
c7c3a50
 
da117ae
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import os
import torch

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the main classifier (Detector_best_model.pth)
main_model = models.mobilenet_v3_large(weights=None)  # Updated: weights=None

#num_ftrs = main_model.fc.in_features
# main_model.fc = nn.Linear(num_ftrs, 2)  # 2 classes: AI-generated_Image, Real_Image

num_ftrs = main_model.classifier[3].in_features
main_model.classifier[3] = nn.Linear(num_ftrs, 2)

# main_model.fc = nn.Sequential(
#     nn.Dropout(p=0.5),  # Match the training architecture
#     nn.Linear(num_ftrs, 2)  # 2 classes: AI-generated Image, Real Image
# )

main_model.load_state_dict(torch.load('best_model3_mobilenetv3_large.pth', map_location=device, weights_only=True))  # Updated: weights_only=True
main_model = main_model.to(device)
main_model.eval()

# Define class names for the classifier based on the Folder structure
classes_name = ['AI-generated Image', 'Real Image']

def convert_to_rgb(image):
    """
    Converts 'P' mode images with transparency to 'RGBA', and then to 'RGB'.
    This is to avoid transparency issues during model training.
    """
    if image.mode in ('P', 'RGBA'):
        return image.convert('RGB')
    return image

# Define preprocessing transformations (same used during training)
preprocess = transforms.Compose([
    transforms.Lambda(convert_to_rgb),
    transforms.Resize((224, 224)),  # Resize here, no need for shape argument in gr.Image
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # ImageNet normalization
])

def classify_image(image):
    # Open the image using PIL
    image = Image.fromarray(image)
    
    # Preprocess the image
    input_image = preprocess(image).unsqueeze(0).to(device)
    
    # Perform inference with the main classifier
    with torch.no_grad():
        output = main_model(input_image)
        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        confidence, predicted_class = torch.max(probabilities, 0)
    
    # Main classifier result
    main_prediction = classes_name[predicted_class]
    main_confidence = confidence.item()
    
    return f"Image is : {main_prediction} (Confidence: {main_confidence:.4f})"

# Gradio interface (updated)
image_input = gr.Image(image_mode="RGB")  # Removed shape argument
output_text = gr.Textbox()

# gr.Interface(fn=classify_image, inputs=image_input, outputs=[output_text], 
#              title="Detect AI-generated Image ",
#              description="Upload an image to Detected AI-generated Image .",
#              theme="default").launch()

gr.Interface(
    fn=classify_image,
    inputs=image_input,
    outputs=[output_text],
    title="Detect AI-generated Image",
    description=(
        "Upload an art image From 6 websites, collecting data from this to detect if it's AI-generated or a real image. take care image jpg or png only.\n\n"
        "### Main Dataset Used:\n"
        "- [AI-generated Images vs Real Images (Kaggle)](https://www.kaggle.com/datasets/tristanzhang32/ai-generated-images-vs-real-images)\n\n"
        "**Fake Images Collected From:**\n"
        "- 10,000 from [Stable Diffusion (OpenArt AI)](https://www.openart.ai)\n"
        "- 10,000 from [MidJourney (Imagine.Art)](https://www.imagine.art)\n"
        "- 10,000 from [DALL·E (OpenAI)](https://openai.com/dall-e-2)\n\n"
        "**Real Images Collected From:**\n"
        "- 7,500 from [WikiArt](https://www.wikiart.org)\n" 
        "- 22,500 from [Pexels](https://www.pexels.com) and [Unsplash but take care image jpg or png only ](https://unsplash.com)\n"
        
    ),
    theme="default"
).launch()