import json import webdataset as wds from tqdm import tqdm from PIL import Image import torch import numpy as np import os import time import cv2 import random import pandas as pd from .vl_checklist import _eval_text_image DATASET_ROOT = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/instruct_data/crepe/prod_hard_negatives" def evaluate_crepe( model, tokenizer, image_processor, vis_embed_size=None, rank=0, world_size=1, id=0, subset=True, debug=False, level=4, type="swap", ): if rank == 0: tqdm.write(f"level: {level}") tqdm.write(f"type: {type}") dataset_name = "crepe" media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1] endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1] endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1] endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1] model.eval().cuda() total = 0 correct = 0 assert type in ["swap"] assert 4 <= level <= 12 filename = os.path.join(DATASET_ROOT, type, f"prod_vg_hard_negs_{type}_complexity_{level}.csv") df = pd.read_csv(filename) pbar = tqdm(df.iterrows(), disable=(rank != 0)) for ii, sample in pbar: if ii % world_size != rank: continue text = sample.caption image_path = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg/VG_100K/{}.jpg".format(sample.image_id) x = sample.x y = sample.y width = sample.width height = sample.height image = Image.open(image_path).convert("RGB") image = image.crop((x, y, x+width, y+height)) image = image.resize((224, 224)) final_rank, final_ranks = _eval_text_image(text, image, model, tokenizer, image_processor, vis_embed_size, media_token_id, prebox_token_id, debug=debug) if final_rank is None: continue correct += int((np.array(final_ranks) < 10).sum()) total += len(final_ranks) if debug: tqdm.write("="*80) pbar.set_description(f"{text} | score: {correct / total:.4f} | {final_rank} | {final_ranks}") with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f: f.write(json.dumps([total, correct])) if world_size > 1: torch.distributed.barrier() if rank == 0: total = 0 correct = 0 print(f"evaluate on rank {rank}. world size is {world_size}") for rank_i in range(world_size): [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json")) os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json") total += total_part correct += correct_part score = correct / total print("score:", score, "total:", total) with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f: pass else: score = 0.0 if world_size > 1: torch.distributed.barrier() return score