File size: 4,053 Bytes
7999e5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
from PIL import Image, ImageFile
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms 
import matplotlib.pyplot as plt
from pathlib import Path
from glob import glob

def adaptive_instance_normalization(x, y, eps=1e-5):
	"""
	Adaptive Instance Normalization. Perform neural style transfer given content image x
	and style image y.

	Args:
		x (torch.FloatTensor): Content image tensor
		y (torch.FloatTensor): Style image tensor
		eps (float, default=1e-5): Small value to avoid zero division

	Return:
		output (torch.FloatTensor): AdaIN style transferred output
	"""

	mu_x = torch.mean(x, dim=[2, 3])
	mu_y = torch.mean(y, dim=[2, 3])
	mu_x = mu_x.unsqueeze(-1).unsqueeze(-1)
	mu_y = mu_y.unsqueeze(-1).unsqueeze(-1)

	sigma_x = torch.std(x, dim=[2, 3])
	sigma_y = torch.std(y, dim=[2, 3])
	sigma_x = sigma_x.unsqueeze(-1).unsqueeze(-1) + eps
	sigma_y = sigma_y.unsqueeze(-1).unsqueeze(-1) + eps

	return (x - mu_x) / sigma_x * sigma_y  + mu_y

def transform(size):
	"""
	Image preprocess transformation. Resize image and convert to tensor.

	Args:
		size (int): Resize image size

	Return:
		output (torchvision.transforms): Composition of torchvision.transforms steps
	"""
	
	t = []
	t.append(transforms.Resize(size))
	t.append(transforms.ToTensor())
	t = transforms.Compose(t)
	return t

def grid_image(row, col, images, height=6, width=6, save_pth='grid.png'):
	"""
	Generate and save an image that contains row x col grids of images.

	Args:
		row (int): number of rows
		col (int): number of columns
		images (list of PIL image): list of images.
		height (int) : height of each image (inch)
		width (int) : width of eac image (inch)
		save_pth (str): save file path
	"""

	width = col * width
	height = row * height
	plt.figure(figsize=(width, height))
	for i, image in enumerate(images):
		plt.subplot(row, col, i+1)
		plt.imshow(image)
		plt.axis('off')
		plt.subplots_adjust(wspace=0.01, hspace=0.01)
	plt.savefig(save_pth)


def linear_histogram_matching(content_tensor, style_tensor):
	"""
	Given content_tensor and style_tensor, transform style_tensor histogram to that of content_tensor.

	Args:
		content_tensor (torch.FloatTensor): Content image 
		style_tensor (torch.FloatTensor): Style Image
	
	Return:
		style_tensor (torch.FloatTensor): histogram matched Style Image
	"""
    #for batch
	for b in range(len(content_tensor)):
		std_ct = []
		std_st = []
		mean_ct = []
		mean_st = []
		#for channel
		for c in range(len(content_tensor[b])):
			std_ct.append(torch.var(content_tensor[b][c],unbiased = False))
			mean_ct.append(torch.mean(content_tensor[b][c]))
			std_st.append(torch.var(style_tensor[b][c],unbiased = False))
			mean_st.append(torch.mean(style_tensor[b][c]))
			style_tensor[b][c] = (style_tensor[b][c] - mean_st[c]) * std_ct[c] / std_st[c] + mean_ct[c]
	return style_tensor


class TrainSet(Dataset):
	"""
	Build Training dataset
	"""
	def __init__(self, content_dir, style_dir, crop_size = 256):
		super().__init__()

		self.content_files = [Path(f) for f in glob(content_dir+'/*')]
		self.style_files = [Path(f) for f in glob(style_dir+'/*')]
		
		self.transform = transforms.Compose([
			transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC),
			transforms.RandomCrop(crop_size),
			transforms.ToTensor(),
			transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
			])

		Image.MAX_IMAGE_PIXELS = None
		ImageFile.LOAD_TRUNCATED_IMAGES = True
	
	def __len__(self):
		return min(len(self.style_files), len(self.content_files))

	def __getitem__(self, index):
		content_img = Image.open(self.content_files[index]).convert('RGB')
		style_img = Image.open(self.style_files[index]).convert('RGB')
	
		content_sample = self.transform(content_img)
		style_sample = self.transform(style_img)

		return content_sample, style_sample

class Range(object):
	"""
	Helper class for input argument range restriction
	"""
	def __init__(self, start, end):
		self.start = start
		self.end = end
	def __eq__(self, other):
		return self.start <= other <= self.end