StarCycle's picture
init
377d3d1
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os
import torch
from datasets import Dataset as HFDataset
from datasets import DatasetDict
from mmengine.config import Config, ConfigDict
from PIL import Image
from torch.utils.data import Dataset
from xtuner.registry import BUILDER
from .huggingface import process_hf_dataset
from .utils import expand2square
class LLaVADataset(Dataset):
def __init__(self,
data_path,
image_folder,
tokenizer,
image_processor,
max_dataset_length=None,
dataset_map_fn=None,
template_map_fn=None,
max_length=2048,
pad_image_to_square=False):
super().__init__()
json_data = json.load(open(data_path))
for idx in range(len(json_data)):
if isinstance(json_data[idx]['id'], int):
json_data[idx]['id'] = str(json_data[idx]['id'])
json_data = DatasetDict({'train': HFDataset.from_list(json_data)})
self.text_data = process_hf_dataset(
dataset=json_data,
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=dataset_map_fn,
template_map_fn=template_map_fn,
split='train',
max_dataset_length=max_dataset_length,
remove_unused_columns=False,
pack_to_max_length=False,
with_image_token=True)
self.image_folder = image_folder
if isinstance(image_processor, dict) or isinstance(
image_processor, Config) or isinstance(image_processor,
ConfigDict):
self.image_processor = BUILDER.build(image_processor)
else:
self.image_processor = image_processor
self.pad_image_to_square = pad_image_to_square
@property
def modality_length(self):
length_list = []
for data_dict in self.text_data:
cur_len = len(data_dict['input_ids'])
if data_dict.get('image', None) is None:
cur_len = -cur_len
length_list.append(cur_len)
return length_list
def __len__(self):
return len(self.text_data)
def __getitem__(self, index):
data_dict = self.text_data[index]
if data_dict.get('image', None) is not None:
image_file = data_dict['image']
image = Image.open(os.path.join(self.image_folder,
image_file)).convert('RGB')
if self.pad_image_to_square:
image = expand2square(
image,
tuple(
int(x * 255) for x in self.image_processor.image_mean))
image = self.image_processor.preprocess(
image, return_tensors='pt')['pixel_values'][0]
data_dict['pixel_values'] = image
else:
size = self.image_processor.size
data_dict['pixel_values'] = torch.zeros(3, size['height'],
size['width'])
return data_dict