pabberpe commited on
Commit
0042ac6
1 Parent(s): e8f1766

Add Lock to handle User Guesses

Browse files
Files changed (1) hide show
  1. app.py +180 -118
app.py CHANGED
@@ -1,11 +1,15 @@
 
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
- import random
5
  from PIL import Image
6
  from skimage.feature import graycomatrix, graycoprops
7
  from torchvision import transforms
8
- import os
9
 
10
  NUM_ROUNDS = 10
11
  PROB_THRESHOLD = 0.3
@@ -13,6 +17,99 @@ PROB_THRESHOLD = 0.3
13
  # Load the model
14
  model = torch.jit.load("SuSy.pt")
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def process_image(image):
17
  # Set Parameters
18
  top_k_patches = 5
@@ -66,47 +163,6 @@ def process_image(image):
66
 
67
  return sorted_probs
68
 
69
-
70
- class GameState:
71
- def __init__(self):
72
- self.user_score = 0
73
- self.model_score = 0
74
- self.current_round = 0
75
- self.total_rounds = NUM_ROUNDS
76
- self.game_images = []
77
- self.is_game_active = False
78
- self.last_results = None
79
- self.waiting_for_input = True
80
-
81
- def reset(self):
82
- self.__init__()
83
-
84
- def get_game_over_message(self):
85
- if self.user_score > self.model_score:
86
- return """
87
- <div style='text-align: center; margin-top: 20px; font-size: 1.2em;'>
88
- 🎉 Congratulations! You won! 🎉<br>
89
- You've outperformed SuSy in detecting AI-generated images.<br>
90
- Click 'Start New Game' to play again.
91
- </div>
92
- """
93
- elif self.user_score < self.model_score:
94
- return """
95
- <div style='text-align: center; margin-top: 20px; font-size: 1.2em;'>
96
- Better luck next time! SuSy won this round.<br>
97
- Keep practicing to improve your detection skills.<br>
98
- Click 'Start New Game' to try again.
99
- </div>
100
- """
101
- else:
102
- return """
103
- <div style='text-align: center; margin-top: 20px; font-size: 1.2em;'>
104
- It's a tie! You matched SuSy's performance!<br>
105
- You're getting good at this.<br>
106
- Click 'Start New Game' to play again.
107
- </div>
108
- """
109
-
110
  game_state = GameState()
111
 
112
  def load_images():
@@ -119,45 +175,45 @@ def load_images():
119
  return selected_images
120
 
121
  def create_score_html():
122
- results_html = ""
123
- if game_state.last_results:
124
- results_html = f"""
125
- <div style='margin-top: 1rem; padding: 1rem; background-color: #e0e0e0; border-radius: 8px; color: #333;'>
126
- <h4 style='color: #333; margin-bottom: 0.5rem;'>Last Round Results:</h4>
127
- <p style='color: #333;'>Your guess: {game_state.last_results['user_guess']}</p>
128
- <p style='color: #333;'>Model's guess: {game_state.last_results['model_guess']}</p>
129
- <p style='color: #333;'>Correct answer: {game_state.last_results['correct_answer']}</p>
130
- </div>
131
- """
 
132
 
133
- current_display_round = min(game_state.current_round + 1, game_state.total_rounds)
134
-
135
- return f"""
136
- <div style='padding: 1rem; background-color: #f0f0f0; border-radius: 8px; color: #333;'>
137
- <h3 style='margin-bottom: 1rem; color: #333;'>Score Board</h3>
138
- <div style='display: flex; justify-content: space-around;'>
139
- <div>
140
- <h4 style='color: #333;'>You</h4>
141
- <p style='font-size: 1.5rem; color: #333;'>{game_state.user_score}</p>
 
 
 
 
 
142
  </div>
143
- <div>
144
- <h4 style='color: #333;'>AI Model</h4>
145
- <p style='font-size: 1.5rem; color: #333;'>{game_state.model_score}</p>
146
  </div>
 
147
  </div>
148
- <div style='margin-top: 1rem;'>
149
- <p style='color: #333;'>Round: {current_display_round}/{game_state.total_rounds}</p>
150
- </div>
151
- {results_html}
152
- </div>
153
- """
154
 
155
  def start_game():
156
- game_state.reset()
157
- game_state.game_images = load_images()
158
- game_state.is_game_active = True
159
- game_state.waiting_for_input = True
160
- current_image = Image.open(game_state.game_images[0])
161
 
162
  return (
163
  gr.update(value=current_image, visible=True),
@@ -168,54 +224,60 @@ def start_game():
168
  gr.update(visible=False)
169
  )
170
 
171
- def submit_guess(user_guess):
172
- if not game_state.is_game_active or not game_state.waiting_for_input:
 
173
  return [gr.update()] * 6
174
 
175
- # Compute Model Guess
176
- current_image = Image.open(game_state.game_images[game_state.current_round])
177
- model_prediction = process_image(current_image)
178
- model_guess = "Real" if model_prediction['Authentic'] > PROB_THRESHOLD else "Fake"
179
- correct_answer = "Real" if "real_images" in game_state.game_images[game_state.current_round] else "Fake"
180
-
181
- # Update scores
182
- if user_guess == correct_answer:
183
- game_state.user_score += 1
184
- if model_guess == correct_answer:
185
- game_state.model_score += 1
186
-
187
- # Store last results for display
188
- game_state.last_results = {
189
- 'user_guess': user_guess,
190
- 'model_guess': model_guess,
191
- 'correct_answer': correct_answer
192
- }
193
-
194
- game_state.current_round += 1
195
- game_state.waiting_for_input = True
196
 
197
- # Check if game is over
198
- if game_state.current_round >= game_state.total_rounds:
199
- game_state.is_game_active = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  return (
201
- gr.update(value=None, visible=False),
202
- gr.update(visible=True),
203
- gr.update(visible=False),
204
  gr.update(visible=False),
 
 
205
  create_score_html(),
206
- gr.update(visible=True, value=game_state.get_game_over_message())
207
  )
208
-
209
- next_image = Image.open(game_state.game_images[game_state.current_round])
210
-
211
- return (
212
- gr.update(value=next_image, visible=True),
213
- gr.update(visible=False),
214
- gr.update(visible=True, interactive=True),
215
- gr.update(visible=True, interactive=True),
216
- create_score_html(),
217
- gr.update(visible=False)
218
- )
219
 
220
  # Custom CSS
221
  custom_css = """
@@ -238,7 +300,7 @@ custom_css = """
238
  min-width: 120px;
239
  }
240
  .image-container img {
241
- max-height: 768px !important;
242
  width: auto !important;
243
  object-fit: contain !important;
244
  }
 
1
+ import os
2
+ import random
3
+ from dataclasses import dataclass
4
+ from threading import Lock
5
+ from typing import List, Optional
6
+
7
  import gradio as gr
8
  import numpy as np
9
  import torch
 
10
  from PIL import Image
11
  from skimage.feature import graycomatrix, graycoprops
12
  from torchvision import transforms
 
13
 
14
  NUM_ROUNDS = 10
15
  PROB_THRESHOLD = 0.3
 
17
  # Load the model
18
  model = torch.jit.load("SuSy.pt")
19
 
20
+ @dataclass
21
+ class GameResults:
22
+ user_guess: str
23
+ model_guess: str
24
+ correct_answer: str
25
+
26
+ class GameState:
27
+ def __init__(self):
28
+ self.lock = Lock()
29
+ self.reset()
30
+
31
+ def reset(self):
32
+ with self.lock:
33
+ self.user_score = 0
34
+ self.model_score = 0
35
+ self.current_round = 0
36
+ self.total_rounds = NUM_ROUNDS
37
+ self.game_images: List[str] = []
38
+ self.is_game_active = False
39
+ self.last_results: Optional[GameResults] = None
40
+ self.processing_submission = False
41
+
42
+ def start_new_game(self) -> bool:
43
+ with self.lock:
44
+ if self.is_game_active:
45
+ return False
46
+ self.reset()
47
+ self.game_images = load_images()
48
+ self.is_game_active = True
49
+ return True
50
+
51
+ def can_submit_guess(self) -> bool:
52
+ with self.lock:
53
+ return (
54
+ self.is_game_active and
55
+ not self.processing_submission and
56
+ self.current_round < self.total_rounds
57
+ )
58
+
59
+ def start_submission(self) -> bool:
60
+ with self.lock:
61
+ if not self.can_submit_guess():
62
+ return False
63
+ self.processing_submission = True
64
+ return True
65
+
66
+ def finish_submission(self, results: GameResults):
67
+ with self.lock:
68
+ if results.user_guess == results.correct_answer:
69
+ self.user_score += 1
70
+ if results.model_guess == results.correct_answer:
71
+ self.model_score += 1
72
+
73
+ self.last_results = results
74
+ self.current_round += 1
75
+ self.processing_submission = False
76
+
77
+ if self.current_round >= self.total_rounds:
78
+ self.is_game_active = False
79
+
80
+ def get_current_image(self) -> Optional[str]:
81
+ with self.lock:
82
+ if not self.is_game_active or self.current_round >= len(self.game_images):
83
+ return None
84
+ return self.game_images[self.current_round]
85
+
86
+ def get_game_over_message(self) -> str:
87
+ with self.lock:
88
+ if self.user_score > self.model_score:
89
+ return """
90
+ <div style='text-align: center; margin-top: 20px; font-size: 1.2em;'>
91
+ 🎉 Congratulations! You won! 🎉<br>
92
+ You've outperformed SuSy in detecting AI-generated images.<br>
93
+ Click 'Start New Game' to play again.
94
+ </div>
95
+ """
96
+ elif self.user_score < self.model_score:
97
+ return """
98
+ <div style='text-align: center; margin-top: 20px; font-size: 1.2em;'>
99
+ Better luck next time! SuSy won this round.<br>
100
+ Keep practicing to improve your detection skills.<br>
101
+ Click 'Start New Game' to try again.
102
+ </div>
103
+ """
104
+ else:
105
+ return """
106
+ <div style='text-align: center; margin-top: 20px; font-size: 1.2em;'>
107
+ It's a tie! You matched SuSy's performance!<br>
108
+ You're getting good at this.<br>
109
+ Click 'Start New Game' to play again.
110
+ </div>
111
+ """
112
+
113
  def process_image(image):
114
  # Set Parameters
115
  top_k_patches = 5
 
163
 
164
  return sorted_probs
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  game_state = GameState()
167
 
168
  def load_images():
 
175
  return selected_images
176
 
177
  def create_score_html():
178
+ with game_state.lock:
179
+ results_html = ""
180
+ if game_state.last_results:
181
+ results_html = f"""
182
+ <div style='margin-top: 1rem; padding: 1rem; background-color: #e0e0e0; border-radius: 8px; color: #333;'>
183
+ <h4 style='color: #333; margin-bottom: 0.5rem;'>Last Round Results:</h4>
184
+ <p style='color: #333;'>Your guess: {game_state.last_results.user_guess}</p>
185
+ <p style='color: #333;'>Model's guess: {game_state.last_results.model_guess}</p>
186
+ <p style='color: #333;'>Correct answer: {game_state.last_results.correct_answer}</p>
187
+ </div>
188
+ """
189
 
190
+ current_display_round = min(game_state.current_round + 1, game_state.total_rounds)
191
+
192
+ return f"""
193
+ <div style='padding: 1rem; background-color: #f0f0f0; border-radius: 8px; color: #333;'>
194
+ <h3 style='margin-bottom: 1rem; color: #333;'>Score Board</h3>
195
+ <div style='display: flex; justify-content: space-around;'>
196
+ <div>
197
+ <h4 style='color: #333;'>You</h4>
198
+ <p style='font-size: 1.5rem; color: #333;'>{game_state.user_score}</p>
199
+ </div>
200
+ <div>
201
+ <h4 style='color: #333;'>AI Model</h4>
202
+ <p style='font-size: 1.5rem; color: #333;'>{game_state.model_score}</p>
203
+ </div>
204
  </div>
205
+ <div style='margin-top: 1rem;'>
206
+ <p style='color: #333;'>Round: {current_display_round}/{game_state.total_rounds}</p>
 
207
  </div>
208
+ {results_html}
209
  </div>
210
+ """
 
 
 
 
 
211
 
212
  def start_game():
213
+ if not game_state.start_new_game():
214
+ return [gr.update()] * 6
215
+
216
+ current_image = Image.open(game_state.get_current_image())
 
217
 
218
  return (
219
  gr.update(value=current_image, visible=True),
 
224
  gr.update(visible=False)
225
  )
226
 
227
+ def submit_guess(user_guess: str):
228
+ # Early return if we can't submit a guess
229
+ if not game_state.can_submit_guess():
230
  return [gr.update()] * 6
231
 
232
+ # Mark submission as being processed
233
+ if not game_state.start_submission():
234
+ return [gr.update()] * 6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
+ try:
237
+ # Get current image and process it
238
+ current_image_path = game_state.get_current_image()
239
+ if not current_image_path:
240
+ return [gr.update()] * 6
241
+
242
+ current_image = Image.open(current_image_path)
243
+ model_prediction = process_image(current_image)
244
+ model_guess = "Real" if model_prediction['Authentic'] > PROB_THRESHOLD else "Fake"
245
+ correct_answer = "Real" if "real_images" in current_image_path else "Fake"
246
+
247
+ # Update game state with results
248
+ results = GameResults(user_guess, model_guess, correct_answer)
249
+ game_state.finish_submission(results)
250
+
251
+ # Check if game is over
252
+ if not game_state.is_game_active:
253
+ return (
254
+ gr.update(value=None, visible=False),
255
+ gr.update(visible=True),
256
+ gr.update(visible=False),
257
+ gr.update(visible=False),
258
+ create_score_html(),
259
+ gr.update(visible=True, value=game_state.get_game_over_message())
260
+ )
261
+
262
+ # Get next image for the next round
263
+ next_image_path = game_state.get_current_image()
264
+ if not next_image_path:
265
+ return [gr.update()] * 6
266
+
267
+ next_image = Image.open(next_image_path)
268
+
269
  return (
270
+ gr.update(value=next_image, visible=True),
 
 
271
  gr.update(visible=False),
272
+ gr.update(visible=True, interactive=True),
273
+ gr.update(visible=True, interactive=True),
274
  create_score_html(),
275
+ gr.update(visible=False)
276
  )
277
+ except Exception as e:
278
+ # If any error occurs, reset the processing flag
279
+ game_state.processing_submission = False
280
+ raise e
 
 
 
 
 
 
 
281
 
282
  # Custom CSS
283
  custom_css = """
 
300
  min-width: 120px;
301
  }
302
  .image-container img {
303
+ max-height: 640px !important;
304
  width: auto !important;
305
  object-fit: contain !important;
306
  }