General-Level-Scorer / predictors /video_comprehension_qa_caption.py
General-Level
Resolve conflict
0eb3766
import tqdm
from typing import List, Dict, Any
from dataclasses import dataclass
from abc import ABC, abstractmethod
from PIL import Image
import numpy as np
import os
import json
import argparse
import torch
from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
LlavaOnevisionForConditionalGeneration, AutoProcessor)
# An example of the model
class LLavaOneVisionModel:
def __init__(self, model_name="llava-hf/llava-onevision-qwen2-7b-ov-hf"):
self.model_name = model_name
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = model
self.tokenizer = tokenizer
def generate(self, conversation, video):
prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = self.processor(images=video, text=prompt, return_tensors="pt").to(self.model.device, torch.float16)
outputs = self.model.generate(**inputs, max_new_tokens=256, do_sample=False)
text_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
text_response = text_response.split('assistant\n')[1]
return text_response
@dataclass
class Instance:
input: Dict[str, Any]
output: Dict[str, Any]
id: str
class BaseTask(ABC):
def __init__(self, task_data: Dict[str, Any], model):
self.task_data = task_data
self.model = model
self.data = self._parse_data(task_data)
@abstractmethod
def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
pass
@abstractmethod
def evaluate(self) -> Dict[str, float]:
pass
@abstractmethod
def run_inference(self):
pass
def cal_accuracy(predictions: List[str], references: List[str]) -> float:
correct = 0
for pred, ref in zip(predictions, references):
if isinstance(ref, str):
ref = [ref]
is_match_this_turn = False
for r in ref:
if "yes" in r.lower() or "no" in r.lower():
# for yes or no question
r = r.lower()
pred = pred.lower()
if r.strip() in pred.strip():
is_match_this_turn = True
if is_match_this_turn:
correct += 1
return correct / len(predictions) if predictions else 0.0
class Bleu1_Scorer():
def __init__(self, predictions, references):
from pycocoevalcap.bleu.bleu import Bleu
self.pred = predictions
self.gt = references
self.scorers = [
(Bleu(4), ['Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4']),
]
def compute_scores(self):
total_scores = {}
for scorer, method in self.scorers:
print('Computing %s score...' % (scorer.method()))
score, scores = scorer.compute_score(self.gt, self.pred)
if isinstance(method, list):
for sc, scs, m in zip(score, scores, method):
print('%s: %0.3f' % (m, sc * 100))
total_scores['Bleu'] = [x * 100 for x in score]
else:
total_scores[method] = score * 100
return {"Bleu_1": total_scores['Bleu'][0]}
class AccTask(BaseTask):
def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
self.task_name = task_data["task"]
return [Instance(input=d["input"], output=d["output"], id=d["id"])
for d in task_data["data"]]
def read_video_frames(self, data_path_list, root_path, max_frames_num=64):
frames = []
if len(data_path_list) > max_frames_num:
frame_idx = np.linspace(0, len(data_path_list) - 1, max_frames_num, dtype=int)
data_path_list = [data_path_list[i] for i in frame_idx]
for frame_path in data_path_list:
path = os.path.join(root_path, frame_path)
if os.path.exists(path):
try:
frame = Image.open(path)
frames.append(frame)
except Exception as e:
print(f"Warning: Failed to read frame {path}. Error: {e}")
else:
print(f"Warning: Frame path {path} does not exist.")
return frames
def run_inference(self, root_path):
if os.path.exists(f'./predictions_{self.task_name}.json'):
self.predictions = json.load(open(f'./predictions_{self.task_name}.json', 'r'))
self.references = json.load(open(f'./references_{self.task_name}.json', 'r'))
return
self.predictions = []
self.references = []
for inst in tqdm.tqdm(self.data):
video_path = inst.input['video_file_list']
video = self.read_video_frames(video_path, os.path.join(root_path, self.task_name, 'videos'), max_frames_num=64)
question = 'Please answer the following question related to the video. ' + inst.input['prompt']
other_requirements = ''
if 'VideoActionCounting' in self.task_name:
other_requirements = 'The output must consist only of Arabic numerals.'
if 'VideoActionOrdering' in self.task_name:
other_requirements = 'The output format must be: [num]->[num]->[num]->[num]. The number represents the index marked in the question. For example: 2->1->3->4, 1->2->3->4, 3->2->1->4...'
if 'SignLanguageVideoRecognition' in self.task_name:
other_requirements = 'The output format must be a word.'
question += other_requirements
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": question},
{"type": "video"},
],
},
]
text_response = self.model.generate(conversation, video)
self.predictions.append(text_response)
self.references.append(inst.output["text"])
json.dump(self.predictions, open(f'./predictions_{self.task_name}.json', 'w'))
json.dump(self.references, open(f'./references_{self.task_name}.json', 'w'))
def evaluate(self) -> Dict[str, float]:
acc = cal_accuracy(self.predictions, self.references)
return {"accuracy": acc*100}
class BLEUTASK(BaseTask):
def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
self.task_name = task_data["task"]
return [Instance(input=d["input"], output=d["output"], id=d["id"])
for d in task_data["data"]]
def read_video_frames(self, data_path_list, root_path, max_frames_num=64):
frames = []
if len(data_path_list) > max_frames_num:
frame_idx = np.linspace(0, len(data_path_list) - 1, max_frames_num, dtype=int)
data_path_list = [data_path_list[i] for i in frame_idx]
for frame_path in data_path_list:
path = os.path.join(root_path, frame_path)
if os.path.exists(path):
try:
frame = Image.open(path)
frames.append(frame)
except Exception as e:
print(f"Warning: Failed to read frame {path}. Error: {e}")
else:
print(f"Warning: Frame path {path} does not exist.")
return frames
def run_inference(self, root_path):
if os.path.exists(f'./predictions_{self.task_name}.json'):
self.predictions = json.load(open(f'./predictions_{self.task_name}.json', 'r'))
self.references = json.load(open(f'./references_{self.task_name}.json', 'r'))
return
self.predictions = []
self.references = []
for inst in tqdm.tqdm(self.data):
video_path = inst.input['video_file_list']
video = self.read_video_frames(video_path, os.path.join(root_path, self.task_name, 'videos'), max_frames_num=64)
question = 'Please answer the following question related to the video. ' + inst.input['prompt']
other_requirements = ' The output should be concise. '
question += other_requirements
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": question},
{"type": "video"},
],
},
]
text_response = self.model.generate(conversation, video)
self.predictions.append(text_response)
self.references.append(inst.output["text"])
json.dump(self.predictions, open(f'./predictions_{self.task_name}.json', 'w'))
json.dump(self.references, open(f'./references_{self.task_name}.json', 'w'))
def evaluate(self) -> Dict[str, float]:
predictions = {}
references = {}
num = 1
for pred, ref in zip(self.predictions, self.references):
predictions[str(num)] = [pred.lower()]
references[str(num)] = [ref.lower()]
num += 1
bleu1_scorer = Bleu1_Scorer(predictions, references)
bleu1_scores = bleu1_scorer.compute_scores()
return bleu1_scores
def log_performance(model_name, task_name, metrics, root_path, output_file='performance_log.csv'):
import csv
file_exists = os.path.isfile(os.path.join(root_path, output_file))
row_data = {
'model': model_name,
'task': task_name,
'metrics': str(metrics)
}
with open(os.path.join(root_path, output_file), mode='a', newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=row_data.keys())
if not file_exists:
writer.writeheader()
writer.writerow(row_data)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--root_path", type=str, default="General-Bench-Openset/video/comprehension")
parser.add_argument("--model_name", type=str, default="llava-hf/llava-onevision-qwen2-7b-ov-hf")
args = parser.parse_args()
root_path = args.root_path
model_name = args.model_name
model = LLavaOneVisionModel(model_name=model_name) # An example of the model
# 56 tasks
task_files = [
"AgricultureVideoQuestionAnswering",
"ArtRecognition",
"ArtsAndCraftsVideoCaptioning",
"AutosAndVehiclesVideoCaptioning",
"BallGameVideoQuestionAnswering",
"BallSportsVideoCaptioning",
"BodyMotionVideoCaptioning",
"BusinessVideoCaptioning",
"ComedyVideoQuestionAnswering",
"DailyLifeAndSkillsVideoCaptioning",
"EducationVideoQuestionAnswering",
"EntertainmentRelatedVideoCaptioning",
"FacialActionVideoCaptioning",
"FacialObjectOperationsVideoCaptioning",
"FinanceVideoCaptioning",
"FoodVideoCaptioning",
"GameVideoQuestionAnswering",
"GeographyVideoQuestionAnswering",
"GymnasticsVideoQuestionAnswering",
"HistoryAndLiteratureVideoCaptioning",
"HumanHumanInteractionVideoCaptioning",
"HumanObjectInteractionVideoCaptioning",
"HumanObjectInteractionVideoQuestionAnswering",
"HumanSurvivalVideoQuestionAnswering",
"HumorVideoCaptioning",
"MilitaryVideoQuestionAnswering",
"MovieAndShowVideoCaptioning",
"MovieVideoQuestionAnswering",
"MusicalInstrumentsVideoCaptioning",
"MusicVideoQuestionAnswering",
"NaturalDisasterVideoRecognition",
"NewsAndDocumentaryVideoCaptioning",
"ObjectColorVideoQuestionAnswering",
"ObjectDirectionVideoQuestionAnswering",
"ObjectLocationVideoQuestionAnswering",
"ObjectMotionVideoQuestionAnswering",
"PersonalCareVideoCaptioning",
"PetsVideoQuestionAnswering",
"PetsVideoRecognition",
"ScienceAndTechnologyVideoCaptioning",
"ScienceVideoQuestionAnswering",
"ScienceVideoRecognition",
"SignLanguageVideoRecognition",
"SportsAndExcerciseVideoCaptioning",
"SportsVideoQuestionAnswering",
"TVShowRecognition",
"VideoActionCounting",
"VideoActionOrdering",
"VideoActionSequencePrediction",
"VideoActionSequenceUnderstanding",
"VideoAnimalRecognition",
"VideoFoodRecognition",
"VideoObjectCounting",
"VideoObjectExistenceRecognition",
"VideoObjectInteractionRecognition",
"VideoSportsRecognition",
]
task_files = [w + '.json' if not w.endswith('json') else w for w in task_files]
if isinstance(task_files, str):
task_files = [task_files]
for idx, filename in enumerate(task_files):
file_path = os.path.join(root_path, f"{filename.replace('.json', '')}/", "annotation.json")
if not os.path.exists(file_path):
continue
with open(file_path, 'r', encoding='utf-8') as f:
task_data = json.load(f)
task_type = task_data["type"]
task_name = task_data["task"]
print(f"Running evaluation for task {idx + 1}: {task_name}")
TASK_MAPPING = {
"AgricultureVideoQuestionAnswering": BLEUTASK,
"ArtRecognition": AccTask,
"ArtsAndCraftsVideoCaptioning": BLEUTASK,
"AutosAndVehiclesVideoCaptioning": BLEUTASK,
"BallGameVideoQuestionAnswering": AccTask,
"BallSportsVideoCaptioning": BLEUTASK,
"BodyMotionVideoCaptioning": BLEUTASK,
"BusinessVideoCaptioning": BLEUTASK,
"ComedyVideoQuestionAnswering": BLEUTASK,
"DailyLifeAndSkillsVideoCaptioning": BLEUTASK,
"EducationVideoQuestionAnswering": AccTask,
"EntertainmentRelatedVideoCaptioning": BLEUTASK,
"FacialActionVideoCaptioning": BLEUTASK,
"FacialObjectOperationsVideoCaptioning": BLEUTASK,
"FinanceVideoCaptioning": BLEUTASK,
"FoodVideoCaptioning": BLEUTASK,
"GameVideoQuestionAnswering": BLEUTASK,
"GeographyVideoQuestionAnswering": BLEUTASK,
"GymnasticsVideoQuestionAnswering": AccTask,
"HistoryAndLiteratureVideoCaptioning": BLEUTASK,
"HumanHumanInteractionVideoCaptioning": BLEUTASK,
"HumanObjectInteractionVideoCaptioning": BLEUTASK,
"HumanObjectInteractionVideoQuestionAnswering": BLEUTASK,
"HumanSurvivalVideoQuestionAnswering": BLEUTASK,
"HumorVideoCaptioning": BLEUTASK,
"MilitaryVideoQuestionAnswering": BLEUTASK,
"MovieAndShowVideoCaptioning": BLEUTASK,
"MovieVideoQuestionAnswering": BLEUTASK,
"MusicalInstrumentsVideoCaptioning": BLEUTASK,
"MusicVideoQuestionAnswering": BLEUTASK,
"NaturalDisasterVideoRecognition": BLEUTASK,
"NewsAndDocumentaryVideoCaptioning": BLEUTASK,
"ObjectColorVideoQuestionAnswering": AccTask,
"ObjectDirectionVideoQuestionAnswering": BLEUTASK,
"ObjectLocationVideoQuestionAnswering": AccTask,
"ObjectMotionVideoQuestionAnswering": AccTask,
"PersonalCareVideoCaptioning": BLEUTASK,
"PetsVideoQuestionAnswering": BLEUTASK,
"PetsVideoRecognition": BLEUTASK,
"ScienceAndTechnologyVideoCaptioning": BLEUTASK,
"ScienceVideoQuestionAnswering": BLEUTASK,
"ScienceVideoRecognition": BLEUTASK,
"SignLanguageVideoRecognition": AccTask,
"SportsAndExcerciseVideoCaptioning": BLEUTASK,
"SportsVideoQuestionAnswering": BLEUTASK,
"TVShowRecognition": AccTask,
"VideoActionCounting": AccTask,
"VideoActionOrdering": AccTask,
"VideoActionSequencePrediction": BLEUTASK,
"VideoActionSequenceUnderstanding": BLEUTASK,
"VideoAnimalRecognition": AccTask,
"VideoFoodRecognition": AccTask,
"VideoObjectCounting": BLEUTASK,
"VideoObjectExistenceRecognition": BLEUTASK,
"VideoObjectInteractionRecognition": BLEUTASK,
"VideoSportsRecognition": AccTask,
}
task_class = TASK_MAPPING.get(task_name)
if task_class is None:
raise NotImplementedError
else:
task = task_class(task_data, model)
task.run_inference(root_path=root_path)
metrics = task.evaluate()
print("Task name: ", task_name, "Task type: ", task_type, "Evaluation results:", metrics)
log_performance(model_name, task_name, metrics, '../outcome/', output_file='video_comprehension_qa_caption_performance_log.csv')