File size: 3,669 Bytes
bbfa6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import logging

import torch
import datasets
import cv2

import numpy as np
from base64 import b64decode
from io import BytesIO
from PIL import Image
from torch.utils.data import ConcatDataset
from llava.datasets.builder import DATASETS
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import ImageTaskDataset
from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_VIDEO_TOKEN
from llava.utils import master_print


class M3ITDataset(ImageTaskDataset):
    def __init__(self, anno_path, data_args=None, name='m3it', selected_tasks=None):
        super().__init__(anno_path, data_args, name)

        self.selected_tasks = selected_tasks
        dataset_list = [
            datasets.load_dataset("MMInstruction/M3IT", i, num_proc=16) for i in selected_tasks
        ]
        # some dataset have no validation
        target_dataset_list = []
        master_print('#' * 50)
        for d in dataset_list:
            try:
                target_dataset_list.append(d['train'])
                master_print(f"TASK {d['train']._info.config_name}, SIZE {len(d['train'])}")
            except KeyError:
                print(f"{d['train']._info.config_name} has no train set.")
        self.dataset = ConcatDataset(target_dataset_list)
        master_print(f"Finished loading dataset {name} {len(self.dataset)} samples...")

    def __len__(self):
        return len(self.dataset)

    def text_preprocess(self, item, is_video=False) -> List[Dict[str, str]]:
        instruction = item['instruction']
        question = item['inputs']
        answer = item['outputs']

        query = f"{instruction} {DEFAULT_IMAGE_TOKEN if not is_video else DEFAULT_VIDEO_TOKEN}"
        if len(question) > 0:
            query += question

        conversations = [
            {
                'from': 'human',
                'value': query
            },
            {
                'from': 'model',
                'value': answer
            }
        ]

        return conversations

    def bin2image(self, image_base64_str):
        img = Image.open(BytesIO(b64decode(image_base64_str))).convert("RGB")
        img = np.array(img)

        if img.shape[2] != 3:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

        img = Image.fromarray(img).convert('RGB')
        img = self.preprocess_image(img)

        return img

    def vis_preprocess(self, image_base64_str_list) -> Image:
        try:
            images = list(map(self.bin2image, image_base64_str_list))
            formatted_images = []
            for image in images:
                if isinstance(image, list):
                    formatted_images.extend(image)
                else:
                    formatted_images.append(image)
            return formatted_images
        except Exception as e:
            # print("Invalid sample, skip.")
            return None

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        item = self.dataset[i]

        img_data = item['image_base64_str']

        images = self.vis_preprocess(img_data)
        if images is None:
            return None

        # M3IT video sample has 8 frames
        is_video = True if len(images) > 0 else False

        ret = {
            'images': images,
            'conversations': self.text_preprocess(item, is_video)
        }

        return ret


@DATASETS.register_obj
def m3it(data_args):
    tasks = data_configs['m3it']['default_tasks']
    if 'tasks' in data_args.external_args:
        tasks = data_args.external_args['tasks']

    return M3ITDataset(anno_path=None, data_args=data_args, selected_tasks=tasks)