Spaces:
Runtime error
Runtime error
Commit
·
7999e5a
1
Parent(s):
fb09963
Initial Commit
Browse files- .gitignore +3 -0
- AdaIN.py +73 -0
- Network.py +73 -0
- README.md +4 -4
- app.py +95 -0
- examples/content/brad_pitt.jpg +0 -0
- examples/img.png +0 -0
- examples/style/flower_of_life.jpg +0 -0
- examples/style/sketch.jpg +0 -0
- infer_func.py +60 -0
- requirements.txt +11 -0
- utils.py +144 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.idea/
|
2 |
+
__pycache__/
|
3 |
+
*.pth
|
AdaIN.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from Network import vgg19, decoder
|
4 |
+
from utils import adaptive_instance_normalization
|
5 |
+
|
6 |
+
class AdaINNet(nn.Module):
|
7 |
+
"""
|
8 |
+
AdaIN Style Transfer Network
|
9 |
+
|
10 |
+
Args:
|
11 |
+
vgg_weight: pretrained vgg19 weight
|
12 |
+
"""
|
13 |
+
def __init__(self, vgg_weight):
|
14 |
+
super().__init__()
|
15 |
+
self.encoder = vgg19(vgg_weight)
|
16 |
+
|
17 |
+
# drop layers after 4_1
|
18 |
+
self.encoder = nn.Sequential(*list(self.encoder.children())[:22])
|
19 |
+
|
20 |
+
# No optimization for encoder
|
21 |
+
for parameter in self.encoder.parameters():
|
22 |
+
parameter.requires_grad = False
|
23 |
+
|
24 |
+
self.decoder = decoder()
|
25 |
+
|
26 |
+
self.mseloss = nn.MSELoss()
|
27 |
+
|
28 |
+
"""
|
29 |
+
Computes style loss of two images
|
30 |
+
|
31 |
+
Args:
|
32 |
+
x (torch.FloatTensor): content image tensor
|
33 |
+
y (torch.FloatTensor): style image tensor
|
34 |
+
|
35 |
+
Return:
|
36 |
+
Mean Squared Error between x.mean, y.mean and MSE between x.std, y.std
|
37 |
+
"""
|
38 |
+
def _style_loss(self, x, y):
|
39 |
+
return self.mseloss(torch.mean(x, dim=[2, 3]), torch.mean(y, dim=[2, 3])) + \
|
40 |
+
self.mseloss(torch.std(x, dim=[2, 3]), torch.std(y, dim=[2, 3]))
|
41 |
+
|
42 |
+
def forward(self, content, style, alpha=1.0):
|
43 |
+
# Generate image features
|
44 |
+
content_enc = self.encoder(content)
|
45 |
+
style_enc = self.encoder(style)
|
46 |
+
|
47 |
+
# Perform style transfer on feature space
|
48 |
+
transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
|
49 |
+
|
50 |
+
# Generate outptu image
|
51 |
+
out = self.decoder(transfer_enc)
|
52 |
+
|
53 |
+
# vgg19 layer relu1_1
|
54 |
+
style_relu11 = self.encoder[:3](style)
|
55 |
+
out_relu11 = self.encoder[:3](out)
|
56 |
+
|
57 |
+
# vgg19 layer relu2_1
|
58 |
+
style_relu21 = self.encoder[3:8](style_relu11)
|
59 |
+
out_relu21 = self.encoder[3:8](out_relu11)
|
60 |
+
|
61 |
+
# vgg19 layer relu3_1
|
62 |
+
style_relu31 = self.encoder[8:13](style_relu21)
|
63 |
+
out_relu31 = self.encoder[8:13](out_relu21)
|
64 |
+
|
65 |
+
# vgg19 layer relu4_1
|
66 |
+
out_enc = self.encoder[13:](out_relu31)
|
67 |
+
|
68 |
+
# Calculate loss
|
69 |
+
content_loss = self.mseloss(out_enc, transfer_enc)
|
70 |
+
style_loss = self._style_loss(out_relu11, style_relu11) + self._style_loss(out_relu21, style_relu21) + \
|
71 |
+
self._style_loss(out_relu31, style_relu31) + self._style_loss(out_enc, style_enc)
|
72 |
+
|
73 |
+
return content_loss, style_loss
|
Network.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
vgg19_cfg = [3, 64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"]
|
4 |
+
decoder_cfg = [512, 256, "U", 256, 256, 256, 128, "U", 128, 64, 'U', 64, 3]
|
5 |
+
|
6 |
+
def vgg19(weights=None):
|
7 |
+
"""
|
8 |
+
Build vgg19 network. Load weights if weights are given.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
weights (dict): vgg19 pretrained weights
|
12 |
+
|
13 |
+
Return:
|
14 |
+
layers (nn.Sequential): vgg19 layers
|
15 |
+
"""
|
16 |
+
|
17 |
+
modules = make_block(vgg19_cfg)
|
18 |
+
modules = [nn.Conv2d(3, 3, kernel_size=1)] + list(modules.children())
|
19 |
+
layers = nn.Sequential(*modules)
|
20 |
+
|
21 |
+
if weights:
|
22 |
+
layers.load_state_dict(weights)
|
23 |
+
|
24 |
+
return layers
|
25 |
+
|
26 |
+
|
27 |
+
def decoder(weights=None):
|
28 |
+
"""
|
29 |
+
Build decoder network. Load weights if weights are given.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
weights (dict): decoder pretrained weights
|
33 |
+
|
34 |
+
Return:
|
35 |
+
layers (nn.Sequential): decoder layers
|
36 |
+
"""
|
37 |
+
|
38 |
+
modules = make_block(decoder_cfg)
|
39 |
+
layers = nn.Sequential(*list(modules.children())[:-1]) # no relu at the last layer
|
40 |
+
|
41 |
+
if weights:
|
42 |
+
layers.load_state_dict(weights)
|
43 |
+
|
44 |
+
return layers
|
45 |
+
|
46 |
+
|
47 |
+
def make_block(config):
|
48 |
+
"""
|
49 |
+
Helper function for building blocks of convolutional layers.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
config (list): List of layer configs. "M"
|
53 |
+
"M" - Max pooling layer.
|
54 |
+
"U" - Upsampling layer.
|
55 |
+
i (int) - Convolutional layer (i filters) plus ReLU activation.
|
56 |
+
Return:
|
57 |
+
layers (nn.Sequential): block layers
|
58 |
+
"""
|
59 |
+
layers = []
|
60 |
+
in_channels = config[0]
|
61 |
+
|
62 |
+
for c in config[1:]:
|
63 |
+
if c == "M":
|
64 |
+
layers.append(nn.MaxPool2d(kernel_size=2, stride=2, padding=0))
|
65 |
+
elif c == "U":
|
66 |
+
layers.append(nn.Upsample(scale_factor=2, mode='nearest'))
|
67 |
+
else:
|
68 |
+
assert(isinstance(c, int))
|
69 |
+
layers.append(nn.Conv2d(in_channels, c, kernel_size=3, padding=1))
|
70 |
+
layers.append(nn.ReLU(inplace=True))
|
71 |
+
in_channels = c
|
72 |
+
|
73 |
+
return nn.Sequential(*layers)
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.2.0
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
+
title: AdaIN
|
3 |
+
emoji: 📚
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: indigo
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.2.0
|
8 |
app_file: app.py
|
app.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
import gdown
|
4 |
+
from packaging.version import Version
|
5 |
+
|
6 |
+
from infer_func import convert
|
7 |
+
|
8 |
+
ROOT = os.path.dirname(os.path.abspath(__file__))
|
9 |
+
|
10 |
+
EXAMPLES = {
|
11 |
+
'content': {
|
12 |
+
'Brad Pitt': ROOT + '/examples/content/brad_pitt.jpg'
|
13 |
+
},
|
14 |
+
'style': {
|
15 |
+
'Flower of Life': ROOT + '/examples/style/flower_of_life.jpg'
|
16 |
+
}
|
17 |
+
}
|
18 |
+
|
19 |
+
VGG_WEIGHT_URL = 'https://drive.google.com/uc?id=1UcSl-Zn3byEmn15NIPXMf9zaGCKc2gfx'
|
20 |
+
DECODER_WEIGHT_URL = 'https://drive.google.com/uc?id=18JpLtMOapA-vwBz-LRomyTl24A9GwhTF'
|
21 |
+
|
22 |
+
VGG_WEIGHT_FILENAME = ROOT + '/vgg.pth'
|
23 |
+
DECODER_WEIGHT_FILENAME = ROOT + '/decoder.pth'
|
24 |
+
|
25 |
+
|
26 |
+
@st.cache
|
27 |
+
def download_models():
|
28 |
+
with st.spinner(text="Downloading VGG weights..."):
|
29 |
+
gdown.download(VGG_WEIGHT_URL, output=VGG_WEIGHT_FILENAME)
|
30 |
+
with st.spinner(text="Downloading Decoder weights..."):
|
31 |
+
gdown.download(DECODER_WEIGHT_URL, output=DECODER_WEIGHT_FILENAME)
|
32 |
+
|
33 |
+
|
34 |
+
def image_getter(image_kind):
|
35 |
+
|
36 |
+
image = None
|
37 |
+
|
38 |
+
options = ['Use Example Image', 'Upload Image']
|
39 |
+
|
40 |
+
if Version(st.__version__) >= Version('1.4.0'):
|
41 |
+
options.append('Open Camera')
|
42 |
+
|
43 |
+
option = st.selectbox(
|
44 |
+
'Choose Image',
|
45 |
+
options, key=image_kind)
|
46 |
+
|
47 |
+
if option == 'Use Example Image':
|
48 |
+
image_key = st.selectbox(
|
49 |
+
'Choose from examples',
|
50 |
+
EXAMPLES[image_kind], key=image_kind)
|
51 |
+
image = EXAMPLES[image_kind][image_key]
|
52 |
+
|
53 |
+
elif option == 'Upload Image':
|
54 |
+
image = st.file_uploader("Upload an image", type=['png', 'jpg', 'PNG', 'JPG', 'JPEG'], key=image_kind)
|
55 |
+
elif option == 'Open Camera':
|
56 |
+
image = st.camera_input('', key=image_kind)
|
57 |
+
|
58 |
+
return image
|
59 |
+
|
60 |
+
|
61 |
+
if __name__ == '__main__':
|
62 |
+
|
63 |
+
st.set_page_config(layout="wide")
|
64 |
+
st.header('Adaptive Instance Normalization demo based on '
|
65 |
+
'[2022-AdaIN-pytorch](https://github.com/media-comp/2022-AdaIN-pytorch)')
|
66 |
+
|
67 |
+
download_models()
|
68 |
+
# col1, col2, col3, col4 = st.columns((2, 2, 1, 3))
|
69 |
+
col1, col2, col3 = st.columns((3, 4, 4))
|
70 |
+
with col1:
|
71 |
+
st.subheader('Content Image')
|
72 |
+
content = image_getter('content')
|
73 |
+
st.subheader('Style Image')
|
74 |
+
style = image_getter('style')
|
75 |
+
with col2:
|
76 |
+
img1 = content if content is not None else 'examples/img.png'
|
77 |
+
img2 = style if style is not None else 'examples/img.png'
|
78 |
+
if img1 is not None:
|
79 |
+
st.image(img1, width=None, caption='Content Image')
|
80 |
+
if img2 is not None:
|
81 |
+
st.image(img2, width=None, caption='Style Image')
|
82 |
+
|
83 |
+
with col3:
|
84 |
+
color_control = st.checkbox('Preserve content image color')
|
85 |
+
alpha = st.slider('Strength of style transfer', 0.0, 1.0, 1.0, 0.01)
|
86 |
+
process = st.button('Stylize')
|
87 |
+
|
88 |
+
if content is not None and style is not None and process:
|
89 |
+
print(content, style)
|
90 |
+
with col3:
|
91 |
+
with st.spinner('Processing...'):
|
92 |
+
output_image = convert(content, style, VGG_WEIGHT_FILENAME, DECODER_WEIGHT_FILENAME, alpha, color_control)
|
93 |
+
|
94 |
+
st.image(output_image, width=None, caption='Stylized Image')
|
95 |
+
|
examples/content/brad_pitt.jpg
ADDED
![]() |
examples/img.png
ADDED
![]() |
examples/style/flower_of_life.jpg
ADDED
![]() |
examples/style/sketch.jpg
ADDED
![]() |
infer_func.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
from AdaIN import AdaINNet
|
6 |
+
from utils import adaptive_instance_normalization, transform, linear_histogram_matching
|
7 |
+
|
8 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
9 |
+
|
10 |
+
|
11 |
+
def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
|
12 |
+
"""
|
13 |
+
Given content image and style image, generate feature maps with encoder, apply
|
14 |
+
neural style transfer with adaptive instance normalization, generate output image
|
15 |
+
with decoder
|
16 |
+
|
17 |
+
Args:
|
18 |
+
content_tensor (torch.FloatTensor): Content image
|
19 |
+
style_tensor (torch.FloatTensor): Style Image
|
20 |
+
encoder: Encoder (vgg19) network
|
21 |
+
decoder: Decoder network
|
22 |
+
alpha (float, default=1.0): Weight of style image feature
|
23 |
+
|
24 |
+
Return:
|
25 |
+
output_tensor (torch.FloatTensor): Style Transfer output image
|
26 |
+
"""
|
27 |
+
|
28 |
+
content_enc = encoder(content_tensor)
|
29 |
+
style_enc = encoder(style_tensor)
|
30 |
+
|
31 |
+
transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
|
32 |
+
|
33 |
+
mix_enc = alpha * transfer_enc + (1 - alpha) * content_enc
|
34 |
+
return decoder(mix_enc)
|
35 |
+
|
36 |
+
|
37 |
+
def convert(content_path, style_path, vgg_weights_path, decoder_weights_path, alpha, color_control):
|
38 |
+
|
39 |
+
vgg = torch.load(vgg_weights_path)
|
40 |
+
model = AdaINNet(vgg).to(device)
|
41 |
+
model.decoder.load_state_dict(torch.load(decoder_weights_path))
|
42 |
+
model.eval()
|
43 |
+
|
44 |
+
# Prepare image transform
|
45 |
+
t = transform(512)
|
46 |
+
|
47 |
+
# load images
|
48 |
+
content_img = Image.open(content_path)
|
49 |
+
content_tensor = t(content_img).unsqueeze(0).to(device)
|
50 |
+
style_tensor = t(Image.open(style_path)).unsqueeze(0).to(device)
|
51 |
+
|
52 |
+
if color_control:
|
53 |
+
style_tensor = linear_histogram_matching(content_tensor, style_tensor)
|
54 |
+
|
55 |
+
with torch.no_grad():
|
56 |
+
out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, alpha).cpu()
|
57 |
+
|
58 |
+
output_image = torchvision.transforms.ToPILImage()(out_tensor.squeeze(0))
|
59 |
+
|
60 |
+
return output_image
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.10.1
|
2 |
+
torchvision==0.11.2
|
3 |
+
opencv-python==4.5.1.48
|
4 |
+
numpy == 1.18.4
|
5 |
+
Pillow==8.4.0
|
6 |
+
tqdm==4.62.3
|
7 |
+
imageio==2.9.0
|
8 |
+
imageio-ffmpeg==0.4.6
|
9 |
+
matplotlib==3.3.2
|
10 |
+
gdown
|
11 |
+
packaging
|
utils.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image, ImageFile
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from pathlib import Path
|
8 |
+
from glob import glob
|
9 |
+
|
10 |
+
def adaptive_instance_normalization(x, y, eps=1e-5):
|
11 |
+
"""
|
12 |
+
Adaptive Instance Normalization. Perform neural style transfer given content image x
|
13 |
+
and style image y.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
x (torch.FloatTensor): Content image tensor
|
17 |
+
y (torch.FloatTensor): Style image tensor
|
18 |
+
eps (float, default=1e-5): Small value to avoid zero division
|
19 |
+
|
20 |
+
Return:
|
21 |
+
output (torch.FloatTensor): AdaIN style transferred output
|
22 |
+
"""
|
23 |
+
|
24 |
+
mu_x = torch.mean(x, dim=[2, 3])
|
25 |
+
mu_y = torch.mean(y, dim=[2, 3])
|
26 |
+
mu_x = mu_x.unsqueeze(-1).unsqueeze(-1)
|
27 |
+
mu_y = mu_y.unsqueeze(-1).unsqueeze(-1)
|
28 |
+
|
29 |
+
sigma_x = torch.std(x, dim=[2, 3])
|
30 |
+
sigma_y = torch.std(y, dim=[2, 3])
|
31 |
+
sigma_x = sigma_x.unsqueeze(-1).unsqueeze(-1) + eps
|
32 |
+
sigma_y = sigma_y.unsqueeze(-1).unsqueeze(-1) + eps
|
33 |
+
|
34 |
+
return (x - mu_x) / sigma_x * sigma_y + mu_y
|
35 |
+
|
36 |
+
def transform(size):
|
37 |
+
"""
|
38 |
+
Image preprocess transformation. Resize image and convert to tensor.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
size (int): Resize image size
|
42 |
+
|
43 |
+
Return:
|
44 |
+
output (torchvision.transforms): Composition of torchvision.transforms steps
|
45 |
+
"""
|
46 |
+
|
47 |
+
t = []
|
48 |
+
t.append(transforms.Resize(size))
|
49 |
+
t.append(transforms.ToTensor())
|
50 |
+
t = transforms.Compose(t)
|
51 |
+
return t
|
52 |
+
|
53 |
+
def grid_image(row, col, images, height=6, width=6, save_pth='grid.png'):
|
54 |
+
"""
|
55 |
+
Generate and save an image that contains row x col grids of images.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
row (int): number of rows
|
59 |
+
col (int): number of columns
|
60 |
+
images (list of PIL image): list of images.
|
61 |
+
height (int) : height of each image (inch)
|
62 |
+
width (int) : width of eac image (inch)
|
63 |
+
save_pth (str): save file path
|
64 |
+
"""
|
65 |
+
|
66 |
+
width = col * width
|
67 |
+
height = row * height
|
68 |
+
plt.figure(figsize=(width, height))
|
69 |
+
for i, image in enumerate(images):
|
70 |
+
plt.subplot(row, col, i+1)
|
71 |
+
plt.imshow(image)
|
72 |
+
plt.axis('off')
|
73 |
+
plt.subplots_adjust(wspace=0.01, hspace=0.01)
|
74 |
+
plt.savefig(save_pth)
|
75 |
+
|
76 |
+
|
77 |
+
def linear_histogram_matching(content_tensor, style_tensor):
|
78 |
+
"""
|
79 |
+
Given content_tensor and style_tensor, transform style_tensor histogram to that of content_tensor.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
content_tensor (torch.FloatTensor): Content image
|
83 |
+
style_tensor (torch.FloatTensor): Style Image
|
84 |
+
|
85 |
+
Return:
|
86 |
+
style_tensor (torch.FloatTensor): histogram matched Style Image
|
87 |
+
"""
|
88 |
+
#for batch
|
89 |
+
for b in range(len(content_tensor)):
|
90 |
+
std_ct = []
|
91 |
+
std_st = []
|
92 |
+
mean_ct = []
|
93 |
+
mean_st = []
|
94 |
+
#for channel
|
95 |
+
for c in range(len(content_tensor[b])):
|
96 |
+
std_ct.append(torch.var(content_tensor[b][c],unbiased = False))
|
97 |
+
mean_ct.append(torch.mean(content_tensor[b][c]))
|
98 |
+
std_st.append(torch.var(style_tensor[b][c],unbiased = False))
|
99 |
+
mean_st.append(torch.mean(style_tensor[b][c]))
|
100 |
+
style_tensor[b][c] = (style_tensor[b][c] - mean_st[c]) * std_ct[c] / std_st[c] + mean_ct[c]
|
101 |
+
return style_tensor
|
102 |
+
|
103 |
+
|
104 |
+
class TrainSet(Dataset):
|
105 |
+
"""
|
106 |
+
Build Training dataset
|
107 |
+
"""
|
108 |
+
def __init__(self, content_dir, style_dir, crop_size = 256):
|
109 |
+
super().__init__()
|
110 |
+
|
111 |
+
self.content_files = [Path(f) for f in glob(content_dir+'/*')]
|
112 |
+
self.style_files = [Path(f) for f in glob(style_dir+'/*')]
|
113 |
+
|
114 |
+
self.transform = transforms.Compose([
|
115 |
+
transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC),
|
116 |
+
transforms.RandomCrop(crop_size),
|
117 |
+
transforms.ToTensor(),
|
118 |
+
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
|
119 |
+
])
|
120 |
+
|
121 |
+
Image.MAX_IMAGE_PIXELS = None
|
122 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
123 |
+
|
124 |
+
def __len__(self):
|
125 |
+
return min(len(self.style_files), len(self.content_files))
|
126 |
+
|
127 |
+
def __getitem__(self, index):
|
128 |
+
content_img = Image.open(self.content_files[index]).convert('RGB')
|
129 |
+
style_img = Image.open(self.style_files[index]).convert('RGB')
|
130 |
+
|
131 |
+
content_sample = self.transform(content_img)
|
132 |
+
style_sample = self.transform(style_img)
|
133 |
+
|
134 |
+
return content_sample, style_sample
|
135 |
+
|
136 |
+
class Range(object):
|
137 |
+
"""
|
138 |
+
Helper class for input argument range restriction
|
139 |
+
"""
|
140 |
+
def __init__(self, start, end):
|
141 |
+
self.start = start
|
142 |
+
self.end = end
|
143 |
+
def __eq__(self, other):
|
144 |
+
return self.start <= other <= self.end
|