Spaces:
Running
Running
import random | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
from torchvision.transforms import Compose | |
from .abinet_aug import CVColorJitter, CVDeterioration, CVGeometry, SVTRDeterioration, SVTRGeometry | |
from .parseq_aug import rand_augment_transform | |
class PARSeqAugPIL(object): | |
def __init__(self, **kwargs): | |
self.transforms = rand_augment_transform() | |
def __call__(self, data): | |
img = data['image'] | |
img_aug = self.transforms(img) | |
data['image'] = img_aug | |
return data | |
class PARSeqAug(object): | |
def __init__(self, **kwargs): | |
self.transforms = rand_augment_transform() | |
def __call__(self, data): | |
img = data['image'] | |
img = np.array(self.transforms(Image.fromarray(img))) | |
data['image'] = img | |
return data | |
class ABINetAug(object): | |
def __init__(self, | |
geometry_p=0.5, | |
deterioration_p=0.25, | |
colorjitter_p=0.25, | |
**kwargs): | |
self.transforms = Compose([ | |
CVGeometry( | |
degrees=45, | |
translate=(0.0, 0.0), | |
scale=(0.5, 2.0), | |
shear=(45, 15), | |
distortion=0.5, | |
p=geometry_p, | |
), | |
CVDeterioration(var=20, degrees=6, factor=4, p=deterioration_p), | |
CVColorJitter( | |
brightness=0.5, | |
contrast=0.5, | |
saturation=0.5, | |
hue=0.1, | |
p=colorjitter_p, | |
), | |
]) | |
def __call__(self, data): | |
img = data['image'] | |
img = self.transforms(img) | |
data['image'] = img | |
return data | |
class SVTRAug(object): | |
def __init__(self, | |
aug_type=0, | |
geometry_p=0.5, | |
deterioration_p=0.25, | |
colorjitter_p=0.25, | |
**kwargs): | |
self.transforms = Compose([ | |
SVTRGeometry( | |
aug_type=aug_type, | |
degrees=45, | |
translate=(0.0, 0.0), | |
scale=(0.5, 2.0), | |
shear=(45, 15), | |
distortion=0.5, | |
p=geometry_p, | |
), | |
SVTRDeterioration(var=20, degrees=6, factor=4, p=deterioration_p), | |
CVColorJitter( | |
brightness=0.5, | |
contrast=0.5, | |
saturation=0.5, | |
hue=0.1, | |
p=colorjitter_p, | |
), | |
]) | |
def __call__(self, data): | |
img = data['image'] | |
img = self.transforms(img) | |
data['image'] = img | |
return data | |
class BaseDataAugmentation(object): | |
def __init__(self, | |
crop_prob=0.4, | |
reverse_prob=0.4, | |
noise_prob=0.4, | |
jitter_prob=0.4, | |
blur_prob=0.4, | |
hsv_aug_prob=0.4, | |
**kwargs): | |
self.crop_prob = crop_prob | |
self.reverse_prob = reverse_prob | |
self.noise_prob = noise_prob | |
self.jitter_prob = jitter_prob | |
self.blur_prob = blur_prob | |
self.hsv_aug_prob = hsv_aug_prob | |
# for GaussianBlur | |
self.fil = cv2.getGaussianKernel(ksize=5, sigma=1, ktype=cv2.CV_32F) | |
def __call__(self, data): | |
img = data['image'] | |
h, w, _ = img.shape | |
if random.random() <= self.crop_prob and h >= 20 and w >= 20: | |
img = get_crop(img) | |
if random.random() <= self.blur_prob: | |
# GaussianBlur | |
img = cv2.sepFilter2D(img, -1, self.fil, self.fil) | |
if random.random() <= self.hsv_aug_prob: | |
img = hsv_aug(img) | |
if random.random() <= self.jitter_prob: | |
img = jitter(img) | |
if random.random() <= self.noise_prob: | |
img = add_gasuss_noise(img) | |
if random.random() <= self.reverse_prob: | |
img = 255 - img | |
data['image'] = img | |
return data | |
def hsv_aug(img): | |
"""cvtColor.""" | |
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) | |
delta = 0.001 * random.random() * flag() | |
hsv[:, :, 2] = hsv[:, :, 2] * (1 + delta) | |
new_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) | |
return new_img | |
def blur(img): | |
"""blur.""" | |
h, w, _ = img.shape | |
if h > 10 and w > 10: | |
return cv2.GaussianBlur(img, (5, 5), 1) | |
else: | |
return img | |
def jitter(img): | |
"""jitter.""" | |
w, h, _ = img.shape | |
if h > 10 and w > 10: | |
thres = min(w, h) | |
s = int(random.random() * thres * 0.01) | |
src_img = img.copy() | |
for i in range(s): | |
img[i:, i:, :] = src_img[:w - i, :h - i, :] | |
return img | |
else: | |
return img | |
def add_gasuss_noise(image, mean=0, var=0.1): | |
"""Gasuss noise.""" | |
noise = np.random.normal(mean, var**0.5, image.shape) | |
out = image + 0.5 * noise | |
out = np.clip(out, 0, 255) | |
out = np.uint8(out) | |
return out | |
def get_crop(image): | |
"""random crop.""" | |
h, w, _ = image.shape | |
top_min = 1 | |
top_max = 8 | |
top_crop = int(random.randint(top_min, top_max)) | |
top_crop = min(top_crop, h - 1) | |
crop_img = image.copy() | |
ratio = random.randint(0, 1) | |
if ratio: | |
crop_img = crop_img[top_crop:h, :, :] | |
else: | |
crop_img = crop_img[0:h - top_crop, :, :] | |
return crop_img | |
def flag(): | |
"""flag.""" | |
return 1 if random.random() > 0.5000001 else -1 | |