model1 / llava /datasets /sharegpt4v_dataset.py
multitensor's picture
Upload folder using huggingface_hub
bbfa6f6 verified
import datasets
import torch
import re
import os
import subprocess
from llava.datasets.builder import DATASETS
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import ImageTaskDataset
from llava.constants import DEFAULT_IMAGE_TOKEN
from llava.datasets.data_cfgs import data_configs
from llava.utils import master_print
class ShareGPT4VDataset(ImageTaskDataset):
def __init__(self, anno_path=None, data_args=None, aux_args=None, name='sharegpt4v'):
super().__init__(anno_path=anno_path,
data_args=data_args,
name=name)
self.annotation = datasets.load_dataset("Lin-Chen/ShareGPT4V", "ShareGPT4V")['train']
self.aux_args = aux_args
master_print(f"Finished loading dataset {name} {len(self.annotation)} samples...")
def __len__(self):
return len(self.annotation)
def text_preprocess(self, item) -> List[Dict[str, str]]:
captions = item['conversations']
conversations = []
conv = [
{
'from': 'human',
'value': DEFAULT_IMAGE_TOKEN + captions[0]['value'].replace('<image>', '')
},
{
'from': 'model',
'value': captions[1]['value']
}
]
conversations.extend(conv)
return conversations
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
item = self.annotation[i]
if 'coco' in item['image']:
vis_path = os.path.join(self.aux_args['coco_dir'], item['image'])
elif 'llava' in item['image']:
file_names = item['image'].split('/')
vis_path = os.path.join(self.aux_args['llava_dir'], *file_names[-3:])
else:
vis_path = os.path.join(self.aux_args['other_dir'], item['image'])
ret = {
'images': self.vis_preprocess(vis_path),
'conversations': self.text_preprocess(item)
}
if 'id' in item:
ret['id'] = item['id']
return ret
@DATASETS.register_obj
def ShareGPT4V(data_args):
data_cfg = data_configs['sharegpt4v']
return ShareGPT4VDataset(None, data_args, aux_args=data_cfg)
if __name__ == '__main__':
dataset = datasets.load_dataset("Lin-Chen/ShareGPT4V", "ShareGPT4V")['train']
aux_args = data_configs['sharegpt4v']
for item in dataset:
if 'coco' in item['image']:
vis_path = os.path.join(aux_args['coco_dir'], item['image'])
elif 'llava' in item['image']:
file_names = item['image'].split('/')
vis_path = os.path.join(aux_args['llava_dir'], *file_names[-3:])
else:
vis_path = os.path.join(aux_args['other_dir'], item['image'])
if not os.path.exists(vis_path):
print(vis_path)
# with open('/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/sharegpt4v/sam.txt') as f:
# for line in f:
# items = line.split('\t')
# name = items[0].strip()
# url = items[1].strip()
# match = re.search(r'(\d+)', name).group(1)
# idx = int(match)
# if idx >= 60:
# continue
# print(name, url)
# output_file = os.path.join('/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/sharegpt4v/sam', name)
# try:
# subprocess.run(["wget", "-O", output_file, url], check=True)
# except subprocess.CalledProcessError as e:
# print("An error occurred while downloading the file.")
# from glob import glob
# file_path = '/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/sharegpt4v/sam'
# for file_name in glob(os.path.join(file_path, '*.tar')):
# subprocess.run(["tar", "-xf", file_name, '-C', '/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/sharegpt4v/sam/images'], check=True)