pabberpe commited on
Commit
ed5b21a
1 Parent(s): 4b40bf3

Rewrite Game State handling

Browse files
Files changed (1) hide show
  1. app.py +101 -96
app.py CHANGED
@@ -1,9 +1,7 @@
1
  import os
2
  import random
3
  from dataclasses import dataclass
4
- from threading import Lock
5
  from typing import List, Optional
6
-
7
  import gradio as gr
8
  import numpy as np
9
  import torch
@@ -25,8 +23,6 @@ class GameResults:
25
 
26
  class GameState:
27
  def __init__(self):
28
- # Lock only used for round transitions
29
- self.round_lock = Lock()
30
  self.reset()
31
 
32
  def reset(self):
@@ -55,25 +51,23 @@ class GameState:
55
  )
56
 
57
  def start_submission(self) -> bool:
58
- with self.round_lock:
59
- if not self.can_submit_guess():
60
- return False
61
- self.processing_submission = True
62
- return True
63
 
64
  def finish_submission(self, results: GameResults):
65
- with self.round_lock:
66
- if results.user_guess == results.correct_answer:
67
- self.user_score += 1
68
- if results.model_guess == results.correct_answer:
69
- self.model_score += 1
70
-
71
- self.last_results = results
72
- self.current_round += 1
73
- self.processing_submission = False
74
-
75
- if self.current_round >= self.total_rounds:
76
- self.is_game_active = False
77
 
78
  def get_current_image(self) -> Optional[str]:
79
  if not self.is_game_active or self.current_round >= len(self.game_images):
@@ -106,6 +100,40 @@ class GameState:
106
  </div>
107
  """
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def process_image(image):
110
  # Set Parameters
111
  top_k_patches = 5
@@ -159,8 +187,6 @@ def process_image(image):
159
 
160
  return sorted_probs
161
 
162
- game_state = GameState()
163
-
164
  def load_images() -> List[str]:
165
  real_image_folder = "real_images"
166
  fake_image_folder = "fake_images"
@@ -170,69 +196,37 @@ def load_images() -> List[str]:
170
  random.shuffle(selected_images)
171
  return selected_images
172
 
173
- def create_score_html() -> str:
174
- results_html = ""
175
- if game_state.last_results:
176
- results_html = f"""
177
- <div style='margin-top: 1rem; padding: 1rem; background-color: #e0e0e0; border-radius: 8px; color: #333;'>
178
- <h4 style='color: #333; margin-bottom: 0.5rem;'>Last Round Results:</h4>
179
- <p style='color: #333;'>Your guess: {game_state.last_results.user_guess}</p>
180
- <p style='color: #333;'>Model's guess: {game_state.last_results.model_guess}</p>
181
- <p style='color: #333;'>Correct answer: {game_state.last_results.correct_answer}</p>
182
- </div>
183
- """
184
-
185
- current_display_round = min(game_state.current_round + 1, game_state.total_rounds)
186
 
187
- return f"""
188
- <div style='padding: 1rem; background-color: #f0f0f0; border-radius: 8px; color: #333;'>
189
- <h3 style='margin-bottom: 1rem; color: #333;'>Score Board</h3>
190
- <div style='display: flex; justify-content: space-around;'>
191
- <div>
192
- <h4 style='color: #333;'>You</h4>
193
- <p style='font-size: 1.5rem; color: #333;'>{game_state.user_score}</p>
194
- </div>
195
- <div>
196
- <h4 style='color: #333;'>AI Model</h4>
197
- <p style='font-size: 1.5rem; color: #333;'>{game_state.model_score}</p>
198
- </div>
199
- </div>
200
- <div style='margin-top: 1rem;'>
201
- <p style='color: #333;'>Round: {current_display_round}/{game_state.total_rounds}</p>
202
- </div>
203
- {results_html}
204
- </div>
205
- """
206
-
207
- def start_game():
208
- if not game_state.start_new_game():
209
- return [gr.update()] * 6
210
 
211
- current_image = Image.open(game_state.get_current_image())
212
 
213
- return (
 
214
  gr.update(value=current_image, visible=True),
215
  gr.update(visible=False),
216
  gr.update(visible=True, interactive=True),
217
  gr.update(visible=True, interactive=True),
218
- create_score_html(),
219
- gr.update(visible=False)
220
- )
221
 
222
- def submit_guess(user_guess: str):
223
- # Early return if we can't submit a guess
224
- if not game_state.can_submit_guess():
225
- return [gr.update()] * 6
226
 
227
- # Mark submission as being processed
228
- if not game_state.start_submission():
229
- return [gr.update()] * 6
230
 
231
  try:
232
- # Get current image and process it
233
- current_image_path = game_state.get_current_image()
234
  if not current_image_path:
235
- return [gr.update()] * 6
236
 
237
  current_image = Image.open(current_image_path)
238
  model_prediction = process_image(current_image)
@@ -241,37 +235,38 @@ def submit_guess(user_guess: str):
241
 
242
  # Update game state with results
243
  results = GameResults(user_guess, model_guess, correct_answer)
244
- game_state.finish_submission(results)
245
 
246
  # Check if game is over
247
- if not game_state.is_game_active:
248
- return (
 
249
  gr.update(value=None, visible=False),
250
  gr.update(visible=True),
251
  gr.update(visible=False),
252
  gr.update(visible=False),
253
- create_score_html(),
254
- gr.update(visible=True, value=game_state.get_game_over_message())
255
- )
256
 
257
- # Get next image for the next round
258
- next_image_path = game_state.get_current_image()
259
  if not next_image_path:
260
- return [gr.update()] * 6
261
-
262
  next_image = Image.open(next_image_path)
263
 
264
- return (
 
265
  gr.update(value=next_image, visible=True),
266
  gr.update(visible=False),
267
  gr.update(visible=True, interactive=True),
268
  gr.update(visible=True, interactive=True),
269
- create_score_html(),
270
  gr.update(visible=False)
271
- )
 
272
  except Exception as e:
273
- # If any error occurs, reset the processing flag
274
- game_state.processing_submission = False
275
  raise e
276
 
277
  # Custom CSS
@@ -303,6 +298,9 @@ custom_css = """
303
 
304
  # Define Gradio interface
305
  with gr.Blocks(css=custom_css) as iface:
 
 
 
306
  with gr.Column(elem_id="game-container"):
307
  gr.HTML("""
308
  <table style="border-collapse: collapse; border: none; padding: 20px;">
@@ -359,7 +357,9 @@ with gr.Blocks(css=custom_css) as iface:
359
  # Event handlers
360
  start_button.click(
361
  fn=start_game,
 
362
  outputs=[
 
363
  image_display,
364
  start_button,
365
  real_button,
@@ -370,28 +370,33 @@ with gr.Blocks(css=custom_css) as iface:
370
  )
371
 
372
  real_button.click(
373
- fn=lambda: submit_guess("Real"),
 
374
  outputs=[
 
375
  image_display,
376
  start_button,
377
  real_button,
378
  fake_button,
379
  score_display,
380
- feedback_display
381
- ]
382
  )
383
-
384
  fake_button.click(
385
- fn=lambda: submit_guess("Fake"),
 
386
  outputs=[
 
387
  image_display,
388
  start_button,
389
  real_button,
390
  fake_button,
391
  score_display,
392
- feedback_display
393
- ]
394
  )
395
 
 
396
  # Launch the interface
397
  iface.launch()
 
1
  import os
2
  import random
3
  from dataclasses import dataclass
 
4
  from typing import List, Optional
 
5
  import gradio as gr
6
  import numpy as np
7
  import torch
 
23
 
24
  class GameState:
25
  def __init__(self):
 
 
26
  self.reset()
27
 
28
  def reset(self):
 
51
  )
52
 
53
  def start_submission(self) -> bool:
54
+ if not self.can_submit_guess():
55
+ return False
56
+ self.processing_submission = True
57
+ return True
 
58
 
59
  def finish_submission(self, results: GameResults):
60
+ if results.user_guess == results.correct_answer:
61
+ self.user_score += 1
62
+ if results.model_guess == results.correct_answer:
63
+ self.model_score += 1
64
+
65
+ self.last_results = results
66
+ self.current_round += 1
67
+ self.processing_submission = False
68
+
69
+ if self.current_round >= self.total_rounds:
70
+ self.is_game_active = False
 
71
 
72
  def get_current_image(self) -> Optional[str]:
73
  if not self.is_game_active or self.current_round >= len(self.game_images):
 
100
  </div>
101
  """
102
 
103
+ def create_score_html(game_state: GameState):
104
+ results_html = ""
105
+ if game_state.last_results:
106
+ results_html = f"""
107
+ <div style='margin-top: 1rem; padding: 1rem; background-color: #e0e0e0; border-radius: 8px; color: #333;'>
108
+ <h4 style='color: #333; margin-bottom: 0.5rem;'>Last Round Results:</h4>
109
+ <p style='color: #333;'>Your guess: {game_state.last_results.user_guess}</p>
110
+ <p style='color: #333;'>Model's guess: {game_state.last_results.model_guess}</p>
111
+ <p style='color: #333;'>Correct answer: {game_state.last_results.correct_answer}</p>
112
+ </div>
113
+ """
114
+
115
+ current_display_round = min(game_state.current_round + 1, game_state.total_rounds)
116
+
117
+ return f"""
118
+ <div style='padding: 1rem; background-color: #f0f0f0; border-radius: 8px; color: #333;'>
119
+ <h3 style='margin-bottom: 1rem; color: #333;'>Score Board</h3>
120
+ <div style='display: flex; justify-content: space-around;'>
121
+ <div>
122
+ <h4 style='color: #333;'>You</h4>
123
+ <p style='font-size: 1.5rem; color: #333;'>{game_state.user_score}</p>
124
+ </div>
125
+ <div>
126
+ <h4 style='color: #333;'>AI Model</h4>
127
+ <p style='font-size: 1.5rem; color: #333;'>{game_state.model_score}</p>
128
+ </div>
129
+ </div>
130
+ <div style='margin-top: 1rem;'>
131
+ <p style='color: #333;'>Round: {current_display_round}/{game_state.total_rounds}</p>
132
+ </div>
133
+ {results_html}
134
+ </div>
135
+ """
136
+
137
  def process_image(image):
138
  # Set Parameters
139
  top_k_patches = 5
 
187
 
188
  return sorted_probs
189
 
 
 
190
  def load_images() -> List[str]:
191
  real_image_folder = "real_images"
192
  fake_image_folder = "fake_images"
 
196
  random.shuffle(selected_images)
197
  return selected_images
198
 
199
+ def start_game(state: Optional[GameState]):
200
+ # Initialize new game state if none exists
201
+ if state is None:
202
+ state = GameState()
 
 
 
 
 
 
 
 
 
203
 
204
+ if not state.start_new_game():
205
+ return [state] + [gr.update()] * 6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
+ current_image = Image.open(state.get_current_image())
208
 
209
+ return [
210
+ state,
211
  gr.update(value=current_image, visible=True),
212
  gr.update(visible=False),
213
  gr.update(visible=True, interactive=True),
214
  gr.update(visible=True, interactive=True),
215
+ create_score_html(state),
216
+ gr.update(visible=False),
217
+ ]
218
 
219
+ def submit_guess(user_guess: str, state: GameState):
220
+ if not state.can_submit_guess():
221
+ return [state] + [gr.update()] * 6
 
222
 
223
+ if not state.start_submission():
224
+ return [state] + [gr.update()] * 6
 
225
 
226
  try:
227
+ current_image_path = state.get_current_image()
 
228
  if not current_image_path:
229
+ return [state] + [gr.update()] * 6
230
 
231
  current_image = Image.open(current_image_path)
232
  model_prediction = process_image(current_image)
 
235
 
236
  # Update game state with results
237
  results = GameResults(user_guess, model_guess, correct_answer)
238
+ state.finish_submission(results)
239
 
240
  # Check if game is over
241
+ if not state.is_game_active:
242
+ return [
243
+ state,
244
  gr.update(value=None, visible=False),
245
  gr.update(visible=True),
246
  gr.update(visible=False),
247
  gr.update(visible=False),
248
+ create_score_html(state),
249
+ gr.update(visible=True, value=state.get_game_over_message())
250
+ ]
251
 
252
+ # Get next image
253
+ next_image_path = state.get_current_image()
254
  if not next_image_path:
255
+ return [state] + [gr.update()] * 6
 
256
  next_image = Image.open(next_image_path)
257
 
258
+ return [
259
+ state,
260
  gr.update(value=next_image, visible=True),
261
  gr.update(visible=False),
262
  gr.update(visible=True, interactive=True),
263
  gr.update(visible=True, interactive=True),
264
+ create_score_html(state),
265
  gr.update(visible=False)
266
+ ]
267
+
268
  except Exception as e:
269
+ state.processing_submission = False
 
270
  raise e
271
 
272
  # Custom CSS
 
298
 
299
  # Define Gradio interface
300
  with gr.Blocks(css=custom_css) as iface:
301
+ # State variable for the game
302
+ state = gr.State(None)
303
+
304
  with gr.Column(elem_id="game-container"):
305
  gr.HTML("""
306
  <table style="border-collapse: collapse; border: none; padding: 20px;">
 
357
  # Event handlers
358
  start_button.click(
359
  fn=start_game,
360
+ inputs=[state],
361
  outputs=[
362
+ state,
363
  image_display,
364
  start_button,
365
  real_button,
 
370
  )
371
 
372
  real_button.click(
373
+ fn=lambda state: submit_guess("Real", state),
374
+ inputs=[state],
375
  outputs=[
376
+ state,
377
  image_display,
378
  start_button,
379
  real_button,
380
  fake_button,
381
  score_display,
382
+ feedback_display,
383
+ ],
384
  )
385
+
386
  fake_button.click(
387
+ fn=lambda state: submit_guess("Fake", state),
388
+ inputs=[state],
389
  outputs=[
390
+ state,
391
  image_display,
392
  start_button,
393
  real_button,
394
  fake_button,
395
  score_display,
396
+ feedback_display,
397
+ ],
398
  )
399
 
400
+
401
  # Launch the interface
402
  iface.launch()