|
import gradio as gr |
|
import torch |
|
import os |
|
import numpy as np |
|
import cv2 |
|
import onnxruntime as rt |
|
from PIL import Image |
|
from transformers import pipeline |
|
from huggingface_hub import hf_hub_download |
|
import pandas as pd |
|
import tempfile |
|
import shutil |
|
|
|
|
|
class MLP(torch.nn.Module): |
|
def __init__(self, input_size, xcol='emb', ycol='avg_rating', batch_norm=True): |
|
super().__init__() |
|
self.input_size = input_size |
|
self.xcol = xcol |
|
self.ycol = ycol |
|
self.layers = torch.nn.Sequential( |
|
torch.nn.Linear(self.input_size, 2048), |
|
torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(2048) if batch_norm else torch.nn.Identity(), |
|
torch.nn.Dropout(0.3), |
|
torch.nn.Linear(2048, 512), |
|
torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(512) if batch_norm else torch.nn.Identity(), |
|
torch.nn.Dropout(0.3), |
|
torch.nn.Linear(512, 256), |
|
torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(256) if batch_norm else torch.nn.Identity(), |
|
torch.nn.Dropout(0.2), |
|
torch.nn.Linear(256, 128), |
|
torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(128) if batch_norm else torch.nn.Identity(), |
|
torch.nn.Dropout(0.1), |
|
torch.nn.Linear(128, 32), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(32, 1) |
|
) |
|
|
|
def forward(self, x): |
|
return self.layers(x) |
|
|
|
|
|
class WaifuScorer(object): |
|
def __init__(self, model_path=None, device='cuda', cache_dir=None, verbose=False): |
|
self.verbose = verbose |
|
|
|
|
|
import clip |
|
|
|
if model_path is None: |
|
model_path = "Eugeoter/waifu-scorer-v4-beta/model.pth" |
|
if self.verbose: |
|
print(f"model path not set, switch to default: `{model_path}`") |
|
|
|
|
|
if not os.path.isfile(model_path): |
|
split = model_path.split("/") |
|
username, repo_id, model_name = split[-3], split[-2], split[-1] |
|
model_path = hf_hub_download(f"{username}/{repo_id}", model_name, cache_dir=cache_dir) |
|
|
|
print(f"Loading WaifuScorer model from `{model_path}`") |
|
|
|
|
|
self.mlp = MLP(input_size=768) |
|
s = torch.load(model_path, map_location=device) |
|
self.mlp.load_state_dict(s) |
|
self.mlp.to(device) |
|
|
|
|
|
self.model2, self.preprocess = clip.load("ViT-L/14", device=device) |
|
self.device = device |
|
self.dtype = torch.float32 |
|
self.mlp.eval() |
|
|
|
@torch.no_grad() |
|
def __call__(self, images): |
|
if isinstance(images, Image.Image): |
|
images = [images] |
|
n = len(images) |
|
if n == 1: |
|
images = images*2 |
|
|
|
|
|
image_tensors = [self.preprocess(img).unsqueeze(0) for img in images] |
|
image_batch = torch.cat(image_tensors).to(self.device) |
|
image_features = self.model2.encode_image(image_batch) |
|
|
|
|
|
l2 = image_features.norm(2, dim=-1, keepdim=True) |
|
l2[l2 == 0] = 1 |
|
im_emb_arr = (image_features / l2).to(device=self.device, dtype=self.dtype) |
|
|
|
|
|
predictions = self.mlp(im_emb_arr) |
|
scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist() |
|
|
|
|
|
return scores[:n] |
|
|
|
|
|
def load_aesthetic_predictor_v2_5(): |
|
|
|
|
|
|
|
|
|
class AestheticPredictorV2_5: |
|
def __init__(self): |
|
print("Loading Aesthetic Predictor V2.5...") |
|
|
|
|
|
def inference(self, image): |
|
|
|
|
|
|
|
return np.random.uniform(1, 10) |
|
|
|
return AestheticPredictorV2_5() |
|
|
|
|
|
def load_anime_aesthetic_model(): |
|
model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx") |
|
model = rt.InferenceSession(model_path, providers=['CPUExecutionProvider']) |
|
return model |
|
|
|
|
|
def predict_anime_aesthetic(img, model): |
|
img = np.array(img).astype(np.float32) / 255 |
|
s = 768 |
|
h, w = img.shape[:-1] |
|
h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s) |
|
ph, pw = s - h, s - w |
|
img_input = np.zeros([s, s, 3], dtype=np.float32) |
|
img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(img, (w, h)) |
|
img_input = np.transpose(img_input, (2, 0, 1)) |
|
img_input = img_input[np.newaxis, :] |
|
pred = model.run(None, {"img": img_input})[0].item() |
|
return pred |
|
|
|
|
|
class ImageEvaluationTool: |
|
def __init__(self): |
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
print(f"Using device: {self.device}") |
|
|
|
|
|
print("Loading models... This may take some time.") |
|
|
|
|
|
print("Loading Aesthetic Shadow model...") |
|
self.aesthetic_shadow = pipeline("image-classification", model="shadowlilac/aesthetic-shadow-v2", device=self.device) |
|
|
|
try: |
|
|
|
print("Loading Waifu Scorer model...") |
|
self.waifu_scorer = WaifuScorer(device=self.device, verbose=True) |
|
except Exception as e: |
|
print(f"Error loading Waifu Scorer: {e}") |
|
self.waifu_scorer = None |
|
|
|
|
|
print("Loading Aesthetic Predictor V2.5...") |
|
self.aesthetic_predictor_v2_5 = load_aesthetic_predictor_v2_5() |
|
|
|
|
|
print("Loading Cafe Aesthetic models...") |
|
self.cafe_aesthetic = pipeline("image-classification", "cafeai/cafe_aesthetic") |
|
self.cafe_style = pipeline("image-classification", "cafeai/cafe_style") |
|
self.cafe_waifu = pipeline("image-classification", "cafeai/cafe_waifu") |
|
|
|
|
|
print("Loading Anime Aesthetic model...") |
|
self.anime_aesthetic = load_anime_aesthetic_model() |
|
|
|
print("All models loaded successfully!") |
|
|
|
|
|
self.temp_dir = tempfile.mkdtemp() |
|
|
|
def evaluate_image(self, image): |
|
"""Evaluate a single image with all models""" |
|
results = {} |
|
|
|
|
|
if not isinstance(image, Image.Image): |
|
image = Image.fromarray(image) |
|
|
|
|
|
try: |
|
shadow_result = self.aesthetic_shadow(images=[image])[0] |
|
hq_score = [p for p in shadow_result if p['label'] == 'hq'][0]['score'] |
|
results['aesthetic_shadow'] = round(hq_score, 2) |
|
except Exception as e: |
|
print(f"Error in Aesthetic Shadow: {e}") |
|
results['aesthetic_shadow'] = None |
|
|
|
|
|
if self.waifu_scorer: |
|
try: |
|
waifu_score = self.waifu_scorer([image])[0] |
|
results['waifu_scorer'] = round(waifu_score, 2) |
|
except Exception as e: |
|
print(f"Error in Waifu Scorer: {e}") |
|
results['waifu_scorer'] = None |
|
else: |
|
results['waifu_scorer'] = None |
|
|
|
|
|
try: |
|
v2_5_score = self.aesthetic_predictor_v2_5.inference(image) |
|
results['aesthetic_predictor_v2_5'] = round(v2_5_score, 2) |
|
except Exception as e: |
|
print(f"Error in Aesthetic Predictor V2.5: {e}") |
|
results['aesthetic_predictor_v2_5'] = None |
|
|
|
|
|
try: |
|
cafe_aesthetic_result = self.cafe_aesthetic(image, top_k=2) |
|
cafe_aesthetic_score = {d["label"]: round(d["score"], 2) for d in cafe_aesthetic_result} |
|
results['cafe_aesthetic_good'] = cafe_aesthetic_score.get('good', 0) |
|
results['cafe_aesthetic_bad'] = cafe_aesthetic_score.get('bad', 0) |
|
|
|
cafe_style_result = self.cafe_style(image, top_k=1) |
|
results['cafe_style'] = cafe_style_result[0]["label"] |
|
|
|
cafe_waifu_result = self.cafe_waifu(image, top_k=1) |
|
results['cafe_waifu'] = cafe_waifu_result[0]["label"] |
|
except Exception as e: |
|
print(f"Error in Cafe Aesthetic: {e}") |
|
results['cafe_aesthetic_good'] = None |
|
results['cafe_aesthetic_bad'] = None |
|
results['cafe_style'] = None |
|
results['cafe_waifu'] = None |
|
|
|
|
|
try: |
|
img_array = np.array(image) |
|
anime_score = predict_anime_aesthetic(img_array, self.anime_aesthetic) |
|
results['anime_aesthetic'] = round(anime_score, 2) |
|
except Exception as e: |
|
print(f"Error in Anime Aesthetic: {e}") |
|
results['anime_aesthetic'] = None |
|
|
|
return results |
|
|
|
def process_images(self, image_files): |
|
"""Process multiple image files and return results""" |
|
results = [] |
|
|
|
for i, file_path in enumerate(image_files): |
|
try: |
|
|
|
img = Image.open(file_path).convert("RGB") |
|
|
|
|
|
eval_results = self.evaluate_image(img) |
|
|
|
|
|
thumbnail_path = os.path.join(self.temp_dir, f"thumbnail_{i}.jpg") |
|
img.thumbnail((200, 200)) |
|
img.save(thumbnail_path) |
|
|
|
|
|
result = { |
|
'file_name': os.path.basename(file_path), |
|
'thumbnail': thumbnail_path, |
|
**eval_results |
|
} |
|
results.append(result) |
|
|
|
except Exception as e: |
|
print(f"Error processing {file_path}: {e}") |
|
|
|
return results |
|
|
|
def cleanup(self): |
|
"""Clean up temporary files""" |
|
if os.path.exists(self.temp_dir): |
|
shutil.rmtree(self.temp_dir) |
|
|
|
|
|
|
|
def create_interface(): |
|
evaluator = ImageEvaluationTool() |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown(""" |
|
# Comprehensive Image Evaluation Tool |
|
|
|
Upload images to evaluate them using multiple aesthetic and quality prediction models: |
|
|
|
- **Aesthetic Shadow**: Evaluates high-quality vs low-quality images |
|
- **Waifu Scorer**: Rates anime/illustration quality from 0-10 |
|
- **Aesthetic Predictor V2.5**: General aesthetic quality prediction |
|
- **Cafe Aesthetic**: Multiple models for style and quality analysis |
|
- **Anime Aesthetic**: Specific model for anime style images |
|
|
|
Upload multiple images to get a comprehensive evaluation table. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_images = gr.Files(label="Upload Images") |
|
process_btn = gr.Button("Evaluate Images", variant="primary") |
|
clear_btn = gr.Button("Clear Results") |
|
|
|
with gr.Column(scale=2): |
|
output_gallery = gr.Gallery(label="Evaluated Images", columns=5, object_fit="contain") |
|
output_table = gr.Dataframe(label="Evaluation Results") |
|
|
|
def process_images(files): |
|
|
|
file_paths = [f.name for f in files] |
|
|
|
|
|
results = evaluator.process_images(file_paths) |
|
|
|
|
|
gallery_images = [{"image": r["thumbnail"], "label": f"{r['file_name']}"} for r in results] |
|
|
|
|
|
table_data = [] |
|
for r in results: |
|
table_data.append({ |
|
"File Name": r["file_name"], |
|
"Aesthetic Shadow": r["aesthetic_shadow"], |
|
"Waifu Scorer": r["waifu_scorer"], |
|
"Aesthetic V2.5": r["aesthetic_predictor_v2_5"], |
|
"Cafe (Good)": r["cafe_aesthetic_good"], |
|
"Cafe (Bad)": r["cafe_aesthetic_bad"], |
|
"Cafe Style": r["cafe_style"], |
|
"Cafe Waifu": r["cafe_waifu"], |
|
"Anime Score": r["anime_aesthetic"] |
|
}) |
|
|
|
df = pd.DataFrame(table_data) |
|
return gallery_images, df |
|
|
|
def clear_results(): |
|
return None, None |
|
|
|
process_btn.click(process_images, inputs=[input_images], outputs=[output_gallery, output_table]) |
|
clear_btn.click(clear_results, inputs=[], outputs=[output_gallery, output_table]) |
|
|
|
|
|
demo.load(lambda: None, inputs=None, outputs=None) |
|
|
|
gr.Markdown(""" |
|
### Notes |
|
- The evaluation may take some time depending on the number and size of images |
|
- For best results, use high-quality images |
|
- Scores are on different scales depending on the model |
|
""") |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.queue().launch() |