zhanghaoji
init
eb0678a
raw
history blame
12.3 kB
# Based on https://github.com/haotian-liu/LLaVA.
import os
import ast
import json
import openai
import argparse
from tqdm import tqdm
from time import sleep
from collections import defaultdict
from multiprocessing.pool import Pool
def parse_args():
parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.")
parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.")
parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.")
parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.")
parser.add_argument("--num_chunks", default=1, type=int, help="Result splits")
parser.add_argument("--api_key", required=True, type=str, help="OpenAI API key")
parser.add_argument("--api_type", default=None, type=str, help="OpenAI API type")
parser.add_argument("--api_version", default=None, type=str, help="OpenAI API version")
parser.add_argument("--api_base", default=None, type=str, help="OpenAI API base")
args = parser.parse_args()
return args
def annotate(prediction_set, caption_files, output_dir):
"""
Evaluates question and answer pairs using GPT-3
Returns a score for correctness.
"""
for file in tqdm(caption_files):
key = file[:-5] # Strip file extension
qa_set = prediction_set[key]
question = qa_set['q']
answer = qa_set['a']
pred = qa_set['pred']
try:
# Compute the correctness score
completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "system",
"content":
"You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. "
"Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:"
"------"
"##INSTRUCTIONS: "
"- Focus on the meaningful match between the predicted answer and the correct answer.\n"
"- Consider synonyms or paraphrases as valid matches.\n"
"- Evaluate the correctness of the prediction compared to the answer."
},
{
"role": "user",
"content":
"Please evaluate the following video-based question-answer pair:\n\n"
f"Question: {question}\n"
f"Correct Answer: {answer}\n"
f"Predicted Answer: {pred}\n\n"
"Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. "
"Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING."
"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
"For example, your response should look like this: {'pred': 'yes', 'score': 4.8}."
}
],
temperature=0.002
)
# Convert response to a Python dictionary.
response_message = completion["choices"][0]["message"]["content"]
response_dict = ast.literal_eval(response_message)
result_qa_pair = [response_dict, qa_set]
# Save the question-answer pairs to a json file.
with open(f"{output_dir}/{key}.json", "w") as f:
json.dump(result_qa_pair, f)
sleep(0.5)
except Exception as e:
print(f"Error processing file '{key}': {e}")
sleep(1)
def main():
"""
Main function to control the flow of the program.
"""
# Parse arguments.
args = parse_args()
if args.num_chunks > 1:
pred_contents = []
for _idx in range(args.num_chunks):
file = os.path.join(args.pred_path, f"{args.num_chunks}_{_idx}.json")
pred_contents += [json.loads(line) for line in open(file)]
else:
file = os.path.join(args.pred_path, f"pred.json")
pred_contents = [json.loads(line) for line in open(file)]
# Dictionary to store the count of occurrences for each video_id
video_id_counts = {}
new_pred_contents = []
# Iterate through each sample in pred_contents
for sample in pred_contents:
video_id = sample['id']
if video_id in video_id_counts:
video_id_counts[video_id] += 1
else:
video_id_counts[video_id] = 0
# Create a new sample with the modified key
new_sample = sample
new_sample['id'] = f"{video_id}_{video_id_counts[video_id]}"
new_pred_contents.append(new_sample)
# Generating list of id's and corresponding files
id_list = [x['id'] for x in new_pred_contents]
caption_files = [f"{id}.json" for id in id_list]
output_dir = args.output_dir
# Generate output directory if not exists.
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Preparing dictionary of question-answer sets
prediction_set = {}
for sample in new_pred_contents:
id = sample['id']
question = sample['question']
answer = sample['answer']
pred = sample['pred']
qa_set = {"q": question, "a": answer, "pred": pred, "a_type": sample['answer_type'] if 'answer_type' in sample else None}
prediction_set[id] = qa_set
# Set the OpenAI API key.
openai.api_key = args.api_key # Your API key here
if args.api_type:
openai.api_type = args.api_type
if args.api_version:
openai.api_version = args.api_version
if args.api_base:
openai.api_base = args.api_base # Your API base here
num_tasks = args.num_tasks
# While loop to ensure that all captions are processed.
incomplete_lengths = []
for _ in range(100):
try:
# Files that have not been processed yet.
completed_files = os.listdir(output_dir)
print(f"completed_files: {len(completed_files)}")
# Files that have not been processed yet.
incomplete_files = [f for f in caption_files if f not in completed_files]
print(f"incomplete_files: {len(incomplete_files)}")
incomplete_lengths.append(len(incomplete_files))
if len(incomplete_lengths) > 5 and len(set(incomplete_lengths[-5:])) <= 1:
print(f"incomplete_lengths: {incomplete_lengths}")
print(f"incomplete_files: {incomplete_files}")
print(f"completed_files: {completed_files}")
print(f"failed for 5 times, break")
break
# Break the loop when there are no incomplete files
if len(incomplete_files) == 0:
break
if len(incomplete_files) <= num_tasks:
num_tasks = 1
# Split tasks into parts.
part_len = len(incomplete_files) // num_tasks
all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]
task_args = [(prediction_set, part, args.output_dir) for part in all_parts]
# Use a pool of workers to process the files in parallel.
with Pool() as pool:
pool.starmap(annotate, task_args)
except Exception as e:
print(f"Error: {e}")
# Combine all the processed files into one
combined_contents = {}
json_path = args.output_json
# Iterate through json files
for file_name in os.listdir(output_dir):
if file_name.endswith(".json"):
file_path = os.path.join(output_dir, file_name)
with open(file_path, "r") as json_file:
content = json.load(json_file)
assert 'pred' in content[0], f"Error: {file_name} don't has key=pred"
assert 'score' in content[0], f"Error: {file_name} don't has key=score"
combined_contents[file_name[:-5]] = content
# Write combined content to a json file
with open(json_path, "w") as json_file:
json.dump(combined_contents, json_file)
print("All evaluation completed!")
class ScoreMeter:
def __init__(self):
self.score_sum = 0
self.count = 0
self.yes_count = 0
self.no_count = 0
self.score_dict = {'yes': defaultdict(int), 'no': defaultdict(int)}
def add_score(self, score, pred):
self.score_sum += score
self.count += 1
pred_lower = pred.lower()
if 'yes' in pred_lower:
self.yes_count += 1
self.score_dict['yes'][score] += 1
elif 'no' in pred_lower:
self.no_count += 1
self.score_dict['no'][score] += 1
def get_average_score(self):
res = (self.score_sum / self.count) if self.count else 0
return f"{res:.6f}"
def get_accuracy(self, response_type):
if response_type == 'yes':
res = (self.yes_count / self.count) if self.count else 0
elif response_type == 'no':
res = (self.no_count / self.count) if self.count else 0
else:
res = 0
return f"{res:.6f}"
meter_dic = {'total': ScoreMeter()}
for key, result in combined_contents.items():
# Computing score
score_match = result[0]['score']
score = int(score_match)
pred = result[0]['pred']
meter_dic["total"].add_score(score, pred)
if 'a_type' in result[1] and result[1]['a_type'] is not None:
typ = str(result[1]['a_type'])
if typ not in meter_dic:
meter_dic[typ] = ScoreMeter()
meter_dic[typ].add_score(score, pred)
if 'next' in args.output_dir:
typ = typ[0]
if typ not in meter_dic:
meter_dic[typ] = ScoreMeter()
meter_dic[typ].add_score(score, pred)
csv_dic = {'acc': meter_dic["total"].get_accuracy('yes'), 'score': meter_dic["total"].get_average_score()}
output = ""
output += "Yes count: " + str(meter_dic["total"].yes_count) + "\n"
output += "No count: " + str(meter_dic["total"].no_count) + "\n"
output += "Accuracy: " + str(meter_dic["total"].get_accuracy('yes')) + "\n"
output += "Average score: " + str(meter_dic["total"].get_average_score()) + "\n"
output += "\n"
output += "Total Score Yes/No distribution:\n"
for key, value in meter_dic["total"].score_dict.items():
output += f"{key}:\n"
for k in range(0, 6):
v = value[k]
output += f"{k}: {v}\n"
output += "\n"
output += "Answer Type Score distribution:\n"
output += 'Type, Accuracy, Avg_score\n'
key_list = sorted([k for k in meter_dic.keys()])
for key in key_list:
output += f"{key}, {meter_dic[key].get_accuracy('yes')}, {meter_dic[key].get_average_score()}\n"
csv_dic[key] = meter_dic[key].get_accuracy('yes')
output += "\n"
for k in csv_dic.keys():
output += f"{k}, "
output = output.rstrip(', ') # Remove the trailing comma and space
output += "\n"
for k in csv_dic.keys():
output += str(csv_dic[k]) + ", "
output = output.rstrip(', ') # Remove the trailing comma and space
output += "\n"
print(output)
args.output_csv = args.output_json.replace(".json", ".csv")
with open(args.output_csv, 'w') as f:
f.write(output)
if __name__ == "__main__":
main()