Spaces:
Running
Running
import torch.utils.data as data | |
import torch | |
from PIL import Image, ImageFilter | |
import os, cv2 | |
import numpy as np | |
import random | |
from scipy.stats import norm | |
from math import floor | |
def random_translate(image, target): | |
if random.random() > 0.5: | |
image_height, image_width = image.size | |
a = 1 | |
b = 0 | |
#c = 30 #left/right (i.e. 5/-5) | |
c = int((random.random()-0.5) * 60) | |
d = 0 | |
e = 1 | |
#f = 30 #up/down (i.e. 5/-5) | |
f = int((random.random()-0.5) * 60) | |
image = image.transform(image.size, Image.AFFINE, (a, b, c, d, e, f)) | |
target_translate = target.copy() | |
target_translate = target_translate.reshape(-1, 2) | |
target_translate[:, 0] -= 1.*c/image_width | |
target_translate[:, 1] -= 1.*f/image_height | |
target_translate = target_translate.flatten() | |
target_translate[target_translate < 0] = 0 | |
target_translate[target_translate > 1] = 1 | |
return image, target_translate | |
else: | |
return image, target | |
def random_blur(image): | |
if random.random() > 0.7: | |
image = image.filter(ImageFilter.GaussianBlur(random.random()*5)) | |
return image | |
def random_occlusion(image): | |
if random.random() > 0.5: | |
image_np = np.array(image).astype(np.uint8) | |
image_np = image_np[:,:,::-1] | |
image_height, image_width, _ = image_np.shape | |
occ_height = int(image_height*0.4*random.random()) | |
occ_width = int(image_width*0.4*random.random()) | |
occ_xmin = int((image_width - occ_width - 10) * random.random()) | |
occ_ymin = int((image_height - occ_height - 10) * random.random()) | |
image_np[occ_ymin:occ_ymin+occ_height, occ_xmin:occ_xmin+occ_width, 0] = int(random.random() * 255) | |
image_np[occ_ymin:occ_ymin+occ_height, occ_xmin:occ_xmin+occ_width, 1] = int(random.random() * 255) | |
image_np[occ_ymin:occ_ymin+occ_height, occ_xmin:occ_xmin+occ_width, 2] = int(random.random() * 255) | |
image_pil = Image.fromarray(image_np[:,:,::-1].astype('uint8'), 'RGB') | |
return image_pil | |
else: | |
return image | |
def random_flip(image, target, points_flip): | |
if random.random() > 0.5: | |
image = image.transpose(Image.FLIP_LEFT_RIGHT) | |
target = np.array(target).reshape(-1, 2) | |
target = target[points_flip, :] | |
target[:,0] = 1-target[:,0] | |
target = target.flatten() | |
return image, target | |
else: | |
return image, target | |
def random_rotate(image, target, angle_max): | |
if random.random() > 0.5: | |
center_x = 0.5 | |
center_y = 0.5 | |
landmark_num= int(len(target) / 2) | |
target_center = np.array(target) - np.array([center_x, center_y]*landmark_num) | |
target_center = target_center.reshape(landmark_num, 2) | |
theta_max = np.radians(angle_max) | |
theta = random.uniform(-theta_max, theta_max) | |
angle = np.degrees(theta) | |
image = image.rotate(angle) | |
c, s = np.cos(theta), np.sin(theta) | |
rot = np.array(((c,-s), (s, c))) | |
target_center_rot = np.matmul(target_center, rot) | |
target_rot = target_center_rot.reshape(landmark_num*2) + np.array([center_x, center_y]*landmark_num) | |
return image, target_rot | |
else: | |
return image, target | |
def gen_target_pip(target, meanface_indices, target_map, target_local_x, target_local_y, target_nb_x, target_nb_y): | |
num_nb = len(meanface_indices[0]) | |
map_channel, map_height, map_width = target_map.shape | |
target = target.reshape(-1, 2) | |
assert map_channel == target.shape[0] | |
for i in range(map_channel): | |
mu_x = int(floor(target[i][0] * map_width)) | |
mu_y = int(floor(target[i][1] * map_height)) | |
mu_x = max(0, mu_x) | |
mu_y = max(0, mu_y) | |
mu_x = min(mu_x, map_width-1) | |
mu_y = min(mu_y, map_height-1) | |
target_map[i, mu_y, mu_x] = 1 | |
shift_x = target[i][0] * map_width - mu_x | |
shift_y = target[i][1] * map_height - mu_y | |
target_local_x[i, mu_y, mu_x] = shift_x | |
target_local_y[i, mu_y, mu_x] = shift_y | |
for j in range(num_nb): | |
nb_x = target[meanface_indices[i][j]][0] * map_width - mu_x | |
nb_y = target[meanface_indices[i][j]][1] * map_height - mu_y | |
target_nb_x[num_nb*i+j, mu_y, mu_x] = nb_x | |
target_nb_y[num_nb*i+j, mu_y, mu_x] = nb_y | |
return target_map, target_local_x, target_local_y, target_nb_x, target_nb_y | |
class ImageFolder_pip(data.Dataset): | |
def __init__(self, root, imgs, input_size, num_lms, net_stride, points_flip, meanface_indices, transform=None, target_transform=None): | |
self.root = root | |
self.imgs = imgs | |
self.num_lms = num_lms | |
self.net_stride = net_stride | |
self.points_flip = points_flip | |
self.meanface_indices = meanface_indices | |
self.num_nb = len(meanface_indices[0]) | |
self.transform = transform | |
self.target_transform = target_transform | |
self.input_size = input_size | |
def __getitem__(self, index): | |
img_name, target = self.imgs[index] | |
img = Image.open(os.path.join(self.root, img_name)).convert('RGB') | |
img, target = random_translate(img, target) | |
img = random_occlusion(img) | |
img, target = random_flip(img, target, self.points_flip) | |
img, target = random_rotate(img, target, 30) | |
img = random_blur(img) | |
target_map = np.zeros((self.num_lms, int(self.input_size/self.net_stride), int(self.input_size/self.net_stride))) | |
target_local_x = np.zeros((self.num_lms, int(self.input_size/self.net_stride), int(self.input_size/self.net_stride))) | |
target_local_y = np.zeros((self.num_lms, int(self.input_size/self.net_stride), int(self.input_size/self.net_stride))) | |
target_nb_x = np.zeros((self.num_nb*self.num_lms, int(self.input_size/self.net_stride), int(self.input_size/self.net_stride))) | |
target_nb_y = np.zeros((self.num_nb*self.num_lms, int(self.input_size/self.net_stride), int(self.input_size/self.net_stride))) | |
target_map, target_local_x, target_local_y, target_nb_x, target_nb_y = gen_target_pip(target, self.meanface_indices, target_map, target_local_x, target_local_y, target_nb_x, target_nb_y) | |
target_map = torch.from_numpy(target_map).float() | |
target_local_x = torch.from_numpy(target_local_x).float() | |
target_local_y = torch.from_numpy(target_local_y).float() | |
target_nb_x = torch.from_numpy(target_nb_x).float() | |
target_nb_y = torch.from_numpy(target_nb_y).float() | |
if self.transform is not None: | |
img = self.transform(img) | |
if self.target_transform is not None: | |
target_map = self.target_transform(target_map) | |
target_local_x = self.target_transform(target_local_x) | |
target_local_y = self.target_transform(target_local_y) | |
target_nb_x = self.target_transform(target_nb_x) | |
target_nb_y = self.target_transform(target_nb_y) | |
return img, target_map, target_local_x, target_local_y, target_nb_x, target_nb_y | |
def __len__(self): | |
return len(self.imgs) | |
if __name__ == '__main__': | |
pass | |