|
|
|
|
|
import os |
|
import ast |
|
import json |
|
import openai |
|
import argparse |
|
from tqdm import tqdm |
|
from time import sleep |
|
from collections import defaultdict |
|
from multiprocessing.pool import Pool |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") |
|
parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.") |
|
parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.") |
|
parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.") |
|
parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.") |
|
parser.add_argument("--num_chunks", default=1, type=int, help="Result splits") |
|
parser.add_argument("--api_key", required=True, type=str, help="OpenAI API key") |
|
parser.add_argument("--api_type", default=None, type=str, help="OpenAI API type") |
|
parser.add_argument("--api_version", default=None, type=str, help="OpenAI API version") |
|
parser.add_argument("--api_base", default=None, type=str, help="OpenAI API base") |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def annotate(prediction_set, caption_files, output_dir): |
|
""" |
|
Evaluates question and answer pairs using GPT-3 |
|
Returns a score for correctness. |
|
""" |
|
for file in tqdm(caption_files): |
|
key = file[:-5] |
|
qa_set = prediction_set[key] |
|
question = qa_set['q'] |
|
answer = qa_set['a'] |
|
pred = qa_set['pred'] |
|
try: |
|
|
|
completion = openai.ChatCompletion.create( |
|
model="gpt-3.5-turbo", |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": |
|
"You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. " |
|
"Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:" |
|
"------" |
|
"##INSTRUCTIONS: " |
|
"- Focus on the meaningful match between the predicted answer and the correct answer.\n" |
|
"- Consider synonyms or paraphrases as valid matches.\n" |
|
"- Evaluate the correctness of the prediction compared to the answer." |
|
}, |
|
{ |
|
"role": "user", |
|
"content": |
|
"Please evaluate the following video-based question-answer pair:\n\n" |
|
f"Question: {question}\n" |
|
f"Correct Answer: {answer}\n" |
|
f"Predicted Answer: {pred}\n\n" |
|
"Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. " |
|
"Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING." |
|
"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " |
|
"For example, your response should look like this: {'pred': 'yes', 'score': 4.8}." |
|
} |
|
], |
|
temperature=0.002 |
|
) |
|
|
|
response_message = completion["choices"][0]["message"]["content"] |
|
response_dict = ast.literal_eval(response_message) |
|
result_qa_pair = [response_dict, qa_set] |
|
|
|
|
|
with open(f"{output_dir}/{key}.json", "w") as f: |
|
json.dump(result_qa_pair, f) |
|
sleep(0.5) |
|
|
|
except Exception as e: |
|
print(f"Error processing file '{key}': {e}") |
|
sleep(1) |
|
|
|
|
|
def main(): |
|
""" |
|
Main function to control the flow of the program. |
|
""" |
|
|
|
args = parse_args() |
|
|
|
if args.num_chunks > 1: |
|
pred_contents = [] |
|
for _idx in range(args.num_chunks): |
|
file = os.path.join(args.pred_path, f"{args.num_chunks}_{_idx}.json") |
|
pred_contents += [json.loads(line) for line in open(file)] |
|
|
|
else: |
|
file = os.path.join(args.pred_path, f"pred.json") |
|
pred_contents = [json.loads(line) for line in open(file)] |
|
|
|
|
|
video_id_counts = {} |
|
new_pred_contents = [] |
|
|
|
|
|
for sample in pred_contents: |
|
video_id = sample['id'] |
|
if video_id in video_id_counts: |
|
video_id_counts[video_id] += 1 |
|
else: |
|
video_id_counts[video_id] = 0 |
|
|
|
|
|
new_sample = sample |
|
new_sample['id'] = f"{video_id}_{video_id_counts[video_id]}" |
|
new_pred_contents.append(new_sample) |
|
|
|
|
|
id_list = [x['id'] for x in new_pred_contents] |
|
caption_files = [f"{id}.json" for id in id_list] |
|
|
|
output_dir = args.output_dir |
|
|
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
|
|
|
|
prediction_set = {} |
|
for sample in new_pred_contents: |
|
id = sample['id'] |
|
question = sample['question'] |
|
answer = sample['answer'] |
|
pred = sample['pred'] |
|
qa_set = {"q": question, "a": answer, "pred": pred, "a_type": sample['answer_type'] if 'answer_type' in sample else None} |
|
prediction_set[id] = qa_set |
|
|
|
|
|
openai.api_key = args.api_key |
|
if args.api_type: |
|
openai.api_type = args.api_type |
|
if args.api_version: |
|
openai.api_version = args.api_version |
|
if args.api_base: |
|
openai.api_base = args.api_base |
|
num_tasks = args.num_tasks |
|
|
|
|
|
incomplete_lengths = [] |
|
for _ in range(100): |
|
try: |
|
|
|
completed_files = os.listdir(output_dir) |
|
print(f"completed_files: {len(completed_files)}") |
|
|
|
|
|
incomplete_files = [f for f in caption_files if f not in completed_files] |
|
print(f"incomplete_files: {len(incomplete_files)}") |
|
incomplete_lengths.append(len(incomplete_files)) |
|
if len(incomplete_lengths) > 5 and len(set(incomplete_lengths[-5:])) <= 1: |
|
print(f"incomplete_lengths: {incomplete_lengths}") |
|
print(f"incomplete_files: {incomplete_files}") |
|
print(f"completed_files: {completed_files}") |
|
print(f"failed for 5 times, break") |
|
break |
|
|
|
|
|
if len(incomplete_files) == 0: |
|
break |
|
if len(incomplete_files) <= num_tasks: |
|
num_tasks = 1 |
|
|
|
|
|
part_len = len(incomplete_files) // num_tasks |
|
all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] |
|
task_args = [(prediction_set, part, args.output_dir) for part in all_parts] |
|
|
|
|
|
with Pool() as pool: |
|
pool.starmap(annotate, task_args) |
|
|
|
except Exception as e: |
|
print(f"Error: {e}") |
|
|
|
|
|
combined_contents = {} |
|
json_path = args.output_json |
|
|
|
|
|
for file_name in os.listdir(output_dir): |
|
if file_name.endswith(".json"): |
|
file_path = os.path.join(output_dir, file_name) |
|
with open(file_path, "r") as json_file: |
|
content = json.load(json_file) |
|
assert 'pred' in content[0], f"Error: {file_name} don't has key=pred" |
|
assert 'score' in content[0], f"Error: {file_name} don't has key=score" |
|
combined_contents[file_name[:-5]] = content |
|
|
|
|
|
with open(json_path, "w") as json_file: |
|
json.dump(combined_contents, json_file) |
|
print("All evaluation completed!") |
|
|
|
class ScoreMeter: |
|
def __init__(self): |
|
self.score_sum = 0 |
|
self.count = 0 |
|
self.yes_count = 0 |
|
self.no_count = 0 |
|
self.score_dict = {'yes': defaultdict(int), 'no': defaultdict(int)} |
|
|
|
def add_score(self, score, pred): |
|
self.score_sum += score |
|
self.count += 1 |
|
pred_lower = pred.lower() |
|
if 'yes' in pred_lower: |
|
self.yes_count += 1 |
|
self.score_dict['yes'][score] += 1 |
|
elif 'no' in pred_lower: |
|
self.no_count += 1 |
|
self.score_dict['no'][score] += 1 |
|
|
|
def get_average_score(self): |
|
res = (self.score_sum / self.count) if self.count else 0 |
|
return f"{res:.6f}" |
|
|
|
def get_accuracy(self, response_type): |
|
if response_type == 'yes': |
|
res = (self.yes_count / self.count) if self.count else 0 |
|
elif response_type == 'no': |
|
res = (self.no_count / self.count) if self.count else 0 |
|
else: |
|
res = 0 |
|
return f"{res:.6f}" |
|
|
|
meter_dic = {'total': ScoreMeter()} |
|
for key, result in combined_contents.items(): |
|
|
|
score_match = result[0]['score'] |
|
score = int(score_match) |
|
pred = result[0]['pred'] |
|
|
|
meter_dic["total"].add_score(score, pred) |
|
if 'a_type' in result[1] and result[1]['a_type'] is not None: |
|
typ = str(result[1]['a_type']) |
|
if typ not in meter_dic: |
|
meter_dic[typ] = ScoreMeter() |
|
meter_dic[typ].add_score(score, pred) |
|
|
|
if 'next' in args.output_dir: |
|
typ = typ[0] |
|
if typ not in meter_dic: |
|
meter_dic[typ] = ScoreMeter() |
|
meter_dic[typ].add_score(score, pred) |
|
|
|
csv_dic = {'acc': meter_dic["total"].get_accuracy('yes'), 'score': meter_dic["total"].get_average_score()} |
|
|
|
output = "" |
|
output += "Yes count: " + str(meter_dic["total"].yes_count) + "\n" |
|
output += "No count: " + str(meter_dic["total"].no_count) + "\n" |
|
output += "Accuracy: " + str(meter_dic["total"].get_accuracy('yes')) + "\n" |
|
output += "Average score: " + str(meter_dic["total"].get_average_score()) + "\n" |
|
output += "\n" |
|
output += "Total Score Yes/No distribution:\n" |
|
for key, value in meter_dic["total"].score_dict.items(): |
|
output += f"{key}:\n" |
|
for k in range(0, 6): |
|
v = value[k] |
|
output += f"{k}: {v}\n" |
|
output += "\n" |
|
output += "Answer Type Score distribution:\n" |
|
output += 'Type, Accuracy, Avg_score\n' |
|
key_list = sorted([k for k in meter_dic.keys()]) |
|
for key in key_list: |
|
output += f"{key}, {meter_dic[key].get_accuracy('yes')}, {meter_dic[key].get_average_score()}\n" |
|
csv_dic[key] = meter_dic[key].get_accuracy('yes') |
|
|
|
output += "\n" |
|
for k in csv_dic.keys(): |
|
output += f"{k}, " |
|
output = output.rstrip(', ') |
|
output += "\n" |
|
|
|
for k in csv_dic.keys(): |
|
output += str(csv_dic[k]) + ", " |
|
output = output.rstrip(', ') |
|
output += "\n" |
|
|
|
print(output) |
|
args.output_csv = args.output_json.replace(".json", ".csv") |
|
with open(args.output_csv, 'w') as f: |
|
f.write(output) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|