aa / ovis /train /dataset /multimodal_dataset.py
root
Ajout du module Ovis
b4942cf
import logging
import os
from typing import Dict, Sequence, Union, List
import torch
from PIL import Image
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
from ovis.model.modeling_ovis import Ovis
from ovis.train.arguments import TrainingArguments
from ovis.util.constants import IGNORE_ID
class MultimodalDataset(Dataset):
def __init__(self, name: str, info: Dict, model: Ovis, training_args: TrainingArguments):
self.name = name
self.meta_file = info['meta_file']
self.image_dir = info['image_dir']
self.caption_template = info.get('caption_template', None)
self.text_tokenizer = model.get_text_tokenizer()
self.visual_tokenizer = model.get_visual_tokenizer()
self.image_height, self.image_width = self.visual_tokenizer.get_image_size()
self.model = model
self.text_max_length = training_args.text_max_length
self.max_partitions = [int(m.strip()) for m in training_args.max_partitions.split('|')]
self.samples = self.load()
def load(self):
raise NotImplementedError
def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
raise NotImplementedError
def __len__(self):
return len(self.samples)
def read_image(self, path):
try:
full_path = os.path.join(self.image_dir, path)
image = Image.open(full_path).convert('RGB')
return image, None
except Exception as e:
return None, e
class DataCollatorForMultimodalDataset:
def __init__(self, text_tokenizer: PreTrainedTokenizer):
self.text_tokenizer = text_tokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
pixel_values, input_ids, labels = tuple([instance[key] for instance in instances]
for key in ("pixel_values", "input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.text_tokenizer.pad_token_id)
attention_mask = torch.ne(input_ids, self.text_tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(
labels,
batch_first=True,
padding_value=IGNORE_ID)
num_valid_label = torch.not_equal(labels, IGNORE_ID).sum().item()
if num_valid_label == 0:
logging.warning(
f'[DataCollatorForMultimodalDataset] All labels in a batch are ignored, which may lead to training instability\n{input_ids=}\n{attention_mask=}\n{labels=}')
return dict(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
pixel_values=pixel_values
)