|
import copy |
|
from dataclasses import dataclass |
|
import json |
|
from typing import Dict, Sequence, TYPE_CHECKING |
|
from PIL import Image, ImageFile |
|
import os |
|
|
|
from .text_preprocess import TextPreprocess |
|
from .image_preprocess import ImagePreprocess |
|
from ..utils.arguments import DataArguments |
|
from ..utils.constants import * |
|
|
|
|
|
import transformers |
|
import torch |
|
from torch.utils.data import Dataset |
|
|
|
|
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
class LazySupervisedDataset(Dataset): |
|
"""Dataset for supervised fine-tuning.""" |
|
|
|
def __init__(self, data_path: str, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
data_args: DataArguments): |
|
super(LazySupervisedDataset, self).__init__() |
|
list_data_dict = json.load(open(data_path, "r")) |
|
|
|
self.tokenizer = tokenizer |
|
self.list_data_dict = list_data_dict |
|
self.data_args = data_args |
|
self.text_preprocess = TextPreprocess(tokenizer, data_args.conv_version) |
|
self.image_preprocess = ImagePreprocess(data_args.image_processor, data_args) |
|
|
|
def __len__(self): |
|
return len(self.list_data_dict) |
|
|
|
@property |
|
def lengths(self): |
|
length_list = [] |
|
for sample in self.list_data_dict: |
|
img_tokens = 128 if 'image' in sample else 0 |
|
length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens) |
|
return length_list |
|
|
|
@property |
|
def modality_lengths(self): |
|
length_list = [] |
|
for sample in self.list_data_dict: |
|
cur_len = sum(len(conv['value'].split()) for conv in sample['conversations']) |
|
cur_len = cur_len if 'image' in sample else -cur_len |
|
length_list.append(cur_len) |
|
return length_list |
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
|
sources = self.list_data_dict[i] |
|
data_dict = self.text_preprocess(copy.deepcopy(sources["conversations"])) |
|
if 'image' in sources: |
|
image_file = self.list_data_dict[i]['image'] |
|
image_folder = self.data_args.image_folder |
|
image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') |
|
image = self.image_preprocess(image) |
|
data_dict['image'] = image |
|
elif self.data_args.is_multimodal: |
|
|
|
|
|
crop_size = getattr(self.data_args.image_processor, 'crop_size', getattr(self.data_args.image_processor, 'size')) |
|
data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) |
|
return data_dict |
|
|
|
|
|
@dataclass |
|
class DataCollatorForSupervisedDataset(object): |
|
"""Collate examples for supervised fine-tuning.""" |
|
|
|
tokenizer: transformers.PreTrainedTokenizer |
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
|
input_ids, labels = tuple([instance[key] for instance in instances] |
|
for key in ("input_ids", "labels")) |
|
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: |
|
for input_id in input_ids: |
|
input_id[input_id == self.tokenizer.eos_token_id] = -300 |
|
input_ids = torch.nn.utils.rnn.pad_sequence( |
|
input_ids, |
|
batch_first=True, |
|
padding_value=self.tokenizer.pad_token_id) |
|
labels = torch.nn.utils.rnn.pad_sequence(labels, |
|
batch_first=True, |
|
padding_value=IGNORE_INDEX) |
|
input_ids = input_ids[:, :self.tokenizer.model_max_length] |
|
attention_mask = input_ids.ne(self.tokenizer.pad_token_id) |
|
labels = labels[:, :self.tokenizer.model_max_length] |
|
|
|
|
|
|
|
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: |
|
for input_id in input_ids: |
|
input_id[input_id == -300] = self.tokenizer.eos_token_id |
|
|
|
batch = dict( |
|
input_ids=input_ids, |
|
labels=labels, |
|
attention_mask=attention_mask, |
|
) |
|
|
|
if 'image' in instances[0]: |
|
images = [instance['image'] for instance in instances] |
|
if all(x is not None and x.shape == images[0].shape for x in images): |
|
batch['images'] = torch.stack(images) |
|
else: |
|
batch['images'] = images |
|
|
|
return batch |
|
|
|
|
|
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, |
|
data_args) -> Dict: |
|
"""Make dataset and collator for supervised fine-tuning.""" |
|
train_dataset = LazySupervisedDataset(tokenizer=tokenizer, |
|
data_path=data_args.data_path, |
|
data_args=data_args) |
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) |
|
return dict(train_dataset=train_dataset, |
|
eval_dataset=None, |
|
data_collator=data_collator) |
|
|