UniWorld-V1 / univa /eval /gedit /step2_gedit_bench.py
LinB203
init
0c8d55e
raw
history blame
8.77 kB
from viescore import VIEScore
import PIL
import os
import megfile
from PIL import Image
from tqdm import tqdm
from datasets import load_dataset, load_from_disk
import sys
import csv
import threading
import time
import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
GROUPS = [
"background_change", "color_alter", "material_alter", "motion_change", "ps_human", "style_change", "subject-add", "subject-remove", "subject-replace", "text_change", "tone_transfer"
]
def process_single_item(item, vie_score, max_retries=10000):
instruction = item['instruction']
key = item['key']
instruction_language = item['instruction_language']
intersection_exist = item['Intersection_exist']
sample_prefix = key
save_path_fullset_source_image = f"{source_path}/fullset/{group_name}/{instruction_language}/{key}_SRCIMG.png"
save_path_fullset_result_image = f"{save_path}/fullset/{group_name}/{instruction_language}/{key}.png"
src_image_path = save_path_fullset_source_image
save_path_item = save_path_fullset_result_image
for retry in range(max_retries):
try:
pil_image_raw =Image.open(megfile.smart_open(src_image_path, 'rb'))
pil_image_edited = Image.open(megfile.smart_open(save_path_item, 'rb')).convert("RGB").resize((pil_image_raw.size[0], pil_image_raw.size[1]))
text_prompt = instruction
score_list = vie_score.evaluate([pil_image_raw, pil_image_edited], text_prompt)
sementics_score, quality_score, overall_score = score_list
print(f"sementics_score: {sementics_score}, quality_score: {quality_score}, overall_score: {overall_score}, instruction_language: {instruction_language}, instruction: {instruction}")
return {
"source_image": src_image_path,
"edited_image": save_path_item,
"instruction": instruction,
"sementics_score": sementics_score,
"quality_score": quality_score,
"intersection_exist" : item['Intersection_exist'],
"instruction_language" : item['instruction_language']
}
except Exception as e:
if retry < max_retries - 1:
wait_time = (retry + 1) * 2 # 指数退避:2秒, 4秒, 6秒...
print(f"Error processing {save_path_item} (attempt {retry + 1}/{max_retries}): {e}")
print(f"Waiting {wait_time} seconds before retry...")
time.sleep(wait_time)
else:
print(f"Failed to process {save_path_item} after {max_retries} attempts: {e}")
return
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="UniWorld")
parser.add_argument("--save_path", type=str, default="/mnt/data/lb/Remake/UniWorld//eval_output/stage3_ema/Gedit")
parser.add_argument("--backbone", type=str, default="gpt4o", choices=["gpt4o", "qwen25vl"])
parser.add_argument("--source_path", type=str, default="/mnt/workspace/lb/Remake/gedit_bench_eval_images")
args = parser.parse_args()
model_name = args.model_name
save_path_dir = args.save_path
source_path = args.source_path
evaluate_group = [args.model_name]
backbone = args.backbone
vie_score = VIEScore(backbone=backbone, task="tie", key_path='secret_t2.env')
max_workers = 5
dataset = load_dataset("stepfun-ai/GEdit-Bench")
for model_name in evaluate_group:
save_path = save_path_dir
save_path_new = os.path.join(save_path_dir, backbone, "eval_results_new")
all_csv_list = [] # Store all results for final combined CSV
# Load existing processed samples from final CSV if it exists
processed_samples = set()
final_csv_path = os.path.join(save_path_new, f"{model_name}_combined_gpt_score.csv")
if megfile.smart_exists(final_csv_path):
with megfile.smart_open(final_csv_path, 'r', newline='') as f:
reader = csv.DictReader(f)
for row in reader:
# Create a unique identifier for each sample
sample_key = (row['source_image'], row['edited_image'])
processed_samples.add(sample_key)
print(f"Loaded {len(processed_samples)} processed samples from existing CSV")
for group_name in GROUPS:
group_csv_list = []
group_dataset_list = []
for item in tqdm(dataset['train'], desc=f"Processing {model_name} - {group_name}"):
if item['instruction_language'] == 'cn':
continue
# import pdb;pdb.set_trace()
if item['task_type'] == group_name:
group_dataset_list.append(item)
# Load existing group CSV if it exists
group_csv_path = os.path.join(save_path_new, f"{model_name}_{group_name}_gpt_score.csv")
if megfile.smart_exists(group_csv_path):
with megfile.smart_open(group_csv_path, 'r', newline='') as f:
reader = csv.DictReader(f)
group_results = list(reader)
group_csv_list.extend(group_results)
print(f"Loaded existing results for {model_name} - {group_name}")
print(f"Processing group: {group_name}")
print(f"Processing model: {model_name}")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for item in group_dataset_list:
instruction = item['instruction']
key = item['key']
instruction_language = item['instruction_language']
intersection_exist = item['Intersection_exist']
sample_prefix = key
save_path_fullset_source_image = f"{source_path}/fullset/{group_name}/{instruction_language}/{key}_SRCIMG.png"
save_path_fullset_result_image = f"{save_path}/fullset/{group_name}/{instruction_language}/{key}.png"
if not megfile.smart_exists(save_path_fullset_result_image) or not megfile.smart_exists(save_path_fullset_source_image):
print(f"Skipping {sample_prefix}: Source or edited image does not exist")
continue
# Check if this sample has already been processed
sample_key = (save_path_fullset_source_image, save_path_fullset_result_image)
exists = sample_key in processed_samples
if exists:
print(f"Skipping already processed sample: {sample_prefix}")
continue
future = executor.submit(process_single_item, item, vie_score)
futures.append(future)
for future in tqdm(as_completed(futures), total=len(futures), desc=f"Processing {model_name} - {group_name}"):
result = future.result()
if result:
group_csv_list.append(result)
# Save group-specific CSV
group_csv_path = os.path.join(save_path_new, f"{model_name}_{group_name}_gpt_score.csv")
with megfile.smart_open(group_csv_path, 'w', newline='') as f:
fieldnames = ["source_image", "edited_image", "instruction", "sementics_score", "quality_score", "intersection_exist", "instruction_language"]
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for row in group_csv_list:
writer.writerow(row)
all_csv_list.extend(group_csv_list)
print(f"Saved group CSV for {group_name}, length: {len(group_csv_list)}")
# After processing all groups, calculate and save combined results
if not all_csv_list:
print(f"Warning: No results for model {model_name}, skipping combined CSV generation")
continue
# Save combined CSV
combined_csv_path = os.path.join(save_path_new, f"{model_name}_combined_gpt_score.csv")
with megfile.smart_open(combined_csv_path, 'w', newline='') as f:
fieldnames = ["source_image", "edited_image", "instruction", "sementics_score", "quality_score", "intersection_exist", "instruction_language"]
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for row in all_csv_list:
writer.writerow(row)