|
import os |
|
import random |
|
from dataclasses import dataclass |
|
from threading import Lock |
|
from typing import List, Optional |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from skimage.feature import graycomatrix, graycoprops |
|
from torchvision import transforms |
|
|
|
NUM_ROUNDS = 10 |
|
PROB_THRESHOLD = 0.3 |
|
|
|
|
|
model = torch.jit.load("SuSy.pt") |
|
|
|
@dataclass |
|
class GameResults: |
|
user_guess: str |
|
model_guess: str |
|
correct_answer: str |
|
|
|
class GameState: |
|
def __init__(self): |
|
self.lock = Lock() |
|
self._reset() |
|
|
|
def _reset(self): |
|
"""Internal reset method - should be called within a lock""" |
|
self.user_score = 0 |
|
self.model_score = 0 |
|
self.current_round = 0 |
|
self.total_rounds = NUM_ROUNDS |
|
self.game_images: List[str] = [] |
|
self.is_game_active = False |
|
self.last_results: Optional[GameResults] = None |
|
self.processing_submission = False |
|
|
|
def reset(self): |
|
"""Public reset method with lock protection""" |
|
with self.lock: |
|
self._reset() |
|
|
|
def start_new_game(self) -> tuple[bool, Optional[str]]: |
|
""" |
|
Starts a new game and returns (success, first_image_path) |
|
""" |
|
with self.lock: |
|
if self.is_game_active: |
|
return False, None |
|
|
|
try: |
|
self._reset() |
|
self.game_images = load_images() |
|
if not self.game_images: |
|
return False, None |
|
|
|
self.is_game_active = True |
|
return True, self.game_images[0] |
|
except Exception as e: |
|
print(f"Error starting new game: {e}") |
|
self._reset() |
|
return False, None |
|
|
|
def can_submit_guess(self) -> bool: |
|
with self.lock: |
|
return ( |
|
self.is_game_active and |
|
not self.processing_submission and |
|
self.current_round < self.total_rounds and |
|
len(self.game_images) > self.current_round |
|
) |
|
|
|
def start_submission(self) -> bool: |
|
with self.lock: |
|
if not self.can_submit_guess(): |
|
return False |
|
self.processing_submission = True |
|
return True |
|
|
|
def finish_submission(self, results: GameResults): |
|
with self.lock: |
|
if results.user_guess == results.correct_answer: |
|
self.user_score += 1 |
|
if results.model_guess == results.correct_answer: |
|
self.model_score += 1 |
|
|
|
self.last_results = results |
|
self.current_round += 1 |
|
self.processing_submission = False |
|
|
|
if self.current_round >= self.total_rounds: |
|
self.is_game_active = False |
|
|
|
def get_current_image(self) -> Optional[str]: |
|
with self.lock: |
|
if not self.is_game_active or self.current_round >= len(self.game_images): |
|
return None |
|
return self.game_images[self.current_round] |
|
|
|
def get_game_state(self): |
|
"""Get a snapshot of the current game state""" |
|
with self.lock: |
|
return { |
|
'is_active': self.is_game_active, |
|
'current_round': self.current_round, |
|
'total_rounds': self.total_rounds, |
|
'user_score': self.user_score, |
|
'model_score': self.model_score, |
|
'last_results': self.last_results |
|
} |
|
|
|
def get_game_over_message(self) -> str: |
|
with self.lock: |
|
if self.user_score > self.model_score: |
|
return """ |
|
<div style='text-align: center; margin-top: 20px; font-size: 1.2em;'> |
|
🎉 Congratulations! You won! 🎉<br> |
|
You've outperformed SuSy in detecting AI-generated images.<br> |
|
Click 'Start New Game' to play again. |
|
</div> |
|
""" |
|
elif self.user_score < self.model_score: |
|
return """ |
|
<div style='text-align: center; margin-top: 20px; font-size: 1.2em;'> |
|
Better luck next time! SuSy won this round.<br> |
|
Keep practicing to improve your detection skills.<br> |
|
Click 'Start New Game' to try again. |
|
</div> |
|
""" |
|
else: |
|
return """ |
|
<div style='text-align: center; margin-top: 20px; font-size: 1.2em;'> |
|
It's a tie! You matched SuSy's performance!<br> |
|
You're getting good at this.<br> |
|
Click 'Start New Game' to play again. |
|
</div> |
|
""" |
|
|
|
def process_image(image): |
|
|
|
top_k_patches = 5 |
|
patch_size = 224 |
|
|
|
|
|
width, height = image.size |
|
|
|
|
|
num_patches_x = width // patch_size |
|
num_patches_y = height // patch_size |
|
|
|
|
|
patches = np.zeros((num_patches_x * num_patches_y, patch_size, patch_size, 3), dtype=np.uint8) |
|
for i in range(num_patches_x): |
|
for j in range(num_patches_y): |
|
x = i * patch_size |
|
y = j * patch_size |
|
patch = image.crop((x, y, x + patch_size, y + patch_size)) |
|
patches[i * num_patches_y + j] = np.array(patch) |
|
|
|
|
|
dissimilarity_scores = [] |
|
for patch in patches: |
|
transform_patch = transforms.Compose([transforms.PILToTensor(), transforms.Grayscale()]) |
|
grayscale_patch = transform_patch(Image.fromarray(patch)).squeeze(0) |
|
glcm = graycomatrix(grayscale_patch, [5], [0], 256, symmetric=True, normed=True) |
|
dissimilarity_scores.append(graycoprops(glcm, "contrast")[0, 0]) |
|
|
|
|
|
sorted_indices = np.argsort(dissimilarity_scores)[::-1] |
|
|
|
|
|
top_patches = patches[sorted_indices[:top_k_patches]] |
|
top_patches = torch.from_numpy(np.transpose(top_patches, (0, 3, 1, 2))) / 255.0 |
|
|
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
preds = model(top_patches) |
|
|
|
|
|
classes = ['Authentic', 'DALL·E 3', 'Stable Diffusion 1.x', 'MJ V5/V6', 'MJ V1/V2', 'Stable Diffusion XL'] |
|
mean_probs = preds.mean(dim=0).numpy() |
|
|
|
|
|
class_probs = {cls: prob for cls, prob in zip(classes, mean_probs)} |
|
|
|
|
|
sorted_probs = dict(sorted(class_probs.items(), key=lambda item: item[1], reverse=True)) |
|
|
|
return sorted_probs |
|
|
|
def load_images(): |
|
real_image_folder = "real_images" |
|
fake_image_folder = "fake_images" |
|
real_images = [os.path.join(real_image_folder, img) for img in os.listdir(real_image_folder)] |
|
fake_images = [os.path.join(fake_image_folder, img) for img in os.listdir(fake_image_folder)] |
|
selected_images = random.sample(real_images, NUM_ROUNDS // 2) + random.sample(fake_images, NUM_ROUNDS // 2) |
|
random.shuffle(selected_images) |
|
return selected_images |
|
|
|
def create_score_html() -> str: |
|
game_state_snapshot = game_state.get_game_state() |
|
|
|
results_html = "" |
|
if game_state_snapshot['last_results']: |
|
results = game_state_snapshot['last_results'] |
|
results_html = f""" |
|
<div style='margin-top: 1rem; padding: 1rem; background-color: #e0e0e0; border-radius: 8px; color: #333;'> |
|
<h4 style='color: #333; margin-bottom: 0.5rem;'>Last Round Results:</h4> |
|
<p style='color: #333;'>Your guess: {results.user_guess}</p> |
|
<p style='color: #333;'>Model's guess: {results.model_guess}</p> |
|
<p style='color: #333;'>Correct answer: {results.correct_answer}</p> |
|
</div> |
|
""" |
|
|
|
current_display_round = min(game_state_snapshot['current_round'] + 1, game_state_snapshot['total_rounds']) |
|
|
|
return f""" |
|
<div style='padding: 1rem; background-color: #f0f0f0; border-radius: 8px; color: #333;'> |
|
<h3 style='margin-bottom: 1rem; color: #333;'>Score Board</h3> |
|
<div style='display: flex; justify-content: space-around;'> |
|
<div> |
|
<h4 style='color: #333;'>You</h4> |
|
<p style='font-size: 1.5rem; color: #333;'>{game_state_snapshot['user_score']}</p> |
|
</div> |
|
<div> |
|
<h4 style='color: #333;'>AI Model</h4> |
|
<p style='font-size: 1.5rem; color: #333;'>{game_state_snapshot['model_score']}</p> |
|
</div> |
|
</div> |
|
<div style='margin-top: 1rem;'> |
|
<p style='color: #333;'>Round: {current_display_round}/{game_state_snapshot['total_rounds']}</p> |
|
</div> |
|
{results_html} |
|
</div> |
|
""" |
|
|
|
game_state = GameState() |
|
|
|
def start_game(): |
|
"""Initialize a new game""" |
|
success, first_image_path = game_state.start_new_game() |
|
|
|
if not success or not first_image_path: |
|
print("Failed to start new game") |
|
return [gr.update()] * 6 |
|
|
|
try: |
|
current_image = Image.open(first_image_path) |
|
|
|
return ( |
|
gr.update(value=current_image, visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=True, interactive=True), |
|
gr.update(visible=True, interactive=True), |
|
create_score_html(), |
|
gr.update(visible=False) |
|
) |
|
except Exception as e: |
|
print(f"Error starting game: {e}") |
|
game_state.reset() |
|
return [gr.update()] * 6 |
|
|
|
def submit_guess(user_guess: str): |
|
"""Handle user guess submission""" |
|
if not game_state.can_submit_guess(): |
|
return [gr.update()] * 6 |
|
|
|
if not game_state.start_submission(): |
|
return [gr.update()] * 6 |
|
|
|
try: |
|
current_image_path = game_state.get_current_image() |
|
if not current_image_path: |
|
game_state.processing_submission = False |
|
return [gr.update()] * 6 |
|
|
|
current_image = Image.open(current_image_path) |
|
model_prediction = process_image(current_image) |
|
model_guess = "Real" if model_prediction['Authentic'] > PROB_THRESHOLD else "Fake" |
|
correct_answer = "Real" if "real_images" in current_image_path else "Fake" |
|
|
|
results = GameResults(user_guess, model_guess, correct_answer) |
|
game_state.finish_submission(results) |
|
|
|
if not game_state.is_game_active: |
|
return ( |
|
gr.update(value=None, visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
create_score_html(), |
|
gr.update(visible=True, value=game_state.get_game_over_message()) |
|
) |
|
|
|
next_image_path = game_state.get_current_image() |
|
if not next_image_path: |
|
return [gr.update()] * 6 |
|
|
|
next_image = Image.open(next_image_path) |
|
|
|
return ( |
|
gr.update(value=next_image, visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=True, interactive=True), |
|
gr.update(visible=True, interactive=True), |
|
create_score_html(), |
|
gr.update(visible=False) |
|
) |
|
except Exception as e: |
|
print(f"Error processing guess: {e}") |
|
game_state.processing_submission = False |
|
return [gr.update()] * 6 |
|
|
|
|
|
custom_css = """ |
|
#game-container { |
|
max-width: 1200px; |
|
margin: 0 auto; |
|
padding: 20px; |
|
} |
|
#start-button { |
|
max-width: 200px; |
|
margin: 0 auto; |
|
} |
|
#guess-buttons { |
|
display: flex; |
|
gap: 10px; |
|
justify-content: center; |
|
margin-top: 20px; |
|
} |
|
.guess-button { |
|
min-width: 120px; |
|
} |
|
.image-container img { |
|
max-height: 640px !important; |
|
width: auto !important; |
|
object-fit: contain !important; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css) as iface: |
|
with gr.Column(elem_id="game-container"): |
|
gr.HTML(""" |
|
<table style="border-collapse: collapse; border: none; padding: 20px;"> |
|
<tr style="border: none;"> |
|
<td style="border: none; vertical-align: top; padding-right: 30px; padding-left: 30px;"> |
|
<img src="https://cdn-uploads.huggingface.co/production/uploads/62f7a16192950415b637e201/NobqlpFbFkTyBi1LsT9JE.png" alt="SuSy Logo" width="120" style="margin-bottom: 10px;"> |
|
</td> |
|
<td style="border: none; vertical-align: top; padding: 10px;"> |
|
<p style="margin-bottom: 15px;">Compete against SuSy to spot AI-Generated images! SuSy can distinguish between authentic images and those generated by DALL·E, Midjourney and Stable Diffusion.</p> |
|
<p style="margin-top: 15px;">Learn more about SuSy: <a href="https://arxiv.org/abs/2409.14128">Present and Future Generalization of Synthetic Image Detectors</a></p> |
|
<p style="margin-top: 15px;"> |
|
Enter the SuSy-verse! |
|
<a href="https://huggingface.co/HPAI-BSC/SuSy">Model</a> | |
|
<a href="https://github.com/HPAI-BSC/SuSy">Code</a> | |
|
<a href="https://huggingface.co/datasets/HPAI-BSC/SuSy-Dataset">Dataset</a> |
|
</p> |
|
</td> |
|
</tr> |
|
</table> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
image_display = gr.Image( |
|
type="pil", |
|
label="Current Image", |
|
interactive=False, |
|
visible=False, |
|
elem_classes=["image-container"] |
|
) |
|
with gr.Row(elem_id="guess-buttons"): |
|
real_button = gr.Button( |
|
"Real", |
|
visible=False, |
|
variant="primary", |
|
elem_classes=["guess-button"] |
|
) |
|
fake_button = gr.Button( |
|
"Fake", |
|
visible=False, |
|
variant="primary", |
|
elem_classes=["guess-button"] |
|
) |
|
|
|
with gr.Column(scale=1): |
|
score_display = gr.HTML() |
|
|
|
with gr.Row(): |
|
with gr.Column(elem_id="start-button"): |
|
start_button = gr.Button("Start New Game", variant="primary", size="sm") |
|
|
|
feedback_display = gr.Markdown(visible=False) |
|
|
|
|
|
start_button.click( |
|
fn=start_game, |
|
outputs=[ |
|
image_display, |
|
start_button, |
|
real_button, |
|
fake_button, |
|
score_display, |
|
feedback_display |
|
] |
|
) |
|
|
|
real_button.click( |
|
fn=lambda: submit_guess("Real"), |
|
outputs=[ |
|
image_display, |
|
start_button, |
|
real_button, |
|
fake_button, |
|
score_display, |
|
feedback_display |
|
] |
|
) |
|
|
|
fake_button.click( |
|
fn=lambda: submit_guess("Fake"), |
|
outputs=[ |
|
image_display, |
|
start_button, |
|
real_button, |
|
fake_button, |
|
score_display, |
|
feedback_display |
|
] |
|
) |
|
|
|
|
|
iface.launch() |
|
|