Spaces:
Running
Running
File size: 2,183 Bytes
22e1b62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import gradio as gr
import torchvision.transforms as transforms
from CNN_model_classifier import predict_cnn
from diffusion_model_classifier import (
ImageClassifier,
predict_single_image,
)
gr.set_static_paths(paths=["samples/"])
diffusion_model = (
"Diffusion/model_checkpoints/image-classifier-step=7007-val_loss=0.09.ckpt"
)
cnn_model = "CNN/model_checkpoints/blur_jpg_prob0.5.pth"
def get_prediction_diffusion(image):
model = ImageClassifier.load_from_checkpoint(diffusion_model)
prediction = predict_single_image(image, model)
print(prediction)
return (prediction >= 0.001, prediction)
def get_prediction_cnn(image):
prediction = predict_cnn(image, cnn_model)
return (prediction >= 0.5, prediction)
def predict(inp):
# Define the transformations for the image
transform = transforms.Compose(
[
transforms.Resize((224, 224)), # Image size expected by ResNet50
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
],
)
image_tensor = transform(inp)
pred_diff, prob_diff = get_prediction_diffusion(image_tensor)
pred_cnn, prob_cnn = get_prediction_cnn(image_tensor)
verdict = (
"AI Generated" if (pred_diff or pred_cnn) else "No GenAI detected"
)
return (
f"<h1>{verdict}</h1>"
f"<ul>"
f"<li>Diffusion detection score: {prob_diff:.2} "
f"{'(MATCH)' if pred_diff else ''}</li>"
f"<li>CNN detection score: {prob_cnn:.1%} "
f"{'(MATCH)' if pred_cnn else ''}</li>"
f"</ul>"
)
demo = gr.Interface(
title="AI-generated image detection",
description="Demo by NICT & Tokyo Techies ",
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.HTML(),
examples=[
["samples/fake_dalle.jpg", "Generated (Dall-E)"],
["samples/fake_midjourney.png", "Generated (MidJourney)"],
["samples/fake_stable.jpg", "Generated (Stable Diffusion)"],
["samples/fake_cnn.png", "Generated (GAN)"],
["samples/real.png", "Organic"],
],
)
demo.launch()
|