VOIDER commited on
Commit
dcbe972
·
verified ·
1 Parent(s): d924e11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -137
app.py CHANGED
@@ -187,75 +187,68 @@ def predict_anime_aesthetic(img, model):
187
  # Image Evaluation Tool #
188
  #####################################
189
 
190
- class ImageEvaluationTool:
191
- """Evaluation tool to process images through multiple aesthetic models and generate logs and HTML outputs."""
192
  def __init__(self):
193
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
194
  print(f"Using device: {self.device}")
195
  print("Loading models... This may take some time.")
196
 
197
- # Load models with progress logs
198
  print("Loading Aesthetic Shadow model...")
199
- self.aesthetic_shadow = pipeline("image-classification", model="NeoChen1024/aesthetic-shadow-v2-backup", device=self.device)
200
  print("Loading Waifu Scorer model...")
201
- self.waifu_scorer = WaifuScorer(device=self.device, verbose=True)
202
  print("Loading Aesthetic Predictor V2.5...")
203
- self.aesthetic_predictor = load_aesthetic_predictor_v2_5()
204
  print("Loading Anime Aesthetic model...")
205
- self.anime_aesthetic = load_anime_aesthetic_model()
206
  print("All models loaded successfully!")
207
 
208
- self.temp_dir = tempfile.mkdtemp()
209
- self.results = [] # Store final results for sorting and display
210
  self.available_models = {
211
- "aesthetic_shadow": {"name": "Aesthetic Shadow", "process": self._process_aesthetic_shadow},
212
- "waifu_scorer": {"name": "Waifu Scorer", "process": self._process_waifu_scorer},
213
- "aesthetic_predictor_v2_5": {"name": "Aesthetic V2.5", "process": self._process_aesthetic_predictor_v2_5},
214
- "anime_aesthetic": {"name": "Anime Score", "process": self._process_anime_aesthetic},
215
  }
 
 
 
216
 
217
- def image_to_base64(self, image: Image.Image) -> str:
218
- """Convert PIL Image to base64 encoded JPEG string."""
219
- buffered = BytesIO()
220
- image.save(buffered, format="JPEG")
221
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
222
-
223
- def auto_tune_batch_size(self, images: list) -> int:
224
- """Automatically determine the optimal batch size for processing."""
225
- batch_size = 1
226
- max_batch = len(images)
227
- test_image = images[0:1]
228
- while batch_size <= max_batch:
229
- try:
230
- if "aesthetic_shadow" in self.available_models and self.available_models["aesthetic_shadow"]['selected']: # Check if model is available and selected
231
- _ = self.aesthetic_shadow(test_image * batch_size)
232
- if "waifu_scorer" in self.available_models and self.available_models["waifu_scorer"]['selected']: # Check if model is available and selected
233
- _ = self.waifu_scorer(test_image * batch_size)
234
- if "aesthetic_predictor_v2_5" in self.available_models and self.available_models["aesthetic_predictor_v2_5"]['selected']: # Check if model is available and selected
235
- _ = self.aesthetic_predictor.inference(test_image * batch_size)
236
- batch_size *= 2
237
- if batch_size > max_batch:
238
- break
239
- except Exception:
240
  break
241
- optimal = max(1, batch_size // 2)
242
- if optimal > 64:
243
- optimal = 64
244
- print("Capped optimal batch size to 64")
245
- print(f"Optimal batch size determined: {optimal}")
246
- return optimal
247
-
248
- async def process_images_evaluation_with_logs(self, file_paths: list, auto_batch: bool, manual_batch_size: int, selected_models):
249
- """Asynchronously process images and yield updates with logs, HTML table, and progress bar."""
250
- self.results = []
 
 
 
 
 
 
 
 
 
 
 
251
  log_events = []
252
  images = []
253
  file_names = []
 
254
 
255
- # Update available models based on selection
256
- for model_key in self.available_models:
257
- self.available_models[model_key]['selected'] = model_key in selected_models
258
-
259
  total_files = len(file_paths)
260
  log_events.append(f"Starting to load {total_files} images...")
261
  for f in file_paths:
@@ -268,14 +261,9 @@ class ImageEvaluationTool:
268
 
269
  if not images:
270
  log_events.append("No valid images loaded.")
271
- progress_percentage = 0 # Define progress_percentage here for no images case
272
- progress_html = self._generate_progress_html(progress_percentage)
273
- yield ("<p>No images loaded.</p>", "", self._format_logs(log_events), progress_html, manual_batch_size)
274
- return
275
 
276
- yield ("<p>Images loaded. Determining batch size...</p>", "", self._format_logs(log_events),
277
- self._generate_progress_html(0), manual_batch_size)
278
- await asyncio.sleep(0.1)
279
 
280
  try:
281
  manual_batch_size = int(manual_batch_size) if manual_batch_size is not None else 1
@@ -285,9 +273,6 @@ class ImageEvaluationTool:
285
 
286
  optimal_batch = self.auto_tune_batch_size(images) if auto_batch else manual_batch_size
287
  log_events.append(f"Using batch size: {optimal_batch}")
288
- yield ("<p>Processing images in batches...</p>", "", self._format_logs(log_events),
289
- self._generate_progress_html(0), optimal_batch)
290
- await asyncio.sleep(0.1)
291
 
292
  total_images = len(images)
293
  for i in range(0, total_images, optimal_batch):
@@ -298,36 +283,18 @@ class ImageEvaluationTool:
298
 
299
  batch_results = {}
300
 
301
- # Aesthetic Shadow processing
302
- if self.available_models['aesthetic_shadow']['selected']:
303
- batch_results['aesthetic_shadow'] = await self._process_aesthetic_shadow(batch_images, log_events)
304
- else:
305
- batch_results['aesthetic_shadow'] = [None] * len(batch_images)
306
-
307
- # Waifu Scorer processing
308
- if self.available_models['waifu_scorer']['selected']:
309
- batch_results['waifu_scorer'] = await self._process_waifu_scorer(batch_images, log_events)
310
- else:
311
- batch_results['waifu_scorer'] = [None] * len(batch_images)
312
-
313
- # Aesthetic Predictor V2.5 processing
314
- if self.available_models['aesthetic_predictor_v2_5']['selected']:
315
- batch_results['aesthetic_predictor_v2_5'] = await self._process_aesthetic_predictor_v2_5(batch_images, log_events)
316
- else:
317
- batch_results['aesthetic_predictor_v2_5'] = [None] * len(batch_images)
318
-
319
- # Anime Aesthetic processing (single image)
320
- if self.available_models['anime_aesthetic']['selected']:
321
- batch_results['anime_aesthetic'] = await self._process_anime_aesthetic(batch_images, log_events)
322
- else:
323
- batch_results['anime_aesthetic'] = [None] * len(batch_images)
324
 
325
-
326
- # Combine results
327
  for j in range(len(batch_images)):
328
  scores_to_average = []
329
- for model_key in self.available_models:
330
- if self.available_models[model_key]['selected']: # Only consider selected models
331
  score = batch_results[model_key][j]
332
  if score is not None:
333
  scores_to_average.append(score)
@@ -340,28 +307,49 @@ class ImageEvaluationTool:
340
  'img_data': self.image_to_base64(thumbnail), # Keep this for the HTML display
341
  'final_score': final_score,
342
  }
343
- for model_key in self.available_models: # Add model scores to result
344
  if self.available_models[model_key]['selected']:
345
  result[model_key] = batch_results[model_key][j]
 
346
 
347
- self.results.append(result)
348
- self.sort_results() # Sort results after adding new result
349
- progress_percentage = min(100, ((i + len(batch_images)) / total_images) * 100) # Define progress_percentage here
350
- yield (f"<p>Processed batch {batch_index}.</p>", self.generate_html_table(self.results, selected_models), # Update table immediately
351
- self._format_logs(log_events[-10:]), self._generate_progress_html(progress_percentage), optimal_batch)
352
- await asyncio.sleep(0.1)
353
 
354
 
355
- log_events.append("All images processed.")
356
- self.sort_results() # Final sort after all images processed
357
- html_table = self.generate_html_table(self.results, selected_models) # Pass selected models to final table generation
358
- final_progress = self._generate_progress_html(100)
359
- yield ("<p>All images processed.</p>", html_table,
360
- self._format_logs(log_events[-10:]), final_progress, optimal_batch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
  async def _process_aesthetic_shadow(self, batch_images, log_events):
363
  try:
364
- shadow_results = self.aesthetic_shadow(batch_images)
365
  log_events.append("Aesthetic Shadow processed for batch.")
366
  except Exception as e:
367
  log_events.append(f"Error in Aesthetic Shadow: {e}")
@@ -379,7 +367,7 @@ class ImageEvaluationTool:
379
 
380
  async def _process_waifu_scorer(self, batch_images, log_events):
381
  try:
382
- waifu_scores = self.waifu_scorer(batch_images)
383
  waifu_scores = [float(np.clip(s, 0.0, 10.0)) if s is not None else None for s in waifu_scores]
384
  log_events.append("Waifu Scorer processed for batch.")
385
  except Exception as e:
@@ -389,7 +377,7 @@ class ImageEvaluationTool:
389
 
390
  async def _process_aesthetic_predictor_v2_5(self, batch_images, log_events):
391
  try:
392
- v2_5_scores = self.aesthetic_predictor.inference(batch_images)
393
  v2_5_scores = [float(np.round(np.clip(s, 0.0, 10.0), 4)) if s is not None else None for s in v2_5_scores]
394
  log_events.append("Aesthetic Predictor V2.5 processed for batch.")
395
  except Exception as e:
@@ -401,7 +389,7 @@ class ImageEvaluationTool:
401
  anime_scores = []
402
  for j, img in enumerate(batch_images):
403
  try:
404
- score = predict_anime_aesthetic(img, self.anime_aesthetic)
405
  anime_scores.append(float(np.clip(score * 10.0, 0.0, 10.0)))
406
  log_events.append(f"Anime Aesthetic processed for image {j + 1}.")
407
  except Exception as e:
@@ -424,7 +412,7 @@ class ImageEvaluationTool:
424
  """Format log events into an HTML string."""
425
  return "<div style='max-height:300px; overflow-y:auto;'>" + "<br>".join(logs) + "</div>"
426
 
427
- def sort_results(self, sort_by: str = "Final Score") -> list:
428
  """Sort results based on the specified column."""
429
  key_map = {
430
  "Final Score": "final_score",
@@ -436,8 +424,8 @@ class ImageEvaluationTool:
436
  }
437
  key = key_map.get(sort_by, "final_score")
438
  reverse = sort_by != "File Name"
439
- self.results.sort(key=lambda r: r.get(key) if r.get(key) is not None else (-float('inf') if not reverse else float('inf')), reverse=reverse)
440
- return self.results
441
 
442
  def generate_html_table(self, results: list, selected_models) -> str:
443
  """Generate an HTML table to display the evaluation results."""
@@ -503,17 +491,25 @@ class ImageEvaluationTool:
503
 
504
 
505
  def cleanup(self):
506
- """Clean up temporary directories."""
507
  if os.path.exists(self.temp_dir):
508
  shutil.rmtree(self.temp_dir)
 
 
 
 
 
 
 
509
 
510
 
511
  #####################################
512
  # Interface #
513
  #####################################
514
 
 
 
515
  def create_interface():
516
- evaluator = ImageEvaluationTool()
517
  sort_options = ["Final Score", "File Name", "Aesthetic Shadow", "Waifu Scorer", "Aesthetic V2.5", "Anime Score"]
518
  model_options = ["aesthetic_shadow", "waifu_scorer", "aesthetic_predictor_v2_5", "anime_aesthetic"]
519
 
@@ -556,12 +552,13 @@ def create_interface():
556
  status_html = gr.HTML(label="Status")
557
  output_html = gr.HTML(label="Evaluation Results")
558
  download_file_output = gr.File() # Initialize gr.File component without filename
 
559
 
560
  # Function to convert results to CSV format, excluding 'img_data'.
561
- def results_to_csv(selected_models):
562
  import csv
563
  import io
564
- if not evaluator.results:
565
  return None # Return None when no results are available
566
  output = io.StringIO()
567
  fieldnames = ['file_name', 'final_score'] # Base fieldnames
@@ -571,7 +568,7 @@ def create_interface():
571
 
572
  writer = csv.DictWriter(output, fieldnames=fieldnames)
573
  writer.writeheader()
574
- for res in evaluator.results:
575
  row_dict = {'file_name': res['file_name'], 'final_score': res['final_score']} # Base data
576
  for model_key in selected_models: # Add selected model scores
577
  if model_key in selected_models: # Double check before accessing res[model_key]
@@ -583,33 +580,49 @@ def create_interface():
583
  def update_batch_size_interactivity(auto_batch):
584
  return gr.update(interactive=not auto_batch)
585
 
586
- async def process_images_and_update(files, auto_batch, manual_batch, selected_models):
587
  file_paths = [f.name for f in files]
588
- async for status, table, logs, progress, updated_batch in evaluator.process_images_evaluation_with_logs(file_paths, auto_batch, manual_batch, selected_models):
589
- yield status, table, logs, progress, gr.update(value=updated_batch, interactive=not auto_batch)
590
 
591
- def update_table_sort(sort_by_column, selected_models):
592
- sorted_results = evaluator.sort_results(sort_by_column)
593
- return evaluator.generate_html_table(sorted_results, selected_models)
 
 
 
 
 
 
 
 
 
 
 
 
594
 
595
- def update_table_model_selection(selected_models):
 
 
 
 
 
 
 
596
  # Recalculate final scores based on selected models
597
- for result in evaluator.results:
598
  scores_to_average = []
599
- for model_key in evaluator.available_models:
600
- if model_key in selected_models and evaluator.available_models[model_key]['selected']: # consider only selected models from checkbox group and available_models
601
  score = result.get(model_key)
602
  if score is not None:
603
  scores_to_average.append(score)
604
  final_score = float(np.clip(np.mean(scores_to_average), 0.0, 10.0)) if scores_to_average else None
605
  result['final_score'] = final_score
606
 
607
- sorted_results = evaluator.sort_results() # Keep sorting by Final Score when models change
608
- return evaluator.generate_html_table(sorted_results, selected_models)
609
 
610
 
611
  def clear_results():
612
- evaluator.results = []
613
  return (gr.update(value=""),
614
  gr.update(value=""),
615
  gr.update(value=""),
@@ -618,10 +631,11 @@ def create_interface():
618
  <div style='width:0%;background-color:#4CAF50;padding:5px 0;text-align:center;'>0%</div>
619
  </div>
620
  """),
621
- gr.update(value=1))
 
622
 
623
- def download_results_csv_trigger(selected_models): # Changed function name to avoid conflict and clarify purpose
624
- csv_content = results_to_csv(selected_models)
625
  if csv_content is None:
626
  return None # Indicate no file to download
627
 
@@ -633,6 +647,10 @@ def create_interface():
633
  return temp_file_path # Return the path to the temporary file
634
 
635
 
 
 
 
 
636
  auto_batch_checkbox.change(
637
  update_batch_size_interactivity,
638
  inputs=[auto_batch_checkbox],
@@ -641,30 +659,30 @@ def create_interface():
641
 
642
  process_btn.click(
643
  process_images_and_update,
644
- inputs=[input_images, auto_batch_checkbox, batch_size_input, model_checkboxes],
645
- outputs=[status_html, output_html, log_window, progress_bar, batch_size_input]
646
  )
647
  sort_dropdown.change(
648
  update_table_sort,
649
- inputs=[sort_dropdown, model_checkboxes],
650
- outputs=[output_html]
651
  )
652
  model_checkboxes.change( # Added change event for model checkboxes
653
  update_table_model_selection,
654
- inputs=[model_checkboxes],
655
- outputs=[output_html]
656
  )
657
  clear_btn.click(
658
  clear_results,
659
  inputs=[],
660
- outputs=[status_html, output_html, log_window, progress_bar, batch_size_input]
661
  )
662
  download_csv.click(
663
  download_results_csv_trigger, # Call the trigger function
664
- inputs=[model_checkboxes],
665
  outputs=[download_file_output] # Output is now the gr.File component
666
  )
667
- demo.load(lambda: update_table_sort("Final Score", model_options), inputs=None, outputs=[output_html]) # Initial sort and table render
668
  gr.Markdown("""
669
  ### Notes
670
  - Select models to use for evaluation using the checkboxes.
 
187
  # Image Evaluation Tool #
188
  #####################################
189
 
190
+ class ModelManager:
191
+ """Manages model loading and processing requests using a queue."""
192
  def __init__(self):
193
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
194
  print(f"Using device: {self.device}")
195
  print("Loading models... This may take some time.")
196
 
197
+ # Load models once during initialization
198
  print("Loading Aesthetic Shadow model...")
199
+ self.aesthetic_shadow_model = pipeline("image-classification", model="NeoChen1024/aesthetic-shadow-v2-backup", device=self.device)
200
  print("Loading Waifu Scorer model...")
201
+ self.waifu_scorer_model = WaifuScorer(device=self.device, verbose=True)
202
  print("Loading Aesthetic Predictor V2.5...")
203
+ self.aesthetic_predictor_model = load_aesthetic_predictor_v2_5()
204
  print("Loading Anime Aesthetic model...")
205
+ self.anime_aesthetic_model = load_anime_aesthetic_model()
206
  print("All models loaded successfully!")
207
 
 
 
208
  self.available_models = {
209
+ "aesthetic_shadow": {"name": "Aesthetic Shadow", "process": self._process_aesthetic_shadow, "model": self.aesthetic_shadow_model},
210
+ "waifu_scorer": {"name": "Waifu Scorer", "process": self._process_waifu_scorer, "model": self.waifu_scorer_model},
211
+ "aesthetic_predictor_v2_5": {"name": "Aesthetic V2.5", "process": self._process_aesthetic_predictor_v2_5, "model": self.aesthetic_predictor_model},
212
+ "anime_aesthetic": {"name": "Anime Score", "process": self._process_anime_aesthetic, "model": self.anime_aesthetic_model},
213
  }
214
+ self.processing_queue: asyncio.Queue = asyncio.Queue()
215
+ self.worker_task = asyncio.create_task(self._worker())
216
+ self.temp_dir = tempfile.mkdtemp()
217
 
218
+ async def _worker(self):
219
+ """Background worker to process image evaluation requests from the queue."""
220
+ while True:
221
+ request = await self.processing_queue.get()
222
+ if request is None: # Shutdown signal
223
+ self.processing_queue.task_done()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  break
225
+ try:
226
+ results = await self._process_request(request)
227
+ request['results_future'].set_result(results) # Fulfill the future with results
228
+ except Exception as e:
229
+ request['results_future'].set_exception(e) # Set exception if processing fails
230
+ finally:
231
+ self.processing_queue.task_done()
232
+
233
+ async def submit_request(self, request_data):
234
+ """Submit a new image processing request to the queue."""
235
+ results_future = asyncio.Future() # Future to hold the results
236
+ request = {**request_data, 'results_future': results_future}
237
+ await self.processing_queue.put(request)
238
+ return await results_future # Wait for and return results
239
+
240
+ async def _process_request(self, request):
241
+ """Process a single image evaluation request."""
242
+ file_paths = request['file_paths']
243
+ auto_batch = request['auto_batch']
244
+ manual_batch_size = request['manual_batch_size']
245
+ selected_models = request['selected_models']
246
  log_events = []
247
  images = []
248
  file_names = []
249
+ final_results = []
250
 
251
+ # Prepare images and file names
 
 
 
252
  total_files = len(file_paths)
253
  log_events.append(f"Starting to load {total_files} images...")
254
  for f in file_paths:
 
261
 
262
  if not images:
263
  log_events.append("No valid images loaded.")
264
+ return [], log_events, 0, manual_batch_size
 
 
 
265
 
266
+ log_events.append("Images loaded. Determining batch size...")
 
 
267
 
268
  try:
269
  manual_batch_size = int(manual_batch_size) if manual_batch_size is not None else 1
 
273
 
274
  optimal_batch = self.auto_tune_batch_size(images) if auto_batch else manual_batch_size
275
  log_events.append(f"Using batch size: {optimal_batch}")
 
 
 
276
 
277
  total_images = len(images)
278
  for i in range(0, total_images, optimal_batch):
 
283
 
284
  batch_results = {}
285
 
286
+ # Process selected models
287
+ for model_key in selected_models:
288
+ if self.available_models[model_key]['selected']: # Ensure model is selected
289
+ batch_results[model_key] = await self.available_models[model_key]['process'](self, batch_images, log_events)
290
+ else:
291
+ batch_results[model_key] = [None] * len(batch_images)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
+ # Combine results and create final results list
 
294
  for j in range(len(batch_images)):
295
  scores_to_average = []
296
+ for model_key in selected_models:
297
+ if self.available_models[model_key]['selected']: # Ensure model is selected
298
  score = batch_results[model_key][j]
299
  if score is not None:
300
  scores_to_average.append(score)
 
307
  'img_data': self.image_to_base64(thumbnail), # Keep this for the HTML display
308
  'final_score': final_score,
309
  }
310
+ for model_key in selected_models: # Add model scores to result
311
  if self.available_models[model_key]['selected']:
312
  result[model_key] = batch_results[model_key][j]
313
+ final_results.append(result)
314
 
315
+ log_events.append("All images processed.")
316
+ return final_results, log_events, 100, optimal_batch
 
 
 
 
317
 
318
 
319
+ def image_to_base64(self, image: Image.Image) -> str:
320
+ """Convert PIL Image to base64 encoded JPEG string."""
321
+ buffered = BytesIO()
322
+ image.save(buffered, format="JPEG")
323
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
324
+
325
+ def auto_tune_batch_size(self, images: list) -> int:
326
+ """Automatically determine the optimal batch size for processing."""
327
+ batch_size = 1
328
+ max_batch = len(images)
329
+ test_image = images[0:1]
330
+ while batch_size <= max_batch:
331
+ try:
332
+ if "aesthetic_shadow" in self.available_models and self.available_models["aesthetic_shadow"]['selected']: # Check if model is available and selected
333
+ _ = self.available_models["aesthetic_shadow"]['model'](test_image * batch_size)
334
+ if "waifu_scorer" in self.available_models and self.available_models["waifu_scorer"]['selected']: # Check if model is available and selected
335
+ _ = self.available_models["waifu_scorer"]['model'](test_image * batch_size)
336
+ if "aesthetic_predictor_v2_5" in self.available_models and self.available_models["aesthetic_predictor_v2_5"]['selected']: # Check if model is available and selected
337
+ _ = self.available_models["aesthetic_predictor_v2_5"]['model'].inference(test_image * batch_size)
338
+ batch_size *= 2
339
+ if batch_size > max_batch:
340
+ break
341
+ except Exception:
342
+ break
343
+ optimal = max(1, batch_size // 2)
344
+ if optimal > 64:
345
+ optimal = 64
346
+ print("Capped optimal batch size to 64")
347
+ print(f"Optimal batch size determined: {optimal}")
348
+ return optimal
349
 
350
  async def _process_aesthetic_shadow(self, batch_images, log_events):
351
  try:
352
+ shadow_results = self.available_models["aesthetic_shadow"]['model'](batch_images)
353
  log_events.append("Aesthetic Shadow processed for batch.")
354
  except Exception as e:
355
  log_events.append(f"Error in Aesthetic Shadow: {e}")
 
367
 
368
  async def _process_waifu_scorer(self, batch_images, log_events):
369
  try:
370
+ waifu_scores = self.available_models["waifu_scorer"]['model'](batch_images)
371
  waifu_scores = [float(np.clip(s, 0.0, 10.0)) if s is not None else None for s in waifu_scores]
372
  log_events.append("Waifu Scorer processed for batch.")
373
  except Exception as e:
 
377
 
378
  async def _process_aesthetic_predictor_v2_5(self, batch_images, log_events):
379
  try:
380
+ v2_5_scores = self.available_models["aesthetic_predictor_v2_5"]['model'].inference(batch_images)
381
  v2_5_scores = [float(np.round(np.clip(s, 0.0, 10.0), 4)) if s is not None else None for s in v2_5_scores]
382
  log_events.append("Aesthetic Predictor V2.5 processed for batch.")
383
  except Exception as e:
 
389
  anime_scores = []
390
  for j, img in enumerate(batch_images):
391
  try:
392
+ score = predict_anime_aesthetic(img, self.available_models["anime_aesthetic"]['model'])
393
  anime_scores.append(float(np.clip(score * 10.0, 0.0, 10.0)))
394
  log_events.append(f"Anime Aesthetic processed for image {j + 1}.")
395
  except Exception as e:
 
412
  """Format log events into an HTML string."""
413
  return "<div style='max-height:300px; overflow-y:auto;'>" + "<br>".join(logs) + "</div>"
414
 
415
+ def sort_results(self, results, sort_by: str = "Final Score") -> list:
416
  """Sort results based on the specified column."""
417
  key_map = {
418
  "Final Score": "final_score",
 
424
  }
425
  key = key_map.get(sort_by, "final_score")
426
  reverse = sort_by != "File Name"
427
+ results.sort(key=lambda r: r.get(key) if r.get(key) is not None else (-float('inf') if not reverse else float('inf')), reverse=reverse)
428
+ return results
429
 
430
  def generate_html_table(self, results: list, selected_models) -> str:
431
  """Generate an HTML table to display the evaluation results."""
 
491
 
492
 
493
  def cleanup(self):
494
+ """Clean up temporary directories and shutdown worker."""
495
  if os.path.exists(self.temp_dir):
496
  shutil.rmtree(self.temp_dir)
497
+ asyncio.run(self.shutdown()) # Shutdown worker gracefully
498
+
499
+ async def shutdown(self):
500
+ """Send shutdown signal to worker and wait for it to finish."""
501
+ await self.processing_queue.put(None) # Send shutdown signal
502
+ await self.worker_task # Wait for worker task to complete
503
+ await self.processing_queue.join() # Wait for queue to be empty
504
 
505
 
506
  #####################################
507
  # Interface #
508
  #####################################
509
 
510
+ model_manager = ModelManager() # Initialize ModelManager once outside the interface function
511
+
512
  def create_interface():
 
513
  sort_options = ["Final Score", "File Name", "Aesthetic Shadow", "Waifu Scorer", "Aesthetic V2.5", "Anime Score"]
514
  model_options = ["aesthetic_shadow", "waifu_scorer", "aesthetic_predictor_v2_5", "anime_aesthetic"]
515
 
 
552
  status_html = gr.HTML(label="Status")
553
  output_html = gr.HTML(label="Evaluation Results")
554
  download_file_output = gr.File() # Initialize gr.File component without filename
555
+ global_results_state = gr.State([]) # Initialize a global state to hold results
556
 
557
  # Function to convert results to CSV format, excluding 'img_data'.
558
+ def results_to_csv(results, selected_models): # Take results as input
559
  import csv
560
  import io
561
+ if not results:
562
  return None # Return None when no results are available
563
  output = io.StringIO()
564
  fieldnames = ['file_name', 'final_score'] # Base fieldnames
 
568
 
569
  writer = csv.DictWriter(output, fieldnames=fieldnames)
570
  writer.writeheader()
571
+ for res in results:
572
  row_dict = {'file_name': res['file_name'], 'final_score': res['final_score']} # Base data
573
  for model_key in selected_models: # Add selected model scores
574
  if model_key in selected_models: # Double check before accessing res[model_key]
 
580
  def update_batch_size_interactivity(auto_batch):
581
  return gr.update(interactive=not auto_batch)
582
 
583
+ async def process_images_and_update(files, auto_batch, manual_batch, selected_models, current_results):
584
  file_paths = [f.name for f in files]
 
 
585
 
586
+ # Prepare request data for the ModelManager
587
+ request_data = {
588
+ 'file_paths': file_paths,
589
+ 'auto_batch': auto_batch,
590
+ 'manual_batch_size': manual_batch,
591
+ 'selected_models': {model: {'selected': model in selected_models} for model in model_options} # Pass model selections
592
+ }
593
+ # Submit request and get results from ModelManager
594
+ results, logs, progress_percent, updated_batch = await model_manager.submit_request(request_data)
595
+
596
+ updated_results = current_results + results # Append new results to current results
597
+
598
+ html_table = model_manager.generate_html_table(updated_results, selected_models)
599
+ progress_html = model_manager._generate_progress_html(progress_percent)
600
+ log_html = model_manager._format_logs(logs[-10:])
601
 
602
+ return status_html, html_table, log_html, progress_html, gr.update(value=updated_batch, interactive=not auto_batch), updated_results
603
+
604
+
605
+ def update_table_sort(sort_by_column, selected_models, current_results):
606
+ sorted_results = model_manager.sort_results(current_results, sort_by_column)
607
+ return model_manager.generate_html_table(sorted_results, selected_models), sorted_results # Return sorted results
608
+
609
+ def update_table_model_selection(selected_models, current_results):
610
  # Recalculate final scores based on selected models
611
+ for result in current_results:
612
  scores_to_average = []
613
+ for model_key in model_options: # Use model_options here, not available_models from manager in UI context
614
+ if model_key in selected_models and model_key in model_manager.available_models and model_manager.available_models[model_key]['selected']: # consider only selected models from checkbox group and available_models
615
  score = result.get(model_key)
616
  if score is not None:
617
  scores_to_average.append(score)
618
  final_score = float(np.clip(np.mean(scores_to_average), 0.0, 10.0)) if scores_to_average else None
619
  result['final_score'] = final_score
620
 
621
+ sorted_results = model_manager.sort_results(current_results, "Final Score") # Keep sorting by Final Score when models change
622
+ return model_manager.generate_html_table(sorted_results, selected_models), sorted_results
623
 
624
 
625
  def clear_results():
 
626
  return (gr.update(value=""),
627
  gr.update(value=""),
628
  gr.update(value=""),
 
631
  <div style='width:0%;background-color:#4CAF50;padding:5px 0;text-align:center;'>0%</div>
632
  </div>
633
  """),
634
+ gr.update(value=1),
635
+ []) # Clear results state
636
 
637
+ def download_results_csv_trigger(selected_models, current_results): # Changed function name to avoid conflict and clarify purpose
638
+ csv_content = results_to_csv(current_results, selected_models)
639
  if csv_content is None:
640
  return None # Indicate no file to download
641
 
 
647
  return temp_file_path # Return the path to the temporary file
648
 
649
 
650
+ # Set initial selection state for models in ModelManager (important!)
651
+ for model_key in model_options:
652
+ model_manager.available_models[model_key]['selected'] = True # Default to all selected initially
653
+
654
  auto_batch_checkbox.change(
655
  update_batch_size_interactivity,
656
  inputs=[auto_batch_checkbox],
 
659
 
660
  process_btn.click(
661
  process_images_and_update,
662
+ inputs=[input_images, auto_batch_checkbox, batch_size_input, model_checkboxes, global_results_state],
663
+ outputs=[status_html, output_html, log_window, progress_bar, batch_size_input, global_results_state]
664
  )
665
  sort_dropdown.change(
666
  update_table_sort,
667
+ inputs=[sort_dropdown, model_checkboxes, global_results_state],
668
+ outputs=[output_html, global_results_state]
669
  )
670
  model_checkboxes.change( # Added change event for model checkboxes
671
  update_table_model_selection,
672
+ inputs=[model_checkboxes, global_results_state],
673
+ outputs=[output_html, global_results_state]
674
  )
675
  clear_btn.click(
676
  clear_results,
677
  inputs=[],
678
+ outputs=[status_html, output_html, log_window, progress_bar, batch_size_input, global_results_state]
679
  )
680
  download_csv.click(
681
  download_results_csv_trigger, # Call the trigger function
682
+ inputs=[model_checkboxes, global_results_state],
683
  outputs=[download_file_output] # Output is now the gr.File component
684
  )
685
+ demo.load(lambda: update_table_sort("Final Score", model_options, []), inputs=None, outputs=[output_html, global_results_state]) # Initial sort and table render, pass empty initial results
686
  gr.Markdown("""
687
  ### Notes
688
  - Select models to use for evaluation using the checkboxes.