aa / ovis /train /dataset /conversation_dataset.py
root
Ajout du module Ovis
b4942cf
import copy
import json
import logging
from datetime import datetime
from typing import Dict
import torch
from ovis.train.dataset.multimodal_dataset import MultimodalDataset
from ovis.util.utils import rank0_print
class ConversationDataset(MultimodalDataset):
def load(self):
rank0_print(f"[{datetime.now()}] Loading dataset {self.name} from {self.meta_file} begin")
with open(self.meta_file, 'r', encoding='utf-8') as f:
samples = json.load(f)
rank0_print(f'#samples: {len(samples)}')
rank0_print(f'sample: {samples[0]}')
rank0_print(f"[{datetime.now()}] Loading dataset {self.name} end")
return samples
def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
sample = self.samples[i]
conversations = copy.deepcopy(sample["conversations"])
images = None
max_partition = None
if 'image' in sample:
image_paths = sample['image']
if isinstance(image_paths, str):
image_paths = [image_paths]
images = []
for image_path in image_paths:
image, e = self.read_image(image_path)
if image is None:
logging.warning(
f'reading image failed with index: {i}, image path: {image_path}, and exception: {e}')
images = None
break
images.append(image)
elif 'video' in sample:
raise RuntimeError('video is to be supported')
if images:
max_partition = self.max_partitions[0] if len(images) == 1 else self.max_partitions[1]
prompt, input_ids, pixel_values, labels = self.model.preprocess_inputs(
conversations,
images,
max_partition=max_partition,
generation_preface=None,
return_labels=True,
propagate_exception=False
)
if pixel_values is None:
pixel_values, _ = self.visual_tokenizer.mock_input()
input_ids = input_ids[:self.text_max_length]
labels = labels[:self.text_max_length]
return dict(
pixel_values=pixel_values,
input_ids=input_ids,
labels=labels
)