model1 / llava /datasets /lk_image_dataset.py
multitensor's picture
Upload folder using huggingface_hub
bbfa6f6 verified
import datasets
import torch
import re
import os
import subprocess
from llava.datasets.builder import DATASETS
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import ImageTaskDataset
from llava.constants import DEFAULT_IMAGE_TOKEN
from llava.datasets.data_cfgs import data_configs
from llava.utils import master_print
class LKImageDataset(ImageTaskDataset):
def __init__(self, anno_path=None, data_args=None, aux_args=None, name='lk_image'):
super().__init__(anno_path=anno_path,
data_args=data_args,
name=name)
def __len__(self):
return len(self.annotation)
def text_preprocess(self, item) -> List[Dict[str, str]]:
return item['conversations']
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
item = self.annotation[i]
vis_path = item['image']
ret = {
'images': self.vis_preprocess(vis_path),
'conversations': self.text_preprocess(item)
}
if 'id' in item:
ret['id'] = item['id']
return ret
@DATASETS.register_obj
def lk_image(data_args):
data_cfg = data_configs['lk_image']
return LKImageDataset(data_cfg['train_data_path'], data_args, aux_args=data_cfg)
# if __name__ == '__main__':
# import json
# from tqdm import tqdm
# with open('/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images_im.json') as f:
# data = json.load(f)
# filterd_data = []
# for idx, item in tqdm(enumerate(data)):
# image_path = item['image']
# if os.path.exists(image_path):
# filterd_data.append(item)
# else:
# print(image_path)
# with open('/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images_im.json', 'w') as f:
# json.dump(filterd_data, f)