VOIDER commited on
Commit
713959a
·
verified ·
1 Parent(s): 3bb6d30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +258 -153
app.py CHANGED
@@ -4,166 +4,169 @@ import io
4
  import os
5
  import pandas as pd
6
  import torch
7
- from transformers import pipeline as transformers_pipeline , AutoModelForImageClassification, CLIPImageProcessor # Изменено для ImageReward
8
- # from torchvision import transforms
9
- from torchmetrics.functional.multimodal import clip_score
10
- import open_clip # Изменено для open_clip
11
  import re
12
  import matplotlib.pyplot as plt
13
  import json
14
  from collections import defaultdict
15
  import numpy as np
16
  import logging
 
17
 
18
  # --- ONNX Related Imports and Setup ---
19
  try:
20
  import onnxruntime
21
  except ImportError:
22
- print("onnxruntime not found. Please ensure it's in requirements.txt")
23
  onnxruntime = None
24
 
25
  from huggingface_hub import hf_hub_download
26
 
 
27
  try:
28
  from imgutils.data import rgb_encode
29
  IMGUTILS_AVAILABLE = True
30
- print("imgutils.data.rgb_encode found and will be used.")
31
  except ImportError:
32
- print("imgutils.data.rgb_encode not found. Using a basic fallback for preprocessing deepghs models.")
33
  IMGUTILS_AVAILABLE = False
34
- def rgb_encode(image: Image.Image, order_='CHW'): # Простая заглушка
35
- img_arr = np.array(image.convert("RGB")) # Убедимся что RGB
36
  if order_ == 'CHW':
37
  img_arr = np.transpose(img_arr, (2, 0, 1))
38
- # Эта заглушка возвращает uint8 0-255, как и ожидается далее
39
  return img_arr.astype(np.uint8)
40
 
41
-
42
  # --- Модель Конфигурация и Загрузка ---
43
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
44
- print(f"Using device: {DEVICE}")
45
- ONNX_DEVICE = "CUDAExecutionProvider" if DEVICE == "cuda" and onnxruntime and "CUDAExecutionProvider" in onnxruntime.get_available_providers() else "CPUExecutionProvider"
46
- print(f"Using ONNX device: {ONNX_DEVICE}")
 
 
 
 
47
 
48
  # --- Helper for ONNX models (deepghs) ---
49
  @torch.no_grad()
50
  def _img_preprocess_for_onnx(image: Image.Image, size: tuple = (384, 384), normalize_mean=0.5, normalize_std=0.5):
51
  image = image.resize(size, Image.Resampling.BILINEAR)
52
- data_uint8 = rgb_encode(image, order_='CHW') # (C, H, W), uint8, 0-255
53
  data_float01 = data_uint8.astype(np.float32) / 255.0
54
-
55
  mean = np.array([normalize_mean] * 3, dtype=np.float32).reshape((3, 1, 1))
56
  std = np.array([normalize_std] * 3, dtype=np.float32).reshape((3, 1, 1))
57
-
58
  normalized_data = (data_float01 - mean) / std
59
  return normalized_data[None, ...].astype(np.float32)
60
 
61
  onnx_sessions_cache = {}
62
-
63
- def get_onnx_session_and_meta(repo_id, model_subfolder):
64
  cache_key = f"{repo_id}/{model_subfolder}"
65
  if cache_key in onnx_sessions_cache:
66
  return onnx_sessions_cache[cache_key]
67
 
68
  if not onnxruntime:
69
- # raise ImportError("ONNX Runtime is not available.") # Не будем падать, просто вернем None
70
- print("ONNX Runtime is not available for get_onnx_session_and_meta")
71
- onnx_sessions_cache[cache_key] = (None, [], None)
 
72
  return None, [], None
73
 
74
-
75
  try:
 
 
76
  model_path = hf_hub_download(repo_id, filename=f"{model_subfolder}/model.onnx")
77
  meta_path = hf_hub_download(repo_id, filename=f"{model_subfolder}/meta.json")
78
 
79
  options = onnxruntime.SessionOptions()
80
  options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
81
- if ONNX_DEVICE == "CPUExecutionProvider" and hasattr(os, 'cpu_count'): # hasattr для безопасности
82
  options.intra_op_num_threads = os.cpu_count()
83
 
84
- session = onnxruntime.InferenceSession(model_path, options, providers=[ONNX_DEVICE])
85
-
86
- with open(meta_path, 'r') as f:
87
- meta = json.load(f)
88
-
89
  labels = meta.get('labels', [])
 
 
 
90
  onnx_sessions_cache[cache_key] = (session, labels, meta)
91
  return session, labels, meta
92
  except Exception as e:
93
- print(f"Error loading ONNX model {repo_id}/{model_subfolder}: {e}")
 
94
  onnx_sessions_cache[cache_key] = (None, [], None)
95
  return None, [], None
96
 
 
97
  # 1. ImageReward
 
98
  try:
99
- # THUDM/ImageReward использует CLIPImageProcessor
100
- reward_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") # Типичный процессор для таких моделей
101
  reward_model = AutoModelForImageClassification.from_pretrained("THUDM/ImageReward").to(DEVICE)
102
  reward_model.eval()
103
- print("THUDM/ImageReward loaded successfully.")
104
  except Exception as e:
105
- print(f"Error loading THUDM/ImageReward: {e}")
106
- reward_processor, reward_model = None, None
107
 
108
- # 2. Anime Aesthetic (deepghs ONNX)
109
  ANIME_AESTHETIC_REPO = "deepghs/anime_aesthetic"
110
  ANIME_AESTHETIC_SUBFOLDER = "swinv2pv3_v0_448_ls0.2_x"
111
  ANIME_AESTHETIC_IMG_SIZE = (448, 448)
112
  ANIME_AESTHETIC_LABEL_WEIGHTS = {"normal": 0.0, "slight": 1.0, "moderate": 2.0, "strong": 3.0, "extreme": 4.0}
113
 
114
  # 3. MANIQA (Technical Quality) - ВРЕМЕННО ОТКЛЮЧЕНО
115
- maniqa_pipe = None
116
- print("MANIQA (honklers/maniqa-nr) is temporarily disabled due to loading issues. Will look for alternatives.")
117
- # try:
118
- # maniqa_pipe = transformers_pipeline("image-classification", model="honklers/maniqa-nr", device=torch.device(DEVICE).index if DEVICE=="cuda" else -1)
119
- # except Exception as e:
120
- # print(f"Error loading honklers/maniqa-nr: {e}")
121
- # maniqa_pipe = None
122
 
123
  # 4. CLIP Score (laion/CLIP-ViT-L-14-laion2B-s32B-b82K) - open_clip
 
124
  try:
125
  clip_model_name = 'ViT-L-14'
126
- # Для open_clip, `pretrained` это обычно имя датасета или комбинация
127
- # `laion2b_s32b_b82k` - это один из весов для ViT-L-14
128
- clip_model_instance, clip_preprocess_train, clip_preprocess_val = open_clip.create_model_and_transforms(
129
- clip_model_name,
130
- pretrained='laion2b_s32b_b82k', # Это правильное имя претрейна для open_clip
131
- device=DEVICE
132
  )
133
- clip_preprocess = clip_preprocess_val # Используем preprocess для инференса
134
  clip_tokenizer = open_clip.get_tokenizer(clip_model_name)
135
  clip_model_instance.eval()
136
- print(f"CLIP model {clip_model_name} (laion2b_s32b_b82k) loaded successfully.")
137
  except Exception as e:
138
- print(f"Error loading CLIP model {clip_model_name} (laion2b_s32b_b82k): {e}")
139
- clip_model_instance, clip_preprocess, clip_tokenizer = None, None, None
140
 
141
  # 5. AI Detectors
142
  # Organika/sdxl-detector - Transformers pipeline
 
143
  try:
 
144
  sdxl_detector_pipe = transformers_pipeline("image-classification", model="Organika/sdxl-detector", device=torch.device(DEVICE).index if DEVICE=="cuda" else -1)
145
- print("Organika/sdxl-detector loaded successfully.")
146
  except Exception as e:
147
- print(f"Error loading Organika/sdxl-detector: {e}")
148
- sdxl_detector_pipe = None
149
 
150
- # deepghs/anime_ai_check - ONNX
151
  ANIME_AI_CHECK_REPO = "deepghs/anime_ai_check"
152
  ANIME_AI_CHECK_SUBFOLDER = "caformer_s36_plus_sce"
153
  ANIME_AI_CHECK_IMG_SIZE = (384, 384)
154
 
155
- # --- Функции извлечения метаданных (без изменений) ---
156
- def extract_sd_parameters(image_pil):
 
 
157
  if image_pil is None: return "", "N/A", "N/A", "N/A", {}
158
  parameters_str = image_pil.info.get("parameters", "")
159
- if not parameters_str: return "", "N/A", "N/A", "N/A", {}
 
 
 
 
160
  prompt, negative_prompt, model_name, model_hash, other_params_dict = "", "N/A", "N/A", "N/A", {}
 
161
  try:
162
  neg_prompt_index = parameters_str.find("Negative prompt:")
163
  steps_meta_index = parameters_str.find("Steps:")
164
  if neg_prompt_index != -1:
165
  prompt = parameters_str[:neg_prompt_index].strip()
166
- params_part_start_index = steps_meta_index if steps_meta_index > neg_prompt_index else -1
167
  if params_part_start_index != -1:
168
  negative_prompt = parameters_str[neg_prompt_index + len("Negative prompt:"):params_part_start_index].strip()
169
  params_part = parameters_str[params_part_start_index:]
@@ -176,10 +179,10 @@ def extract_sd_parameters(image_pil):
176
  prompt = parameters_str[:steps_meta_index].strip()
177
  params_part = parameters_str[steps_meta_index:]
178
  else:
179
- prompt = parameters_str.strip()
180
- params_part = ""
181
 
182
- if params_part:
183
  params_list = [p.strip() for p in params_part.split(",")]
184
  temp_other_params = {}
185
  for param_val_str in params_list:
@@ -187,165 +190,220 @@ def extract_sd_parameters(image_pil):
187
  if len(parts) == 2:
188
  key, value = parts[0].strip(), parts[1].strip()
189
  temp_other_params[key] = value
190
- if key == "Model": model_name = value
191
- elif key == "Model hash": model_hash = value
192
  for k,v in temp_other_params.items():
193
- if k not in ["Model", "Model hash"]: other_params_dict[k] = v
194
 
195
  if model_name == "N/A" and model_hash != "N/A": model_name = f"hash_{model_hash}"
196
- # Fallback for model name if only Checkpoint is present (e.g. from ComfyUI)
197
  if model_name == "N/A" and "Checkpoint" in other_params_dict: model_name = other_params_dict["Checkpoint"]
198
- if model_name == "N/A" and "model" in other_params_dict: model_name = other_params_dict["model"] # Another common key
199
-
200
 
201
  except Exception as e:
202
- print(f"Error parsing metadata: {e}")
203
  return prompt, negative_prompt, model_name, model_hash, other_params_dict
204
 
205
- # --- Функции оценки ---
206
  @torch.no_grad()
207
- def get_image_reward(image_pil):
208
- if not reward_model or not reward_processor: return "N/A"
 
 
 
 
209
  try:
210
- # ImageReward ожидает специфическую предобработку, часто как у CLIP
211
  inputs = reward_processor(images=image_pil, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
212
  outputs = reward_model(**inputs)
213
- return round(outputs.logits.item(), 4)
 
 
 
214
  except Exception as e:
215
- print(f"Error in ImageReward: {e}")
216
  return "Error"
217
 
218
- def get_anime_aesthetic_score_deepghs(image_pil):
219
- session, labels, meta = get_onnx_session_and_meta(ANIME_AESTHETIC_REPO, ANIME_AESTHETIC_SUBFOLDER)
220
- if not session or not labels: return "N/A"
 
 
 
 
221
  try:
222
  input_data = _img_preprocess_for_onnx(image_pil.copy(), size=ANIME_AESTHETIC_IMG_SIZE)
223
  input_name = session.get_inputs()[0].name
224
  output_name = session.get_outputs()[0].name
225
  onnx_output, = session.run([output_name], {input_name: input_data})
226
  scores = onnx_output[0]
227
- exp_scores = np.exp(scores - np.max(scores))
228
- probabilities = exp_scores / np.sum(exp_scores)
229
  weighted_score = sum(probabilities[i] * ANIME_AESTHETIC_LABEL_WEIGHTS.get(label, 0.0) for i, label in enumerate(labels))
230
- return round(weighted_score, 4)
 
 
 
231
  except Exception as e:
232
- print(f"Error in Anime Aesthetic (ONNX): {e}")
233
  return "Error"
234
 
235
  @torch.no_grad()
236
- def get_maniqa_score(image_pil): # Временно возвращает N/A
237
- # if not maniqa_pipe: return "N/A"
238
- # try:
239
- # result = maniqa_pipe(image_pil.copy())
240
- # score = 0.0
241
- # for item in result:
242
- # if item['label'].lower() == 'good quality': score = item['score']; break
243
- # return round(score, 4)
244
- # except Exception as e:
245
- # print(f"Error in MANIQA: {e}")
246
- # return "Error"
247
  return "N/A (Disabled)"
248
 
249
-
250
  @torch.no_grad()
251
- def calculate_clip_score_value(image_pil, prompt_text):
252
- if not clip_model_instance or not clip_preprocess or not clip_tokenizer or not prompt_text or prompt_text == "N/A":
 
253
  return "N/A"
 
 
 
 
 
 
254
  try:
255
  image_input = clip_preprocess(image_pil).unsqueeze(0).to(DEVICE)
256
- # Убедимся, что prompt_text это строка, а не None или что-то еще
257
- text_for_tokenizer = str(prompt_text) if prompt_text else ""
258
- if not text_for_tokenizer: return "N/A (Empty Prompt)"
259
-
260
  text_input = clip_tokenizer([text_for_tokenizer]).to(DEVICE)
261
-
262
  image_features = clip_model_instance.encode_image(image_input)
263
  text_features = clip_model_instance.encode_text(text_input)
264
  image_features_norm = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
265
  text_features_norm = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
266
- score = (text_features_norm @ image_features_norm.T).squeeze().item() * 100.0
267
- return round(score, 2)
 
 
 
268
  except Exception as e:
269
- print(f"Error in CLIP Score: {e}")
270
  return "Error"
271
 
272
  @torch.no_grad()
273
- def get_sdxl_detection_score(image_pil):
274
- if not sdxl_detector_pipe: return "N/A"
 
 
 
 
275
  try:
276
  result = sdxl_detector_pipe(image_pil.copy())
277
- ai_score = 0.0
278
  for item in result:
279
- if item['label'].lower() == 'artificial': ai_score = item['score']; break
280
- return round(ai_score, 4)
 
 
 
281
  except Exception as e:
282
- print(f"Error in SDXL Detector: {e}")
283
  return "Error"
284
 
285
- def get_anime_ai_check_score_deepghs(image_pil):
286
- session, labels, meta = get_onnx_session_and_meta(ANIME_AI_CHECK_REPO, ANIME_AI_CHECK_SUBFOLDER)
287
- if not session or not labels: return "N/A"
 
 
 
 
288
  try:
289
  input_data = _img_preprocess_for_onnx(image_pil.copy(), size=ANIME_AI_CHECK_IMG_SIZE)
290
  input_name = session.get_inputs()[0].name
291
  output_name = session.get_outputs()[0].name
292
  onnx_output, = session.run([output_name], {input_name: input_data})
293
  scores = onnx_output[0]
294
- exp_scores = np.exp(scores - np.max(scores))
295
- probabilities = exp_scores / np.sum(exp_scores)
296
- ai_prob = 0.0
297
  for i, label in enumerate(labels):
298
- if label.lower() == 'ai': ai_prob = probabilities[i]; break
299
- return round(ai_prob, 4)
 
 
 
300
  except Exception as e:
301
- print(f"Error in Anime AI Check (ONNX): {e}")
302
  return "Error"
303
 
304
- # --- Основная функция обработки ---
305
- def process_images(files, progress=gr.Progress(track_tqdm=True)):
306
  if not files:
307
- return pd.DataFrame(), None, None, None, None, "Please upload some images."
 
308
 
309
  all_results = []
 
 
 
 
310
  for i, file_obj in enumerate(files):
311
- filename = "Unknown File"
 
312
  try:
313
- # file_obj.name может быть абсолютным путем на сервере
314
- filename = os.path.basename(getattr(file_obj, 'name', f"file_{i}"))
 
 
 
 
 
 
 
315
  img = Image.open(getattr(file_obj, 'name', str(file_obj)))
316
  if img.mode != "RGB": img = img.convert("RGB")
317
 
318
- prompt, neg_prompt, model_n, model_h, other_p = extract_sd_parameters(img)
 
 
 
 
 
 
 
 
 
 
319
 
320
- reward = get_image_reward(img)
321
- anime_aes_deepghs = get_anime_aesthetic_score_deepghs(img)
322
- maniqa = get_maniqa_score(img) # Будет N/A (Disabled)
323
- clip_val = calculate_clip_score_value(img, prompt)
324
- sdxl_detect = get_sdxl_detection_score(img)
325
- anime_ai_chk_deepghs = get_anime_ai_check_score_deepghs(img)
326
 
327
  all_results.append({
328
- "Filename": filename, "Prompt": prompt if prompt else "N/A", "Model Name": model_n, "Model Hash": model_h,
329
  "ImageReward": reward, "AnimeAesthetic_dg": anime_aes_deepghs, "MANIQA_TQ": maniqa,
330
  "CLIPScore": clip_val, "SDXL_Detector_AI_Prob": sdxl_detect, "AnimeAI_Check_dg_Prob": anime_ai_chk_deepghs,
331
  })
 
 
 
 
 
 
 
 
 
332
  except Exception as e:
333
- print(f"CRITICAL: Failed to process {filename}: {e}")
 
334
  all_results.append({
335
- "Filename": filename, "Prompt": "Error", "Model Name": "Error", "Model Hash": "Error",
336
  "ImageReward": "Error", "AnimeAesthetic_dg": "Error", "MANIQA_TQ": "Error",
337
  "CLIPScore": "Error", "SDXL_Detector_AI_Prob": "Error", "AnimeAI_Check_dg_Prob": "Error"
338
  })
 
 
 
 
 
 
 
 
 
339
 
340
  df = pd.DataFrame(all_results)
341
  plot_model_avg_scores_buffer, plot_prompt_clip_scores_buffer = None, None
342
  csv_buffer_val, json_buffer_val = "", ""
343
 
344
  if not df.empty:
345
- numeric_cols = ["ImageReward", "AnimeAesthetic_dg", "MANIQA_TQ", "CLIPScore"]
346
  for col in numeric_cols: df[col] = pd.to_numeric(df[col], errors='coerce')
347
 
348
- # График 1
349
  df_model_plot = df[(df["Model Name"] != "N/A") & (df["Model Name"].notna())]
350
  if not df_model_plot.empty and df_model_plot["Model Name"].nunique() > 0:
351
  try:
@@ -355,57 +413,104 @@ def process_images(files, progress=gr.Progress(track_tqdm=True)):
355
  ax1.set_title("Average Scores per Model"); ax1.set_ylabel("Average Score")
356
  ax1.tick_params(axis='x', rotation=45, labelsize=8); plt.tight_layout()
357
  plot_model_avg_scores_buffer = io.BytesIO(); fig1.savefig(plot_model_avg_scores_buffer, format="png"); plot_model_avg_scores_buffer.seek(0); plt.close(fig1)
358
- except Exception as e: print(f"Error generating model average scores plot: {e}")
 
359
 
360
- # График 2
361
  df_prompt_plot = df[(df["Prompt"] != "N/A") & (df["Prompt"].notna()) & (df["CLIPScore"].notna())]
362
  if not df_prompt_plot.empty and df_prompt_plot["Prompt"].nunique() > 0 :
363
  try:
364
  df_prompt_plot["Short Prompt"] = df_prompt_plot["Prompt"].apply(lambda x: (str(x)[:30] + '...') if len(str(x)) > 33 else str(x))
365
  prompt_clip_scores = df_prompt_plot.groupby("Short Prompt")["CLIPScore"].mean().sort_values(ascending=False)
366
- if not prompt_clip_scores.empty and len(prompt_clip_scores) >= 1 : # Изменено на >=1 для одиночных промптов
367
  fig2, ax2 = plt.subplots(figsize=(12, max(7, min(len(prompt_clip_scores)*0.5, 15))))
368
  prompt_clip_scores.head(20).plot(kind="barh", ax=ax2)
369
  ax2.set_title("Average CLIPScore per Prompt (Top 20 unique prompts)"); ax2.set_xlabel("Average CLIPScore")
370
  plt.tight_layout(); plot_prompt_clip_scores_buffer = io.BytesIO(); fig2.savefig(plot_prompt_clip_scores_buffer, format="png"); plot_prompt_clip_scores_buffer.seek(0); plt.close(fig2)
371
- except Exception as e: print(f"Error generating prompt CLIP scores plot: {e}")
 
372
 
373
  csv_b = io.StringIO(); df.to_csv(csv_b, index=False); csv_buffer_val = csv_b.getvalue()
374
  json_b = io.StringIO(); df.to_json(json_b, orient='records', indent=4); json_buffer_val = json_b.getvalue()
 
375
 
376
- return (
 
 
 
 
377
  df,
378
  gr.Image(value=plot_model_avg_scores_buffer, type="pil", visible=plot_model_avg_scores_buffer is not None),
379
  gr.Image(value=plot_prompt_clip_scores_buffer, type="pil", visible=plot_prompt_clip_scores_buffer is not None),
380
  gr.File(value=csv_buffer_val or None, label="Download CSV Results", visible=bool(csv_buffer_val), file_name="evaluation_results.csv"),
381
  gr.File(value=json_buffer_val or None, label="Download JSON Results", visible=bool(json_buffer_val), file_name="evaluation_results.json"),
382
- f"Processed {len(all_results)} images.",
 
383
  )
384
 
 
385
  # --- Интерфейс Gradio ---
386
  with gr.Blocks(css="footer {display: none !important}") as demo:
387
  gr.Markdown("# AI Image Model Evaluation Tool")
388
  gr.Markdown("Upload PNG images (ideally with Stable Diffusion metadata) to evaluate them...")
389
- with gr.Row(): image_uploader = gr.Files(label="Upload Images (PNG)", file_count="multiple", file_types=["image"])
 
 
 
 
 
 
 
390
  process_button = gr.Button("Evaluate Images", variant="primary")
391
- status_textbox = gr.Textbox(label="Status", interactive=False)
 
 
 
 
392
  gr.Markdown("## Evaluation Results Table")
393
- results_table = gr.DataFrame(headers=[ # Убран max_rows
394
  "Filename", "Prompt", "Model Name", "Model Hash", "ImageReward", "AnimeAesthetic_dg",
395
  "MANIQA_TQ", "CLIPScore", "SDXL_Detector_AI_Prob", "AnimeAI_Check_dg_Prob"
396
  ], wrap=True)
 
397
  with gr.Row():
398
  download_csv_button = gr.File(label="Download CSV Results", interactive=False)
399
  download_json_button = gr.File(label="Download JSON Results", interactive=False)
 
400
  gr.Markdown("## Visualizations")
401
  with gr.Row():
402
  plot_output_model_avg = gr.Image(label="Average Scores per Model", type="pil", interactive=False)
403
  plot_output_prompt_clip = gr.Image(label="Average CLIPScore per Prompt", type="pil", interactive=False)
404
- process_button.click(fn=process_images, inputs=[image_uploader], outputs=[
405
- results_table, plot_output_model_avg, plot_output_prompt_clip,
406
- download_csv_button, download_json_button, status_textbox
407
- ])
 
 
 
 
 
 
 
 
 
 
 
408
  gr.Markdown("""**Metric Explanations:** ... (без изменений)""")
409
 
410
  if __name__ == "__main__":
411
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import os
5
  import pandas as pd
6
  import torch
7
+ from transformers import pipeline as transformers_pipeline , AutoModelForImageClassification, CLIPImageProcessor
8
+ import open_clip
 
 
9
  import re
10
  import matplotlib.pyplot as plt
11
  import json
12
  from collections import defaultdict
13
  import numpy as np
14
  import logging
15
+ import time # Для замера времени
16
 
17
  # --- ONNX Related Imports and Setup ---
18
  try:
19
  import onnxruntime
20
  except ImportError:
21
+ print("WARNING: onnxruntime not found. ONNX models will not be available.")
22
  onnxruntime = None
23
 
24
  from huggingface_hub import hf_hub_download
25
 
26
+ # imgutils для rgb_encode
27
  try:
28
  from imgutils.data import rgb_encode
29
  IMGUTILS_AVAILABLE = True
30
+ print("INFO: imgutils.data.rgb_encode found and will be used for deepghs models.")
31
  except ImportError:
32
+ print("WARNING: imgutils.data.rgb_encode not found. Using a basic fallback for preprocessing deepghs models.")
33
  IMGUTILS_AVAILABLE = False
34
+ def rgb_encode(image: Image.Image, order_='CHW'):
35
+ img_arr = np.array(image.convert("RGB"))
36
  if order_ == 'CHW':
37
  img_arr = np.transpose(img_arr, (2, 0, 1))
 
38
  return img_arr.astype(np.uint8)
39
 
 
40
  # --- Модель Конфигурация и Загрузка ---
41
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
42
+ print(f"INFO: PyTorch Device: {DEVICE}")
43
+ ONNX_EXECUTION_PROVIDER = "CUDAExecutionProvider" if DEVICE == "cuda" and onnxruntime and "CUDAExecutionProvider" in onnxruntime.get_available_providers() else "CPUExecutionProvider"
44
+ if onnxruntime:
45
+ print(f"INFO: ONNX Execution Provider: {ONNX_EXECUTION_PROVIDER}")
46
+ else:
47
+ print("INFO: ONNX Runtime not available, ONNX models will be skipped.")
48
+
49
 
50
  # --- Helper for ONNX models (deepghs) ---
51
  @torch.no_grad()
52
  def _img_preprocess_for_onnx(image: Image.Image, size: tuple = (384, 384), normalize_mean=0.5, normalize_std=0.5):
53
  image = image.resize(size, Image.Resampling.BILINEAR)
54
+ data_uint8 = rgb_encode(image, order_='CHW')
55
  data_float01 = data_uint8.astype(np.float32) / 255.0
 
56
  mean = np.array([normalize_mean] * 3, dtype=np.float32).reshape((3, 1, 1))
57
  std = np.array([normalize_std] * 3, dtype=np.float32).reshape((3, 1, 1))
 
58
  normalized_data = (data_float01 - mean) / std
59
  return normalized_data[None, ...].astype(np.float32)
60
 
61
  onnx_sessions_cache = {}
62
+ def get_onnx_session_and_meta(repo_id, model_subfolder, current_log_list):
 
63
  cache_key = f"{repo_id}/{model_subfolder}"
64
  if cache_key in onnx_sessions_cache:
65
  return onnx_sessions_cache[cache_key]
66
 
67
  if not onnxruntime:
68
+ msg = f"ERROR: ONNX Runtime not available for get_onnx_session_and_meta ({cache_key}). Skipping."
69
+ print(msg)
70
+ current_log_list.append(msg)
71
+ onnx_sessions_cache[cache_key] = (None, [], None) # Cache error state
72
  return None, [], None
73
 
 
74
  try:
75
+ msg = f"INFO: Loading ONNX model {repo_id}/{model_subfolder}..."
76
+ print(msg); current_log_list.append(msg)
77
  model_path = hf_hub_download(repo_id, filename=f"{model_subfolder}/model.onnx")
78
  meta_path = hf_hub_download(repo_id, filename=f"{model_subfolder}/meta.json")
79
 
80
  options = onnxruntime.SessionOptions()
81
  options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
82
+ if ONNX_EXECUTION_PROVIDER == "CPUExecutionProvider" and hasattr(os, 'cpu_count'):
83
  options.intra_op_num_threads = os.cpu_count()
84
 
85
+ session = onnxruntime.InferenceSession(model_path, options, providers=[ONNX_EXECUTION_PROVIDER])
86
+ with open(meta_path, 'r') as f: meta = json.load(f)
 
 
 
87
  labels = meta.get('labels', [])
88
+
89
+ msg = f"INFO: ONNX model {cache_key} loaded successfully with provider {ONNX_EXECUTION_PROVIDER}."
90
+ print(msg); current_log_list.append(msg)
91
  onnx_sessions_cache[cache_key] = (session, labels, meta)
92
  return session, labels, meta
93
  except Exception as e:
94
+ msg = f"ERROR: Failed to load ONNX model {cache_key}: {e}"
95
+ print(msg); current_log_list.append(msg)
96
  onnx_sessions_cache[cache_key] = (None, [], None)
97
  return None, [], None
98
 
99
+ # --- Модели PyTorch и Transformers ---
100
  # 1. ImageReward
101
+ reward_processor, reward_model = None, None
102
  try:
103
+ print("INFO: Loading THUDM/ImageReward model...")
104
+ reward_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
105
  reward_model = AutoModelForImageClassification.from_pretrained("THUDM/ImageReward").to(DEVICE)
106
  reward_model.eval()
107
+ print("INFO: THUDM/ImageReward loaded successfully.")
108
  except Exception as e:
109
+ print(f"ERROR: Failed to load THUDM/ImageReward: {e}")
 
110
 
111
+ # 2. Anime Aesthetic (deepghs ONNX) - Константы
112
  ANIME_AESTHETIC_REPO = "deepghs/anime_aesthetic"
113
  ANIME_AESTHETIC_SUBFOLDER = "swinv2pv3_v0_448_ls0.2_x"
114
  ANIME_AESTHETIC_IMG_SIZE = (448, 448)
115
  ANIME_AESTHETIC_LABEL_WEIGHTS = {"normal": 0.0, "slight": 1.0, "moderate": 2.0, "strong": 3.0, "extreme": 4.0}
116
 
117
  # 3. MANIQA (Technical Quality) - ВРЕМЕННО ОТКЛЮЧЕНО
118
+ # maniqa_pipe = None (уже объявлено в глобальной области видимости неявно)
119
+ print("INFO: MANIQA (honklers/maniqa-nr) is currently disabled.")
 
 
 
 
 
120
 
121
  # 4. CLIP Score (laion/CLIP-ViT-L-14-laion2B-s32B-b82K) - open_clip
122
+ clip_model_instance, clip_preprocess, clip_tokenizer = None, None, None
123
  try:
124
  clip_model_name = 'ViT-L-14'
125
+ print(f"INFO: Loading CLIP model {clip_model_name} (laion2b_s32b_b82k)...")
126
+ clip_model_instance, _, clip_preprocess_val = open_clip.create_model_and_transforms(
127
+ clip_model_name, pretrained='laion2b_s32b_b82k', device=DEVICE
 
 
 
128
  )
129
+ clip_preprocess = clip_preprocess_val
130
  clip_tokenizer = open_clip.get_tokenizer(clip_model_name)
131
  clip_model_instance.eval()
132
+ print(f"INFO: CLIP model {clip_model_name} (laion2b_s32b_b82k) loaded successfully.")
133
  except Exception as e:
134
+ print(f"ERROR: Failed to load CLIP model {clip_model_name} (laion2b_s32b_b82k): {e}")
 
135
 
136
  # 5. AI Detectors
137
  # Organika/sdxl-detector - Transformers pipeline
138
+ sdxl_detector_pipe = None
139
  try:
140
+ print("INFO: Loading Organika/sdxl-detector model...")
141
  sdxl_detector_pipe = transformers_pipeline("image-classification", model="Organika/sdxl-detector", device=torch.device(DEVICE).index if DEVICE=="cuda" else -1)
142
+ print("INFO: Organika/sdxl-detector loaded successfully.")
143
  except Exception as e:
144
+ print(f"ERROR: Failed to load Organika/sdxl-detector: {e}")
 
145
 
146
+ # deepghs/anime_ai_check - ONNX - Константы
147
  ANIME_AI_CHECK_REPO = "deepghs/anime_ai_check"
148
  ANIME_AI_CHECK_SUBFOLDER = "caformer_s36_plus_sce"
149
  ANIME_AI_CHECK_IMG_SIZE = (384, 384)
150
 
151
+
152
+ # --- Функции извлечения метаданных (без изменений в логике, только print) ---
153
+ def extract_sd_parameters(image_pil, filename_for_log, current_log_list):
154
+ # ... (остальной код extract_sd_parameters без изменений)
155
  if image_pil is None: return "", "N/A", "N/A", "N/A", {}
156
  parameters_str = image_pil.info.get("parameters", "")
157
+ if not parameters_str:
158
+ current_log_list.append(f"DEBUG [{filename_for_log}]: No metadata found in image.")
159
+ return "", "N/A", "N/A", "N/A", {}
160
+
161
+ current_log_list.append(f"DEBUG [{filename_for_log}]: Raw metadata: {parameters_str[:100]}...") # Логируем начало
162
  prompt, negative_prompt, model_name, model_hash, other_params_dict = "", "N/A", "N/A", "N/A", {}
163
+ # ... (остальной парсинг)
164
  try:
165
  neg_prompt_index = parameters_str.find("Negative prompt:")
166
  steps_meta_index = parameters_str.find("Steps:")
167
  if neg_prompt_index != -1:
168
  prompt = parameters_str[:neg_prompt_index].strip()
169
+ params_part_start_index = steps_meta_index if steps_meta_index != -1 and steps_meta_index > neg_prompt_index else -1
170
  if params_part_start_index != -1:
171
  negative_prompt = parameters_str[neg_prompt_index + len("Negative prompt:"):params_part_start_index].strip()
172
  params_part = parameters_str[params_part_start_index:]
 
179
  prompt = parameters_str[:steps_meta_index].strip()
180
  params_part = parameters_str[steps_meta_index:]
181
  else:
182
+ prompt = parameters_str.strip() # Весь текст - промпт
183
+ params_part = "" # Нет блока параметров
184
 
185
+ if params_part: # Если есть блок параметров после Steps:
186
  params_list = [p.strip() for p in params_part.split(",")]
187
  temp_other_params = {}
188
  for param_val_str in params_list:
 
190
  if len(parts) == 2:
191
  key, value = parts[0].strip(), parts[1].strip()
192
  temp_other_params[key] = value
193
+ if key.lower() == "model": model_name = value
194
+ elif key.lower() == "model hash": model_hash = value
195
  for k,v in temp_other_params.items():
196
+ if k.lower() not in ["model", "model hash"]: other_params_dict[k] = v
197
 
198
  if model_name == "N/A" and model_hash != "N/A": model_name = f"hash_{model_hash}"
 
199
  if model_name == "N/A" and "Checkpoint" in other_params_dict: model_name = other_params_dict["Checkpoint"]
200
+ if model_name == "N/A" and "model" in other_params_dict: model_name = other_params_dict["model"]
201
+ current_log_list.append(f"DEBUG [{filename_for_log}]: Parsed Prompt: {prompt[:50]}... | Model: {model_name}")
202
 
203
  except Exception as e:
204
+ current_log_list.append(f"ERROR [{filename_for_log}]: Failed to parse metadata: {e}")
205
  return prompt, negative_prompt, model_name, model_hash, other_params_dict
206
 
207
+ # --- Функции оценки (добавлено логирование и замер времени) ---
208
  @torch.no_grad()
209
+ def get_image_reward(image_pil, filename_for_log, current_log_list):
210
+ if not reward_model or not reward_processor:
211
+ current_log_list.append(f"INFO [{filename_for_log}]: ImageReward model not loaded, skipping.")
212
+ return "N/A"
213
+ t_start = time.time()
214
+ current_log_list.append(f"DEBUG [{filename_for_log}]: Starting ImageReward score (PyTorch Device: {DEVICE})...")
215
  try:
 
216
  inputs = reward_processor(images=image_pil, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
217
  outputs = reward_model(**inputs)
218
+ score = round(outputs.logits.item(), 4)
219
+ t_end = time.time()
220
+ current_log_list.append(f"DEBUG [{filename_for_log}]: ImageReward score: {score} (took {t_end - t_start:.2f}s)")
221
+ return score
222
  except Exception as e:
223
+ current_log_list.append(f"ERROR [{filename_for_log}]: ImageReward scoring failed: {e}")
224
  return "Error"
225
 
226
+ def get_anime_aesthetic_score_deepghs(image_pil, filename_for_log, current_log_list):
227
+ session, labels, meta = get_onnx_session_and_meta(ANIME_AESTHETIC_REPO, ANIME_AESTHETIC_SUBFOLDER, current_log_list)
228
+ if not session or not labels:
229
+ current_log_list.append(f"INFO [{filename_for_log}]: AnimeAesthetic ONNX model not loaded, skipping.")
230
+ return "N/A"
231
+ t_start = time.time()
232
+ current_log_list.append(f"DEBUG [{filename_for_log}]: Starting AnimeAesthetic (ONNX) score...")
233
  try:
234
  input_data = _img_preprocess_for_onnx(image_pil.copy(), size=ANIME_AESTHETIC_IMG_SIZE)
235
  input_name = session.get_inputs()[0].name
236
  output_name = session.get_outputs()[0].name
237
  onnx_output, = session.run([output_name], {input_name: input_data})
238
  scores = onnx_output[0]
239
+ exp_scores = np.exp(scores - np.max(scores)); probabilities = exp_scores / np.sum(exp_scores)
 
240
  weighted_score = sum(probabilities[i] * ANIME_AESTHETIC_LABEL_WEIGHTS.get(label, 0.0) for i, label in enumerate(labels))
241
+ score = round(weighted_score, 4)
242
+ t_end = time.time()
243
+ current_log_list.append(f"DEBUG [{filename_for_log}]: AnimeAesthetic (ONNX) score: {score} (took {t_end - t_start:.2f}s)")
244
+ return score
245
  except Exception as e:
246
+ current_log_list.append(f"ERROR [{filename_for_log}]: AnimeAesthetic (ONNX) scoring failed: {e}")
247
  return "Error"
248
 
249
  @torch.no_grad()
250
+ def get_maniqa_score(image_pil, filename_for_log, current_log_list):
251
+ current_log_list.append(f"INFO [{filename_for_log}]: MANIQA is disabled.")
 
 
 
 
 
 
 
 
 
252
  return "N/A (Disabled)"
253
 
 
254
  @torch.no_grad()
255
+ def calculate_clip_score_value(image_pil, prompt_text, filename_for_log, current_log_list):
256
+ if not clip_model_instance or not clip_preprocess or not clip_tokenizer:
257
+ current_log_list.append(f"INFO [{filename_for_log}]: CLIP model not loaded, skipping CLIPScore.")
258
  return "N/A"
259
+ if not prompt_text or prompt_text == "N/A":
260
+ current_log_list.append(f"INFO [{filename_for_log}]: Empty prompt, skipping CLIPScore.")
261
+ return "N/A (Empty Prompt)"
262
+
263
+ t_start = time.time()
264
+ current_log_list.append(f"DEBUG [{filename_for_log}]: Starting CLIPScore (PyTorch Device: {DEVICE})...")
265
  try:
266
  image_input = clip_preprocess(image_pil).unsqueeze(0).to(DEVICE)
267
+ text_for_tokenizer = str(prompt_text)
 
 
 
268
  text_input = clip_tokenizer([text_for_tokenizer]).to(DEVICE)
 
269
  image_features = clip_model_instance.encode_image(image_input)
270
  text_features = clip_model_instance.encode_text(text_input)
271
  image_features_norm = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
272
  text_features_norm = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
273
+ score_val = (text_features_norm @ image_features_norm.T).squeeze().item() * 100.0
274
+ score = round(score_val, 2)
275
+ t_end = time.time()
276
+ current_log_list.append(f"DEBUG [{filename_for_log}]: CLIPScore: {score} (took {t_end - t_start:.2f}s)")
277
+ return score
278
  except Exception as e:
279
+ current_log_list.append(f"ERROR [{filename_for_log}]: CLIPScore calculation failed: {e}")
280
  return "Error"
281
 
282
  @torch.no_grad()
283
+ def get_sdxl_detection_score(image_pil, filename_for_log, current_log_list):
284
+ if not sdxl_detector_pipe:
285
+ current_log_list.append(f"INFO [{filename_for_log}]: SDXL_Detector model not loaded, skipping.")
286
+ return "N/A"
287
+ t_start = time.time()
288
+ current_log_list.append(f"DEBUG [{filename_for_log}]: Starting SDXL_Detector score (Device for pipeline: {sdxl_detector_pipe.device})...")
289
  try:
290
  result = sdxl_detector_pipe(image_pil.copy())
291
+ ai_score_val = 0.0
292
  for item in result:
293
+ if item['label'].lower() == 'artificial': ai_score_val = item['score']; break
294
+ score = round(ai_score_val, 4)
295
+ t_end = time.time()
296
+ current_log_list.append(f"DEBUG [{filename_for_log}]: SDXL_Detector AI Prob: {score} (took {t_end - t_start:.2f}s)")
297
+ return score
298
  except Exception as e:
299
+ current_log_list.append(f"ERROR [{filename_for_log}]: SDXL_Detector scoring failed: {e}")
300
  return "Error"
301
 
302
+ def get_anime_ai_check_score_deepghs(image_pil, filename_for_log, current_log_list):
303
+ session, labels, meta = get_onnx_session_and_meta(ANIME_AI_CHECK_REPO, ANIME_AI_CHECK_SUBFOLDER, current_log_list)
304
+ if not session or not labels:
305
+ current_log_list.append(f"INFO [{filename_for_log}]: AnimeAI_Check ONNX model not loaded, skipping.")
306
+ return "N/A"
307
+ t_start = time.time()
308
+ current_log_list.append(f"DEBUG [{filename_for_log}]: Starting AnimeAI_Check (ONNX) score...")
309
  try:
310
  input_data = _img_preprocess_for_onnx(image_pil.copy(), size=ANIME_AI_CHECK_IMG_SIZE)
311
  input_name = session.get_inputs()[0].name
312
  output_name = session.get_outputs()[0].name
313
  onnx_output, = session.run([output_name], {input_name: input_data})
314
  scores = onnx_output[0]
315
+ exp_scores = np.exp(scores - np.max(scores)); probabilities = exp_scores / np.sum(exp_scores)
316
+ ai_prob_val = 0.0
 
317
  for i, label in enumerate(labels):
318
+ if label.lower() == 'ai': ai_prob_val = probabilities[i]; break
319
+ score = round(ai_prob_val, 4)
320
+ t_end = time.time()
321
+ current_log_list.append(f"DEBUG [{filename_for_log}]: AnimeAI_Check (ONNX) AI Prob: {score} (took {t_end - t_start:.2f}s)")
322
+ return score
323
  except Exception as e:
324
+ current_log_list.append(f"ERROR [{filename_for_log}]: AnimeAI_Check (ONNX) scoring failed: {e}")
325
  return "Error"
326
 
327
+ # --- Основная функция обработки (стала генератором) ---
328
+ def process_images_generator(files, progress=gr.Progress(track_tqdm=True)):
329
  if not files:
330
+ yield pd.DataFrame(), None, None, None, None, "Please upload some images.", "No files to process."
331
+ return
332
 
333
  all_results = []
334
+ log_accumulator = [f"INFO: Starting processing for {len(files)} images..."]
335
+ yield pd.DataFrame(), None, None, None, None, "Processing...", "\n".join(log_accumulator)
336
+
337
+
338
  for i, file_obj in enumerate(files):
339
+ filename_for_log = "Unknown File"
340
+ current_img_total_time_start = time.time()
341
  try:
342
+ filename_for_log = os.path.basename(getattr(file_obj, 'name', f"file_{i}_{time.time()}"))
343
+ log_accumulator.append(f"--- Processing image {i+1}/{len(files)}: {filename_for_log} ---")
344
+
345
+ # Обновляем UI перед началом обработки файла
346
+ progress.update(amount=(i+1)/len(files), desc=f"Img {i+1}/{len(files)}: {filename_for_log}")
347
+ yield (pd.DataFrame(all_results), None, None, None, None,
348
+ f"Processing image {i+1}/{len(files)}: {filename_for_log}",
349
+ "\n".join(log_accumulator))
350
+
351
  img = Image.open(getattr(file_obj, 'name', str(file_obj)))
352
  if img.mode != "RGB": img = img.convert("RGB")
353
 
354
+ prompt, neg_prompt, model_n, model_h, other_p = extract_sd_parameters(img, filename_for_log, log_accumulator)
355
+
356
+ reward = get_image_reward(img, filename_for_log, log_accumulator)
357
+ anime_aes_deepghs = get_anime_aesthetic_score_deepghs(img, filename_for_log, log_accumulator)
358
+ maniqa = get_maniqa_score(img, filename_for_log, log_accumulator)
359
+ clip_val = calculate_clip_score_value(img, prompt, filename_for_log, log_accumulator)
360
+ sdxl_detect = get_sdxl_detection_score(img, filename_for_log, log_accumulator)
361
+ anime_ai_chk_deepghs = get_anime_ai_check_score_deepghs(img, filename_for_log, log_accumulator)
362
+
363
+ current_img_total_time_end = time.time()
364
+ log_accumulator.append(f"INFO [{filename_for_log}]: Finished all scores (total for image: {current_img_total_time_end - current_img_total_time_start:.2f}s)")
365
 
 
 
 
 
 
 
366
 
367
  all_results.append({
368
+ "Filename": filename_for_log, "Prompt": prompt if prompt else "N/A", "Model Name": model_n, "Model Hash": model_h,
369
  "ImageReward": reward, "AnimeAesthetic_dg": anime_aes_deepghs, "MANIQA_TQ": maniqa,
370
  "CLIPScore": clip_val, "SDXL_Detector_AI_Prob": sdxl_detect, "AnimeAI_Check_dg_Prob": anime_ai_chk_deepghs,
371
  })
372
+
373
+ # Обновляем UI после обработки каждого файла с текущими результатами
374
+ # Графики и файлы для скачивания будут генерироваться только в конце
375
+ # Но можно передавать df для обновления таблицы
376
+ df_so_far = pd.DataFrame(all_results)
377
+ yield (df_so_far, None, None, None, None, # Пока без графиков и файлов
378
+ f"Processed image {i+1}/{len(files)}: {filename_for_log}",
379
+ "\n".join(log_accumulator))
380
+
381
  except Exception as e:
382
+ log_accumulator.append(f"CRITICAL ERROR processing {filename_for_log}: {e}")
383
+ print(f"CRITICAL ERROR processing {filename_for_log}: {e}")
384
  all_results.append({
385
+ "Filename": filename_for_log, "Prompt": "Critical Error", "Model Name": "Error", "Model Hash": "Error",
386
  "ImageReward": "Error", "AnimeAesthetic_dg": "Error", "MANIQA_TQ": "Error",
387
  "CLIPScore": "Error", "SDXL_Detector_AI_Prob": "Error", "AnimeAI_Check_dg_Prob": "Error"
388
  })
389
+ df_so_far = pd.DataFrame(all_results)
390
+ yield (df_so_far, None, None, None, None,
391
+ f"Error on image {i+1}/{len(files)}: {filename_for_log}",
392
+ "\n".join(log_accumulator))
393
+
394
+ log_accumulator.append("--- Generating final plots and download files ---")
395
+ yield (pd.DataFrame(all_results), None, None, None, None,
396
+ "Generating final plots...",
397
+ "\n".join(log_accumulator))
398
 
399
  df = pd.DataFrame(all_results)
400
  plot_model_avg_scores_buffer, plot_prompt_clip_scores_buffer = None, None
401
  csv_buffer_val, json_buffer_val = "", ""
402
 
403
  if not df.empty:
404
+ numeric_cols = ["ImageReward", "AnimeAesthetic_dg", "MANIQA_TQ", "CLIPScore"] # MANIQA TQ будет NaN, нормально
405
  for col in numeric_cols: df[col] = pd.to_numeric(df[col], errors='coerce')
406
 
 
407
  df_model_plot = df[(df["Model Name"] != "N/A") & (df["Model Name"].notna())]
408
  if not df_model_plot.empty and df_model_plot["Model Name"].nunique() > 0:
409
  try:
 
413
  ax1.set_title("Average Scores per Model"); ax1.set_ylabel("Average Score")
414
  ax1.tick_params(axis='x', rotation=45, labelsize=8); plt.tight_layout()
415
  plot_model_avg_scores_buffer = io.BytesIO(); fig1.savefig(plot_model_avg_scores_buffer, format="png"); plot_model_avg_scores_buffer.seek(0); plt.close(fig1)
416
+ log_accumulator.append("INFO: Model average scores plot generated.")
417
+ except Exception as e: log_accumulator.append(f"ERROR: Failed to generate model average scores plot: {e}")
418
 
 
419
  df_prompt_plot = df[(df["Prompt"] != "N/A") & (df["Prompt"].notna()) & (df["CLIPScore"].notna())]
420
  if not df_prompt_plot.empty and df_prompt_plot["Prompt"].nunique() > 0 :
421
  try:
422
  df_prompt_plot["Short Prompt"] = df_prompt_plot["Prompt"].apply(lambda x: (str(x)[:30] + '...') if len(str(x)) > 33 else str(x))
423
  prompt_clip_scores = df_prompt_plot.groupby("Short Prompt")["CLIPScore"].mean().sort_values(ascending=False)
424
+ if not prompt_clip_scores.empty and len(prompt_clip_scores) >= 1 :
425
  fig2, ax2 = plt.subplots(figsize=(12, max(7, min(len(prompt_clip_scores)*0.5, 15))))
426
  prompt_clip_scores.head(20).plot(kind="barh", ax=ax2)
427
  ax2.set_title("Average CLIPScore per Prompt (Top 20 unique prompts)"); ax2.set_xlabel("Average CLIPScore")
428
  plt.tight_layout(); plot_prompt_clip_scores_buffer = io.BytesIO(); fig2.savefig(plot_prompt_clip_scores_buffer, format="png"); plot_prompt_clip_scores_buffer.seek(0); plt.close(fig2)
429
+ log_accumulator.append("INFO: Prompt CLIP scores plot generated.")
430
+ except Exception as e: log_accumulator.append(f"ERROR: Failed to generate prompt CLIP scores plot: {e}")
431
 
432
  csv_b = io.StringIO(); df.to_csv(csv_b, index=False); csv_buffer_val = csv_b.getvalue()
433
  json_b = io.StringIO(); df.to_json(json_b, orient='records', indent=4); json_buffer_val = json_b.getvalue()
434
+ log_accumulator.append("INFO: CSV and JSON data prepared for download.")
435
 
436
+ final_status = f"Finished processing {len(all_results)} images. Total time: {sum(entry.get('total_time', 0) for entry in all_results):.2f}s (approx, if times were logged per image)"
437
+ # ^Это не совсем точно, т.к. total_time не собирается в entry, но идея понятна
438
+ log_accumulator.append(final_status)
439
+
440
+ yield (
441
  df,
442
  gr.Image(value=plot_model_avg_scores_buffer, type="pil", visible=plot_model_avg_scores_buffer is not None),
443
  gr.Image(value=plot_prompt_clip_scores_buffer, type="pil", visible=plot_prompt_clip_scores_buffer is not None),
444
  gr.File(value=csv_buffer_val or None, label="Download CSV Results", visible=bool(csv_buffer_val), file_name="evaluation_results.csv"),
445
  gr.File(value=json_buffer_val or None, label="Download JSON Results", visible=bool(json_buffer_val), file_name="evaluation_results.json"),
446
+ final_status,
447
+ "\n".join(log_accumulator)
448
  )
449
 
450
+
451
  # --- Интерфейс Gradio ---
452
  with gr.Blocks(css="footer {display: none !important}") as demo:
453
  gr.Markdown("# AI Image Model Evaluation Tool")
454
  gr.Markdown("Upload PNG images (ideally with Stable Diffusion metadata) to evaluate them...")
455
+
456
+ with gr.Row():
457
+ image_uploader = gr.Files(
458
+ label="Upload Images (PNG)",
459
+ file_count="multiple",
460
+ file_types=["image"]
461
+ )
462
+
463
  process_button = gr.Button("Evaluate Images", variant="primary")
464
+
465
+ status_textbox = gr.Textbox(label="Overall Status", interactive=False)
466
+
467
+ log_output_textbox = gr.Textbox(label="Detailed Logs", lines=15, interactive=False, autoscroll=True) # Новый логгер
468
+
469
  gr.Markdown("## Evaluation Results Table")
470
+ results_table = gr.DataFrame(headers=[
471
  "Filename", "Prompt", "Model Name", "Model Hash", "ImageReward", "AnimeAesthetic_dg",
472
  "MANIQA_TQ", "CLIPScore", "SDXL_Detector_AI_Prob", "AnimeAI_Check_dg_Prob"
473
  ], wrap=True)
474
+
475
  with gr.Row():
476
  download_csv_button = gr.File(label="Download CSV Results", interactive=False)
477
  download_json_button = gr.File(label="Download JSON Results", interactive=False)
478
+
479
  gr.Markdown("## Visualizations")
480
  with gr.Row():
481
  plot_output_model_avg = gr.Image(label="Average Scores per Model", type="pil", interactive=False)
482
  plot_output_prompt_clip = gr.Image(label="Average CLIPScore per Prompt", type="pil", interactive=False)
483
+
484
+ process_button.click(
485
+ fn=process_images_generator, # Изменено на генератор
486
+ inputs=[image_uploader],
487
+ outputs=[
488
+ results_table,
489
+ plot_output_model_avg,
490
+ plot_output_prompt_clip,
491
+ download_csv_button,
492
+ download_json_button,
493
+ status_textbox,
494
+ log_output_textbox # Добавлен вывод для логов
495
+ ]
496
+ )
497
+
498
  gr.Markdown("""**Metric Explanations:** ... (без изменений)""")
499
 
500
  if __name__ == "__main__":
501
+ # Загрузка моделей при старте (вне функции Gradio)
502
+ print("--- Initializing models, please wait... ---")
503
+ # Вызов функций загрузки ONNX моделей, чтобы они кэшировались при старте, если возможно
504
+ # Это не будет выводить логи в UI, только в консоль сервера при запуске.
505
+ # Но поможет понять, загружаются ли они вообще.
506
+ initial_dummy_logs = []
507
+ if onnxruntime:
508
+ get_onnx_session_and_meta(ANIME_AESTHETIC_REPO, ANIME_AESTHETIC_SUBFOLDER, initial_dummy_logs)
509
+ get_onnx_session_and_meta(ANIME_AI_CHECK_REPO, ANIME_AI_CHECK_SUBFOLDER, initial_dummy_logs)
510
+ if initial_dummy_logs:
511
+ print("--- Initial ONNX loading attempts log: ---")
512
+ for log_line in initial_dummy_logs: print(log_line)
513
+ print("-----------------------------------------")
514
+ print("--- Model initialization attempt complete. Launching Gradio. ---")
515
+
516
+ demo.queue().launch(debug=True) # queue() важен для генераторов