model1 / llava /datasets /m3it_dataset.py
multitensor's picture
Upload folder using huggingface_hub
bbfa6f6 verified
import logging
import torch
import datasets
import cv2
import numpy as np
from base64 import b64decode
from io import BytesIO
from PIL import Image
from torch.utils.data import ConcatDataset
from llava.datasets.builder import DATASETS
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import ImageTaskDataset
from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_VIDEO_TOKEN
from llava.utils import master_print
class M3ITDataset(ImageTaskDataset):
def __init__(self, anno_path, data_args=None, name='m3it', selected_tasks=None):
super().__init__(anno_path, data_args, name)
self.selected_tasks = selected_tasks
dataset_list = [
datasets.load_dataset("MMInstruction/M3IT", i, num_proc=16) for i in selected_tasks
]
# some dataset have no validation
target_dataset_list = []
master_print('#' * 50)
for d in dataset_list:
try:
target_dataset_list.append(d['train'])
master_print(f"TASK {d['train']._info.config_name}, SIZE {len(d['train'])}")
except KeyError:
print(f"{d['train']._info.config_name} has no train set.")
self.dataset = ConcatDataset(target_dataset_list)
master_print(f"Finished loading dataset {name} {len(self.dataset)} samples...")
def __len__(self):
return len(self.dataset)
def text_preprocess(self, item, is_video=False) -> List[Dict[str, str]]:
instruction = item['instruction']
question = item['inputs']
answer = item['outputs']
query = f"{instruction} {DEFAULT_IMAGE_TOKEN if not is_video else DEFAULT_VIDEO_TOKEN}"
if len(question) > 0:
query += question
conversations = [
{
'from': 'human',
'value': query
},
{
'from': 'model',
'value': answer
}
]
return conversations
def bin2image(self, image_base64_str):
img = Image.open(BytesIO(b64decode(image_base64_str))).convert("RGB")
img = np.array(img)
if img.shape[2] != 3:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
img = Image.fromarray(img).convert('RGB')
img = self.preprocess_image(img)
return img
def vis_preprocess(self, image_base64_str_list) -> Image:
try:
images = list(map(self.bin2image, image_base64_str_list))
formatted_images = []
for image in images:
if isinstance(image, list):
formatted_images.extend(image)
else:
formatted_images.append(image)
return formatted_images
except Exception as e:
# print("Invalid sample, skip.")
return None
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
item = self.dataset[i]
img_data = item['image_base64_str']
images = self.vis_preprocess(img_data)
if images is None:
return None
# M3IT video sample has 8 frames
is_video = True if len(images) > 0 else False
ret = {
'images': images,
'conversations': self.text_preprocess(item, is_video)
}
return ret
@DATASETS.register_obj
def m3it(data_args):
tasks = data_configs['m3it']['default_tasks']
if 'tasks' in data_args.external_args:
tasks = data_args.external_args['tasks']
return M3ITDataset(anno_path=None, data_args=data_args, selected_tasks=tasks)