|
import os |
|
import shutil |
|
import tempfile |
|
import asyncio |
|
from io import BytesIO, StringIO |
|
import csv |
|
from pathlib import Path |
|
import logging |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import onnxruntime as rt |
|
from PIL import Image |
|
import gradio as gr |
|
from transformers import pipeline, AutoProcessor, AutoModelForImageClassification |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_v2_5_from_siglip(repo_id="unum-cloud/siglip-base-patch16-224-aesthetic-v2.5", low_cpu_mem_usage=True, trust_remote_code=True): |
|
logger.info(f"Loading model and preprocessor from Hugging Face Hub: {repo_id}") |
|
try: |
|
|
|
processor = AutoProcessor.from_pretrained(repo_id, low_cpu_mem_usage=low_cpu_mem_usage, trust_remote_code=trust_remote_code) |
|
model = AutoModelForImageClassification.from_pretrained(repo_id, low_cpu_mem_usage=low_cpu_mem_usage, trust_remote_code=trust_remote_code) |
|
logger.info("Successfully loaded model and preprocessor from Hugging Face Hub.") |
|
except Exception as e: |
|
logger.warning(f"Failed to load from {repo_id} due to: {e}. Using fallback mock objects.") |
|
|
|
class MockProcessor: |
|
def __call__(self, images, return_tensors="pt"): |
|
if isinstance(images, list): |
|
num_images = len(images) |
|
return {"pixel_values": torch.randn(num_images, 3, 224, 224)} |
|
else: |
|
return {"pixel_values": torch.randn(1, 3, 224, 224)} |
|
class MockModel: |
|
def __init__(self): self._parameters = {"dummy": torch.nn.Parameter(torch.empty(0))} |
|
def __call__(self, pixel_values): |
|
bs = pixel_values.shape[0] |
|
class Output: |
|
def __init__(self, logits_val): self.logits = logits_val |
|
return Output(logits_val=torch.rand(bs, 1) * 10) |
|
def to(self, *args, **kwargs): return self |
|
def cuda(self, *args, **kwargs): return self |
|
def bfloat16(self, *args, **kwargs): return self |
|
processor = MockProcessor() |
|
model = MockModel() |
|
logger.info("Using fallback mock model and preprocessor for Aesthetic Predictor V2.5.") |
|
return model, processor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MLP(torch.nn.Module): |
|
def __init__(self, input_size: int, batch_norm: bool = True): |
|
super().__init__() |
|
self.input_size = input_size |
|
layers = [ |
|
torch.nn.Linear(self.input_size, 2048), torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(2048) if batch_norm else torch.nn.Identity(), torch.nn.Dropout(0.3), |
|
torch.nn.Linear(2048, 512), torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(512) if batch_norm else torch.nn.Identity(), torch.nn.Dropout(0.3), |
|
torch.nn.Linear(512, 256), torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(256) if batch_norm else torch.nn.Identity(), torch.nn.Dropout(0.2), |
|
torch.nn.Linear(256, 128), torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(128) if batch_norm else torch.nn.Identity(), torch.nn.Dropout(0.1), |
|
torch.nn.Linear(128, 32), torch.nn.ReLU(), |
|
torch.nn.Linear(32, 1) |
|
] |
|
self.layers = torch.nn.Sequential(*layers) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.layers(x) |
|
|
|
|
|
class WaifuScorer: |
|
def __init__(self, model_path: str = None, device: str = 'cuda', cache_dir: str = None, verbose: bool = False): |
|
self.verbose = verbose |
|
self.device = device |
|
self.dtype = torch.float32 |
|
self.available = False |
|
self.clip_model = None |
|
self.preprocess = None |
|
self.mlp = None |
|
|
|
try: |
|
import clip |
|
if model_path is None: |
|
model_path = "Eugeoter/waifu-scorer-v3/model.pth" |
|
if self.verbose: logger.info(f"WaifuScorer model path not provided. Using default: {model_path}") |
|
|
|
if not Path(model_path).is_file(): |
|
try: |
|
|
|
parts = model_path.split("/") |
|
if len(parts) >= 3: |
|
repo_id_parts = parts[:-1] |
|
filename = parts[-1] |
|
repo_id_str = "/".join(repo_id_parts) |
|
model_path_resolved = hf_hub_download(repo_id=repo_id_str, filename=filename, cache_dir=cache_dir) |
|
else: |
|
model_path_resolved = hf_hub_download(repo_id=model_path, filename="model.pth", cache_dir=cache_dir) |
|
except Exception as e: |
|
logger.error(f"Failed to download WaifuScorer model from HF Hub ({model_path}): {e}") |
|
|
|
logger.info("Attempting to download specific WaifuScorer model Eugeoter/waifu-scorer-v3/model.pth") |
|
model_path_resolved = hf_hub_download("Eugeoter/waifu-scorer-v3", "model.pth", cache_dir=cache_dir) |
|
model_path = model_path_resolved |
|
|
|
|
|
if self.verbose: logger.info(f"Loading WaifuScorer model from: {model_path}") |
|
|
|
self.mlp = MLP(input_size=768) |
|
if str(model_path).endswith(".safetensors"): |
|
from safetensors.torch import load_file |
|
state_dict = load_file(model_path, device=device) |
|
else: |
|
state_dict = torch.load(model_path, map_location=device) |
|
|
|
|
|
if any(key.startswith("module.") for key in state_dict.keys()): |
|
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
|
|
|
self.mlp.load_state_dict(state_dict) |
|
self.mlp.to(device=self.device, dtype=self.dtype) |
|
self.mlp.eval() |
|
|
|
self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device) |
|
self.available = True |
|
logger.info("WaifuScorer initialized successfully.") |
|
except ImportError: |
|
logger.error("OpenAI CLIP library not found. WaifuScorer will be unavailable. Please install with 'pip install openai-clip'") |
|
except Exception as e: |
|
logger.error(f"Unable to initialize WaifuScorer: {e}") |
|
|
|
@torch.no_grad() |
|
def __call__(self, images: list[Image.Image]) -> list[float | None]: |
|
if not self.available: |
|
return [None] * len(images) |
|
|
|
if not images: |
|
return [] |
|
|
|
original_n = len(images) |
|
|
|
processed_images = images if len(images) > 1 else images * 2 |
|
|
|
try: |
|
image_tensors = [self.preprocess(img).unsqueeze(0) for img in processed_images] |
|
image_batch = torch.cat(image_tensors).to(self.device) |
|
image_features = self.clip_model.encode_image(image_batch) |
|
|
|
norm = image_features.norm(p=2, dim=-1, keepdim=True) |
|
norm = torch.where(norm == 0, torch.tensor(1.0, device=norm.device, dtype=norm.dtype), norm) |
|
im_emb = (image_features / norm).to(device=self.device, dtype=self.dtype) |
|
|
|
predictions = self.mlp(im_emb) |
|
scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist() |
|
return scores[:original_n] |
|
except Exception as e: |
|
logger.error(f"Error during WaifuScorer prediction: {e}") |
|
return [None] * original_n |
|
|
|
|
|
class AestheticPredictorV2_5_Wrapper: |
|
def __init__(self, device: str): |
|
self.device = device |
|
self.model, self.preprocessor = convert_v2_5_from_siglip( |
|
low_cpu_mem_usage=True, trust_remote_code=True |
|
) |
|
if self.device == 'cuda' and torch.cuda.is_available(): |
|
self.model = self.model.to(torch.bfloat16).cuda() |
|
logger.info("Aesthetic Predictor V2.5 Wrapper initialized.") |
|
|
|
@torch.no_grad() |
|
def inference(self, images: list[Image.Image]) -> list[float | None]: |
|
if not images: |
|
return [] |
|
try: |
|
images_rgb = [img.convert("RGB") for img in images] |
|
pixel_values = self.preprocessor(images=images_rgb, return_tensors="pt").pixel_values |
|
if self.device == 'cuda' and torch.cuda.is_available(): |
|
pixel_values = pixel_values.to(torch.bfloat16).cuda() |
|
|
|
scores_tensor = self.model(pixel_values).logits.squeeze().float().cpu().numpy() |
|
if scores_tensor.ndim == 0: |
|
scores = [scores_tensor.item()] |
|
else: |
|
scores = scores_tensor.tolist() |
|
return [round(max(0.0, min(s, 10.0)), 4) for s in scores] |
|
except Exception as e: |
|
logger.error(f"Error during Aesthetic Predictor V2.5 inference: {e}") |
|
return [None] * len(images) |
|
|
|
def load_anime_aesthetic_onnx_model(cache_dir: str = None) -> rt.InferenceSession | None: |
|
try: |
|
model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx", cache_dir=cache_dir) |
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if torch.cuda.is_available() else ['CPUExecutionProvider'] |
|
session = rt.InferenceSession(model_path, providers=providers) |
|
logger.info(f"Anime Aesthetic ONNX model loaded with providers: {session.get_providers()}") |
|
return session |
|
except Exception as e: |
|
logger.error(f"Failed to load Anime Aesthetic ONNX model: {e}") |
|
return None |
|
|
|
def preprocess_anime_aesthetic_batch(images_pil: list[Image.Image], target_size: int = 768) -> np.ndarray | None: |
|
if not images_pil: |
|
return None |
|
batch_canvases = [] |
|
try: |
|
for img_pil in images_pil: |
|
img_np = np.array(img_pil.convert("RGB")).astype(np.float32) / 255.0 |
|
h, w = img_np.shape[:2] |
|
if h > w: |
|
new_h, new_w = target_size, int(target_size * w / h) |
|
else: |
|
new_h, new_w = int(target_size * h / w), target_size |
|
|
|
resized = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_AREA) |
|
canvas = np.zeros((target_size, target_size, 3), dtype=np.float32) |
|
pad_h = (target_size - new_h) // 2 |
|
pad_w = (target_size - new_w) // 2 |
|
canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized |
|
batch_canvases.append(canvas) |
|
|
|
input_tensor_batch = np.array(batch_canvases, dtype=np.float32) |
|
input_tensor_batch = np.transpose(input_tensor_batch, (0, 3, 1, 2)) |
|
return input_tensor_batch |
|
except Exception as e: |
|
logger.error(f"Error during Anime Aesthetic preprocessing: {e}") |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
class ModelManager: |
|
def __init__(self, cache_dir: str = None): |
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
logger.info(f"Using device: {self.device}") |
|
self.cache_dir = cache_dir |
|
self.models = {} |
|
self.model_configs = {} |
|
self._load_all_models() |
|
|
|
self.processing_queue: asyncio.Queue = asyncio.Queue() |
|
self.worker_task = None |
|
self._temp_files_to_clean = [] |
|
|
|
def _load_all_models(self): |
|
logger.info("Loading Aesthetic Shadow model...") |
|
try: |
|
self.models["aesthetic_shadow"] = pipeline("image-classification", model="NeoChen1024/aesthetic-shadow-v2-backup", device=0 if self.device == 'cuda' else -1) |
|
self.model_configs["aesthetic_shadow"] = {"name": "Aesthetic Shadow", "process_func": self._process_aesthetic_shadow} |
|
logger.info("Aesthetic Shadow model loaded.") |
|
except Exception as e: |
|
logger.error(f"Failed to load Aesthetic Shadow model: {e}") |
|
|
|
logger.info("Loading Waifu Scorer model...") |
|
try: |
|
ws = WaifuScorer(device=self.device, cache_dir=self.cache_dir, verbose=True) |
|
if ws.available: |
|
self.models["waifu_scorer"] = ws |
|
self.model_configs["waifu_scorer"] = {"name": "Waifu Scorer", "process_func": self._process_waifu_scorer} |
|
logger.info("Waifu Scorer model loaded.") |
|
else: |
|
logger.warning("Waifu Scorer model is not available.") |
|
except Exception as e: |
|
logger.error(f"Failed to load Waifu Scorer model: {e}") |
|
|
|
logger.info("Loading Aesthetic Predictor V2.5...") |
|
try: |
|
ap_v25 = AestheticPredictorV2_5_Wrapper(device=self.device) |
|
self.models["aesthetic_predictor_v2_5"] = ap_v25 |
|
self.model_configs["aesthetic_predictor_v2_5"] = {"name": "Aesthetic V2.5", "process_func": self._process_aesthetic_predictor_v2_5} |
|
logger.info("Aesthetic Predictor V2.5 loaded.") |
|
except Exception as e: |
|
logger.error(f"Failed to load Aesthetic Predictor V2.5: {e}") |
|
|
|
logger.info("Loading Anime Aesthetic model...") |
|
try: |
|
aa_model = load_anime_aesthetic_onnx_model(cache_dir=self.cache_dir) |
|
if aa_model: |
|
self.models["anime_aesthetic"] = aa_model |
|
self.model_configs["anime_aesthetic"] = {"name": "Anime Score", "process_func": self._process_anime_aesthetic} |
|
logger.info("Anime Aesthetic model loaded.") |
|
else: |
|
logger.warning("Anime Aesthetic ONNX model failed to load and will be unavailable.") |
|
except Exception as e: |
|
logger.error(f"Failed to load Anime Aesthetic model: {e}") |
|
|
|
logger.info(f"Available models for processing: {list(self.model_configs.keys())}") |
|
|
|
|
|
async def start_worker_if_not_running(self): |
|
if self.worker_task is None or self.worker_task.done(): |
|
self.worker_task = asyncio.create_task(self._worker()) |
|
logger.info("Async worker started.") |
|
|
|
async def _worker(self): |
|
while True: |
|
request = await self.processing_queue.get() |
|
if request is None: |
|
self.processing_queue.task_done() |
|
logger.info("Async worker received shutdown signal.") |
|
break |
|
|
|
future = request.get('future') |
|
try: |
|
if request['type'] == 'run_evaluation_generator': |
|
|
|
|
|
gen = self.run_evaluation_generator(**request['params']) |
|
future.set_result(gen) |
|
else: |
|
logger.warning(f"Unknown request type in worker: {request.get('type')}") |
|
if future: future.set_exception(ValueError("Unknown request type")) |
|
except Exception as e: |
|
logger.error(f"Error in worker processing request: {e}", exc_info=True) |
|
if future: future.set_exception(e) |
|
finally: |
|
self.processing_queue.task_done() |
|
|
|
async def submit_evaluation_request(self, file_paths, auto_batch, manual_batch_size, selected_model_keys): |
|
await self.start_worker_if_not_running() |
|
future = asyncio.Future() |
|
request_item = { |
|
'type': 'run_evaluation_generator', |
|
'params': { |
|
'file_paths': file_paths, |
|
'auto_batch': auto_batch, |
|
'manual_batch_size': manual_batch_size, |
|
'selected_model_keys': selected_model_keys, |
|
}, |
|
'future': future |
|
} |
|
await self.processing_queue.put(request_item) |
|
return await future |
|
|
|
def auto_tune_batch_size(self, images: list[Image.Image], selected_model_keys: list[str]) -> int: |
|
if not images or not selected_model_keys: |
|
return 1 |
|
|
|
batch_size = 1 |
|
max_possible_batch = len(images) |
|
test_image_pil = [images[0].copy()] |
|
|
|
logger.info(f"Auto-tuning batch size with selected models: {selected_model_keys}") |
|
|
|
optimal_batch_size = 1 |
|
while batch_size <= max_possible_batch: |
|
current_test_batch = test_image_pil * batch_size |
|
try: |
|
logger.debug(f"Testing batch size: {batch_size}") |
|
if "aesthetic_shadow" in selected_model_keys and "aesthetic_shadow" in self.models: |
|
_ = self.models["aesthetic_shadow"](current_test_batch, batch_size=batch_size) |
|
if "waifu_scorer" in selected_model_keys and "waifu_scorer" in self.models: |
|
_ = self.models["waifu_scorer"](current_test_batch) |
|
if "aesthetic_predictor_v2_5" in selected_model_keys and "aesthetic_predictor_v2_5" in self.models: |
|
_ = self.models["aesthetic_predictor_v2_5"].inference(current_test_batch) |
|
if "anime_aesthetic" in selected_model_keys and "anime_aesthetic" in self.models: |
|
processed_input = preprocess_anime_aesthetic_batch(current_test_batch) |
|
if processed_input is None: raise ValueError("Anime aesthetic preprocessing failed for test batch") |
|
_ = self.models["anime_aesthetic"].run(None, {"img": processed_input}) |
|
|
|
optimal_batch_size = batch_size |
|
if batch_size * 2 > max_possible_batch : |
|
if max_possible_batch > batch_size: |
|
|
|
pass |
|
break |
|
batch_size *= 2 |
|
|
|
except Exception as e: |
|
logger.warning(f"Auto-tune failed at batch size {batch_size} for at least one model: {e}") |
|
break |
|
|
|
|
|
final_optimal_batch = min(optimal_batch_size, max_possible_batch, 64) |
|
logger.info(f"Optimal batch size determined: {final_optimal_batch}") |
|
return max(1, final_optimal_batch) |
|
|
|
|
|
async def run_evaluation_generator(self, file_paths: list[str], auto_batch: bool, |
|
manual_batch_size: int, selected_model_keys: list[str]): |
|
|
|
log_messages = [] |
|
def _log(msg): |
|
log_messages.append(msg) |
|
logger.info(msg) |
|
|
|
_log("Starting image evaluation...") |
|
yield {"type": "log_update", "messages": log_messages[-20:]} |
|
yield {"type": "progress", "value": 0.0, "desc": "Initiating..."} |
|
|
|
images_pil = [] |
|
file_names = [] |
|
for f_path_str in file_paths: |
|
try: |
|
p = Path(f_path_str) |
|
img = Image.open(p).convert("RGB") |
|
images_pil.append(img) |
|
file_names.append(p.name) |
|
_log(f"Loaded image: {p.name}") |
|
except Exception as e: |
|
_log(f"Error opening {f_path_str}: {e}") |
|
|
|
yield {"type": "log_update", "messages": log_messages[-20:]} |
|
|
|
if not images_pil: |
|
_log("No valid images loaded. Aborting.") |
|
yield {"type": "log_update", "messages": log_messages[-20:]} |
|
yield {"type": "progress", "value": 1.0, "desc": "No images loaded"} |
|
yield {"type": "final_results_state", "data": []} |
|
return |
|
|
|
actual_batch_size = 1 |
|
if auto_batch: |
|
_log("Auto-tuning batch size...") |
|
yield {"type": "log_update", "messages": log_messages[-20:]} |
|
yield {"type": "progress", "value": 0.05, "desc": "Auto-tuning batch size..."} |
|
actual_batch_size = self.auto_tune_batch_size(images_pil, selected_model_keys) |
|
_log(f"Auto-detected batch size: {actual_batch_size}") |
|
else: |
|
actual_batch_size = int(manual_batch_size) if manual_batch_size > 0 else 1 |
|
_log(f"Using manual batch size: {actual_batch_size}") |
|
|
|
yield {"type": "batch_size_update", "value": actual_batch_size} |
|
yield {"type": "log_update", "messages": log_messages[-20:]} |
|
|
|
all_results_for_state = [] |
|
dataframe_rows_so_far = [] |
|
|
|
total_images = len(images_pil) |
|
processed_count = 0 |
|
|
|
for i in range(0, total_images, actual_batch_size): |
|
batch_images_pil = images_pil[i:i+actual_batch_size] |
|
batch_file_names = file_names[i:i+actual_batch_size] |
|
num_in_batch = len(batch_images_pil) |
|
_log(f"Processing batch {i//actual_batch_size + 1}/{ (total_images + actual_batch_size -1) // actual_batch_size }: images {i+1} to {i+num_in_batch}") |
|
yield {"type": "log_update", "messages": log_messages[-20:]} |
|
|
|
batch_model_scores = {key: [None] * num_in_batch for key in self.model_configs.keys()} |
|
|
|
for model_key in selected_model_keys: |
|
if model_key in self.models and model_key in self.model_configs: |
|
_log(f" Running {self.model_configs[model_key]['name']} for batch...") |
|
yield {"type": "log_update", "messages": log_messages[-20:]} |
|
try: |
|
scores = await self.model_configs[model_key]['process_func'](batch_images_pil) |
|
batch_model_scores[model_key] = scores |
|
_log(f" {self.model_configs[model_key]['name']} scores: {scores}") |
|
except Exception as e: |
|
_log(f" Error processing batch with {self.model_configs[model_key]['name']}: {e}") |
|
batch_model_scores[model_key] = [None] * num_in_batch |
|
yield {"type": "log_update", "messages": log_messages[-20:]} |
|
|
|
|
|
current_batch_df_rows = [] |
|
for j in range(num_in_batch): |
|
result_item_state = {'file_name': batch_file_names[j]} |
|
|
|
|
|
thumbnail = batch_images_pil[j].copy() |
|
thumbnail.thumbnail((150, 150)) |
|
result_item_df_row = [thumbnail, batch_file_names[j]] |
|
|
|
|
|
current_image_scores = [] |
|
for model_key in self.model_configs.keys(): |
|
score = batch_model_scores[model_key][j] |
|
result_item_state[model_key] = score |
|
if model_key in selected_model_keys: |
|
result_item_df_row.append(f"{score:.4f}" if isinstance(score, (float, int)) else "N/A") |
|
if isinstance(score, (float, int)) and model_key in selected_model_keys: |
|
current_image_scores.append(score) |
|
|
|
final_score = None |
|
if current_image_scores: |
|
final_score_val = float(np.mean([s for s in current_image_scores if s is not None])) |
|
final_score = float(np.clip(final_score_val, 0.0, 10.0)) |
|
|
|
result_item_state['final_score'] = final_score |
|
result_item_df_row.append(f"{final_score:.4f}" if final_score is not None else "N/A") |
|
|
|
all_results_for_state.append(result_item_state) |
|
current_batch_df_rows.append(result_item_df_row) |
|
|
|
dataframe_rows_so_far.extend(current_batch_df_rows) |
|
|
|
processed_count += num_in_batch |
|
progress_value = processed_count / total_images |
|
yield {"type": "partial_results_df_rows", "data": dataframe_rows_so_far, "selected_model_keys": selected_model_keys} |
|
yield {"type": "progress", "value": progress_value, "desc": f"Processed {processed_count}/{total_images}"} |
|
|
|
_log("All images processed.") |
|
yield {"type": "log_update", "messages": log_messages[-20:]} |
|
yield {"type": "progress", "value": 1.0, "desc": "Completed!"} |
|
yield {"type": "final_results_state", "data": all_results_for_state} |
|
|
|
|
|
async def _process_aesthetic_shadow(self, batch_images: list[Image.Image]) -> list[float | None]: |
|
model = self.models.get("aesthetic_shadow") |
|
if not model: return [None] * len(batch_images) |
|
results = model(batch_images, batch_size=len(batch_images)) |
|
scores = [] |
|
for res_group in results: |
|
|
|
current_res_list = res_group if isinstance(res_group, list) else [res_group] |
|
try: |
|
hq_score_item = next(p for p in current_res_list if p['label'] == 'hq') |
|
score = float(np.clip(hq_score_item['score'] * 10.0, 0.0, 10.0)) |
|
except (StopIteration, KeyError, TypeError): |
|
score = None |
|
scores.append(score) |
|
return scores |
|
|
|
async def _process_waifu_scorer(self, batch_images: list[Image.Image]) -> list[float | None]: |
|
model = self.models.get("waifu_scorer") |
|
if not model: return [None] * len(batch_images) |
|
raw_scores = model(batch_images) |
|
return [float(np.clip(s, 0.0, 10.0)) if s is not None else None for s in raw_scores] |
|
|
|
async def _process_aesthetic_predictor_v2_5(self, batch_images: list[Image.Image]) -> list[float | None]: |
|
model = self.models.get("aesthetic_predictor_v2_5") |
|
if not model: return [None] * len(batch_images) |
|
|
|
return model.inference(batch_images) |
|
|
|
async def _process_anime_aesthetic(self, batch_images: list[Image.Image]) -> list[float | None]: |
|
model = self.models.get("anime_aesthetic") |
|
if not model: return [None] * len(batch_images) |
|
|
|
input_data = preprocess_anime_aesthetic_batch(batch_images) |
|
if input_data is None: |
|
return [None] * len(batch_images) |
|
|
|
try: |
|
preds = model.run(None, {"img": input_data})[0] |
|
scores = [float(np.clip(p.item() * 10.0, 0.0, 10.0)) for p in preds] |
|
return scores |
|
except Exception as e: |
|
logger.error(f"Error during Anime Aesthetic ONNX prediction: {e}") |
|
return [None] * len(batch_images) |
|
|
|
def add_temp_file_for_cleanup(self, file_path: str): |
|
self._temp_files_to_clean.append(file_path) |
|
|
|
async def shutdown_worker(self): |
|
if self.worker_task and not self.worker_task.done(): |
|
logger.info("Attempting to shutdown worker...") |
|
await self.processing_queue.put(None) |
|
try: |
|
await asyncio.wait_for(self.worker_task, timeout=5.0) |
|
logger.info("Worker task finished.") |
|
except asyncio.TimeoutError: |
|
logger.warning("Worker task did not finish in time. Cancelling...") |
|
self.worker_task.cancel() |
|
except Exception as e: |
|
logger.error(f"Exception during worker shutdown: {e}") |
|
await self.processing_queue.join() |
|
logger.info("Processing queue joined.") |
|
self.worker_task = None |
|
|
|
|
|
def cleanup(self): |
|
logger.info("Running cleanup...") |
|
|
|
if self.worker_task: |
|
|
|
if asyncio.get_event_loop().is_running(): |
|
asyncio.create_task(self.shutdown_worker()) |
|
else: |
|
try: |
|
asyncio.run(self.shutdown_worker()) |
|
except RuntimeError as e: |
|
logger.error(f"RuntimeError during cleanup's shutdown_worker: {e}. May need manual loop management.") |
|
|
|
|
|
for f_path in self_temp_files_to_clean: |
|
try: |
|
os.remove(f_path) |
|
logger.info(f"Removed temp file: {f_path}") |
|
except OSError as e: |
|
logger.error(f"Error removing temp file {f_path}: {e}") |
|
self._temp_files_to_clean.clear() |
|
logger.info("Cleanup finished.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_manager = ModelManager(cache_dir=".model_cache") |
|
|
|
def create_interface(): |
|
|
|
|
|
AVAILABLE_MODEL_KEYS = [k for k in model_manager.model_configs.keys() if k in model_manager.models] |
|
AVAILABLE_MODEL_NAMES_MAP = {k: model_manager.model_configs[k]['name'] for k in AVAILABLE_MODEL_KEYS} |
|
|
|
|
|
MODEL_CHOICES_FOR_CHECKBOX = [(AVAILABLE_MODEL_NAMES_MAP[k], k) for k in AVAILABLE_MODEL_KEYS] |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky)) as demo: |
|
gr.Markdown(""" |
|
# Comprehensive Image Evaluation Tool (Refactored) |
|
Upload images to evaluate them using multiple aesthetic and quality prediction models. |
|
Results are displayed in a sortable table with image previews. |
|
""") |
|
|
|
|
|
|
|
|
|
results_state = gr.State([]) |
|
|
|
selected_models_state = gr.State(AVAILABLE_MODEL_KEYS) |
|
|
|
log_messages_state = gr.State([]) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_images = gr.Files(label="Upload Images", file_count="multiple", type="filepath") |
|
|
|
if not MODEL_CHOICES_FOR_CHECKBOX: |
|
gr.Markdown("## No models loaded successfully. Please check logs.") |
|
model_checkboxes = None |
|
else: |
|
model_checkboxes = gr.CheckboxGroup( |
|
choices=MODEL_CHOICES_FOR_CHECKBOX, |
|
label="Select Models", |
|
value=AVAILABLE_MODEL_KEYS, |
|
info="Choose models for evaluation. Final score is an average of selected model scores." |
|
) |
|
|
|
auto_batch_checkbox = gr.Checkbox(label="Automatic Batch Size Detection", value=True) |
|
batch_size_input = gr.Number(label="Manual Batch Size", value=8, minimum=1, precision=0, interactive=False) |
|
|
|
process_btn = gr.Button("Evaluate Images", variant="primary", interactive=bool(MODEL_CHOICES_FOR_CHECKBOX)) |
|
clear_btn = gr.Button("Clear Results") |
|
download_csv_btn = gr.Button("Download Results as CSV", variant="secondary") |
|
|
|
with gr.Column(scale=3): |
|
progress_tracker = gr.Progress(label="Processing Progress") |
|
log_output = gr.Textbox(label="Logs", lines=10, max_lines=20, interactive=False, autoscroll=True) |
|
|
|
|
|
initial_df_headers = ['Image', 'File Name'] + [AVAILABLE_MODEL_NAMES_MAP[k] for k in AVAILABLE_MODEL_KEYS] + ['Final Score'] |
|
results_dataframe = gr.DataFrame( |
|
headers=initial_df_headers, |
|
datatype=['pil'] + ['str'] * (len(initial_df_headers) -1) , |
|
label="Evaluation Results", |
|
interactive=True, |
|
row_count=(10, "dynamic"), |
|
col_count=(len(initial_df_headers), "fixed"), |
|
wrap=True, |
|
) |
|
|
|
download_file_provider = gr.File(label="Download Link", visible=False) |
|
|
|
|
|
def update_batch_size_interactive(auto_detect_enabled: bool): |
|
return gr.Number.update(interactive=not auto_detect_enabled) |
|
|
|
async def handle_process_images_ui( |
|
files_list: list[gr. rýchle.TempFile] | None, |
|
auto_batch_flag: bool, |
|
manual_batch_val: int, |
|
selected_model_keys_from_ui: list[str], |
|
|
|
|
|
|
|
progress_instance: gr.Progress |
|
): |
|
if not files_list: |
|
yield { |
|
log_output: "No files uploaded. Please select images first.", |
|
progress_tracker: gr.Progress(0.0, "Idle. No files."), |
|
results_dataframe: gr.DataFrame.update(value=None), |
|
results_state: [], |
|
selected_models_state: selected_model_keys_from_ui, |
|
log_messages_state: ["No files uploaded. Please select images first."] |
|
} |
|
return |
|
|
|
|
|
yield { selected_models_state: selected_model_keys_from_ui, log_messages_state: [] } |
|
|
|
|
|
actual_file_paths = [f.name for f in files_list] |
|
|
|
current_log_list = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
evaluation_generator = await model_manager.submit_evaluation_request( |
|
actual_file_paths, auto_batch_flag, manual_batch_val, selected_model_keys_from_ui |
|
) |
|
|
|
dataframe_update_value = None |
|
final_results_for_app_state = [] |
|
|
|
async for event in evaluation_generator: |
|
outputs_to_yield = {} |
|
if event["type"] == "log_update": |
|
current_log_list = event["messages"] |
|
outputs_to_yield[log_output] = "\n".join(current_log_list) |
|
elif event["type"] == "progress": |
|
|
|
progress_instance(event["value"], desc=event.get("desc")) |
|
elif event["type"] == "batch_size_update": |
|
outputs_to_yield[batch_size_input] = gr.Number.update(value=event["value"]) |
|
elif event["type"] == "partial_results_df_rows": |
|
|
|
|
|
dynamic_headers = ['Image', 'File Name'] + \ |
|
[AVAILABLE_MODEL_NAMES_MAP[k] for k in event["selected_model_keys"] if k in AVAILABLE_MODEL_NAMES_MAP] + \ |
|
['Final Score'] |
|
dataframe_update_value = pd.DataFrame(event["data"], columns=dynamic_headers) if event["data"] else None |
|
outputs_to_yield[results_dataframe] = gr.DataFrame.update(value=dataframe_update_value, headers=dynamic_headers) |
|
|
|
elif event["type"] == "final_results_state": |
|
final_results_for_app_state = event["data"] |
|
|
|
if outputs_to_yield: |
|
yield outputs_to_yield |
|
|
|
|
|
yield { |
|
results_state: final_results_for_app_state, |
|
log_messages_state: current_log_list, |
|
|
|
} |
|
|
|
|
|
def handle_clear_results_ui(): |
|
|
|
return { |
|
input_images: None, |
|
log_output: "Results cleared.", |
|
results_dataframe: gr.DataFrame.update(value=None, headers=initial_df_headers), |
|
progress_tracker: gr.Progress(0.0, "Idle"), |
|
results_state: [], |
|
|
|
batch_size_input: gr.Number.update(value=8), |
|
log_messages_state: ["Results cleared."] |
|
} |
|
|
|
|
|
def handle_model_selection_or_state_change_ui( |
|
current_selected_keys: list[str], |
|
current_full_results: list[dict] |
|
): |
|
if not current_full_results: |
|
dynamic_headers = ['Image', 'File Name'] + \ |
|
[AVAILABLE_MODEL_NAMES_MAP[k] for k in current_selected_keys if k in AVAILABLE_MODEL_NAMES_MAP] + \ |
|
['Final Score'] |
|
return { |
|
results_dataframe: gr.DataFrame.update(value=None, headers=dynamic_headers), |
|
selected_models_state: current_selected_keys, |
|
results_state: current_full_results |
|
} |
|
|
|
new_df_rows = [] |
|
updated_full_results = [] |
|
|
|
for res_item_dict in current_full_results: |
|
|
|
scores_to_avg = [] |
|
for mk in current_selected_keys: |
|
if mk in res_item_dict and isinstance(res_item_dict[mk], (float, int)): |
|
scores_to_avg.append(res_item_dict[mk]) |
|
|
|
new_final_score = None |
|
if scores_to_avg: |
|
new_final_score_val = float(np.mean(scores_to_avg)) |
|
new_final_score = float(np.clip(new_final_score_val, 0.0, 10.0)) |
|
|
|
|
|
res_item_dict['final_score'] = new_final_score |
|
updated_full_results.append(res_item_dict.copy()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
df_row = [res_item_dict.get('thumbnail_pil_placeholder', "N/A"), res_item_dict['file_name']] |
|
for mk_cfg in AVAILABLE_MODEL_KEYS: |
|
if mk_cfg in current_selected_keys: |
|
score = res_item_dict.get(mk_cfg) |
|
df_row.append(f"{score:.4f}" if isinstance(score, (float, int)) else "N/A") |
|
|
|
df_row.append(f"{new_final_score:.4f}" if new_final_score is not None else "N/A") |
|
new_df_rows.append(df_row) |
|
|
|
dynamic_headers = ['Image', 'File Name'] + \ |
|
[AVAILABLE_MODEL_NAMES_MAP[k] for k in current_selected_keys if k in AVAILABLE_MODEL_NAMES_MAP] + \ |
|
['Final Score'] |
|
|
|
import pandas as pd |
|
df_value = pd.DataFrame(new_df_rows, columns=dynamic_headers) if new_df_rows else None |
|
|
|
return { |
|
results_dataframe: gr.DataFrame.update(value=df_value, headers=dynamic_headers), |
|
selected_models_state: current_selected_keys, |
|
results_state: updated_full_results |
|
} |
|
|
|
|
|
def handle_download_csv_ui(current_full_results: list[dict], current_selected_keys: list[str]): |
|
if not current_full_results: |
|
|
|
return gr.File.update(value=None, visible=False) |
|
|
|
|
|
csv_output = StringIO() |
|
|
|
fieldnames = ['File Name'] + \ |
|
[AVAILABLE_MODEL_NAMES_MAP[k] for k in current_selected_keys if k in AVAILABLE_MODEL_NAMES_MAP] + \ |
|
['Final Score'] |
|
|
|
writer = csv.DictWriter(csv_output, fieldnames=fieldnames, extrasaction='ignore') |
|
writer.writeheader() |
|
|
|
for res_item in current_full_results: |
|
row_to_write = {'File Name': res_item['file_name']} |
|
final_score_val = res_item.get('final_score') |
|
row_to_write['Final Score'] = f"{final_score_val:.4f}" if final_score_val is not None else "N/A" |
|
|
|
for key in current_selected_keys: |
|
if key in AVAILABLE_MODEL_NAMES_MAP: |
|
model_display_name = AVAILABLE_MODEL_NAMES_MAP[key] |
|
score_val = res_item.get(key) |
|
row_to_write[model_display_name] = f"{score_val:.4f}" if isinstance(score_val, (float, int)) else "N/A" |
|
writer.writerow(row_to_write) |
|
|
|
csv_content = csv_output.getvalue() |
|
csv_output.close() |
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv", encoding='utf-8') as tmp_file: |
|
tmp_file.write(csv_content) |
|
temp_file_path = tmp_file.name |
|
|
|
model_manager.add_temp_file_for_cleanup(temp_file_path) |
|
|
|
return gr.File.update(value=temp_file_path, visible=True, label="results.csv") |
|
|
|
|
|
|
|
auto_batch_checkbox.change( |
|
fn=update_batch_size_interactive, |
|
inputs=[auto_batch_checkbox], |
|
outputs=[batch_size_input] |
|
) |
|
|
|
|
|
if model_checkboxes: |
|
process_btn.click( |
|
fn=handle_process_images_ui, |
|
inputs=[input_images, auto_batch_checkbox, batch_size_input, model_checkboxes], |
|
outputs=[ |
|
log_output, progress_tracker, results_dataframe, batch_size_input, |
|
results_state, selected_models_state, log_messages_state |
|
] |
|
) |
|
|
|
model_checkboxes.change( |
|
fn=handle_model_selection_or_state_change_ui, |
|
inputs=[model_checkboxes, results_state], |
|
outputs=[results_dataframe, selected_models_state, results_state] |
|
) |
|
|
|
clear_btn.click( |
|
fn=handle_clear_results_ui, |
|
outputs=[ |
|
input_images, log_output, results_dataframe, progress_tracker, |
|
results_state, batch_size_input, log_messages_state |
|
] |
|
) |
|
|
|
download_csv_btn.click( |
|
fn=handle_download_csv_ui, |
|
inputs=[results_state, selected_models_state], |
|
outputs=[download_file_provider] |
|
) |
|
|
|
|
|
async def initial_load_setup(): |
|
await model_manager.start_worker_if_not_running() |
|
|
|
|
|
return {selected_models_state: AVAILABLE_MODEL_KEYS, log_messages_state: ["Application loaded. Ready."]} |
|
|
|
demo.load( |
|
fn=initial_load_setup, |
|
outputs=[selected_models_state, log_messages_state] |
|
) |
|
|
|
demo.unload(model_manager.cleanup) |
|
|
|
|
|
gr.Markdown(""" |
|
### Notes |
|
- **Model Selection**: Dynamically choose models for evaluation. The 'Final Score' and displayed columns update accordingly. |
|
- **Native Table**: Results are shown in a native Gradio DataFrame, allowing sorting by clicking column headers. |
|
- **Batching**: Automatic batch size detection is enabled by default. You can switch to manual batch sizing. |
|
- **CSV Export**: Download the current results (respecting selected models for columns) as a CSV file. |
|
- **Asynchronous Processing**: Image evaluation runs in the background, providing live updates for logs and progress. |
|
""") |
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not Path("aesthetic_predictor_v2_5.py").exists(): |
|
stub_content = """ |
|
# Placeholder for aesthetic_predictor_v2_5.py |
|
# This file needs to contain the actual 'convert_v2_5_from_siglip' function. |
|
# The main script uses a basic stub if this file is missing or fails to import. |
|
# print("aesthetic_predictor_v2_5.py placeholder executed") |
|
def convert_v2_5_from_siglip(*args, **kwargs): |
|
raise NotImplementedError("This is a placeholder. Implement convert_v2_5_from_siglip here or ensure the main script's stub is used.") |
|
""" |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
app_interface = create_interface() |
|
app_interface.queue().launch(debug=True, share=False) |
|
|
|
|
|
import atexit |
|
atexit.register(model_manager.cleanup) |