import os | |
import random | |
from llava.datasets.builder import DATASETS | |
from typing import Dict, Optional, Sequence, List | |
from llava.datasets.data_cfgs import data_configs | |
from llava.datasets.base_dataset import FramesTaskDataset | |
from llava.datasets.prompts import internvid_prompt | |
from llava.constants import DEFAULT_VIDEO_TOKEN | |
class InternVidDataset(FramesTaskDataset): | |
def __init__(self, anno_path, data_args=None, name='internvid'): | |
super().__init__(anno_path=anno_path, | |
data_args=data_args, | |
name=name) | |
def text_preprocess(self, item) -> List[Dict[str, str]]: | |
caption = item['caption'] | |
conversations = [ | |
{ | |
'from': 'human', | |
'value': DEFAULT_VIDEO_TOKEN + random.choice(internvid_prompt) | |
}, | |
{ | |
'from': 'model', | |
'value': caption | |
} | |
] | |
return conversations | |
def internvid(data_args): | |
return InternVidDataset(data_configs["internvid"]['train_data_path'], data_args) | |