ICDR / utils /image_utils.py
Siwon123's picture
q
7f43945
raw
history blame
9.36 kB
"""
Created on 2020/9/8
@author: Boyun Li
"""
import os
import numpy as np
import torch
import random
import torch.nn as nn
from torch.nn import init
from PIL import Image
class EdgeComputation(nn.Module):
def __init__(self, test=False):
super(EdgeComputation, self).__init__()
self.test = test
def forward(self, x):
if self.test:
x_diffx = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
x_diffy = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
# y = torch.Tensor(x.size()).cuda()
y = torch.Tensor(x.size())
y.fill_(0)
y[:, :, :, 1:] += x_diffx
y[:, :, :, :-1] += x_diffx
y[:, :, 1:, :] += x_diffy
y[:, :, :-1, :] += x_diffy
y = torch.sum(y, 1, keepdim=True) / 3
y /= 4
return y
else:
x_diffx = torch.abs(x[:, :, 1:] - x[:, :, :-1])
x_diffy = torch.abs(x[:, 1:, :] - x[:, :-1, :])
y = torch.Tensor(x.size())
y.fill_(0)
y[:, :, 1:] += x_diffx
y[:, :, :-1] += x_diffx
y[:, 1:, :] += x_diffy
y[:, :-1, :] += x_diffy
y = torch.sum(y, 0) / 3
y /= 4
return y.unsqueeze(0)
# randomly crop a patch from image
def crop_patch(im, pch_size):
H = im.shape[0]
W = im.shape[1]
ind_H = random.randint(0, H - pch_size)
ind_W = random.randint(0, W - pch_size)
pch = im[ind_H:ind_H + pch_size, ind_W:ind_W + pch_size]
return pch
# crop an image to the multiple of base
def crop_img(image, base=64):
h = image.shape[0]
w = image.shape[1]
crop_h = h % base
crop_w = w % base
return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :]
# image (H, W, C) -> patches (B, H, W, C)
def slice_image2patches(image, patch_size=64, overlap=0):
assert image.shape[0] % patch_size == 0 and image.shape[1] % patch_size == 0
H = image.shape[0]
W = image.shape[1]
patches = []
image_padding = np.pad(image, ((overlap, overlap), (overlap, overlap), (0, 0)), mode='edge')
for h in range(H // patch_size):
for w in range(W // patch_size):
idx_h = [h * patch_size, (h + 1) * patch_size + overlap]
idx_w = [w * patch_size, (w + 1) * patch_size + overlap]
patches.append(np.expand_dims(image_padding[idx_h[0]:idx_h[1], idx_w[0]:idx_w[1], :], axis=0))
return np.concatenate(patches, axis=0)
# patches (B, H, W, C) -> image (H, W, C)
def splice_patches2image(patches, image_size, overlap=0):
assert len(image_size) > 1
assert patches.shape[-3] == patches.shape[-2]
H = image_size[0]
W = image_size[1]
patch_size = patches.shape[-2] - overlap
image = np.zeros(image_size)
idx = 0
for h in range(H // patch_size):
for w in range(W // patch_size):
image[h * patch_size:(h + 1) * patch_size, w * patch_size:(w + 1) * patch_size, :] = patches[idx,
overlap:patch_size + overlap,
overlap:patch_size + overlap,
:]
idx += 1
return image
# def data_augmentation(image, mode):
# if mode == 0:
# # original
# out = image.numpy()
# elif mode == 1:
# # flip up and down
# out = np.flipud(image)
# elif mode == 2:
# # rotate counterwise 90 degree
# out = np.rot90(image, axes=(1, 2))
# elif mode == 3:
# # rotate 90 degree and flip up and down
# out = np.rot90(image, axes=(1, 2))
# out = np.flipud(out)
# elif mode == 4:
# # rotate 180 degree
# out = np.rot90(image, k=2, axes=(1, 2))
# elif mode == 5:
# # rotate 180 degree and flip
# out = np.rot90(image, k=2, axes=(1, 2))
# out = np.flipud(out)
# elif mode == 6:
# # rotate 270 degree
# out = np.rot90(image, k=3, axes=(1, 2))
# elif mode == 7:
# # rotate 270 degree and flip
# out = np.rot90(image, k=3, axes=(1, 2))
# out = np.flipud(out)
# else:
# raise Exception('Invalid choice of image transformation')
# return out
def data_augmentation(image, mode):
if mode == 0:
# original
out = image.numpy()
elif mode == 1:
# flip up and down
out = np.flipud(image)
elif mode == 2:
# rotate counterwise 90 degree
out = np.rot90(image)
elif mode == 3:
# rotate 90 degree and flip up and down
out = np.rot90(image)
out = np.flipud(out)
elif mode == 4:
# rotate 180 degree
out = np.rot90(image, k=2)
elif mode == 5:
# rotate 180 degree and flip
out = np.rot90(image, k=2)
out = np.flipud(out)
elif mode == 6:
# rotate 270 degree
out = np.rot90(image, k=3)
elif mode == 7:
# rotate 270 degree and flip
out = np.rot90(image, k=3)
out = np.flipud(out)
else:
raise Exception('Invalid choice of image transformation')
return out
# def random_augmentation(*args):
# out = []
# if random.randint(0, 1) == 1:
# flag_aug = random.randint(1, 7)
# for data in args:
# out.append(data_augmentation(data, flag_aug).copy())
# else:
# for data in args:
# out.append(data)
# return out
def random_augmentation(*args):
out = []
flag_aug = random.randint(1, 7)
for data in args:
out.append(data_augmentation(data, flag_aug).copy())
return out
def weights_init_normal_(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.uniform(m.weight.data, 0.0, 0.02)
elif classname.find('Linear') != -1:
init.uniform(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
init.uniform(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
m.apply(weights_init_normal_)
elif classname.find('Linear') != -1:
init.uniform(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
init.uniform(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def weights_init_xavier(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.xavier_normal(m.weight.data, gain=1)
elif classname.find('Linear') != -1:
init.xavier_normal(m.weight.data, gain=1)
elif classname.find('BatchNorm2d') != -1:
init.uniform(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def weights_init_kaiming(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif classname.find('Linear') != -1:
init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif classname.find('BatchNorm2d') != -1:
init.uniform(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def weights_init_orthogonal(m):
classname = m.__class__.__name__
print(classname)
if classname.find('Conv') != -1:
init.orthogonal(m.weight.data, gain=1)
elif classname.find('Linear') != -1:
init.orthogonal(m.weight.data, gain=1)
elif classname.find('BatchNorm2d') != -1:
init.uniform(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def init_weights(net, init_type='normal'):
print('initialization method [%s]' % init_type)
if init_type == 'normal':
net.apply(weights_init_normal)
elif init_type == 'xavier':
net.apply(weights_init_xavier)
elif init_type == 'kaiming':
net.apply(weights_init_kaiming)
elif init_type == 'orthogonal':
net.apply(weights_init_orthogonal)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
def np_to_torch(img_np):
"""
Converts image in numpy.array to torch.Tensor.
From C x W x H [0..1] to C x W x H [0..1]
:param img_np:
:return:
"""
return torch.from_numpy(img_np)[None, :]
def torch_to_np(img_var):
"""
Converts an image in torch.Tensor format to np.array.
From 1 x C x W x H [0..1] to C x W x H [0..1]
:param img_var:
:return:
"""
return img_var.detach().cpu().numpy()
# return img_var.detach().cpu().numpy()[0]
def save_image(name, image_np, output_path="output/normal/"):
if not os.path.exists(output_path):
os.mkdir(output_path)
p = np_to_pil(image_np)
p.save(output_path + "{}.png".format(name))
def np_to_pil(img_np):
"""
Converts image in np.array format to PIL image.
From C x W x H [0..1] to W x H x C [0...255]
:param img_np:
:return:
"""
ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)
if img_np.shape[0] == 1:
ar = ar[0]
else:
assert img_np.shape[0] == 3, img_np.shape
ar = ar.transpose(1, 2, 0)
return Image.fromarray(ar)