|
|
|
from llava.datasets.builder import DATASETS |
|
|
|
from typing import Dict, Optional, Sequence, List |
|
from llava.datasets.data_cfgs import data_configs |
|
from llava.datasets.base_dataset import FramesTaskDataset |
|
from llava.datasets.data_cfgs import data_configs |
|
import pickle |
|
from pathlib import Path |
|
import random |
|
import numpy as np |
|
from llava.datasets.prompts import tt_caption_prompt, internvid_prompt |
|
from llava.constants import DEFAULT_VIDEO_TOKEN |
|
from PIL import Image |
|
import json |
|
import torch |
|
import os |
|
|
|
|
|
class LKVideoDataset(FramesTaskDataset): |
|
def __init__(self, anno_path=None, data_args=None, fps=1.0, conv_type='multi', select_datasets=None, name='lk_video'): |
|
self.default_fps = 1.0 |
|
self.fps = fps |
|
self.conv_type = conv_type |
|
self.select_datasets = select_datasets |
|
self.annotation = self.get_dataset(anno_path) |
|
|
|
assert self.conv_type in ('multi'), "lk_video conv type must be multi" |
|
|
|
|
|
super().__init__(anno_path=anno_path, |
|
data_args=data_args, |
|
fps=fps, |
|
name=name) |
|
def __len__(self): |
|
return len(self.annotation) |
|
|
|
|
|
def get_dataset(self, anno_path): |
|
anno_path = Path(anno_path) |
|
with anno_path.open('rb') as f: |
|
data = json.load(f) |
|
|
|
if self.select_datasets is not None: |
|
filtered_data = [] |
|
for sample in data: |
|
video_path = Path(sample['video']) |
|
dataset_name = video_path.parent.name |
|
if dataset_name in self.select_datasets: |
|
filtered_data.append(sample) |
|
data = filtered_data |
|
|
|
return data |
|
|
|
|
|
def text_preprocess(self, item) -> List[Dict[str, str]]: |
|
return item['conversations'] |
|
|
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
|
item = self.annotation[i] |
|
|
|
ret = { |
|
'images': self.vis_preprocess(item['video']), |
|
'conversations': self.text_preprocess(item) |
|
} |
|
if 'id' in item: |
|
ret['id'] = item['id'] |
|
|
|
return ret |
|
|
|
|
|
@staticmethod |
|
def _sample_frames(frames, num_segments): |
|
indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int) |
|
|
|
frames = [frames[ind] for ind in indices] |
|
|
|
return frames |
|
|
|
def vis_preprocess(self, vis_path): |
|
image_files = [] |
|
for img_path in os.listdir(vis_path): |
|
if img_path.endswith('.jpeg'): |
|
img_idx = int(img_path.split('_')[-1][:-5]) |
|
image_files.append((img_idx, img_path)) |
|
|
|
image_files = sorted(image_files, key=lambda img: img[0]) |
|
|
|
if len(image_files) > 10: |
|
image_files = self._sample_frames(image_files, 10) |
|
if self.num_segments > 0 and len(image_files) > self.num_segments: |
|
image_files = self._sample_frames(image_files, self.num_segments) |
|
|
|
images = [] |
|
for image_file in image_files: |
|
try: |
|
images.append(Image.open(os.path.join(vis_path, image_file[1])).convert('RGB')) |
|
except Exception as e: |
|
continue |
|
formatted_images = [] |
|
for image in images: |
|
im = self.preprocess_image(image) |
|
if isinstance(im, list): |
|
formatted_images.extend(im) |
|
else: |
|
formatted_images.append(im) |
|
return formatted_images |
|
|
|
|
|
@DATASETS.register_obj |
|
def lk_video(data_args): |
|
data_cfg = data_configs['lk_video'] |
|
fps, conv_type = data_args.external_args['fps'], data_args.external_args['conv_type'] |
|
select_datasets = data_args.external_args['select_datasets'] if 'select_datasets' in data_args.external_args else None |
|
return LKVideoDataset(data_cfg['train_data_path'], data_args, fps, conv_type, select_datasets=select_datasets) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|