File size: 4,375 Bytes
bbfa6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

from omegaconf import OmegaConf
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from llava.mm_utils import select_best_resolution

class BaseProcessor:
    def __init__(self):
        self.transform = lambda x: x
        return

    def __call__(self, item):
        return self.transform(item)

    @classmethod
    def from_config(cls, cfg=None):
        return cls()

    def build(self, **kwargs):
        cfg = OmegaConf.create(kwargs)

        return self.from_config(cfg)


class BlipImageBaseProcessor(BaseProcessor):
    def __init__(self, image_mean=None, image_std=None):
        if image_mean is None:
            image_mean = (0.48145466, 0.4578275, 0.40821073)
        if image_std is None:
            image_std = (0.26862954, 0.26130258, 0.27577711)

        self.normalize = transforms.Normalize(image_mean, image_std)
        self.image_mean = image_mean
        self.image_std = image_std

class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
    def __init__(self, image_size=224, image_mean=None, image_std=None, min_scale=0.5, max_scale=1.0, is_training=True, dynamic_resolution=None):
        super().__init__(image_mean=image_mean, image_std=image_std)

        self.is_training = is_training
        self.dynamic_resolution = dynamic_resolution
        if isinstance(image_size, int):
            self.img_size = image_size
            size_tuple = (image_size, image_size)
        elif isinstance(image_size, tuple):
            self.img_size = image_size[0]
            size_tuple = image_size   # H, W
        self.crop_size = {
            'height': self.img_size,
            'width': self.img_size
        }
        if self.dynamic_resolution:
            self.transform_dic = {}
            for size_ in self.dynamic_resolution:
                self.transform_dic[size_] = (
                    transforms.Compose(
                    [
                        transforms.Resize(
                            size_, interpolation=InterpolationMode.BICUBIC # H, W
                        ),
                        transforms.ToTensor(),
                        self.normalize,
                    ]
                    )                   
                )
        self.transform = transforms.Compose(
            [
                transforms.Resize(
                    size_tuple, interpolation=InterpolationMode.BICUBIC
                ),
                transforms.ToTensor(),
                self.normalize,
            ]
        )

    def preprocess(self, item):
        # if self.dynamic_resolution is not None:
        #     images = []
        #     images.append(self.transform(item))
        #     width, height = item.size
        #     best_fit_res = select_best_resolution((width, height), self.dynamic_resolution)
        #     resize_img = self.transform_dic[best_fit_res](item)
        #     splitted_imgs = self.split_images(resize_img, (self.img_size, self.img_size))
        #     images.extend(splitted_imgs)
        #     return images
        # else:
        return self.transform(item)

    @classmethod
    def from_config(cls, cfg=None):
        if cfg is None:
            cfg = OmegaConf.create()

        image_size = cfg.get("image_size", 224)

        image_mean = cfg.get("mean", None)
        image_std = cfg.get("image_std", None)

        min_scale = cfg.get("min_scale", 0.5)
        max_scale = cfg.get("max_scale", 1.0)

        return cls(
            image_size=image_size,
            image_mean=image_mean,
            image_std=image_std,
            min_scale=min_scale,
            max_scale=max_scale,
        )

    @staticmethod
    def split_images(image, split_size):
        splited_images = []
        _, h, w = image.shape # C, H, W
        assert h % split_size[0] == 0 and w % split_size[1] == 0, "dynamic resolution must be a multiple of input image size "
        for i in range(0, h, split_size[0]):
            for j in range(0, w, split_size[1]):
                patch = image[:, i:i+split_size[0], j:j+split_size[1]].clone()
                splited_images.append(patch)
        return splited_images