model1 / llava /eval /masp_eval /video_chair /cal_video_chair.py
multitensor's picture
Upload folder using huggingface_hub
bbfa6f6 verified
from audioop import avg
from email.policy import default
import os
import re
import json
import sys
import argparse
import openai
from abc import ABC, abstractmethod
# from pattern3.en import singularize
# from nltk.stem import WordNetLemmatizer
# from call_dino_service import
from tqdm import tqdm
from functools import partial
# import spacy
import time
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from multiprocessing import Pool
from llava.eval.masp_eval.utils import GPTAPIWrapper
# class RefineCHAIR():
# def __init__(self):
# self.system_prompt = "I am ChatGPT, a virtual assistant based on OpenAI's GPT-4 model. I'm designed to understand and generate human-like text based on the input I receive. My main purpose is to assist with information, answer questions, help with tasks that involve natural language processing, and engage in conversations with users.Please note that while I aim to provide accurate and reliable information, I can't guarantee perfection, and it's always a good idea to consult additional resources or professionals when making critical decisions based on the information I provide."
# self.openai_obj = GPTAPIWrapper(ak="GjrgjjyJHUbLa15DLnr7t0Bhu6IPqFPj")
# with open('llava/eval/masp_eval/video_chair/prompts/cap_mention.txt', 'r') as file:
# content = file.read()
# self.cap_user_prompt = content
system_prompt = "I am ChatGPT, a virtual assistant based on OpenAI's GPT-4 model. I'm designed to understand and generate human-like text based on the input I receive. My main purpose is to assist with information, answer questions, help with tasks that involve natural language processing, and engage in conversations with users.Please note that while I aim to provide accurate and reliable information, I can't guarantee perfection, and it's always a good idea to consult additional resources or professionals when making critical decisions based on the information I provide."
with open('llava/eval/masp_eval/video_chair/prompts/cap_mention.txt', 'r') as file:
content = file.read()
cap_user_prompt = content
openai_obj = GPTAPIWrapper(ak="GjrgjjyJHUbLa15DLnr7t0Bhu6IPqFPj")
def _add(case_res, all_res):
for key, value in case_res.items():
for idx, count_ in enumerate(value):
all_res[key][idx] += count_
return
def save_metric(coverage, hallucination, case_len, output_dir=None):
final_metrics = {}
for name, res in [['coverage', coverage], ['hallucination', hallucination]]:
combine_counter = [0, 0]
for cat, counter in res.items():
final_metrics[name+'_'+cat] = round(counter[0] * 100/ counter[1], 2)
combine_counter[0] += counter[0]
combine_counter[1] += counter[1]
if name == 'hallucination':
final_metrics[name+'_'+cat] = round(100 - final_metrics[name+'_'+cat], 2)
final_metrics[name] = round(combine_counter[0] * 100 / combine_counter[1], 2)
if name == 'hallucination':
final_metrics[name] = round(100 - final_metrics[name], 2)
final_metrics['avg_len'] = round(sum(case_len) / len(case_len), 1)
if output_dir is not None:
with (output_dir / 'chair_metric_neg.json').open('w') as f:
json.dump(final_metrics, f, indent=4)
print(json.dumps(final_metrics, indent=1))
def combine_info(pred_info, gt_info):
combined_info = defaultdict(dict)
if 'object_id' in gt_info[0]:
id_key = 'object_id'
else:
id_key = 'task_id'
for gt in gt_info:
object_id = gt[id_key]
if gt['cap_info'] is None:
continue
combined_info[object_id]['gt_caption'] = gt['refine_caption']
combined_info[object_id]['gt_info'] = gt['cap_info']
for pred in pred_info:
object_id = pred[id_key]
if object_id not in combined_info:
# print(pred)
continue
if pred['cap_info'] is None:
continue
combined_info[object_id]['pred_caption'] = pred['masp_inference']
combined_info[object_id]['pred_info'] = pred['cap_info']
filtered_ids = []
for key, value in combined_info.items():
if ('pred_info' not in value) or ('gt_info' not in value):
filtered_ids.append(key)
for obj_id in filtered_ids:
del combined_info[obj_id]
print(f'evaluation cases: {len(combined_info)}')
return combined_info
def format_question(info):
categories = ['subjects', 'activities', 'locations', 'text_overlays']
question_id = 0
question_mapping = {}
questions = []
for cat in categories:
if cat == 'subjects':
for c_id, character_content in enumerate(info['subjects']):
questions.append(cat + ':' + character_content['name'])
question_mapping[question_id] = (cat, c_id)
question_id += 1
if 'attributes' not in character_content:
continue
for a_id, attr in enumerate(character_content['attributes']):
questions.append(character_content['name'] + ':' + attr)
question_mapping[question_id] = ('attributes', c_id, a_id)
question_id += 1
else:
for c_id, cat_attr in enumerate(info[cat]):
questions.append(cat + ':' + cat_attr)
question_mapping[question_id] = (cat, c_id)
question_id += 1
question_str = ''
for idx, q in enumerate(questions):
question_str += f'{idx+1}. {q}' + '\n'
return question_str, question_mapping
def parsing_results(gpt_ret, question_mapping):
gpt_ret = gpt_ret.lower()
pattern = r'(\d+)\.(.+) - (yes|no|maybe),(.+)'
# Find all matches in the text
matches = re.findall(pattern, gpt_ret)
collected_answer = defaultdict(lambda:[0,0])
# Print the matches
for match in matches:
question_id, question, answer, reason = match
question_id = int(question_id) - 1
cat = question_mapping[question_id][0]
collected_answer[cat][1] += 1
if 'yes' in answer:
collected_answer[cat][0] += 1
elif 'no' in answer:
pass
elif 'maybe' in answer:
collected_answer[cat][0] += 1
else:
NotImplementedError
return collected_answer
def process_coverage(data):
object_id = data[0]
case_info = data[1]
gt_info = case_info['gt_info']
# if gt_info is None:
# return None
try:
question_str, question_mapping = format_question(gt_info)
except Exception as e:
print(e)
return None
user_prompt = deepcopy(cap_user_prompt)
user_prompt = user_prompt.replace("/video caption/", case_info['pred_caption'])
user_prompt = user_prompt.replace("/question/", question_str)
gpt_ret, _ = openai_obj.get_completion(user_prompt=user_prompt, system_prompt=system_prompt)
try:
coverage_res = parsing_results(gpt_ret, question_mapping)
except Exception as e:
print(e)
print(gpt_ret)
return None
sentence_len = len(case_info['pred_caption'].split(' '))
return (object_id, gpt_ret, dict(coverage_res), sentence_len)
def process_hallucination(data):
object_id = data[0]
case_info = data[1]
pred_info = case_info['pred_info']
# if pred_info is None:
# return None
try:
question_str, question_mapping = format_question(pred_info)
except Exception as e:
print(e)
return None
user_prompt = deepcopy(cap_user_prompt)
user_prompt = user_prompt.replace("/video caption/", case_info['gt_caption'])
user_prompt = user_prompt.replace("/question/", question_str)
gpt_ret, _ = openai_obj.get_completion(user_prompt=user_prompt, system_prompt=system_prompt)
try:
hallucination_res = parsing_results(gpt_ret, question_mapping)
except Exception as e:
print(e)
print(gpt_ret)
return None
# self._add(hallucination_res, evaluator.hallucination_metric)
# saved_combined_info[object_id]['hallucination_res'] = gpt_ret
# print(gpt_ret)
return (object_id, gpt_ret, dict(hallucination_res))
def compute_refine_chair(pred_file, gt_file, coverage_file, hallucination_file):
coverage_metric = defaultdict(lambda:[0,0])
hallucination_metric = defaultdict(lambda:[0,0])
case_len = []
with open(pred_file, 'r', encoding='utf-8') as f:
pred_info = json.load(f)
with open(gt_file, 'r', encoding='utf-8') as f:
gt_info = json.load(f)
combined_info = combine_info(pred_info, gt_info)
saved_combined_info = deepcopy(combined_info)
combine_info_lst = list(combined_info.items())
pool = Pool(processes=32)
print('calculate coverage')
dict_res_coverage = {}
for res in tqdm(pool.imap_unordered(process_coverage, combine_info_lst), total=len(combine_info_lst)):
if res is None:
continue
object_id, gpt_ret, coverage_res, sentence_len = res
_add(coverage_res, coverage_metric)
case_len.append(sentence_len)
saved_combined_info[object_id]['coverage_res'] = gpt_ret
dict_res_coverage[str(object_id)] = coverage_res
print('calculate hallucination')
dict_res_hallucination = {}
for res in tqdm(pool.imap_unordered(process_hallucination, combine_info_lst), total=len(combine_info_lst)):
if res is None:
continue
object_id, gpt_ret, hallucination_res = res
_add(hallucination_res, hallucination_metric)
saved_combined_info[object_id]['hallucination_res'] = gpt_ret
dict_res_hallucination[str(object_id)] = hallucination_res
pool.close()
pool.join()
output_dir = Path(pred_file).parent
with (output_dir / coverage_file).open('w') as f:
json.dump(dict_res_coverage, f, indent=4)
print(f"Saving coverage result for each video in {output_dir}")
with (output_dir / hallucination_file).open('w') as f:
json.dump(dict_res_hallucination, f, indent=4)
print(f"Saving hallucination result for each video in {output_dir}")
save_metric(coverage_metric, hallucination_metric, case_len, output_dir)
with (output_dir / 'chair_metric_detailed.json').open('w') as f:
json.dump(saved_combined_info, f, indent=4)
def print_metrics(hallucination_cap_dict, quiet=False):
sentence_metrics = hallucination_cap_dict['overall_metrics']
metric_string = "%0.01f\t%0.01f" %(sentence_metrics['CHAIRs']*100,
sentence_metrics['CHAIRi']*100)
if not quiet:
print("CHAIRs\tCHAIRi")
print(metric_string)
print(sentence_metrics['sentence len'])
print(sentence_metrics['avg objects'])
else:
return metric_string
# python3 chair/chair_gpt.py --cap_file /mnt/bd/bohanzhaiv1/LLM/bohan/POPE/caption_data/vg_instruction1_llava.json --annotation_path /mnt/bn/algo-masp-nas-2/masp_data/coco_2014/annotations
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--pred_file", type=str, default='/mnt/bn/yukunfeng-nasdrive/xiangchen/model/masp_models/checkpoints/mistral-ablation-v077-ocr/video_chair/vid_top1k_neg_res_non_dup_info.json')
parser.add_argument("--gt_file", type=str, default='/mnt/bn/yukunfeng-nasdrive/xiangchen/repo/benchmark_data/refine_chair_eval_gt_neg_1k.json')
parser.add_argument("--coverage_filename", type=str, default='each_video_coverage_detail.json')
parser.add_argument("--hallucination_filename", type=str, default='each_video_halluciantion_detail.json')
# parser.add_argument("--gt_file", type=str, default='/mnt/bn/yukunfeng-nasdrive/xiangchen/repo/benchmark_data/refine_chair_eval_gt.json')
args = parser.parse_args()
compute_refine_chair(args.pred_file, args.gt_file, args.coverage_filename, args.hallucination_filename)