import torch from ldm.modules.midas.api import load_midas_transform import albumentations from torchvision import transforms from PIL import Image import numpy as np from einops import rearrange import cv2 from ldm.util import instantiate_from_config from omegaconf import ListConfig from open_clip.transform import ResizeMaxSize class AddMiDaS(object): def __init__(self, model_type): super().__init__() self.transform = load_midas_transform(model_type) def pt2np(self, x): x = ((x + 1.0) * .5).detach().cpu().numpy() return x def np2pt(self, x): x = torch.from_numpy(x) * 2 - 1. return x def __call__(self, sample): # sample['jpg'] is tensor hwc in [-1, 1] at this point x = self.pt2np(sample['jpg']) x = self.transform({"image": x})["image"] sample['midas_in'] = x return sample class new_process_im_base: def __init__(self, size = 512, interpolation = 3, do_flip = True, flip_p = 0.5, hint_range_m11 = False, ): self.do_flip = do_flip self.flip_p = flip_p self.rescale = transforms.Resize(size=size, interpolation=interpolation) if self.do_flip: self.flip = transforms.functional.hflip # base_tf [-1, 1] base_tf_m11 = [ transforms.ToTensor(), # to be checked transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))] self.base_tf_m11 = transforms.Compose(base_tf_m11) # base_tf [0, 1] base_tf_01 = [ transforms.ToTensor(), # to be checked transforms.Lambda(lambda x: rearrange(x, 'c h w -> h w c'))] self.base_tf_01 = transforms.Compose(base_tf_01) self.hint_range_m11 = hint_range_m11 def __call__(self, im, pos_info, im_hint = None): # im = Image.open(filename) im = im.convert("RGB") # crop size = im.size crop_size = min(size) crop_axis = size.index(crop_size) lf, up, rg, dn = pos_info if crop_axis == 0: # width box_up, box_dn = self.generate_range(up, dn, size[1], size[0]) box_lf, box_rg = 0, size[0] else: box_lf, box_rg = self.generate_range(lf, rg, size[0], size[1]) box_up, box_dn = 0, size[1] im = im.crop((box_lf, box_up, box_rg, box_dn)) # rescale im = self.rescale(im) # flip_img = False if self.do_flip: if torch.rand(1) < self.flip_p: im = self.flip(im) flip_img = True im = self.base_tf_m11(im) # im_hint = None # if hint_filename is not None: # im_hint = Image.open(hint_filename) if im_hint is not None: im_hint = im_hint.convert("RGB") im_hint = im_hint.crop((box_lf, box_up, box_rg, box_dn)) im_hint = self.rescale(im_hint) if flip_img: im_hint = self.flip(im_hint) im_hint = self.base_tf_m11(im_hint) if self.hint_range_m11 else self.base_tf_01(im_hint) return im, im_hint def generate_range(self, low, high, len_max, len_min): mid = (low + high) / 2 * (len_max if high <= 1 else 1) max_range = min(mid + len_min / 2, len_max) min_range = min( max(mid - len_min / 2, 0 ), max(max_range - len_min, 0) ) return int(min_range), int(min_range + len_min) class new_process_im(new_process_im_base): def __call__(self, filename, pos_info, hint_filename = None): im = Image.open(filename) if hint_filename is not None: im_hint = Image.open(hint_filename) else: im_hint = None return super().__call__(im, pos_info, im_hint) class imagenet_process_im: def __init__(self, size = 512, do_flip = False, min_crop_f=0.5, max_crop_f=1., flip_p=0.5, random_crop=False ): self.do_flip = do_flip if self.do_flip: self.flip = transforms.RandomHorizontalFlip(p=flip_p) # self.base = self.get_base() # self.size = size self.min_crop_f = min_crop_f self.max_crop_f = max_crop_f assert(max_crop_f <= 1.) self.center_crop = not random_crop self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) self.size = size def __call__(self, im): im = im.convert("RGB") image = np.array(im).astype(np.uint8) # if image.shape[0] < self.size or image.shape[1] < self.size: # return None # crop min_side_len = min(image.shape[:2]) crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) crop_side_len = int(crop_side_len) if self.center_crop: self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) else: self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) image = self.cropper(image=image)["image"] # ? # rescale image = self.image_rescaler(image=image)["image"] # flip if self.do_flip: image = self.flip(Image.fromarray(image)) image = np.array(image).astype(np.uint8) return (image/127.5 - 1.0).astype(np.float32) # used for CLIP image encoder class process_wb_im: def __init__(self, size = 224, # do_padding = True, image_transforms=[], use_clip_resize=False, image_mean = None, image_std = None, exchange_channel = True, ): self.image_rescaler = albumentations.LongestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) self.image_size = size # self.do_padding = do_padding self.pad = albumentations.PadIfNeeded(min_height= self.image_size, min_width=self.image_size, border_mode=cv2.BORDER_CONSTANT, value= (255, 255, 255), ) if isinstance(image_transforms, ListConfig): image_transforms = [instantiate_from_config(tt) for tt in image_transforms] image_transforms.extend([ transforms.ToTensor(), transforms.Normalize( mean= image_mean if image_mean is not None else (0.48145466, 0.4578275, 0.40821073), std= image_std if image_std is not None else (0.26862954, 0.26130258, 0.27577711) ), ]) # transforms.Lambda(lambda x: rearrange(x, 'c h w -> h w c')) # ]) # transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) if exchange_channel: image_transforms.append( transforms.Lambda(lambda x: rearrange(x, 'c h w -> h w c')) ) image_transforms = transforms.Compose(image_transforms) self.tform = image_transforms self.use_clip_resize = use_clip_resize self.clip_resize = ResizeMaxSize(max_size = self.image_size, interpolation=transforms.InterpolationMode.BICUBIC, fill=(255, 255, 255)) def __call__(self, im): im = im.convert("RGB") # if self.do_padding: # im = self.padding_image(im) if self.use_clip_resize: im = self.clip_resize(im) else: im = self.padding_image(im) return self.tform(im) def padding_image(self, im): # resize im = np.array(im).astype(np.uint8) im_rescaled = self.image_rescaler(image=im)["image"] # padding im_padded = self.pad(image=im_rescaled)["image"] return im_padded # use for VQ-GAN class vqgan_process_im: def __init__(self, size=384, random_crop=False, augment=False, ori_preprocessor = False, to_tensor=False): self.size = size self.random_crop = random_crop self.augment = augment assert self.size is not None and self.size > 0 if ori_preprocessor: # if self.size is not None and self.size > 0: self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) if not self.random_crop: # train self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) else: # test self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) # else: # self.preprocessor = lambda **kwargs: kwargs else: self.rescaler = albumentations.LongestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) self.pad = albumentations.PadIfNeeded(min_height= self.size, min_width=self.size, border_mode=cv2.BORDER_CONSTANT, value= (255, 255, 255), ) self.preprocessor = albumentations.Compose([self.rescaler, self.pad]) if self.augment: # train # Add data aug transformations self.data_augmentation = albumentations.Compose([ albumentations.GaussianBlur(p=0.1), albumentations.OneOf([ albumentations.HueSaturationValue (p=0.3), albumentations.ToGray(p=0.3), albumentations.ChannelShuffle(p=0.3) ], p=0.3) ]) if to_tensor: self.tform = transforms.ToTensor() self.to_tensor = to_tensor # if exchange_channel: # self.exchange_channel = transforms.Lambda(lambda x: rearrange(x, 'c h w -> h w c')) def __call__(self, image): image = image.convert("RGB") image = np.array(image).astype(np.uint8) image = self.preprocessor(image=image)["image"] if self.augment: image = self.data_augmentation(image=image)['image'] image = (image/127.5 - 1.0).astype(np.float32) if self.to_tensor: image = self.tform(image) return image