from llava.datasets.builder import DATASETS | |
from typing import Dict, Optional, Sequence, List | |
from llava.datasets.data_cfgs import data_configs | |
from llava.datasets.base_dataset import ImageTaskDataset | |
from llava.constants import DEFAULT_IMAGE_TOKEN | |
class LLaVAPretrainDataset(ImageTaskDataset): | |
def __init__(self, anno_path, data_args=None, name='llava_pretrain'): | |
super().__init__(anno_path=anno_path, | |
data_args=data_args, | |
name=name) | |
def text_preprocess(self, item) -> List[Dict[str, str]]: | |
qas = item['qas'] | |
conversations = [] | |
for qa in qas: | |
conv = [ | |
{ | |
'from': 'human', | |
'value': DEFAULT_IMAGE_TOKEN + qa['q'] | |
}, | |
{ | |
'from': 'model', | |
'value': qa['a'] | |
} | |
] | |
conversations.extend(conv) | |
return conversations | |
def llava_pretrain(data_args): | |
return LLaVAPretrainDataset(data_configs["llava_pretrain"]['train_data_path'], data_args) |