File size: 3,822 Bytes
7f4b1c2 a131fad 7f4b1c2 a131fad 076763e a131fad 076763e 758ef7d 5b8ca90 a131fad 7f4b1c2 a131fad 758ef7d a131fad 7f4b1c2 a131fad 7f4b1c2 a131fad 7f4b1c2 a131fad 68f66aa a131fad 7f4b1c2 a131fad 7f4b1c2 a131fad 7f4b1c2 a131fad 7f4b1c2 a131fad 7f4b1c2 da117ae c7c3a50 da117ae c7c3a50 da117ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import gradio as gr
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import os
import torch
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the main classifier (Detector_best_model.pth)
main_model = models.mobilenet_v3_large(weights=None) # Updated: weights=None
#num_ftrs = main_model.fc.in_features
# main_model.fc = nn.Linear(num_ftrs, 2) # 2 classes: AI-generated_Image, Real_Image
num_ftrs = main_model.classifier[3].in_features
main_model.classifier[3] = nn.Linear(num_ftrs, 2)
# main_model.fc = nn.Sequential(
# nn.Dropout(p=0.5), # Match the training architecture
# nn.Linear(num_ftrs, 2) # 2 classes: AI-generated Image, Real Image
# )
main_model.load_state_dict(torch.load('best_model3_mobilenetv3_large.pth', map_location=device, weights_only=True)) # Updated: weights_only=True
main_model = main_model.to(device)
main_model.eval()
# Define class names for the classifier based on the Folder structure
classes_name = ['AI-generated Image', 'Real Image']
def convert_to_rgb(image):
"""
Converts 'P' mode images with transparency to 'RGBA', and then to 'RGB'.
This is to avoid transparency issues during model training.
"""
if image.mode in ('P', 'RGBA'):
return image.convert('RGB')
return image
# Define preprocessing transformations (same used during training)
preprocess = transforms.Compose([
transforms.Lambda(convert_to_rgb),
transforms.Resize((224, 224)), # Resize here, no need for shape argument in gr.Image
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet normalization
])
def classify_image(image):
# Open the image using PIL
image = Image.fromarray(image)
# Preprocess the image
input_image = preprocess(image).unsqueeze(0).to(device)
# Perform inference with the main classifier
with torch.no_grad():
output = main_model(input_image)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
confidence, predicted_class = torch.max(probabilities, 0)
# Main classifier result
main_prediction = classes_name[predicted_class]
main_confidence = confidence.item()
return f"Image is : {main_prediction} (Confidence: {main_confidence:.4f})"
# Gradio interface (updated)
image_input = gr.Image(image_mode="RGB") # Removed shape argument
output_text = gr.Textbox()
# gr.Interface(fn=classify_image, inputs=image_input, outputs=[output_text],
# title="Detect AI-generated Image ",
# description="Upload an image to Detected AI-generated Image .",
# theme="default").launch()
gr.Interface(
fn=classify_image,
inputs=image_input,
outputs=[output_text],
title="Detect AI-generated Image",
description=(
"Upload an art image From 6 websites, collecting data from this to detect if it's AI-generated or a real image. take care image jpg or png only.\n\n"
"### Main Dataset Used:\n"
"- [AI-generated Images vs Real Images (Kaggle)](https://www.kaggle.com/datasets/tristanzhang32/ai-generated-images-vs-real-images)\n\n"
"**Fake Images Collected From:**\n"
"- 10,000 from [Stable Diffusion (OpenArt AI)](https://www.openart.ai)\n"
"- 10,000 from [MidJourney (Imagine.Art)](https://www.imagine.art)\n"
"- 10,000 from [DALL·E (OpenAI)](https://openai.com/dall-e-2)\n\n"
"**Real Images Collected From:**\n"
"- 7,500 from [WikiArt](https://www.wikiart.org)\n"
"- 22,500 from [Pexels](https://www.pexels.com) and [Unsplash but take care image jpg or png only ](https://unsplash.com)\n"
),
theme="default"
).launch() |