import logging import torch import datasets import cv2 import numpy as np from base64 import b64decode from io import BytesIO from PIL import Image from torch.utils.data import ConcatDataset from llava.datasets.builder import DATASETS from typing import Dict, Optional, Sequence, List from llava.datasets.data_cfgs import data_configs from llava.datasets.base_dataset import ImageTaskDataset from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_VIDEO_TOKEN from llava.utils import master_print class M3ITDataset(ImageTaskDataset): def __init__(self, anno_path, data_args=None, name='m3it', selected_tasks=None): super().__init__(anno_path, data_args, name) self.selected_tasks = selected_tasks dataset_list = [ datasets.load_dataset("MMInstruction/M3IT", i, num_proc=16) for i in selected_tasks ] # some dataset have no validation target_dataset_list = [] master_print('#' * 50) for d in dataset_list: try: target_dataset_list.append(d['train']) master_print(f"TASK {d['train']._info.config_name}, SIZE {len(d['train'])}") except KeyError: print(f"{d['train']._info.config_name} has no train set.") self.dataset = ConcatDataset(target_dataset_list) master_print(f"Finished loading dataset {name} {len(self.dataset)} samples...") def __len__(self): return len(self.dataset) def text_preprocess(self, item, is_video=False) -> List[Dict[str, str]]: instruction = item['instruction'] question = item['inputs'] answer = item['outputs'] query = f"{instruction} {DEFAULT_IMAGE_TOKEN if not is_video else DEFAULT_VIDEO_TOKEN}" if len(question) > 0: query += question conversations = [ { 'from': 'human', 'value': query }, { 'from': 'model', 'value': answer } ] return conversations def bin2image(self, image_base64_str): img = Image.open(BytesIO(b64decode(image_base64_str))).convert("RGB") img = np.array(img) if img.shape[2] != 3: img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) img = Image.fromarray(img).convert('RGB') img = self.preprocess_image(img) return img def vis_preprocess(self, image_base64_str_list) -> Image: try: images = list(map(self.bin2image, image_base64_str_list)) formatted_images = [] for image in images: if isinstance(image, list): formatted_images.extend(image) else: formatted_images.append(image) return formatted_images except Exception as e: # print("Invalid sample, skip.") return None def __getitem__(self, i) -> Dict[str, torch.Tensor]: item = self.dataset[i] img_data = item['image_base64_str'] images = self.vis_preprocess(img_data) if images is None: return None # M3IT video sample has 8 frames is_video = True if len(images) > 0 else False ret = { 'images': images, 'conversations': self.text_preprocess(item, is_video) } return ret @DATASETS.register_obj def m3it(data_args): tasks = data_configs['m3it']['default_tasks'] if 'tasks' in data_args.external_args: tasks = data_args.external_args['tasks'] return M3ITDataset(anno_path=None, data_args=data_args, selected_tasks=tasks)