import copy from dataclasses import dataclass import json from typing import Dict, Sequence, TYPE_CHECKING from PIL import Image, ImageFile import os from .text_preprocess import TextPreprocess from .image_preprocess import ImagePreprocess from ..utils.arguments import DataArguments from ..utils.constants import * import transformers import torch from torch.utils.data import Dataset ImageFile.LOAD_TRUNCATED_IMAGES = True class LazySupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments): super(LazySupervisedDataset, self).__init__() list_data_dict = json.load(open(data_path, "r")) self.tokenizer = tokenizer self.list_data_dict = list_data_dict self.data_args = data_args self.text_preprocess = TextPreprocess(tokenizer, data_args.conv_version) self.image_preprocess = ImagePreprocess(data_args.image_processor, data_args) def __len__(self): return len(self.list_data_dict) @property def lengths(self): length_list = [] for sample in self.list_data_dict: img_tokens = 128 if 'image' in sample else 0 length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens) return length_list @property def modality_lengths(self): length_list = [] for sample in self.list_data_dict: cur_len = sum(len(conv['value'].split()) for conv in sample['conversations']) cur_len = cur_len if 'image' in sample else -cur_len length_list.append(cur_len) return length_list def __getitem__(self, i) -> Dict[str, torch.Tensor]: sources = self.list_data_dict[i] data_dict = self.text_preprocess(copy.deepcopy(sources["conversations"])) if 'image' in sources: image_file = self.list_data_dict[i]['image'] image_folder = self.data_args.image_folder image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') image = self.image_preprocess(image) data_dict['image'] = image elif self.data_args.is_multimodal: # image does not exist in the data, but the model is multimodal # print(f'{i}:{sources}') crop_size = getattr(self.data_args.image_processor, 'crop_size', getattr(self.data_args.image_processor, 'size')) data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) return data_dict @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: for input_id in input_ids: input_id[input_id == self.tokenizer.eos_token_id] = -300 input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) input_ids = input_ids[:, :self.tokenizer.model_max_length] attention_mask = input_ids.ne(self.tokenizer.pad_token_id) labels = labels[:, :self.tokenizer.model_max_length] # FIXME: This is a hack for handling phi and stablelm, as they have the same eos, pad and unk. We want the model # FIXME: to predict the eos in the input ids, but we also use the id of eos to pad sequence, so we use a temp # FIXME: eos id first, and convert them back. if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: for input_id in input_ids: input_id[input_id == -300] = self.tokenizer.eos_token_id batch = dict( input_ids=input_ids, labels=labels, attention_mask=attention_mask, ) if 'image' in instances[0]: images = [instance['image'] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): batch['images'] = torch.stack(images) else: batch['images'] = images return batch def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: """Make dataset and collator for supervised fine-tuning.""" train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)