|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import os |
|
import shutil |
|
from PIL import Image |
|
from transformers import pipeline |
|
import clip |
|
from huggingface_hub import hf_hub_download |
|
import onnxruntime as rt |
|
import pandas as pd |
|
import time |
|
|
|
|
|
class MLP(torch.nn.Module): |
|
def __init__(self, input_size, xcol='emb', ycol='avg_rating', batch_norm=True): |
|
super().__init__() |
|
self.input_size = input_size |
|
self.xcol = xcol |
|
self.ycol = ycol |
|
self.layers = torch.nn.Sequential( |
|
torch.nn.Linear(self.input_size, 2048), |
|
torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(2048) if batch_norm else torch.nn.Identity(), |
|
torch.nn.Dropout(0.3), |
|
torch.nn.Linear(2048, 512), |
|
torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(512) if batch_norm else torch.nn.Identity(), |
|
torch.nn.Dropout(0.3), |
|
torch.nn.Linear(512, 256), |
|
torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(256) if batch_norm else torch.nn.Identity(), |
|
torch.nn.Dropout(0.2), |
|
torch.nn.Linear(256, 128), |
|
torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(128) if batch_norm else torch.nn.Identity(), |
|
torch.nn.Dropout(0.1), |
|
torch.nn.Linear(128, 32), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(32, 1) |
|
) |
|
|
|
def forward(self, x): |
|
return self.layers(x) |
|
|
|
class WaifuScorer: |
|
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'): |
|
self.device = device |
|
model_path = hf_hub_download("Eugeoter/waifu-scorer-v4-beta", "model.pth", cache_dir="models") |
|
self.mlp = self._load_model(model_path, input_size=768, device=device) |
|
self.model2, self.preprocess = clip.load("ViT-L/14", device=device) |
|
self.dtype = self.mlp.dtype |
|
self.mlp.eval() |
|
|
|
def _load_model(self, model_path, input_size=768, device='cuda'): |
|
model = MLP(input_size=input_size) |
|
s = torch.load(model_path, map_location=device) |
|
model.load_state_dict(s) |
|
model.to(device) |
|
return model |
|
|
|
def _normalized(self, a, order=2, dim=-1): |
|
l2 = a.norm(order, dim, keepdim=True) |
|
l2[l2 == 0] = 1 |
|
return a / l2 |
|
|
|
@torch.no_grad() |
|
def _encode_images(self, images): |
|
if isinstance(images, Image.Image): |
|
images = [images] |
|
image_tensors = [self.preprocess(img).unsqueeze(0) for img in images] |
|
image_batch = torch.cat(image_tensors).to(self.device) |
|
image_features = self.model2.encode_image(image_batch) |
|
im_emb_arr = self._normalized(image_features).cpu().float() |
|
return im_emb_arr |
|
|
|
@torch.no_grad() |
|
def score(self, image): |
|
if isinstance(image, np.ndarray): |
|
image = Image.fromarray(image) |
|
images = [image, image] |
|
images = self._encode_images(images).to(device=self.device, dtype=self.dtype) |
|
predictions = self.mlp(images) |
|
scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist() |
|
return scores[0] |
|
|
|
class AnimeAestheticPredictor: |
|
def __init__(self): |
|
model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx", cache_dir="models") |
|
self.model = rt.InferenceSession(model_path, providers=['CPUExecutionProvider']) |
|
|
|
def predict(self, img): |
|
if isinstance(img, Image.Image): |
|
img = np.array(img) |
|
img = img.astype(np.float32) / 255 |
|
s = 768 |
|
h, w = img.shape[:-1] |
|
h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s) |
|
ph, pw = s - h, s - w |
|
img_input = np.zeros([s, s, 3], dtype=np.float32) |
|
img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(img, (w, h)) |
|
img_input = np.transpose(img_input, (2, 0, 1)) |
|
img_input = img_input[np.newaxis, :] |
|
pred = self.model.run(None, {"img": img_input})[0].item() |
|
return pred |
|
|
|
class ImageEvaluator: |
|
def __init__(self): |
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
self.setup_models() |
|
self.results_df = None |
|
self.temp_dir = "temp_images" |
|
if not os.path.exists(self.temp_dir): |
|
os.makedirs(self.temp_dir) |
|
if not os.path.exists("output"): |
|
os.makedirs("output/hq_folder", exist_ok=True) |
|
os.makedirs("output/lq_folder", exist_ok=True) |
|
|
|
def setup_models(self): |
|
|
|
print("Setting up models (this may take a few minutes)...") |
|
|
|
|
|
self.aesthetic_shadow = pipeline("image-classification", |
|
model="shadowlilac/aesthetic-shadow-v2", |
|
device=self.device) |
|
|
|
|
|
try: |
|
self.waifu_scorer = WaifuScorer(device=self.device) |
|
except Exception as e: |
|
print(f"Error loading WaifuScorer: {e}") |
|
self.waifu_scorer = None |
|
|
|
|
|
self.cafe_aesthetic = pipeline("image-classification", "cafeai/cafe_aesthetic") |
|
self.cafe_style = pipeline("image-classification", "cafeai/cafe_style") |
|
self.cafe_waifu = pipeline("image-classification", "cafeai/cafe_waifu") |
|
|
|
|
|
self.anime_aesthetic = AnimeAestheticPredictor() |
|
|
|
print("All models loaded successfully!") |
|
|
|
def evaluate_image(self, image_path): |
|
"""Evaluate a single image with all models""" |
|
if isinstance(image_path, str): |
|
image = Image.open(image_path).convert('RGB') |
|
else: |
|
image = image_path |
|
|
|
results = {} |
|
|
|
|
|
shadow_result = self.aesthetic_shadow(images=[image]) |
|
results["shadow_hq"] = round([p for p in shadow_result[0] if p['label'] == 'hq'][0]['score'], 2) |
|
|
|
|
|
if self.waifu_scorer: |
|
try: |
|
results["waifu_score"] = round(self.waifu_scorer.score(image), 2) |
|
except Exception as e: |
|
results["waifu_score"] = 0 |
|
print(f"Error with WaifuScorer: {e}") |
|
|
|
|
|
cafe_aesthetic_result = self.cafe_aesthetic(image, top_k=2) |
|
results["cafe_aesthetic"] = round(next((item["score"] for item in cafe_aesthetic_result if item["label"] == "aesthetic"), 0), 2) |
|
|
|
|
|
cafe_style_result = self.cafe_style(image, top_k=5) |
|
results["cafe_top_style"] = cafe_style_result[0]["label"] |
|
results["cafe_top_style_score"] = round(cafe_style_result[0]["score"], 2) |
|
|
|
|
|
cafe_waifu_result = self.cafe_waifu(image, top_k=5) |
|
results["cafe_top_waifu"] = cafe_waifu_result[0]["label"] |
|
results["cafe_top_waifu_score"] = round(cafe_waifu_result[0]["score"], 2) |
|
|
|
|
|
try: |
|
results["anime_aesthetic"] = round(self.anime_aesthetic.predict(image), 2) |
|
except Exception as e: |
|
results["anime_aesthetic"] = 0 |
|
print(f"Error with Anime Aesthetic: {e}") |
|
|
|
|
|
scores = [results["shadow_hq"] * 10] |
|
if self.waifu_scorer: |
|
scores.append(results["waifu_score"]) |
|
scores.append(results["cafe_aesthetic"] * 10) |
|
scores.append(results["anime_aesthetic"]) |
|
|
|
results["average_score"] = round(sum(scores) / len(scores), 2) |
|
|
|
return results |
|
|
|
def process_images(self, files, threshold=0.5, progress=None): |
|
"""Process multiple images and return results dataframe""" |
|
results = [] |
|
total_files = len(files) |
|
|
|
|
|
for f in os.listdir(self.temp_dir): |
|
os.remove(os.path.join(self.temp_dir, f)) |
|
|
|
|
|
for i, file in enumerate(files): |
|
if progress is not None: |
|
progress(i / total_files, f"Processing {i+1}/{total_files}: {os.path.basename(file)}") |
|
|
|
|
|
filename = os.path.basename(file) |
|
temp_path = os.path.join(self.temp_dir, filename) |
|
shutil.copy(file, temp_path) |
|
|
|
|
|
results_dict = self.evaluate_image(temp_path) |
|
results_dict["filename"] = filename |
|
results_dict["path"] = temp_path |
|
results_dict["is_hq"] = results_dict["shadow_hq"] >= threshold |
|
|
|
|
|
destination = "output/hq_folder" if results_dict["is_hq"] else "output/lq_folder" |
|
shutil.copy(temp_path, os.path.join(destination, filename)) |
|
|
|
results.append(results_dict) |
|
|
|
|
|
self.results_df = pd.DataFrame(results) |
|
self.results_df = self.results_df.sort_values(by="average_score", ascending=False) |
|
|
|
if progress is not None: |
|
progress(1.0, "Processing complete!") |
|
|
|
return self.results_df |
|
|
|
def get_results_html(self): |
|
"""Generate HTML with results and image previews""" |
|
if self.results_df is None: |
|
return "<p>No results available. Please process images first.</p>" |
|
|
|
html = "<h2>Results (Sorted by Average Score)</h2>" |
|
html += "<table style='width:100%; border-collapse: collapse;'>" |
|
html += "<tr style='background-color:#f0f0f0'>" |
|
html += "<th style='padding:8px; border:1px solid #ddd;'>Image</th>" |
|
html += "<th style='padding:8px; border:1px solid #ddd;'>Filename</th>" |
|
html += "<th style='padding:8px; border:1px solid #ddd;'>Average</th>" |
|
html += "<th style='padding:8px; border:1px solid #ddd;'>Shadow HQ</th>" |
|
if "waifu_score" in self.results_df.columns: |
|
html += "<th style='padding:8px; border:1px solid #ddd;'>Waifu</th>" |
|
html += "<th style='padding:8px; border:1px solid #ddd;'>Cafe</th>" |
|
html += "<th style='padding:8px; border:1px solid #ddd;'>Anime</th>" |
|
html += "<th style='padding:8px; border:1px solid #ddd;'>Style</th>" |
|
html += "</tr>" |
|
|
|
for _, row in self.results_df.iterrows(): |
|
|
|
row_color = "#e8f5e9" if row["is_hq"] else "#ffebee" |
|
|
|
html += f"<tr style='background-color:{row_color}'>" |
|
|
|
html += f"<td style='padding:8px; border:1px solid #ddd;'><img src='file={row['path']}' height='100'></td>" |
|
|
|
html += f"<td style='padding:8px; border:1px solid #ddd;'>{row['filename']}</td>" |
|
|
|
html += f"<td style='padding:8px; border:1px solid #ddd; font-weight:bold;'>{row['average_score']}</td>" |
|
|
|
html += f"<td style='padding:8px; border:1px solid #ddd;'>{row['shadow_hq']}</td>" |
|
|
|
if "waifu_score" in self.results_df.columns: |
|
html += f"<td style='padding:8px; border:1px solid #ddd;'>{row['waifu_score']}</td>" |
|
|
|
html += f"<td style='padding:8px; border:1px solid #ddd;'>{row['cafe_aesthetic']}</td>" |
|
|
|
html += f"<td style='padding:8px; border:1px solid #ddd;'>{row['anime_aesthetic']}</td>" |
|
|
|
html += f"<td style='padding:8px; border:1px solid #ddd;'>{row['cafe_top_style']} ({row['cafe_top_style_score']})</td>" |
|
html += "</tr>" |
|
|
|
html += "</table>" |
|
return html |
|
|
|
def export_results_csv(self, output_path="results.csv"): |
|
"""Export results to CSV file""" |
|
if self.results_df is not None: |
|
self.results_df.to_csv(output_path, index=False) |
|
return f"Results exported to {output_path}" |
|
return "No results to export" |
|
|
|
|
|
def create_interface(): |
|
evaluator = ImageEvaluator() |
|
|
|
with gr.Blocks(title="Comprehensive Image Evaluation Tool", theme=gr.themes.Soft()) as app: |
|
gr.Markdown(""" |
|
# 🖼️ Comprehensive Image Evaluation Tool |
|
|
|
Upload images to evaluate their aesthetic quality using multiple models: |
|
|
|
- **ShadowLilac** - General aesthetic quality (0-1) |
|
- **WaifuScorer** - Anime-style quality score (0-10) |
|
- **CafeAI** - Style classification and aesthetic assessment |
|
- **Anime Aesthetic** - Specialized for anime/manga art (0-10) |
|
|
|
The tool will provide an average score and classify images as high or low quality based on your threshold. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_files = gr.Files(label="Upload Images", file_types=["image"], file_count="multiple") |
|
threshold = gr.Slider(label="HQ Threshold (ShadowLilac score)", min=0, max=1, value=0.5, step=0.01) |
|
process_btn = gr.Button("Process Images", variant="primary") |
|
progress_bar = gr.Progress() |
|
export_btn = gr.Button("Export Results to CSV") |
|
export_msg = gr.Textbox(label="Export Status") |
|
|
|
with gr.Column(scale=2): |
|
results_html = gr.HTML(label="Results") |
|
|
|
with gr.Row(): |
|
gr.Markdown(""" |
|
### Single Image Evaluation |
|
Upload a single image to get detailed evaluation metrics. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
single_img = gr.Image(label="Upload Single Image", type="pil") |
|
single_eval_btn = gr.Button("Evaluate") |
|
|
|
with gr.Column(scale=2): |
|
shadow_score = gr.Number(label="ShadowLilac HQ Score (0-1)") |
|
waifu_score = gr.Number(label="Waifu Score (0-10)") |
|
cafe_aesthetic = gr.Number(label="Cafe Aesthetic Score (0-1)") |
|
anime_aesthetic = gr.Number(label="Anime Aesthetic Score (0-10)") |
|
average_score = gr.Number(label="Average Score (0-10)") |
|
style_label = gr.Label(label="Top Style Categories (Cafe)") |
|
|
|
def process_images_callback(files, threshold, progress=progress_bar): |
|
file_paths = [f.name for f in files] |
|
evaluator.process_images(file_paths, threshold, progress) |
|
return evaluator.get_results_html() |
|
|
|
def export_callback(): |
|
timestamp = time.strftime("%Y%m%d-%H%M%S") |
|
filename = f"results_{timestamp}.csv" |
|
return evaluator.export_results_csv(filename) |
|
|
|
def evaluate_single(image): |
|
if image is None: |
|
return 0, 0, 0, 0, 0, [] |
|
|
|
results = evaluator.evaluate_image(image) |
|
|
|
|
|
style_data = { |
|
results["cafe_top_style"]: results["cafe_top_style_score"], |
|
results["cafe_top_waifu"]: results["cafe_top_waifu_score"] |
|
} |
|
|
|
return ( |
|
results["shadow_hq"], |
|
results["waifu_score"] if "waifu_score" in results else 0, |
|
results["cafe_aesthetic"], |
|
results["anime_aesthetic"], |
|
results["average_score"], |
|
style_data |
|
) |
|
|
|
|
|
process_btn.click( |
|
process_images_callback, |
|
inputs=[input_files, threshold], |
|
outputs=[results_html] |
|
) |
|
|
|
export_btn.click( |
|
export_callback, |
|
inputs=[], |
|
outputs=[export_msg] |
|
) |
|
|
|
single_eval_btn.click( |
|
evaluate_single, |
|
inputs=[single_img], |
|
outputs=[shadow_score, waifu_score, cafe_aesthetic, anime_aesthetic, average_score, style_label] |
|
) |
|
|
|
return app |
|
|
|
if __name__ == "__main__": |
|
app = create_interface() |
|
app.launch() |