File size: 25,997 Bytes
57c45ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
import gradio as gr
from PIL import Image, PngImagePlugin
import io
import os
import pandas as pd
import torch
from transformers import pipeline as transformers_pipeline , AutoImageProcessor, AutoModelForImageClassification
# from torchvision import transforms # Менее релевантно для ONNX пайплайна
from torchmetrics.functional.multimodal import clip_score
from open_clip import create_model_from_pretrained, get_tokenizer
import re
import matplotlib.pyplot as plt
import json
from collections import defaultdict
import numpy as np
import logging # Для логирования ONNX

# --- ONNX Related Imports and Setup ---
try:
    import onnxruntime
except ImportError:
    print("onnxruntime not found. Please ensure it's in requirements.txt")
    onnxruntime = None

from huggingface_hub import hf_hub_download

# imgutils для rgb_encode (если установлен)
try:
    from imgutils.data import rgb_encode # Предполагаем, что это правильный импорт
except ImportError:
    print("imgutils.data.rgb_encode not found. Preprocessing for deepghs might be limited.")
    def rgb_encode(image, order_='CHW'): # Простая заглушка, если imgutils нет
        img_arr = np.array(image)
        if order_ == 'CHW':
            img_arr = np.transpose(img_arr, (2, 0, 1))
        return img_arr.astype(np.float32) / 255.0 # Базовая нормализация, если не указана другая

# --- Модель Конфигурация и Загрузка ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
ONNX_DEVICE = "CUDAExecutionProvider" if DEVICE == "cuda" and onnxruntime and "CUDAExecutionProvider" in onnxruntime.get_available_providers() else "CPUExecutionProvider"
print(f"Using ONNX device: {ONNX_DEVICE}")


# --- Helper for ONNX models (deepghs) ---
@torch.no_grad()
def _img_preprocess_for_onnx(image: Image.Image, size: tuple = (384, 384), normalize_mean=0.5, normalize_std=0.5):
    image = image.resize(size, Image.Resampling.BILINEAR) # Обновлено до Resampling
    data = rgb_encode(image, order_='CHW') # (C, H, W), float32, 0-1 range from common imgutils

    # Нормализация ((data / 255.0) - mean) / std, если data в 0-255
    # Если rgb_encode уже возвращает 0-1, то (data - mean) / std
    # Предположим, rgb_encode возвращает [0,1] диапазон float32
    mean = np.array([normalize_mean] * 3, dtype=np.float32).reshape((3, 1, 1))
    std = np.array([normalize_std] * 3, dtype=np.float32).reshape((3, 1, 1))
    
    normalized_data = (data - mean) / std
    return normalized_data[None, ...].astype(np.float32) # Add batch dimension

onnx_sessions_cache = {} # Кэш для ONNX сессий и метаданных

def get_onnx_session_and_meta(repo_id, model_subfolder):
    cache_key = f"{repo_id}/{model_subfolder}"
    if cache_key in onnx_sessions_cache:
        return onnx_sessions_cache[cache_key]

    if not onnxruntime:
        raise ImportError("ONNX Runtime is not available.")

    try:
        model_path = hf_hub_download(repo_id, filename=f"{model_subfolder}/model.onnx")
        meta_path = hf_hub_download(repo_id, filename=f"{model_subfolder}/meta.json")

        options = onnxruntime.SessionOptions()
        options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
        if ONNX_DEVICE == "CPUExecutionProvider":
             options.intra_op_num_threads = os.cpu_count()

        session = onnxruntime.InferenceSession(model_path, options, providers=[ONNX_DEVICE])
        
        with open(meta_path, 'r') as f:
            meta = json.load(f)
        
        labels = meta.get('labels', [])
        onnx_sessions_cache[cache_key] = (session, labels, meta)
        return session, labels, meta
    except Exception as e:
        print(f"Error loading ONNX model {repo_id}/{model_subfolder}: {e}")
        onnx_sessions_cache[cache_key] = (None, [], None) # Кэшируем ошибку
        return None, [], None


# 1. ImageReward
try:
    reward_processor = AutoImageProcessor.from_pretrained("THUDM/ImageReward")
    reward_model = AutoModelForImageClassification.from_pretrained("THUDM/ImageReward").to(DEVICE)
    reward_model.eval()
except Exception as e:
    print(f"Error loading THUDM/ImageReward: {e}")
    reward_processor, reward_model = None, None

# 2. Anime Aesthetic (deepghs ONNX)
# Модель: deepghs/anime_aesthetic, подпапка: swinv2pv3_v0_448_ls0.2_x
ANIME_AESTHETIC_REPO = "deepghs/anime_aesthetic"
ANIME_AESTHETIC_SUBFOLDER = "swinv2pv3_v0_448_ls0.2_x"
ANIME_AESTHETIC_IMG_SIZE = (448, 448)
# Метки из meta.json: ["normal", "slight", "moderate", "strong", "extreme"]
# Веса для взвешенной суммы:
ANIME_AESTHETIC_LABEL_WEIGHTS = {"normal": 0.0, "slight": 1.0, "moderate": 2.0, "strong": 3.0, "extreme": 4.0}

# 3. MANIQA (Technical Quality) - Transformers pipeline
try:
    maniqa_pipe = transformers_pipeline("image-classification", model="honklers/maniqa-nr", device=torch.device(DEVICE).index if DEVICE=="cuda" else -1)
except Exception as e:
    print(f"Error loading honklers/maniqa-nr: {e}")
    maniqa_pipe = None

# 4. CLIP Score (laion/CLIP-ViT-L-14-laion2B-s32B-b82K) - open_clip
try:
    clip_model_name = 'ViT-L-14'
    clip_pretrained = 'laion2b_s32b_b82k' # laion2B-s32B-b82K
    clip_model_instance, _, clip_preprocess = create_model_from_pretrained(clip_model_name, pretrained=clip_pretrained, device=DEVICE)
    clip_tokenizer = get_tokenizer(clip_model_name)
    clip_model_instance.eval()
except Exception as e:
    print(f"Error loading CLIP model {clip_model_name} ({clip_pretrained}): {e}")
    clip_model_instance, clip_preprocess, clip_tokenizer = None, None, None

# 5. AI Detectors
# Organika/sdxl-detector - Transformers pipeline
try:
    sdxl_detector_pipe = transformers_pipeline("image-classification", model="Organika/sdxl-detector", device=torch.device(DEVICE).index if DEVICE=="cuda" else -1)
except Exception as e:
    print(f"Error loading Organika/sdxl-detector: {e}")
    sdxl_detector_pipe = None

# deepghs/anime_ai_check - ONNX
# Модель: deepghs/anime_ai_check, подпапка: caformer_s36_plus_sce
ANIME_AI_CHECK_REPO = "deepghs/anime_ai_check"
ANIME_AI_CHECK_SUBFOLDER = "caformer_s36_plus_sce"
ANIME_AI_CHECK_IMG_SIZE = (384, 384) # Предположение, если не указано иначе

# --- Функции извлечения метаданных (без изменений) ---
def extract_sd_parameters(image_pil):
    if image_pil is None:
        return "", "N/A", "N/A", "N/A", {}
    
    parameters_str = image_pil.info.get("parameters", "")
    if not parameters_str:
        return "", "N/A", "N/A", "N/A", {}

    prompt = ""
    negative_prompt = ""
    model_name = "N/A"
    model_hash = "N/A"
    other_params_dict = {}

    neg_prompt_index = parameters_str.find("Negative prompt:")
    steps_meta_index = parameters_str.find("Steps:") # Ищем начало блока с параметрами

    if neg_prompt_index != -1:
        prompt = parameters_str[:neg_prompt_index].strip()
        # Если "Steps:" найдено после "Negative prompt:", то neg_prompt между ними
        if steps_meta_index != -1 and steps_meta_index > neg_prompt_index:
            negative_prompt = parameters_str[neg_prompt_index + len("Negative prompt:"):steps_meta_index].strip()
            params_part = parameters_str[steps_meta_index:]
        else: # "Steps:" не найдено или до "Negative prompt:", значит neg_prompt до конца строки или до params_part
            # Если params_part вообще нет, то neg_prompt до конца строки
            end_of_neg_prompt = parameters_str.find("\n", neg_prompt_index) # Ищем конец строки для негативного промпта
            if end_of_neg_prompt == -1: end_of_neg_prompt = len(parameters_str)

            search_params_in_rest = parameters_str[neg_prompt_index + len("Negative prompt:"):]
            actual_steps_index_in_rest = search_params_in_rest.find("Steps:")
            if actual_steps_index_in_rest != -1:
                 negative_prompt = search_params_in_rest[:actual_steps_index_in_rest].strip()
                 params_part = search_params_in_rest[actual_steps_index_in_rest:]
            else: # Нет "Steps:" после "Negative prompt:"
                 negative_prompt = search_params_in_rest.strip() # Берем все как негативный
                 params_part = "" # Нет блока параметров

    else: # "Negative prompt:" не найдено
        # Если "Steps:" найдено, то промпт до него
        if steps_meta_index != -1:
            prompt = parameters_str[:steps_meta_index].strip()
            params_part = parameters_str[steps_meta_index:]
        else: # Нет ни "Negative prompt:", ни "Steps:", весь текст - это промпт
            prompt = parameters_str.strip()
            params_part = ""
    
    if not prompt and not negative_prompt and not params_part: # Если все пусто, возможно, это просто параметры
        params_part = parameters_str

    if params_part:
        params_list = [p.strip() for p in params_part.split(",")]
        temp_other_params = {}
        for param_val_str in params_list:
            parts = param_val_str.split(':', 1)
            if len(parts) == 2:
                key, value = parts[0].strip(), parts[1].strip()
                temp_other_params[key] = value
                if key == "Model": model_name = value
                elif key == "Model hash": model_hash = value
        
        # Добавляем в other_params_dict только то, что не "Model" и не "Model hash"
        for k,v in temp_other_params.items():
            if k not in ["Model", "Model hash"]:
                other_params_dict[k] = v
    
    if model_name == "N/A" and model_hash != "N/A": model_name = f"hash_{model_hash}"
    if model_name == "N/A" and "Checkpoint" in other_params_dict: model_name = other_params_dict["Checkpoint"]

    return prompt, negative_prompt, model_name, model_hash, other_params_dict

# --- Функции оценки (обновлены для deepghs) ---

@torch.no_grad()
def get_image_reward(image_pil):
    if not reward_model or not reward_processor: return "N/A"
    try:
        inputs = reward_processor(images=image_pil, return_tensors="pt").to(DEVICE)
        outputs = reward_model(**inputs)
        return round(outputs.logits.item(), 4)
    except Exception as e:
        print(f"Error in ImageReward: {e}")
        return "Error"

def get_anime_aesthetic_score_deepghs(image_pil):
    session, labels, meta = get_onnx_session_and_meta(ANIME_AESTHETIC_REPO, ANIME_AESTHETIC_SUBFOLDER)
    if not session or not labels: return "N/A"
    try:
        input_data = _img_preprocess_for_onnx(image_pil, size=ANIME_AESTHETIC_IMG_SIZE)
        input_name = session.get_inputs()[0].name
        output_name = session.get_outputs()[0].name
        
        onnx_output, = session.run([output_name], {input_name: input_data})
        
        scores = onnx_output[0] # Должен быть массив вероятностей/логитов
        # Применение softmax если это логиты (обычно модели классификации ONNX возвращают логиты)
        exp_scores = np.exp(scores - np.max(scores)) # Вычитаем max для стабильности softmax
        probabilities = exp_scores / np.sum(exp_scores)

        weighted_score = 0.0
        for i, label in enumerate(labels):
            if label in ANIME_AESTHETIC_LABEL_WEIGHTS:
                weighted_score += probabilities[i] * ANIME_AESTHETIC_LABEL_WEIGHTS[label]
        return round(weighted_score, 4)
    except Exception as e:
        print(f"Error in Anime Aesthetic (ONNX): {e}")
        return "Error"

@torch.no_grad()
def get_maniqa_score(image_pil):
    if not maniqa_pipe: return "N/A"
    try:
        result = maniqa_pipe(image_pil.copy())
        score = 0.0
        # Ищем метку, которая соответствует высокому качеству
        # honklers/maniqa-nr может иметь 'LABEL_0', 'LABEL_1' или 'Good Quality', 'Bad Quality'
        # Проверьте model card. Предположим, более высокий скор для первой метки - хорошо.
        # В данном случае, `honklers/maniqa-nr` выводит [{'label': 'Bad Quality', 'score': 0.9}, {'label': 'Good Quality', 'score': 0.1}]
        # Ищем 'Good Quality'
        for item in result:
            if item['label'].lower() == 'good quality': # или другой позитивный лейбл
                score = item['score']
                break
            # Если нет "Good Quality", но есть что-то вроде LABEL_1 (положительный)
            # elif item['label'] == 'LABEL_1': # Пример, если метки такие
            #     score = item['score']
            #     break
        if score == 0.0 and result: # Если "Good Quality" не найдено, но есть результат
            # Пробуем взять максимальный скор, если метки непонятные, но это рискованно
            # Либо ищем специфичные метки из model card
            pass # Оставляем 0.0 если не найдена позитивная метка

        return round(score, 4)
    except Exception as e:
        print(f"Error in MANIQA: {e}")
        return "Error"

@torch.no_grad()
def calculate_clip_score_value(image_pil, prompt_text): # Изменено имя, чтобы не конфликтовать с torchmetrics.clip_score
    if not clip_model_instance or not clip_preprocess or not clip_tokenizer or not prompt_text or prompt_text == "N/A":
        return "N/A"
    try:
        image_input = clip_preprocess(image_pil).unsqueeze(0).to(DEVICE)
        text_input = clip_tokenizer([str(prompt_text)]).to(DEVICE)
        
        image_features = clip_model_instance.encode_image(image_input)
        text_features = clip_model_instance.encode_text(text_input)
        
        image_features_norm = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
        text_features_norm = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
        score = (text_features_norm @ image_features_norm.T).squeeze().item() * 100.0
        return round(score, 2)
    except Exception as e:
        print(f"Error in CLIP Score: {e}")
        return "Error"

@torch.no_grad()
def get_sdxl_detection_score(image_pil):
    if not sdxl_detector_pipe: return "N/A"
    try:
        result = sdxl_detector_pipe(image_pil.copy())
        ai_score = 0.0
        # Organika/sdxl-detector метки: 'artificial', 'real'
        for item in result:
            if item['label'].lower() == 'artificial':
                ai_score = item['score']
                break
        return round(ai_score, 4)
    except Exception as e:
        print(f"Error in SDXL Detector: {e}")
        return "Error"

def get_anime_ai_check_score_deepghs(image_pil):
    session, labels, meta = get_onnx_session_and_meta(ANIME_AI_CHECK_REPO, ANIME_AI_CHECK_SUBFOLDER)
    if not session or not labels: return "N/A"
    try:
        input_data = _img_preprocess_for_onnx(image_pil, size=ANIME_AI_CHECK_IMG_SIZE)
        input_name = session.get_inputs()[0].name
        output_name = session.get_outputs()[0].name

        onnx_output, = session.run([output_name], {input_name: input_data})
        
        scores = onnx_output[0]
        exp_scores = np.exp(scores - np.max(scores))
        probabilities = exp_scores / np.sum(exp_scores)

        ai_prob = 0.0
        for i, label in enumerate(labels):
            if label.lower() == 'ai': # Ищем метку 'ai'
                ai_prob = probabilities[i]
                break
        return round(ai_prob, 4)
    except Exception as e:
        print(f"Error in Anime AI Check (ONNX): {e}")
        return "Error"

# --- Основная функция обработки ---
def process_images(files, progress=gr.Progress(track_tqdm=True)):
    if not files:
        return pd.DataFrame(), None, None, None, None, "Please upload some images."

    all_results = []
    
    # progress(0, desc="Starting processing...") # track_tqdm сделает это

    for i, file_obj in enumerate(files):
        try:
            # В HF Spaces file_obj может быть именем временного файла или объектом с атрибутом name
            filename = os.path.basename(getattr(file_obj, 'name', str(file_obj))) # getattr для совместимости
            # progress((i+1)/len(files), desc=f"Processing {filename}") # track_tqdm
            
            img = Image.open(getattr(file_obj, 'name', str(file_obj)))
            if img.mode != "RGB":
                img = img.convert("RGB")

            prompt, neg_prompt, model_n, model_h, other_p = extract_sd_parameters(img)

            # Оценки
            reward = get_image_reward(img.copy())
            anime_aes_deepghs = get_anime_aesthetic_score_deepghs(img.copy())
            maniqa = get_maniqa_score(img.copy())
            clip_val = calculate_clip_score_value(img.copy(), prompt) # Изменено имя функции
            sdxl_detect = get_sdxl_detection_score(img.copy())
            anime_ai_chk_deepghs = get_anime_ai_check_score_deepghs(img.copy())

            result_entry = {
                "Filename": filename,
                "Prompt": prompt if prompt else "N/A",
                "Model Name": model_n,
                "Model Hash": model_h,
                "ImageReward": reward,
                "AnimeAesthetic_dg": anime_aes_deepghs, # dg = deepghs
                "MANIQA_TQ": maniqa,
                "CLIPScore": clip_val,
                "SDXL_Detector_AI_Prob": sdxl_detect,
                "AnimeAI_Check_dg_Prob": anime_ai_chk_deepghs, # dg = deepghs
            }
            all_results.append(result_entry)

        except Exception as e:
            print(f"Failed to process {getattr(file_obj, 'name', str(file_obj))}: {e}")
            all_results.append({
                "Filename": os.path.basename(getattr(file_obj, 'name', str(file_obj))) if file_obj else "Unknown File",
                "Prompt": "Error", "Model Name": "Error", "Model Hash": "Error",
                "ImageReward": "Error", "AnimeAesthetic_dg": "Error", "MANIQA_TQ": "Error",
                "CLIPScore": "Error", "SDXL_Detector_AI_Prob": "Error", "AnimeAI_Check_dg_Prob": "Error"
            })

    df = pd.DataFrame(all_results)
    
    plot_model_avg_scores_buffer = None
    if "Model Name" in df.columns and df["Model Name"].nunique() > 0 and df["Model Name"].count() > 0 :
        numeric_cols = ["ImageReward", "AnimeAesthetic_dg", "MANIQA_TQ", "CLIPScore"]
        for col in numeric_cols: df[col] = pd.to_numeric(df[col], errors='coerce')
        try:
            # Фильтруем модели "N/A" перед группировкой для графика
            df_for_plot = df[df["Model Name"] != "N/A"]
            if not df_for_plot.empty and df_for_plot["Model Name"].nunique() > 0 :
                model_avg_scores = df_for_plot.groupby("Model Name")[numeric_cols].mean().dropna(how='all')
                if not model_avg_scores.empty:
                    fig1, ax1 = plt.subplots(figsize=(12, 7))
                    model_avg_scores.plot(kind="bar", ax=ax1)
                    ax1.set_title("Average Scores per Model")
                    ax1.set_ylabel("Average Score")
                    ax1.tick_params(axis='x', rotation=45, labelsize=8)
                    plt.tight_layout()
                    plot_model_avg_scores_buffer = io.BytesIO()
                    fig1.savefig(plot_model_avg_scores_buffer, format="png")
                    plot_model_avg_scores_buffer.seek(0)
                    plt.close(fig1)
        except Exception as e: print(f"Error generating model average scores plot: {e}")

    plot_prompt_clip_scores_buffer = None
    if "Prompt" in df.columns and "CLIPScore" in df.columns and df["Prompt"].nunique() > 0:
        df["CLIPScore"] = pd.to_numeric(df["CLIPScore"], errors='coerce')
        df_prompt_plot = df[df["Prompt"] != "N/A"].dropna(subset=["CLIPScore"])
        if not df_prompt_plot.empty and df_prompt_plot["Prompt"].nunique() > 0:
            try:
                # Сокращаем длинные промпты для графика
                df_prompt_plot["Short Prompt"] = df_prompt_plot["Prompt"].apply(lambda x: (x[:30] + '...') if len(x) > 33 else x)
                prompt_clip_scores = df_prompt_plot.groupby("Short Prompt")["CLIPScore"].mean().sort_values(ascending=False)
                if not prompt_clip_scores.empty and len(prompt_clip_scores) > 1 :
                    fig2, ax2 = plt.subplots(figsize=(12, max(7, min(len(prompt_clip_scores)*0.5, 15)))) # Ограничиваем высоту
                    prompt_clip_scores.head(20).plot(kind="barh", ax=ax2)
                    ax2.set_title("Average CLIPScore per Prompt (Top 20 unique prompts)")
                    ax2.set_xlabel("Average CLIPScore")
                    plt.tight_layout()
                    plot_prompt_clip_scores_buffer = io.BytesIO()
                    fig2.savefig(plot_prompt_clip_scores_buffer, format="png")
                    plot_prompt_clip_scores_buffer.seek(0)
                    plt.close(fig2)
            except Exception as e: print(f"Error generating prompt CLIP scores plot: {e}")

    csv_buffer_val = ""
    if not df.empty:
        csv_buffer = io.StringIO()
        df.to_csv(csv_buffer, index=False)
        csv_buffer_val = csv_buffer.getvalue()
    
    json_buffer_val = ""
    if not df.empty:
        json_buffer = io.StringIO()
        df.to_json(json_buffer, orient='records', indent=4)
        json_buffer_val = json_buffer.getvalue()

    return (
        df,
        gr.Image(value=plot_model_avg_scores_buffer, type="pil", visible=plot_model_avg_scores_buffer is not None),
        gr.Image(value=plot_prompt_clip_scores_buffer, type="pil", visible=plot_prompt_clip_scores_buffer is not None),
        gr.File(value=csv_buffer_val if csv_buffer_val else None, label="Download CSV Results", visible=bool(csv_buffer_val), file_name="evaluation_results.csv"),
        gr.File(value=json_buffer_val if json_buffer_val else None, label="Download JSON Results", visible=bool(json_buffer_val), file_name="evaluation_results.json"),
        f"Processed {len(all_results)} images.",
    )

# --- Интерфейс Gradio ---
with gr.Blocks(css="footer {display: none !important}") as demo:
    gr.Markdown("# AI Image Model Evaluation Tool")
    gr.Markdown(
        "Upload PNG images (ideally with Stable Diffusion metadata) to evaluate them using various metrics. "
        "Results will be displayed in a table and visualized in charts."
    )

    with gr.Row():
        image_uploader = gr.Files(
            label="Upload Images (PNG)",
            file_count="multiple",
            file_types=["image"],
        )

    process_button = gr.Button("Evaluate Images", variant="primary")
    status_textbox = gr.Textbox(label="Status", interactive=False)

    gr.Markdown("## Evaluation Results Table")
    results_table = gr.DataFrame(headers=[
        "Filename", "Prompt", "Model Name", "Model Hash", 
        "ImageReward", "AnimeAesthetic_dg", "MANIQA_TQ", "CLIPScore", 
        "SDXL_Detector_AI_Prob", "AnimeAI_Check_dg_Prob"
    ], wrap=True, max_rows=10) # Ограничиваем начальное отображение строк
    
    with gr.Row():
        download_csv_button = gr.File(label="Download CSV Results", interactive=False) # visible управляется из output
        download_json_button = gr.File(label="Download JSON Results", interactive=False)

    gr.Markdown("## Visualizations")
    with gr.Row():
        plot_output_model_avg = gr.Image(label="Average Scores per Model", type="pil", interactive=False)
        plot_output_prompt_clip = gr.Image(label="Average CLIPScore per Prompt", type="pil", interactive=False)
        
    process_button.click(
        fn=process_images,
        inputs=[image_uploader],
        outputs=[
            results_table, 
            plot_output_model_avg, 
            plot_output_prompt_clip,
            download_csv_button,
            download_json_button,
            status_textbox
        ]
    )

    gr.Markdown(
        """
        **Metric Explanations:**
        - **ImageReward:** General aesthetic and prompt alignment score (higher is better). From THUDM.
        - **AnimeAesthetic_dg:** Aesthetic level for anime style (0-4, higher is better quality level: normal, slight, moderate, strong, extreme). From deepghs (ONNX).
        - **MANIQA_TQ:** Technical Quality score (no-reference), higher indicates better quality (less noise/artifacts). Based on MANIQA.
        - **CLIPScore:** Semantic similarity between the image and its prompt (0-100, higher is better). Uses LAION's CLIP.
        - **SDXL_Detector_AI_Prob:** Estimated probability that the image is AI-generated (higher means more likely AI). From Organika.
        - **AnimeAI_Check_dg_Prob:** Estimated probability that an anime-style image is AI-generated (higher means more likely AI). From deepghs (ONNX).
        
        *Processing can take time, especially for many images or on CPU.*
        """
    )

if __name__ == "__main__":
    demo.launch(debug=True)