import datasets import torch import re import os import subprocess from llava.datasets.builder import DATASETS from typing import Dict, Optional, Sequence, List from llava.datasets.data_cfgs import data_configs from llava.datasets.base_dataset import ImageTaskDataset from llava.constants import DEFAULT_IMAGE_TOKEN from llava.datasets.data_cfgs import data_configs from llava.utils import master_print class LKImageDataset(ImageTaskDataset): def __init__(self, anno_path=None, data_args=None, aux_args=None, name='lk_image'): super().__init__(anno_path=anno_path, data_args=data_args, name=name) def __len__(self): return len(self.annotation) def text_preprocess(self, item) -> List[Dict[str, str]]: return item['conversations'] def __getitem__(self, i) -> Dict[str, torch.Tensor]: item = self.annotation[i] vis_path = item['image'] ret = { 'images': self.vis_preprocess(vis_path), 'conversations': self.text_preprocess(item) } if 'id' in item: ret['id'] = item['id'] return ret @DATASETS.register_obj def lk_image(data_args): data_cfg = data_configs['lk_image'] return LKImageDataset(data_cfg['train_data_path'], data_args, aux_args=data_cfg) # if __name__ == '__main__': # import json # from tqdm import tqdm # with open('/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images_im.json') as f: # data = json.load(f) # filterd_data = [] # for idx, item in tqdm(enumerate(data)): # image_path = item['image'] # if os.path.exists(image_path): # filterd_data.append(item) # else: # print(image_path) # with open('/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images_im.json', 'w') as f: # json.dump(filterd_data, f)