File size: 1,134 Bytes
bbfa6f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
from llava.datasets.builder import DATASETS
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import ImageTaskDataset
from llava.constants import DEFAULT_IMAGE_TOKEN
class LLaVAPretrainDataset(ImageTaskDataset):
def __init__(self, anno_path, data_args=None, name='llava_pretrain'):
super().__init__(anno_path=anno_path,
data_args=data_args,
name=name)
def text_preprocess(self, item) -> List[Dict[str, str]]:
qas = item['qas']
conversations = []
for qa in qas:
conv = [
{
'from': 'human',
'value': DEFAULT_IMAGE_TOKEN + qa['q']
},
{
'from': 'model',
'value': qa['a']
}
]
conversations.extend(conv)
return conversations
@DATASETS.register_obj
def llava_pretrain(data_args):
return LLaVAPretrainDataset(data_configs["llava_pretrain"]['train_data_path'], data_args) |