Update app.py
Browse files
app.py
CHANGED
@@ -1,27 +1,32 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import torch
|
3 |
import os
|
4 |
-
import
|
|
|
|
|
|
|
|
|
|
|
5 |
import cv2
|
|
|
|
|
6 |
import onnxruntime as rt
|
7 |
from PIL import Image
|
|
|
8 |
from transformers import pipeline
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
-
import pandas as pd
|
11 |
-
import tempfile
|
12 |
-
import shutil
|
13 |
-
import base64
|
14 |
-
from io import BytesIO
|
15 |
|
16 |
# Import necessary function from aesthetic_predictor_v2_5
|
17 |
from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
|
18 |
|
|
|
|
|
|
|
|
|
|
|
19 |
class MLP(torch.nn.Module):
|
20 |
-
|
|
|
21 |
super().__init__()
|
22 |
self.input_size = input_size
|
23 |
-
self.xcol = xcol
|
24 |
-
self.ycol = ycol
|
25 |
self.layers = torch.nn.Sequential(
|
26 |
torch.nn.Linear(self.input_size, 2048),
|
27 |
torch.nn.ReLU(),
|
@@ -44,29 +49,37 @@ class MLP(torch.nn.Module):
|
|
44 |
torch.nn.Linear(32, 1)
|
45 |
)
|
46 |
|
47 |
-
def forward(self, x):
|
48 |
return self.layers(x)
|
49 |
|
50 |
-
|
51 |
-
|
|
|
|
|
52 |
self.verbose = verbose
|
|
|
|
|
|
|
53 |
|
54 |
try:
|
55 |
-
import clip
|
56 |
-
|
57 |
if model_path is None:
|
58 |
model_path = "Eugeoter/waifu-scorer-v3/model.pth"
|
59 |
if self.verbose:
|
60 |
-
print(f"
|
61 |
|
|
|
62 |
if not os.path.isfile(model_path):
|
63 |
-
|
64 |
-
username, repo_id, model_name = split[-3], split[-2], split[-1]
|
65 |
model_path = hf_hub_download(f"{username}/{repo_id}", model_name, cache_dir=cache_dir)
|
66 |
|
67 |
-
|
|
|
68 |
|
|
|
69 |
self.mlp = MLP(input_size=768)
|
|
|
70 |
if model_path.endswith(".safetensors"):
|
71 |
from safetensors.torch import load_file
|
72 |
state_dict = load_file(model_path)
|
@@ -74,42 +87,44 @@ class WaifuScorer(object):
|
|
74 |
state_dict = torch.load(model_path, map_location=device)
|
75 |
self.mlp.load_state_dict(state_dict)
|
76 |
self.mlp.to(device)
|
77 |
-
|
78 |
-
self.model2, self.preprocess = clip.load("ViT-L/14", device=device)
|
79 |
-
self.device = device
|
80 |
-
self.dtype = torch.float32
|
81 |
self.mlp.eval()
|
|
|
|
|
|
|
82 |
self.available = True
|
83 |
except Exception as e:
|
84 |
print(f"Unable to initialize WaifuScorer: {e}")
|
85 |
-
self.available = False
|
86 |
|
87 |
@torch.no_grad()
|
88 |
def __call__(self, images):
|
89 |
if not self.available:
|
90 |
-
return [None] * (
|
91 |
-
|
92 |
if isinstance(images, Image.Image):
|
93 |
images = [images]
|
94 |
n = len(images)
|
|
|
95 |
if n == 1:
|
96 |
-
images = images*2
|
97 |
|
98 |
image_tensors = [self.preprocess(img).unsqueeze(0) for img in images]
|
99 |
image_batch = torch.cat(image_tensors).to(self.device)
|
100 |
-
image_features = self.
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
predictions = self.mlp(im_emb_arr)
|
107 |
scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist()
|
108 |
-
|
109 |
return scores[:n]
|
110 |
|
|
|
|
|
|
|
|
|
|
|
111 |
def load_aesthetic_predictor_v2_5():
|
112 |
-
|
|
|
113 |
def __init__(self):
|
114 |
print("Loading Aesthetic Predictor V2.5...")
|
115 |
self.model, self.preprocessor = convert_v2_5_from_siglip(
|
@@ -119,516 +134,546 @@ def load_aesthetic_predictor_v2_5():
|
|
119 |
if torch.cuda.is_available():
|
120 |
self.model = self.model.to(torch.bfloat16).cuda()
|
121 |
|
122 |
-
def inference(self, image
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
-
|
136 |
|
137 |
-
return AestheticPredictorV2_5_Impl() # Return an instance of the implementation class
|
138 |
|
139 |
def load_anime_aesthetic_model():
|
|
|
140 |
model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx")
|
141 |
-
|
142 |
-
|
143 |
|
144 |
def predict_anime_aesthetic(img, model):
|
145 |
-
|
|
|
146 |
s = 768
|
147 |
-
h, w =
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
return pred
|
156 |
|
|
|
|
|
|
|
|
|
|
|
157 |
class ImageEvaluationTool:
|
|
|
158 |
def __init__(self):
|
159 |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
160 |
print(f"Using device: {self.device}")
|
161 |
-
|
162 |
print("Loading models... This may take some time.")
|
163 |
|
|
|
164 |
print("Loading Aesthetic Shadow model...")
|
165 |
self.aesthetic_shadow = pipeline("image-classification", model="NeoChen1024/aesthetic-shadow-v2-backup", device=self.device)
|
166 |
-
|
167 |
print("Loading Waifu Scorer model...")
|
168 |
self.waifu_scorer = WaifuScorer(device=self.device, verbose=True)
|
169 |
-
|
170 |
print("Loading Aesthetic Predictor V2.5...")
|
171 |
-
self.
|
172 |
-
|
173 |
print("Loading Anime Aesthetic model...")
|
174 |
self.anime_aesthetic = load_anime_aesthetic_model()
|
175 |
-
|
176 |
print("All models loaded successfully!")
|
177 |
|
178 |
self.temp_dir = tempfile.mkdtemp()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
-
def
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
-
if not
|
184 |
-
|
|
|
|
|
|
|
|
|
185 |
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
# Scale aesthetic_shadow to 0-10 and clamp
|
190 |
-
aesthetic_shadow_score = np.clip(hq_score * 10.0, 0.0, 10.0)
|
191 |
-
results['aesthetic_shadow'] = aesthetic_shadow_score
|
192 |
-
except Exception as e:
|
193 |
-
print(f"Error in Aesthetic Shadow: {e}")
|
194 |
-
results['aesthetic_shadow'] = None
|
195 |
|
196 |
try:
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
except Exception as e:
|
211 |
-
print(f"Error in Aesthetic Predictor V2.5: {e}")
|
212 |
-
results['aesthetic_predictor_v2_5'] = None
|
213 |
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
results['anime_aesthetic'] = anime_score_scaled
|
220 |
-
except Exception as e:
|
221 |
-
print(f"Error in Anime Aesthetic: {e}")
|
222 |
-
results['anime_aesthetic'] = None
|
223 |
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
else:
|
230 |
-
results['final_score'] = None
|
231 |
|
232 |
-
return results
|
233 |
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
|
|
|
|
|
|
|
|
238 |
|
239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
try:
|
241 |
-
|
242 |
-
|
243 |
-
thumbnail = img.copy()
|
244 |
-
thumbnail.thumbnail((200, 200))
|
245 |
-
img_base64 = self.image_to_base64(thumbnail)
|
246 |
-
result = {
|
247 |
-
'file_name': os.path.basename(file_path),
|
248 |
-
'img_data': img_base64,
|
249 |
-
**eval_results
|
250 |
-
}
|
251 |
-
return result
|
252 |
except Exception as e:
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
results = []
|
258 |
-
|
259 |
-
for i, file_path in enumerate(image_files):
|
260 |
try:
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
|
|
|
|
|
|
275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
except Exception as e:
|
277 |
-
|
278 |
-
|
279 |
-
return
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
|
|
290 |
|
291 |
-
def
|
292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
<style>
|
294 |
-
.results-table {
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
}
|
301 |
-
|
302 |
-
.
|
303 |
-
.results-table td {
|
304 |
-
color: #eee;
|
305 |
-
border: 1px solid #ddd;
|
306 |
-
padding: 8px;
|
307 |
-
text-align: center;
|
308 |
-
background-color: transparent;
|
309 |
-
}
|
310 |
-
|
311 |
-
.results-table th {
|
312 |
-
font-weight: bold;
|
313 |
-
}
|
314 |
-
|
315 |
-
.results-table tr:nth-child(even) {
|
316 |
-
background-color: transparent;
|
317 |
-
}
|
318 |
-
|
319 |
-
.results-table tr:hover {
|
320 |
-
background-color: rgba(255, 255, 255, 0.1);
|
321 |
-
}
|
322 |
-
|
323 |
-
.image-preview {
|
324 |
-
max-width: 150px;
|
325 |
-
max-height: 150px;
|
326 |
-
display: block;
|
327 |
-
margin: 0 auto;
|
328 |
-
}
|
329 |
-
|
330 |
-
.good-score {
|
331 |
-
color: #0f0;
|
332 |
-
font-weight: bold;
|
333 |
-
}
|
334 |
-
.bad-score {
|
335 |
-
color: #f00;
|
336 |
-
font-weight: bold;
|
337 |
-
}
|
338 |
-
.medium-score {
|
339 |
-
color: orange;
|
340 |
-
font-weight: bold;
|
341 |
-
}
|
342 |
</style>
|
343 |
-
|
344 |
<table class="results-table">
|
345 |
<thead>
|
346 |
<tr>
|
347 |
<th>Image</th>
|
348 |
<th>File Name</th>
|
349 |
-
<th>Aesthetic Shadow</th>
|
350 |
-
<th>Waifu Scorer</th>
|
351 |
-
<th>Aesthetic V2.5</th>
|
352 |
-
<th>Anime Score</th>
|
353 |
-
<th>Final Score</th>
|
354 |
-
</tr>
|
355 |
-
</thead>
|
356 |
-
<tbody>
|
357 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
|
359 |
for result in results:
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
html += "</tr>"
|
386 |
-
|
387 |
-
html += """
|
388 |
-
</tbody>
|
389 |
-
</table>
|
390 |
-
"""
|
391 |
|
392 |
-
return html
|
393 |
|
394 |
def cleanup(self):
|
|
|
395 |
if os.path.exists(self.temp_dir):
|
396 |
shutil.rmtree(self.temp_dir)
|
397 |
|
398 |
-
# Global variable to store evaluation results
|
399 |
-
global_results = None
|
400 |
|
401 |
-
|
402 |
-
|
|
|
403 |
|
|
|
404 |
evaluator = ImageEvaluationTool()
|
405 |
-
sort_options = ["Final Score", "File Name", "Aesthetic Shadow", "Waifu Scorer", "Aesthetic V2.5", "Anime Score"]
|
|
|
406 |
|
407 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
408 |
gr.Markdown("""
|
409 |
# Comprehensive Image Evaluation Tool
|
410 |
|
411 |
-
Upload images to evaluate them using multiple aesthetic and quality prediction models
|
412 |
-
|
413 |
-
|
414 |
-
- **
|
415 |
-
- **
|
416 |
-
- **
|
417 |
-
- **
|
418 |
-
|
419 |
-
|
|
|
|
|
|
|
420 |
""")
|
421 |
|
422 |
with gr.Row():
|
423 |
with gr.Column(scale=1):
|
424 |
-
input_images = gr.Files(label="Upload Images")
|
425 |
-
|
|
|
|
|
|
|
426 |
process_btn = gr.Button("Evaluate Images", variant="primary")
|
427 |
clear_btn = gr.Button("Clear Results")
|
|
|
428 |
|
429 |
with gr.Column(scale=2):
|
430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
output_html = gr.HTML(label="Evaluation Results")
|
432 |
-
|
433 |
-
|
434 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
435 |
file_paths = [f.name for f in files]
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
if not file_paths:
|
440 |
-
global_results = []
|
441 |
-
yield "<p>No files uploaded.</p>", gr.update()
|
442 |
-
return
|
443 |
-
|
444 |
-
# Helper function to generate a styled progress bar HTML snippet.
|
445 |
-
def generate_progress_bar(percentage):
|
446 |
-
return f"""
|
447 |
-
<div style="background-color: #ddd; border-radius: 5px; width: 100%; margin: 10px 0;">
|
448 |
-
<div style="width: {percentage:.1f}%; background-color: #4CAF50; text-align: center; padding: 5px 0; border-radius: 5px;">
|
449 |
-
{percentage:.1f}%
|
450 |
-
</div>
|
451 |
-
</div>
|
452 |
-
"""
|
453 |
-
|
454 |
-
total_models = 4 # Total number of model steps per image.
|
455 |
-
|
456 |
-
for i, file_path in enumerate(file_paths):
|
457 |
-
file_name = os.path.basename(file_path)
|
458 |
-
try:
|
459 |
-
img = Image.open(file_path).convert("RGB")
|
460 |
-
except Exception as e:
|
461 |
-
yield f"<p>Error opening {file_name}: {e}</p>", gr.update()
|
462 |
-
continue
|
463 |
-
|
464 |
-
# Update overall progress before starting this image.
|
465 |
-
overall_percent = (i / total_files) * 100
|
466 |
-
overall_bar = generate_progress_bar(overall_percent)
|
467 |
-
progress_html = f"""
|
468 |
-
<p>{file_name}: Starting evaluation...</p>
|
469 |
-
<p>Overall Progress:</p>
|
470 |
-
{overall_bar}
|
471 |
-
"""
|
472 |
-
yield progress_html, gr.update()
|
473 |
-
|
474 |
-
# === Model Step 1: Aesthetic Shadow ===
|
475 |
-
model_index = 0
|
476 |
-
sub_percent = ((model_index + 1) / total_models) * 100
|
477 |
-
sub_bar = generate_progress_bar(sub_percent)
|
478 |
-
progress_html = f"""
|
479 |
-
<p>{file_name}: Evaluating <strong>Aesthetic Shadow</strong>...</p>
|
480 |
-
<p>Current Image Progress:</p>
|
481 |
-
{sub_bar}
|
482 |
-
<p>Overall Progress:</p>
|
483 |
-
{overall_bar}
|
484 |
-
"""
|
485 |
-
yield progress_html, gr.update()
|
486 |
-
try:
|
487 |
-
shadow_result = evaluator.aesthetic_shadow(images=[img])[0]
|
488 |
-
hq_score = [p for p in shadow_result if p['label'] == 'hq'][0]['score']
|
489 |
-
aesthetic_shadow_score = np.clip(hq_score * 10.0, 0.0, 10.0)
|
490 |
-
except Exception as e:
|
491 |
-
print(f"Error in Aesthetic Shadow for {file_name}: {e}")
|
492 |
-
aesthetic_shadow_score = None
|
493 |
-
yield f"<p>{file_name}: <strong>Aesthetic Shadow</strong> evaluation complete.</p>", gr.update()
|
494 |
-
|
495 |
-
# === Model Step 2: Waifu Scorer ===
|
496 |
-
model_index = 1
|
497 |
-
sub_percent = ((model_index + 1) / total_models) * 100
|
498 |
-
sub_bar = generate_progress_bar(sub_percent)
|
499 |
-
progress_html = f"""
|
500 |
-
<p>{file_name}: Evaluating <strong>Waifu Scorer</strong>...</p>
|
501 |
-
<p>Current Image Progress:</p>
|
502 |
-
{sub_bar}
|
503 |
-
<p>Overall Progress:</p>
|
504 |
-
{overall_bar}
|
505 |
-
"""
|
506 |
-
yield progress_html, gr.update()
|
507 |
-
try:
|
508 |
-
waifu_score = evaluator.waifu_scorer([img])[0]
|
509 |
-
waifu_score = np.clip(waifu_score, 0.0, 10.0)
|
510 |
-
except Exception as e:
|
511 |
-
print(f"Error in Waifu Scorer for {file_name}: {e}")
|
512 |
-
waifu_score = None
|
513 |
-
yield f"<p>{file_name}: <strong>Waifu Scorer</strong> evaluation complete.</p>", gr.update()
|
514 |
-
|
515 |
-
# === Model Step 3: Aesthetic Predictor V2.5 ===
|
516 |
-
model_index = 2
|
517 |
-
sub_percent = ((model_index + 1) / total_models) * 100
|
518 |
-
sub_bar = generate_progress_bar(sub_percent)
|
519 |
-
progress_html = f"""
|
520 |
-
<p>{file_name}: Evaluating <strong>Aesthetic Predictor V2.5</strong>...</p>
|
521 |
-
<p>Current Image Progress:</p>
|
522 |
-
{sub_bar}
|
523 |
-
<p>Overall Progress:</p>
|
524 |
-
{overall_bar}
|
525 |
-
"""
|
526 |
-
yield progress_html, gr.update()
|
527 |
-
try:
|
528 |
-
v2_5_score = evaluator.aesthetic_predictor_v2_5.inference(img)
|
529 |
-
v2_5_score = float(np.round(np.clip(v2_5_score, 0.0, 10.0), 4))
|
530 |
-
except Exception as e:
|
531 |
-
print(f"Error in Aesthetic Predictor V2.5 for {file_name}: {e}")
|
532 |
-
v2_5_score = None
|
533 |
-
yield f"<p>{file_name}: <strong>Aesthetic Predictor V2.5</strong> evaluation complete.</p>", gr.update()
|
534 |
-
|
535 |
-
# === Model Step 4: Anime Aesthetic ===
|
536 |
-
model_index = 3
|
537 |
-
sub_percent = ((model_index + 1) / total_models) * 100
|
538 |
-
sub_bar = generate_progress_bar(sub_percent)
|
539 |
-
progress_html = f"""
|
540 |
-
<p>{file_name}: Evaluating <strong>Anime Aesthetic</strong>...</p>
|
541 |
-
<p>Current Image Progress:</p>
|
542 |
-
{sub_bar}
|
543 |
-
<p>Overall Progress:</p>
|
544 |
-
{overall_bar}
|
545 |
-
"""
|
546 |
-
yield progress_html, gr.update()
|
547 |
-
try:
|
548 |
-
img_array = np.array(img)
|
549 |
-
anime_score = predict_anime_aesthetic(img_array, evaluator.anime_aesthetic)
|
550 |
-
anime_score = np.clip(anime_score * 10.0, 0.0, 10.0)
|
551 |
-
except Exception as e:
|
552 |
-
print(f"Error in Anime Aesthetic for {file_name}: {e}")
|
553 |
-
anime_score = None
|
554 |
-
yield f"<p>{file_name}: <strong>Anime Aesthetic</strong> evaluation complete.</p>", gr.update()
|
555 |
-
|
556 |
-
# === Final Score Calculation and Results Collection ===
|
557 |
-
valid_scores = [v for v in [aesthetic_shadow_score, waifu_score, v2_5_score, anime_score] if v is not None]
|
558 |
-
final_score = np.clip(np.mean(valid_scores), 0.0, 10.0) if valid_scores else None
|
559 |
-
|
560 |
-
# Create a thumbnail and store the evaluation results.
|
561 |
-
thumbnail = img.copy()
|
562 |
-
thumbnail.thumbnail((200, 200))
|
563 |
-
img_base64 = evaluator.image_to_base64(thumbnail)
|
564 |
-
result = {
|
565 |
-
'file_name': file_name,
|
566 |
-
'img_data': img_base64,
|
567 |
-
'aesthetic_shadow': aesthetic_shadow_score,
|
568 |
-
'waifu_scorer': waifu_score,
|
569 |
-
'aesthetic_predictor_v2_5': v2_5_score,
|
570 |
-
'anime_aesthetic': anime_score,
|
571 |
-
'final_score': final_score
|
572 |
-
}
|
573 |
-
results.append(result)
|
574 |
-
|
575 |
-
# Update overall progress for the processed file.
|
576 |
-
overall_percent = ((i + 1) / total_files) * 100
|
577 |
-
overall_bar = generate_progress_bar(overall_percent)
|
578 |
-
progress_html = f"""
|
579 |
-
<p>{file_name}: Evaluation complete. ({i + 1}/{total_files} images processed.)</p>
|
580 |
-
<p>Overall Progress:</p>
|
581 |
-
{overall_bar}
|
582 |
-
"""
|
583 |
-
# Sort the results by Final Score and update the table.
|
584 |
-
sorted_results = evaluator.sort_results(results.copy(), sort_by="Final Score")
|
585 |
-
html_table = evaluator.generate_html_table(sorted_results)
|
586 |
-
yield progress_html, html_table
|
587 |
-
|
588 |
-
# Sort final results by Final Score.
|
589 |
-
global_results = evaluator.sort_results(results, sort_by="Final Score")
|
590 |
-
yield "<p>All images processed.</p>", evaluator.generate_html_table(global_results)
|
591 |
-
|
592 |
-
def update_table_sort(sort_by_column): # New function for sorting update
|
593 |
-
global global_results
|
594 |
-
if global_results is None:
|
595 |
-
return "No images evaluated yet." # Or handle case when no images are evaluated
|
596 |
-
sorted_results = evaluator.sort_results(global_results, sort_by=sort_by_column)
|
597 |
-
html_table = evaluator.generate_html_table(sorted_results)
|
598 |
-
return html_table
|
599 |
|
600 |
-
def
|
601 |
-
|
602 |
-
|
603 |
-
return gr.update(value=""), gr.update(value="")
|
604 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
605 |
|
606 |
process_btn.click(
|
607 |
process_images_and_update,
|
608 |
-
inputs=[input_images],
|
609 |
-
outputs=[
|
610 |
)
|
611 |
-
sort_dropdown.change(
|
612 |
update_table_sort,
|
613 |
-
inputs=[sort_dropdown],
|
614 |
-
outputs=[output_html]
|
|
|
|
|
|
|
|
|
|
|
615 |
)
|
616 |
clear_btn.click(
|
617 |
clear_results,
|
618 |
inputs=[],
|
619 |
-
outputs=[
|
620 |
)
|
621 |
-
|
622 |
-
|
623 |
-
|
|
|
|
|
|
|
624 |
gr.Markdown("""
|
625 |
### Notes
|
626 |
-
-
|
627 |
-
-
|
628 |
-
-
|
629 |
-
-
|
630 |
-
-
|
631 |
-
-
|
|
|
632 |
""")
|
633 |
|
634 |
return demo
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import shutil
|
3 |
+
import tempfile
|
4 |
+
import base64
|
5 |
+
import asyncio
|
6 |
+
from io import BytesIO
|
7 |
+
|
8 |
import cv2
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
import onnxruntime as rt
|
12 |
from PIL import Image
|
13 |
+
import gradio as gr
|
14 |
from transformers import pipeline
|
15 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
# Import necessary function from aesthetic_predictor_v2_5
|
18 |
from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
|
19 |
|
20 |
+
|
21 |
+
#####################################
|
22 |
+
# Model Definitions #
|
23 |
+
#####################################
|
24 |
+
|
25 |
class MLP(torch.nn.Module):
|
26 |
+
"""A simple multi-layer perceptron for image feature regression."""
|
27 |
+
def __init__(self, input_size: int, batch_norm: bool = True):
|
28 |
super().__init__()
|
29 |
self.input_size = input_size
|
|
|
|
|
30 |
self.layers = torch.nn.Sequential(
|
31 |
torch.nn.Linear(self.input_size, 2048),
|
32 |
torch.nn.ReLU(),
|
|
|
49 |
torch.nn.Linear(32, 1)
|
50 |
)
|
51 |
|
52 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
53 |
return self.layers(x)
|
54 |
|
55 |
+
|
56 |
+
class WaifuScorer:
|
57 |
+
"""WaifuScorer model that uses CLIP for feature extraction and a custom MLP for scoring."""
|
58 |
+
def __init__(self, model_path: str = None, device: str = 'cuda', cache_dir: str = None, verbose: bool = False):
|
59 |
self.verbose = verbose
|
60 |
+
self.device = device
|
61 |
+
self.dtype = torch.float32
|
62 |
+
self.available = False
|
63 |
|
64 |
try:
|
65 |
+
import clip # local import to avoid dependency issues
|
66 |
+
# Set default model path if not provided
|
67 |
if model_path is None:
|
68 |
model_path = "Eugeoter/waifu-scorer-v3/model.pth"
|
69 |
if self.verbose:
|
70 |
+
print(f"Model path not provided. Using default: {model_path}")
|
71 |
|
72 |
+
# Download model if not found locally
|
73 |
if not os.path.isfile(model_path):
|
74 |
+
username, repo_id, model_name = model_path.split("/")[-3:]
|
|
|
75 |
model_path = hf_hub_download(f"{username}/{repo_id}", model_name, cache_dir=cache_dir)
|
76 |
|
77 |
+
if self.verbose:
|
78 |
+
print(f"Loading WaifuScorer model from: {model_path}")
|
79 |
|
80 |
+
# Initialize MLP model
|
81 |
self.mlp = MLP(input_size=768)
|
82 |
+
# Load state dict
|
83 |
if model_path.endswith(".safetensors"):
|
84 |
from safetensors.torch import load_file
|
85 |
state_dict = load_file(model_path)
|
|
|
87 |
state_dict = torch.load(model_path, map_location=device)
|
88 |
self.mlp.load_state_dict(state_dict)
|
89 |
self.mlp.to(device)
|
|
|
|
|
|
|
|
|
90 |
self.mlp.eval()
|
91 |
+
|
92 |
+
# Load CLIP model for image preprocessing and feature extraction
|
93 |
+
self.clip_model, self.preprocess = clip.load("ViT-L/14", device=device)
|
94 |
self.available = True
|
95 |
except Exception as e:
|
96 |
print(f"Unable to initialize WaifuScorer: {e}")
|
|
|
97 |
|
98 |
@torch.no_grad()
|
99 |
def __call__(self, images):
|
100 |
if not self.available:
|
101 |
+
return [None] * (len(images) if isinstance(images, list) else 1)
|
|
|
102 |
if isinstance(images, Image.Image):
|
103 |
images = [images]
|
104 |
n = len(images)
|
105 |
+
# Ensure at least two images for CLIP model compatibility
|
106 |
if n == 1:
|
107 |
+
images = images * 2
|
108 |
|
109 |
image_tensors = [self.preprocess(img).unsqueeze(0) for img in images]
|
110 |
image_batch = torch.cat(image_tensors).to(self.device)
|
111 |
+
image_features = self.clip_model.encode_image(image_batch)
|
112 |
+
# Normalize features
|
113 |
+
norm = image_features.norm(2, dim=-1, keepdim=True)
|
114 |
+
norm[norm == 0] = 1
|
115 |
+
im_emb = (image_features / norm).to(device=self.device, dtype=self.dtype)
|
116 |
+
predictions = self.mlp(im_emb)
|
|
|
117 |
scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist()
|
|
|
118 |
return scores[:n]
|
119 |
|
120 |
+
|
121 |
+
#####################################
|
122 |
+
# Aesthetic Predictor Functions #
|
123 |
+
#####################################
|
124 |
+
|
125 |
def load_aesthetic_predictor_v2_5():
|
126 |
+
"""Load and return an instance of Aesthetic Predictor V2.5 with batch processing support."""
|
127 |
+
class AestheticPredictorV2_5_Impl:
|
128 |
def __init__(self):
|
129 |
print("Loading Aesthetic Predictor V2.5...")
|
130 |
self.model, self.preprocessor = convert_v2_5_from_siglip(
|
|
|
134 |
if torch.cuda.is_available():
|
135 |
self.model = self.model.to(torch.bfloat16).cuda()
|
136 |
|
137 |
+
def inference(self, image):
|
138 |
+
if isinstance(image, list):
|
139 |
+
images_rgb = [img.convert("RGB") for img in image]
|
140 |
+
pixel_values = self.preprocessor(images=images_rgb, return_tensors="pt").pixel_values
|
141 |
+
if torch.cuda.is_available():
|
142 |
+
pixel_values = pixel_values.to(torch.bfloat16).cuda()
|
143 |
+
with torch.inference_mode():
|
144 |
+
scores = self.model(pixel_values).logits.squeeze().float().cpu().numpy()
|
145 |
+
if scores.ndim == 0:
|
146 |
+
scores = np.array([scores])
|
147 |
+
return scores.tolist()
|
148 |
+
else:
|
149 |
+
pixel_values = self.preprocessor(images=image.convert("RGB"), return_tensors="pt").pixel_values
|
150 |
+
if torch.cuda.is_available():
|
151 |
+
pixel_values = pixel_values.to(torch.bfloat16).cuda()
|
152 |
+
with torch.inference_mode():
|
153 |
+
score = self.model(pixel_values).logits.squeeze().float().cpu().numpy()
|
154 |
+
return score
|
155 |
|
156 |
+
return AestheticPredictorV2_5_Impl()
|
157 |
|
|
|
158 |
|
159 |
def load_anime_aesthetic_model():
|
160 |
+
"""Load and return the Anime Aesthetic ONNX model."""
|
161 |
model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx")
|
162 |
+
return rt.InferenceSession(model_path, providers=['CPUExecutionProvider'])
|
163 |
+
|
164 |
|
165 |
def predict_anime_aesthetic(img, model):
|
166 |
+
"""Predict Anime Aesthetic score for a single image."""
|
167 |
+
img_np = np.array(img).astype(np.float32) / 255.0
|
168 |
s = 768
|
169 |
+
h, w = img_np.shape[:2]
|
170 |
+
if h > w:
|
171 |
+
new_h, new_w = s, int(s * w / h)
|
172 |
+
else:
|
173 |
+
new_h, new_w = int(s * h / w), s
|
174 |
+
resized = cv2.resize(img_np, (new_w, new_h))
|
175 |
+
# Center the resized image in a square canvas
|
176 |
+
canvas = np.zeros((s, s, 3), dtype=np.float32)
|
177 |
+
pad_h = (s - new_h) // 2
|
178 |
+
pad_w = (s - new_w) // 2
|
179 |
+
canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
|
180 |
+
# Prepare input for model
|
181 |
+
input_tensor = np.transpose(canvas, (2, 0, 1))[np.newaxis, :]
|
182 |
+
pred = model.run(None, {"img": input_tensor})[0].item()
|
183 |
return pred
|
184 |
|
185 |
+
|
186 |
+
#####################################
|
187 |
+
# Image Evaluation Tool #
|
188 |
+
#####################################
|
189 |
+
|
190 |
class ImageEvaluationTool:
|
191 |
+
"""Evaluation tool to process images through multiple aesthetic models and generate logs and HTML outputs."""
|
192 |
def __init__(self):
|
193 |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
194 |
print(f"Using device: {self.device}")
|
|
|
195 |
print("Loading models... This may take some time.")
|
196 |
|
197 |
+
# Load models with progress logs
|
198 |
print("Loading Aesthetic Shadow model...")
|
199 |
self.aesthetic_shadow = pipeline("image-classification", model="NeoChen1024/aesthetic-shadow-v2-backup", device=self.device)
|
|
|
200 |
print("Loading Waifu Scorer model...")
|
201 |
self.waifu_scorer = WaifuScorer(device=self.device, verbose=True)
|
|
|
202 |
print("Loading Aesthetic Predictor V2.5...")
|
203 |
+
self.aesthetic_predictor = load_aesthetic_predictor_v2_5()
|
|
|
204 |
print("Loading Anime Aesthetic model...")
|
205 |
self.anime_aesthetic = load_anime_aesthetic_model()
|
|
|
206 |
print("All models loaded successfully!")
|
207 |
|
208 |
self.temp_dir = tempfile.mkdtemp()
|
209 |
+
self.results = [] # Store final results for sorting and display
|
210 |
+
self.available_models = {
|
211 |
+
"aesthetic_shadow": {"name": "Aesthetic Shadow", "process": self._process_aesthetic_shadow},
|
212 |
+
"waifu_scorer": {"name": "Waifu Scorer", "process": self._process_waifu_scorer},
|
213 |
+
"aesthetic_predictor_v2_5": {"name": "Aesthetic V2.5", "process": self._process_aesthetic_predictor_v2_5},
|
214 |
+
"anime_aesthetic": {"name": "Anime Score", "process": self._process_anime_aesthetic},
|
215 |
+
}
|
216 |
+
|
217 |
+
def image_to_base64(self, image: Image.Image) -> str:
|
218 |
+
"""Convert PIL Image to base64 encoded JPEG string."""
|
219 |
+
buffered = BytesIO()
|
220 |
+
image.save(buffered, format="JPEG")
|
221 |
+
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
222 |
|
223 |
+
def auto_tune_batch_size(self, images: list) -> int:
|
224 |
+
"""Automatically determine the optimal batch size for processing."""
|
225 |
+
batch_size = 1
|
226 |
+
max_batch = len(images)
|
227 |
+
test_image = images[0:1]
|
228 |
+
while batch_size <= max_batch:
|
229 |
+
try:
|
230 |
+
if "aesthetic_shadow" in self.available_models and self.available_models["aesthetic_shadow"]['selected']: # Check if model is available and selected
|
231 |
+
_ = self.aesthetic_shadow(test_image * batch_size)
|
232 |
+
if "waifu_scorer" in self.available_models and self.available_models["waifu_scorer"]['selected']: # Check if model is available and selected
|
233 |
+
_ = self.waifu_scorer(test_image * batch_size)
|
234 |
+
if "aesthetic_predictor_v2_5" in self.available_models and self.available_models["aesthetic_predictor_v2_5"]['selected']: # Check if model is available and selected
|
235 |
+
_ = self.aesthetic_predictor.inference(test_image * batch_size)
|
236 |
+
batch_size *= 2
|
237 |
+
if batch_size > max_batch:
|
238 |
+
break
|
239 |
+
except Exception:
|
240 |
+
break
|
241 |
+
optimal = max(1, batch_size // 2)
|
242 |
+
if optimal > 64:
|
243 |
+
optimal = 64
|
244 |
+
print("Capped optimal batch size to 64")
|
245 |
+
print(f"Optimal batch size determined: {optimal}")
|
246 |
+
return optimal
|
247 |
+
|
248 |
+
async def process_images_evaluation_with_logs(self, file_paths: list, auto_batch: bool, manual_batch_size: int, selected_models):
|
249 |
+
"""Asynchronously process images and yield updates with logs, HTML table, and progress bar."""
|
250 |
+
self.results = []
|
251 |
+
log_events = []
|
252 |
+
images = []
|
253 |
+
file_names = []
|
254 |
+
|
255 |
+
# Update available models based on selection
|
256 |
+
for model_key in self.available_models:
|
257 |
+
self.available_models[model_key]['selected'] = model_key in selected_models
|
258 |
+
|
259 |
+
total_files = len(file_paths)
|
260 |
+
log_events.append(f"Starting to load {total_files} images...")
|
261 |
+
for f in file_paths:
|
262 |
+
try:
|
263 |
+
img = Image.open(f).convert("RGB")
|
264 |
+
images.append(img)
|
265 |
+
file_names.append(os.path.basename(f))
|
266 |
+
except Exception as e:
|
267 |
+
log_events.append(f"Error opening {f}: {e}")
|
268 |
|
269 |
+
if not images:
|
270 |
+
log_events.append("No valid images loaded.")
|
271 |
+
progress_percentage = 0 # Define progress_percentage here for no images case
|
272 |
+
progress_html = self._generate_progress_html(progress_percentage)
|
273 |
+
yield ("<p>No images loaded.</p>", "", self._format_logs(log_events), progress_html, manual_batch_size)
|
274 |
+
return
|
275 |
|
276 |
+
yield ("<p>Images loaded. Determining batch size...</p>", "", self._format_logs(log_events),
|
277 |
+
self._generate_progress_html(0), manual_batch_size)
|
278 |
+
await asyncio.sleep(0.1)
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
|
280 |
try:
|
281 |
+
manual_batch_size = int(manual_batch_size) if manual_batch_size is not None else 1
|
282 |
+
except ValueError:
|
283 |
+
manual_batch_size = 1
|
284 |
+
log_events.append("Invalid manual batch size. Defaulting to 1.")
|
285 |
+
|
286 |
+
optimal_batch = self.auto_tune_batch_size(images) if auto_batch else manual_batch_size
|
287 |
+
log_events.append(f"Using batch size: {optimal_batch}")
|
288 |
+
yield ("<p>Processing images in batches...</p>", "", self._format_logs(log_events),
|
289 |
+
self._generate_progress_html(0), optimal_batch)
|
290 |
+
await asyncio.sleep(0.1)
|
291 |
+
|
292 |
+
total_images = len(images)
|
293 |
+
for i in range(0, total_images, optimal_batch):
|
294 |
+
batch_images = images[i:i+optimal_batch]
|
295 |
+
batch_file_names = file_names[i:i+optimal_batch]
|
296 |
+
batch_index = i // optimal_batch + 1
|
297 |
+
log_events.append(f"Processing batch {batch_index}: images {i+1} to {min(i+optimal_batch, total_images)}")
|
298 |
+
|
299 |
+
batch_results = {}
|
300 |
+
|
301 |
+
# Aesthetic Shadow processing
|
302 |
+
if self.available_models['aesthetic_shadow']['selected']:
|
303 |
+
batch_results['aesthetic_shadow'] = await self._process_aesthetic_shadow(batch_images, log_events)
|
304 |
+
else:
|
305 |
+
batch_results['aesthetic_shadow'] = [None] * len(batch_images)
|
306 |
|
307 |
+
# Waifu Scorer processing
|
308 |
+
if self.available_models['waifu_scorer']['selected']:
|
309 |
+
batch_results['waifu_scorer'] = await self._process_waifu_scorer(batch_images, log_events)
|
310 |
+
else:
|
311 |
+
batch_results['waifu_scorer'] = [None] * len(batch_images)
|
|
|
|
|
|
|
312 |
|
313 |
+
# Aesthetic Predictor V2.5 processing
|
314 |
+
if self.available_models['aesthetic_predictor_v2_5']['selected']:
|
315 |
+
batch_results['aesthetic_predictor_v2_5'] = await self._process_aesthetic_predictor_v2_5(batch_images, log_events)
|
316 |
+
else:
|
317 |
+
batch_results['aesthetic_predictor_v2_5'] = [None] * len(batch_images)
|
|
|
|
|
|
|
|
|
318 |
|
319 |
+
# Anime Aesthetic processing (single image)
|
320 |
+
if self.available_models['anime_aesthetic']['selected']:
|
321 |
+
batch_results['anime_aesthetic'] = await self._process_anime_aesthetic(batch_images, log_events)
|
322 |
+
else:
|
323 |
+
batch_results['anime_aesthetic'] = [None] * len(batch_images)
|
|
|
|
|
324 |
|
|
|
325 |
|
326 |
+
# Combine results
|
327 |
+
for j in range(len(batch_images)):
|
328 |
+
scores_to_average = []
|
329 |
+
for model_key in self.available_models:
|
330 |
+
if self.available_models[model_key]['selected']: # Only consider selected models
|
331 |
+
score = batch_results[model_key][j]
|
332 |
+
if score is not None:
|
333 |
+
scores_to_average.append(score)
|
334 |
|
335 |
+
final_score = float(np.clip(np.mean(scores_to_average), 0.0, 10.0)) if scores_to_average else None
|
336 |
+
thumbnail = batch_images[j].copy()
|
337 |
+
thumbnail.thumbnail((200, 200))
|
338 |
+
result = {
|
339 |
+
'file_name': batch_file_names[j],
|
340 |
+
'img_data': self.image_to_base64(thumbnail), # Keep this for the HTML display
|
341 |
+
'final_score': final_score,
|
342 |
+
}
|
343 |
+
for model_key in self.available_models: # Add model scores to result
|
344 |
+
if self.available_models[model_key]['selected']:
|
345 |
+
result[model_key] = batch_results[model_key][j]
|
346 |
+
|
347 |
+
self.results.append(result)
|
348 |
+
self.sort_results() # Sort results after adding new result
|
349 |
+
progress_percentage = min(100, ((i + len(batch_images)) / total_images) * 100) # Define progress_percentage here
|
350 |
+
yield (f"<p>Processed batch {batch_index}.</p>", self.generate_html_table(self.results, selected_models), # Update table immediately
|
351 |
+
self._format_logs(log_events[-10:]), self._generate_progress_html(progress_percentage), optimal_batch)
|
352 |
+
await asyncio.sleep(0.1)
|
353 |
+
|
354 |
+
|
355 |
+
log_events.append("All images processed.")
|
356 |
+
self.sort_results() # Final sort after all images processed
|
357 |
+
html_table = self.generate_html_table(self.results, selected_models) # Pass selected models to final table generation
|
358 |
+
final_progress = self._generate_progress_html(100)
|
359 |
+
yield ("<p>All images processed.</p>", html_table,
|
360 |
+
self._format_logs(log_events[-10:]), final_progress, optimal_batch)
|
361 |
+
|
362 |
+
async def _process_aesthetic_shadow(self, batch_images, log_events):
|
363 |
try:
|
364 |
+
shadow_results = self.aesthetic_shadow(batch_images)
|
365 |
+
log_events.append("Aesthetic Shadow processed for batch.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
except Exception as e:
|
367 |
+
log_events.append(f"Error in Aesthetic Shadow: {e}")
|
368 |
+
shadow_results = [None] * len(batch_images)
|
369 |
+
aesthetic_shadow_scores = []
|
370 |
+
for res in shadow_results:
|
|
|
|
|
|
|
371 |
try:
|
372 |
+
hq_score = next(p for p in res if p['label'] == 'hq')['score']
|
373 |
+
score = float(np.clip(hq_score * 10.0, 0.0, 10.0))
|
374 |
+
except Exception:
|
375 |
+
score = None
|
376 |
+
aesthetic_shadow_scores.append(score)
|
377 |
+
log_events.append("Aesthetic Shadow scores computed for batch.")
|
378 |
+
return aesthetic_shadow_scores
|
379 |
+
|
380 |
+
async def _process_waifu_scorer(self, batch_images, log_events):
|
381 |
+
try:
|
382 |
+
waifu_scores = self.waifu_scorer(batch_images)
|
383 |
+
waifu_scores = [float(np.clip(s, 0.0, 10.0)) if s is not None else None for s in waifu_scores]
|
384 |
+
log_events.append("Waifu Scorer processed for batch.")
|
385 |
+
except Exception as e:
|
386 |
+
log_events.append(f"Error in Waifu Scorer: {e}")
|
387 |
+
waifu_scores = [None] * len(batch_images)
|
388 |
+
return waifu_scores
|
389 |
|
390 |
+
async def _process_aesthetic_predictor_v2_5(self, batch_images, log_events):
|
391 |
+
try:
|
392 |
+
v2_5_scores = self.aesthetic_predictor.inference(batch_images)
|
393 |
+
v2_5_scores = [float(np.round(np.clip(s, 0.0, 10.0), 4)) if s is not None else None for s in v2_5_scores]
|
394 |
+
log_events.append("Aesthetic Predictor V2.5 processed for batch.")
|
395 |
+
except Exception as e:
|
396 |
+
log_events.append(f"Error in Aesthetic Predictor V2.5: {e}")
|
397 |
+
v2_5_scores = [None] * len(batch_images)
|
398 |
+
return v2_5_scores
|
399 |
|
400 |
+
async def _process_anime_aesthetic(self, batch_images, log_events):
|
401 |
+
anime_scores = []
|
402 |
+
for j, img in enumerate(batch_images):
|
403 |
+
try:
|
404 |
+
score = predict_anime_aesthetic(img, self.anime_aesthetic)
|
405 |
+
anime_scores.append(float(np.clip(score * 10.0, 0.0, 10.0)))
|
406 |
+
log_events.append(f"Anime Aesthetic processed for image {j + 1}.")
|
407 |
except Exception as e:
|
408 |
+
log_events.append(f"Error in Anime Aesthetic for image {j + 1}: {e}")
|
409 |
+
anime_scores.append(None)
|
410 |
+
return anime_scores
|
411 |
+
|
412 |
+
|
413 |
+
def _generate_progress_html(self, percentage: float) -> str:
|
414 |
+
"""Generate HTML for a progress bar given a percentage."""
|
415 |
+
return f"""
|
416 |
+
<div style="width:100%;background-color:#ddd; border-radius:5px;">
|
417 |
+
<div style="width:{percentage:.1f}%; background-color:#4CAF50; text-align:center; padding:5px 0; border-radius:5px;">
|
418 |
+
{percentage:.1f}%
|
419 |
+
</div>
|
420 |
+
</div>
|
421 |
+
"""
|
422 |
|
423 |
+
def _format_logs(self, logs: list) -> str:
|
424 |
+
"""Format log events into an HTML string."""
|
425 |
+
return "<div style='max-height:300px; overflow-y:auto;'>" + "<br>".join(logs) + "</div>"
|
426 |
+
|
427 |
+
def sort_results(self, sort_by: str = "Final Score") -> list:
|
428 |
+
"""Sort results based on the specified column."""
|
429 |
+
key_map = {
|
430 |
+
"Final Score": "final_score",
|
431 |
+
"File Name": "file_name",
|
432 |
+
"Aesthetic Shadow": "aesthetic_shadow",
|
433 |
+
"Waifu Scorer": "waifu_scorer",
|
434 |
+
"Aesthetic V2.5": "aesthetic_predictor_v2_5",
|
435 |
+
"Anime Score": "anime_aesthetic"
|
436 |
+
}
|
437 |
+
key = key_map.get(sort_by, "final_score")
|
438 |
+
reverse = sort_by != "File Name"
|
439 |
+
self.results.sort(key=lambda r: r.get(key) if r.get(key) is not None else (-float('inf') if not reverse else float('inf')), reverse=reverse)
|
440 |
+
return self.results
|
441 |
+
|
442 |
+
def generate_html_table(self, results: list, selected_models) -> str:
|
443 |
+
"""Generate an HTML table to display the evaluation results."""
|
444 |
+
table_html = """
|
445 |
<style>
|
446 |
+
.results-table { width: 100%; border-collapse: collapse; margin: 20px 0; font-family: Arial, sans-serif; }
|
447 |
+
.results-table th, .results-table td { color: #eee; border: 1px solid #ddd; padding: 8px; text-align: center; }
|
448 |
+
.results-table th { font-weight: bold; }
|
449 |
+
.results-table tr:nth-child(even) { background-color: transparent; }
|
450 |
+
.results-table tr:hover { background-color: rgba(255, 255, 255, 0.1); }
|
451 |
+
.image-preview { max-width: 150px; max-height: 150px; display: block; margin: 0 auto; }
|
452 |
+
.good-score { color: #0f0; font-weight: bold; }
|
453 |
+
.bad-score { color: #f00; font-weight: bold; }
|
454 |
+
.medium-score { color: orange; font-weight: bold; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
455 |
</style>
|
|
|
456 |
<table class="results-table">
|
457 |
<thead>
|
458 |
<tr>
|
459 |
<th>Image</th>
|
460 |
<th>File Name</th>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
461 |
"""
|
462 |
+
visible_models = [] # Keep track of visible model columns
|
463 |
+
if "aesthetic_shadow" in selected_models:
|
464 |
+
table_html += "<th>Aesthetic Shadow</th>"
|
465 |
+
visible_models.append("aesthetic_shadow")
|
466 |
+
if "waifu_scorer" in selected_models:
|
467 |
+
table_html += "<th>Waifu Scorer</th>"
|
468 |
+
visible_models.append("waifu_scorer")
|
469 |
+
if "aesthetic_predictor_v2_5" in selected_models:
|
470 |
+
table_html += "<th>Aesthetic V2.5</th>"
|
471 |
+
visible_models.append("aesthetic_predictor_v2_5")
|
472 |
+
if "anime_aesthetic" in selected_models:
|
473 |
+
table_html += "<th>Anime Score</th>"
|
474 |
+
visible_models.append("anime_aesthetic")
|
475 |
+
table_html += "<th>Final Score</th>"
|
476 |
+
table_html += "</tr></thead><tbody>"
|
477 |
|
478 |
for result in results:
|
479 |
+
table_html += "<tr>"
|
480 |
+
table_html += f'<td><img src="data:image/jpeg;base64,{result["img_data"]}" class="image-preview"></td>'
|
481 |
+
table_html += f'<td>{result["file_name"]}</td>'
|
482 |
+
for model_key in visible_models: # Iterate through visible models only
|
483 |
+
score = result.get(model_key)
|
484 |
+
table_html += self._format_score_cell(score)
|
485 |
+
|
486 |
+
score = result.get("final_score")
|
487 |
+
table_html += self._format_score_cell(score)
|
488 |
+
table_html += "</tr>"
|
489 |
+
table_html += """</tbody></table>"""
|
490 |
+
return table_html
|
491 |
+
|
492 |
+
def _format_score_cell(self, score):
|
493 |
+
score_str = f"{score:.4f}" if isinstance(score, (int, float)) else "N/A"
|
494 |
+
score_class = ""
|
495 |
+
if isinstance(score, (int, float)):
|
496 |
+
if score >= 7:
|
497 |
+
score_class = "good-score"
|
498 |
+
elif score >= 5:
|
499 |
+
score_class = "medium-score"
|
500 |
+
else:
|
501 |
+
score_class = "bad-score"
|
502 |
+
return f'<td class="{score_class}">{score_str}</td>'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
|
|
|
504 |
|
505 |
def cleanup(self):
|
506 |
+
"""Clean up temporary directories."""
|
507 |
if os.path.exists(self.temp_dir):
|
508 |
shutil.rmtree(self.temp_dir)
|
509 |
|
|
|
|
|
510 |
|
511 |
+
#####################################
|
512 |
+
# Interface #
|
513 |
+
#####################################
|
514 |
|
515 |
+
def create_interface():
|
516 |
evaluator = ImageEvaluationTool()
|
517 |
+
sort_options = ["Final Score", "File Name", "Aesthetic Shadow", "Waifu Scorer", "Aesthetic V2.5", "Anime Score"]
|
518 |
+
model_options = ["aesthetic_shadow", "waifu_scorer", "aesthetic_predictor_v2_5", "anime_aesthetic"]
|
519 |
|
520 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
521 |
gr.Markdown("""
|
522 |
# Comprehensive Image Evaluation Tool
|
523 |
|
524 |
+
Upload images to evaluate them using multiple aesthetic and quality prediction models.
|
525 |
+
|
526 |
+
**New features:**
|
527 |
+
- **Dynamic Final Score:** Final score recalculates on model selection changes.
|
528 |
+
- **Model Selection:** Choose which models to use for evaluation.
|
529 |
+
- **Dynamic Table Updates:** Table updates automatically based on model selection.
|
530 |
+
- **Automatic Sorting:** Table is automatically sorted by 'Final Score'.
|
531 |
+
- **Detailed Logs:** See major processing events (limited to the last 10).
|
532 |
+
- **Progress Bar:** Visual indication of processing status.
|
533 |
+
- **Asynchronous Updates:** Streaming status and logs during processing.
|
534 |
+
- **Batch Size Controls:** Choose manual batch size or let the tool auto-detect it.
|
535 |
+
- **Download Results:** Export the evaluation results as CSV.
|
536 |
""")
|
537 |
|
538 |
with gr.Row():
|
539 |
with gr.Column(scale=1):
|
540 |
+
input_images = gr.Files(label="Upload Images", file_count="multiple")
|
541 |
+
model_checkboxes = gr.CheckboxGroup(model_options, label="Select Models", value=model_options, info="Choose models for evaluation.")
|
542 |
+
auto_batch_checkbox = gr.Checkbox(label="Automatic Batch Size Detection", value=False, info="Enable to automatically determine the optimal batch size.")
|
543 |
+
batch_size_input = gr.Number(label="Batch Size", value=1, interactive=True, info="Manually specify the batch size if auto-detection is disabled.")
|
544 |
+
sort_dropdown = gr.Dropdown(sort_options, value="Final Score", label="Sort by", info="Select the column to sort results by.")
|
545 |
process_btn = gr.Button("Evaluate Images", variant="primary")
|
546 |
clear_btn = gr.Button("Clear Results")
|
547 |
+
download_csv = gr.Button("Download CSV", variant="secondary")
|
548 |
|
549 |
with gr.Column(scale=2):
|
550 |
+
progress_bar = gr.HTML(label="Progress Bar", value="""
|
551 |
+
<div style='width:100%;background-color:#ddd;'>
|
552 |
+
<div style='width:0%;background-color:#4CAF50;padding:5px 0;text-align:center;'>0%</div>
|
553 |
+
</div>
|
554 |
+
""")
|
555 |
+
log_window = gr.HTML(label="Detailed Logs", value="<div style='max-height:300px; overflow-y:auto;'>Logs will appear here...</div>")
|
556 |
+
status_html = gr.HTML(label="Status")
|
557 |
output_html = gr.HTML(label="Evaluation Results")
|
558 |
+
download_file_output = gr.File() # Initialize gr.File component without filename
|
559 |
+
|
560 |
+
# Function to convert results to CSV format, excluding 'img_data'.
|
561 |
+
def results_to_csv(selected_models):
|
562 |
+
import csv
|
563 |
+
import io
|
564 |
+
if not evaluator.results:
|
565 |
+
return None # Return None when no results are available
|
566 |
+
output = io.StringIO()
|
567 |
+
fieldnames = ['file_name', 'final_score'] # Base fieldnames
|
568 |
+
for model_key in selected_models: # Add selected model names as fieldnames
|
569 |
+
if model_key in selected_models: # Double check if model_key is indeed in selected_models list
|
570 |
+
fieldnames.append(model_key)
|
571 |
+
|
572 |
+
writer = csv.DictWriter(output, fieldnames=fieldnames)
|
573 |
+
writer.writeheader()
|
574 |
+
for res in evaluator.results:
|
575 |
+
row_dict = {'file_name': res['file_name'], 'final_score': res['final_score']} # Base data
|
576 |
+
for model_key in selected_models: # Add selected model scores
|
577 |
+
if model_key in selected_models: # Double check before accessing res[model_key]
|
578 |
+
row_dict[model_key] = res.get(model_key, 'N/A') # Use get with default 'N/A' if model not in result (shouldn't happen but for safety)
|
579 |
+
writer.writerow(row_dict)
|
580 |
+
return output.getvalue()
|
581 |
+
|
582 |
+
|
583 |
+
def update_batch_size_interactivity(auto_batch):
|
584 |
+
return gr.update(interactive=not auto_batch)
|
585 |
+
|
586 |
+
async def process_images_and_update(files, auto_batch, manual_batch, selected_models):
|
587 |
file_paths = [f.name for f in files]
|
588 |
+
async for status, table, logs, progress, updated_batch in evaluator.process_images_evaluation_with_logs(file_paths, auto_batch, manual_batch, selected_models):
|
589 |
+
yield status, table, logs, progress, gr.update(value=updated_batch, interactive=not auto_batch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
590 |
|
591 |
+
def update_table_sort(sort_by_column, selected_models):
|
592 |
+
sorted_results = evaluator.sort_results(sort_by_column)
|
593 |
+
return evaluator.generate_html_table(sorted_results, selected_models)
|
|
|
594 |
|
595 |
+
def update_table_model_selection(selected_models):
|
596 |
+
# Recalculate final scores based on selected models
|
597 |
+
for result in evaluator.results:
|
598 |
+
scores_to_average = []
|
599 |
+
for model_key in evaluator.available_models:
|
600 |
+
if model_key in selected_models and evaluator.available_models[model_key]['selected']: # consider only selected models from checkbox group and available_models
|
601 |
+
score = result.get(model_key)
|
602 |
+
if score is not None:
|
603 |
+
scores_to_average.append(score)
|
604 |
+
final_score = float(np.clip(np.mean(scores_to_average), 0.0, 10.0)) if scores_to_average else None
|
605 |
+
result['final_score'] = final_score
|
606 |
+
|
607 |
+
sorted_results = evaluator.sort_results() # Keep sorting by Final Score when models change
|
608 |
+
return evaluator.generate_html_table(sorted_results, selected_models)
|
609 |
+
|
610 |
+
|
611 |
+
def clear_results():
|
612 |
+
evaluator.results = []
|
613 |
+
return (gr.update(value=""),
|
614 |
+
gr.update(value=""),
|
615 |
+
gr.update(value=""),
|
616 |
+
gr.update(value="""
|
617 |
+
<div style='width:100%;background-color:#ddd;'>
|
618 |
+
<div style='width:0%;background-color:#4CAF50;padding:5px 0;text-align:center;'>0%</div>
|
619 |
+
</div>
|
620 |
+
"""),
|
621 |
+
gr.update(value=1))
|
622 |
+
|
623 |
+
def download_results_csv_trigger(selected_models): # Changed function name to avoid conflict and clarify purpose
|
624 |
+
csv_content = results_to_csv(selected_models)
|
625 |
+
if csv_content is None:
|
626 |
+
return None # Indicate no file to download
|
627 |
+
|
628 |
+
# Create a temporary file to save the CSV data
|
629 |
+
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmp_file:
|
630 |
+
tmp_file.write(csv_content.encode())
|
631 |
+
temp_file_path = tmp_file.name # Get the path to the temporary file
|
632 |
+
|
633 |
+
return temp_file_path # Return the path to the temporary file
|
634 |
+
|
635 |
+
|
636 |
+
auto_batch_checkbox.change(
|
637 |
+
update_batch_size_interactivity,
|
638 |
+
inputs=[auto_batch_checkbox],
|
639 |
+
outputs=[batch_size_input]
|
640 |
+
)
|
641 |
|
642 |
process_btn.click(
|
643 |
process_images_and_update,
|
644 |
+
inputs=[input_images, auto_batch_checkbox, batch_size_input, model_checkboxes],
|
645 |
+
outputs=[status_html, output_html, log_window, progress_bar, batch_size_input]
|
646 |
)
|
647 |
+
sort_dropdown.change(
|
648 |
update_table_sort,
|
649 |
+
inputs=[sort_dropdown, model_checkboxes],
|
650 |
+
outputs=[output_html]
|
651 |
+
)
|
652 |
+
model_checkboxes.change( # Added change event for model checkboxes
|
653 |
+
update_table_model_selection,
|
654 |
+
inputs=[model_checkboxes],
|
655 |
+
outputs=[output_html]
|
656 |
)
|
657 |
clear_btn.click(
|
658 |
clear_results,
|
659 |
inputs=[],
|
660 |
+
outputs=[status_html, output_html, log_window, progress_bar, batch_size_input]
|
661 |
)
|
662 |
+
download_csv.click(
|
663 |
+
download_results_csv_trigger, # Call the trigger function
|
664 |
+
inputs=[model_checkboxes],
|
665 |
+
outputs=[download_file_output] # Output is now the gr.File component
|
666 |
+
)
|
667 |
+
demo.load(lambda: update_table_sort("Final Score", model_options), inputs=None, outputs=[output_html]) # Initial sort and table render
|
668 |
gr.Markdown("""
|
669 |
### Notes
|
670 |
+
- Select models to use for evaluation using the checkboxes.
|
671 |
+
- The 'Final Score' recalculates dynamically when models are selected/deselected.
|
672 |
+
- The table updates automatically when models are selected/deselected and is always sorted by 'Final Score'.
|
673 |
+
- The log window displays the most recent 10 events.
|
674 |
+
- The progress bar shows overall processing status.
|
675 |
+
- When 'Automatic Batch Size Detection' is enabled, the batch size field becomes disabled.
|
676 |
+
- Use the download button to export your evaluation results as CSV.
|
677 |
""")
|
678 |
|
679 |
return demo
|