|
|
|
|
|
|
|
|
|
|
|
import random |
|
import os |
|
import math |
|
import mmcv |
|
import torch |
|
import numpy as np |
|
import torchvision.transforms as T |
|
from torchvision import datasets |
|
from torch.utils.data import Dataset |
|
from megatron import print_rank_0 |
|
from megatron import get_args |
|
from PIL import Image, ImageOps, ImageEnhance |
|
import torchvision.transforms as torch_tr |
|
|
|
def _is_pil_image(img): |
|
return isinstance(img, Image.Image) |
|
|
|
class PhotoMetricDistortion(object): |
|
"""Apply photometric distortion to image sequentially, every transformation |
|
is applied with a probability of 0.5. The position of random contrast is in |
|
second or second to last. |
|
1. random brightness |
|
2. random contrast (mode 0) |
|
3. convert color from BGR to HSV |
|
4. random saturation |
|
5. random hue |
|
6. convert color from HSV to BGR |
|
7. random contrast (mode 1) |
|
8. randomly swap channels |
|
Args: |
|
brightness_delta (int): delta of brightness. |
|
contrast_range (tuple): range of contrast. |
|
saturation_range (tuple): range of saturation. |
|
hue_delta (int): delta of hue. |
|
""" |
|
|
|
def __init__(self, |
|
brightness_delta=32, |
|
contrast_range=(0.5, 1.5), |
|
saturation_range=(0.5, 1.5), |
|
hue_delta=18): |
|
self.brightness_delta = brightness_delta |
|
self.contrast_lower, self.contrast_upper = contrast_range |
|
self.saturation_lower, self.saturation_upper = saturation_range |
|
self.hue_delta = hue_delta |
|
|
|
def convert(self, img, alpha=1, beta=0): |
|
"""Multiple with alpha and add beat with clip.""" |
|
img = img.astype(np.float32) * alpha + beta |
|
img = np.clip(img, 0, 255) |
|
return img.astype(np.uint8) |
|
|
|
def brightness(self, img): |
|
"""Brightness distortion.""" |
|
if random.randint(0, 1): |
|
return self.convert( |
|
img, |
|
beta=random.uniform(-self.brightness_delta, |
|
self.brightness_delta)) |
|
return img |
|
|
|
def contrast(self, img): |
|
"""Contrast distortion.""" |
|
if random.randint(0, 1): |
|
return self.convert( |
|
img, |
|
alpha=random.uniform(self.contrast_lower, self.contrast_upper)) |
|
return img |
|
|
|
def saturation(self, img): |
|
"""Saturation distortion.""" |
|
if random.randint(0, 1): |
|
img = mmcv.bgr2hsv(img) |
|
img[:, :, 1] = self.convert( |
|
img[:, :, 1], |
|
alpha=random.uniform(self.saturation_lower, |
|
self.saturation_upper)) |
|
img = mmcv.hsv2bgr(img) |
|
return img |
|
|
|
def hue(self, img): |
|
"""Hue distortion.""" |
|
if random.randint(0, 1): |
|
img = mmcv.bgr2hsv(img) |
|
img[:, :, |
|
0] = (img[:, :, 0].astype(int) + |
|
random.randint(-self.hue_delta, self.hue_delta)) % 180 |
|
img = mmcv.hsv2bgr(img) |
|
return img |
|
|
|
def __call__(self, img): |
|
"""Call function to perform photometric distortion on images. |
|
Args: |
|
results (dict): Result dict from loading pipeline. |
|
Returns: |
|
dict: Result dict with images distorted. |
|
""" |
|
img = np.array(img) |
|
|
|
|
|
img = self.brightness(img) |
|
|
|
|
|
|
|
mode = random.randint(0, 1) |
|
if mode == 1: |
|
img = self.contrast(img) |
|
|
|
|
|
img = self.saturation(img) |
|
|
|
|
|
img = self.hue(img) |
|
|
|
|
|
if mode == 0: |
|
img = self.contrast(img) |
|
|
|
img = Image.fromarray(img.astype(np.uint8)).convert('RGB') |
|
return img |
|
|
|
|
|
class RandomCrop(object): |
|
""" |
|
Take a random crop from the image. |
|
|
|
First the image or crop size may need to be adjusted if the incoming image |
|
is too small... |
|
|
|
If the image is smaller than the crop, then: |
|
the image is padded up to the size of the crop |
|
unless 'nopad', in which case the crop size is shrunk to fit the image |
|
|
|
A random crop is taken such that the crop fits within the image. |
|
|
|
|
|
if cfg.DATASET.TRANSLATION_AUG_FIX is set, we insure that there's always |
|
translation randomness of at least that value around the image. |
|
|
|
if image < crop_size: |
|
# slide crop within image, random offset |
|
else: |
|
# slide image within crop |
|
""" |
|
def __init__(self, crop_size): |
|
args = get_args() |
|
self.size = crop_size |
|
self.cat_max_ratio = 0.75 |
|
self.ignore_index = args.ignore_index |
|
self.pad_color = (0, 0, 0) |
|
|
|
def get_crop_bbox(self, img): |
|
"""Randomly get a crop bounding box.""" |
|
img_w, img_h = img.size |
|
target_h, target_w = self.size |
|
margin_h = max(img_h - target_h, 0) |
|
margin_w = max(img_w - target_w, 0) |
|
offset_h = random.randint(0, margin_h) |
|
offset_w = random.randint(0, margin_w) |
|
crop_y1, crop_y2 = offset_h, offset_h + target_h |
|
crop_x1, crop_x2 = offset_w, offset_w + target_w |
|
|
|
return crop_y1, crop_y2, crop_x1, crop_x2 |
|
|
|
def crop(self, img, crop_bbox): |
|
"""Crop from ``img``""" |
|
crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox |
|
img = img.crop((crop_x1, crop_y1, crop_x2, crop_y2)) |
|
return img |
|
|
|
@staticmethod |
|
def crop_in_image(target_w, target_h, w, h, img, mask): |
|
if w == target_w: |
|
x1 = 0 |
|
else: |
|
x1 = random.randint(0, w - target_w) |
|
if h == target_h: |
|
y1 = 0 |
|
else: |
|
y1 = random.randint(0, h - target_h) |
|
|
|
return [img.crop((x1, y1, x1 + target_w, y1 + target_h)), |
|
mask.crop((x1, y1, x1 + target_w, y1 + target_h))] |
|
|
|
|
|
def __call__(self, img, mask): |
|
w, h = img.size |
|
target_h, target_w = self.size |
|
|
|
if w == target_w and h == target_h: |
|
return img, mask |
|
|
|
|
|
if target_h > h: |
|
pad_h = (target_h - h) // 2 + 1 |
|
else: |
|
pad_h = 0 |
|
if target_w > w: |
|
pad_w = (target_w - w) // 2 + 1 |
|
else: |
|
pad_w = 0 |
|
border = (pad_w, pad_h, pad_w, pad_h) |
|
if pad_h or pad_w: |
|
img = ImageOps.expand(img, border=border, fill=(0, 0, 0)) |
|
mask = ImageOps.expand(mask, border=border, fill=self.ignore_index) |
|
w, h = img.size |
|
|
|
crop_bbox = self.get_crop_bbox(img) |
|
if self.cat_max_ratio < 1.: |
|
|
|
for _ in range(10): |
|
seg_temp = self.crop(mask, crop_bbox) |
|
labels, cnt = np.unique(seg_temp, return_counts=True) |
|
cnt = cnt[labels != self.ignore_index] |
|
if len(cnt) > 1 and np.max(cnt) / np.sum( |
|
cnt) < self.cat_max_ratio: |
|
break |
|
crop_bbox = self.get_crop_bbox(img) |
|
|
|
|
|
img = self.crop(img, crop_bbox) |
|
|
|
|
|
mask = self.crop(mask, crop_bbox) |
|
assert(img.size[0] == self.size[1] and img.size[1] == self.size[0]) |
|
|
|
return img, mask |
|
|
|
|
|
class RandomSizeAndCrop(object): |
|
def __init__(self, |
|
crop_size, |
|
scale_min=0.5, |
|
scale_max=2.0): |
|
self.crop = RandomCrop(crop_size) |
|
self.scale_min = scale_min |
|
self.scale_max = scale_max |
|
|
|
def __call__(self, img, mask): |
|
|
|
scale_amt = random.uniform(self.scale_min, self.scale_max) |
|
w, h = [int(i * scale_amt) for i in img.size] |
|
|
|
resized_img = img.resize((w, h), Image.BICUBIC) |
|
resized_mask = mask.resize((w, h), Image.NEAREST) |
|
img, mask = self.crop(resized_img, resized_mask) |
|
return img, mask |
|
|
|
class RandomHorizontallyFlip(object): |
|
def __call__(self, img, mask): |
|
if random.random() < 0.5: |
|
return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose( |
|
Image.FLIP_LEFT_RIGHT) |
|
return img, mask |
|
|
|
|
|
def adjust_brightness(img, brightness_factor): |
|
"""Adjust brightness of an Image. |
|
|
|
Args: |
|
img (PIL Image): PIL Image to be adjusted. |
|
brightness_factor (float): How much to adjust the brightness. Can be |
|
any non negative number. 0 gives a black image, 1 gives the |
|
original image while 2 increases the brightness by a factor of 2. |
|
|
|
Returns: |
|
PIL Image: Brightness adjusted image. |
|
""" |
|
if not _is_pil_image(img): |
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) |
|
|
|
enhancer = ImageEnhance.Brightness(img) |
|
img = enhancer.enhance(brightness_factor) |
|
return img |
|
|
|
|
|
def adjust_contrast(img, contrast_factor): |
|
"""Adjust contrast of an Image. |
|
|
|
Args: |
|
img (PIL Image): PIL Image to be adjusted. |
|
contrast_factor (float): How much to adjust the contrast. Can be any |
|
non negative number. 0 gives a solid gray image, 1 gives the |
|
original image while 2 increases the contrast by a factor of 2. |
|
|
|
Returns: |
|
PIL Image: Contrast adjusted image. |
|
""" |
|
if not _is_pil_image(img): |
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) |
|
|
|
enhancer = ImageEnhance.Contrast(img) |
|
img = enhancer.enhance(contrast_factor) |
|
return img |
|
|
|
|
|
def adjust_saturation(img, saturation_factor): |
|
"""Adjust color saturation of an image. |
|
|
|
Args: |
|
img (PIL Image): PIL Image to be adjusted. |
|
saturation_factor (float): How much to adjust the saturation. 0 will |
|
give a black and white image, 1 will give the original image while |
|
2 will enhance the saturation by a factor of 2. |
|
|
|
Returns: |
|
PIL Image: Saturation adjusted image. |
|
""" |
|
if not _is_pil_image(img): |
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) |
|
|
|
enhancer = ImageEnhance.Color(img) |
|
img = enhancer.enhance(saturation_factor) |
|
return img |
|
|
|
|
|
def adjust_hue(img, hue_factor): |
|
"""Adjust hue of an image. |
|
|
|
The image hue is adjusted by converting the image to HSV and |
|
cyclically shifting the intensities in the hue channel (H). |
|
The image is then converted back to original image mode. |
|
|
|
`hue_factor` is the amount of shift in H channel and must be in the |
|
interval `[-0.5, 0.5]`. |
|
|
|
See https://en.wikipedia.org/wiki/Hue for more details on Hue. |
|
|
|
Args: |
|
img (PIL Image): PIL Image to be adjusted. |
|
hue_factor (float): How much to shift the hue channel. Should be in |
|
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in |
|
HSV space in positive and negative direction respectively. |
|
0 means no shift. Therefore, both -0.5 and 0.5 will give an image |
|
with complementary colors while 0 gives the original image. |
|
|
|
Returns: |
|
PIL Image: Hue adjusted image. |
|
""" |
|
if not(-0.5 <= hue_factor <= 0.5): |
|
raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) |
|
|
|
if not _is_pil_image(img): |
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) |
|
|
|
input_mode = img.mode |
|
if input_mode in {'L', '1', 'I', 'F'}: |
|
return img |
|
|
|
h, s, v = img.convert('HSV').split() |
|
|
|
np_h = np.array(h, dtype=np.uint8) |
|
|
|
with np.errstate(over='ignore'): |
|
np_h += np.uint8(hue_factor * 255) |
|
h = Image.fromarray(np_h, 'L') |
|
|
|
img = Image.merge('HSV', (h, s, v)).convert(input_mode) |
|
return img |
|
|
|
|
|
class ColorJitter(object): |
|
"""Randomly change the brightness, contrast and saturation of an image. |
|
|
|
Args: |
|
brightness (float): How much to jitter brightness. brightness_factor |
|
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. |
|
contrast (float): How much to jitter contrast. contrast_factor |
|
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. |
|
saturation (float): How much to jitter saturation. saturation_factor |
|
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. |
|
hue(float): How much to jitter hue. hue_factor is chosen uniformly from |
|
[-hue, hue]. Should be >=0 and <= 0.5. |
|
""" |
|
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): |
|
self.brightness = brightness |
|
self.contrast = contrast |
|
self.saturation = saturation |
|
self.hue = hue |
|
|
|
@staticmethod |
|
def get_params(brightness, contrast, saturation, hue): |
|
"""Get a randomized transform to be applied on image. |
|
|
|
Arguments are same as that of __init__. |
|
|
|
Returns: |
|
Transform which randomly adjusts brightness, contrast and |
|
saturation in a random order. |
|
""" |
|
transforms = [] |
|
if brightness > 0: |
|
brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) |
|
transforms.append( |
|
torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor))) |
|
|
|
if contrast > 0: |
|
contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) |
|
transforms.append( |
|
torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor))) |
|
|
|
if saturation > 0: |
|
saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) |
|
transforms.append( |
|
torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor))) |
|
|
|
if hue > 0: |
|
hue_factor = np.random.uniform(-hue, hue) |
|
transforms.append( |
|
torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor))) |
|
|
|
np.random.shuffle(transforms) |
|
transform = torch_tr.Compose(transforms) |
|
|
|
return transform |
|
|
|
def __call__(self, img): |
|
""" |
|
Args: |
|
img (PIL Image): Input image. |
|
|
|
Returns: |
|
PIL Image: Color jittered image. |
|
""" |
|
transform = self.get_params(self.brightness, self.contrast, |
|
self.saturation, self.hue) |
|
return transform(img) |
|
|
|
|