from llava.datasets.builder import DATASETS from typing import Dict, Optional, Sequence, List from llava.datasets.data_cfgs import data_configs from llava.datasets.base_dataset import ImageTaskDataset from llava.constants import DEFAULT_IMAGE_TOKEN class LLaVAPretrainDataset(ImageTaskDataset): def __init__(self, anno_path, data_args=None, name='llava_pretrain'): super().__init__(anno_path=anno_path, data_args=data_args, name=name) def text_preprocess(self, item) -> List[Dict[str, str]]: qas = item['qas'] conversations = [] for qa in qas: conv = [ { 'from': 'human', 'value': DEFAULT_IMAGE_TOKEN + qa['q'] }, { 'from': 'model', 'value': qa['a'] } ] conversations.extend(conv) return conversations @DATASETS.register_obj def llava_pretrain(data_args): return LLaVAPretrainDataset(data_configs["llava_pretrain"]['train_data_path'], data_args)