File size: 3,563 Bytes
0b7b08a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import json
import webdataset as wds
from tqdm import tqdm
from PIL import Image
import torch
import numpy as np
import os
import time
import cv2
import random
import pandas as pd
from .vl_checklist import _eval_text_image
DATASET_ROOT = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/instruct_data/crepe/prod_hard_negatives"


def evaluate_crepe(
    model,
    tokenizer,
    image_processor,
    vis_embed_size=None,
    rank=0,
    world_size=1,
    id=0,
    subset=True,
    debug=False,
    level=4,
    type="swap",
):
    if rank == 0:
        tqdm.write(f"level: {level}")
        tqdm.write(f"type: {type}")
    dataset_name = "crepe"
    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
    box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
    endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
    endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
    visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
    previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
    prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
    model.eval().cuda()
    total = 0
    correct = 0
    assert type in ["swap"]
    assert 4 <= level <= 12
    filename = os.path.join(DATASET_ROOT, type, f"prod_vg_hard_negs_{type}_complexity_{level}.csv")
    df = pd.read_csv(filename)
    pbar = tqdm(df.iterrows(), disable=(rank != 0))
    for ii, sample in pbar:
        if ii % world_size != rank:
            continue
        text = sample.caption
        image_path = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg/VG_100K/{}.jpg".format(sample.image_id)
        x = sample.x
        y = sample.y
        width = sample.width
        height = sample.height
        image = Image.open(image_path).convert("RGB")
        image = image.crop((x, y, x+width, y+height))
        image = image.resize((224, 224))
        final_rank, final_ranks = _eval_text_image(text, image, model, tokenizer, image_processor, vis_embed_size, media_token_id, prebox_token_id, debug=debug)
        if final_rank is None:
            continue
        correct += int((np.array(final_ranks) < 10).sum())
        total += len(final_ranks)
        if debug:
            tqdm.write("="*80)
        pbar.set_description(f"{text} | score: {correct / total:.4f} | {final_rank} | {final_ranks}")


    with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
        f.write(json.dumps([total, correct]))
    if world_size > 1:
        torch.distributed.barrier()
    if rank == 0:
        total = 0
        correct = 0
        print(f"evaluate on rank {rank}. world size is {world_size}")
        for rank_i in range(world_size):
            [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
            os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
            total += total_part
            correct += correct_part
        score = correct / total
        print("score:", score, "total:", total)
        with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
            pass
    else:
        score = 0.0
    if world_size > 1:
        torch.distributed.barrier()
    return score