|
""" |
|
Created on 2020/9/8 |
|
|
|
@author: Boyun Li |
|
""" |
|
import os |
|
import numpy as np |
|
import torch |
|
import random |
|
import torch.nn as nn |
|
from torch.nn import init |
|
from PIL import Image |
|
|
|
class EdgeComputation(nn.Module): |
|
def __init__(self, test=False): |
|
super(EdgeComputation, self).__init__() |
|
self.test = test |
|
def forward(self, x): |
|
if self.test: |
|
x_diffx = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]) |
|
x_diffy = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]) |
|
|
|
|
|
y = torch.Tensor(x.size()) |
|
y.fill_(0) |
|
y[:, :, :, 1:] += x_diffx |
|
y[:, :, :, :-1] += x_diffx |
|
y[:, :, 1:, :] += x_diffy |
|
y[:, :, :-1, :] += x_diffy |
|
y = torch.sum(y, 1, keepdim=True) / 3 |
|
y /= 4 |
|
return y |
|
else: |
|
x_diffx = torch.abs(x[:, :, 1:] - x[:, :, :-1]) |
|
x_diffy = torch.abs(x[:, 1:, :] - x[:, :-1, :]) |
|
|
|
y = torch.Tensor(x.size()) |
|
y.fill_(0) |
|
y[:, :, 1:] += x_diffx |
|
y[:, :, :-1] += x_diffx |
|
y[:, 1:, :] += x_diffy |
|
y[:, :-1, :] += x_diffy |
|
y = torch.sum(y, 0) / 3 |
|
y /= 4 |
|
return y.unsqueeze(0) |
|
|
|
|
|
|
|
def crop_patch(im, pch_size): |
|
H = im.shape[0] |
|
W = im.shape[1] |
|
ind_H = random.randint(0, H - pch_size) |
|
ind_W = random.randint(0, W - pch_size) |
|
pch = im[ind_H:ind_H + pch_size, ind_W:ind_W + pch_size] |
|
return pch |
|
|
|
|
|
|
|
def crop_img(image, base=64): |
|
h = image.shape[0] |
|
w = image.shape[1] |
|
crop_h = h % base |
|
crop_w = w % base |
|
return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :] |
|
|
|
|
|
|
|
def slice_image2patches(image, patch_size=64, overlap=0): |
|
assert image.shape[0] % patch_size == 0 and image.shape[1] % patch_size == 0 |
|
H = image.shape[0] |
|
W = image.shape[1] |
|
patches = [] |
|
image_padding = np.pad(image, ((overlap, overlap), (overlap, overlap), (0, 0)), mode='edge') |
|
for h in range(H // patch_size): |
|
for w in range(W // patch_size): |
|
idx_h = [h * patch_size, (h + 1) * patch_size + overlap] |
|
idx_w = [w * patch_size, (w + 1) * patch_size + overlap] |
|
patches.append(np.expand_dims(image_padding[idx_h[0]:idx_h[1], idx_w[0]:idx_w[1], :], axis=0)) |
|
return np.concatenate(patches, axis=0) |
|
|
|
|
|
|
|
def splice_patches2image(patches, image_size, overlap=0): |
|
assert len(image_size) > 1 |
|
assert patches.shape[-3] == patches.shape[-2] |
|
H = image_size[0] |
|
W = image_size[1] |
|
patch_size = patches.shape[-2] - overlap |
|
image = np.zeros(image_size) |
|
idx = 0 |
|
for h in range(H // patch_size): |
|
for w in range(W // patch_size): |
|
image[h * patch_size:(h + 1) * patch_size, w * patch_size:(w + 1) * patch_size, :] = patches[idx, |
|
overlap:patch_size + overlap, |
|
overlap:patch_size + overlap, |
|
:] |
|
idx += 1 |
|
return image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def data_augmentation(image, mode): |
|
if mode == 0: |
|
|
|
out = image.numpy() |
|
elif mode == 1: |
|
|
|
out = np.flipud(image) |
|
elif mode == 2: |
|
|
|
out = np.rot90(image) |
|
elif mode == 3: |
|
|
|
out = np.rot90(image) |
|
out = np.flipud(out) |
|
elif mode == 4: |
|
|
|
out = np.rot90(image, k=2) |
|
elif mode == 5: |
|
|
|
out = np.rot90(image, k=2) |
|
out = np.flipud(out) |
|
elif mode == 6: |
|
|
|
out = np.rot90(image, k=3) |
|
elif mode == 7: |
|
|
|
out = np.rot90(image, k=3) |
|
out = np.flipud(out) |
|
else: |
|
raise Exception('Invalid choice of image transformation') |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def random_augmentation(*args): |
|
out = [] |
|
flag_aug = random.randint(1, 7) |
|
for data in args: |
|
out.append(data_augmentation(data, flag_aug).copy()) |
|
return out |
|
|
|
|
|
def weights_init_normal_(m): |
|
classname = m.__class__.__name__ |
|
if classname.find('Conv') != -1: |
|
init.uniform(m.weight.data, 0.0, 0.02) |
|
elif classname.find('Linear') != -1: |
|
init.uniform(m.weight.data, 0.0, 0.02) |
|
elif classname.find('BatchNorm2d') != -1: |
|
init.uniform(m.weight.data, 1.0, 0.02) |
|
init.constant(m.bias.data, 0.0) |
|
|
|
|
|
def weights_init_normal(m): |
|
classname = m.__class__.__name__ |
|
if classname.find('Conv2d') != -1: |
|
m.apply(weights_init_normal_) |
|
elif classname.find('Linear') != -1: |
|
init.uniform(m.weight.data, 0.0, 0.02) |
|
elif classname.find('BatchNorm2d') != -1: |
|
init.uniform(m.weight.data, 1.0, 0.02) |
|
init.constant(m.bias.data, 0.0) |
|
|
|
|
|
def weights_init_xavier(m): |
|
classname = m.__class__.__name__ |
|
if classname.find('Conv') != -1: |
|
init.xavier_normal(m.weight.data, gain=1) |
|
elif classname.find('Linear') != -1: |
|
init.xavier_normal(m.weight.data, gain=1) |
|
elif classname.find('BatchNorm2d') != -1: |
|
init.uniform(m.weight.data, 1.0, 0.02) |
|
init.constant(m.bias.data, 0.0) |
|
|
|
|
|
def weights_init_kaiming(m): |
|
classname = m.__class__.__name__ |
|
if classname.find('Conv') != -1: |
|
init.kaiming_normal(m.weight.data, a=0, mode='fan_in') |
|
elif classname.find('Linear') != -1: |
|
init.kaiming_normal(m.weight.data, a=0, mode='fan_in') |
|
elif classname.find('BatchNorm2d') != -1: |
|
init.uniform(m.weight.data, 1.0, 0.02) |
|
init.constant(m.bias.data, 0.0) |
|
|
|
|
|
def weights_init_orthogonal(m): |
|
classname = m.__class__.__name__ |
|
print(classname) |
|
if classname.find('Conv') != -1: |
|
init.orthogonal(m.weight.data, gain=1) |
|
elif classname.find('Linear') != -1: |
|
init.orthogonal(m.weight.data, gain=1) |
|
elif classname.find('BatchNorm2d') != -1: |
|
init.uniform(m.weight.data, 1.0, 0.02) |
|
init.constant(m.bias.data, 0.0) |
|
|
|
|
|
def init_weights(net, init_type='normal'): |
|
print('initialization method [%s]' % init_type) |
|
if init_type == 'normal': |
|
net.apply(weights_init_normal) |
|
elif init_type == 'xavier': |
|
net.apply(weights_init_xavier) |
|
elif init_type == 'kaiming': |
|
net.apply(weights_init_kaiming) |
|
elif init_type == 'orthogonal': |
|
net.apply(weights_init_orthogonal) |
|
else: |
|
raise NotImplementedError('initialization method [%s] is not implemented' % init_type) |
|
|
|
|
|
def np_to_torch(img_np): |
|
""" |
|
Converts image in numpy.array to torch.Tensor. |
|
|
|
From C x W x H [0..1] to C x W x H [0..1] |
|
|
|
:param img_np: |
|
:return: |
|
""" |
|
return torch.from_numpy(img_np)[None, :] |
|
|
|
|
|
def torch_to_np(img_var): |
|
""" |
|
Converts an image in torch.Tensor format to np.array. |
|
|
|
From 1 x C x W x H [0..1] to C x W x H [0..1] |
|
:param img_var: |
|
:return: |
|
""" |
|
return img_var.detach().cpu().numpy() |
|
|
|
|
|
|
|
def save_image(name, image_np, output_path="output/normal/"): |
|
if not os.path.exists(output_path): |
|
os.mkdir(output_path) |
|
|
|
p = np_to_pil(image_np) |
|
p.save(output_path + "{}.png".format(name)) |
|
|
|
|
|
def np_to_pil(img_np): |
|
""" |
|
Converts image in np.array format to PIL image. |
|
|
|
From C x W x H [0..1] to W x H x C [0...255] |
|
:param img_np: |
|
:return: |
|
""" |
|
ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) |
|
|
|
if img_np.shape[0] == 1: |
|
ar = ar[0] |
|
else: |
|
assert img_np.shape[0] == 3, img_np.shape |
|
ar = ar.transpose(1, 2, 0) |
|
|
|
return Image.fromarray(ar) |