|
import logging |
|
|
|
import torch |
|
import datasets |
|
import cv2 |
|
|
|
import numpy as np |
|
from base64 import b64decode |
|
from io import BytesIO |
|
from PIL import Image |
|
from torch.utils.data import ConcatDataset |
|
from llava.datasets.builder import DATASETS |
|
from typing import Dict, Optional, Sequence, List |
|
from llava.datasets.data_cfgs import data_configs |
|
from llava.datasets.base_dataset import ImageTaskDataset |
|
from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_VIDEO_TOKEN |
|
from llava.utils import master_print |
|
|
|
|
|
class M3ITDataset(ImageTaskDataset): |
|
def __init__(self, anno_path, data_args=None, name='m3it', selected_tasks=None): |
|
super().__init__(anno_path, data_args, name) |
|
|
|
self.selected_tasks = selected_tasks |
|
dataset_list = [ |
|
datasets.load_dataset("MMInstruction/M3IT", i, num_proc=16) for i in selected_tasks |
|
] |
|
|
|
target_dataset_list = [] |
|
master_print('#' * 50) |
|
for d in dataset_list: |
|
try: |
|
target_dataset_list.append(d['train']) |
|
master_print(f"TASK {d['train']._info.config_name}, SIZE {len(d['train'])}") |
|
except KeyError: |
|
print(f"{d['train']._info.config_name} has no train set.") |
|
self.dataset = ConcatDataset(target_dataset_list) |
|
master_print(f"Finished loading dataset {name} {len(self.dataset)} samples...") |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def text_preprocess(self, item, is_video=False) -> List[Dict[str, str]]: |
|
instruction = item['instruction'] |
|
question = item['inputs'] |
|
answer = item['outputs'] |
|
|
|
query = f"{instruction} {DEFAULT_IMAGE_TOKEN if not is_video else DEFAULT_VIDEO_TOKEN}" |
|
if len(question) > 0: |
|
query += question |
|
|
|
conversations = [ |
|
{ |
|
'from': 'human', |
|
'value': query |
|
}, |
|
{ |
|
'from': 'model', |
|
'value': answer |
|
} |
|
] |
|
|
|
return conversations |
|
|
|
def bin2image(self, image_base64_str): |
|
img = Image.open(BytesIO(b64decode(image_base64_str))).convert("RGB") |
|
img = np.array(img) |
|
|
|
if img.shape[2] != 3: |
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) |
|
|
|
img = Image.fromarray(img).convert('RGB') |
|
img = self.preprocess_image(img) |
|
|
|
return img |
|
|
|
def vis_preprocess(self, image_base64_str_list) -> Image: |
|
try: |
|
images = list(map(self.bin2image, image_base64_str_list)) |
|
formatted_images = [] |
|
for image in images: |
|
if isinstance(image, list): |
|
formatted_images.extend(image) |
|
else: |
|
formatted_images.append(image) |
|
return formatted_images |
|
except Exception as e: |
|
|
|
return None |
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
|
item = self.dataset[i] |
|
|
|
img_data = item['image_base64_str'] |
|
|
|
images = self.vis_preprocess(img_data) |
|
if images is None: |
|
return None |
|
|
|
|
|
is_video = True if len(images) > 0 else False |
|
|
|
ret = { |
|
'images': images, |
|
'conversations': self.text_preprocess(item, is_video) |
|
} |
|
|
|
return ret |
|
|
|
|
|
@DATASETS.register_obj |
|
def m3it(data_args): |
|
tasks = data_configs['m3it']['default_tasks'] |
|
if 'tasks' in data_args.external_args: |
|
tasks = data_args.external_args['tasks'] |
|
|
|
return M3ITDataset(anno_path=None, data_args=data_args, selected_tasks=tasks) |
|
|