VOIDER commited on
Commit
d924e11
·
verified ·
1 Parent(s): 3801ded

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +500 -455
app.py CHANGED
@@ -1,27 +1,32 @@
1
- import gradio as gr
2
- import torch
3
  import os
4
- import numpy as np
 
 
 
 
 
5
  import cv2
 
 
6
  import onnxruntime as rt
7
  from PIL import Image
 
8
  from transformers import pipeline
9
  from huggingface_hub import hf_hub_download
10
- import pandas as pd
11
- import tempfile
12
- import shutil
13
- import base64
14
- from io import BytesIO
15
 
16
  # Import necessary function from aesthetic_predictor_v2_5
17
  from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
18
 
 
 
 
 
 
19
  class MLP(torch.nn.Module):
20
- def __init__(self, input_size, xcol='emb', ycol='avg_rating', batch_norm=True):
 
21
  super().__init__()
22
  self.input_size = input_size
23
- self.xcol = xcol
24
- self.ycol = ycol
25
  self.layers = torch.nn.Sequential(
26
  torch.nn.Linear(self.input_size, 2048),
27
  torch.nn.ReLU(),
@@ -44,29 +49,37 @@ class MLP(torch.nn.Module):
44
  torch.nn.Linear(32, 1)
45
  )
46
 
47
- def forward(self, x):
48
  return self.layers(x)
49
 
50
- class WaifuScorer(object):
51
- def __init__(self, model_path=None, device='cuda', cache_dir=None, verbose=False):
 
 
52
  self.verbose = verbose
 
 
 
53
 
54
  try:
55
- import clip
56
-
57
  if model_path is None:
58
  model_path = "Eugeoter/waifu-scorer-v3/model.pth"
59
  if self.verbose:
60
- print(f"model path not set, switch to default: `{model_path}`")
61
 
 
62
  if not os.path.isfile(model_path):
63
- split = model_path.split("/")
64
- username, repo_id, model_name = split[-3], split[-2], split[-1]
65
  model_path = hf_hub_download(f"{username}/{repo_id}", model_name, cache_dir=cache_dir)
66
 
67
- print(f"Loading WaifuScorer model from `{model_path}`")
 
68
 
 
69
  self.mlp = MLP(input_size=768)
 
70
  if model_path.endswith(".safetensors"):
71
  from safetensors.torch import load_file
72
  state_dict = load_file(model_path)
@@ -74,42 +87,44 @@ class WaifuScorer(object):
74
  state_dict = torch.load(model_path, map_location=device)
75
  self.mlp.load_state_dict(state_dict)
76
  self.mlp.to(device)
77
-
78
- self.model2, self.preprocess = clip.load("ViT-L/14", device=device)
79
- self.device = device
80
- self.dtype = torch.float32
81
  self.mlp.eval()
 
 
 
82
  self.available = True
83
  except Exception as e:
84
  print(f"Unable to initialize WaifuScorer: {e}")
85
- self.available = False
86
 
87
  @torch.no_grad()
88
  def __call__(self, images):
89
  if not self.available:
90
- return [None] * (1 if not isinstance(images, list) else len(images))
91
-
92
  if isinstance(images, Image.Image):
93
  images = [images]
94
  n = len(images)
 
95
  if n == 1:
96
- images = images*2
97
 
98
  image_tensors = [self.preprocess(img).unsqueeze(0) for img in images]
99
  image_batch = torch.cat(image_tensors).to(self.device)
100
- image_features = self.model2.encode_image(image_batch)
101
-
102
- l2 = image_features.norm(2, dim=-1, keepdim=True)
103
- l2[l2 == 0] = 1
104
- im_emb_arr = (image_features / l2).to(device=self.device, dtype=self.dtype)
105
-
106
- predictions = self.mlp(im_emb_arr)
107
  scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist()
108
-
109
  return scores[:n]
110
 
 
 
 
 
 
111
  def load_aesthetic_predictor_v2_5():
112
- class AestheticPredictorV2_5_Impl: # Renamed class to avoid confusion
 
113
  def __init__(self):
114
  print("Loading Aesthetic Predictor V2.5...")
115
  self.model, self.preprocessor = convert_v2_5_from_siglip(
@@ -119,516 +134,546 @@ def load_aesthetic_predictor_v2_5():
119
  if torch.cuda.is_available():
120
  self.model = self.model.to(torch.bfloat16).cuda()
121
 
122
- def inference(self, image: Image.Image) -> float:
123
- # preprocess image
124
- pixel_values = self.preprocessor(
125
- images=image.convert("RGB"), return_tensors="pt"
126
- ).pixel_values
127
-
128
- if torch.cuda.is_available():
129
- pixel_values = pixel_values.to(torch.bfloat16).cuda()
130
-
131
- # predict aesthetic score
132
- with torch.inference_mode():
133
- score = self.model(pixel_values).logits.squeeze().float().cpu().numpy()
 
 
 
 
 
 
134
 
135
- return score
136
 
137
- return AestheticPredictorV2_5_Impl() # Return an instance of the implementation class
138
 
139
  def load_anime_aesthetic_model():
 
140
  model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx")
141
- model = rt.InferenceSession(model_path, providers=['CPUExecutionProvider'])
142
- return model
143
 
144
  def predict_anime_aesthetic(img, model):
145
- img = np.array(img).astype(np.float32) / 255
 
146
  s = 768
147
- h, w = img.shape[:-1]
148
- h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
149
- ph, pw = s - h, s - w
150
- img_input = np.zeros([s, s, 3], dtype=np.float32)
151
- img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(img, (w, h))
152
- img_input = np.transpose(img_input, (2, 0, 1))
153
- img_input = img_input[np.newaxis, :]
154
- pred = model.run(None, {"img": img_input})[0].item()
 
 
 
 
 
 
155
  return pred
156
 
 
 
 
 
 
157
  class ImageEvaluationTool:
 
158
  def __init__(self):
159
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
160
  print(f"Using device: {self.device}")
161
-
162
  print("Loading models... This may take some time.")
163
 
 
164
  print("Loading Aesthetic Shadow model...")
165
  self.aesthetic_shadow = pipeline("image-classification", model="NeoChen1024/aesthetic-shadow-v2-backup", device=self.device)
166
-
167
  print("Loading Waifu Scorer model...")
168
  self.waifu_scorer = WaifuScorer(device=self.device, verbose=True)
169
-
170
  print("Loading Aesthetic Predictor V2.5...")
171
- self.aesthetic_predictor_v2_5 = load_aesthetic_predictor_v2_5()
172
-
173
  print("Loading Anime Aesthetic model...")
174
  self.anime_aesthetic = load_anime_aesthetic_model()
175
-
176
  print("All models loaded successfully!")
177
 
178
  self.temp_dir = tempfile.mkdtemp()
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
- def evaluate_image(self, image):
181
- results = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
- if not isinstance(image, Image.Image):
184
- image = Image.fromarray(image)
 
 
 
 
185
 
186
- try:
187
- shadow_result = self.aesthetic_shadow(images=[image])[0]
188
- hq_score = [p for p in shadow_result if p['label'] == 'hq'][0]['score']
189
- # Scale aesthetic_shadow to 0-10 and clamp
190
- aesthetic_shadow_score = np.clip(hq_score * 10.0, 0.0, 10.0)
191
- results['aesthetic_shadow'] = aesthetic_shadow_score
192
- except Exception as e:
193
- print(f"Error in Aesthetic Shadow: {e}")
194
- results['aesthetic_shadow'] = None
195
 
196
  try:
197
- waifu_score = self.waifu_scorer([image])[0]
198
- # Clamp waifu_score
199
- waifu_score_clamped = np.clip(waifu_score, 0.0, 10.0)
200
- results['waifu_scorer'] = waifu_score_clamped
201
- except Exception as e:
202
- print(f"Error in Waifu Scorer: {e}")
203
- results['waifu_scorer'] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- try:
206
- v2_5_score = self.aesthetic_predictor_v2_5.inference(image)
207
- # Clamp v2.5 score
208
- v2_5_score_clamped = np.clip(v2_5_score, 0.0, 10.0)
209
- results['aesthetic_predictor_v2_5'] = float(np.round(v2_5_score_clamped, 4)) # Keep 4 decimal places after clamping
210
- except Exception as e:
211
- print(f"Error in Aesthetic Predictor V2.5: {e}")
212
- results['aesthetic_predictor_v2_5'] = None
213
 
214
- try:
215
- img_array = np.array(image)
216
- anime_score = predict_anime_aesthetic(img_array, self.anime_aesthetic)
217
- # Scale Anime Score to 0-10 and clamp
218
- anime_score_scaled = np.clip(anime_score * 10.0, 0.0, 10.0)
219
- results['anime_aesthetic'] = anime_score_scaled
220
- except Exception as e:
221
- print(f"Error in Anime Aesthetic: {e}")
222
- results['anime_aesthetic'] = None
223
 
224
- # Calculate Final Score (simple average of available scores)
225
- valid_scores = [v for v in results.values() if v is not None]
226
- if valid_scores:
227
- final_score = np.mean(valid_scores)
228
- results['final_score'] = np.clip(final_score, 0.0, 10.0) # Clamp final score too
229
- else:
230
- results['final_score'] = None
231
 
232
- return results
233
 
234
- def image_to_base64(self, image):
235
- buffered = BytesIO()
236
- image.save(buffered, format="JPEG")
237
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
 
 
 
 
238
 
239
- def process_single_image(self, file_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  try:
241
- img = Image.open(file_path).convert("RGB")
242
- eval_results = self.evaluate_image(img)
243
- thumbnail = img.copy()
244
- thumbnail.thumbnail((200, 200))
245
- img_base64 = self.image_to_base64(thumbnail)
246
- result = {
247
- 'file_name': os.path.basename(file_path),
248
- 'img_data': img_base64,
249
- **eval_results
250
- }
251
- return result
252
  except Exception as e:
253
- print(f"Error processing {file_path}: {e}")
254
- return None
255
-
256
- def process_images_evaluation(self, image_files): # Renamed and now for evaluation only
257
- results = []
258
-
259
- for i, file_path in enumerate(image_files):
260
  try:
261
- img = Image.open(file_path).convert("RGB")
262
- eval_results = self.evaluate_image(img)
263
-
264
- thumbnail = img.copy()
265
- thumbnail.thumbnail((200, 200))
266
-
267
- img_base64 = self.image_to_base64(thumbnail)
 
 
 
 
 
 
 
 
 
 
268
 
269
- result = {
270
- 'file_name': os.path.basename(file_path),
271
- 'img_data': img_base64,
272
- **eval_results
273
- }
274
- results.append(result)
 
 
 
275
 
 
 
 
 
 
 
 
276
  except Exception as e:
277
- print(f"Error processing {file_path}: {e}")
278
-
279
- return results
280
-
281
- def sort_results(self, results, sort_by="Final Score"): # New function for sorting
282
- def sort_key(res): # Define a sorting key function
283
- sort_value = res.get(sort_by.lower().replace(" ", "_"), None) # Handle spaces and case
284
- if sort_value is None: # Put N/A at the end
285
- return -float('inf') if sort_by == "File Name" else float('inf') # File Name sort N/A at end alphabetically
286
- return sort_value
287
-
288
- results.sort(key=sort_key, reverse=sort_by != "File Name") # Sort results, reverse for score columns
289
- return results
 
290
 
291
- def generate_html_table(self, results):
292
- html = """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  <style>
294
- .results-table {
295
- width: 100%;
296
- border-collapse: collapse;
297
- margin: 20px 0;
298
- font-family: Arial, sans-serif;
299
- background-color: transparent;
300
- }
301
-
302
- .results-table th,
303
- .results-table td {
304
- color: #eee;
305
- border: 1px solid #ddd;
306
- padding: 8px;
307
- text-align: center;
308
- background-color: transparent;
309
- }
310
-
311
- .results-table th {
312
- font-weight: bold;
313
- }
314
-
315
- .results-table tr:nth-child(even) {
316
- background-color: transparent;
317
- }
318
-
319
- .results-table tr:hover {
320
- background-color: rgba(255, 255, 255, 0.1);
321
- }
322
-
323
- .image-preview {
324
- max-width: 150px;
325
- max-height: 150px;
326
- display: block;
327
- margin: 0 auto;
328
- }
329
-
330
- .good-score {
331
- color: #0f0;
332
- font-weight: bold;
333
- }
334
- .bad-score {
335
- color: #f00;
336
- font-weight: bold;
337
- }
338
- .medium-score {
339
- color: orange;
340
- font-weight: bold;
341
- }
342
  </style>
343
-
344
  <table class="results-table">
345
  <thead>
346
  <tr>
347
  <th>Image</th>
348
  <th>File Name</th>
349
- <th>Aesthetic Shadow</th>
350
- <th>Waifu Scorer</th>
351
- <th>Aesthetic V2.5</th>
352
- <th>Anime Score</th>
353
- <th>Final Score</th>
354
- </tr>
355
- </thead>
356
- <tbody>
357
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
359
  for result in results:
360
- html += "<tr>"
361
- html += f'<td><img src="data:image/jpeg;base64,{result["img_data"]}" class="image-preview"></td>'
362
- html += f'<td>{result["file_name"]}</td>'
363
-
364
- score = result["aesthetic_shadow"]
365
- score_class = "good-score" if score and score >= 7 else "medium-score" if score and score >= 4 else "bad-score"
366
- html += f'<td class="{score_class}">{score if score is not None else "N/A":.4f}</td>' # Format to 4 decimal places
367
-
368
- score = result["waifu_scorer"]
369
- score_class = "good-score" if score and score >= 7 else "medium-score" if score and score >= 5 else "bad-score"
370
- html += f'<td class="{score_class}">{score if score is not None else "N/A":.4f}</td>' # Format to 4 decimal places
371
-
372
- score = result["aesthetic_predictor_v2_5"]
373
- score_class = "good-score" if score and score >= 7 else "medium-score" if score and score >= 5 else "bad-score"
374
- html += f'<td class="{score_class}">{score if score is not None else "N/A":.4f}</td>' # Format to 4 decimal places
375
-
376
- score = result["anime_aesthetic"]
377
- score_class = "good-score" if score and score >= 7 else "medium-score" if score and score >= 5 else "bad-score"
378
- html += f'<td class="{score_class}">{score if score is not None else "N/A":.4f}</td>' # Format to 4 decimal places
379
-
380
- score = result["final_score"]
381
- score_class = "good-score" if score and score >= 7 else "medium-score" if score and score >= 5 else "bad-score"
382
- html += f'<td class="{score_class}">{score if score is not None else "N/A":.4f}</td>' # Format to 4 decimal places
383
-
384
-
385
- html += "</tr>"
386
-
387
- html += """
388
- </tbody>
389
- </table>
390
- """
391
 
392
- return html
393
 
394
  def cleanup(self):
 
395
  if os.path.exists(self.temp_dir):
396
  shutil.rmtree(self.temp_dir)
397
 
398
- # Global variable to store evaluation results
399
- global_results = None
400
 
401
- def create_interface():
402
- global global_results # Use the global variable
 
403
 
 
404
  evaluator = ImageEvaluationTool()
405
- sort_options = ["Final Score", "File Name", "Aesthetic Shadow", "Waifu Scorer", "Aesthetic V2.5", "Anime Score"] # Sort options
 
406
 
407
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
408
  gr.Markdown("""
409
  # Comprehensive Image Evaluation Tool
410
 
411
- Upload images to evaluate them using multiple aesthetic and quality prediction models:
412
-
413
- - **Aesthetic Shadow**: Evaluates high-quality vs low-quality images (scaled to 0-10)
414
- - **Waifu Scorer**: Rates anime/illustration quality from 0-10
415
- - **Aesthetic Predictor V2.5**: General aesthetic quality prediction (clamped to 0-10)
416
- - **Anime Aesthetic**: Specific model for anime style images (scaled and clamped to 0-10)
417
- - **Final Score**: Average of available scores (clamped to 0-10)
418
-
419
- Upload multiple images to get a comprehensive evaluation table. Scores are clamped to the range 0.0000 - 10.0000.
 
 
 
420
  """)
421
 
422
  with gr.Row():
423
  with gr.Column(scale=1):
424
- input_images = gr.Files(label="Upload Images")
425
- sort_dropdown = gr.Dropdown(sort_options, value="Final Score", label="Sort by") # Dropdown for sorting
 
 
 
426
  process_btn = gr.Button("Evaluate Images", variant="primary")
427
  clear_btn = gr.Button("Clear Results")
 
428
 
429
  with gr.Column(scale=2):
430
- progress_html = gr.HTML(label="Progress") # Keep progress_html if you want to show initial progress
 
 
 
 
 
 
431
  output_html = gr.HTML(label="Evaluation Results")
432
-
433
- def process_images_and_update(files):
434
- global global_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  file_paths = [f.name for f in files]
436
- total_files = len(file_paths)
437
- results = []
438
-
439
- if not file_paths:
440
- global_results = []
441
- yield "<p>No files uploaded.</p>", gr.update()
442
- return
443
-
444
- # Helper function to generate a styled progress bar HTML snippet.
445
- def generate_progress_bar(percentage):
446
- return f"""
447
- <div style="background-color: #ddd; border-radius: 5px; width: 100%; margin: 10px 0;">
448
- <div style="width: {percentage:.1f}%; background-color: #4CAF50; text-align: center; padding: 5px 0; border-radius: 5px;">
449
- {percentage:.1f}%
450
- </div>
451
- </div>
452
- """
453
-
454
- total_models = 4 # Total number of model steps per image.
455
-
456
- for i, file_path in enumerate(file_paths):
457
- file_name = os.path.basename(file_path)
458
- try:
459
- img = Image.open(file_path).convert("RGB")
460
- except Exception as e:
461
- yield f"<p>Error opening {file_name}: {e}</p>", gr.update()
462
- continue
463
-
464
- # Update overall progress before starting this image.
465
- overall_percent = (i / total_files) * 100
466
- overall_bar = generate_progress_bar(overall_percent)
467
- progress_html = f"""
468
- <p>{file_name}: Starting evaluation...</p>
469
- <p>Overall Progress:</p>
470
- {overall_bar}
471
- """
472
- yield progress_html, gr.update()
473
-
474
- # === Model Step 1: Aesthetic Shadow ===
475
- model_index = 0
476
- sub_percent = ((model_index + 1) / total_models) * 100
477
- sub_bar = generate_progress_bar(sub_percent)
478
- progress_html = f"""
479
- <p>{file_name}: Evaluating <strong>Aesthetic Shadow</strong>...</p>
480
- <p>Current Image Progress:</p>
481
- {sub_bar}
482
- <p>Overall Progress:</p>
483
- {overall_bar}
484
- """
485
- yield progress_html, gr.update()
486
- try:
487
- shadow_result = evaluator.aesthetic_shadow(images=[img])[0]
488
- hq_score = [p for p in shadow_result if p['label'] == 'hq'][0]['score']
489
- aesthetic_shadow_score = np.clip(hq_score * 10.0, 0.0, 10.0)
490
- except Exception as e:
491
- print(f"Error in Aesthetic Shadow for {file_name}: {e}")
492
- aesthetic_shadow_score = None
493
- yield f"<p>{file_name}: <strong>Aesthetic Shadow</strong> evaluation complete.</p>", gr.update()
494
-
495
- # === Model Step 2: Waifu Scorer ===
496
- model_index = 1
497
- sub_percent = ((model_index + 1) / total_models) * 100
498
- sub_bar = generate_progress_bar(sub_percent)
499
- progress_html = f"""
500
- <p>{file_name}: Evaluating <strong>Waifu Scorer</strong>...</p>
501
- <p>Current Image Progress:</p>
502
- {sub_bar}
503
- <p>Overall Progress:</p>
504
- {overall_bar}
505
- """
506
- yield progress_html, gr.update()
507
- try:
508
- waifu_score = evaluator.waifu_scorer([img])[0]
509
- waifu_score = np.clip(waifu_score, 0.0, 10.0)
510
- except Exception as e:
511
- print(f"Error in Waifu Scorer for {file_name}: {e}")
512
- waifu_score = None
513
- yield f"<p>{file_name}: <strong>Waifu Scorer</strong> evaluation complete.</p>", gr.update()
514
-
515
- # === Model Step 3: Aesthetic Predictor V2.5 ===
516
- model_index = 2
517
- sub_percent = ((model_index + 1) / total_models) * 100
518
- sub_bar = generate_progress_bar(sub_percent)
519
- progress_html = f"""
520
- <p>{file_name}: Evaluating <strong>Aesthetic Predictor V2.5</strong>...</p>
521
- <p>Current Image Progress:</p>
522
- {sub_bar}
523
- <p>Overall Progress:</p>
524
- {overall_bar}
525
- """
526
- yield progress_html, gr.update()
527
- try:
528
- v2_5_score = evaluator.aesthetic_predictor_v2_5.inference(img)
529
- v2_5_score = float(np.round(np.clip(v2_5_score, 0.0, 10.0), 4))
530
- except Exception as e:
531
- print(f"Error in Aesthetic Predictor V2.5 for {file_name}: {e}")
532
- v2_5_score = None
533
- yield f"<p>{file_name}: <strong>Aesthetic Predictor V2.5</strong> evaluation complete.</p>", gr.update()
534
-
535
- # === Model Step 4: Anime Aesthetic ===
536
- model_index = 3
537
- sub_percent = ((model_index + 1) / total_models) * 100
538
- sub_bar = generate_progress_bar(sub_percent)
539
- progress_html = f"""
540
- <p>{file_name}: Evaluating <strong>Anime Aesthetic</strong>...</p>
541
- <p>Current Image Progress:</p>
542
- {sub_bar}
543
- <p>Overall Progress:</p>
544
- {overall_bar}
545
- """
546
- yield progress_html, gr.update()
547
- try:
548
- img_array = np.array(img)
549
- anime_score = predict_anime_aesthetic(img_array, evaluator.anime_aesthetic)
550
- anime_score = np.clip(anime_score * 10.0, 0.0, 10.0)
551
- except Exception as e:
552
- print(f"Error in Anime Aesthetic for {file_name}: {e}")
553
- anime_score = None
554
- yield f"<p>{file_name}: <strong>Anime Aesthetic</strong> evaluation complete.</p>", gr.update()
555
-
556
- # === Final Score Calculation and Results Collection ===
557
- valid_scores = [v for v in [aesthetic_shadow_score, waifu_score, v2_5_score, anime_score] if v is not None]
558
- final_score = np.clip(np.mean(valid_scores), 0.0, 10.0) if valid_scores else None
559
-
560
- # Create a thumbnail and store the evaluation results.
561
- thumbnail = img.copy()
562
- thumbnail.thumbnail((200, 200))
563
- img_base64 = evaluator.image_to_base64(thumbnail)
564
- result = {
565
- 'file_name': file_name,
566
- 'img_data': img_base64,
567
- 'aesthetic_shadow': aesthetic_shadow_score,
568
- 'waifu_scorer': waifu_score,
569
- 'aesthetic_predictor_v2_5': v2_5_score,
570
- 'anime_aesthetic': anime_score,
571
- 'final_score': final_score
572
- }
573
- results.append(result)
574
-
575
- # Update overall progress for the processed file.
576
- overall_percent = ((i + 1) / total_files) * 100
577
- overall_bar = generate_progress_bar(overall_percent)
578
- progress_html = f"""
579
- <p>{file_name}: Evaluation complete. ({i + 1}/{total_files} images processed.)</p>
580
- <p>Overall Progress:</p>
581
- {overall_bar}
582
- """
583
- # Sort the results by Final Score and update the table.
584
- sorted_results = evaluator.sort_results(results.copy(), sort_by="Final Score")
585
- html_table = evaluator.generate_html_table(sorted_results)
586
- yield progress_html, html_table
587
-
588
- # Sort final results by Final Score.
589
- global_results = evaluator.sort_results(results, sort_by="Final Score")
590
- yield "<p>All images processed.</p>", evaluator.generate_html_table(global_results)
591
-
592
- def update_table_sort(sort_by_column): # New function for sorting update
593
- global global_results
594
- if global_results is None:
595
- return "No images evaluated yet." # Or handle case when no images are evaluated
596
- sorted_results = evaluator.sort_results(global_results, sort_by=sort_by_column)
597
- html_table = evaluator.generate_html_table(sorted_results)
598
- return html_table
599
 
600
- def clear_results():
601
- global global_results
602
- global_results = None # Clear stored results
603
- return gr.update(value=""), gr.update(value="")
604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
 
606
  process_btn.click(
607
  process_images_and_update,
608
- inputs=[input_images],
609
- outputs=[progress_html, output_html]
610
  )
611
- sort_dropdown.change( # Only update table on sort change
612
  update_table_sort,
613
- inputs=[sort_dropdown],
614
- outputs=[output_html] # Only update output_html
 
 
 
 
 
615
  )
616
  clear_btn.click(
617
  clear_results,
618
  inputs=[],
619
- outputs=[progress_html, output_html]
620
  )
621
-
622
- demo.load(lambda: None, inputs=None, outputs=None)
623
-
 
 
 
624
  gr.Markdown("""
625
  ### Notes
626
- - The evaluation may take some time depending on the number and size of images
627
- - For best results, use high-quality images
628
- - Scores are color-coded: green for good (>=7), orange for medium (>=5), and red for poor scores (<5, or <4 for Aesthetic Shadow)
629
- - Some models may fail for certain image types, shown as "N/A" in the results
630
- - "Final Score" is a simple average of available model scores.
631
- - Table is sortable by clicking the dropdown above the "Evaluate Images" button. Default sort is by "Final Score". Sorting happens instantly without re-evaluating images.
 
632
  """)
633
 
634
  return demo
 
 
 
1
  import os
2
+ import shutil
3
+ import tempfile
4
+ import base64
5
+ import asyncio
6
+ from io import BytesIO
7
+
8
  import cv2
9
+ import numpy as np
10
+ import torch
11
  import onnxruntime as rt
12
  from PIL import Image
13
+ import gradio as gr
14
  from transformers import pipeline
15
  from huggingface_hub import hf_hub_download
 
 
 
 
 
16
 
17
  # Import necessary function from aesthetic_predictor_v2_5
18
  from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
19
 
20
+
21
+ #####################################
22
+ # Model Definitions #
23
+ #####################################
24
+
25
  class MLP(torch.nn.Module):
26
+ """A simple multi-layer perceptron for image feature regression."""
27
+ def __init__(self, input_size: int, batch_norm: bool = True):
28
  super().__init__()
29
  self.input_size = input_size
 
 
30
  self.layers = torch.nn.Sequential(
31
  torch.nn.Linear(self.input_size, 2048),
32
  torch.nn.ReLU(),
 
49
  torch.nn.Linear(32, 1)
50
  )
51
 
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
  return self.layers(x)
54
 
55
+
56
+ class WaifuScorer:
57
+ """WaifuScorer model that uses CLIP for feature extraction and a custom MLP for scoring."""
58
+ def __init__(self, model_path: str = None, device: str = 'cuda', cache_dir: str = None, verbose: bool = False):
59
  self.verbose = verbose
60
+ self.device = device
61
+ self.dtype = torch.float32
62
+ self.available = False
63
 
64
  try:
65
+ import clip # local import to avoid dependency issues
66
+ # Set default model path if not provided
67
  if model_path is None:
68
  model_path = "Eugeoter/waifu-scorer-v3/model.pth"
69
  if self.verbose:
70
+ print(f"Model path not provided. Using default: {model_path}")
71
 
72
+ # Download model if not found locally
73
  if not os.path.isfile(model_path):
74
+ username, repo_id, model_name = model_path.split("/")[-3:]
 
75
  model_path = hf_hub_download(f"{username}/{repo_id}", model_name, cache_dir=cache_dir)
76
 
77
+ if self.verbose:
78
+ print(f"Loading WaifuScorer model from: {model_path}")
79
 
80
+ # Initialize MLP model
81
  self.mlp = MLP(input_size=768)
82
+ # Load state dict
83
  if model_path.endswith(".safetensors"):
84
  from safetensors.torch import load_file
85
  state_dict = load_file(model_path)
 
87
  state_dict = torch.load(model_path, map_location=device)
88
  self.mlp.load_state_dict(state_dict)
89
  self.mlp.to(device)
 
 
 
 
90
  self.mlp.eval()
91
+
92
+ # Load CLIP model for image preprocessing and feature extraction
93
+ self.clip_model, self.preprocess = clip.load("ViT-L/14", device=device)
94
  self.available = True
95
  except Exception as e:
96
  print(f"Unable to initialize WaifuScorer: {e}")
 
97
 
98
  @torch.no_grad()
99
  def __call__(self, images):
100
  if not self.available:
101
+ return [None] * (len(images) if isinstance(images, list) else 1)
 
102
  if isinstance(images, Image.Image):
103
  images = [images]
104
  n = len(images)
105
+ # Ensure at least two images for CLIP model compatibility
106
  if n == 1:
107
+ images = images * 2
108
 
109
  image_tensors = [self.preprocess(img).unsqueeze(0) for img in images]
110
  image_batch = torch.cat(image_tensors).to(self.device)
111
+ image_features = self.clip_model.encode_image(image_batch)
112
+ # Normalize features
113
+ norm = image_features.norm(2, dim=-1, keepdim=True)
114
+ norm[norm == 0] = 1
115
+ im_emb = (image_features / norm).to(device=self.device, dtype=self.dtype)
116
+ predictions = self.mlp(im_emb)
 
117
  scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist()
 
118
  return scores[:n]
119
 
120
+
121
+ #####################################
122
+ # Aesthetic Predictor Functions #
123
+ #####################################
124
+
125
  def load_aesthetic_predictor_v2_5():
126
+ """Load and return an instance of Aesthetic Predictor V2.5 with batch processing support."""
127
+ class AestheticPredictorV2_5_Impl:
128
  def __init__(self):
129
  print("Loading Aesthetic Predictor V2.5...")
130
  self.model, self.preprocessor = convert_v2_5_from_siglip(
 
134
  if torch.cuda.is_available():
135
  self.model = self.model.to(torch.bfloat16).cuda()
136
 
137
+ def inference(self, image):
138
+ if isinstance(image, list):
139
+ images_rgb = [img.convert("RGB") for img in image]
140
+ pixel_values = self.preprocessor(images=images_rgb, return_tensors="pt").pixel_values
141
+ if torch.cuda.is_available():
142
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
143
+ with torch.inference_mode():
144
+ scores = self.model(pixel_values).logits.squeeze().float().cpu().numpy()
145
+ if scores.ndim == 0:
146
+ scores = np.array([scores])
147
+ return scores.tolist()
148
+ else:
149
+ pixel_values = self.preprocessor(images=image.convert("RGB"), return_tensors="pt").pixel_values
150
+ if torch.cuda.is_available():
151
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
152
+ with torch.inference_mode():
153
+ score = self.model(pixel_values).logits.squeeze().float().cpu().numpy()
154
+ return score
155
 
156
+ return AestheticPredictorV2_5_Impl()
157
 
 
158
 
159
  def load_anime_aesthetic_model():
160
+ """Load and return the Anime Aesthetic ONNX model."""
161
  model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx")
162
+ return rt.InferenceSession(model_path, providers=['CPUExecutionProvider'])
163
+
164
 
165
  def predict_anime_aesthetic(img, model):
166
+ """Predict Anime Aesthetic score for a single image."""
167
+ img_np = np.array(img).astype(np.float32) / 255.0
168
  s = 768
169
+ h, w = img_np.shape[:2]
170
+ if h > w:
171
+ new_h, new_w = s, int(s * w / h)
172
+ else:
173
+ new_h, new_w = int(s * h / w), s
174
+ resized = cv2.resize(img_np, (new_w, new_h))
175
+ # Center the resized image in a square canvas
176
+ canvas = np.zeros((s, s, 3), dtype=np.float32)
177
+ pad_h = (s - new_h) // 2
178
+ pad_w = (s - new_w) // 2
179
+ canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
180
+ # Prepare input for model
181
+ input_tensor = np.transpose(canvas, (2, 0, 1))[np.newaxis, :]
182
+ pred = model.run(None, {"img": input_tensor})[0].item()
183
  return pred
184
 
185
+
186
+ #####################################
187
+ # Image Evaluation Tool #
188
+ #####################################
189
+
190
  class ImageEvaluationTool:
191
+ """Evaluation tool to process images through multiple aesthetic models and generate logs and HTML outputs."""
192
  def __init__(self):
193
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
194
  print(f"Using device: {self.device}")
 
195
  print("Loading models... This may take some time.")
196
 
197
+ # Load models with progress logs
198
  print("Loading Aesthetic Shadow model...")
199
  self.aesthetic_shadow = pipeline("image-classification", model="NeoChen1024/aesthetic-shadow-v2-backup", device=self.device)
 
200
  print("Loading Waifu Scorer model...")
201
  self.waifu_scorer = WaifuScorer(device=self.device, verbose=True)
 
202
  print("Loading Aesthetic Predictor V2.5...")
203
+ self.aesthetic_predictor = load_aesthetic_predictor_v2_5()
 
204
  print("Loading Anime Aesthetic model...")
205
  self.anime_aesthetic = load_anime_aesthetic_model()
 
206
  print("All models loaded successfully!")
207
 
208
  self.temp_dir = tempfile.mkdtemp()
209
+ self.results = [] # Store final results for sorting and display
210
+ self.available_models = {
211
+ "aesthetic_shadow": {"name": "Aesthetic Shadow", "process": self._process_aesthetic_shadow},
212
+ "waifu_scorer": {"name": "Waifu Scorer", "process": self._process_waifu_scorer},
213
+ "aesthetic_predictor_v2_5": {"name": "Aesthetic V2.5", "process": self._process_aesthetic_predictor_v2_5},
214
+ "anime_aesthetic": {"name": "Anime Score", "process": self._process_anime_aesthetic},
215
+ }
216
+
217
+ def image_to_base64(self, image: Image.Image) -> str:
218
+ """Convert PIL Image to base64 encoded JPEG string."""
219
+ buffered = BytesIO()
220
+ image.save(buffered, format="JPEG")
221
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
222
 
223
+ def auto_tune_batch_size(self, images: list) -> int:
224
+ """Automatically determine the optimal batch size for processing."""
225
+ batch_size = 1
226
+ max_batch = len(images)
227
+ test_image = images[0:1]
228
+ while batch_size <= max_batch:
229
+ try:
230
+ if "aesthetic_shadow" in self.available_models and self.available_models["aesthetic_shadow"]['selected']: # Check if model is available and selected
231
+ _ = self.aesthetic_shadow(test_image * batch_size)
232
+ if "waifu_scorer" in self.available_models and self.available_models["waifu_scorer"]['selected']: # Check if model is available and selected
233
+ _ = self.waifu_scorer(test_image * batch_size)
234
+ if "aesthetic_predictor_v2_5" in self.available_models and self.available_models["aesthetic_predictor_v2_5"]['selected']: # Check if model is available and selected
235
+ _ = self.aesthetic_predictor.inference(test_image * batch_size)
236
+ batch_size *= 2
237
+ if batch_size > max_batch:
238
+ break
239
+ except Exception:
240
+ break
241
+ optimal = max(1, batch_size // 2)
242
+ if optimal > 64:
243
+ optimal = 64
244
+ print("Capped optimal batch size to 64")
245
+ print(f"Optimal batch size determined: {optimal}")
246
+ return optimal
247
+
248
+ async def process_images_evaluation_with_logs(self, file_paths: list, auto_batch: bool, manual_batch_size: int, selected_models):
249
+ """Asynchronously process images and yield updates with logs, HTML table, and progress bar."""
250
+ self.results = []
251
+ log_events = []
252
+ images = []
253
+ file_names = []
254
+
255
+ # Update available models based on selection
256
+ for model_key in self.available_models:
257
+ self.available_models[model_key]['selected'] = model_key in selected_models
258
+
259
+ total_files = len(file_paths)
260
+ log_events.append(f"Starting to load {total_files} images...")
261
+ for f in file_paths:
262
+ try:
263
+ img = Image.open(f).convert("RGB")
264
+ images.append(img)
265
+ file_names.append(os.path.basename(f))
266
+ except Exception as e:
267
+ log_events.append(f"Error opening {f}: {e}")
268
 
269
+ if not images:
270
+ log_events.append("No valid images loaded.")
271
+ progress_percentage = 0 # Define progress_percentage here for no images case
272
+ progress_html = self._generate_progress_html(progress_percentage)
273
+ yield ("<p>No images loaded.</p>", "", self._format_logs(log_events), progress_html, manual_batch_size)
274
+ return
275
 
276
+ yield ("<p>Images loaded. Determining batch size...</p>", "", self._format_logs(log_events),
277
+ self._generate_progress_html(0), manual_batch_size)
278
+ await asyncio.sleep(0.1)
 
 
 
 
 
 
279
 
280
  try:
281
+ manual_batch_size = int(manual_batch_size) if manual_batch_size is not None else 1
282
+ except ValueError:
283
+ manual_batch_size = 1
284
+ log_events.append("Invalid manual batch size. Defaulting to 1.")
285
+
286
+ optimal_batch = self.auto_tune_batch_size(images) if auto_batch else manual_batch_size
287
+ log_events.append(f"Using batch size: {optimal_batch}")
288
+ yield ("<p>Processing images in batches...</p>", "", self._format_logs(log_events),
289
+ self._generate_progress_html(0), optimal_batch)
290
+ await asyncio.sleep(0.1)
291
+
292
+ total_images = len(images)
293
+ for i in range(0, total_images, optimal_batch):
294
+ batch_images = images[i:i+optimal_batch]
295
+ batch_file_names = file_names[i:i+optimal_batch]
296
+ batch_index = i // optimal_batch + 1
297
+ log_events.append(f"Processing batch {batch_index}: images {i+1} to {min(i+optimal_batch, total_images)}")
298
+
299
+ batch_results = {}
300
+
301
+ # Aesthetic Shadow processing
302
+ if self.available_models['aesthetic_shadow']['selected']:
303
+ batch_results['aesthetic_shadow'] = await self._process_aesthetic_shadow(batch_images, log_events)
304
+ else:
305
+ batch_results['aesthetic_shadow'] = [None] * len(batch_images)
306
 
307
+ # Waifu Scorer processing
308
+ if self.available_models['waifu_scorer']['selected']:
309
+ batch_results['waifu_scorer'] = await self._process_waifu_scorer(batch_images, log_events)
310
+ else:
311
+ batch_results['waifu_scorer'] = [None] * len(batch_images)
 
 
 
312
 
313
+ # Aesthetic Predictor V2.5 processing
314
+ if self.available_models['aesthetic_predictor_v2_5']['selected']:
315
+ batch_results['aesthetic_predictor_v2_5'] = await self._process_aesthetic_predictor_v2_5(batch_images, log_events)
316
+ else:
317
+ batch_results['aesthetic_predictor_v2_5'] = [None] * len(batch_images)
 
 
 
 
318
 
319
+ # Anime Aesthetic processing (single image)
320
+ if self.available_models['anime_aesthetic']['selected']:
321
+ batch_results['anime_aesthetic'] = await self._process_anime_aesthetic(batch_images, log_events)
322
+ else:
323
+ batch_results['anime_aesthetic'] = [None] * len(batch_images)
 
 
324
 
 
325
 
326
+ # Combine results
327
+ for j in range(len(batch_images)):
328
+ scores_to_average = []
329
+ for model_key in self.available_models:
330
+ if self.available_models[model_key]['selected']: # Only consider selected models
331
+ score = batch_results[model_key][j]
332
+ if score is not None:
333
+ scores_to_average.append(score)
334
 
335
+ final_score = float(np.clip(np.mean(scores_to_average), 0.0, 10.0)) if scores_to_average else None
336
+ thumbnail = batch_images[j].copy()
337
+ thumbnail.thumbnail((200, 200))
338
+ result = {
339
+ 'file_name': batch_file_names[j],
340
+ 'img_data': self.image_to_base64(thumbnail), # Keep this for the HTML display
341
+ 'final_score': final_score,
342
+ }
343
+ for model_key in self.available_models: # Add model scores to result
344
+ if self.available_models[model_key]['selected']:
345
+ result[model_key] = batch_results[model_key][j]
346
+
347
+ self.results.append(result)
348
+ self.sort_results() # Sort results after adding new result
349
+ progress_percentage = min(100, ((i + len(batch_images)) / total_images) * 100) # Define progress_percentage here
350
+ yield (f"<p>Processed batch {batch_index}.</p>", self.generate_html_table(self.results, selected_models), # Update table immediately
351
+ self._format_logs(log_events[-10:]), self._generate_progress_html(progress_percentage), optimal_batch)
352
+ await asyncio.sleep(0.1)
353
+
354
+
355
+ log_events.append("All images processed.")
356
+ self.sort_results() # Final sort after all images processed
357
+ html_table = self.generate_html_table(self.results, selected_models) # Pass selected models to final table generation
358
+ final_progress = self._generate_progress_html(100)
359
+ yield ("<p>All images processed.</p>", html_table,
360
+ self._format_logs(log_events[-10:]), final_progress, optimal_batch)
361
+
362
+ async def _process_aesthetic_shadow(self, batch_images, log_events):
363
  try:
364
+ shadow_results = self.aesthetic_shadow(batch_images)
365
+ log_events.append("Aesthetic Shadow processed for batch.")
 
 
 
 
 
 
 
 
 
366
  except Exception as e:
367
+ log_events.append(f"Error in Aesthetic Shadow: {e}")
368
+ shadow_results = [None] * len(batch_images)
369
+ aesthetic_shadow_scores = []
370
+ for res in shadow_results:
 
 
 
371
  try:
372
+ hq_score = next(p for p in res if p['label'] == 'hq')['score']
373
+ score = float(np.clip(hq_score * 10.0, 0.0, 10.0))
374
+ except Exception:
375
+ score = None
376
+ aesthetic_shadow_scores.append(score)
377
+ log_events.append("Aesthetic Shadow scores computed for batch.")
378
+ return aesthetic_shadow_scores
379
+
380
+ async def _process_waifu_scorer(self, batch_images, log_events):
381
+ try:
382
+ waifu_scores = self.waifu_scorer(batch_images)
383
+ waifu_scores = [float(np.clip(s, 0.0, 10.0)) if s is not None else None for s in waifu_scores]
384
+ log_events.append("Waifu Scorer processed for batch.")
385
+ except Exception as e:
386
+ log_events.append(f"Error in Waifu Scorer: {e}")
387
+ waifu_scores = [None] * len(batch_images)
388
+ return waifu_scores
389
 
390
+ async def _process_aesthetic_predictor_v2_5(self, batch_images, log_events):
391
+ try:
392
+ v2_5_scores = self.aesthetic_predictor.inference(batch_images)
393
+ v2_5_scores = [float(np.round(np.clip(s, 0.0, 10.0), 4)) if s is not None else None for s in v2_5_scores]
394
+ log_events.append("Aesthetic Predictor V2.5 processed for batch.")
395
+ except Exception as e:
396
+ log_events.append(f"Error in Aesthetic Predictor V2.5: {e}")
397
+ v2_5_scores = [None] * len(batch_images)
398
+ return v2_5_scores
399
 
400
+ async def _process_anime_aesthetic(self, batch_images, log_events):
401
+ anime_scores = []
402
+ for j, img in enumerate(batch_images):
403
+ try:
404
+ score = predict_anime_aesthetic(img, self.anime_aesthetic)
405
+ anime_scores.append(float(np.clip(score * 10.0, 0.0, 10.0)))
406
+ log_events.append(f"Anime Aesthetic processed for image {j + 1}.")
407
  except Exception as e:
408
+ log_events.append(f"Error in Anime Aesthetic for image {j + 1}: {e}")
409
+ anime_scores.append(None)
410
+ return anime_scores
411
+
412
+
413
+ def _generate_progress_html(self, percentage: float) -> str:
414
+ """Generate HTML for a progress bar given a percentage."""
415
+ return f"""
416
+ <div style="width:100%;background-color:#ddd; border-radius:5px;">
417
+ <div style="width:{percentage:.1f}%; background-color:#4CAF50; text-align:center; padding:5px 0; border-radius:5px;">
418
+ {percentage:.1f}%
419
+ </div>
420
+ </div>
421
+ """
422
 
423
+ def _format_logs(self, logs: list) -> str:
424
+ """Format log events into an HTML string."""
425
+ return "<div style='max-height:300px; overflow-y:auto;'>" + "<br>".join(logs) + "</div>"
426
+
427
+ def sort_results(self, sort_by: str = "Final Score") -> list:
428
+ """Sort results based on the specified column."""
429
+ key_map = {
430
+ "Final Score": "final_score",
431
+ "File Name": "file_name",
432
+ "Aesthetic Shadow": "aesthetic_shadow",
433
+ "Waifu Scorer": "waifu_scorer",
434
+ "Aesthetic V2.5": "aesthetic_predictor_v2_5",
435
+ "Anime Score": "anime_aesthetic"
436
+ }
437
+ key = key_map.get(sort_by, "final_score")
438
+ reverse = sort_by != "File Name"
439
+ self.results.sort(key=lambda r: r.get(key) if r.get(key) is not None else (-float('inf') if not reverse else float('inf')), reverse=reverse)
440
+ return self.results
441
+
442
+ def generate_html_table(self, results: list, selected_models) -> str:
443
+ """Generate an HTML table to display the evaluation results."""
444
+ table_html = """
445
  <style>
446
+ .results-table { width: 100%; border-collapse: collapse; margin: 20px 0; font-family: Arial, sans-serif; }
447
+ .results-table th, .results-table td { color: #eee; border: 1px solid #ddd; padding: 8px; text-align: center; }
448
+ .results-table th { font-weight: bold; }
449
+ .results-table tr:nth-child(even) { background-color: transparent; }
450
+ .results-table tr:hover { background-color: rgba(255, 255, 255, 0.1); }
451
+ .image-preview { max-width: 150px; max-height: 150px; display: block; margin: 0 auto; }
452
+ .good-score { color: #0f0; font-weight: bold; }
453
+ .bad-score { color: #f00; font-weight: bold; }
454
+ .medium-score { color: orange; font-weight: bold; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  </style>
 
456
  <table class="results-table">
457
  <thead>
458
  <tr>
459
  <th>Image</th>
460
  <th>File Name</th>
 
 
 
 
 
 
 
 
461
  """
462
+ visible_models = [] # Keep track of visible model columns
463
+ if "aesthetic_shadow" in selected_models:
464
+ table_html += "<th>Aesthetic Shadow</th>"
465
+ visible_models.append("aesthetic_shadow")
466
+ if "waifu_scorer" in selected_models:
467
+ table_html += "<th>Waifu Scorer</th>"
468
+ visible_models.append("waifu_scorer")
469
+ if "aesthetic_predictor_v2_5" in selected_models:
470
+ table_html += "<th>Aesthetic V2.5</th>"
471
+ visible_models.append("aesthetic_predictor_v2_5")
472
+ if "anime_aesthetic" in selected_models:
473
+ table_html += "<th>Anime Score</th>"
474
+ visible_models.append("anime_aesthetic")
475
+ table_html += "<th>Final Score</th>"
476
+ table_html += "</tr></thead><tbody>"
477
 
478
  for result in results:
479
+ table_html += "<tr>"
480
+ table_html += f'<td><img src="data:image/jpeg;base64,{result["img_data"]}" class="image-preview"></td>'
481
+ table_html += f'<td>{result["file_name"]}</td>'
482
+ for model_key in visible_models: # Iterate through visible models only
483
+ score = result.get(model_key)
484
+ table_html += self._format_score_cell(score)
485
+
486
+ score = result.get("final_score")
487
+ table_html += self._format_score_cell(score)
488
+ table_html += "</tr>"
489
+ table_html += """</tbody></table>"""
490
+ return table_html
491
+
492
+ def _format_score_cell(self, score):
493
+ score_str = f"{score:.4f}" if isinstance(score, (int, float)) else "N/A"
494
+ score_class = ""
495
+ if isinstance(score, (int, float)):
496
+ if score >= 7:
497
+ score_class = "good-score"
498
+ elif score >= 5:
499
+ score_class = "medium-score"
500
+ else:
501
+ score_class = "bad-score"
502
+ return f'<td class="{score_class}">{score_str}</td>'
 
 
 
 
 
 
 
503
 
 
504
 
505
  def cleanup(self):
506
+ """Clean up temporary directories."""
507
  if os.path.exists(self.temp_dir):
508
  shutil.rmtree(self.temp_dir)
509
 
 
 
510
 
511
+ #####################################
512
+ # Interface #
513
+ #####################################
514
 
515
+ def create_interface():
516
  evaluator = ImageEvaluationTool()
517
+ sort_options = ["Final Score", "File Name", "Aesthetic Shadow", "Waifu Scorer", "Aesthetic V2.5", "Anime Score"]
518
+ model_options = ["aesthetic_shadow", "waifu_scorer", "aesthetic_predictor_v2_5", "anime_aesthetic"]
519
 
520
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
521
  gr.Markdown("""
522
  # Comprehensive Image Evaluation Tool
523
 
524
+ Upload images to evaluate them using multiple aesthetic and quality prediction models.
525
+
526
+ **New features:**
527
+ - **Dynamic Final Score:** Final score recalculates on model selection changes.
528
+ - **Model Selection:** Choose which models to use for evaluation.
529
+ - **Dynamic Table Updates:** Table updates automatically based on model selection.
530
+ - **Automatic Sorting:** Table is automatically sorted by 'Final Score'.
531
+ - **Detailed Logs:** See major processing events (limited to the last 10).
532
+ - **Progress Bar:** Visual indication of processing status.
533
+ - **Asynchronous Updates:** Streaming status and logs during processing.
534
+ - **Batch Size Controls:** Choose manual batch size or let the tool auto-detect it.
535
+ - **Download Results:** Export the evaluation results as CSV.
536
  """)
537
 
538
  with gr.Row():
539
  with gr.Column(scale=1):
540
+ input_images = gr.Files(label="Upload Images", file_count="multiple")
541
+ model_checkboxes = gr.CheckboxGroup(model_options, label="Select Models", value=model_options, info="Choose models for evaluation.")
542
+ auto_batch_checkbox = gr.Checkbox(label="Automatic Batch Size Detection", value=False, info="Enable to automatically determine the optimal batch size.")
543
+ batch_size_input = gr.Number(label="Batch Size", value=1, interactive=True, info="Manually specify the batch size if auto-detection is disabled.")
544
+ sort_dropdown = gr.Dropdown(sort_options, value="Final Score", label="Sort by", info="Select the column to sort results by.")
545
  process_btn = gr.Button("Evaluate Images", variant="primary")
546
  clear_btn = gr.Button("Clear Results")
547
+ download_csv = gr.Button("Download CSV", variant="secondary")
548
 
549
  with gr.Column(scale=2):
550
+ progress_bar = gr.HTML(label="Progress Bar", value="""
551
+ <div style='width:100%;background-color:#ddd;'>
552
+ <div style='width:0%;background-color:#4CAF50;padding:5px 0;text-align:center;'>0%</div>
553
+ </div>
554
+ """)
555
+ log_window = gr.HTML(label="Detailed Logs", value="<div style='max-height:300px; overflow-y:auto;'>Logs will appear here...</div>")
556
+ status_html = gr.HTML(label="Status")
557
  output_html = gr.HTML(label="Evaluation Results")
558
+ download_file_output = gr.File() # Initialize gr.File component without filename
559
+
560
+ # Function to convert results to CSV format, excluding 'img_data'.
561
+ def results_to_csv(selected_models):
562
+ import csv
563
+ import io
564
+ if not evaluator.results:
565
+ return None # Return None when no results are available
566
+ output = io.StringIO()
567
+ fieldnames = ['file_name', 'final_score'] # Base fieldnames
568
+ for model_key in selected_models: # Add selected model names as fieldnames
569
+ if model_key in selected_models: # Double check if model_key is indeed in selected_models list
570
+ fieldnames.append(model_key)
571
+
572
+ writer = csv.DictWriter(output, fieldnames=fieldnames)
573
+ writer.writeheader()
574
+ for res in evaluator.results:
575
+ row_dict = {'file_name': res['file_name'], 'final_score': res['final_score']} # Base data
576
+ for model_key in selected_models: # Add selected model scores
577
+ if model_key in selected_models: # Double check before accessing res[model_key]
578
+ row_dict[model_key] = res.get(model_key, 'N/A') # Use get with default 'N/A' if model not in result (shouldn't happen but for safety)
579
+ writer.writerow(row_dict)
580
+ return output.getvalue()
581
+
582
+
583
+ def update_batch_size_interactivity(auto_batch):
584
+ return gr.update(interactive=not auto_batch)
585
+
586
+ async def process_images_and_update(files, auto_batch, manual_batch, selected_models):
587
  file_paths = [f.name for f in files]
588
+ async for status, table, logs, progress, updated_batch in evaluator.process_images_evaluation_with_logs(file_paths, auto_batch, manual_batch, selected_models):
589
+ yield status, table, logs, progress, gr.update(value=updated_batch, interactive=not auto_batch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
 
591
+ def update_table_sort(sort_by_column, selected_models):
592
+ sorted_results = evaluator.sort_results(sort_by_column)
593
+ return evaluator.generate_html_table(sorted_results, selected_models)
 
594
 
595
+ def update_table_model_selection(selected_models):
596
+ # Recalculate final scores based on selected models
597
+ for result in evaluator.results:
598
+ scores_to_average = []
599
+ for model_key in evaluator.available_models:
600
+ if model_key in selected_models and evaluator.available_models[model_key]['selected']: # consider only selected models from checkbox group and available_models
601
+ score = result.get(model_key)
602
+ if score is not None:
603
+ scores_to_average.append(score)
604
+ final_score = float(np.clip(np.mean(scores_to_average), 0.0, 10.0)) if scores_to_average else None
605
+ result['final_score'] = final_score
606
+
607
+ sorted_results = evaluator.sort_results() # Keep sorting by Final Score when models change
608
+ return evaluator.generate_html_table(sorted_results, selected_models)
609
+
610
+
611
+ def clear_results():
612
+ evaluator.results = []
613
+ return (gr.update(value=""),
614
+ gr.update(value=""),
615
+ gr.update(value=""),
616
+ gr.update(value="""
617
+ <div style='width:100%;background-color:#ddd;'>
618
+ <div style='width:0%;background-color:#4CAF50;padding:5px 0;text-align:center;'>0%</div>
619
+ </div>
620
+ """),
621
+ gr.update(value=1))
622
+
623
+ def download_results_csv_trigger(selected_models): # Changed function name to avoid conflict and clarify purpose
624
+ csv_content = results_to_csv(selected_models)
625
+ if csv_content is None:
626
+ return None # Indicate no file to download
627
+
628
+ # Create a temporary file to save the CSV data
629
+ with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmp_file:
630
+ tmp_file.write(csv_content.encode())
631
+ temp_file_path = tmp_file.name # Get the path to the temporary file
632
+
633
+ return temp_file_path # Return the path to the temporary file
634
+
635
+
636
+ auto_batch_checkbox.change(
637
+ update_batch_size_interactivity,
638
+ inputs=[auto_batch_checkbox],
639
+ outputs=[batch_size_input]
640
+ )
641
 
642
  process_btn.click(
643
  process_images_and_update,
644
+ inputs=[input_images, auto_batch_checkbox, batch_size_input, model_checkboxes],
645
+ outputs=[status_html, output_html, log_window, progress_bar, batch_size_input]
646
  )
647
+ sort_dropdown.change(
648
  update_table_sort,
649
+ inputs=[sort_dropdown, model_checkboxes],
650
+ outputs=[output_html]
651
+ )
652
+ model_checkboxes.change( # Added change event for model checkboxes
653
+ update_table_model_selection,
654
+ inputs=[model_checkboxes],
655
+ outputs=[output_html]
656
  )
657
  clear_btn.click(
658
  clear_results,
659
  inputs=[],
660
+ outputs=[status_html, output_html, log_window, progress_bar, batch_size_input]
661
  )
662
+ download_csv.click(
663
+ download_results_csv_trigger, # Call the trigger function
664
+ inputs=[model_checkboxes],
665
+ outputs=[download_file_output] # Output is now the gr.File component
666
+ )
667
+ demo.load(lambda: update_table_sort("Final Score", model_options), inputs=None, outputs=[output_html]) # Initial sort and table render
668
  gr.Markdown("""
669
  ### Notes
670
+ - Select models to use for evaluation using the checkboxes.
671
+ - The 'Final Score' recalculates dynamically when models are selected/deselected.
672
+ - The table updates automatically when models are selected/deselected and is always sorted by 'Final Score'.
673
+ - The log window displays the most recent 10 events.
674
+ - The progress bar shows overall processing status.
675
+ - When 'Automatic Batch Size Detection' is enabled, the batch size field becomes disabled.
676
+ - Use the download button to export your evaluation results as CSV.
677
  """)
678
 
679
  return demo