suyccc
commited on
Commit
·
ab49e08
1
Parent(s):
9f1b525
Add main file
Browse files- README.md +1 -2
- data/answer.enc +0 -0
- main.py +169 -0
README.md
CHANGED
@@ -5,10 +5,9 @@ colorFrom: red
|
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.9.1
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
short_description: This is a automated evaluation for VMCBench test and dev set
|
12 |
---
|
13 |
-
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.9.1
|
8 |
+
app_file: main.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
short_description: This is a automated evaluation for VMCBench test and dev set
|
12 |
---
|
|
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
data/answer.enc
ADDED
The diff for this file is too large to render.
See raw diff
|
|
main.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
from cryptography.fernet import Fernet
|
8 |
+
|
9 |
+
random.seed(0)
|
10 |
+
|
11 |
+
# Helper function to load and decrypt the encrypted answer.json
|
12 |
+
def load_and_decrypt_answer(secret_key):
|
13 |
+
try:
|
14 |
+
# Read encrypted answer file
|
15 |
+
with open("data/answer.enc", "rb") as enc_file:
|
16 |
+
encrypted_data = enc_file.read()
|
17 |
+
|
18 |
+
# Initialize Fernet cipher with the secret key
|
19 |
+
cipher = Fernet(secret_key.encode())
|
20 |
+
|
21 |
+
# Decrypt the file
|
22 |
+
decrypted_data = cipher.decrypt(encrypted_data).decode("utf-8")
|
23 |
+
|
24 |
+
# Parse JSON
|
25 |
+
return json.loads(decrypted_data)
|
26 |
+
except Exception as e:
|
27 |
+
raise ValueError(f"Failed to decrypt answer file: {str(e)}")
|
28 |
+
|
29 |
+
def parse_multi_choice_response(response, all_choices, index2ans):
|
30 |
+
# (Code unchanged)
|
31 |
+
response = str(response)
|
32 |
+
for char in [',', '.', '!', '?', ';', ':', "'"]:
|
33 |
+
response = response.strip(char)
|
34 |
+
response = " " + response + " " # add space to avoid partial match
|
35 |
+
|
36 |
+
index_ans = True
|
37 |
+
ans_with_brack = False
|
38 |
+
candidates = []
|
39 |
+
for choice in all_choices: # e.g., (A) (B) (C) (D)
|
40 |
+
if f'({choice})' in response or f'{choice}. ' in response:
|
41 |
+
candidates.append(choice)
|
42 |
+
ans_with_brack = True
|
43 |
+
|
44 |
+
if len(candidates) == 0:
|
45 |
+
for choice in all_choices: # e.g., A B C D
|
46 |
+
if f' {choice} ' in response:
|
47 |
+
candidates.append(choice)
|
48 |
+
|
49 |
+
if len(candidates) == 0 and len(response.split()) > 5:
|
50 |
+
for index, ans in index2ans.items():
|
51 |
+
if ans.lower() in response.lower():
|
52 |
+
candidates.append(index)
|
53 |
+
index_ans = False
|
54 |
+
|
55 |
+
if len(candidates) == 0:
|
56 |
+
pred_index = random.choice(all_choices)
|
57 |
+
elif len(candidates) > 1:
|
58 |
+
start_indexes = []
|
59 |
+
if index_ans:
|
60 |
+
if ans_with_brack:
|
61 |
+
for can in candidates:
|
62 |
+
index = response.rfind(f'({can})')
|
63 |
+
start_indexes.append(index)
|
64 |
+
else:
|
65 |
+
for can in candidates:
|
66 |
+
index = response.rfind(f" {can} ")
|
67 |
+
start_indexes.append(index)
|
68 |
+
else:
|
69 |
+
for can in candidates:
|
70 |
+
index = response.lower().rfind(index2ans[can].lower())
|
71 |
+
start_indexes.append(index)
|
72 |
+
pred_index = candidates[np.argmax(start_indexes)]
|
73 |
+
else:
|
74 |
+
pred_index = candidates[0]
|
75 |
+
return pred_index
|
76 |
+
|
77 |
+
def get_mc_score(row, use_parse = True):
|
78 |
+
if use_parse:
|
79 |
+
if pd.isna(row["A"]):
|
80 |
+
return False
|
81 |
+
response = row["prediction"]
|
82 |
+
all_choices = []
|
83 |
+
for i in range(9):
|
84 |
+
if chr(65+i) in row and pd.isna(row[chr(65+i)])== False:
|
85 |
+
all_choices.append(chr(65+i))
|
86 |
+
index2ans = {index: row[index] for index in all_choices}
|
87 |
+
pred_index = parse_multi_choice_response(response, all_choices, index2ans)
|
88 |
+
else:
|
89 |
+
pred_index = row["output"]
|
90 |
+
return pred_index == row["answer"]
|
91 |
+
|
92 |
+
def process_json(file):
|
93 |
+
try:
|
94 |
+
data = json.load(open(file))
|
95 |
+
except json.JSONDecodeError:
|
96 |
+
return "Error: Invalid JSON format. Please upload a valid JSON file."
|
97 |
+
|
98 |
+
if not isinstance(data, list):
|
99 |
+
return "Error: JSON must be a list of records."
|
100 |
+
|
101 |
+
required_fields = ['index', 'prediction']
|
102 |
+
for record in data:
|
103 |
+
if not all(field in record for field in required_fields):
|
104 |
+
return f"Error: Each record must contain the following fields: {', '.join(required_fields)}"
|
105 |
+
|
106 |
+
# Decrypt answer.json
|
107 |
+
try:
|
108 |
+
secret_key = os.getenv("SECRET_KEY")
|
109 |
+
answer_data = load_and_decrypt_answer(secret_key)
|
110 |
+
except ValueError as e:
|
111 |
+
return str(e)
|
112 |
+
|
113 |
+
# Convert to DataFrame
|
114 |
+
df = pd.DataFrame(data)
|
115 |
+
df = df[['index', 'prediction']]
|
116 |
+
answer_df = pd.DataFrame(answer_data)
|
117 |
+
df = df.merge(answer_df, on="index", how="left")
|
118 |
+
|
119 |
+
# Example categories
|
120 |
+
general_datasets = ["SEEDBench", "MMStar", "A-OKVQA", "VizWiz", "MMVet",
|
121 |
+
"VQAv2", "OKVQA"]
|
122 |
+
reason_datasets = ["MMMU", "MathVista", "ScienceQA", "RealWorldQA", "GQA", "MathVision"]
|
123 |
+
ocr_datasets = ["TextVQA", "OCRVQA"]
|
124 |
+
doc_datasets = ["AI2D", "ChartQA","DocVQA", "InfoVQA", "TableVQABench"]
|
125 |
+
try:
|
126 |
+
score = df.apply(get_mc_score, axis=1) * 100
|
127 |
+
df['score'] = score.round(2)
|
128 |
+
except Exception as e:
|
129 |
+
return f"Error during scoring: {str(e)}"
|
130 |
+
|
131 |
+
# Calculate metrics for each category
|
132 |
+
results = {}
|
133 |
+
for category in df['category'].unique():
|
134 |
+
category_df = df[df['category'] == category]
|
135 |
+
category_result = category_df['score'].mean()
|
136 |
+
results[category] = category_result
|
137 |
+
results['General'] = np.array([results[category] for category in general_datasets]).mean()
|
138 |
+
results['Reasoning'] = np.array([results[category] for category in reason_datasets]).mean()
|
139 |
+
results['OCR'] = np.array([results[category] for category in ocr_datasets]).mean()
|
140 |
+
results['Doc & Chart'] = np.array([results[category] for category in doc_datasets]).mean()
|
141 |
+
results['Overall'] = np.array([results[category] for category in df['category'].unique()]).mean()
|
142 |
+
|
143 |
+
return json.dumps(results, indent=4)
|
144 |
+
|
145 |
+
def main_gradio():
|
146 |
+
example_json = '''[
|
147 |
+
{
|
148 |
+
"index": 1,
|
149 |
+
"prediction": "A"
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"index": 2,
|
153 |
+
"prediction": "The answer is C. cat"
|
154 |
+
}
|
155 |
+
]'''
|
156 |
+
|
157 |
+
interface = gr.Interface(
|
158 |
+
fn=process_json,
|
159 |
+
inputs=gr.File(label="Upload JSON File"),
|
160 |
+
outputs=gr.Textbox(label="Evaluation Results", interactive=False),
|
161 |
+
title="Automated Evaluation for VMCBench",
|
162 |
+
description=f"Upload a JSON file containing question index and model prediction to evaluate the performance.\n\n"
|
163 |
+
f"Example JSON format:\n\n{example_json}\n\n"
|
164 |
+
"Each record should contain the fields: 'index', 'prediction'."
|
165 |
+
)
|
166 |
+
interface.launch(share=True)
|
167 |
+
|
168 |
+
if __name__ == "__main__":
|
169 |
+
main_gradio()
|