import gradio as gr |
import torch |
import numpy as np |
import os |
import shutil |
from PIL import Image |
from transformers import pipeline |
import clip |
from huggingface_hub import hf_hub_download |
import onnxruntime as rt |
import pandas as pd |
import time |
class MLP(torch.nn.Module): |
def __init__(self, input_size, xcol='emb', ycol='avg_rating', batch_norm=True): |
super().__init__() |
self.input_size = input_size |
self.xcol = xcol |
self.ycol = ycol |
self.layers = torch.nn.Sequential( |
torch.nn.Linear(self.input_size, 2048), |
torch.nn.ReLU(), |
torch.nn.BatchNorm1d(2048) if batch_norm else torch.nn.Identity(), |
torch.nn.Dropout(0.3), |
torch.nn.Linear(2048, 512), |
torch.nn.ReLU(), |
torch.nn.BatchNorm1d(512) if batch_norm else torch.nn.Identity(), |
torch.nn.Dropout(0.3), |
torch.nn.Linear(512, 256), |
torch.nn.ReLU(), |
torch.nn.BatchNorm1d(256) if batch_norm else torch.nn.Identity(), |
torch.nn.Dropout(0.2), |
torch.nn.Linear(256, 128), |
torch.nn.ReLU(), |
torch.nn.BatchNorm1d(128) if batch_norm else torch.nn.Identity(), |
torch.nn.Dropout(0.1), |
torch.nn.Linear(128, 32), |
torch.nn.ReLU(), |
torch.nn.Linear(32, 1) |
) |
def forward(self, x): |
return self.layers(x) |
class WaifuScorer: |
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'): |
self.device = device |
model_path = hf_hub_download("Eugeoter/waifu-scorer-v4-beta", "model.pth", cache_dir="models") |
self.mlp = self._load_model(model_path, input_size=768, device=device) |
self.model2, self.preprocess = clip.load("ViT-L/14", device=device) |
self.dtype = self.mlp.dtype |
self.mlp.eval() |
def _load_model(self, model_path, input_size=768, device='cuda'): |
model = MLP(input_size=input_size) |
s = torch.load(model_path, map_location=device) |
model.load_state_dict(s) |
model.to(device) |
return model |
def _normalized(self, a, order=2, dim=-1): |
l2 = a.norm(order, dim, keepdim=True) |
l2[l2 == 0] = 1 |
return a / l2 |
@torch.no_grad() |
def _encode_images(self, images): |
if isinstance(images, Image.Image): |
images = [images] |
image_tensors = [self.preprocess(img).unsqueeze(0) for img in images] |
image_batch = torch.cat(image_tensors).to(self.device) |
image_features = self.model2.encode_image(image_batch) |
im_emb_arr = self._normalized(image_features).cpu().float() |
return im_emb_arr |
@torch.no_grad() |
def score(self, image): |
if isinstance(image, np.ndarray): |
image = Image.fromarray(image) |
images = [image, image] |
images = self._encode_images(images).to(device=self.device, dtype=self.dtype) |
predictions = self.mlp(images) |
scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist() |
return scores[0] |
class AnimeAestheticPredictor: |
def __init__(self): |
model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx", cache_dir="models") |
self.model = rt.InferenceSession(model_path, providers=['CPUExecutionProvider']) |
def predict(self, img): |
if isinstance(img, Image.Image): |
img = np.array(img) |
img = img.astype(np.float32) / 255 |
s = 768 |
h, w = img.shape[:-1] |
h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s) |
ph, pw = s - h, s - w |
img_input = np.zeros([s, s, 3], dtype=np.float32) |
img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(img, (w, h)) |
img_input = np.transpose(img_input, (2, 0, 1)) |
img_input = img_input[np.newaxis, :] |
pred = self.model.run(None, {"img": img_input})[0].item() |
return pred |
class ImageEvaluator: |
def __init__(self): |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
self.setup_models() |
self.results_df = None |
self.temp_dir = "temp_images" |
if not os.path.exists(self.temp_dir): |
os.makedirs(self.temp_dir) |
if not os.path.exists("output"): |
os.makedirs("output/hq_folder", exist_ok=True) |
os.makedirs("output/lq_folder", exist_ok=True) |
def setup_models(self): |
print("Setting up models (this may take a few minutes)...") |
self.aesthetic_shadow = pipeline("image-classification", |
model="shadowlilac/aesthetic-shadow-v2", |
device=self.device) |
try: |
self.waifu_scorer = WaifuScorer(device=self.device) |
except Exception as e: |
print(f"Error loading WaifuScorer: {e}") |
self.waifu_scorer = None |
self.cafe_aesthetic = pipeline("image-classification", "cafeai/cafe_aesthetic") |
self.cafe_style = pipeline("image-classification", "cafeai/cafe_style") |
self.cafe_waifu = pipeline("image-classification", "cafeai/cafe_waifu") |
self.anime_aesthetic = AnimeAestheticPredictor() |
print("All models loaded successfully!") |
def evaluate_image(self, image_path): |
"""Evaluate a single image with all models""" |
if isinstance(image_path, str): |
image = Image.open(image_path).convert('RGB') |
else: |
image = image_path |
results = {} |
shadow_result = self.aesthetic_shadow(images=[image]) |
results["shadow_hq"] = round([p for p in shadow_result[0] if p['label'] == 'hq'][0]['score'], 2) |
if self.waifu_scorer: |
try: |
results["waifu_score"] = round(self.waifu_scorer.score(image), 2) |
except Exception as e: |
results["waifu_score"] = 0 |
print(f"Error with WaifuScorer: {e}") |
cafe_aesthetic_result = self.cafe_aesthetic(image, top_k=2) |
results["cafe_aesthetic"] = round(next((item["score"] for item in cafe_aesthetic_result if item["label"] == "aesthetic"), 0), 2) |
cafe_style_result = self.cafe_style(image, top_k=5) |
results["cafe_top_style"] = cafe_style_result[0]["label"] |
results["cafe_top_style_score"] = round(cafe_style_result[0]["score"], 2) |
cafe_waifu_result = self.cafe_waifu(image, top_k=5) |
results["cafe_top_waifu"] = cafe_waifu_result[0]["label"] |
results["cafe_top_waifu_score"] = round(cafe_waifu_result[0]["score"], 2) |
try: |
results["anime_aesthetic"] = round(self.anime_aesthetic.predict(image), 2) |
except Exception as e: |
results["anime_aesthetic"] = 0 |
print(f"Error with Anime Aesthetic: {e}") |
scores = [results["shadow_hq"] * 10] |
if self.waifu_scorer: |
scores.append(results["waifu_score"]) |
scores.append(results["cafe_aesthetic"] * 10) |
scores.append(results["anime_aesthetic"]) |
results["average_score"] = round(sum(scores) / len(scores), 2) |
return results |
def process_images(self, files, threshold=0.5, progress=None): |
"""Process multiple images and return results dataframe""" |
results = [] |
total_files = len(files) |
for f in os.listdir(self.temp_dir): |
os.remove(os.path.join(self.temp_dir, f)) |
for i, file in enumerate(files): |
if progress is not None: |
progress(i / total_files, f"Processing {i+1}/{total_files}: {os.path.basename(file)}") |
filename = os.path.basename(file) |
temp_path = os.path.join(self.temp_dir, filename) |
shutil.copy(file, temp_path) |
results_dict = self.evaluate_image(temp_path) |
results_dict["filename"] = filename |
results_dict["path"] = temp_path |
results_dict["is_hq"] = results_dict["shadow_hq"] >= threshold |
destination = "output/hq_folder" if results_dict["is_hq"] else "output/lq_folder" |
shutil.copy(temp_path, os.path.join(destination, filename)) |
results.append(results_dict) |
self.results_df = pd.DataFrame(results) |
self.results_df = self.results_df.sort_values(by="average_score", ascending=False) |
if progress is not None: |
progress(1.0, "Processing complete!") |
return self.results_df |
def get_results_html(self): |
"""Generate HTML with results and image previews""" |
if self.results_df is None: |
return "<p>No results available. Please process images first.</p>" |
html = "<h2>Results (Sorted by Average Score)</h2>" |
html += "<table style='width:100%; border-collapse: collapse;'>" |
html += "<tr style='background-color:#f0f0f0'>" |
html += "<th style='padding:8px; border:1px solid #ddd;'>Image</th>" |
html += "<th style='padding:8px; border:1px solid #ddd;'>Filename</th>" |
html += "<th style='padding:8px; border:1px solid #ddd;'>Average</th>" |
html += "<th style='padding:8px; border:1px solid #ddd;'>Shadow HQ</th>" |
if "waifu_score" in self.results_df.columns: |
html += "<th style='padding:8px; border:1px solid #ddd;'>Waifu</th>" |
html += "<th style='padding:8px; border:1px solid #ddd;'>Cafe</th>" |
html += "<th style='padding:8px; border:1px solid #ddd;'>Anime</th>" |
html += "<th style='padding:8px; border:1px solid #ddd;'>Style</th>" |
html += "</tr>" |
for _, row in self.results_df.iterrows(): |
row_color = "#e8f5e9" if row["is_hq"] else "#ffebee" |
html += f"<tr style='background-color:{row_color}'>" |
html += f"<td style='padding:8px; border:1px solid #ddd;'><img src='file={row['path']}' height='100'></td>" |
html += f"<td style='padding:8px; border:1px solid #ddd;'>{row['filename']}</td>" |
html += f"<td style='padding:8px; border:1px solid #ddd; font-weight:bold;'>{row['average_score']}</td>" |
html += f"<td style='padding:8px; border:1px solid #ddd;'>{row['shadow_hq']}</td>" |
if "waifu_score" in self.results_df.columns: |
html += f"<td style='padding:8px; border:1px solid #ddd;'>{row['waifu_score']}</td>" |
html += f"<td style='padding:8px; border:1px solid #ddd;'>{row['cafe_aesthetic']}</td>" |
html += f"<td style='padding:8px; border:1px solid #ddd;'>{row['anime_aesthetic']}</td>" |
html += f"<td style='padding:8px; border:1px solid #ddd;'>{row['cafe_top_style']} ({row['cafe_top_style_score']})</td>" |
html += "</tr>" |
html += "</table>" |
return html |
def export_results_csv(self, output_path="results.csv"): |
"""Export results to CSV file""" |
if self.results_df is not None: |
self.results_df.to_csv(output_path, index=False) |
return f"Results exported to {output_path}" |
return "No results to export" |
def create_interface(): |
evaluator = ImageEvaluator() |
with gr.Blocks(title="Comprehensive Image Evaluation Tool", theme=gr.themes.Soft()) as app: |
gr.Markdown(""" |
# 🖼️ Comprehensive Image Evaluation Tool |
Upload images to evaluate their aesthetic quality using multiple models: |
- **ShadowLilac** - General aesthetic quality (0-1) |
- **WaifuScorer** - Anime-style quality score (0-10) |
- **CafeAI** - Style classification and aesthetic assessment |
- **Anime Aesthetic** - Specialized for anime/manga art (0-10) |
The tool will provide an average score and classify images as high or low quality based on your threshold. |
""") |
with gr.Row(): |
with gr.Column(scale=1): |
input_files = gr.Files(label="Upload Images", file_types=["image"], file_count="multiple") |
threshold = gr.Slider(label="HQ Threshold (ShadowLilac score)", min=0, max=1, value=0.5, step=0.01) |
process_btn = gr.Button("Process Images", variant="primary") |
progress_bar = gr.Progress() |
export_btn = gr.Button("Export Results to CSV") |
export_msg = gr.Textbox(label="Export Status") |
with gr.Column(scale=2): |
results_html = gr.HTML(label="Results") |
with gr.Row(): |
gr.Markdown(""" |
### Single Image Evaluation |
Upload a single image to get detailed evaluation metrics. |
""") |
with gr.Row(): |
with gr.Column(scale=1): |
single_img = gr.Image(label="Upload Single Image", type="pil") |
single_eval_btn = gr.Button("Evaluate") |
with gr.Column(scale=2): |
shadow_score = gr.Number(label="ShadowLilac HQ Score (0-1)") |
waifu_score = gr.Number(label="Waifu Score (0-10)") |
cafe_aesthetic = gr.Number(label="Cafe Aesthetic Score (0-1)") |
anime_aesthetic = gr.Number(label="Anime Aesthetic Score (0-10)") |
average_score = gr.Number(label="Average Score (0-10)") |
style_label = gr.Label(label="Top Style Categories (Cafe)") |
def process_images_callback(files, threshold, progress=progress_bar): |
file_paths = [f.name for f in files] |
evaluator.process_images(file_paths, threshold, progress) |
return evaluator.get_results_html() |
def export_callback(): |
timestamp = time.strftime("%Y%m%d-%H%M%S") |
filename = f"results_{timestamp}.csv" |
return evaluator.export_results_csv(filename) |
def evaluate_single(image): |
if image is None: |
return 0, 0, 0, 0, 0, [] |
results = evaluator.evaluate_image(image) |
style_data = { |
results["cafe_top_style"]: results["cafe_top_style_score"], |
results["cafe_top_waifu"]: results["cafe_top_waifu_score"] |
} |
return ( |
results["shadow_hq"], |
results["waifu_score"] if "waifu_score" in results else 0, |
results["cafe_aesthetic"], |
results["anime_aesthetic"], |
results["average_score"], |
style_data |
) |
process_btn.click( |
process_images_callback, |
inputs=[input_files, threshold], |
outputs=[results_html] |
) |
export_btn.click( |
export_callback, |
inputs=[], |
outputs=[export_msg] |
) |
single_eval_btn.click( |
evaluate_single, |
inputs=[single_img], |
outputs=[shadow_score, waifu_score, cafe_aesthetic, anime_aesthetic, average_score, style_label] |
) |
return app |
if __name__ == "__main__": |
app = create_interface() |
app.launch() |