# Copyright (c) 2020 The MMSegmenation Authors. # # This source code is licensed under the Apache license found in the # LICENSE file in the root directory of this source tree. import random import os import math import mmcv import torch import numpy as np import torchvision.transforms as T from torchvision import datasets from torch.utils.data import Dataset from megatron import print_rank_0 from megatron import get_args from PIL import Image, ImageOps, ImageEnhance import torchvision.transforms as torch_tr def _is_pil_image(img): return isinstance(img, Image.Image) class PhotoMetricDistortion(object): """Apply photometric distortion to image sequentially, every transformation is applied with a probability of 0.5. The position of random contrast is in second or second to last. 1. random brightness 2. random contrast (mode 0) 3. convert color from BGR to HSV 4. random saturation 5. random hue 6. convert color from HSV to BGR 7. random contrast (mode 1) 8. randomly swap channels Args: brightness_delta (int): delta of brightness. contrast_range (tuple): range of contrast. saturation_range (tuple): range of saturation. hue_delta (int): delta of hue. """ def __init__(self, brightness_delta=32, contrast_range=(0.5, 1.5), saturation_range=(0.5, 1.5), hue_delta=18): self.brightness_delta = brightness_delta self.contrast_lower, self.contrast_upper = contrast_range self.saturation_lower, self.saturation_upper = saturation_range self.hue_delta = hue_delta def convert(self, img, alpha=1, beta=0): """Multiple with alpha and add beat with clip.""" img = img.astype(np.float32) * alpha + beta img = np.clip(img, 0, 255) return img.astype(np.uint8) def brightness(self, img): """Brightness distortion.""" if random.randint(0, 1): return self.convert( img, beta=random.uniform(-self.brightness_delta, self.brightness_delta)) return img def contrast(self, img): """Contrast distortion.""" if random.randint(0, 1): return self.convert( img, alpha=random.uniform(self.contrast_lower, self.contrast_upper)) return img def saturation(self, img): """Saturation distortion.""" if random.randint(0, 1): img = mmcv.bgr2hsv(img) img[:, :, 1] = self.convert( img[:, :, 1], alpha=random.uniform(self.saturation_lower, self.saturation_upper)) img = mmcv.hsv2bgr(img) return img def hue(self, img): """Hue distortion.""" if random.randint(0, 1): img = mmcv.bgr2hsv(img) img[:, :, 0] = (img[:, :, 0].astype(int) + random.randint(-self.hue_delta, self.hue_delta)) % 180 img = mmcv.hsv2bgr(img) return img def __call__(self, img): """Call function to perform photometric distortion on images. Args: results (dict): Result dict from loading pipeline. Returns: dict: Result dict with images distorted. """ img = np.array(img) # random brightness img = self.brightness(img) # mode == 0 --> do random contrast first # mode == 1 --> do random contrast last mode = random.randint(0, 1) if mode == 1: img = self.contrast(img) # random saturation img = self.saturation(img) # random hue img = self.hue(img) # random contrast if mode == 0: img = self.contrast(img) img = Image.fromarray(img.astype(np.uint8)).convert('RGB') return img class RandomCrop(object): """ Take a random crop from the image. First the image or crop size may need to be adjusted if the incoming image is too small... If the image is smaller than the crop, then: the image is padded up to the size of the crop unless 'nopad', in which case the crop size is shrunk to fit the image A random crop is taken such that the crop fits within the image. if cfg.DATASET.TRANSLATION_AUG_FIX is set, we insure that there's always translation randomness of at least that value around the image. if image < crop_size: # slide crop within image, random offset else: # slide image within crop """ def __init__(self, crop_size): args = get_args() self.size = crop_size self.cat_max_ratio = 0.75 self.ignore_index = args.ignore_index self.pad_color = (0, 0, 0) def get_crop_bbox(self, img): """Randomly get a crop bounding box.""" img_w, img_h = img.size target_h, target_w = self.size #[H W] margin_h = max(img_h - target_h, 0) margin_w = max(img_w - target_w, 0) offset_h = random.randint(0, margin_h) offset_w = random.randint(0, margin_w) crop_y1, crop_y2 = offset_h, offset_h + target_h crop_x1, crop_x2 = offset_w, offset_w + target_w return crop_y1, crop_y2, crop_x1, crop_x2 def crop(self, img, crop_bbox): """Crop from ``img``""" crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox img = img.crop((crop_x1, crop_y1, crop_x2, crop_y2)) return img @staticmethod def crop_in_image(target_w, target_h, w, h, img, mask): if w == target_w: x1 = 0 else: x1 = random.randint(0, w - target_w) if h == target_h: y1 = 0 else: y1 = random.randint(0, h - target_h) return [img.crop((x1, y1, x1 + target_w, y1 + target_h)), mask.crop((x1, y1, x1 + target_w, y1 + target_h))] def __call__(self, img, mask): w, h = img.size target_h, target_w = self.size # ASSUME H, W if w == target_w and h == target_h: return img, mask # Pad image if image < crop if target_h > h: pad_h = (target_h - h) // 2 + 1 else: pad_h = 0 if target_w > w: pad_w = (target_w - w) // 2 + 1 else: pad_w = 0 border = (pad_w, pad_h, pad_w, pad_h) if pad_h or pad_w: img = ImageOps.expand(img, border=border, fill=(0, 0, 0)) mask = ImageOps.expand(mask, border=border, fill=self.ignore_index) w, h = img.size crop_bbox = self.get_crop_bbox(img) if self.cat_max_ratio < 1.: # Repeat 10 times for _ in range(10): seg_temp = self.crop(mask, crop_bbox) labels, cnt = np.unique(seg_temp, return_counts=True) cnt = cnt[labels != self.ignore_index] if len(cnt) > 1 and np.max(cnt) / np.sum( cnt) < self.cat_max_ratio: break crop_bbox = self.get_crop_bbox(img) # crop the image img = self.crop(img, crop_bbox) # crop semantic seg mask = self.crop(mask, crop_bbox) assert(img.size[0] == self.size[1] and img.size[1] == self.size[0]) return img, mask class RandomSizeAndCrop(object): def __init__(self, crop_size, scale_min=0.5, scale_max=2.0): self.crop = RandomCrop(crop_size) self.scale_min = scale_min self.scale_max = scale_max def __call__(self, img, mask): scale_amt = random.uniform(self.scale_min, self.scale_max) w, h = [int(i * scale_amt) for i in img.size] resized_img = img.resize((w, h), Image.BICUBIC) resized_mask = mask.resize((w, h), Image.NEAREST) img, mask = self.crop(resized_img, resized_mask) return img, mask class RandomHorizontallyFlip(object): def __call__(self, img, mask): if random.random() < 0.5: return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose( Image.FLIP_LEFT_RIGHT) return img, mask def adjust_brightness(img, brightness_factor): """Adjust brightness of an Image. Args: img (PIL Image): PIL Image to be adjusted. brightness_factor (float): How much to adjust the brightness. Can be any non negative number. 0 gives a black image, 1 gives the original image while 2 increases the brightness by a factor of 2. Returns: PIL Image: Brightness adjusted image. """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) enhancer = ImageEnhance.Brightness(img) img = enhancer.enhance(brightness_factor) return img def adjust_contrast(img, contrast_factor): """Adjust contrast of an Image. Args: img (PIL Image): PIL Image to be adjusted. contrast_factor (float): How much to adjust the contrast. Can be any non negative number. 0 gives a solid gray image, 1 gives the original image while 2 increases the contrast by a factor of 2. Returns: PIL Image: Contrast adjusted image. """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) enhancer = ImageEnhance.Contrast(img) img = enhancer.enhance(contrast_factor) return img def adjust_saturation(img, saturation_factor): """Adjust color saturation of an image. Args: img (PIL Image): PIL Image to be adjusted. saturation_factor (float): How much to adjust the saturation. 0 will give a black and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2. Returns: PIL Image: Saturation adjusted image. """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) enhancer = ImageEnhance.Color(img) img = enhancer.enhance(saturation_factor) return img def adjust_hue(img, hue_factor): """Adjust hue of an image. The image hue is adjusted by converting the image to HSV and cyclically shifting the intensities in the hue channel (H). The image is then converted back to original image mode. `hue_factor` is the amount of shift in H channel and must be in the interval `[-0.5, 0.5]`. See https://en.wikipedia.org/wiki/Hue for more details on Hue. Args: img (PIL Image): PIL Image to be adjusted. hue_factor (float): How much to shift the hue channel. Should be in [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in HSV space in positive and negative direction respectively. 0 means no shift. Therefore, both -0.5 and 0.5 will give an image with complementary colors while 0 gives the original image. Returns: PIL Image: Hue adjusted image. """ if not(-0.5 <= hue_factor <= 0.5): raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) input_mode = img.mode if input_mode in {'L', '1', 'I', 'F'}: return img h, s, v = img.convert('HSV').split() np_h = np.array(h, dtype=np.uint8) # uint8 addition take cares of rotation across boundaries with np.errstate(over='ignore'): np_h += np.uint8(hue_factor * 255) h = Image.fromarray(np_h, 'L') img = Image.merge('HSV', (h, s, v)).convert(input_mode) return img class ColorJitter(object): """Randomly change the brightness, contrast and saturation of an image. Args: brightness (float): How much to jitter brightness. brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. contrast (float): How much to jitter contrast. contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. saturation (float): How much to jitter saturation. saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. hue(float): How much to jitter hue. hue_factor is chosen uniformly from [-hue, hue]. Should be >=0 and <= 0.5. """ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): self.brightness = brightness self.contrast = contrast self.saturation = saturation self.hue = hue @staticmethod def get_params(brightness, contrast, saturation, hue): """Get a randomized transform to be applied on image. Arguments are same as that of __init__. Returns: Transform which randomly adjusts brightness, contrast and saturation in a random order. """ transforms = [] if brightness > 0: brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) transforms.append( torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor))) if contrast > 0: contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) transforms.append( torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor))) if saturation > 0: saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) transforms.append( torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor))) if hue > 0: hue_factor = np.random.uniform(-hue, hue) transforms.append( torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor))) np.random.shuffle(transforms) transform = torch_tr.Compose(transforms) return transform def __call__(self, img): """ Args: img (PIL Image): Input image. Returns: PIL Image: Color jittered image. """ transform = self.get_params(self.brightness, self.contrast, self.saturation, self.hue) return transform(img)