import gradio as gr
import numpy as np
import torch
import random
from PIL import Image
from skimage.feature import graycomatrix, graycoprops
from torchvision import transforms
import os
NUM_ROUNDS = 5 # Adjust the number of game rounds here
PROB_THRESHOLD = 0.5 # Adjust the probability threshold for model prediction here
# Load the model
model = torch.jit.load("SuSy.pt")
def process_image(image):
# Set Parameters
top_k_patches = 5
patch_size = 224
# Get the image dimensions
width, height = image.size
# Calculate the number of patches
num_patches_x = width // patch_size
num_patches_y = height // patch_size
# Divide the image in patches
patches = np.zeros((num_patches_x * num_patches_y, patch_size, patch_size, 3), dtype=np.uint8)
for i in range(num_patches_x):
for j in range(num_patches_y):
x = i * patch_size
y = j * patch_size
patch = image.crop((x, y, x + patch_size, y + patch_size))
patches[i * num_patches_y + j] = np.array(patch)
# Compute the most relevant patches (optional)
dissimilarity_scores = []
for patch in patches:
transform_patch = transforms.Compose([transforms.PILToTensor(), transforms.Grayscale()])
grayscale_patch = transform_patch(Image.fromarray(patch)).squeeze(0)
glcm = graycomatrix(grayscale_patch, [5], [0], 256, symmetric=True, normed=True)
dissimilarity_scores.append(graycoprops(glcm, "contrast")[0, 0])
# Sort patch indices by their dissimilarity score
sorted_indices = np.argsort(dissimilarity_scores)[::-1]
# Extract top k patches and convert them to tensor
top_patches = patches[sorted_indices[:top_k_patches]]
top_patches = torch.from_numpy(np.transpose(top_patches, (0, 3, 1, 2))) / 255.0
# Predict patches
model.eval()
with torch.no_grad():
preds = model(top_patches)
# Process results
classes = ['Authentic', 'DALLĀ·E 3', 'Stable Diffusion 1.x', 'MJ V5/V6', 'MJ V1/V2', 'Stable Diffusion XL']
mean_probs = preds.mean(dim=0).numpy()
# Create a dictionary of class probabilities
class_probs = {cls: prob for cls, prob in zip(classes, mean_probs)}
# Sort probabilities in descending order
sorted_probs = dict(sorted(class_probs.items(), key=lambda item: item[1], reverse=True))
return sorted_probs
class GameState:
def __init__(self):
self.user_score = 0
self.model_score = 0
self.current_round = 0
self.total_rounds = 2
self.game_images = []
self.is_game_active = False
self.last_results = None
self.waiting_for_input = True
def reset(self):
self.__init__()
game_state = GameState()
def load_images():
real_image_folder = "real_images"
fake_image_folder = "fake_images"
real_images = [os.path.join(real_image_folder, img) for img in os.listdir(real_image_folder)]
fake_images = [os.path.join(fake_image_folder, img) for img in os.listdir(fake_image_folder)]
selected_images = random.sample(real_images, 1) + random.sample(fake_images, 1)
random.shuffle(selected_images)
return selected_images
def create_score_html():
results_html = ""
if game_state.last_results:
results_html = f"""
Last Round Results:
Your guess: {game_state.last_results['user_guess']}
Model's guess: {game_state.last_results['model_guess']}
Correct answer: {game_state.last_results['correct_answer']}
"""
current_display_round = min(game_state.current_round + 1, game_state.total_rounds)
return f"""
Score Board
You
{game_state.user_score}
AI Model
{game_state.model_score}
Round: {current_display_round}/{game_state.total_rounds}
{results_html}
"""
def start_game():
game_state.reset()
game_state.game_images = load_images()
game_state.is_game_active = True
game_state.waiting_for_input = True
current_image = Image.open(game_state.game_images[0])
return (
gr.update(value=current_image, visible=True), # Show image
gr.update(visible=False), # Hide start button
gr.update(interactive=True, visible=True, value=None), # Show radio buttons
gr.update(visible=True, interactive=True), # Show submit button
create_score_html(),
gr.update(visible=False) # Hide feedback
)
def submit_guess(user_guess):
if not game_state.is_game_active or not game_state.waiting_for_input or user_guess is None:
return [gr.update()] * 6 # Return no updates if invalid state
current_image = Image.open(game_state.game_images[game_state.current_round])
model_prediction = process_image(current_image)
correct_answer = "Real" if "real_images" in game_state.game_images[game_state.current_round] else "Fake"
# Determine model's guess based on probabilities
model_guess = "Real" if model_prediction['Authentic'] > 0.5 else "Fake"
# Update scores
if user_guess == correct_answer:
game_state.user_score += 1
if model_guess == correct_answer:
game_state.model_score += 1
# Store last results for display
game_state.last_results = {
'user_guess': user_guess,
'model_guess': model_guess,
'correct_answer': correct_answer
}
game_state.current_round += 1
game_state.waiting_for_input = True
# Check if game is over
if game_state.current_round >= game_state.total_rounds:
game_state.is_game_active = False
return (
gr.update(value=None, visible=False), # Hide image
gr.update(visible=True), # Show start button
gr.update(interactive=False, visible=False, value=None), # Hide radio
gr.update(visible=False), # Hide submit button
create_score_html(),
gr.update(visible=True, value="Game Over! Click 'Start New Game' to play again.
")
)
# Continue to next round
next_image = Image.open(game_state.game_images[game_state.current_round])
return (
gr.update(value=next_image, visible=True), # Show next image
gr.update(visible=False), # Keep start button hidden
gr.update(interactive=True, visible=True, value=None), # Reset radio
gr.update(visible=True, interactive=True), # Show submit button
create_score_html(),
gr.update(visible=False) # Keep feedback hidden
)
# Custom CSS
custom_css = """
#game-container {
max-width: 1200px;
margin: 0 auto;
padding: 20px;
}
#start-button {
max-width: 200px;
margin: 0 auto;
}
"""
# Define Gradio interface
with gr.Blocks(css=custom_css) as iface:
with gr.Column(elem_id="game-container"):
gr.Markdown("# Real or Fake Image Challenge")
gr.Markdown("Can you beat the AI at detecting synthetic images?")
with gr.Row():
with gr.Column(scale=2):
image_display = gr.Image(
type="pil",
label="Current Image",
interactive=False,
visible=False
)
guess_input = gr.Radio(
choices=["Real", "Fake"],
label="Your Guess",
interactive=False,
visible=False
)
submit_button = gr.Button(
"Submit Guess",
visible=False,
variant="primary"
)
with gr.Column(scale=1):
score_display = gr.HTML()
with gr.Row():
with gr.Column(elem_id="start-button"):
start_button = gr.Button("Start New Game", variant="primary", size="sm")
feedback_display = gr.Markdown(visible=False)
# Event handlers
start_button.click(
fn=start_game,
outputs=[
image_display,
start_button,
guess_input,
submit_button,
score_display,
feedback_display
]
)
submit_button.click(
fn=submit_guess,
inputs=[guess_input],
outputs=[
image_display,
start_button,
guess_input,
submit_button,
score_display,
feedback_display
]
)
# Launch the interface
iface.launch()