File size: 5,774 Bytes
6a83074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import argparse
import json
import os
import re
import random
from collections import defaultdict


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--base-dir', type=str)
    parser.add_argument('--gpt4-result', type=str)
    parser.add_argument('--requery-result', type=str)
    parser.add_argument('--our-result', type=str)
    parser.add_argument('--output-result', type=str)
    parser.add_argument('--split', type=str, default='test')
    parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
    return parser.parse_args()


def convert_caps(results):
    fakecaps = []
    for result in results:
        image_id = result['question_id']
        caption = result['text']
        fakecaps.append({"image_id": int(image_id), "caption": caption})
    return fakecaps


def get_pred_idx(prediction, choices, options):
    """
    Get the index (e.g. 2) from the prediction (e.g. 'C')
    """
    if prediction in options[:len(choices)]:
        return options.index(prediction)
    else:
        return random.choice(range(len(choices)))


if __name__ == "__main__":
    args = get_args()

    base_dir = args.base_dir
    split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
    problems = json.load(open(os.path.join(base_dir, "problems.json")))
    our_predictions = [json.loads(line) for line in open(args.our_result)]
    our_predictions = {pred['question_id']: pred for pred in our_predictions}
    split_problems = {idx: problems[idx] for idx in split_indices}

    requery_predictions = [json.loads(line) for line in open(args.requery_result)]
    requery_predictions = {pred['question_id']: pred for pred in requery_predictions}

    gpt4_predictions = json.load(open(args.gpt4_result))['outputs']

    results = defaultdict(lambda: 0)

    sqa_results = {}
    sqa_results['acc'] = None
    sqa_results['correct'] = None
    sqa_results['count'] = None
    sqa_results['results'] = {}
    sqa_results['outputs'] = {}

    for prob_id, prob in split_problems.items():
        if prob_id not in our_predictions:
            assert False
        if prob_id not in gpt4_predictions:
            assert False
        our_pred = our_predictions[prob_id]['text']
        gpt4_pred = gpt4_predictions[prob_id]
        if prob_id not in requery_predictions:
            results['missing_requery'] += 1
            requery_pred = "MISSING"
        else:
            requery_pred = requery_predictions[prob_id]['text']

        pattern = re.compile(r'The answer is ([A-Z]).')
        our_res = pattern.findall(our_pred)
        if len(our_res) == 1:
            our_answer = our_res[0]  # 'A', 'B', ...
        else:
            our_answer = "FAILED"

        requery_res = pattern.findall(requery_pred)
        if len(requery_res) == 1:
            requery_answer = requery_res[0]  # 'A', 'B', ...
        else:
            requery_answer = "FAILED"

        gpt4_res = pattern.findall(gpt4_pred)
        if len(gpt4_res) == 1:
            gpt4_answer = gpt4_res[0]  # 'A', 'B', ...
        else:
            gpt4_answer = "FAILED"

        our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
        gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
        requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options)

        results['total'] += 1

        if gpt4_answer == 'FAILED':
            results['gpt4_failed'] += 1
            if gpt4_pred_idx == prob['answer']:
                results['gpt4_correct'] += 1
            if our_pred_idx == prob['answer']:
                results['gpt4_ourvisual_correct'] += 1
        elif gpt4_pred_idx == prob['answer']:
            results['gpt4_correct'] += 1
            results['gpt4_ourvisual_correct'] += 1

        if our_pred_idx == prob['answer']:
            results['our_correct'] += 1

        if requery_answer == 'FAILED':
            sqa_results['results'][prob_id] = our_pred_idx
            if our_pred_idx == prob['answer']:
                results['requery_correct'] += 1
        else:
            sqa_results['results'][prob_id] = requery_pred_idx
            if requery_pred_idx == prob['answer']:
                results['requery_correct'] += 1
            else:
                print(f"""
Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']}
Our ({our_answer}): {our_pred}
GPT-4 ({gpt4_answer}): {gpt4_pred}
Requery ({requery_answer}): {requery_pred}
print("=====================================")
""")

        if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
            results['correct_upperbound'] += 1

    total = results['total']
    print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%')
    print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%')
    print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
    print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%')
    print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%')
    print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')

    sqa_results['acc'] = results["requery_correct"] / total * 100
    sqa_results['correct'] = results["requery_correct"]
    sqa_results['count'] = total

    with open(args.output_result, 'w') as f:
        json.dump(sqa_results, f, indent=2)