import datasets import torch import re import os import subprocess from llava.datasets.builder import DATASETS from typing import Dict, Optional, Sequence, List from llava.datasets.data_cfgs import data_configs from llava.datasets.base_dataset import ImageTaskDataset from llava.constants import DEFAULT_IMAGE_TOKEN from llava.datasets.data_cfgs import data_configs from llava.utils import master_print class ShareGPT4VDataset(ImageTaskDataset): def __init__(self, anno_path=None, data_args=None, aux_args=None, name='sharegpt4v'): super().__init__(anno_path=anno_path, data_args=data_args, name=name) self.annotation = datasets.load_dataset("Lin-Chen/ShareGPT4V", "ShareGPT4V")['train'] self.aux_args = aux_args master_print(f"Finished loading dataset {name} {len(self.annotation)} samples...") def __len__(self): return len(self.annotation) def text_preprocess(self, item) -> List[Dict[str, str]]: captions = item['conversations'] conversations = [] conv = [ { 'from': 'human', 'value': DEFAULT_IMAGE_TOKEN + captions[0]['value'].replace('', '') }, { 'from': 'model', 'value': captions[1]['value'] } ] conversations.extend(conv) return conversations def __getitem__(self, i) -> Dict[str, torch.Tensor]: item = self.annotation[i] if 'coco' in item['image']: vis_path = os.path.join(self.aux_args['coco_dir'], item['image']) elif 'llava' in item['image']: file_names = item['image'].split('/') vis_path = os.path.join(self.aux_args['llava_dir'], *file_names[-3:]) else: vis_path = os.path.join(self.aux_args['other_dir'], item['image']) ret = { 'images': self.vis_preprocess(vis_path), 'conversations': self.text_preprocess(item) } if 'id' in item: ret['id'] = item['id'] return ret @DATASETS.register_obj def ShareGPT4V(data_args): data_cfg = data_configs['sharegpt4v'] return ShareGPT4VDataset(None, data_args, aux_args=data_cfg) if __name__ == '__main__': dataset = datasets.load_dataset("Lin-Chen/ShareGPT4V", "ShareGPT4V")['train'] aux_args = data_configs['sharegpt4v'] for item in dataset: if 'coco' in item['image']: vis_path = os.path.join(aux_args['coco_dir'], item['image']) elif 'llava' in item['image']: file_names = item['image'].split('/') vis_path = os.path.join(aux_args['llava_dir'], *file_names[-3:]) else: vis_path = os.path.join(aux_args['other_dir'], item['image']) if not os.path.exists(vis_path): print(vis_path) # with open('/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/sharegpt4v/sam.txt') as f: # for line in f: # items = line.split('\t') # name = items[0].strip() # url = items[1].strip() # match = re.search(r'(\d+)', name).group(1) # idx = int(match) # if idx >= 60: # continue # print(name, url) # output_file = os.path.join('/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/sharegpt4v/sam', name) # try: # subprocess.run(["wget", "-O", output_file, url], check=True) # except subprocess.CalledProcessError as e: # print("An error occurred while downloading the file.") # from glob import glob # file_path = '/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/sharegpt4v/sam' # for file_name in glob(os.path.join(file_path, '*.tar')): # subprocess.run(["tar", "-xf", file_name, '-C', '/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/sharegpt4v/sam/images'], check=True)