img-eval-test / app.py
VOIDER's picture
Create app.py
57c45ff verified
raw
history blame
26 kB
import gradio as gr
from PIL import Image, PngImagePlugin
import io
import os
import pandas as pd
import torch
from transformers import pipeline as transformers_pipeline , AutoImageProcessor, AutoModelForImageClassification
# from torchvision import transforms # Менее релевантно для ONNX пайплайна
from torchmetrics.functional.multimodal import clip_score
from open_clip import create_model_from_pretrained, get_tokenizer
import re
import matplotlib.pyplot as plt
import json
from collections import defaultdict
import numpy as np
import logging # Для логирования ONNX
# --- ONNX Related Imports and Setup ---
try:
import onnxruntime
except ImportError:
print("onnxruntime not found. Please ensure it's in requirements.txt")
onnxruntime = None
from huggingface_hub import hf_hub_download
# imgutils для rgb_encode (если установлен)
try:
from imgutils.data import rgb_encode # Предполагаем, что это правильный импорт
except ImportError:
print("imgutils.data.rgb_encode not found. Preprocessing for deepghs might be limited.")
def rgb_encode(image, order_='CHW'): # Простая заглушка, если imgutils нет
img_arr = np.array(image)
if order_ == 'CHW':
img_arr = np.transpose(img_arr, (2, 0, 1))
return img_arr.astype(np.float32) / 255.0 # Базовая нормализация, если не указана другая
# --- Модель Конфигурация и Загрузка ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
ONNX_DEVICE = "CUDAExecutionProvider" if DEVICE == "cuda" and onnxruntime and "CUDAExecutionProvider" in onnxruntime.get_available_providers() else "CPUExecutionProvider"
print(f"Using ONNX device: {ONNX_DEVICE}")
# --- Helper for ONNX models (deepghs) ---
@torch.no_grad()
def _img_preprocess_for_onnx(image: Image.Image, size: tuple = (384, 384), normalize_mean=0.5, normalize_std=0.5):
image = image.resize(size, Image.Resampling.BILINEAR) # Обновлено до Resampling
data = rgb_encode(image, order_='CHW') # (C, H, W), float32, 0-1 range from common imgutils
# Нормализация ((data / 255.0) - mean) / std, если data в 0-255
# Если rgb_encode уже возвращает 0-1, то (data - mean) / std
# Предположим, rgb_encode возвращает [0,1] диапазон float32
mean = np.array([normalize_mean] * 3, dtype=np.float32).reshape((3, 1, 1))
std = np.array([normalize_std] * 3, dtype=np.float32).reshape((3, 1, 1))
normalized_data = (data - mean) / std
return normalized_data[None, ...].astype(np.float32) # Add batch dimension
onnx_sessions_cache = {} # Кэш для ONNX сессий и метаданных
def get_onnx_session_and_meta(repo_id, model_subfolder):
cache_key = f"{repo_id}/{model_subfolder}"
if cache_key in onnx_sessions_cache:
return onnx_sessions_cache[cache_key]
if not onnxruntime:
raise ImportError("ONNX Runtime is not available.")
try:
model_path = hf_hub_download(repo_id, filename=f"{model_subfolder}/model.onnx")
meta_path = hf_hub_download(repo_id, filename=f"{model_subfolder}/meta.json")
options = onnxruntime.SessionOptions()
options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
if ONNX_DEVICE == "CPUExecutionProvider":
options.intra_op_num_threads = os.cpu_count()
session = onnxruntime.InferenceSession(model_path, options, providers=[ONNX_DEVICE])
with open(meta_path, 'r') as f:
meta = json.load(f)
labels = meta.get('labels', [])
onnx_sessions_cache[cache_key] = (session, labels, meta)
return session, labels, meta
except Exception as e:
print(f"Error loading ONNX model {repo_id}/{model_subfolder}: {e}")
onnx_sessions_cache[cache_key] = (None, [], None) # Кэшируем ошибку
return None, [], None
# 1. ImageReward
try:
reward_processor = AutoImageProcessor.from_pretrained("THUDM/ImageReward")
reward_model = AutoModelForImageClassification.from_pretrained("THUDM/ImageReward").to(DEVICE)
reward_model.eval()
except Exception as e:
print(f"Error loading THUDM/ImageReward: {e}")
reward_processor, reward_model = None, None
# 2. Anime Aesthetic (deepghs ONNX)
# Модель: deepghs/anime_aesthetic, подпапка: swinv2pv3_v0_448_ls0.2_x
ANIME_AESTHETIC_REPO = "deepghs/anime_aesthetic"
ANIME_AESTHETIC_SUBFOLDER = "swinv2pv3_v0_448_ls0.2_x"
ANIME_AESTHETIC_IMG_SIZE = (448, 448)
# Метки из meta.json: ["normal", "slight", "moderate", "strong", "extreme"]
# Веса для взвешенной суммы:
ANIME_AESTHETIC_LABEL_WEIGHTS = {"normal": 0.0, "slight": 1.0, "moderate": 2.0, "strong": 3.0, "extreme": 4.0}
# 3. MANIQA (Technical Quality) - Transformers pipeline
try:
maniqa_pipe = transformers_pipeline("image-classification", model="honklers/maniqa-nr", device=torch.device(DEVICE).index if DEVICE=="cuda" else -1)
except Exception as e:
print(f"Error loading honklers/maniqa-nr: {e}")
maniqa_pipe = None
# 4. CLIP Score (laion/CLIP-ViT-L-14-laion2B-s32B-b82K) - open_clip
try:
clip_model_name = 'ViT-L-14'
clip_pretrained = 'laion2b_s32b_b82k' # laion2B-s32B-b82K
clip_model_instance, _, clip_preprocess = create_model_from_pretrained(clip_model_name, pretrained=clip_pretrained, device=DEVICE)
clip_tokenizer = get_tokenizer(clip_model_name)
clip_model_instance.eval()
except Exception as e:
print(f"Error loading CLIP model {clip_model_name} ({clip_pretrained}): {e}")
clip_model_instance, clip_preprocess, clip_tokenizer = None, None, None
# 5. AI Detectors
# Organika/sdxl-detector - Transformers pipeline
try:
sdxl_detector_pipe = transformers_pipeline("image-classification", model="Organika/sdxl-detector", device=torch.device(DEVICE).index if DEVICE=="cuda" else -1)
except Exception as e:
print(f"Error loading Organika/sdxl-detector: {e}")
sdxl_detector_pipe = None
# deepghs/anime_ai_check - ONNX
# Модель: deepghs/anime_ai_check, подпапка: caformer_s36_plus_sce
ANIME_AI_CHECK_REPO = "deepghs/anime_ai_check"
ANIME_AI_CHECK_SUBFOLDER = "caformer_s36_plus_sce"
ANIME_AI_CHECK_IMG_SIZE = (384, 384) # Предположение, если не указано иначе
# --- Функции извлечения метаданных (без изменений) ---
def extract_sd_parameters(image_pil):
if image_pil is None:
return "", "N/A", "N/A", "N/A", {}
parameters_str = image_pil.info.get("parameters", "")
if not parameters_str:
return "", "N/A", "N/A", "N/A", {}
prompt = ""
negative_prompt = ""
model_name = "N/A"
model_hash = "N/A"
other_params_dict = {}
neg_prompt_index = parameters_str.find("Negative prompt:")
steps_meta_index = parameters_str.find("Steps:") # Ищем начало блока с параметрами
if neg_prompt_index != -1:
prompt = parameters_str[:neg_prompt_index].strip()
# Если "Steps:" найдено после "Negative prompt:", то neg_prompt между ними
if steps_meta_index != -1 and steps_meta_index > neg_prompt_index:
negative_prompt = parameters_str[neg_prompt_index + len("Negative prompt:"):steps_meta_index].strip()
params_part = parameters_str[steps_meta_index:]
else: # "Steps:" не найдено или до "Negative prompt:", значит neg_prompt до конца строки или до params_part
# Если params_part вообще нет, то neg_prompt до конца строки
end_of_neg_prompt = parameters_str.find("\n", neg_prompt_index) # Ищем конец строки для негативного промпта
if end_of_neg_prompt == -1: end_of_neg_prompt = len(parameters_str)
search_params_in_rest = parameters_str[neg_prompt_index + len("Negative prompt:"):]
actual_steps_index_in_rest = search_params_in_rest.find("Steps:")
if actual_steps_index_in_rest != -1:
negative_prompt = search_params_in_rest[:actual_steps_index_in_rest].strip()
params_part = search_params_in_rest[actual_steps_index_in_rest:]
else: # Нет "Steps:" после "Negative prompt:"
negative_prompt = search_params_in_rest.strip() # Берем все как негативный
params_part = "" # Нет блока параметров
else: # "Negative prompt:" не найдено
# Если "Steps:" найдено, то промпт до него
if steps_meta_index != -1:
prompt = parameters_str[:steps_meta_index].strip()
params_part = parameters_str[steps_meta_index:]
else: # Нет ни "Negative prompt:", ни "Steps:", весь текст - это промпт
prompt = parameters_str.strip()
params_part = ""
if not prompt and not negative_prompt and not params_part: # Если все пусто, возможно, это просто параметры
params_part = parameters_str
if params_part:
params_list = [p.strip() for p in params_part.split(",")]
temp_other_params = {}
for param_val_str in params_list:
parts = param_val_str.split(':', 1)
if len(parts) == 2:
key, value = parts[0].strip(), parts[1].strip()
temp_other_params[key] = value
if key == "Model": model_name = value
elif key == "Model hash": model_hash = value
# Добавляем в other_params_dict только то, что не "Model" и не "Model hash"
for k,v in temp_other_params.items():
if k not in ["Model", "Model hash"]:
other_params_dict[k] = v
if model_name == "N/A" and model_hash != "N/A": model_name = f"hash_{model_hash}"
if model_name == "N/A" and "Checkpoint" in other_params_dict: model_name = other_params_dict["Checkpoint"]
return prompt, negative_prompt, model_name, model_hash, other_params_dict
# --- Функции оценки (обновлены для deepghs) ---
@torch.no_grad()
def get_image_reward(image_pil):
if not reward_model or not reward_processor: return "N/A"
try:
inputs = reward_processor(images=image_pil, return_tensors="pt").to(DEVICE)
outputs = reward_model(**inputs)
return round(outputs.logits.item(), 4)
except Exception as e:
print(f"Error in ImageReward: {e}")
return "Error"
def get_anime_aesthetic_score_deepghs(image_pil):
session, labels, meta = get_onnx_session_and_meta(ANIME_AESTHETIC_REPO, ANIME_AESTHETIC_SUBFOLDER)
if not session or not labels: return "N/A"
try:
input_data = _img_preprocess_for_onnx(image_pil, size=ANIME_AESTHETIC_IMG_SIZE)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
onnx_output, = session.run([output_name], {input_name: input_data})
scores = onnx_output[0] # Должен быть массив вероятностей/логитов
# Применение softmax если это логиты (обычно модели классификации ONNX возвращают логиты)
exp_scores = np.exp(scores - np.max(scores)) # Вычитаем max для стабильности softmax
probabilities = exp_scores / np.sum(exp_scores)
weighted_score = 0.0
for i, label in enumerate(labels):
if label in ANIME_AESTHETIC_LABEL_WEIGHTS:
weighted_score += probabilities[i] * ANIME_AESTHETIC_LABEL_WEIGHTS[label]
return round(weighted_score, 4)
except Exception as e:
print(f"Error in Anime Aesthetic (ONNX): {e}")
return "Error"
@torch.no_grad()
def get_maniqa_score(image_pil):
if not maniqa_pipe: return "N/A"
try:
result = maniqa_pipe(image_pil.copy())
score = 0.0
# Ищем метку, которая соответствует высокому качеству
# honklers/maniqa-nr может иметь 'LABEL_0', 'LABEL_1' или 'Good Quality', 'Bad Quality'
# Проверьте model card. Предположим, более высокий скор для первой метки - хорошо.
# В данном случае, `honklers/maniqa-nr` выводит [{'label': 'Bad Quality', 'score': 0.9}, {'label': 'Good Quality', 'score': 0.1}]
# Ищем 'Good Quality'
for item in result:
if item['label'].lower() == 'good quality': # или другой позитивный лейбл
score = item['score']
break
# Если нет "Good Quality", но есть что-то вроде LABEL_1 (положительный)
# elif item['label'] == 'LABEL_1': # Пример, если метки такие
# score = item['score']
# break
if score == 0.0 and result: # Если "Good Quality" не найдено, но есть результат
# Пробуем взять максимальный скор, если метки непонятные, но это рискованно
# Либо ищем специфичные метки из model card
pass # Оставляем 0.0 если не найдена позитивная метка
return round(score, 4)
except Exception as e:
print(f"Error in MANIQA: {e}")
return "Error"
@torch.no_grad()
def calculate_clip_score_value(image_pil, prompt_text): # Изменено имя, чтобы не конфликтовать с torchmetrics.clip_score
if not clip_model_instance or not clip_preprocess or not clip_tokenizer or not prompt_text or prompt_text == "N/A":
return "N/A"
try:
image_input = clip_preprocess(image_pil).unsqueeze(0).to(DEVICE)
text_input = clip_tokenizer([str(prompt_text)]).to(DEVICE)
image_features = clip_model_instance.encode_image(image_input)
text_features = clip_model_instance.encode_text(text_input)
image_features_norm = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
text_features_norm = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
score = (text_features_norm @ image_features_norm.T).squeeze().item() * 100.0
return round(score, 2)
except Exception as e:
print(f"Error in CLIP Score: {e}")
return "Error"
@torch.no_grad()
def get_sdxl_detection_score(image_pil):
if not sdxl_detector_pipe: return "N/A"
try:
result = sdxl_detector_pipe(image_pil.copy())
ai_score = 0.0
# Organika/sdxl-detector метки: 'artificial', 'real'
for item in result:
if item['label'].lower() == 'artificial':
ai_score = item['score']
break
return round(ai_score, 4)
except Exception as e:
print(f"Error in SDXL Detector: {e}")
return "Error"
def get_anime_ai_check_score_deepghs(image_pil):
session, labels, meta = get_onnx_session_and_meta(ANIME_AI_CHECK_REPO, ANIME_AI_CHECK_SUBFOLDER)
if not session or not labels: return "N/A"
try:
input_data = _img_preprocess_for_onnx(image_pil, size=ANIME_AI_CHECK_IMG_SIZE)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
onnx_output, = session.run([output_name], {input_name: input_data})
scores = onnx_output[0]
exp_scores = np.exp(scores - np.max(scores))
probabilities = exp_scores / np.sum(exp_scores)
ai_prob = 0.0
for i, label in enumerate(labels):
if label.lower() == 'ai': # Ищем метку 'ai'
ai_prob = probabilities[i]
break
return round(ai_prob, 4)
except Exception as e:
print(f"Error in Anime AI Check (ONNX): {e}")
return "Error"
# --- Основная функция обработки ---
def process_images(files, progress=gr.Progress(track_tqdm=True)):
if not files:
return pd.DataFrame(), None, None, None, None, "Please upload some images."
all_results = []
# progress(0, desc="Starting processing...") # track_tqdm сделает это
for i, file_obj in enumerate(files):
try:
# В HF Spaces file_obj может быть именем временного файла или объектом с атрибутом name
filename = os.path.basename(getattr(file_obj, 'name', str(file_obj))) # getattr для совместимости
# progress((i+1)/len(files), desc=f"Processing {filename}") # track_tqdm
img = Image.open(getattr(file_obj, 'name', str(file_obj)))
if img.mode != "RGB":
img = img.convert("RGB")
prompt, neg_prompt, model_n, model_h, other_p = extract_sd_parameters(img)
# Оценки
reward = get_image_reward(img.copy())
anime_aes_deepghs = get_anime_aesthetic_score_deepghs(img.copy())
maniqa = get_maniqa_score(img.copy())
clip_val = calculate_clip_score_value(img.copy(), prompt) # Изменено имя функции
sdxl_detect = get_sdxl_detection_score(img.copy())
anime_ai_chk_deepghs = get_anime_ai_check_score_deepghs(img.copy())
result_entry = {
"Filename": filename,
"Prompt": prompt if prompt else "N/A",
"Model Name": model_n,
"Model Hash": model_h,
"ImageReward": reward,
"AnimeAesthetic_dg": anime_aes_deepghs, # dg = deepghs
"MANIQA_TQ": maniqa,
"CLIPScore": clip_val,
"SDXL_Detector_AI_Prob": sdxl_detect,
"AnimeAI_Check_dg_Prob": anime_ai_chk_deepghs, # dg = deepghs
}
all_results.append(result_entry)
except Exception as e:
print(f"Failed to process {getattr(file_obj, 'name', str(file_obj))}: {e}")
all_results.append({
"Filename": os.path.basename(getattr(file_obj, 'name', str(file_obj))) if file_obj else "Unknown File",
"Prompt": "Error", "Model Name": "Error", "Model Hash": "Error",
"ImageReward": "Error", "AnimeAesthetic_dg": "Error", "MANIQA_TQ": "Error",
"CLIPScore": "Error", "SDXL_Detector_AI_Prob": "Error", "AnimeAI_Check_dg_Prob": "Error"
})
df = pd.DataFrame(all_results)
plot_model_avg_scores_buffer = None
if "Model Name" in df.columns and df["Model Name"].nunique() > 0 and df["Model Name"].count() > 0 :
numeric_cols = ["ImageReward", "AnimeAesthetic_dg", "MANIQA_TQ", "CLIPScore"]
for col in numeric_cols: df[col] = pd.to_numeric(df[col], errors='coerce')
try:
# Фильтруем модели "N/A" перед группировкой для графика
df_for_plot = df[df["Model Name"] != "N/A"]
if not df_for_plot.empty and df_for_plot["Model Name"].nunique() > 0 :
model_avg_scores = df_for_plot.groupby("Model Name")[numeric_cols].mean().dropna(how='all')
if not model_avg_scores.empty:
fig1, ax1 = plt.subplots(figsize=(12, 7))
model_avg_scores.plot(kind="bar", ax=ax1)
ax1.set_title("Average Scores per Model")
ax1.set_ylabel("Average Score")
ax1.tick_params(axis='x', rotation=45, labelsize=8)
plt.tight_layout()
plot_model_avg_scores_buffer = io.BytesIO()
fig1.savefig(plot_model_avg_scores_buffer, format="png")
plot_model_avg_scores_buffer.seek(0)
plt.close(fig1)
except Exception as e: print(f"Error generating model average scores plot: {e}")
plot_prompt_clip_scores_buffer = None
if "Prompt" in df.columns and "CLIPScore" in df.columns and df["Prompt"].nunique() > 0:
df["CLIPScore"] = pd.to_numeric(df["CLIPScore"], errors='coerce')
df_prompt_plot = df[df["Prompt"] != "N/A"].dropna(subset=["CLIPScore"])
if not df_prompt_plot.empty and df_prompt_plot["Prompt"].nunique() > 0:
try:
# Сокращаем длинные промпты для графика
df_prompt_plot["Short Prompt"] = df_prompt_plot["Prompt"].apply(lambda x: (x[:30] + '...') if len(x) > 33 else x)
prompt_clip_scores = df_prompt_plot.groupby("Short Prompt")["CLIPScore"].mean().sort_values(ascending=False)
if not prompt_clip_scores.empty and len(prompt_clip_scores) > 1 :
fig2, ax2 = plt.subplots(figsize=(12, max(7, min(len(prompt_clip_scores)*0.5, 15)))) # Ограничиваем высоту
prompt_clip_scores.head(20).plot(kind="barh", ax=ax2)
ax2.set_title("Average CLIPScore per Prompt (Top 20 unique prompts)")
ax2.set_xlabel("Average CLIPScore")
plt.tight_layout()
plot_prompt_clip_scores_buffer = io.BytesIO()
fig2.savefig(plot_prompt_clip_scores_buffer, format="png")
plot_prompt_clip_scores_buffer.seek(0)
plt.close(fig2)
except Exception as e: print(f"Error generating prompt CLIP scores plot: {e}")
csv_buffer_val = ""
if not df.empty:
csv_buffer = io.StringIO()
df.to_csv(csv_buffer, index=False)
csv_buffer_val = csv_buffer.getvalue()
json_buffer_val = ""
if not df.empty:
json_buffer = io.StringIO()
df.to_json(json_buffer, orient='records', indent=4)
json_buffer_val = json_buffer.getvalue()
return (
df,
gr.Image(value=plot_model_avg_scores_buffer, type="pil", visible=plot_model_avg_scores_buffer is not None),
gr.Image(value=plot_prompt_clip_scores_buffer, type="pil", visible=plot_prompt_clip_scores_buffer is not None),
gr.File(value=csv_buffer_val if csv_buffer_val else None, label="Download CSV Results", visible=bool(csv_buffer_val), file_name="evaluation_results.csv"),
gr.File(value=json_buffer_val if json_buffer_val else None, label="Download JSON Results", visible=bool(json_buffer_val), file_name="evaluation_results.json"),
f"Processed {len(all_results)} images.",
)
# --- Интерфейс Gradio ---
with gr.Blocks(css="footer {display: none !important}") as demo:
gr.Markdown("# AI Image Model Evaluation Tool")
gr.Markdown(
"Upload PNG images (ideally with Stable Diffusion metadata) to evaluate them using various metrics. "
"Results will be displayed in a table and visualized in charts."
)
with gr.Row():
image_uploader = gr.Files(
label="Upload Images (PNG)",
file_count="multiple",
file_types=["image"],
)
process_button = gr.Button("Evaluate Images", variant="primary")
status_textbox = gr.Textbox(label="Status", interactive=False)
gr.Markdown("## Evaluation Results Table")
results_table = gr.DataFrame(headers=[
"Filename", "Prompt", "Model Name", "Model Hash",
"ImageReward", "AnimeAesthetic_dg", "MANIQA_TQ", "CLIPScore",
"SDXL_Detector_AI_Prob", "AnimeAI_Check_dg_Prob"
], wrap=True, max_rows=10) # Ограничиваем начальное отображение строк
with gr.Row():
download_csv_button = gr.File(label="Download CSV Results", interactive=False) # visible управляется из output
download_json_button = gr.File(label="Download JSON Results", interactive=False)
gr.Markdown("## Visualizations")
with gr.Row():
plot_output_model_avg = gr.Image(label="Average Scores per Model", type="pil", interactive=False)
plot_output_prompt_clip = gr.Image(label="Average CLIPScore per Prompt", type="pil", interactive=False)
process_button.click(
fn=process_images,
inputs=[image_uploader],
outputs=[
results_table,
plot_output_model_avg,
plot_output_prompt_clip,
download_csv_button,
download_json_button,
status_textbox
]
)
gr.Markdown(
"""
**Metric Explanations:**
- **ImageReward:** General aesthetic and prompt alignment score (higher is better). From THUDM.
- **AnimeAesthetic_dg:** Aesthetic level for anime style (0-4, higher is better quality level: normal, slight, moderate, strong, extreme). From deepghs (ONNX).
- **MANIQA_TQ:** Technical Quality score (no-reference), higher indicates better quality (less noise/artifacts). Based on MANIQA.
- **CLIPScore:** Semantic similarity between the image and its prompt (0-100, higher is better). Uses LAION's CLIP.
- **SDXL_Detector_AI_Prob:** Estimated probability that the image is AI-generated (higher means more likely AI). From Organika.
- **AnimeAI_Check_dg_Prob:** Estimated probability that an anime-style image is AI-generated (higher means more likely AI). From deepghs (ONNX).
*Processing can take time, especially for many images or on CPU.*
"""
)
if __name__ == "__main__":
demo.launch(debug=True)