VOIDER commited on
Commit
2642664
·
verified ·
1 Parent(s): 87b041e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +394 -0
app.py CHANGED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import os
5
+ import shutil
6
+ from PIL import Image
7
+ from transformers import pipeline
8
+ import clip
9
+ from huggingface_hub import hf_hub_download
10
+ import onnxruntime as rt
11
+ import pandas as pd
12
+ import time
13
+
14
+ # Utility class for Waifu Scorer
15
+ class MLP(torch.nn.Module):
16
+ def __init__(self, input_size, xcol='emb', ycol='avg_rating', batch_norm=True):
17
+ super().__init__()
18
+ self.input_size = input_size
19
+ self.xcol = xcol
20
+ self.ycol = ycol
21
+ self.layers = torch.nn.Sequential(
22
+ torch.nn.Linear(self.input_size, 2048),
23
+ torch.nn.ReLU(),
24
+ torch.nn.BatchNorm1d(2048) if batch_norm else torch.nn.Identity(),
25
+ torch.nn.Dropout(0.3),
26
+ torch.nn.Linear(2048, 512),
27
+ torch.nn.ReLU(),
28
+ torch.nn.BatchNorm1d(512) if batch_norm else torch.nn.Identity(),
29
+ torch.nn.Dropout(0.3),
30
+ torch.nn.Linear(512, 256),
31
+ torch.nn.ReLU(),
32
+ torch.nn.BatchNorm1d(256) if batch_norm else torch.nn.Identity(),
33
+ torch.nn.Dropout(0.2),
34
+ torch.nn.Linear(256, 128),
35
+ torch.nn.ReLU(),
36
+ torch.nn.BatchNorm1d(128) if batch_norm else torch.nn.Identity(),
37
+ torch.nn.Dropout(0.1),
38
+ torch.nn.Linear(128, 32),
39
+ torch.nn.ReLU(),
40
+ torch.nn.Linear(32, 1)
41
+ )
42
+
43
+ def forward(self, x):
44
+ return self.layers(x)
45
+
46
+ class WaifuScorer:
47
+ def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
48
+ self.device = device
49
+ model_path = hf_hub_download("Eugeoter/waifu-scorer-v4-beta", "model.pth", cache_dir="models")
50
+ self.mlp = self._load_model(model_path, input_size=768, device=device)
51
+ self.model2, self.preprocess = clip.load("ViT-L/14", device=device)
52
+ self.dtype = self.mlp.dtype
53
+ self.mlp.eval()
54
+
55
+ def _load_model(self, model_path, input_size=768, device='cuda'):
56
+ model = MLP(input_size=input_size)
57
+ s = torch.load(model_path, map_location=device)
58
+ model.load_state_dict(s)
59
+ model.to(device)
60
+ return model
61
+
62
+ def _normalized(self, a, order=2, dim=-1):
63
+ l2 = a.norm(order, dim, keepdim=True)
64
+ l2[l2 == 0] = 1
65
+ return a / l2
66
+
67
+ @torch.no_grad()
68
+ def _encode_images(self, images):
69
+ if isinstance(images, Image.Image):
70
+ images = [images]
71
+ image_tensors = [self.preprocess(img).unsqueeze(0) for img in images]
72
+ image_batch = torch.cat(image_tensors).to(self.device)
73
+ image_features = self.model2.encode_image(image_batch)
74
+ im_emb_arr = self._normalized(image_features).cpu().float()
75
+ return im_emb_arr
76
+
77
+ @torch.no_grad()
78
+ def score(self, image):
79
+ if isinstance(image, np.ndarray):
80
+ image = Image.fromarray(image)
81
+ images = [image, image] # batch norm needs at least 2 images
82
+ images = self._encode_images(images).to(device=self.device, dtype=self.dtype)
83
+ predictions = self.mlp(images)
84
+ scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist()
85
+ return scores[0] # Return first score only
86
+
87
+ class AnimeAestheticPredictor:
88
+ def __init__(self):
89
+ model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx", cache_dir="models")
90
+ self.model = rt.InferenceSession(model_path, providers=['CPUExecutionProvider'])
91
+
92
+ def predict(self, img):
93
+ if isinstance(img, Image.Image):
94
+ img = np.array(img)
95
+ img = img.astype(np.float32) / 255
96
+ s = 768
97
+ h, w = img.shape[:-1]
98
+ h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
99
+ ph, pw = s - h, s - w
100
+ img_input = np.zeros([s, s, 3], dtype=np.float32)
101
+ img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(img, (w, h))
102
+ img_input = np.transpose(img_input, (2, 0, 1))
103
+ img_input = img_input[np.newaxis, :]
104
+ pred = self.model.run(None, {"img": img_input})[0].item()
105
+ return pred
106
+
107
+ class ImageEvaluator:
108
+ def __init__(self):
109
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
110
+ self.setup_models()
111
+ self.results_df = None
112
+ self.temp_dir = "temp_images"
113
+ if not os.path.exists(self.temp_dir):
114
+ os.makedirs(self.temp_dir)
115
+ if not os.path.exists("output"):
116
+ os.makedirs("output/hq_folder", exist_ok=True)
117
+ os.makedirs("output/lq_folder", exist_ok=True)
118
+
119
+ def setup_models(self):
120
+ # Initialize all models
121
+ print("Setting up models (this may take a few minutes)...")
122
+
123
+ # ShadowLilac's aesthetic model
124
+ self.aesthetic_shadow = pipeline("image-classification",
125
+ model="shadowlilac/aesthetic-shadow-v2",
126
+ device=self.device)
127
+
128
+ # WaifuScorer model
129
+ try:
130
+ self.waifu_scorer = WaifuScorer(device=self.device)
131
+ except Exception as e:
132
+ print(f"Error loading WaifuScorer: {e}")
133
+ self.waifu_scorer = None
134
+
135
+ # CafeAI models
136
+ self.cafe_aesthetic = pipeline("image-classification", "cafeai/cafe_aesthetic")
137
+ self.cafe_style = pipeline("image-classification", "cafeai/cafe_style")
138
+ self.cafe_waifu = pipeline("image-classification", "cafeai/cafe_waifu")
139
+
140
+ # Anime Aesthetic model
141
+ self.anime_aesthetic = AnimeAestheticPredictor()
142
+
143
+ print("All models loaded successfully!")
144
+
145
+ def evaluate_image(self, image_path):
146
+ """Evaluate a single image with all models"""
147
+ if isinstance(image_path, str):
148
+ image = Image.open(image_path).convert('RGB')
149
+ else:
150
+ image = image_path
151
+
152
+ results = {}
153
+
154
+ # ShadowLilac evaluation
155
+ shadow_result = self.aesthetic_shadow(images=[image])
156
+ results["shadow_hq"] = round([p for p in shadow_result[0] if p['label'] == 'hq'][0]['score'], 2)
157
+
158
+ # WaifuScorer evaluation
159
+ if self.waifu_scorer:
160
+ try:
161
+ results["waifu_score"] = round(self.waifu_scorer.score(image), 2)
162
+ except Exception as e:
163
+ results["waifu_score"] = 0
164
+ print(f"Error with WaifuScorer: {e}")
165
+
166
+ # CafeAI evaluations
167
+ cafe_aesthetic_result = self.cafe_aesthetic(image, top_k=2)
168
+ results["cafe_aesthetic"] = round(next((item["score"] for item in cafe_aesthetic_result if item["label"] == "aesthetic"), 0), 2)
169
+
170
+ # Get top style
171
+ cafe_style_result = self.cafe_style(image, top_k=5)
172
+ results["cafe_top_style"] = cafe_style_result[0]["label"]
173
+ results["cafe_top_style_score"] = round(cafe_style_result[0]["score"], 2)
174
+
175
+ # Get top waifu style if applicable
176
+ cafe_waifu_result = self.cafe_waifu(image, top_k=5)
177
+ results["cafe_top_waifu"] = cafe_waifu_result[0]["label"]
178
+ results["cafe_top_waifu_score"] = round(cafe_waifu_result[0]["score"], 2)
179
+
180
+ # Anime aesthetic evaluation
181
+ try:
182
+ results["anime_aesthetic"] = round(self.anime_aesthetic.predict(image), 2)
183
+ except Exception as e:
184
+ results["anime_aesthetic"] = 0
185
+ print(f"Error with Anime Aesthetic: {e}")
186
+
187
+ # Calculate average score
188
+ scores = [results["shadow_hq"] * 10] # Scale to 0-10
189
+ if self.waifu_scorer:
190
+ scores.append(results["waifu_score"])
191
+ scores.append(results["cafe_aesthetic"] * 10) # Scale to 0-10
192
+ scores.append(results["anime_aesthetic"])
193
+
194
+ results["average_score"] = round(sum(scores) / len(scores), 2)
195
+
196
+ return results
197
+
198
+ def process_images(self, files, threshold=0.5, progress=None):
199
+ """Process multiple images and return results dataframe"""
200
+ results = []
201
+ total_files = len(files)
202
+
203
+ # Clean temp directory
204
+ for f in os.listdir(self.temp_dir):
205
+ os.remove(os.path.join(self.temp_dir, f))
206
+
207
+ # Process each file and save a copy to temp directory
208
+ for i, file in enumerate(files):
209
+ if progress is not None:
210
+ progress(i / total_files, f"Processing {i+1}/{total_files}: {os.path.basename(file)}")
211
+
212
+ # Copy file to temp directory with clean name
213
+ filename = os.path.basename(file)
214
+ temp_path = os.path.join(self.temp_dir, filename)
215
+ shutil.copy(file, temp_path)
216
+
217
+ # Evaluate the image
218
+ results_dict = self.evaluate_image(temp_path)
219
+ results_dict["filename"] = filename
220
+ results_dict["path"] = temp_path
221
+ results_dict["is_hq"] = results_dict["shadow_hq"] >= threshold
222
+
223
+ # Copy to output directory based on HQ threshold
224
+ destination = "output/hq_folder" if results_dict["is_hq"] else "output/lq_folder"
225
+ shutil.copy(temp_path, os.path.join(destination, filename))
226
+
227
+ results.append(results_dict)
228
+
229
+ # Create dataframe and sort by average score
230
+ self.results_df = pd.DataFrame(results)
231
+ self.results_df = self.results_df.sort_values(by="average_score", ascending=False)
232
+
233
+ if progress is not None:
234
+ progress(1.0, "Processing complete!")
235
+
236
+ return self.results_df
237
+
238
+ def get_results_html(self):
239
+ """Generate HTML with results and image previews"""
240
+ if self.results_df is None:
241
+ return "<p>No results available. Please process images first.</p>"
242
+
243
+ html = "<h2>Results (Sorted by Average Score)</h2>"
244
+ html += "<table style='width:100%; border-collapse: collapse;'>"
245
+ html += "<tr style='background-color:#f0f0f0'>"
246
+ html += "<th style='padding:8px; border:1px solid #ddd;'>Image</th>"
247
+ html += "<th style='padding:8px; border:1px solid #ddd;'>Filename</th>"
248
+ html += "<th style='padding:8px; border:1px solid #ddd;'>Average</th>"
249
+ html += "<th style='padding:8px; border:1px solid #ddd;'>Shadow HQ</th>"
250
+ if "waifu_score" in self.results_df.columns:
251
+ html += "<th style='padding:8px; border:1px solid #ddd;'>Waifu</th>"
252
+ html += "<th style='padding:8px; border:1px solid #ddd;'>Cafe</th>"
253
+ html += "<th style='padding:8px; border:1px solid #ddd;'>Anime</th>"
254
+ html += "<th style='padding:8px; border:1px solid #ddd;'>Style</th>"
255
+ html += "</tr>"
256
+
257
+ for _, row in self.results_df.iterrows():
258
+ # Determine row color based on HQ status
259
+ row_color = "#e8f5e9" if row["is_hq"] else "#ffebee"
260
+
261
+ html += f"<tr style='background-color:{row_color}'>"
262
+ # Image thumbnail
263
+ html += f"<td style='padding:8px; border:1px solid #ddd;'><img src='file={row['path']}' height='100'></td>"
264
+ # Filename
265
+ html += f"<td style='padding:8px; border:1px solid #ddd;'>{row['filename']}</td>"
266
+ # Average score
267
+ html += f"<td style='padding:8px; border:1px solid #ddd; font-weight:bold;'>{row['average_score']}</td>"
268
+ # Shadow HQ score
269
+ html += f"<td style='padding:8px; border:1px solid #ddd;'>{row['shadow_hq']}</td>"
270
+ # Waifu score
271
+ if "waifu_score" in self.results_df.columns:
272
+ html += f"<td style='padding:8px; border:1px solid #ddd;'>{row['waifu_score']}</td>"
273
+ # Cafe aesthetic
274
+ html += f"<td style='padding:8px; border:1px solid #ddd;'>{row['cafe_aesthetic']}</td>"
275
+ # Anime aesthetic
276
+ html += f"<td style='padding:8px; border:1px solid #ddd;'>{row['anime_aesthetic']}</td>"
277
+ # Top style
278
+ html += f"<td style='padding:8px; border:1px solid #ddd;'>{row['cafe_top_style']} ({row['cafe_top_style_score']})</td>"
279
+ html += "</tr>"
280
+
281
+ html += "</table>"
282
+ return html
283
+
284
+ def export_results_csv(self, output_path="results.csv"):
285
+ """Export results to CSV file"""
286
+ if self.results_df is not None:
287
+ self.results_df.to_csv(output_path, index=False)
288
+ return f"Results exported to {output_path}"
289
+ return "No results to export"
290
+
291
+ # Create Gradio interface
292
+ def create_interface():
293
+ evaluator = ImageEvaluator()
294
+
295
+ with gr.Blocks(title="Comprehensive Image Evaluation Tool", theme=gr.themes.Soft()) as app:
296
+ gr.Markdown("""
297
+ # 🖼️ Comprehensive Image Evaluation Tool
298
+
299
+ Upload images to evaluate their aesthetic quality using multiple models:
300
+
301
+ - **ShadowLilac** - General aesthetic quality (0-1)
302
+ - **WaifuScorer** - Anime-style quality score (0-10)
303
+ - **CafeAI** - Style classification and aesthetic assessment
304
+ - **Anime Aesthetic** - Specialized for anime/manga art (0-10)
305
+
306
+ The tool will provide an average score and classify images as high or low quality based on your threshold.
307
+ """)
308
+
309
+ with gr.Row():
310
+ with gr.Column(scale=1):
311
+ input_files = gr.Files(label="Upload Images", file_types=["image"], file_count="multiple")
312
+ threshold = gr.Slider(label="HQ Threshold (ShadowLilac score)", min=0, max=1, value=0.5, step=0.01)
313
+ process_btn = gr.Button("Process Images", variant="primary")
314
+ progress_bar = gr.Progress()
315
+ export_btn = gr.Button("Export Results to CSV")
316
+ export_msg = gr.Textbox(label="Export Status")
317
+
318
+ with gr.Column(scale=2):
319
+ results_html = gr.HTML(label="Results")
320
+
321
+ with gr.Row():
322
+ gr.Markdown("""
323
+ ### Single Image Evaluation
324
+ Upload a single image to get detailed evaluation metrics.
325
+ """)
326
+
327
+ with gr.Row():
328
+ with gr.Column(scale=1):
329
+ single_img = gr.Image(label="Upload Single Image", type="pil")
330
+ single_eval_btn = gr.Button("Evaluate")
331
+
332
+ with gr.Column(scale=2):
333
+ shadow_score = gr.Number(label="ShadowLilac HQ Score (0-1)")
334
+ waifu_score = gr.Number(label="Waifu Score (0-10)")
335
+ cafe_aesthetic = gr.Number(label="Cafe Aesthetic Score (0-1)")
336
+ anime_aesthetic = gr.Number(label="Anime Aesthetic Score (0-10)")
337
+ average_score = gr.Number(label="Average Score (0-10)")
338
+ style_label = gr.Label(label="Top Style Categories (Cafe)")
339
+
340
+ def process_images_callback(files, threshold, progress=progress_bar):
341
+ file_paths = [f.name for f in files]
342
+ evaluator.process_images(file_paths, threshold, progress)
343
+ return evaluator.get_results_html()
344
+
345
+ def export_callback():
346
+ timestamp = time.strftime("%Y%m%d-%H%M%S")
347
+ filename = f"results_{timestamp}.csv"
348
+ return evaluator.export_results_csv(filename)
349
+
350
+ def evaluate_single(image):
351
+ if image is None:
352
+ return 0, 0, 0, 0, 0, []
353
+
354
+ results = evaluator.evaluate_image(image)
355
+
356
+ # Prepare style labels
357
+ style_data = {
358
+ results["cafe_top_style"]: results["cafe_top_style_score"],
359
+ results["cafe_top_waifu"]: results["cafe_top_waifu_score"]
360
+ }
361
+
362
+ return (
363
+ results["shadow_hq"],
364
+ results["waifu_score"] if "waifu_score" in results else 0,
365
+ results["cafe_aesthetic"],
366
+ results["anime_aesthetic"],
367
+ results["average_score"],
368
+ style_data
369
+ )
370
+
371
+ # Set up event handlers
372
+ process_btn.click(
373
+ process_images_callback,
374
+ inputs=[input_files, threshold],
375
+ outputs=[results_html]
376
+ )
377
+
378
+ export_btn.click(
379
+ export_callback,
380
+ inputs=[],
381
+ outputs=[export_msg]
382
+ )
383
+
384
+ single_eval_btn.click(
385
+ evaluate_single,
386
+ inputs=[single_img],
387
+ outputs=[shadow_score, waifu_score, cafe_aesthetic, anime_aesthetic, average_score, style_label]
388
+ )
389
+
390
+ return app
391
+
392
+ if __name__ == "__main__":
393
+ app = create_interface()
394
+ app.launch()