subatomicseer's picture
Initial Commit
7999e5a
import os
from PIL import Image, ImageFile
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from pathlib import Path
from glob import glob
def adaptive_instance_normalization(x, y, eps=1e-5):
"""
Adaptive Instance Normalization. Perform neural style transfer given content image x
and style image y.
Args:
x (torch.FloatTensor): Content image tensor
y (torch.FloatTensor): Style image tensor
eps (float, default=1e-5): Small value to avoid zero division
Return:
output (torch.FloatTensor): AdaIN style transferred output
"""
mu_x = torch.mean(x, dim=[2, 3])
mu_y = torch.mean(y, dim=[2, 3])
mu_x = mu_x.unsqueeze(-1).unsqueeze(-1)
mu_y = mu_y.unsqueeze(-1).unsqueeze(-1)
sigma_x = torch.std(x, dim=[2, 3])
sigma_y = torch.std(y, dim=[2, 3])
sigma_x = sigma_x.unsqueeze(-1).unsqueeze(-1) + eps
sigma_y = sigma_y.unsqueeze(-1).unsqueeze(-1) + eps
return (x - mu_x) / sigma_x * sigma_y + mu_y
def transform(size):
"""
Image preprocess transformation. Resize image and convert to tensor.
Args:
size (int): Resize image size
Return:
output (torchvision.transforms): Composition of torchvision.transforms steps
"""
t = []
t.append(transforms.Resize(size))
t.append(transforms.ToTensor())
t = transforms.Compose(t)
return t
def grid_image(row, col, images, height=6, width=6, save_pth='grid.png'):
"""
Generate and save an image that contains row x col grids of images.
Args:
row (int): number of rows
col (int): number of columns
images (list of PIL image): list of images.
height (int) : height of each image (inch)
width (int) : width of eac image (inch)
save_pth (str): save file path
"""
width = col * width
height = row * height
plt.figure(figsize=(width, height))
for i, image in enumerate(images):
plt.subplot(row, col, i+1)
plt.imshow(image)
plt.axis('off')
plt.subplots_adjust(wspace=0.01, hspace=0.01)
plt.savefig(save_pth)
def linear_histogram_matching(content_tensor, style_tensor):
"""
Given content_tensor and style_tensor, transform style_tensor histogram to that of content_tensor.
Args:
content_tensor (torch.FloatTensor): Content image
style_tensor (torch.FloatTensor): Style Image
Return:
style_tensor (torch.FloatTensor): histogram matched Style Image
"""
#for batch
for b in range(len(content_tensor)):
std_ct = []
std_st = []
mean_ct = []
mean_st = []
#for channel
for c in range(len(content_tensor[b])):
std_ct.append(torch.var(content_tensor[b][c],unbiased = False))
mean_ct.append(torch.mean(content_tensor[b][c]))
std_st.append(torch.var(style_tensor[b][c],unbiased = False))
mean_st.append(torch.mean(style_tensor[b][c]))
style_tensor[b][c] = (style_tensor[b][c] - mean_st[c]) * std_ct[c] / std_st[c] + mean_ct[c]
return style_tensor
class TrainSet(Dataset):
"""
Build Training dataset
"""
def __init__(self, content_dir, style_dir, crop_size = 256):
super().__init__()
self.content_files = [Path(f) for f in glob(content_dir+'/*')]
self.style_files = [Path(f) for f in glob(style_dir+'/*')]
self.transform = transforms.Compose([
transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True
def __len__(self):
return min(len(self.style_files), len(self.content_files))
def __getitem__(self, index):
content_img = Image.open(self.content_files[index]).convert('RGB')
style_img = Image.open(self.style_files[index]).convert('RGB')
content_sample = self.transform(content_img)
style_sample = self.transform(style_img)
return content_sample, style_sample
class Range(object):
"""
Helper class for input argument range restriction
"""
def __init__(self, start, end):
self.start = start
self.end = end
def __eq__(self, other):
return self.start <= other <= self.end