File size: 4,750 Bytes
bbfa6f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from llava.datasets.builder import DATASETS
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import FramesTaskDataset
from llava.datasets.data_cfgs import data_configs
import pickle
from pathlib import Path
import random
import numpy as np
from llava.datasets.prompts import tt_caption_prompt, internvid_prompt
from llava.constants import DEFAULT_VIDEO_TOKEN
from PIL import Image
import json
import torch
import os
class LKVideoDataset(FramesTaskDataset):
def __init__(self, anno_path=None, data_args=None, fps=1.0, conv_type='multi', select_datasets=None, name='lk_video'):
self.default_fps = 1.0
self.fps = fps
self.conv_type = conv_type
self.select_datasets = select_datasets
self.annotation = self.get_dataset(anno_path)
#TODO: support single
assert self.conv_type in ('multi'), "lk_video conv type must be multi"
# assert hasattr(self.data_args, 'task_types') , "gpt4v_public must have key 'task_types' in yaml config"
# master_print(f"Finished loading dataset {name} {len(self.annotation)} samples...")
super().__init__(anno_path=anno_path,
data_args=data_args,
fps=fps,
name=name)
def __len__(self):
return len(self.annotation)
def get_dataset(self, anno_path):
anno_path = Path(anno_path)
with anno_path.open('rb') as f:
data = json.load(f)
if self.select_datasets is not None:
filtered_data = []
for sample in data:
video_path = Path(sample['video'])
dataset_name = video_path.parent.name
if dataset_name in self.select_datasets:
filtered_data.append(sample)
data = filtered_data
return data
def text_preprocess(self, item) -> List[Dict[str, str]]:
return item['conversations']
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
item = self.annotation[i]
ret = {
'images': self.vis_preprocess(item['video']),
'conversations': self.text_preprocess(item)
}
if 'id' in item:
ret['id'] = item['id']
return ret
@staticmethod
def _sample_frames(frames, num_segments):
indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int)
frames = [frames[ind] for ind in indices]
return frames
def vis_preprocess(self, vis_path):
image_files = []
for img_path in os.listdir(vis_path):
if img_path.endswith('.jpeg'):
img_idx = int(img_path.split('_')[-1][:-5])
image_files.append((img_idx, img_path))
image_files = sorted(image_files, key=lambda img: img[0])
# TODO: addhoc fix, only 10 frames
if len(image_files) > 10:
image_files = self._sample_frames(image_files, 10)
if self.num_segments > 0 and len(image_files) > self.num_segments:
image_files = self._sample_frames(image_files, self.num_segments)
images = []
for image_file in image_files:
try:
images.append(Image.open(os.path.join(vis_path, image_file[1])).convert('RGB'))
except Exception as e:
continue
formatted_images = []
for image in images:
im = self.preprocess_image(image)
if isinstance(im, list):
formatted_images.extend(im)
else:
formatted_images.append(im)
return formatted_images
@DATASETS.register_obj
def lk_video(data_args):
data_cfg = data_configs['lk_video']
fps, conv_type = data_args.external_args['fps'], data_args.external_args['conv_type']
select_datasets = data_args.external_args['select_datasets'] if 'select_datasets' in data_args.external_args else None
return LKVideoDataset(data_cfg['train_data_path'], data_args, fps, conv_type, select_datasets=select_datasets)
# if __name__ == '__main__':
# import json
# from tqdm import tqdm
# with open('/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images_vid.json') as f:
# data = json.load(f)
# filterd_data = []
# for item in tqdm(data):
# image_path = item['video']
# if os.path.exists(image_path):
# filterd_data.append(item)
# else:
# print(image_path)
# with open('/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images_vid.json', 'w') as f:
# json.dump(filterd_data, f)
|