import gradio as gr import numpy as np import torch import random from PIL import Image from skimage.feature import graycomatrix, graycoprops from torchvision import transforms import os NUM_ROUNDS = 10 PROB_THRESHOLD = 0.3 # Load the model model = torch.jit.load("SuSy.pt") def process_image(image): # Set Parameters top_k_patches = 5 patch_size = 224 # Get the image dimensions width, height = image.size # Calculate the number of patches num_patches_x = width // patch_size num_patches_y = height // patch_size # Divide the image in patches patches = np.zeros((num_patches_x * num_patches_y, patch_size, patch_size, 3), dtype=np.uint8) for i in range(num_patches_x): for j in range(num_patches_y): x = i * patch_size y = j * patch_size patch = image.crop((x, y, x + patch_size, y + patch_size)) patches[i * num_patches_y + j] = np.array(patch) # Compute the most relevant patches dissimilarity_scores = [] for patch in patches: transform_patch = transforms.Compose([transforms.PILToTensor(), transforms.Grayscale()]) grayscale_patch = transform_patch(Image.fromarray(patch)).squeeze(0) glcm = graycomatrix(grayscale_patch, [5], [0], 256, symmetric=True, normed=True) dissimilarity_scores.append(graycoprops(glcm, "contrast")[0, 0]) # Sort patch indices by their dissimilarity score sorted_indices = np.argsort(dissimilarity_scores)[::-1] # Extract top k patches and convert them to tensor top_patches = patches[sorted_indices[:top_k_patches]] top_patches = torch.from_numpy(np.transpose(top_patches, (0, 3, 1, 2))) / 255.0 # Predict patches model.eval() with torch.no_grad(): preds = model(top_patches) # Process results classes = ['Authentic', 'DALL·E 3', 'Stable Diffusion 1.x', 'MJ V5/V6', 'MJ V1/V2', 'Stable Diffusion XL'] mean_probs = preds.mean(dim=0).numpy() # Create a dictionary of class probabilities class_probs = {cls: prob for cls, prob in zip(classes, mean_probs)} # Sort probabilities in descending order sorted_probs = dict(sorted(class_probs.items(), key=lambda item: item[1], reverse=True)) return sorted_probs class GameState: def __init__(self): self.user_score = 0 self.model_score = 0 self.current_round = 0 self.total_rounds = NUM_ROUNDS self.game_images = [] self.is_game_active = False self.last_results = None self.waiting_for_input = True def reset(self): self.__init__() def get_game_over_message(self): if self.user_score > self.model_score: return """
🎉 Congratulations! You won! 🎉
You've outperformed SuSy in detecting AI-generated images.
Click 'Start New Game' to play again.
""" elif self.user_score < self.model_score: return """
Better luck next time! SuSy won this round.
Keep practicing to improve your detection skills.
Click 'Start New Game' to try again.
""" else: return """
It's a tie! You matched SuSy's performance!
You're getting good at this.
Click 'Start New Game' to play again.
""" game_state = GameState() def load_images(): real_image_folder = "real_images" fake_image_folder = "fake_images" real_images = [os.path.join(real_image_folder, img) for img in os.listdir(real_image_folder)] fake_images = [os.path.join(fake_image_folder, img) for img in os.listdir(fake_image_folder)] selected_images = random.sample(real_images, NUM_ROUNDS // 2) + random.sample(fake_images, NUM_ROUNDS // 2) random.shuffle(selected_images) return selected_images def create_score_html(): results_html = "" if game_state.last_results: results_html = f"""

Last Round Results:

Your guess: {game_state.last_results['user_guess']}

Model's guess: {game_state.last_results['model_guess']}

Correct answer: {game_state.last_results['correct_answer']}

""" current_display_round = min(game_state.current_round + 1, game_state.total_rounds) return f"""

Score Board

You

{game_state.user_score}

AI Model

{game_state.model_score}

Round: {current_display_round}/{game_state.total_rounds}

{results_html}
""" def start_game(): game_state.reset() game_state.game_images = load_images() game_state.is_game_active = True game_state.waiting_for_input = True current_image = Image.open(game_state.game_images[0]) return ( gr.update(value=current_image, visible=True), gr.update(visible=False), gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), create_score_html(), gr.update(visible=False) ) def submit_guess(user_guess): if not game_state.is_game_active or not game_state.waiting_for_input: return [gr.update()] * 6 # Compute Model Guess current_image = Image.open(game_state.game_images[game_state.current_round]) model_prediction = process_image(current_image) model_guess = "Real" if model_prediction['Authentic'] > PROB_THRESHOLD else "Fake" correct_answer = "Real" if "real_images" in game_state.game_images[game_state.current_round] else "Fake" # Update scores if user_guess == correct_answer: game_state.user_score += 1 if model_guess == correct_answer: game_state.model_score += 1 # Store last results for display game_state.last_results = { 'user_guess': user_guess, 'model_guess': model_guess, 'correct_answer': correct_answer } game_state.current_round += 1 game_state.waiting_for_input = True # Check if game is over if game_state.current_round >= game_state.total_rounds: game_state.is_game_active = False return ( gr.update(value=None, visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), create_score_html(), gr.update(visible=True, value=game_state.get_game_over_message()) ) next_image = Image.open(game_state.game_images[game_state.current_round]) return ( gr.update(value=next_image, visible=True), gr.update(visible=False), gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), create_score_html(), gr.update(visible=False) ) # Custom CSS custom_css = """ #game-container { max-width: 1200px; margin: 0 auto; padding: 20px; } #start-button { max-width: 200px; margin: 0 auto; } #guess-buttons { display: flex; gap: 10px; justify-content: center; margin-top: 20px; } .guess-button { min-width: 120px; } .image-container img { max-height: 768px !important; width: auto !important; object-fit: contain !important; } """ # Define Gradio interface with gr.Blocks(css=custom_css) as iface: with gr.Column(elem_id="game-container"): gr.HTML("""
SuSy Logo

Compete against SuSy to spot AI-Generated images! SuSy can distinguish between authentic images and those generated by DALL·E, Midjourney and Stable Diffusion.

Learn more about SuSy: Present and Future Generalization of Synthetic Image Detectors

Enter the SuSy-verse! Model | Code | Dataset

""") with gr.Row(): with gr.Column(scale=2): image_display = gr.Image( type="pil", label="Current Image", interactive=False, visible=False, elem_classes=["image-container"] ) with gr.Row(elem_id="guess-buttons"): real_button = gr.Button( "Real", visible=False, variant="primary", elem_classes=["guess-button"] ) fake_button = gr.Button( "Fake", visible=False, variant="primary", elem_classes=["guess-button"] ) with gr.Column(scale=1): score_display = gr.HTML() with gr.Row(): with gr.Column(elem_id="start-button"): start_button = gr.Button("Start New Game", variant="primary", size="sm") feedback_display = gr.Markdown(visible=False) # Event handlers start_button.click( fn=start_game, outputs=[ image_display, start_button, real_button, fake_button, score_display, feedback_display ] ) real_button.click( fn=lambda: submit_guess("Real"), outputs=[ image_display, start_button, real_button, fake_button, score_display, feedback_display ] ) fake_button.click( fn=lambda: submit_guess("Fake"), outputs=[ image_display, start_button, real_button, fake_button, score_display, feedback_display ] ) # Launch the interface iface.launch()