|
import logging |
|
import os |
|
from typing import Dict, Sequence, Union, List |
|
|
|
import torch |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
from transformers import PreTrainedTokenizer |
|
|
|
from ovis.model.modeling_ovis import Ovis |
|
from ovis.train.arguments import TrainingArguments |
|
from ovis.util.constants import IGNORE_ID |
|
|
|
|
|
class MultimodalDataset(Dataset): |
|
def __init__(self, name: str, info: Dict, model: Ovis, training_args: TrainingArguments): |
|
self.name = name |
|
self.meta_file = info['meta_file'] |
|
self.image_dir = info['image_dir'] |
|
self.caption_template = info.get('caption_template', None) |
|
self.text_tokenizer = model.get_text_tokenizer() |
|
self.visual_tokenizer = model.get_visual_tokenizer() |
|
self.image_height, self.image_width = self.visual_tokenizer.get_image_size() |
|
self.model = model |
|
self.text_max_length = training_args.text_max_length |
|
self.max_partitions = [int(m.strip()) for m in training_args.max_partitions.split('|')] |
|
self.samples = self.load() |
|
|
|
def load(self): |
|
raise NotImplementedError |
|
|
|
def __getitem__(self, i: int) -> Dict[str, torch.Tensor]: |
|
raise NotImplementedError |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|
|
def read_image(self, path): |
|
try: |
|
full_path = os.path.join(self.image_dir, path) |
|
image = Image.open(full_path).convert('RGB') |
|
return image, None |
|
except Exception as e: |
|
return None, e |
|
|
|
|
|
class DataCollatorForMultimodalDataset: |
|
def __init__(self, text_tokenizer: PreTrainedTokenizer): |
|
self.text_tokenizer = text_tokenizer |
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: |
|
pixel_values, input_ids, labels = tuple([instance[key] for instance in instances] |
|
for key in ("pixel_values", "input_ids", "labels")) |
|
input_ids = torch.nn.utils.rnn.pad_sequence( |
|
input_ids, |
|
batch_first=True, |
|
padding_value=self.text_tokenizer.pad_token_id) |
|
attention_mask = torch.ne(input_ids, self.text_tokenizer.pad_token_id) |
|
labels = torch.nn.utils.rnn.pad_sequence( |
|
labels, |
|
batch_first=True, |
|
padding_value=IGNORE_ID) |
|
num_valid_label = torch.not_equal(labels, IGNORE_ID).sum().item() |
|
if num_valid_label == 0: |
|
logging.warning( |
|
f'[DataCollatorForMultimodalDataset] All labels in a batch are ignored, which may lead to training instability\n{input_ids=}\n{attention_mask=}\n{labels=}') |
|
return dict( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
labels=labels, |
|
pixel_values=pixel_values |
|
) |
|
|