model1 / llava /datasets /super_dataset.py
multitensor's picture
Upload folder using huggingface_hub
bbfa6f6 verified
from operator import length_hint
import random
import bisect
import copy
import torch
import transformers
from torch.utils.data import get_worker_info
from omegaconf import OmegaConf
import torchvision.transforms.functional as F
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence, List
from torch.utils.data import Dataset, ConcatDataset
from llava.datasets.registry import build_from_cfg
from llava.datasets.builder import DATASETS
from llava.datasets.data_cfgs import data_configs
from llava.train.arguments import DataArguments
from llava.model.preprocessor import preprocess_multimodal, preprocess
from llava.constants import IGNORE_INDEX
from llava.utils import DatasetIter, get_world_size, get_rank, master_print
from transformers import CLIPImageProcessor, SiglipImageProcessor
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_cfg: str,
tokenizer: transformers.PreTrainedTokenizer,
data_args: DataArguments,
num_workers: int):
super(LazySupervisedDataset, self).__init__()
dataset_config = OmegaConf.load(data_cfg)
self.tokenizer = tokenizer
self.data_args = data_args
self.datasets, self.sample_ratios = list(), list()
for ds in list(dataset_config.datasets.keys()):
ds_cfg = dataset_config.datasets[ds]
external_args = {}
for key, value in ds_cfg.items():
external_args[key] = value
args_ = copy.deepcopy(vars(data_args))
data_args_copy = type('DataArguments', (object,), args_)
dataset = build_from_cfg(ds, data_args_copy, DATASETS, default_args=external_args)
self.datasets.append(dataset)
if 'sample_ratio' in ds_cfg:
self.sample_ratios.append(ds_cfg.sample_ratio)
if len(self.sample_ratios) != len(self.datasets):
self.sample_ratios = [1.0] * len(self.sample_ratios)
self.sample_ratios = [float(ratio) / sum(self.sample_ratios) for ratio in self.sample_ratios]
self.ds_iters = [DatasetIter(len(dataset), get_world_size(), get_rank(), num_workers)
for dataset in self.datasets]
def __len__(self):
# set iters per epoch as the maximum iterations of each dataset
max_ds = sorted([int(len(ds) / ratio) for (ds, ratio) in zip(self.datasets, self.sample_ratios)], reverse=True)[0]
return max_ds
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
worker_info = get_worker_info()
ds_idx = random.choices(range(len(self.datasets)), self.sample_ratios, k=1)[0]
item = None
while item is None:
item_id = self.ds_iters[ds_idx].increment(worker_info.id)
# item_id = self.ds_iters[ds_idx].increment(0)
item = self.datasets[ds_idx].__getitem__(item_id)
sources = item
if isinstance(i, int):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
if 'images' in sources[0]:
images = sources[0]['images']
conversations = copy.deepcopy([e['conversations'] for e in sources])
sources = preprocess_multimodal(
conversations, self.data_args)
else:
sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess(
sources,
self.tokenizer,
has_image=('images' in item))
if isinstance(i, int):
data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0])
if images is not None and len(images) > 0:
data_dict["images"] = images
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
img_size = self.data_args.image_processor.img_size
# data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
if getattr(self.data_args, 'image_aspect_ratio', 'square') == 'anyres':
data_dict['images'] = [torch.zeros(1, 3, img_size, img_size)]
else:
data_dict['images'] = [torch.zeros(3, img_size, img_size)]
data_dict['labels'][:] = IGNORE_INDEX
return data_dict
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances]
for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels,
batch_first=True,
padding_value=IGNORE_INDEX)
input_ids = input_ids[:, :self.tokenizer.model_max_length]
labels = labels[:, :self.tokenizer.model_max_length]
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
if 'images' in instances[0]:
images = [instance['images'] for instance in instances]
images_data = []
for imgs in images:
if all(x is not None and x.shape == imgs[0].shape for x in imgs):
imgs = torch.stack(imgs)
else:
imgs = [x for x in imgs if x is not None]
imgs = [x for x in imgs if x.shape == imgs[0].shape]
imgs = torch.stack(imgs)
images_data.append(imgs)
batch["images"] = images_data
if 'images' not in batch or len(batch['images']) == 0:
print("images not in batch")
return batch
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
data_args,
num_workers) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = LazySupervisedDataset(data_cfg=data_args.dataset_config,
tokenizer=tokenizer,
data_args=data_args,
num_workers=num_workers)
for ds, ratio in zip(train_dataset.datasets, train_dataset.sample_ratios):
master_print(f"==> Real epoch of {ds.name} is {round(len(train_dataset) * ratio / len(ds), 2)} epochs.")
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset,
eval_dataset=None,
data_collator=data_collator)
class SupervisedConcatDataset(ConcatDataset):
r"""Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
Args:
datasets (sequence): List of datasets to be concatenated
"""
datasets: List[Dataset]
cumulative_sizes: List[int]
def __init__(self, datasets: List[Dataset],
tokenizer: transformers.PreTrainedTokenizer,
data_args: DataArguments) -> None:
# super().__init__()
super().__init__(datasets)
self.tokenizer = tokenizer
self.data_args = data_args
# self.datasets = list(datasets)
# assert len(self.datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type]
# for d in self.datasets:
# assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
# self.cumulative_sizes = self.cumsum(self.datasets)
# @property
# def lengths(self):
# length_list = []
# for sample in self.list_data_dict:
# img_tokens = 128 if 'image' in sample else 0
# length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
# return length_list
@property
def modality_lengths(self):
length_list = []
token_per_image = getattr(self.data_args, 'num_token_per_image', 32)
# token_per_image = 32
# for sample in self.list_data_dict:
# cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
# cur_len = cur_len if 'image' in sample else -cur_len
# length_list.append(cur_len)
for idx in range(len(self)):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
item = self.datasets[dataset_idx].annotation[sample_idx]
conversations = self.datasets[dataset_idx].text_preprocess(item)
cur_len = sum([len(conv['value'].split()) for conv in conversations])
if self.datasets[dataset_idx].type == 'images':
cur_len += token_per_image
else:
cur_len += token_per_image * self.data_args.num_segments
length_list.append(cur_len)
return length_list
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
item = self.datasets[dataset_idx][sample_idx]
sources = item
if isinstance(idx, int):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
if 'images' in sources[0]:
images = sources[0]['images']
conversations = copy.deepcopy([e['conversations'] for e in sources])
sources = preprocess_multimodal(
conversations, self.data_args)
else:
sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess(
sources,
self.tokenizer,
has_image=('images' in item))
if isinstance(idx, int):
data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0])
if images is not None and len(images) > 0:
data_dict["images"] = images
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
if isinstance(self.data_args.image_processor, SiglipImageProcessor):
img_size = self.data_args.image_processor.size['height']
elif isinstance(self.data_args.image_processor, CLIPImageProcessor):
img_size = self.data_args.image_processor.crop_size['height']
else:
img_size = self.data_args.image_processor.img_size
# data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
if getattr(self.data_args, 'image_aspect_ratio', 'square') == 'anyres':
data_dict['images'] = [torch.zeros(1, 3, img_size, img_size)]
else:
data_dict['images'] = [torch.zeros(3, img_size, img_size)]
data_dict['labels'][:] = IGNORE_INDEX
return data_dict
def make_supervised_data_module_concatdataset(tokenizer: transformers.PreTrainedTokenizer,
data_args,
num_workers) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
datasets = []
dataset_config = OmegaConf.load(data_args.dataset_config)
for ds in list(dataset_config.datasets.keys()):
ds_cfg = dataset_config.datasets[ds]
external_args = {}
for key, value in ds_cfg.items():
external_args[key] = value
args_ = copy.deepcopy(vars(data_args))
data_args_copy = type('DataArguments', (object,), args_)
dataset = build_from_cfg(ds, data_args_copy, DATASETS, default_args=external_args)
datasets.append(dataset)
train_dataset = SupervisedConcatDataset(datasets=datasets,
tokenizer=tokenizer,
data_args=data_args)
# for ds, ratio in zip(train_dataset.datasets, train_dataset.sample_ratios):
# master_print(f"==> Real epoch of {ds.name} is {round(len(train_dataset) * ratio / len(ds), 2)} epochs.")
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset,
eval_dataset=None,
data_collator=data_collator)