File size: 3,743 Bytes
21bd14f 228987b ff5cbd0 228987b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import gradio as gr
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import os
import torch
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the main classifier (Detector_best_model.pth)
main_model = models.mobilenet_v3_large(weights=None) # Updated: weights=None
#num_ftrs = main_model.fc.in_features
# main_model.fc = nn.Linear(num_ftrs, 2) # 2 classes: AI-generated_Image, Real_Image
num_ftrs = main_model.classifier[3].in_features
main_model.classifier[3] = nn.Linear(num_ftrs, 2)
# main_model.fc = nn.Sequential(
# nn.Dropout(p=0.5), # Match the training architecture
# nn.Linear(num_ftrs, 2) # 2 classes: AI-generated Image, Real Image
# )
main_model.load_state_dict(torch.load('best_model3_mobilenetv3_large.pth', map_location=device, weights_only=True)) # Updated: weights_only=True
main_model = main_model.to(device)
main_model.eval()
# Define class names for the classifier based on the Folder structure
classes_name = ['AI-generated Image', 'Real Image']
def convert_to_rgb(image):
"""
Converts 'P' mode images with transparency to 'RGBA', and then to 'RGB'.
This is to avoid transparency issues during model training.
"""
if image.mode in ('P', 'RGBA'):
return image.convert('RGB')
return image
# Define preprocessing transformations (same used during training)
preprocess = transforms.Compose([
transforms.Lambda(convert_to_rgb),
transforms.Resize((224, 224)), # Resize here, no need for shape argument in gr.Image
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet normalization
])
def classify_image(image):
# Open the image using PIL
image = Image.fromarray(image)
# Preprocess the image
input_image = preprocess(image).unsqueeze(0).to(device)
# Perform inference with the main classifier
with torch.no_grad():
output = main_model(input_image)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
confidence, predicted_class = torch.max(probabilities, 0)
# Main classifier result
main_prediction = classes_name[predicted_class]
main_confidence = confidence.item()
return f"Image is : {main_prediction} (Confidence: {main_confidence:.4f})"
# Gradio interface (updated)
image_input = gr.Image(image_mode="RGB") # Removed shape argument
output_text = gr.Textbox()
# gr.Interface(fn=classify_image, inputs=image_input, outputs=[output_text],
# title="Detect AI-generated Image ",
# description="Upload an image to Detected AI-generated Image .",
# theme="default").launch()
gr.Interface(
fn=classify_image,
inputs=image_input,
outputs=[output_text],
title="Detect AI-generated Image",
description=(
"Upload an art image From 6 websites, collecting data from this to detect if it's AI-generated or a real image.\n\n"
"### Main Dataset Used:\n"
"- [AI-generated Images vs Real Images (Kaggle)](https://www.kaggle.com/datasets/tristanzhang32/ai-generated-images-vs-real-images)\n\n"
"**Fake Images Collected From:**\n"
"- 10,000 from [Stable Diffusion (OpenArt AI)](https://www.openart.ai)\n"
"- 10,000 from [MidJourney (Imagine.Art)](https://www.imagine.art)\n"
"- 10,000 from [DALL·E (OpenAI)](https://openai.com/dall-e-2)\n\n"
"**Real Images Collected From:**\n"
"- 22,500 from [Pexels](https://www.pexels.com) and [Unsplash](https://unsplash.com)\n"
"- 7,500 from [WikiArt](https://www.wikiart.org)\n"
),
theme="default"
).launch()
|