|
from audioop import avg |
|
from email.policy import default |
|
import os |
|
import re |
|
import json |
|
import sys |
|
import argparse |
|
|
|
import openai |
|
from abc import ABC, abstractmethod |
|
|
|
|
|
|
|
from tqdm import tqdm |
|
from functools import partial |
|
|
|
|
|
import time |
|
from collections import defaultdict |
|
from copy import deepcopy |
|
from pathlib import Path |
|
from multiprocessing import Pool |
|
from llava.eval.masp_eval.utils import GPTAPIWrapper |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
system_prompt = "I am ChatGPT, a virtual assistant based on OpenAI's GPT-4 model. I'm designed to understand and generate human-like text based on the input I receive. My main purpose is to assist with information, answer questions, help with tasks that involve natural language processing, and engage in conversations with users.Please note that while I aim to provide accurate and reliable information, I can't guarantee perfection, and it's always a good idea to consult additional resources or professionals when making critical decisions based on the information I provide." |
|
|
|
with open('llava/eval/masp_eval/video_chair/prompts/cap_mention.txt', 'r') as file: |
|
content = file.read() |
|
cap_user_prompt = content |
|
|
|
openai_obj = GPTAPIWrapper(ak="GjrgjjyJHUbLa15DLnr7t0Bhu6IPqFPj") |
|
|
|
|
|
def _add(case_res, all_res): |
|
for key, value in case_res.items(): |
|
for idx, count_ in enumerate(value): |
|
all_res[key][idx] += count_ |
|
return |
|
|
|
def save_metric(coverage, hallucination, case_len, output_dir=None): |
|
final_metrics = {} |
|
for name, res in [['coverage', coverage], ['hallucination', hallucination]]: |
|
combine_counter = [0, 0] |
|
for cat, counter in res.items(): |
|
final_metrics[name+'_'+cat] = round(counter[0] * 100/ counter[1], 2) |
|
combine_counter[0] += counter[0] |
|
combine_counter[1] += counter[1] |
|
if name == 'hallucination': |
|
final_metrics[name+'_'+cat] = round(100 - final_metrics[name+'_'+cat], 2) |
|
final_metrics[name] = round(combine_counter[0] * 100 / combine_counter[1], 2) |
|
if name == 'hallucination': |
|
final_metrics[name] = round(100 - final_metrics[name], 2) |
|
final_metrics['avg_len'] = round(sum(case_len) / len(case_len), 1) |
|
|
|
if output_dir is not None: |
|
with (output_dir / 'chair_metric_neg.json').open('w') as f: |
|
json.dump(final_metrics, f, indent=4) |
|
|
|
print(json.dumps(final_metrics, indent=1)) |
|
|
|
def combine_info(pred_info, gt_info): |
|
combined_info = defaultdict(dict) |
|
if 'object_id' in gt_info[0]: |
|
id_key = 'object_id' |
|
else: |
|
id_key = 'task_id' |
|
for gt in gt_info: |
|
object_id = gt[id_key] |
|
if gt['cap_info'] is None: |
|
continue |
|
combined_info[object_id]['gt_caption'] = gt['refine_caption'] |
|
combined_info[object_id]['gt_info'] = gt['cap_info'] |
|
|
|
for pred in pred_info: |
|
object_id = pred[id_key] |
|
if object_id not in combined_info: |
|
|
|
continue |
|
if pred['cap_info'] is None: |
|
continue |
|
combined_info[object_id]['pred_caption'] = pred['masp_inference'] |
|
combined_info[object_id]['pred_info'] = pred['cap_info'] |
|
filtered_ids = [] |
|
for key, value in combined_info.items(): |
|
if ('pred_info' not in value) or ('gt_info' not in value): |
|
filtered_ids.append(key) |
|
for obj_id in filtered_ids: |
|
del combined_info[obj_id] |
|
|
|
print(f'evaluation cases: {len(combined_info)}') |
|
|
|
return combined_info |
|
|
|
def format_question(info): |
|
categories = ['subjects', 'activities', 'locations', 'text_overlays'] |
|
question_id = 0 |
|
question_mapping = {} |
|
questions = [] |
|
for cat in categories: |
|
if cat == 'subjects': |
|
for c_id, character_content in enumerate(info['subjects']): |
|
questions.append(cat + ':' + character_content['name']) |
|
question_mapping[question_id] = (cat, c_id) |
|
question_id += 1 |
|
if 'attributes' not in character_content: |
|
continue |
|
for a_id, attr in enumerate(character_content['attributes']): |
|
questions.append(character_content['name'] + ':' + attr) |
|
question_mapping[question_id] = ('attributes', c_id, a_id) |
|
question_id += 1 |
|
|
|
else: |
|
for c_id, cat_attr in enumerate(info[cat]): |
|
questions.append(cat + ':' + cat_attr) |
|
question_mapping[question_id] = (cat, c_id) |
|
question_id += 1 |
|
|
|
question_str = '' |
|
for idx, q in enumerate(questions): |
|
question_str += f'{idx+1}. {q}' + '\n' |
|
|
|
return question_str, question_mapping |
|
|
|
def parsing_results(gpt_ret, question_mapping): |
|
gpt_ret = gpt_ret.lower() |
|
pattern = r'(\d+)\.(.+) - (yes|no|maybe),(.+)' |
|
|
|
|
|
matches = re.findall(pattern, gpt_ret) |
|
collected_answer = defaultdict(lambda:[0,0]) |
|
|
|
for match in matches: |
|
question_id, question, answer, reason = match |
|
question_id = int(question_id) - 1 |
|
cat = question_mapping[question_id][0] |
|
collected_answer[cat][1] += 1 |
|
if 'yes' in answer: |
|
collected_answer[cat][0] += 1 |
|
elif 'no' in answer: |
|
pass |
|
elif 'maybe' in answer: |
|
collected_answer[cat][0] += 1 |
|
else: |
|
NotImplementedError |
|
return collected_answer |
|
|
|
|
|
|
|
def process_coverage(data): |
|
object_id = data[0] |
|
case_info = data[1] |
|
gt_info = case_info['gt_info'] |
|
|
|
|
|
try: |
|
question_str, question_mapping = format_question(gt_info) |
|
except Exception as e: |
|
print(e) |
|
return None |
|
user_prompt = deepcopy(cap_user_prompt) |
|
user_prompt = user_prompt.replace("/video caption/", case_info['pred_caption']) |
|
user_prompt = user_prompt.replace("/question/", question_str) |
|
gpt_ret, _ = openai_obj.get_completion(user_prompt=user_prompt, system_prompt=system_prompt) |
|
try: |
|
coverage_res = parsing_results(gpt_ret, question_mapping) |
|
except Exception as e: |
|
print(e) |
|
print(gpt_ret) |
|
return None |
|
sentence_len = len(case_info['pred_caption'].split(' ')) |
|
return (object_id, gpt_ret, dict(coverage_res), sentence_len) |
|
|
|
|
|
def process_hallucination(data): |
|
object_id = data[0] |
|
case_info = data[1] |
|
pred_info = case_info['pred_info'] |
|
|
|
|
|
try: |
|
question_str, question_mapping = format_question(pred_info) |
|
except Exception as e: |
|
print(e) |
|
return None |
|
user_prompt = deepcopy(cap_user_prompt) |
|
user_prompt = user_prompt.replace("/video caption/", case_info['gt_caption']) |
|
user_prompt = user_prompt.replace("/question/", question_str) |
|
gpt_ret, _ = openai_obj.get_completion(user_prompt=user_prompt, system_prompt=system_prompt) |
|
try: |
|
hallucination_res = parsing_results(gpt_ret, question_mapping) |
|
except Exception as e: |
|
print(e) |
|
print(gpt_ret) |
|
return None |
|
|
|
|
|
|
|
return (object_id, gpt_ret, dict(hallucination_res)) |
|
|
|
|
|
|
|
def compute_refine_chair(pred_file, gt_file, coverage_file, hallucination_file): |
|
coverage_metric = defaultdict(lambda:[0,0]) |
|
hallucination_metric = defaultdict(lambda:[0,0]) |
|
case_len = [] |
|
|
|
with open(pred_file, 'r', encoding='utf-8') as f: |
|
pred_info = json.load(f) |
|
with open(gt_file, 'r', encoding='utf-8') as f: |
|
gt_info = json.load(f) |
|
|
|
combined_info = combine_info(pred_info, gt_info) |
|
saved_combined_info = deepcopy(combined_info) |
|
combine_info_lst = list(combined_info.items()) |
|
|
|
pool = Pool(processes=32) |
|
print('calculate coverage') |
|
dict_res_coverage = {} |
|
for res in tqdm(pool.imap_unordered(process_coverage, combine_info_lst), total=len(combine_info_lst)): |
|
if res is None: |
|
continue |
|
object_id, gpt_ret, coverage_res, sentence_len = res |
|
_add(coverage_res, coverage_metric) |
|
case_len.append(sentence_len) |
|
saved_combined_info[object_id]['coverage_res'] = gpt_ret |
|
dict_res_coverage[str(object_id)] = coverage_res |
|
|
|
print('calculate hallucination') |
|
dict_res_hallucination = {} |
|
for res in tqdm(pool.imap_unordered(process_hallucination, combine_info_lst), total=len(combine_info_lst)): |
|
if res is None: |
|
continue |
|
object_id, gpt_ret, hallucination_res = res |
|
_add(hallucination_res, hallucination_metric) |
|
saved_combined_info[object_id]['hallucination_res'] = gpt_ret |
|
dict_res_hallucination[str(object_id)] = hallucination_res |
|
|
|
pool.close() |
|
pool.join() |
|
|
|
output_dir = Path(pred_file).parent |
|
|
|
with (output_dir / coverage_file).open('w') as f: |
|
json.dump(dict_res_coverage, f, indent=4) |
|
print(f"Saving coverage result for each video in {output_dir}") |
|
|
|
with (output_dir / hallucination_file).open('w') as f: |
|
json.dump(dict_res_hallucination, f, indent=4) |
|
print(f"Saving hallucination result for each video in {output_dir}") |
|
|
|
save_metric(coverage_metric, hallucination_metric, case_len, output_dir) |
|
with (output_dir / 'chair_metric_detailed.json').open('w') as f: |
|
json.dump(saved_combined_info, f, indent=4) |
|
|
|
|
|
def print_metrics(hallucination_cap_dict, quiet=False): |
|
sentence_metrics = hallucination_cap_dict['overall_metrics'] |
|
metric_string = "%0.01f\t%0.01f" %(sentence_metrics['CHAIRs']*100, |
|
sentence_metrics['CHAIRi']*100) |
|
if not quiet: |
|
print("CHAIRs\tCHAIRi") |
|
print(metric_string) |
|
print(sentence_metrics['sentence len']) |
|
print(sentence_metrics['avg objects']) |
|
else: |
|
return metric_string |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--pred_file", type=str, default='/mnt/bn/yukunfeng-nasdrive/xiangchen/model/masp_models/checkpoints/mistral-ablation-v077-ocr/video_chair/vid_top1k_neg_res_non_dup_info.json') |
|
parser.add_argument("--gt_file", type=str, default='/mnt/bn/yukunfeng-nasdrive/xiangchen/repo/benchmark_data/refine_chair_eval_gt_neg_1k.json') |
|
parser.add_argument("--coverage_filename", type=str, default='each_video_coverage_detail.json') |
|
parser.add_argument("--hallucination_filename", type=str, default='each_video_halluciantion_detail.json') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
compute_refine_chair(args.pred_file, args.gt_file, args.coverage_filename, args.hallucination_filename) |
|
|