model1 / llava /datasets /gpt4v_public_dataset.py
Wangpeng An
Upload folder using huggingface_hub
bbfa6f6 verified
from llava.datasets.builder import DATASETS
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import FramesTaskDataset
from llava.datasets.data_cfgs import data_configs
import pickle
from pathlib import Path
import random
import numpy as np
from llava.datasets.prompts import tt_caption_prompt, internvid_prompt
from llava.constants import DEFAULT_VIDEO_TOKEN
from PIL import Image
import json
import torch
import os
class GPT4VPublicDataset(FramesTaskDataset):
def __init__(self, anno_path=None, data_args=None, fps=1.0, conv_type='single', task_types=None, sample_method='uniform', name='gpt4v_public'):
self.default_fps = 1.0
self.fps = fps
self.conv_type = conv_type
self.task_types = task_types
self.annotation = self.get_dataset(anno_path)
self.sample_method = sample_method
assert self.conv_type in ('single', 'multi'), "gpt4v_public conv type must in single/multi"
assert self.sample_method in ('sequential', 'uniform'), "gpt4v_public sample method must in sequential/uniform"
# assert hasattr(self.data_args, 'task_types') , "gpt4v_public must have key 'task_types' in yaml config"
# master_print(f"Finished loading dataset {name} {len(self.annotation)} samples...")
super().__init__(anno_path=anno_path,
data_args=data_args,
fps=fps,
name=name)
def __len__(self):
return len(self.annotation)
def get_dataset(self, anno_path):
dataset = []
anno_path = Path(anno_path)
with anno_path.open('rb') as f:
data = json.load(f)
for info in data:
filtered_qa = []
if 'qa_pairs' not in info:
index = 0
while index < len(info['conversation']):
if len(info['conversation'][index].strip()) == 0:
index += 1
continue
if 'C' in info['conversation'][index]:
if index+1 < len(info['conversation']) and 'A' in info['conversation'][index+1]:
filtered_qa.append(
[info['conversation'][index], info['conversation'][index+1]]
)
index += 2
else:
index += 1
continue
else:
# print(info['conversation'][index])
index += 1
continue
else:
for qa in info['qa_pairs']:
if len(qa[0]) == 0 or len(qa[1]) == 0:
continue
filtered_qa.append(qa)
info['qa_pairs'] = filtered_qa
for task_type in self.task_types:
info_task = info.copy()
if len(info_task[task_type]) == 0:
continue
if task_type == 'qa_pairs' and self.conv_type == 'single':
for qa_pair in info_task[task_type]:
one_info = info_task.copy()
one_info[task_type] = [qa_pair]
one_info.update({
'task_type': task_type
})
dataset.append(one_info)
else:
info_task.update({
'task_type': task_type
})
dataset.append(info_task)
return dataset
# @staticmethod
# def _sample_frames(frames, num_segments):
# indices = list(range(num_segments))
# frames = [frames[ind] for ind in indices]
# return frames
def text_preprocess(self, item) -> List[Dict[str, str]]:
all_convs = []
# TODO: different prompt for summary and detail
if item['task_type'] == 'summary':
summary = ''
if isinstance(item['summary'], list):
for s in item['summary']:
if len(s.strip()) != 0:
summary = s
break
else:
summary = item['summary']
all_convs.append([
{
'from': 'human',
'value': random.choice(internvid_prompt)
},
{
'from': 'model',
'value': summary
}
])
elif item['task_type'] == 'detail':
detail = ''
if isinstance(item['detail'], list):
for s in item['detail']:
if len(s.strip()) != 0:
detail = s
break
else:
detail = item['detail']
all_convs.append([
{
'from': 'human',
'value': random.choice(tt_caption_prompt)
},
{
'from': 'model',
'value': detail
}
])
else:
for qa in item['qa_pairs']:
all_convs.append([
{
'from': 'human',
'value': qa[0]
},
{
'from': 'model',
'value': qa[1]
}
])
conversations = []
random.shuffle(all_convs)
for idx, conv in enumerate(all_convs):
if idx == 0:
conv[0]['value'] = DEFAULT_VIDEO_TOKEN + conv[0]['value']
conversations.extend(conv)
return conversations
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
item = self.annotation[i]
ret = {
'images': self.vis_preprocess(item['vis_path']),
'conversations': self.text_preprocess(item)
}
if 'id' in item:
ret['id'] = item['id']
return ret
def _sample_frames(self, frames, num_segments, preprocess=False):
if preprocess:
if self.sample_method == 'uniform':
indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int)
elif self.sample_method == 'sequential':
indices = range(10)
else:
raise NotImplementedError
frames = [frames[ind] for ind in indices]
else:
indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int)
frames = [frames[ind] for ind in indices]
return frames
def vis_preprocess(self, vis_path):
image_files = []
for img_path in os.listdir(vis_path):
if img_path.endswith('.jpeg'):
img_idx = int(img_path.split('_')[-1][:-5])
image_files.append((img_idx, img_path))
image_files = sorted(image_files, key=lambda img: img[0])
# TODO: addhoc fix, only 10 frames
if len(image_files) > 10:
image_files = self._sample_frames(image_files, 10, preprocess=True)
if self.num_segments > 0 and len(image_files) > self.num_segments:
image_files = self._sample_frames(image_files, self.num_segments)
images = []
for image_file in image_files:
try:
images.append(Image.open(os.path.join(vis_path, image_file[1])).convert('RGB'))
except Exception as e:
continue
formatted_images = []
for image in images:
im = self.preprocess_image(image)
if isinstance(im, list):
formatted_images.extend(im)
else:
formatted_images.append(im)
return formatted_images
@DATASETS.register_obj
def gpt4v_public(data_args):
data_cfg = data_configs['gpt4v_public']
if 'train_data_path' in data_args.external_args:
data_cfg['train_data_path'] = data_args.external_args['train_data_path']
anno_path = data_cfg['train_data_path']
fps, conv_type, task_types = data_args.external_args['fps'], data_args.external_args['conv_type'], data_args.external_args['task_types']
if 'sample_method' in data_args.external_args:
sample_method = data_args.external_args['sample_method']
else:
sample_method = 'uniform'
return GPT4VPublicDataset(anno_path, data_args, fps, conv_type, task_types, sample_method)
if __name__ == '__main__':
pass
# import pickle
# from tqdm import tqdm
# file_paths = ['/mnt/bn/algo-masp-nas-2/xianyang/clean_annotations/annotations/webvid10m',
# '/mnt/bn/algo-masp-nas-2/xianyang/clean_annotations/annotations/webvid2m']
# frame_paths = ['/mnt/bn/algo-masp-nas-2/xianyang/clean_annotations/frames/webvid10m',
# '/mnt/bn/algo-masp-nas-2/xianyang/clean_annotations/frames/webvid2m']
# data = []
# for file_path, frame_path in zip(file_paths, frame_paths):
# file_path = Path(file_path)
# for pkl_path in tqdm(file_path.glob('*')):
# with pkl_path.open('rb') as f:
# info = pickle.load(f)
# pkl_name = pkl_path.name[:-4]
# frame_folder_path = Path(frame_path) / pkl_name
# info['vis_path'] = str(frame_folder_path)
# if os.path.exists(info['vis_path']):
# data.append(info)
# with open ('/mnt/bn/algo-masp-nas-2/xiangchen/data/shared_gpt4v_data/data_500k.json', 'w') as f:
# json.dump(data, f)
# if frame_path.exists():
# print(1)
# with open('/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images.json') as f:
# data = json.load(f)
# data_im = []
# data_vid = []
# for sample in data:
# if 'image' in sample:
# data_im.append(sample)
# else:
# data_vid.append(sample)
# with open('/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images_im.json', 'w') as f:
# json.dump(data_im, f)
# with open('/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images_vid.json', 'w') as f:
# json.dump(data_vid, f)