model1 / llava /datasets /promptv1_2_internal_dataset.py
multitensor's picture
Upload folder using huggingface_hub
bbfa6f6 verified
from llava.datasets.builder import DATASETS
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import FramesTaskDataset
from llava.datasets.data_cfgs import data_configs
import pickle
from pathlib import Path
import random
import numpy as np
from llava.datasets.prompts import tt_caption_prompt, internvid_prompt
from llava.constants import DEFAULT_VIDEO_TOKEN
from PIL import Image
import json
import torch
import os
class PromptV1Dataset(FramesTaskDataset):
def __init__(self, anno_path=None, data_args=None, name='promptv1_2_internal', task_types=None):
self.default_fps = 1.0
self.task_types = task_types
self.annotation = self.get_dataset(anno_path)
super().__init__(anno_path=anno_path,
data_args=data_args,
name=name)
def __len__(self):
return len(self.annotation)
def get_dataset(self, anno_path):
dataset = []
anno_path = Path(anno_path)
with anno_path.open('rb') as f:
data = json.load(f)
for info in data:
for task_type in self.task_types:
info_task = info.copy()
if task_type not in info or len(info_task[task_type]) == 0:
continue
if task_type == 'qas' and self.conv_type == 'single':
for qa_pair in info_task[task_type]:
one_info = info_task.copy()
one_info[task_type] = [qa_pair]
one_info.update({
'task_type': task_type
})
dataset.append(one_info)
else:
info_task.update({
'task_type': task_type
})
dataset.append(info_task)
return dataset
def text_preprocess(self, item) -> List[Dict[str, str]]:
all_convs = []
if hasattr(self.data_args, 'caption_prompt'):
cap_prompt = eval(self.data_args.caption_prompt)
else:
cap_prompt = tt_caption_prompt
if item['task_type'] == 'refine_caption':
all_convs.append([
{
'from': 'human',
'value': random.choice(cap_prompt)
},
{
'from': 'model',
'value': item['refine_caption']
}
])
else:
for idx, qa in enumerate(item['qas']):
all_convs.append([
{
'from': 'human',
'value': qa['q']
},
{
'from': 'model',
'value': qa['a']
}
])
conversations = []
random.shuffle(all_convs)
for idx, conv in enumerate(all_convs):
if idx == 0:
conv[0]['value'] = DEFAULT_VIDEO_TOKEN + conv[0]['value']
conversations.extend(conv)
return conversations
# def __getitem__(self, i) -> Dict[str, torch.Tensor]:
# item = self.annotation[i]
# ret = {
# 'images': self.vis_preprocess(item['video_path']),
# 'conversations': self.text_preprocess(item)
# }
# if 'id' in item:
# ret['id'] = item['id']
# return ret
# @staticmethod
# def _sample_frames(frames, num_segments):
# indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int)
# frames = [frames[ind] for ind in indices]
# return frames
# def vis_preprocess(self, vis_path):
# image_files = []
# for img_path in os.listdir(vis_path):
# if img_path.endswith('.jpeg'):
# img_idx = int(img_path.split('_')[-1][:-5])
# image_files.append((img_idx, img_path))
# image_files = sorted(image_files, key=lambda img: img[0])
# # TODO: addhoc fix, only 10 frames
# if len(image_files) > 10:
# image_files = self._sample_frames(image_files, 10)
# if self.num_segments > 0 and len(image_files) > self.num_segments:
# image_files = self._sample_frames(image_files, self.num_segments)
# images = []
# for image_file in image_files:
# try:
# images.append(Image.open(os.path.join(vis_path, image_file[1])).convert('RGB'))
# except Exception as e:
# continue
# formatted_images = []
# for image in images:
# im = self.preprocess_image(image)
# if isinstance(im, list):
# formatted_images.extend(im)
# else:
# formatted_images.append(im)
# return formatted_images
@DATASETS.register_obj
def promptv1_2_internal(data_args):
data_cfg = data_configs['promptv1_2_internal']
task_types = data_args.external_args['task_types']
return PromptV1Dataset(anno_path=data_cfg['train_data_path'], data_args=data_args, task_types=task_types)