File size: 2,183 Bytes
22e1b62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
import torchvision.transforms as transforms
from CNN_model_classifier import predict_cnn
from diffusion_model_classifier import (
    ImageClassifier,
    predict_single_image,
)

gr.set_static_paths(paths=["samples/"])
diffusion_model = (
    "Diffusion/model_checkpoints/image-classifier-step=7007-val_loss=0.09.ckpt"
)
cnn_model = "CNN/model_checkpoints/blur_jpg_prob0.5.pth"


def get_prediction_diffusion(image):
    model = ImageClassifier.load_from_checkpoint(diffusion_model)

    prediction = predict_single_image(image, model)
    print(prediction)
    return (prediction >= 0.001, prediction)


def get_prediction_cnn(image):
    prediction = predict_cnn(image, cnn_model)
    return (prediction >= 0.5, prediction)


def predict(inp):
    # Define the transformations for the image
    transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),  # Image size expected by ResNet50
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
        ],
    )
    image_tensor = transform(inp)
    pred_diff, prob_diff = get_prediction_diffusion(image_tensor)
    pred_cnn, prob_cnn = get_prediction_cnn(image_tensor)
    verdict = (
        "AI Generated" if (pred_diff or pred_cnn) else "No GenAI detected"
    )
    return (
        f"<h1>{verdict}</h1>"
        f"<ul>"
        f"<li>Diffusion detection score: {prob_diff:.2} "
        f"{'(MATCH)' if pred_diff else ''}</li>"
        f"<li>CNN detection score: {prob_cnn:.1%} "
        f"{'(MATCH)' if pred_cnn else ''}</li>"
        f"</ul>"
    )


demo = gr.Interface(
    title="AI-generated image detection",
    description="Demo by NICT & Tokyo Techies ",
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=gr.HTML(),
    examples=[
        ["samples/fake_dalle.jpg", "Generated (Dall-E)"],
        ["samples/fake_midjourney.png", "Generated (MidJourney)"],
        ["samples/fake_stable.jpg", "Generated (Stable Diffusion)"],
        ["samples/fake_cnn.png", "Generated (GAN)"],
        ["samples/real.png", "Organic"],
    ],
)

demo.launch()