File size: 2,103 Bytes
79958cf
 
 
 
 
 
 
104362c
f88be9d
79958cf
104362c
79958cf
104362c
79958cf
 
 
 
 
 
 
 
 
 
 
 
 
f88be9d
 
79958cf
 
 
 
 
 
 
 
 
 
 
 
 
104362c
61447d3
f88be9d
61447d3
f88be9d
 
104362c
f88be9d
 
 
79958cf
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import gradio as gr
from PIL import Image
import torch
import torchvision.models as models
from torchvision.transforms import v2 as transforms
import os

# Define the class names
class_names = ['Fake/AI-Generated Image', "Real/Not an AI-Generated Image"]

# Load the model
weights_path = "FaKe-ViT-B16.pth"
model = torch.load(weights_path, map_location=torch.device('cpu'))
model.eval()
# Preprocessing the image
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define the prediction function
def predict_image(image):
    image = preprocess(image)
    if image.shape[0] != 3:
        # image = image[:3, :, :]
        return "Invalid Image: Image should be in RGB format. Please upload a valid image."
    image = image.unsqueeze(0)
    with torch.inference_mode():
        output = model(image)
    output1 = torch.argmax(torch.softmax(output,dim=1),dim=1).item()
    return class_names[output1]



demo = gr.Interface(
    predict_image,
    gr.Image(image_mode="RGB",type="pil"),
    "text",
    flagging_options=["incorrect prediction"],
     examples=[
         ("images/cheetah.jpg"),
         ( "images/cat.jpg"),
         ("images/astronaut.jpg"),
         ("images/mountain.jpg"),
         ("images/unicorn.jpg") 
     ],
    title="<u>FaKe-ViT-B/16: Robust and Fast AI-Generated Image Detection using Vision Transformer(ViT-B/16):</u>",
    description="<p style='font-size: 20px;'>This is a demo to detect AI-Generated images using a fine-tuned Vision Transformer(ViT-B/16). Upload an image and the model will predict whether the image is AI-Generated or Real",
    article="<p style='font-size: 20px;'><b>Paper</b>: 'An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale', Alexey et al.<br/><b>Dataset</b>: 'Fake or Real competition dataset' at <a href='https://huggingface.co/datasets/mncai/Fake_or_Real_Competition_Dataset'>Fake or Real competition dataset</a>"
)

if __name__ == "__main__":
    demo.launch()