import gradio as gr
import numpy as np
import torch
import random
from PIL import Image
from skimage.feature import graycomatrix, graycoprops
from torchvision import transforms
import os
NUM_ROUNDS = 10
PROB_THRESHOLD = 0.3
# Load the model
model = torch.jit.load("SuSy.pt")
def process_image(image):
# Set Parameters
top_k_patches = 5
patch_size = 224
# Get the image dimensions
width, height = image.size
# Calculate the number of patches
num_patches_x = width // patch_size
num_patches_y = height // patch_size
# Divide the image in patches
patches = np.zeros((num_patches_x * num_patches_y, patch_size, patch_size, 3), dtype=np.uint8)
for i in range(num_patches_x):
for j in range(num_patches_y):
x = i * patch_size
y = j * patch_size
patch = image.crop((x, y, x + patch_size, y + patch_size))
patches[i * num_patches_y + j] = np.array(patch)
# Compute the most relevant patches
dissimilarity_scores = []
for patch in patches:
transform_patch = transforms.Compose([transforms.PILToTensor(), transforms.Grayscale()])
grayscale_patch = transform_patch(Image.fromarray(patch)).squeeze(0)
glcm = graycomatrix(grayscale_patch, [5], [0], 256, symmetric=True, normed=True)
dissimilarity_scores.append(graycoprops(glcm, "contrast")[0, 0])
# Sort patch indices by their dissimilarity score
sorted_indices = np.argsort(dissimilarity_scores)[::-1]
# Extract top k patches and convert them to tensor
top_patches = patches[sorted_indices[:top_k_patches]]
top_patches = torch.from_numpy(np.transpose(top_patches, (0, 3, 1, 2))) / 255.0
# Predict patches
model.eval()
with torch.no_grad():
preds = model(top_patches)
# Process results
classes = ['Authentic', 'DALL·E 3', 'Stable Diffusion 1.x', 'MJ V5/V6', 'MJ V1/V2', 'Stable Diffusion XL']
mean_probs = preds.mean(dim=0).numpy()
# Create a dictionary of class probabilities
class_probs = {cls: prob for cls, prob in zip(classes, mean_probs)}
# Sort probabilities in descending order
sorted_probs = dict(sorted(class_probs.items(), key=lambda item: item[1], reverse=True))
return sorted_probs
class GameState:
def __init__(self):
self.user_score = 0
self.model_score = 0
self.current_round = 0
self.total_rounds = NUM_ROUNDS
self.game_images = []
self.is_game_active = False
self.last_results = None
self.waiting_for_input = True
def reset(self):
self.__init__()
def get_game_over_message(self):
if self.user_score > self.model_score:
return """
🎉 Congratulations! You won! 🎉
You've outperformed SuSy in detecting AI-generated images.
Click 'Start New Game' to play again.
"""
elif self.user_score < self.model_score:
return """
Better luck next time! SuSy won this round.
Keep practicing to improve your detection skills.
Click 'Start New Game' to try again.
"""
else:
return """
It's a tie! You matched SuSy's performance!
You're getting good at this.
Click 'Start New Game' to play again.
"""
game_state = GameState()
def load_images():
real_image_folder = "real_images"
fake_image_folder = "fake_images"
real_images = [os.path.join(real_image_folder, img) for img in os.listdir(real_image_folder)]
fake_images = [os.path.join(fake_image_folder, img) for img in os.listdir(fake_image_folder)]
selected_images = random.sample(real_images, NUM_ROUNDS // 2) + random.sample(fake_images, NUM_ROUNDS // 2)
random.shuffle(selected_images)
return selected_images
def create_score_html():
results_html = ""
if game_state.last_results:
results_html = f"""
Last Round Results:
Your guess: {game_state.last_results['user_guess']}
Model's guess: {game_state.last_results['model_guess']}
Correct answer: {game_state.last_results['correct_answer']}
"""
current_display_round = min(game_state.current_round + 1, game_state.total_rounds)
return f"""
Score Board
You
{game_state.user_score}
AI Model
{game_state.model_score}
Round: {current_display_round}/{game_state.total_rounds}
{results_html}
"""
def start_game():
game_state.reset()
game_state.game_images = load_images()
game_state.is_game_active = True
game_state.waiting_for_input = True
current_image = Image.open(game_state.game_images[0])
return (
gr.update(value=current_image, visible=True),
gr.update(visible=False),
gr.update(visible=True, interactive=True),
gr.update(visible=True, interactive=True),
create_score_html(),
gr.update(visible=False)
)
def submit_guess(user_guess):
if not game_state.is_game_active or not game_state.waiting_for_input:
return [gr.update()] * 6
# Compute Model Guess
current_image = Image.open(game_state.game_images[game_state.current_round])
model_prediction = process_image(current_image)
model_guess = "Real" if model_prediction['Authentic'] > PROB_THRESHOLD else "Fake"
correct_answer = "Real" if "real_images" in game_state.game_images[game_state.current_round] else "Fake"
# Update scores
if user_guess == correct_answer:
game_state.user_score += 1
if model_guess == correct_answer:
game_state.model_score += 1
# Store last results for display
game_state.last_results = {
'user_guess': user_guess,
'model_guess': model_guess,
'correct_answer': correct_answer
}
game_state.current_round += 1
game_state.waiting_for_input = True
# Check if game is over
if game_state.current_round >= game_state.total_rounds:
game_state.is_game_active = False
return (
gr.update(value=None, visible=False),
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
create_score_html(),
gr.update(visible=True, value=game_state.get_game_over_message())
)
next_image = Image.open(game_state.game_images[game_state.current_round])
return (
gr.update(value=next_image, visible=True),
gr.update(visible=False),
gr.update(visible=True, interactive=True),
gr.update(visible=True, interactive=True),
create_score_html(),
gr.update(visible=False)
)
# Custom CSS
custom_css = """
#game-container {
max-width: 1200px;
margin: 0 auto;
padding: 20px;
}
#start-button {
max-width: 200px;
margin: 0 auto;
}
#guess-buttons {
display: flex;
gap: 10px;
justify-content: center;
margin-top: 20px;
}
.guess-button {
min-width: 120px;
}
.image-container img {
max-height: 768px !important;
width: auto !important;
object-fit: contain !important;
}
"""
# Define Gradio interface
with gr.Blocks(css=custom_css) as iface:
with gr.Column(elem_id="game-container"):
gr.HTML("""
""")
with gr.Row():
with gr.Column(scale=2):
image_display = gr.Image(
type="pil",
label="Current Image",
interactive=False,
visible=False,
elem_classes=["image-container"]
)
with gr.Row(elem_id="guess-buttons"):
real_button = gr.Button(
"Real",
visible=False,
variant="primary",
elem_classes=["guess-button"]
)
fake_button = gr.Button(
"Fake",
visible=False,
variant="primary",
elem_classes=["guess-button"]
)
with gr.Column(scale=1):
score_display = gr.HTML()
with gr.Row():
with gr.Column(elem_id="start-button"):
start_button = gr.Button("Start New Game", variant="primary", size="sm")
feedback_display = gr.Markdown(visible=False)
# Event handlers
start_button.click(
fn=start_game,
outputs=[
image_display,
start_button,
real_button,
fake_button,
score_display,
feedback_display
]
)
real_button.click(
fn=lambda: submit_guess("Real"),
outputs=[
image_display,
start_button,
real_button,
fake_button,
score_display,
feedback_display
]
)
fake_button.click(
fn=lambda: submit_guess("Fake"),
outputs=[
image_display,
start_button,
real_button,
fake_button,
score_display,
feedback_display
]
)
# Launch the interface
iface.launch()