""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ from omegaconf import OmegaConf from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from llava.mm_utils import select_best_resolution class BaseProcessor: def __init__(self): self.transform = lambda x: x return def __call__(self, item): return self.transform(item) @classmethod def from_config(cls, cfg=None): return cls() def build(self, **kwargs): cfg = OmegaConf.create(kwargs) return self.from_config(cfg) class BlipImageBaseProcessor(BaseProcessor): def __init__(self, image_mean=None, image_std=None): if image_mean is None: image_mean = (0.48145466, 0.4578275, 0.40821073) if image_std is None: image_std = (0.26862954, 0.26130258, 0.27577711) self.normalize = transforms.Normalize(image_mean, image_std) self.image_mean = image_mean self.image_std = image_std class Blip2ImageTrainProcessor(BlipImageBaseProcessor): def __init__(self, image_size=224, image_mean=None, image_std=None, min_scale=0.5, max_scale=1.0, is_training=True, dynamic_resolution=None): super().__init__(image_mean=image_mean, image_std=image_std) self.is_training = is_training self.dynamic_resolution = dynamic_resolution if isinstance(image_size, int): self.img_size = image_size size_tuple = (image_size, image_size) elif isinstance(image_size, tuple): self.img_size = image_size[0] size_tuple = image_size # H, W self.crop_size = { 'height': self.img_size, 'width': self.img_size } if self.dynamic_resolution: self.transform_dic = {} for size_ in self.dynamic_resolution: self.transform_dic[size_] = ( transforms.Compose( [ transforms.Resize( size_, interpolation=InterpolationMode.BICUBIC # H, W ), transforms.ToTensor(), self.normalize, ] ) ) self.transform = transforms.Compose( [ transforms.Resize( size_tuple, interpolation=InterpolationMode.BICUBIC ), transforms.ToTensor(), self.normalize, ] ) def preprocess(self, item): # if self.dynamic_resolution is not None: # images = [] # images.append(self.transform(item)) # width, height = item.size # best_fit_res = select_best_resolution((width, height), self.dynamic_resolution) # resize_img = self.transform_dic[best_fit_res](item) # splitted_imgs = self.split_images(resize_img, (self.img_size, self.img_size)) # images.extend(splitted_imgs) # return images # else: return self.transform(item) @classmethod def from_config(cls, cfg=None): if cfg is None: cfg = OmegaConf.create() image_size = cfg.get("image_size", 224) image_mean = cfg.get("mean", None) image_std = cfg.get("image_std", None) min_scale = cfg.get("min_scale", 0.5) max_scale = cfg.get("max_scale", 1.0) return cls( image_size=image_size, image_mean=image_mean, image_std=image_std, min_scale=min_scale, max_scale=max_scale, ) @staticmethod def split_images(image, split_size): splited_images = [] _, h, w = image.shape # C, H, W assert h % split_size[0] == 0 and w % split_size[1] == 0, "dynamic resolution must be a multiple of input image size " for i in range(0, h, split_size[0]): for j in range(0, w, split_size[1]): patch = image[:, i:i+split_size[0], j:j+split_size[1]].clone() splited_images.append(patch) return splited_images