from operator import length_hint import random import bisect import copy import torch import transformers from torch.utils.data import get_worker_info from omegaconf import OmegaConf import torchvision.transforms.functional as F from dataclasses import dataclass, field from typing import Dict, Optional, Sequence, List from torch.utils.data import Dataset, ConcatDataset from llava.datasets.registry import build_from_cfg from llava.datasets.builder import DATASETS from llava.datasets.data_cfgs import data_configs from llava.train.arguments import DataArguments from llava.model.preprocessor import preprocess_multimodal, preprocess from llava.constants import IGNORE_INDEX from llava.utils import DatasetIter, get_world_size, get_rank, master_print from transformers import CLIPImageProcessor, SiglipImageProcessor class LazySupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__(self, data_cfg: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, num_workers: int): super(LazySupervisedDataset, self).__init__() dataset_config = OmegaConf.load(data_cfg) self.tokenizer = tokenizer self.data_args = data_args self.datasets, self.sample_ratios = list(), list() for ds in list(dataset_config.datasets.keys()): ds_cfg = dataset_config.datasets[ds] external_args = {} for key, value in ds_cfg.items(): external_args[key] = value args_ = copy.deepcopy(vars(data_args)) data_args_copy = type('DataArguments', (object,), args_) dataset = build_from_cfg(ds, data_args_copy, DATASETS, default_args=external_args) self.datasets.append(dataset) if 'sample_ratio' in ds_cfg: self.sample_ratios.append(ds_cfg.sample_ratio) if len(self.sample_ratios) != len(self.datasets): self.sample_ratios = [1.0] * len(self.sample_ratios) self.sample_ratios = [float(ratio) / sum(self.sample_ratios) for ratio in self.sample_ratios] self.ds_iters = [DatasetIter(len(dataset), get_world_size(), get_rank(), num_workers) for dataset in self.datasets] def __len__(self): # set iters per epoch as the maximum iterations of each dataset max_ds = sorted([int(len(ds) / ratio) for (ds, ratio) in zip(self.datasets, self.sample_ratios)], reverse=True)[0] return max_ds def __getitem__(self, i) -> Dict[str, torch.Tensor]: worker_info = get_worker_info() ds_idx = random.choices(range(len(self.datasets)), self.sample_ratios, k=1)[0] item = None while item is None: item_id = self.ds_iters[ds_idx].increment(worker_info.id) # item_id = self.ds_iters[ds_idx].increment(0) item = self.datasets[ds_idx].__getitem__(item_id) sources = item if isinstance(i, int): sources = [sources] assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME if 'images' in sources[0]: images = sources[0]['images'] conversations = copy.deepcopy([e['conversations'] for e in sources]) sources = preprocess_multimodal( conversations, self.data_args) else: sources = copy.deepcopy([e["conversations"] for e in sources]) data_dict = preprocess( sources, self.tokenizer, has_image=('images' in item)) if isinstance(i, int): data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) if images is not None and len(images) > 0: data_dict["images"] = images elif self.data_args.is_multimodal: # image does not exist in the data, but the model is multimodal img_size = self.data_args.image_processor.img_size # data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) if getattr(self.data_args, 'image_aspect_ratio', 'square') == 'anyres': data_dict['images'] = [torch.zeros(1, 3, img_size, img_size)] else: data_dict['images'] = [torch.zeros(3, img_size, img_size)] data_dict['labels'][:] = IGNORE_INDEX return data_dict @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) input_ids = input_ids[:, :self.tokenizer.model_max_length] labels = labels[:, :self.tokenizer.model_max_length] batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) if 'images' in instances[0]: images = [instance['images'] for instance in instances] images_data = [] for imgs in images: if all(x is not None and x.shape == imgs[0].shape for x in imgs): imgs = torch.stack(imgs) else: imgs = [x for x in imgs if x is not None] imgs = [x for x in imgs if x.shape == imgs[0].shape] imgs = torch.stack(imgs) images_data.append(imgs) batch["images"] = images_data if 'images' not in batch or len(batch['images']) == 0: print("images not in batch") return batch def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args, num_workers) -> Dict: """Make dataset and collator for supervised fine-tuning.""" train_dataset = LazySupervisedDataset(data_cfg=data_args.dataset_config, tokenizer=tokenizer, data_args=data_args, num_workers=num_workers) for ds, ratio in zip(train_dataset.datasets, train_dataset.sample_ratios): master_print(f"==> Real epoch of {ds.name} is {round(len(train_dataset) * ratio / len(ds), 2)} epochs.") data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) class SupervisedConcatDataset(ConcatDataset): r"""Dataset as a concatenation of multiple datasets. This class is useful to assemble different existing datasets. Args: datasets (sequence): List of datasets to be concatenated """ datasets: List[Dataset] cumulative_sizes: List[int] def __init__(self, datasets: List[Dataset], tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments) -> None: # super().__init__() super().__init__(datasets) self.tokenizer = tokenizer self.data_args = data_args # self.datasets = list(datasets) # assert len(self.datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type] # for d in self.datasets: # assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset" # self.cumulative_sizes = self.cumsum(self.datasets) # @property # def lengths(self): # length_list = [] # for sample in self.list_data_dict: # img_tokens = 128 if 'image' in sample else 0 # length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens) # return length_list @property def modality_lengths(self): length_list = [] token_per_image = getattr(self.data_args, 'num_token_per_image', 32) # token_per_image = 32 # for sample in self.list_data_dict: # cur_len = sum(len(conv['value'].split()) for conv in sample['conversations']) # cur_len = cur_len if 'image' in sample else -cur_len # length_list.append(cur_len) for idx in range(len(self)): dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] item = self.datasets[dataset_idx].annotation[sample_idx] conversations = self.datasets[dataset_idx].text_preprocess(item) cur_len = sum([len(conv['value'].split()) for conv in conversations]) if self.datasets[dataset_idx].type == 'images': cur_len += token_per_image else: cur_len += token_per_image * self.data_args.num_segments length_list.append(cur_len) return length_list def __len__(self): return self.cumulative_sizes[-1] def __getitem__(self, idx): if idx < 0: if -idx > len(self): raise ValueError("absolute value of index should not exceed dataset length") idx = len(self) + idx dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] item = self.datasets[dataset_idx][sample_idx] sources = item if isinstance(idx, int): sources = [sources] assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME if 'images' in sources[0]: images = sources[0]['images'] conversations = copy.deepcopy([e['conversations'] for e in sources]) sources = preprocess_multimodal( conversations, self.data_args) else: sources = copy.deepcopy([e["conversations"] for e in sources]) data_dict = preprocess( sources, self.tokenizer, has_image=('images' in item)) if isinstance(idx, int): data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) if images is not None and len(images) > 0: data_dict["images"] = images elif self.data_args.is_multimodal: # image does not exist in the data, but the model is multimodal if isinstance(self.data_args.image_processor, SiglipImageProcessor): img_size = self.data_args.image_processor.size['height'] elif isinstance(self.data_args.image_processor, CLIPImageProcessor): img_size = self.data_args.image_processor.crop_size['height'] else: img_size = self.data_args.image_processor.img_size # data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) if getattr(self.data_args, 'image_aspect_ratio', 'square') == 'anyres': data_dict['images'] = [torch.zeros(1, 3, img_size, img_size)] else: data_dict['images'] = [torch.zeros(3, img_size, img_size)] data_dict['labels'][:] = IGNORE_INDEX return data_dict def make_supervised_data_module_concatdataset(tokenizer: transformers.PreTrainedTokenizer, data_args, num_workers) -> Dict: """Make dataset and collator for supervised fine-tuning.""" datasets = [] dataset_config = OmegaConf.load(data_args.dataset_config) for ds in list(dataset_config.datasets.keys()): ds_cfg = dataset_config.datasets[ds] external_args = {} for key, value in ds_cfg.items(): external_args[key] = value args_ = copy.deepcopy(vars(data_args)) data_args_copy = type('DataArguments', (object,), args_) dataset = build_from_cfg(ds, data_args_copy, DATASETS, default_args=external_args) datasets.append(dataset) train_dataset = SupervisedConcatDataset(datasets=datasets, tokenizer=tokenizer, data_args=data_args) # for ds, ratio in zip(train_dataset.datasets, train_dataset.sample_ratios): # master_print(f"==> Real epoch of {ds.name} is {round(len(train_dataset) * ratio / len(ds), 2)} epochs.") data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)