File size: 4,375 Bytes
bbfa6f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from omegaconf import OmegaConf
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from llava.mm_utils import select_best_resolution
class BaseProcessor:
def __init__(self):
self.transform = lambda x: x
return
def __call__(self, item):
return self.transform(item)
@classmethod
def from_config(cls, cfg=None):
return cls()
def build(self, **kwargs):
cfg = OmegaConf.create(kwargs)
return self.from_config(cfg)
class BlipImageBaseProcessor(BaseProcessor):
def __init__(self, image_mean=None, image_std=None):
if image_mean is None:
image_mean = (0.48145466, 0.4578275, 0.40821073)
if image_std is None:
image_std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = transforms.Normalize(image_mean, image_std)
self.image_mean = image_mean
self.image_std = image_std
class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
def __init__(self, image_size=224, image_mean=None, image_std=None, min_scale=0.5, max_scale=1.0, is_training=True, dynamic_resolution=None):
super().__init__(image_mean=image_mean, image_std=image_std)
self.is_training = is_training
self.dynamic_resolution = dynamic_resolution
if isinstance(image_size, int):
self.img_size = image_size
size_tuple = (image_size, image_size)
elif isinstance(image_size, tuple):
self.img_size = image_size[0]
size_tuple = image_size # H, W
self.crop_size = {
'height': self.img_size,
'width': self.img_size
}
if self.dynamic_resolution:
self.transform_dic = {}
for size_ in self.dynamic_resolution:
self.transform_dic[size_] = (
transforms.Compose(
[
transforms.Resize(
size_, interpolation=InterpolationMode.BICUBIC # H, W
),
transforms.ToTensor(),
self.normalize,
]
)
)
self.transform = transforms.Compose(
[
transforms.Resize(
size_tuple, interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
self.normalize,
]
)
def preprocess(self, item):
# if self.dynamic_resolution is not None:
# images = []
# images.append(self.transform(item))
# width, height = item.size
# best_fit_res = select_best_resolution((width, height), self.dynamic_resolution)
# resize_img = self.transform_dic[best_fit_res](item)
# splitted_imgs = self.split_images(resize_img, (self.img_size, self.img_size))
# images.extend(splitted_imgs)
# return images
# else:
return self.transform(item)
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
image_size = cfg.get("image_size", 224)
image_mean = cfg.get("mean", None)
image_std = cfg.get("image_std", None)
min_scale = cfg.get("min_scale", 0.5)
max_scale = cfg.get("max_scale", 1.0)
return cls(
image_size=image_size,
image_mean=image_mean,
image_std=image_std,
min_scale=min_scale,
max_scale=max_scale,
)
@staticmethod
def split_images(image, split_size):
splited_images = []
_, h, w = image.shape # C, H, W
assert h % split_size[0] == 0 and w % split_size[1] == 0, "dynamic resolution must be a multiple of input image size "
for i in range(0, h, split_size[0]):
for j in range(0, w, split_size[1]):
patch = image[:, i:i+split_size[0], j:j+split_size[1]].clone()
splited_images.append(patch)
return splited_images |