model1 / llava /datasets /llava_pretrain_dataset.py
multitensor's picture
Upload folder using huggingface_hub
bbfa6f6 verified
from llava.datasets.builder import DATASETS
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import ImageTaskDataset
from llava.constants import DEFAULT_IMAGE_TOKEN
class LLaVAPretrainDataset(ImageTaskDataset):
def __init__(self, anno_path, data_args=None, name='llava_pretrain'):
super().__init__(anno_path=anno_path,
data_args=data_args,
name=name)
def text_preprocess(self, item) -> List[Dict[str, str]]:
qas = item['qas']
conversations = []
for qa in qas:
conv = [
{
'from': 'human',
'value': DEFAULT_IMAGE_TOKEN + qa['q']
},
{
'from': 'model',
'value': qa['a']
}
]
conversations.extend(conv)
return conversations
@DATASETS.register_obj
def llava_pretrain(data_args):
return LLaVAPretrainDataset(data_configs["llava_pretrain"]['train_data_path'], data_args)