VOIDER commited on
Commit
7f7c3a3
·
verified ·
1 Parent(s): 842de2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -138
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import gradio as gr
2
- from PIL import Image, PngImagePlugin
3
  import io
4
  import os
5
  import pandas as pd
6
  import torch
7
- # from transformers import pipeline as transformers_pipeline , AutoModelForImageClassification, CLIPImageProcessor # ImageReward пока отключен
8
- from transformers import pipeline as transformers_pipeline , CLIPImageProcessor # Убрали AutoModelForImageClassification
9
  import open_clip
10
  import re
11
  import matplotlib.pyplot as plt
@@ -14,6 +13,7 @@ from collections import defaultdict
14
  import numpy as np
15
  import logging
16
  import time
 
17
 
18
  # --- ONNX Related Imports and Setup ---
19
  try:
@@ -84,49 +84,25 @@ def get_onnx_session_and_meta(repo_id, model_subfolder, current_log_list):
84
  onnx_sessions_cache[cache_key] = (None, [], None)
85
  return None, [], None
86
 
87
- # 1. ImageReward - ВРЕМЕННО ОТКЛЮЧЕНО
88
  reward_processor, reward_model = None, None
89
  print("INFO: THUDM/ImageReward is temporarily disabled due to loading issues.")
90
- # try:
91
- # print("INFO: Loading THUDM/ImageReward model...")
92
- # # reward_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
93
- # # reward_model = AutoModelForImageClassification.from_pretrained("THUDM/ImageReward", trust_remote_code=True).to(DEVICE) # Попытка с trust_remote_code
94
- # # reward_model.eval()
95
- # # print("INFO: THUDM/ImageReward loaded successfully.")
96
- # except Exception as e:
97
- # print(f"ERROR: Failed to load THUDM/ImageReward: {e}")
98
-
99
-
100
- ANIME_AESTHETIC_REPO = "deepghs/anime_aesthetic"
101
- ANIME_AESTHETIC_SUBFOLDER = "swinv2pv3_v0_448_ls0.2_x"
102
- ANIME_AESTHETIC_IMG_SIZE = (448, 448)
103
- ANIME_AESTHETIC_LABEL_WEIGHTS = {"normal": 0.0, "slight": 1.0, "moderate": 2.0, "strong": 3.0, "extreme": 4.0}
104
  print("INFO: MANIQA (honklers/maniqa-nr) is currently disabled.")
105
-
106
  clip_model_instance, clip_preprocess, clip_tokenizer = None, None, None
107
  try:
108
- clip_model_name = 'ViT-L-14'
109
- print(f"INFO: Loading CLIP model {clip_model_name} (laion2b_s32b_b82k)...")
110
- clip_model_instance, _, clip_preprocess_val = open_clip.create_model_and_transforms(
111
- clip_model_name, pretrained='laion2b_s32b_b82k', device=DEVICE
112
- )
113
- clip_preprocess = clip_preprocess_val
114
- clip_tokenizer = open_clip.get_tokenizer(clip_model_name)
115
- clip_model_instance.eval()
116
- print(f"INFO: CLIP model {clip_model_name} (laion2b_s32b_b82k) loaded successfully.")
117
- except Exception as e:
118
- print(f"ERROR: Failed to load CLIP model {clip_model_name} (laion2b_s32b_b82k): {e}")
119
-
120
  sdxl_detector_pipe = None
121
  try:
122
  print("INFO: Loading Organika/sdxl-detector model...")
123
  sdxl_detector_pipe = transformers_pipeline("image-classification", model="Organika/sdxl-detector", device=torch.device(DEVICE).index if DEVICE=="cuda" else -1)
124
  print("INFO: Organika/sdxl-detector loaded successfully.")
125
- except Exception as e:
126
- print(f"ERROR: Failed to load Organika/sdxl-detector: {e}")
127
-
128
- ANIME_AI_CHECK_REPO = "deepghs/anime_ai_check"
129
- ANIME_AI_CHECK_SUBFOLDER = "caformer_s36_plus_sce"
130
  ANIME_AI_CHECK_IMG_SIZE = (384, 384)
131
 
132
  def extract_sd_parameters(image_pil, filename_for_log, current_log_list):
@@ -152,8 +128,7 @@ def extract_sd_parameters(image_pil, filename_for_log, current_log_list):
152
  negative_prompt = parameters_str[neg_prompt_index + len("Negative prompt:"):end_of_neg].strip()
153
  params_part = parameters_str[end_of_neg:].strip() if end_of_neg < len(parameters_str) else ""
154
  elif steps_meta_index != -1:
155
- prompt = parameters_str[:steps_meta_index].strip()
156
- params_part = parameters_str[steps_meta_index:]
157
  else:
158
  prompt = parameters_str.strip(); params_part = ""
159
  if params_part:
@@ -172,35 +147,14 @@ def extract_sd_parameters(image_pil, filename_for_log, current_log_list):
172
  if model_name == "N/A" and "Checkpoint" in other_params_dict: model_name = other_params_dict["Checkpoint"]
173
  if model_name == "N/A" and "model" in other_params_dict: model_name = other_params_dict["model"]
174
  current_log_list.append(f"DEBUG [{filename_for_log}]: Parsed Prompt: {prompt[:50]}... | Model: {model_name}")
175
- except Exception as e:
176
- current_log_list.append(f"ERROR [{filename_for_log}]: Failed to parse metadata: {e}")
177
  return prompt, negative_prompt, model_name, model_hash, other_params_dict
178
 
179
  @torch.no_grad()
180
- def get_image_reward(image_pil, filename_for_log, current_log_list):
181
- # current_log_list.append(f"INFO [{filename_for_log}]: ImageReward model not loaded (disabled), skipping.")
182
- return "N/A (Disabled)" # Временно отключено
183
- # if not reward_model or not reward_processor:
184
- # current_log_list.append(f"INFO [{filename_for_log}]: ImageReward model not loaded, skipping.")
185
- # return "N/A"
186
- # t_start = time.time()
187
- # current_log_list.append(f"DEBUG [{filename_for_log}]: Starting ImageReward score (PyTorch Device: {DEVICE})...")
188
- # try:
189
- # inputs = reward_processor(images=image_pil, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
190
- # outputs = reward_model(**inputs)
191
- # score = round(outputs.logits.item(), 4)
192
- # t_end = time.time()
193
- # current_log_list.append(f"DEBUG [{filename_for_log}]: ImageReward score: {score} (took {t_end - t_start:.2f}s)")
194
- # return score
195
- # except Exception as e:
196
- # current_log_list.append(f"ERROR [{filename_for_log}]: ImageReward scoring failed: {e}")
197
- # return "Error"
198
-
199
  def get_anime_aesthetic_score_deepghs(image_pil, filename_for_log, current_log_list):
200
  session, labels, meta = get_onnx_session_and_meta(ANIME_AESTHETIC_REPO, ANIME_AESTHETIC_SUBFOLDER, current_log_list)
201
- if not session or not labels:
202
- current_log_list.append(f"INFO [{filename_for_log}]: AnimeAesthetic ONNX model not loaded, skipping.")
203
- return "N/A"
204
  t_start = time.time(); current_log_list.append(f"DEBUG [{filename_for_log}]: Starting AnimeAesthetic (ONNX) score...")
205
  try:
206
  input_data = _img_preprocess_for_onnx(image_pil.copy(), size=ANIME_AESTHETIC_IMG_SIZE)
@@ -209,24 +163,15 @@ def get_anime_aesthetic_score_deepghs(image_pil, filename_for_log, current_log_l
209
  scores = onnx_output[0]; exp_scores = np.exp(scores - np.max(scores)); probabilities = exp_scores / np.sum(exp_scores)
210
  weighted_score = sum(probabilities[i] * ANIME_AESTHETIC_LABEL_WEIGHTS.get(label, 0.0) for i, label in enumerate(labels))
211
  score = round(weighted_score, 4); t_end = time.time()
212
- current_log_list.append(f"DEBUG [{filename_for_log}]: AnimeAesthetic (ONNX) score: {score} (took {t_end - t_start:.2f}s)")
213
- return score
214
- except Exception as e:
215
- current_log_list.append(f"ERROR [{filename_for_log}]: AnimeAesthetic (ONNX) scoring failed: {e}"); return "Error"
216
-
217
  @torch.no_grad()
218
  def get_maniqa_score(image_pil, filename_for_log, current_log_list):
219
- current_log_list.append(f"INFO [{filename_for_log}]: MANIQA is disabled.")
220
- return "N/A (Disabled)"
221
-
222
  @torch.no_grad()
223
  def calculate_clip_score_value(image_pil, prompt_text, filename_for_log, current_log_list):
224
- if not clip_model_instance or not clip_preprocess or not clip_tokenizer:
225
- current_log_list.append(f"INFO [{filename_for_log}]: CLIP model not loaded, skipping CLIPScore.")
226
- return "N/A"
227
- if not prompt_text or prompt_text == "N/A":
228
- current_log_list.append(f"INFO [{filename_for_log}]: Empty prompt, skipping CLIPScore.")
229
- return "N/A (Empty Prompt)"
230
  t_start = time.time(); current_log_list.append(f"DEBUG [{filename_for_log}]: Starting CLIPScore (PyTorch Device: {DEVICE})...")
231
  try:
232
  image_input = clip_preprocess(image_pil).unsqueeze(0).to(DEVICE)
@@ -236,32 +181,22 @@ def calculate_clip_score_value(image_pil, prompt_text, filename_for_log, current
236
  text_features_norm = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
237
  score_val = (text_features_norm @ image_features_norm.T).squeeze().item() * 100.0
238
  score = round(score_val, 2); t_end = time.time()
239
- current_log_list.append(f"DEBUG [{filename_for_log}]: CLIPScore: {score} (took {t_end - t_start:.2f}s)")
240
- return score
241
- except Exception as e:
242
- current_log_list.append(f"ERROR [{filename_for_log}]: CLIPScore calculation failed: {e}"); return "Error"
243
-
244
  @torch.no_grad()
245
  def get_sdxl_detection_score(image_pil, filename_for_log, current_log_list):
246
- if not sdxl_detector_pipe:
247
- current_log_list.append(f"INFO [{filename_for_log}]: SDXL_Detector model not loaded, skipping.")
248
- return "N/A"
249
  t_start = time.time(); current_log_list.append(f"DEBUG [{filename_for_log}]: Starting SDXL_Detector score (Device: {sdxl_detector_pipe.device})...")
250
  try:
251
  result = sdxl_detector_pipe(image_pil.copy()); ai_score_val = 0.0
252
  for item in result:
253
  if item['label'].lower() == 'artificial': ai_score_val = item['score']; break
254
  score = round(ai_score_val, 4); t_end = time.time()
255
- current_log_list.append(f"DEBUG [{filename_for_log}]: SDXL_Detector AI Prob: {score} (took {t_end - t_start:.2f}s)")
256
- return score
257
- except Exception as e:
258
- current_log_list.append(f"ERROR [{filename_for_log}]: SDXL_Detector scoring failed: {e}"); return "Error"
259
-
260
  def get_anime_ai_check_score_deepghs(image_pil, filename_for_log, current_log_list):
261
  session, labels, meta = get_onnx_session_and_meta(ANIME_AI_CHECK_REPO, ANIME_AI_CHECK_SUBFOLDER, current_log_list)
262
- if not session or not labels:
263
- current_log_list.append(f"INFO [{filename_for_log}]: AnimeAI_Check ONNX model not loaded, skipping.")
264
- return "N/A"
265
  t_start = time.time(); current_log_list.append(f"DEBUG [{filename_for_log}]: Starting AnimeAI_Check (ONNX) score...")
266
  try:
267
  input_data = _img_preprocess_for_onnx(image_pil.copy(), size=ANIME_AI_CHECK_IMG_SIZE)
@@ -272,43 +207,38 @@ def get_anime_ai_check_score_deepghs(image_pil, filename_for_log, current_log_li
272
  for i, label in enumerate(labels):
273
  if label.lower() == 'ai': ai_prob_val = probabilities[i]; break
274
  score = round(ai_prob_val, 4); t_end = time.time()
275
- current_log_list.append(f"DEBUG [{filename_for_log}]: AnimeAI_Check (ONNX) AI Prob: {score} (took {t_end - t_start:.2f}s)")
276
- return score
277
- except Exception as e:
278
- current_log_list.append(f"ERROR [{filename_for_log}]: AnimeAI_Check (ONNX) scoring failed: {e}"); return "Error"
279
 
280
  def process_images_generator(files, progress=gr.Progress(track_tqdm=True)):
281
  if not files:
282
- yield pd.DataFrame(), None, None, None, None, "Please upload some images.", "No files to process."
 
 
 
283
  return
284
 
285
  all_results = []
286
  log_accumulator = [f"INFO: Starting processing for {len(files)} images..."]
287
- # Начальный yield для лога и статуса
288
- yield (pd.DataFrame(all_results), None, None,
289
- gr.File(visible=False), gr.File(visible=False), # Скрываем кнопки скачивания вначале
290
  "Processing...", "\n".join(log_accumulator))
291
 
292
  for i, file_obj in enumerate(files):
293
- filename_for_log = "Unknown File"
294
- current_img_total_time_start = time.time()
295
  try:
296
  filename_for_log = os.path.basename(getattr(file_obj, 'name', f"file_{i}_{int(time.time())}"))
297
  log_accumulator.append(f"--- Processing image {i+1}/{len(files)}: {filename_for_log} ---")
298
-
299
- # Используем progress(float, desc=...)
300
  progress( (i + 0.1) / len(files), desc=f"Img {i+1}/{len(files)}: Loading {filename_for_log}")
301
- # Немедленно обновляем UI с логом перед тяжелой загрузкой изображения
302
- yield (pd.DataFrame(all_results), None, None,
303
  gr.File(visible=False), gr.File(visible=False),
304
- f"Loading image {i+1}/{len(files)}: {filename_for_log}",
305
- "\n".join(log_accumulator))
306
 
307
  img = Image.open(getattr(file_obj, 'name', str(file_obj)))
308
  if img.mode != "RGB": img = img.convert("RGB")
309
  progress( (i + 0.3) / len(files), desc=f"Img {i+1}/{len(files)}: Scoring {filename_for_log}")
310
-
311
-
312
  prompt, neg_prompt, model_n, model_h, other_p = extract_sd_parameters(img, filename_for_log, log_accumulator)
313
  reward = get_image_reward(img, filename_for_log, log_accumulator)
314
  anime_aes_deepghs = get_anime_aesthetic_score_deepghs(img, filename_for_log, log_accumulator)
@@ -318,7 +248,6 @@ def process_images_generator(files, progress=gr.Progress(track_tqdm=True)):
318
  anime_ai_chk_deepghs = get_anime_ai_check_score_deepghs(img, filename_for_log, log_accumulator)
319
  current_img_total_time_end = time.time()
320
  log_accumulator.append(f"INFO [{filename_for_log}]: Finished all scores (total for image: {current_img_total_time_end - current_img_total_time_start:.2f}s)")
321
-
322
  all_results.append({
323
  "Filename": filename_for_log, "Prompt": prompt if prompt else "N/A", "Model Name": model_n, "Model Hash": model_h,
324
  "ImageReward": reward, "AnimeAesthetic_dg": anime_aes_deepghs, "MANIQA_TQ": maniqa,
@@ -326,10 +255,10 @@ def process_images_generator(files, progress=gr.Progress(track_tqdm=True)):
326
  })
327
  df_so_far = pd.DataFrame(all_results)
328
  progress( (i + 1.0) / len(files), desc=f"Img {i+1}/{len(files)}: Done {filename_for_log}")
329
- yield (df_so_far, None, None,
 
330
  gr.File(visible=False), gr.File(visible=False),
331
- f"Processed image {i+1}/{len(files)}: {filename_for_log}",
332
- "\n".join(log_accumulator))
333
  except Exception as e:
334
  log_accumulator.append(f"CRITICAL ERROR processing {filename_for_log}: {e}")
335
  print(f"CRITICAL ERROR processing {filename_for_log}: {e}")
@@ -339,21 +268,21 @@ def process_images_generator(files, progress=gr.Progress(track_tqdm=True)):
339
  "CLIPScore": "Error", "SDXL_Detector_AI_Prob": "Error", "AnimeAI_Check_dg_Prob": "Error"
340
  })
341
  df_so_far = pd.DataFrame(all_results)
342
- yield (df_so_far, None, None,
 
343
  gr.File(visible=False), gr.File(visible=False),
344
- f"Error on image {i+1}/{len(files)}: {filename_for_log}",
345
- "\n".join(log_accumulator))
346
 
347
  log_accumulator.append("--- Generating final plots and download files ---")
348
  progress(1.0, desc="Generating final plots...")
349
- yield (pd.DataFrame(all_results), None, None,
 
350
  gr.File(visible=False), gr.File(visible=False),
351
- "Generating final plots...",
352
- "\n".join(log_accumulator))
353
 
354
  df = pd.DataFrame(all_results)
355
  plot_model_avg_scores_buffer, plot_prompt_clip_scores_buffer = None, None
356
- csv_file_path_out, json_file_path_out = None, None # Будем возвращать пути к файлам
357
 
358
  if not df.empty:
359
  numeric_cols = ["ImageReward", "AnimeAesthetic_dg", "MANIQA_TQ", "CLIPScore"]
@@ -381,35 +310,36 @@ def process_images_generator(files, progress=gr.Progress(track_tqdm=True)):
381
  plt.tight_layout(); plot_prompt_clip_scores_buffer = io.BytesIO(); fig2.savefig(plot_prompt_clip_scores_buffer, format="png"); plot_prompt_clip_scores_buffer.seek(0); plt.close(fig2)
382
  log_accumulator.append("INFO: Prompt CLIP scores plot generated.")
383
  except Exception as e: log_accumulator.append(f"ERROR: Failed to generate prompt CLIP scores plot: {e}")
384
-
385
- # Сохраняем файлы во временные файлы и возвращаем пути
386
  try:
387
  with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv", encoding='utf-8') as tmp_csv:
388
- df.to_csv(tmp_csv, index=False)
389
- csv_file_path_out = tmp_csv.name
390
  with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json", encoding='utf-8') as tmp_json:
391
- df.to_json(tmp_json, orient='records', indent=4)
392
- json_file_path_out = tmp_json.name
393
  log_accumulator.append("INFO: CSV and JSON data prepared for download.")
394
- except Exception as e:
395
- log_accumulator.append(f"ERROR preparing download files: {e}")
396
-
397
 
398
  final_status = f"Finished processing {len(all_results)} images."
399
  log_accumulator.append(final_status)
400
 
 
 
 
 
 
 
 
 
 
401
  yield (
402
  df,
403
- gr.Image(value=plot_model_avg_scores_buffer, visible=plot_model_avg_scores_buffer is not None),
404
- gr.Image(value=plot_prompt_clip_scores_buffer, visible=plot_prompt_clip_scores_buffer is not None),
405
- gr.File(value=csv_file_path_out, visible=csv_file_path_out is not None), # Убрали file_name
406
- gr.File(value=json_file_path_out, visible=json_file_path_out is not None), # Убрали file_name
407
  final_status,
408
  "\n".join(log_accumulator)
409
  )
410
 
411
- import tempfile # Для gr.File
412
-
413
  with gr.Blocks(css="footer {display: none !important}") as demo:
414
  gr.Markdown("# AI Image Model Evaluation Tool")
415
  gr.Markdown("Upload PNG images (ideally with Stable Diffusion metadata) to evaluate them...")
@@ -423,8 +353,8 @@ with gr.Blocks(css="footer {display: none !important}") as demo:
423
  "MANIQA_TQ", "CLIPScore", "SDXL_Detector_AI_Prob", "AnimeAI_Check_dg_Prob"
424
  ], wrap=True)
425
  with gr.Row():
426
- download_csv_button = gr.File(label="Download CSV Results", interactive=False) # Будет обновляться из yield
427
- download_json_button = gr.File(label="Download JSON Results", interactive=False) # Будет обновляться из yield
428
  gr.Markdown("## Visualizations")
429
  with gr.Row():
430
  plot_output_model_avg = gr.Image(label="Average Scores per Model", type="pil", interactive=False)
 
1
  import gradio as gr
2
+ from PIL import Image, PngImagePlugin # Убедимся, что Image из PIL импортирован
3
  import io
4
  import os
5
  import pandas as pd
6
  import torch
7
+ from transformers import pipeline as transformers_pipeline , CLIPImageProcessor
 
8
  import open_clip
9
  import re
10
  import matplotlib.pyplot as plt
 
13
  import numpy as np
14
  import logging
15
  import time
16
+ import tempfile
17
 
18
  # --- ONNX Related Imports and Setup ---
19
  try:
 
84
  onnx_sessions_cache[cache_key] = (None, [], None)
85
  return None, [], None
86
 
 
87
  reward_processor, reward_model = None, None
88
  print("INFO: THUDM/ImageReward is temporarily disabled due to loading issues.")
89
+ ANIME_AESTHETIC_REPO = "deepghs/anime_aesthetic"; ANIME_AESTHETIC_SUBFOLDER = "swinv2pv3_v0_448_ls0.2_x"
90
+ ANIME_AESTHETIC_IMG_SIZE = (448, 448); ANIME_AESTHETIC_LABEL_WEIGHTS = {"normal": 0.0, "slight": 1.0, "moderate": 2.0, "strong": 3.0, "extreme": 4.0}
 
 
 
 
 
 
 
 
 
 
 
 
91
  print("INFO: MANIQA (honklers/maniqa-nr) is currently disabled.")
 
92
  clip_model_instance, clip_preprocess, clip_tokenizer = None, None, None
93
  try:
94
+ clip_model_name = 'ViT-L-14'; print(f"INFO: Loading CLIP model {clip_model_name} (laion2b_s32b_b82k)...")
95
+ clip_model_instance, _, clip_preprocess_val = open_clip.create_model_and_transforms(clip_model_name, pretrained='laion2b_s32b_b82k', device=DEVICE)
96
+ clip_preprocess = clip_preprocess_val; clip_tokenizer = open_clip.get_tokenizer(clip_model_name)
97
+ clip_model_instance.eval(); print(f"INFO: CLIP model {clip_model_name} (laion2b_s32b_b82k) loaded successfully.")
98
+ except Exception as e: print(f"ERROR: Failed to load CLIP model {clip_model_name} (laion2b_s32b_b82k): {e}")
 
 
 
 
 
 
 
99
  sdxl_detector_pipe = None
100
  try:
101
  print("INFO: Loading Organika/sdxl-detector model...")
102
  sdxl_detector_pipe = transformers_pipeline("image-classification", model="Organika/sdxl-detector", device=torch.device(DEVICE).index if DEVICE=="cuda" else -1)
103
  print("INFO: Organika/sdxl-detector loaded successfully.")
104
+ except Exception as e: print(f"ERROR: Failed to load Organika/sdxl-detector: {e}")
105
+ ANIME_AI_CHECK_REPO = "deepghs/anime_ai_check"; ANIME_AI_CHECK_SUBFOLDER = "caformer_s36_plus_sce"
 
 
 
106
  ANIME_AI_CHECK_IMG_SIZE = (384, 384)
107
 
108
  def extract_sd_parameters(image_pil, filename_for_log, current_log_list):
 
128
  negative_prompt = parameters_str[neg_prompt_index + len("Negative prompt:"):end_of_neg].strip()
129
  params_part = parameters_str[end_of_neg:].strip() if end_of_neg < len(parameters_str) else ""
130
  elif steps_meta_index != -1:
131
+ prompt = parameters_str[:steps_meta_index].strip(); params_part = parameters_str[steps_meta_index:]
 
132
  else:
133
  prompt = parameters_str.strip(); params_part = ""
134
  if params_part:
 
147
  if model_name == "N/A" and "Checkpoint" in other_params_dict: model_name = other_params_dict["Checkpoint"]
148
  if model_name == "N/A" and "model" in other_params_dict: model_name = other_params_dict["model"]
149
  current_log_list.append(f"DEBUG [{filename_for_log}]: Parsed Prompt: {prompt[:50]}... | Model: {model_name}")
150
+ except Exception as e: current_log_list.append(f"ERROR [{filename_for_log}]: Failed to parse metadata: {e}")
 
151
  return prompt, negative_prompt, model_name, model_hash, other_params_dict
152
 
153
  @torch.no_grad()
154
+ def get_image_reward(image_pil, filename_for_log, current_log_list): return "N/A (Disabled)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  def get_anime_aesthetic_score_deepghs(image_pil, filename_for_log, current_log_list):
156
  session, labels, meta = get_onnx_session_and_meta(ANIME_AESTHETIC_REPO, ANIME_AESTHETIC_SUBFOLDER, current_log_list)
157
+ if not session or not labels: current_log_list.append(f"INFO [{filename_for_log}]: AnimeAesthetic ONNX model not loaded, skipping."); return "N/A"
 
 
158
  t_start = time.time(); current_log_list.append(f"DEBUG [{filename_for_log}]: Starting AnimeAesthetic (ONNX) score...")
159
  try:
160
  input_data = _img_preprocess_for_onnx(image_pil.copy(), size=ANIME_AESTHETIC_IMG_SIZE)
 
163
  scores = onnx_output[0]; exp_scores = np.exp(scores - np.max(scores)); probabilities = exp_scores / np.sum(exp_scores)
164
  weighted_score = sum(probabilities[i] * ANIME_AESTHETIC_LABEL_WEIGHTS.get(label, 0.0) for i, label in enumerate(labels))
165
  score = round(weighted_score, 4); t_end = time.time()
166
+ current_log_list.append(f"DEBUG [{filename_for_log}]: AnimeAesthetic (ONNX) score: {score} (took {t_end - t_start:.2f}s)"); return score
167
+ except Exception as e: current_log_list.append(f"ERROR [{filename_for_log}]: AnimeAesthetic (ONNX) scoring failed: {e}"); return "Error"
 
 
 
168
  @torch.no_grad()
169
  def get_maniqa_score(image_pil, filename_for_log, current_log_list):
170
+ current_log_list.append(f"INFO [{filename_for_log}]: MANIQA is disabled."); return "N/A (Disabled)"
 
 
171
  @torch.no_grad()
172
  def calculate_clip_score_value(image_pil, prompt_text, filename_for_log, current_log_list):
173
+ if not clip_model_instance or not clip_preprocess or not clip_tokenizer: current_log_list.append(f"INFO [{filename_for_log}]: CLIP model not loaded, skipping CLIPScore."); return "N/A"
174
+ if not prompt_text or prompt_text == "N/A": current_log_list.append(f"INFO [{filename_for_log}]: Empty prompt, skipping CLIPScore."); return "N/A (Empty Prompt)"
 
 
 
 
175
  t_start = time.time(); current_log_list.append(f"DEBUG [{filename_for_log}]: Starting CLIPScore (PyTorch Device: {DEVICE})...")
176
  try:
177
  image_input = clip_preprocess(image_pil).unsqueeze(0).to(DEVICE)
 
181
  text_features_norm = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
182
  score_val = (text_features_norm @ image_features_norm.T).squeeze().item() * 100.0
183
  score = round(score_val, 2); t_end = time.time()
184
+ current_log_list.append(f"DEBUG [{filename_for_log}]: CLIPScore: {score} (took {t_end - t_start:.2f}s)"); return score
185
+ except Exception as e: current_log_list.append(f"ERROR [{filename_for_log}]: CLIPScore calculation failed: {e}"); return "Error"
 
 
 
186
  @torch.no_grad()
187
  def get_sdxl_detection_score(image_pil, filename_for_log, current_log_list):
188
+ if not sdxl_detector_pipe: current_log_list.append(f"INFO [{filename_for_log}]: SDXL_Detector model not loaded, skipping."); return "N/A"
 
 
189
  t_start = time.time(); current_log_list.append(f"DEBUG [{filename_for_log}]: Starting SDXL_Detector score (Device: {sdxl_detector_pipe.device})...")
190
  try:
191
  result = sdxl_detector_pipe(image_pil.copy()); ai_score_val = 0.0
192
  for item in result:
193
  if item['label'].lower() == 'artificial': ai_score_val = item['score']; break
194
  score = round(ai_score_val, 4); t_end = time.time()
195
+ current_log_list.append(f"DEBUG [{filename_for_log}]: SDXL_Detector AI Prob: {score} (took {t_end - t_start:.2f}s)"); return score
196
+ except Exception as e: current_log_list.append(f"ERROR [{filename_for_log}]: SDXL_Detector scoring failed: {e}"); return "Error"
 
 
 
197
  def get_anime_ai_check_score_deepghs(image_pil, filename_for_log, current_log_list):
198
  session, labels, meta = get_onnx_session_and_meta(ANIME_AI_CHECK_REPO, ANIME_AI_CHECK_SUBFOLDER, current_log_list)
199
+ if not session or not labels: current_log_list.append(f"INFO [{filename_for_log}]: AnimeAI_Check ONNX model not loaded, skipping."); return "N/A"
 
 
200
  t_start = time.time(); current_log_list.append(f"DEBUG [{filename_for_log}]: Starting AnimeAI_Check (ONNX) score...")
201
  try:
202
  input_data = _img_preprocess_for_onnx(image_pil.copy(), size=ANIME_AI_CHECK_IMG_SIZE)
 
207
  for i, label in enumerate(labels):
208
  if label.lower() == 'ai': ai_prob_val = probabilities[i]; break
209
  score = round(ai_prob_val, 4); t_end = time.time()
210
+ current_log_list.append(f"DEBUG [{filename_for_log}]: AnimeAI_Check (ONNX) AI Prob: {score} (took {t_end - t_start:.2f}s)"); return score
211
+ except Exception as e: current_log_list.append(f"ERROR [{filename_for_log}]: AnimeAI_Check (ONNX) scoring failed: {e}"); return "Error"
 
 
212
 
213
  def process_images_generator(files, progress=gr.Progress(track_tqdm=True)):
214
  if not files:
215
+ yield (pd.DataFrame(),
216
+ gr.Image(visible=False), gr.Image(visible=False),
217
+ gr.File(visible=False), gr.File(visible=False),
218
+ "Please upload some images.", "No files to process.")
219
  return
220
 
221
  all_results = []
222
  log_accumulator = [f"INFO: Starting processing for {len(files)} images..."]
223
+ yield (pd.DataFrame(all_results),
224
+ gr.Image(visible=False), gr.Image(visible=False),
225
+ gr.File(visible=False), gr.File(visible=False),
226
  "Processing...", "\n".join(log_accumulator))
227
 
228
  for i, file_obj in enumerate(files):
229
+ filename_for_log = "Unknown File"; current_img_total_time_start = time.time()
 
230
  try:
231
  filename_for_log = os.path.basename(getattr(file_obj, 'name', f"file_{i}_{int(time.time())}"))
232
  log_accumulator.append(f"--- Processing image {i+1}/{len(files)}: {filename_for_log} ---")
 
 
233
  progress( (i + 0.1) / len(files), desc=f"Img {i+1}/{len(files)}: Loading {filename_for_log}")
234
+ yield (pd.DataFrame(all_results),
235
+ gr.Image(visible=False), gr.Image(visible=False),
236
  gr.File(visible=False), gr.File(visible=False),
237
+ f"Loading image {i+1}/{len(files)}: {filename_for_log}", "\n".join(log_accumulator))
 
238
 
239
  img = Image.open(getattr(file_obj, 'name', str(file_obj)))
240
  if img.mode != "RGB": img = img.convert("RGB")
241
  progress( (i + 0.3) / len(files), desc=f"Img {i+1}/{len(files)}: Scoring {filename_for_log}")
 
 
242
  prompt, neg_prompt, model_n, model_h, other_p = extract_sd_parameters(img, filename_for_log, log_accumulator)
243
  reward = get_image_reward(img, filename_for_log, log_accumulator)
244
  anime_aes_deepghs = get_anime_aesthetic_score_deepghs(img, filename_for_log, log_accumulator)
 
248
  anime_ai_chk_deepghs = get_anime_ai_check_score_deepghs(img, filename_for_log, log_accumulator)
249
  current_img_total_time_end = time.time()
250
  log_accumulator.append(f"INFO [{filename_for_log}]: Finished all scores (total for image: {current_img_total_time_end - current_img_total_time_start:.2f}s)")
 
251
  all_results.append({
252
  "Filename": filename_for_log, "Prompt": prompt if prompt else "N/A", "Model Name": model_n, "Model Hash": model_h,
253
  "ImageReward": reward, "AnimeAesthetic_dg": anime_aes_deepghs, "MANIQA_TQ": maniqa,
 
255
  })
256
  df_so_far = pd.DataFrame(all_results)
257
  progress( (i + 1.0) / len(files), desc=f"Img {i+1}/{len(files)}: Done {filename_for_log}")
258
+ yield (df_so_far,
259
+ gr.Image(visible=False), gr.Image(visible=False),
260
  gr.File(visible=False), gr.File(visible=False),
261
+ f"Processed image {i+1}/{len(files)}: {filename_for_log}", "\n".join(log_accumulator))
 
262
  except Exception as e:
263
  log_accumulator.append(f"CRITICAL ERROR processing {filename_for_log}: {e}")
264
  print(f"CRITICAL ERROR processing {filename_for_log}: {e}")
 
268
  "CLIPScore": "Error", "SDXL_Detector_AI_Prob": "Error", "AnimeAI_Check_dg_Prob": "Error"
269
  })
270
  df_so_far = pd.DataFrame(all_results)
271
+ yield (df_so_far,
272
+ gr.Image(visible=False), gr.Image(visible=False),
273
  gr.File(visible=False), gr.File(visible=False),
274
+ f"Error on image {i+1}/{len(files)}: {filename_for_log}", "\n".join(log_accumulator))
 
275
 
276
  log_accumulator.append("--- Generating final plots and download files ---")
277
  progress(1.0, desc="Generating final plots...")
278
+ yield (pd.DataFrame(all_results),
279
+ gr.Image(visible=False), gr.Image(visible=False),
280
  gr.File(visible=False), gr.File(visible=False),
281
+ "Generating final plots...", "\n".join(log_accumulator))
 
282
 
283
  df = pd.DataFrame(all_results)
284
  plot_model_avg_scores_buffer, plot_prompt_clip_scores_buffer = None, None
285
+ csv_file_path_out, json_file_path_out = None, None
286
 
287
  if not df.empty:
288
  numeric_cols = ["ImageReward", "AnimeAesthetic_dg", "MANIQA_TQ", "CLIPScore"]
 
310
  plt.tight_layout(); plot_prompt_clip_scores_buffer = io.BytesIO(); fig2.savefig(plot_prompt_clip_scores_buffer, format="png"); plot_prompt_clip_scores_buffer.seek(0); plt.close(fig2)
311
  log_accumulator.append("INFO: Prompt CLIP scores plot generated.")
312
  except Exception as e: log_accumulator.append(f"ERROR: Failed to generate prompt CLIP scores plot: {e}")
 
 
313
  try:
314
  with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv", encoding='utf-8') as tmp_csv:
315
+ df.to_csv(tmp_csv, index=False); csv_file_path_out = tmp_csv.name
 
316
  with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json", encoding='utf-8') as tmp_json:
317
+ df.to_json(tmp_json, orient='records', indent=4); json_file_path_out = tmp_json.name
 
318
  log_accumulator.append("INFO: CSV and JSON data prepared for download.")
319
+ except Exception as e: log_accumulator.append(f"ERROR preparing download files: {e}")
 
 
320
 
321
  final_status = f"Finished processing {len(all_results)} images."
322
  log_accumulator.append(final_status)
323
 
324
+ # Преобразуем BytesIO в PIL.Image перед передачей в gr.Image
325
+ pil_plot_model_avg = Image.open(plot_model_avg_scores_buffer) if plot_model_avg_scores_buffer and plot_model_avg_scores_buffer.getbuffer().nbytes > 0 else None
326
+ pil_plot_prompt_clip = Image.open(plot_prompt_clip_scores_buffer) if plot_prompt_clip_scores_buffer and plot_prompt_clip_scores_buffer.getbuffer().nbytes > 0 else None
327
+ if pil_plot_model_avg or pil_plot_prompt_clip:
328
+ log_accumulator.append("INFO: Plots converted to PIL Images for display.")
329
+ else:
330
+ log_accumulator.append("INFO: No plots were generated or plots are empty.")
331
+
332
+
333
  yield (
334
  df,
335
+ gr.Image(value=pil_plot_model_avg, visible=pil_plot_model_avg is not None),
336
+ gr.Image(value=pil_plot_prompt_clip, visible=pil_plot_prompt_clip is not None),
337
+ gr.File(value=csv_file_path_out, visible=csv_file_path_out is not None),
338
+ gr.File(value=json_file_path_out, visible=json_file_path_out is not None),
339
  final_status,
340
  "\n".join(log_accumulator)
341
  )
342
 
 
 
343
  with gr.Blocks(css="footer {display: none !important}") as demo:
344
  gr.Markdown("# AI Image Model Evaluation Tool")
345
  gr.Markdown("Upload PNG images (ideally with Stable Diffusion metadata) to evaluate them...")
 
353
  "MANIQA_TQ", "CLIPScore", "SDXL_Detector_AI_Prob", "AnimeAI_Check_dg_Prob"
354
  ], wrap=True)
355
  with gr.Row():
356
+ download_csv_button = gr.File(label="Download CSV Results", interactive=False)
357
+ download_json_button = gr.File(label="Download JSON Results", interactive=False)
358
  gr.Markdown("## Visualizations")
359
  with gr.Row():
360
  plot_output_model_avg = gr.Image(label="Average Scores per Model", type="pil", interactive=False)