|
from operator import length_hint |
|
import random |
|
import bisect |
|
import copy |
|
import torch |
|
import transformers |
|
from torch.utils.data import get_worker_info |
|
from omegaconf import OmegaConf |
|
import torchvision.transforms.functional as F |
|
from dataclasses import dataclass, field |
|
from typing import Dict, Optional, Sequence, List |
|
from torch.utils.data import Dataset, ConcatDataset |
|
|
|
from llava.datasets.registry import build_from_cfg |
|
from llava.datasets.builder import DATASETS |
|
from llava.datasets.data_cfgs import data_configs |
|
from llava.train.arguments import DataArguments |
|
from llava.model.preprocessor import preprocess_multimodal, preprocess |
|
from llava.constants import IGNORE_INDEX |
|
from llava.utils import DatasetIter, get_world_size, get_rank, master_print |
|
from transformers import CLIPImageProcessor, SiglipImageProcessor |
|
|
|
class LazySupervisedDataset(Dataset): |
|
"""Dataset for supervised fine-tuning.""" |
|
|
|
def __init__(self, data_cfg: str, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
data_args: DataArguments, |
|
num_workers: int): |
|
|
|
super(LazySupervisedDataset, self).__init__() |
|
dataset_config = OmegaConf.load(data_cfg) |
|
|
|
self.tokenizer = tokenizer |
|
self.data_args = data_args |
|
|
|
self.datasets, self.sample_ratios = list(), list() |
|
for ds in list(dataset_config.datasets.keys()): |
|
ds_cfg = dataset_config.datasets[ds] |
|
external_args = {} |
|
for key, value in ds_cfg.items(): |
|
external_args[key] = value |
|
args_ = copy.deepcopy(vars(data_args)) |
|
data_args_copy = type('DataArguments', (object,), args_) |
|
dataset = build_from_cfg(ds, data_args_copy, DATASETS, default_args=external_args) |
|
self.datasets.append(dataset) |
|
if 'sample_ratio' in ds_cfg: |
|
self.sample_ratios.append(ds_cfg.sample_ratio) |
|
|
|
if len(self.sample_ratios) != len(self.datasets): |
|
self.sample_ratios = [1.0] * len(self.sample_ratios) |
|
|
|
self.sample_ratios = [float(ratio) / sum(self.sample_ratios) for ratio in self.sample_ratios] |
|
self.ds_iters = [DatasetIter(len(dataset), get_world_size(), get_rank(), num_workers) |
|
for dataset in self.datasets] |
|
def __len__(self): |
|
|
|
max_ds = sorted([int(len(ds) / ratio) for (ds, ratio) in zip(self.datasets, self.sample_ratios)], reverse=True)[0] |
|
|
|
return max_ds |
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
|
worker_info = get_worker_info() |
|
|
|
ds_idx = random.choices(range(len(self.datasets)), self.sample_ratios, k=1)[0] |
|
|
|
item = None |
|
while item is None: |
|
item_id = self.ds_iters[ds_idx].increment(worker_info.id) |
|
|
|
item = self.datasets[ds_idx].__getitem__(item_id) |
|
|
|
sources = item |
|
if isinstance(i, int): |
|
sources = [sources] |
|
assert len(sources) == 1, "Don't know why it is wrapped to a list" |
|
if 'images' in sources[0]: |
|
images = sources[0]['images'] |
|
conversations = copy.deepcopy([e['conversations'] for e in sources]) |
|
|
|
sources = preprocess_multimodal( |
|
conversations, self.data_args) |
|
else: |
|
sources = copy.deepcopy([e["conversations"] for e in sources]) |
|
|
|
data_dict = preprocess( |
|
sources, |
|
self.tokenizer, |
|
has_image=('images' in item)) |
|
|
|
if isinstance(i, int): |
|
data_dict = dict(input_ids=data_dict["input_ids"][0], |
|
labels=data_dict["labels"][0]) |
|
|
|
if images is not None and len(images) > 0: |
|
data_dict["images"] = images |
|
elif self.data_args.is_multimodal: |
|
|
|
img_size = self.data_args.image_processor.img_size |
|
|
|
if getattr(self.data_args, 'image_aspect_ratio', 'square') == 'anyres': |
|
data_dict['images'] = [torch.zeros(1, 3, img_size, img_size)] |
|
else: |
|
data_dict['images'] = [torch.zeros(3, img_size, img_size)] |
|
data_dict['labels'][:] = IGNORE_INDEX |
|
return data_dict |
|
|
|
|
|
|
|
|
|
@dataclass |
|
class DataCollatorForSupervisedDataset(object): |
|
"""Collate examples for supervised fine-tuning.""" |
|
tokenizer: transformers.PreTrainedTokenizer |
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
|
input_ids, labels = tuple([instance[key] for instance in instances] |
|
for key in ("input_ids", "labels")) |
|
input_ids = torch.nn.utils.rnn.pad_sequence( |
|
input_ids, |
|
batch_first=True, |
|
padding_value=self.tokenizer.pad_token_id) |
|
labels = torch.nn.utils.rnn.pad_sequence(labels, |
|
batch_first=True, |
|
padding_value=IGNORE_INDEX) |
|
input_ids = input_ids[:, :self.tokenizer.model_max_length] |
|
labels = labels[:, :self.tokenizer.model_max_length] |
|
batch = dict( |
|
input_ids=input_ids, |
|
labels=labels, |
|
attention_mask=input_ids.ne(self.tokenizer.pad_token_id), |
|
) |
|
|
|
if 'images' in instances[0]: |
|
images = [instance['images'] for instance in instances] |
|
images_data = [] |
|
for imgs in images: |
|
if all(x is not None and x.shape == imgs[0].shape for x in imgs): |
|
imgs = torch.stack(imgs) |
|
else: |
|
imgs = [x for x in imgs if x is not None] |
|
imgs = [x for x in imgs if x.shape == imgs[0].shape] |
|
imgs = torch.stack(imgs) |
|
|
|
images_data.append(imgs) |
|
|
|
batch["images"] = images_data |
|
|
|
if 'images' not in batch or len(batch['images']) == 0: |
|
print("images not in batch") |
|
|
|
return batch |
|
|
|
|
|
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, |
|
data_args, |
|
num_workers) -> Dict: |
|
"""Make dataset and collator for supervised fine-tuning.""" |
|
train_dataset = LazySupervisedDataset(data_cfg=data_args.dataset_config, |
|
tokenizer=tokenizer, |
|
data_args=data_args, |
|
num_workers=num_workers) |
|
|
|
for ds, ratio in zip(train_dataset.datasets, train_dataset.sample_ratios): |
|
master_print(f"==> Real epoch of {ds.name} is {round(len(train_dataset) * ratio / len(ds), 2)} epochs.") |
|
|
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) |
|
return dict(train_dataset=train_dataset, |
|
eval_dataset=None, |
|
data_collator=data_collator) |
|
|
|
|
|
|
|
class SupervisedConcatDataset(ConcatDataset): |
|
r"""Dataset as a concatenation of multiple datasets. |
|
|
|
This class is useful to assemble different existing datasets. |
|
|
|
Args: |
|
datasets (sequence): List of datasets to be concatenated |
|
""" |
|
|
|
datasets: List[Dataset] |
|
cumulative_sizes: List[int] |
|
|
|
|
|
def __init__(self, datasets: List[Dataset], |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
data_args: DataArguments) -> None: |
|
|
|
super().__init__(datasets) |
|
self.tokenizer = tokenizer |
|
self.data_args = data_args |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property |
|
def modality_lengths(self): |
|
length_list = [] |
|
token_per_image = getattr(self.data_args, 'num_token_per_image', 32) |
|
|
|
|
|
|
|
|
|
|
|
for idx in range(len(self)): |
|
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
|
if dataset_idx == 0: |
|
sample_idx = idx |
|
else: |
|
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] |
|
item = self.datasets[dataset_idx].annotation[sample_idx] |
|
conversations = self.datasets[dataset_idx].text_preprocess(item) |
|
cur_len = sum([len(conv['value'].split()) for conv in conversations]) |
|
if self.datasets[dataset_idx].type == 'images': |
|
cur_len += token_per_image |
|
else: |
|
cur_len += token_per_image * self.data_args.num_segments |
|
length_list.append(cur_len) |
|
return length_list |
|
|
|
def __len__(self): |
|
return self.cumulative_sizes[-1] |
|
|
|
def __getitem__(self, idx): |
|
if idx < 0: |
|
if -idx > len(self): |
|
raise ValueError("absolute value of index should not exceed dataset length") |
|
idx = len(self) + idx |
|
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
|
if dataset_idx == 0: |
|
sample_idx = idx |
|
else: |
|
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] |
|
item = self.datasets[dataset_idx][sample_idx] |
|
sources = item |
|
if isinstance(idx, int): |
|
sources = [sources] |
|
assert len(sources) == 1, "Don't know why it is wrapped to a list" |
|
if 'images' in sources[0]: |
|
images = sources[0]['images'] |
|
conversations = copy.deepcopy([e['conversations'] for e in sources]) |
|
|
|
sources = preprocess_multimodal( |
|
conversations, self.data_args) |
|
else: |
|
sources = copy.deepcopy([e["conversations"] for e in sources]) |
|
|
|
data_dict = preprocess( |
|
sources, |
|
self.tokenizer, |
|
has_image=('images' in item)) |
|
|
|
if isinstance(idx, int): |
|
data_dict = dict(input_ids=data_dict["input_ids"][0], |
|
labels=data_dict["labels"][0]) |
|
|
|
if images is not None and len(images) > 0: |
|
data_dict["images"] = images |
|
elif self.data_args.is_multimodal: |
|
|
|
if isinstance(self.data_args.image_processor, SiglipImageProcessor): |
|
img_size = self.data_args.image_processor.size['height'] |
|
elif isinstance(self.data_args.image_processor, CLIPImageProcessor): |
|
img_size = self.data_args.image_processor.crop_size['height'] |
|
else: |
|
img_size = self.data_args.image_processor.img_size |
|
|
|
if getattr(self.data_args, 'image_aspect_ratio', 'square') == 'anyres': |
|
data_dict['images'] = [torch.zeros(1, 3, img_size, img_size)] |
|
else: |
|
data_dict['images'] = [torch.zeros(3, img_size, img_size)] |
|
data_dict['labels'][:] = IGNORE_INDEX |
|
return data_dict |
|
|
|
|
|
def make_supervised_data_module_concatdataset(tokenizer: transformers.PreTrainedTokenizer, |
|
data_args, |
|
num_workers) -> Dict: |
|
"""Make dataset and collator for supervised fine-tuning.""" |
|
datasets = [] |
|
dataset_config = OmegaConf.load(data_args.dataset_config) |
|
for ds in list(dataset_config.datasets.keys()): |
|
ds_cfg = dataset_config.datasets[ds] |
|
external_args = {} |
|
for key, value in ds_cfg.items(): |
|
external_args[key] = value |
|
args_ = copy.deepcopy(vars(data_args)) |
|
data_args_copy = type('DataArguments', (object,), args_) |
|
dataset = build_from_cfg(ds, data_args_copy, DATASETS, default_args=external_args) |
|
datasets.append(dataset) |
|
|
|
train_dataset = SupervisedConcatDataset(datasets=datasets, |
|
tokenizer=tokenizer, |
|
data_args=data_args) |
|
|
|
|
|
|
|
|
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) |
|
return dict(train_dataset=train_dataset, |
|
eval_dataset=None, |
|
data_collator=data_collator) |
|
|
|
|