Chengkai Yang
commited on
Commit
·
7930ce0
0
Parent(s):
init
Browse files- AdaIN.py +54 -0
- Network.py +73 -0
- README.md +132 -0
- test.py +125 -0
- test_interpolate.py +126 -0
- test_style_transfer.py +58 -0
- test_video.py +107 -0
- train.py +100 -0
- utils.py +117 -0
AdaIN.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from Network import vgg19, decoder
|
4 |
+
from utils import adaptive_instance_normalization
|
5 |
+
|
6 |
+
class AdaINNet(nn.Module):
|
7 |
+
"""
|
8 |
+
AdaIN Style Transfer Network
|
9 |
+
|
10 |
+
Args:
|
11 |
+
vgg_weight: pretrained vgg19 weight
|
12 |
+
"""
|
13 |
+
def __init__(self, vgg_weight):
|
14 |
+
super().__init__()
|
15 |
+
self.encoder = vgg19(vgg_weight)
|
16 |
+
self.encoder = nn.Sequential(*list(self.encoder.children())[:22]) # drop layers after 4_1
|
17 |
+
for parameter in self.encoder.parameters():
|
18 |
+
parameter.requires_grad = False
|
19 |
+
|
20 |
+
self.decoder = decoder()
|
21 |
+
|
22 |
+
self.mseloss = nn.MSELoss()
|
23 |
+
|
24 |
+
def _style_loss(self, x, y):
|
25 |
+
return self.mseloss(torch.mean(x, dim=[2, 3]), torch.mean(y, dim=[2, 3])) + \
|
26 |
+
self.mseloss(torch.std(x, dim=[2, 3]), torch.std(y, dim=[2, 3]))
|
27 |
+
|
28 |
+
def forward(self, content, style, alpha=1.0):
|
29 |
+
content_enc = self.encoder(content)
|
30 |
+
style_enc = self.encoder(style)
|
31 |
+
transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
|
32 |
+
|
33 |
+
out = self.decoder(transfer_enc)
|
34 |
+
|
35 |
+
# vgg19 layer relu1_1
|
36 |
+
style_relu11 = self.encoder[:3](style)
|
37 |
+
out_relu11 = self.encoder[:3](out)
|
38 |
+
|
39 |
+
# vgg19 layer relu2_1
|
40 |
+
style_relu21 = self.encoder[3:8](style_relu11)
|
41 |
+
out_relu21 = self.encoder[3:8](out_relu11)
|
42 |
+
|
43 |
+
# vgg19 layer relu3_1
|
44 |
+
style_relu31 = self.encoder[8:13](style_relu21)
|
45 |
+
out_relu31 = self.encoder[8:13](out_relu21)
|
46 |
+
|
47 |
+
# vgg19 layer relu4_1
|
48 |
+
out_enc = self.encoder[13:](out_relu31)
|
49 |
+
|
50 |
+
content_loss = self.mseloss(out_enc, transfer_enc)
|
51 |
+
style_loss = self._style_loss(out_relu11, style_relu11) + self._style_loss(out_relu21, style_relu21) + \
|
52 |
+
self._style_loss(out_relu31, style_relu31) + self._style_loss(out_enc, style_enc)
|
53 |
+
|
54 |
+
return content_loss, style_loss
|
Network.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
vgg19_cfg = [3, 64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"]
|
4 |
+
decoder_cfg = [512, 256, "U", 256, 256, 256, 128, "U", 128, 64, 'U', 64, 3]
|
5 |
+
|
6 |
+
def vgg19(weights=None):
|
7 |
+
"""
|
8 |
+
Build vgg19 network. Load weights if weights are given.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
weights (dict): vgg19 pretrained weights
|
12 |
+
|
13 |
+
Return:
|
14 |
+
layers (nn.Sequential): vgg19 layers
|
15 |
+
"""
|
16 |
+
|
17 |
+
modules = make_block(vgg19_cfg)
|
18 |
+
modules = [nn.Conv2d(3, 3, kernel_size=1)] + list(modules.children())
|
19 |
+
layers = nn.Sequential(*modules)
|
20 |
+
|
21 |
+
if weights:
|
22 |
+
layers.load_state_dict(weights)
|
23 |
+
|
24 |
+
return layers
|
25 |
+
|
26 |
+
|
27 |
+
def decoder(weights=None):
|
28 |
+
"""
|
29 |
+
Build decoder network. Load weights if weights are given.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
weights (dict): decoder pretrained weights
|
33 |
+
|
34 |
+
Return:
|
35 |
+
layers (nn.Sequential): decoder layers
|
36 |
+
"""
|
37 |
+
|
38 |
+
modules = make_block(decoder_cfg)
|
39 |
+
layers = nn.Sequential(*list(modules.children())[:-1]) # no relu at the last layer
|
40 |
+
|
41 |
+
if weights:
|
42 |
+
layers.load_state_dict(weights)
|
43 |
+
|
44 |
+
return layers
|
45 |
+
|
46 |
+
|
47 |
+
def make_block(config):
|
48 |
+
"""
|
49 |
+
Helper function for building blocks of convolutional layers.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
config (list): List of layer configs. "M"
|
53 |
+
"M" - Max pooling layer.
|
54 |
+
"U" - Upsampling layer.
|
55 |
+
i (int) - Convolutional layer (i filters) plus ReLU activation.
|
56 |
+
Return:
|
57 |
+
layers (nn.Sequential): block layers
|
58 |
+
"""
|
59 |
+
layers = []
|
60 |
+
in_channels = config[0]
|
61 |
+
|
62 |
+
for c in config[1:]:
|
63 |
+
if c == "M":
|
64 |
+
layers.append(nn.MaxPool2d(kernel_size=2, stride=2, padding=0))
|
65 |
+
elif c == "U":
|
66 |
+
layers.append(nn.Upsample(scale_factor=2, mode='nearest'))
|
67 |
+
else:
|
68 |
+
assert(isinstance(c, int))
|
69 |
+
layers.append(nn.Conv2d(in_channels, c, kernel_size=3, padding=1))
|
70 |
+
layers.append(nn.ReLU(inplace=True))
|
71 |
+
in_channels = c
|
72 |
+
|
73 |
+
return nn.Sequential(*layers)
|
README.md
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2022-AdaIN-pytorch
|
2 |
+
============================
|
3 |
+
This is an unofficial Pytorch implementation of the paper, Style Transfer with Adaptive Instance Normalization [arxiv](https://arxiv.org/abs/1703.06868). I referred to the [official implementation](https://github.com/xunhuang1995/AdaIN-style) in Torch. I used pretrained weights for vgg19 and decoder from [naoto0804](https://github.com/naoto0804/pytorch-AdaIN).
|
4 |
+
|
5 |
+
Requirement
|
6 |
+
----------------------------
|
7 |
+
* Python 3.7
|
8 |
+
* PyTorch 1.10
|
9 |
+
* Pillow
|
10 |
+
* TorchVision
|
11 |
+
* Numpy
|
12 |
+
* tqdm
|
13 |
+
|
14 |
+
|
15 |
+
Usage
|
16 |
+
----------------------------
|
17 |
+
|
18 |
+
### Training
|
19 |
+
|
20 |
+
The encoder uses pretrained vgg19 network. Download the [weight of vgg19](https://drive.google.com/file/d/1UcSl-Zn3byEmn15NIPXMf9zaGCKc2gfx/view?usp=sharing). The decoder is trained on MSCOCO and wikiart dataset.
|
21 |
+
Run the script train.py
|
22 |
+
```
|
23 |
+
$ python train.py --content_dir $CONTENT_DIR --style_dir STYLE_DIR --cuda
|
24 |
+
|
25 |
+
usage: train.py [-h] [--content_dir CONTENT_DIR] [--style_dir STYLE_DIR]
|
26 |
+
[--epochs EPOCHS] [--batch_size BATCH_SIZE] [--resume RESUME] [--cuda]
|
27 |
+
|
28 |
+
optional arguments:
|
29 |
+
-h, --help show this help message and exit
|
30 |
+
--content_dir CONTENT_DIR
|
31 |
+
content images folder path
|
32 |
+
--style_dir STYLE_DIR
|
33 |
+
style images folder path
|
34 |
+
--epochs EPOCHS Number of epoch
|
35 |
+
--batch_size BATCH_SIZE
|
36 |
+
Batch size
|
37 |
+
--resume RESUME Continue training from epoch
|
38 |
+
--cuda Use CUDA
|
39 |
+
```
|
40 |
+
|
41 |
+
### Test Image Style Transfer
|
42 |
+
|
43 |
+
Download the [decoder weight](https://drive.google.com/file/d/18JpLtMOapA-vwBz-LRomyTl24A9GwhTF/view?usp=sharing).
|
44 |
+
|
45 |
+
To test basic style transfer, run the script test_image.py.
|
46 |
+
|
47 |
+
```
|
48 |
+
$ python test.py --content_image $IMG --style_image $STYLE --decoder_weight $WEIGHT --cuda
|
49 |
+
|
50 |
+
usage: test_style_transfer.py [-h] [--content_image CONTENT_IMAGE] [--content_dir CONTENT_DIR]
|
51 |
+
[--style_image STYLE_IMAGE] [--style_dir STYLE_DIR]
|
52 |
+
[--decoder_weight DECODER_WEIGHT] [--alpha {Alpha Range}]
|
53 |
+
[--cuda] [--grid_pth GRID_PTH]
|
54 |
+
|
55 |
+
optional arguments:
|
56 |
+
-h, --help show this help message and exit
|
57 |
+
--content_image CONTENT_IMAGE
|
58 |
+
single content image file
|
59 |
+
--content_dir CONTENT_DIR
|
60 |
+
content image directory, iterate all images under this directory
|
61 |
+
--style_image STYLE_IMAGE
|
62 |
+
single style image
|
63 |
+
--style_dir STYLE_DIR
|
64 |
+
style image directory, iterate all images under this directory
|
65 |
+
--decoder_weight DECODER_WEIGHT decoder weight file
|
66 |
+
--alpha {Alpha Range}
|
67 |
+
Alpha [0.0, 1.0] controls style transfer level
|
68 |
+
--cuda Use CUDA
|
69 |
+
--grid_pth GRID_PTH
|
70 |
+
Specify a grid image path (default=None) if generate a grid image that contains all style transferred images
|
71 |
+
```
|
72 |
+
|
73 |
+
### Test Image Interpolation Style Transfer
|
74 |
+
|
75 |
+
To test style transfer interpolation, run the script test_interpolate.py.
|
76 |
+
|
77 |
+
```
|
78 |
+
$ python test_interpolation.py --content_image $IMG --style_image $STYLE --decoder_weight $WEIGHT --cuda
|
79 |
+
|
80 |
+
usage: test_style_transfer.py [-h] [--content_image CONTENT_IMAGE] [--style_image STYLE_IMAGE]
|
81 |
+
[--decoder_weight DECODER_WEIGHT] [--alpha {Alpha Range}]
|
82 |
+
[--interpolation_weights INTERPOLATION_WEIGHTS]
|
83 |
+
[--cuda] [--grid_pth GRID_PTH]
|
84 |
+
|
85 |
+
optional arguments:
|
86 |
+
-h, --help show this help message and exit
|
87 |
+
--content_image CONTENT_IMAGE
|
88 |
+
single content image file
|
89 |
+
--style_image STYLE_IMAGE
|
90 |
+
multiple style images file separated by comma
|
91 |
+
--decoder_weight DECODER_WEIGHT decoder weight file
|
92 |
+
--alpha {Alpha Range}
|
93 |
+
Alpha [0.0, 1.0] controls style transfer level
|
94 |
+
--interpolation_weights INTERPOLATION_WEIGHTS
|
95 |
+
Interpolation weight of each style image, separated by comma. Perform interpolation style transfer once. Do not specify if input grid_pth.
|
96 |
+
--cuda Use CUDA
|
97 |
+
--grid_pth GRID_PTH
|
98 |
+
Specify a grid image path (default=None) to perform interpolation style transfer multiple times with different built-in weights and generate a grid image that contains all style transferred images. Provide 4 style images. Do not specify if input interpolation_weights.
|
99 |
+
```
|
100 |
+
|
101 |
+
### Test Video Style Transfer
|
102 |
+
----------------------------
|
103 |
+
|
104 |
+
To test video style transfer, run the script test_video.py.
|
105 |
+
|
106 |
+
|
107 |
+
```
|
108 |
+
$ python test_video.py --content_video $VID --style_image $STYLE --decoder_weight $WEIGHT --cuda
|
109 |
+
|
110 |
+
usage: test_style_transfer.py [-h] [--content_video CONTENT_VID] [--style_image STYLE_IMAGE]
|
111 |
+
[--decoder_weight DECODER_WEIGHT] [--alpha {Alpha Range}]
|
112 |
+
[--cuda]
|
113 |
+
|
114 |
+
optional arguments:
|
115 |
+
-h, --help show this help message and exit
|
116 |
+
--content_image CONTENT_IMAGE
|
117 |
+
single content video file
|
118 |
+
--style_image STYLE_IMAGE
|
119 |
+
single style image
|
120 |
+
--decoder_weight DECODER_WEIGHT decoder weight file
|
121 |
+
--alpha {Alpha Range}
|
122 |
+
Alpha [0.0, 1.0] controls style transfer level
|
123 |
+
--cuda Use CUDA
|
124 |
+
```
|
125 |
+
|
126 |
+
|
127 |
+
### References
|
128 |
+
----------------------------
|
129 |
+
|
130 |
+
-[1]: X. Huang and S. Belongie. "Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization.", in ICCV, 2017. [arxiv](https://arxiv.org/abs/1703.06868)
|
131 |
+
-[2]: [Original implementation in Torch](https://github.com/xunhuang1995/AdaIN-style)
|
132 |
+
-[3]: [Pretrained weights](https://github.com/naoto0804/pytorch-AdaIN)
|
test.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
import time
|
5 |
+
import numpy as np
|
6 |
+
from pathlib import Path
|
7 |
+
from AdaIN import AdaINNet
|
8 |
+
from PIL import Image
|
9 |
+
from torchvision.utils import save_image
|
10 |
+
from torchvision.transforms import ToPILImage
|
11 |
+
from utils import adaptive_instance_normalization, grid_image, transform, Range
|
12 |
+
from glob import glob
|
13 |
+
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument('--content_image', type=str, help='Content image file path')
|
16 |
+
parser.add_argument('--content_dir', type=str, help='Content image folder path')
|
17 |
+
parser.add_argument('--style_image', type=str, help='Style image file path')
|
18 |
+
parser.add_argument('--style_dir', type=str, help='Content image folder path')
|
19 |
+
parser.add_argument('--decoder_weight', type=str, required=True, help='Decoder weight file path')
|
20 |
+
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
|
21 |
+
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
22 |
+
parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images')
|
23 |
+
args = parser.parse_args()
|
24 |
+
assert args.content_image or args.content_dir
|
25 |
+
assert args.style_image or args.style_dir
|
26 |
+
assert args.decoder_weight
|
27 |
+
|
28 |
+
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
|
29 |
+
|
30 |
+
|
31 |
+
def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
|
32 |
+
"""
|
33 |
+
Given content image and style image, generate feature maps with encoder, apply
|
34 |
+
neural style transfer with adaptive instance normalization, generate output image
|
35 |
+
with decoder
|
36 |
+
|
37 |
+
Args:
|
38 |
+
content_tensor (torch.FloatTensor): Content image
|
39 |
+
style_tensor (torch.FloatTensor): Style Image
|
40 |
+
encoder: Encoder (vgg19) network
|
41 |
+
decoder: Decoder network
|
42 |
+
alpha (float, default=1.0): Weight of style image feature
|
43 |
+
|
44 |
+
Return:
|
45 |
+
output_tensor (torch.FloatTensor): Style Transfer output image
|
46 |
+
"""
|
47 |
+
|
48 |
+
content_enc = encoder(content_tensor)
|
49 |
+
style_enc = encoder(style_tensor)
|
50 |
+
|
51 |
+
transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
|
52 |
+
|
53 |
+
mix_enc = alpha * transfer_enc + (1-alpha) * content_enc
|
54 |
+
return decoder(mix_enc)
|
55 |
+
|
56 |
+
|
57 |
+
def main():
|
58 |
+
# Read content images and style images
|
59 |
+
if args.content_image:
|
60 |
+
content_pths = [Path(args.content_image)]
|
61 |
+
else:
|
62 |
+
content_pths = [Path(f) for f in glob(args.content_dir+'/*')]
|
63 |
+
|
64 |
+
if args.style_image:
|
65 |
+
style_pths = [Path(args.style_image)]
|
66 |
+
else:
|
67 |
+
style_pths = [Path(f) for f in glob(args.style_dir+'/*')]
|
68 |
+
|
69 |
+
out_dir = './results/'
|
70 |
+
os.makedirs(out_dir, exist_ok=True)
|
71 |
+
|
72 |
+
# Load AdaIN model
|
73 |
+
vgg = torch.load('vgg_normalized.pth')
|
74 |
+
model = AdaINNet(vgg).to(device)
|
75 |
+
model.decoder.load_state_dict(torch.load(args.decoder_weight))
|
76 |
+
model.eval()
|
77 |
+
|
78 |
+
# Prepare image transform
|
79 |
+
t = transform(512)
|
80 |
+
|
81 |
+
# Prepare grid image
|
82 |
+
if args.grid_pth:
|
83 |
+
imgs = [np.zeros((1,1))]
|
84 |
+
for style_pth in style_pths:
|
85 |
+
imgs.append(Image.open(style_pth))
|
86 |
+
|
87 |
+
# Timer
|
88 |
+
times = []
|
89 |
+
|
90 |
+
for content_pth in content_pths:
|
91 |
+
content_img = Image.open(content_pth)
|
92 |
+
content_tensor = t(content_img).unsqueeze(0).to(device)
|
93 |
+
|
94 |
+
if args.grid_pth:
|
95 |
+
imgs.append(content_img)
|
96 |
+
|
97 |
+
for style_pth in style_pths:
|
98 |
+
|
99 |
+
style_tensor = t(Image.open(style_pth)).unsqueeze(0).to(device)
|
100 |
+
|
101 |
+
tic = time.perf_counter() # Start time
|
102 |
+
with torch.no_grad():
|
103 |
+
out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, args.alpha).cpu()
|
104 |
+
|
105 |
+
toc = time.perf_counter() # End time
|
106 |
+
print("Content: " + content_pth.stem + ". Style: " \
|
107 |
+
+ style_pth.stem + '. Style Transfer time: %.4f seconds' % (toc-tic))
|
108 |
+
times.append(toc-tic)
|
109 |
+
|
110 |
+
out_pth = out_dir + content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(args.alpha) + content_pth.suffix
|
111 |
+
save_image(out_tensor, out_pth)
|
112 |
+
|
113 |
+
if args.grid_pth:
|
114 |
+
imgs.append(Image.open(out_pth))
|
115 |
+
|
116 |
+
avg = sum(times)/len(times)
|
117 |
+
print("Average style transfer time: %.4f seconds" % (avg))
|
118 |
+
|
119 |
+
if args.grid_pth:
|
120 |
+
print("Generating grid image")
|
121 |
+
grid_image(len(content_pths) + 1, len(style_pths) + 1, imgs, save_pth=args.grid_pth)
|
122 |
+
print("Finished")
|
123 |
+
|
124 |
+
if __name__ == '__main__':
|
125 |
+
main()
|
test_interpolate.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
import time
|
5 |
+
import numpy as np
|
6 |
+
from pathlib import Path
|
7 |
+
from AdaIN import AdaINNet
|
8 |
+
from PIL import Image
|
9 |
+
from torchvision.utils import save_image
|
10 |
+
from utils import adaptive_instance_normalization, transform, Range, grid_image
|
11 |
+
from glob import glob
|
12 |
+
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument('--content_image', type=str, help='Test image file path')
|
15 |
+
parser.add_argument('--style_image', type=str, required=True, help='Multiple Style image file path, separated by comma')
|
16 |
+
parser.add_argument('--decoder_weight', type=str, required=True, help='Decoder weight file path')
|
17 |
+
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
|
18 |
+
parser.add_argument('--interpolation_weights', type=str, help='Weights of interpolate multiple style images')
|
19 |
+
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
20 |
+
parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images. \
|
21 |
+
if use grid mode, provide 4 style images')
|
22 |
+
args = parser.parse_args()
|
23 |
+
assert args.content_image
|
24 |
+
assert args.style_image
|
25 |
+
assert args.decoder_weight
|
26 |
+
assert args.interpolation_weights or args.grid_pth
|
27 |
+
|
28 |
+
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
|
29 |
+
|
30 |
+
|
31 |
+
def interpolate_style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0, interpolation_weights=None):
|
32 |
+
"""
|
33 |
+
Given content image and multiple style images, generate feature maps with encoder, apply
|
34 |
+
neural style transfer with adaptive instance normalization, interpolate style image features
|
35 |
+
with interpolation weights, generate output image with decoder
|
36 |
+
|
37 |
+
Args:
|
38 |
+
content_tensor (torch.FloatTensor): Content image
|
39 |
+
style_tensor (torch.FloatTensor): Multiple Style Images
|
40 |
+
encoder: Encoder (vgg19) network
|
41 |
+
decoder: Decoder network
|
42 |
+
alpha (float, default=1.0): Weight of style image feature
|
43 |
+
interpolation_weights (list): Weight of each style image
|
44 |
+
|
45 |
+
Return:
|
46 |
+
output_tensor (torch.FloatTensor): Interpolate Style Transfer output image
|
47 |
+
"""
|
48 |
+
|
49 |
+
content_enc = encoder(content_tensor)
|
50 |
+
style_enc = encoder(style_tensor)
|
51 |
+
|
52 |
+
transfer_enc = torch.zeros_like(content_enc).to(device)
|
53 |
+
full_enc = adaptive_instance_normalization(content_enc, style_enc)
|
54 |
+
for i, w in enumerate(interpolation_weights):
|
55 |
+
transfer_enc += w * full_enc[i]
|
56 |
+
|
57 |
+
mix_enc = alpha * transfer_enc + (1-alpha) * content_enc
|
58 |
+
return decoder(mix_enc)
|
59 |
+
|
60 |
+
|
61 |
+
def main():
|
62 |
+
# Read content and style image
|
63 |
+
if args.content_image:
|
64 |
+
content_pths = [Path(args.content_image)]
|
65 |
+
else:
|
66 |
+
content_pths = [Path(f) for f in glob(args.content_dir+'/*')]
|
67 |
+
|
68 |
+
style_pths_list = args.style_image.split(',')
|
69 |
+
style_pths = [Path(pth) for pth in style_pths_list]
|
70 |
+
|
71 |
+
inter_weights = []
|
72 |
+
# If grid mode, use 4 style images, 5x5 interpolation weights
|
73 |
+
if args.grid_pth:
|
74 |
+
assert len(style_pths) == 4
|
75 |
+
inter_weights = [ [ min(4-a, 4-b) / 4, min(4-a, b) / 4, min(a, 4-b) / 4, min(a, b) / 4] \
|
76 |
+
for a in range(5) for b in range(5) ]
|
77 |
+
|
78 |
+
# Use user input interpolation weights
|
79 |
+
else:
|
80 |
+
inter_weight = [float(i) for i in args.interpolation_weights.split(',')]
|
81 |
+
inter_weight = [i / sum(inter_weight) for i in inter_weight]
|
82 |
+
inter_weights.append(inter_weight)
|
83 |
+
|
84 |
+
|
85 |
+
out_dir = './results_interpolate/'
|
86 |
+
os.makedirs(out_dir, exist_ok=True)
|
87 |
+
|
88 |
+
# Load AdaIN model
|
89 |
+
vgg = torch.load('vgg_normalized.pth')
|
90 |
+
model = AdaINNet(vgg).to(device)
|
91 |
+
model.decoder.load_state_dict(torch.load(args.decoder_weight))
|
92 |
+
model.eval()
|
93 |
+
|
94 |
+
# Prepare image transform
|
95 |
+
t = transform(512)
|
96 |
+
|
97 |
+
imgs = []
|
98 |
+
|
99 |
+
for content_pth in content_pths:
|
100 |
+
content_tensor = t(Image.open(content_pth)).unsqueeze(0).to(device)
|
101 |
+
|
102 |
+
style_tensor = []
|
103 |
+
for style_pth in style_pths:
|
104 |
+
img = Image.open(style_pth)
|
105 |
+
style_tensor.append(transform([512, 512])(img)) # Convert style images to same size
|
106 |
+
style_tensor = torch.stack(style_tensor, dim=0).to(device)
|
107 |
+
|
108 |
+
for inter_weight in inter_weights:
|
109 |
+
with torch.no_grad():
|
110 |
+
out_tensor = out_tensor = interpolate_style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, args.alpha, inter_weight).cpu()
|
111 |
+
|
112 |
+
print("Content: " + content_pth.stem + ". Style: " + str([style_pth.stem for style_pth in style_pths]) + ". Interpolation weight: ", str(inter_weight))
|
113 |
+
|
114 |
+
out_pth = out_dir + content_pth.stem + '_interpolate_' + str(inter_weight) + content_pth.suffix
|
115 |
+
save_image(out_tensor, out_pth)
|
116 |
+
|
117 |
+
if args.grid_pth:
|
118 |
+
imgs.append(Image.open(out_pth))
|
119 |
+
|
120 |
+
if args.grid_pth:
|
121 |
+
print("Generating grid image")
|
122 |
+
grid_image(5, 5, imgs, save_pth=args.grid_pth)
|
123 |
+
print("Finished")
|
124 |
+
|
125 |
+
if __name__ == '__main__':
|
126 |
+
main()
|
test_style_transfer.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from AdaIN import StyleTransferNet
|
5 |
+
from PIL import Image
|
6 |
+
import torch
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
from torchvision.utils import save_image
|
9 |
+
|
10 |
+
class AlphaRange(object):
|
11 |
+
def __init__(self, start, end):
|
12 |
+
self.start = start
|
13 |
+
self.end = end
|
14 |
+
def __eq__(self, other):
|
15 |
+
return self.start <= other <= self.end
|
16 |
+
def __str__(self):
|
17 |
+
return 'Alpha Range'
|
18 |
+
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument('--input_image', type=str, help='test image')
|
21 |
+
parser.add_argument('--style_image', type=str, help='style image')
|
22 |
+
parser.add_argument('--weight', type=str, help='decoder weight file')
|
23 |
+
parser.add_argument('--alpha', type=float, default=1.0, choices=[AlphaRange(0.0, 1.0)], help='Level of style transfer, value between 0 and 1')
|
24 |
+
parser.add_argument('--cuda', action='store_true', help='Using GPU to train')
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == '__main__':
|
28 |
+
opt =parser.parse_args()
|
29 |
+
input_image = Image.open(opt.input_image)
|
30 |
+
style_image = Image.open(opt.style_image)
|
31 |
+
output_format = opt.input_image[opt.input_image.find('.'):]
|
32 |
+
out_dir = './results/'
|
33 |
+
os.makedirs(out_dir, exist_ok=True)
|
34 |
+
with torch.no_grad():
|
35 |
+
vgg_model = torch.load('vgg_normalized.pth')
|
36 |
+
|
37 |
+
net = StyleTransferNet(vgg_model)
|
38 |
+
net.decoder.load_state_dict(torch.load(opt.weight))
|
39 |
+
|
40 |
+
net.eval()
|
41 |
+
|
42 |
+
input_image = transforms.Resize(512)(input_image)
|
43 |
+
style_image = transforms.Resize(512)(style_image)
|
44 |
+
|
45 |
+
input_tensor = transforms.ToTensor()(input_image).unsqueeze(0)
|
46 |
+
style_tensor = transforms.ToTensor()(style_image).unsqueeze(0)
|
47 |
+
|
48 |
+
|
49 |
+
if torch.cuda.is_available() and opt.cuda:
|
50 |
+
net.cuda()
|
51 |
+
input_tensor = input_tensor.cuda()
|
52 |
+
style_tensor = style_tensor.cuda()
|
53 |
+
out_tensor = net([input_tensor, style_tensor], alpha = opt.alpha)
|
54 |
+
|
55 |
+
|
56 |
+
save_image(out_tensor, out_dir + opt.input_image[opt.input_image.rfind('/')+1: opt.input_image.find('.')]
|
57 |
+
+"_style_"+ opt.style_image[opt.style_image.rfind('/')+1: opt.style_image.find('.')]
|
58 |
+
+ output_format)
|
test_video.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
from pathlib import Path
|
5 |
+
from AdaIN import AdaINNet
|
6 |
+
from PIL import Image
|
7 |
+
from utils import transform, adaptive_instance_normalization, Range
|
8 |
+
import cv2
|
9 |
+
import imageio
|
10 |
+
import numpy as np
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument('--content_video', type=str, required=True, help='Content video file path')
|
16 |
+
parser.add_argument('--style_image', type=str, required=True, help='Style image file path')
|
17 |
+
parser.add_argument('--decoder_weight', type=str, required=True, help='Decoder weight file path')
|
18 |
+
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
|
19 |
+
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
20 |
+
args = parser.parse_args()
|
21 |
+
|
22 |
+
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
|
23 |
+
|
24 |
+
|
25 |
+
def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
|
26 |
+
"""
|
27 |
+
Given content image and style image, generate feature maps with encoder, apply
|
28 |
+
neural style transfer with adaptive instance normalization, generate output image
|
29 |
+
with decoder
|
30 |
+
|
31 |
+
Args:
|
32 |
+
content_tensor (torch.FloatTensor): Content image
|
33 |
+
style_tensor (torch.FloatTensor): Style Image
|
34 |
+
encoder: Encoder (vgg19) network
|
35 |
+
decoder: Decoder network
|
36 |
+
alpha (float, default=1.0): Weight of style image feature
|
37 |
+
|
38 |
+
Return:
|
39 |
+
output_tensor (torch.FloatTensor): Style Transfer output image
|
40 |
+
"""
|
41 |
+
|
42 |
+
content_enc = encoder(content_tensor)
|
43 |
+
style_enc = encoder(style_tensor)
|
44 |
+
|
45 |
+
transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
|
46 |
+
|
47 |
+
mix_enc = alpha * transfer_enc + (1-alpha) * content_enc
|
48 |
+
return decoder(mix_enc)
|
49 |
+
|
50 |
+
|
51 |
+
def main():
|
52 |
+
# Read video file
|
53 |
+
content_video_pth = Path(args.content_video)
|
54 |
+
content_video = cv2.VideoCapture(str(content_video_pth))
|
55 |
+
style_image_pth = Path(args.style_image)
|
56 |
+
style_image = Image.open(style_image_pth)
|
57 |
+
|
58 |
+
fps = int(content_video.get(cv2.CAP_PROP_FPS))
|
59 |
+
frame_count = int(content_video.get(cv2.CAP_PROP_FRAME_COUNT))
|
60 |
+
video_height = int(content_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
61 |
+
video_width = int(content_video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
62 |
+
|
63 |
+
video_tqdm = tqdm(frame_count)
|
64 |
+
|
65 |
+
out_dir = './results_video/'
|
66 |
+
os.makedirs(out_dir, exist_ok=True)
|
67 |
+
out_pth = Path(out_dir + content_video_pth.stem + '_style_' \
|
68 |
+
+ style_image_pth.stem + content_video_pth.suffix)
|
69 |
+
writer = imageio.get_writer(out_pth, mode='I', fps=fps)
|
70 |
+
|
71 |
+
# Load AdaIN model
|
72 |
+
vgg = torch.load('vgg_normalized.pth')
|
73 |
+
model = AdaINNet(vgg).to(device)
|
74 |
+
model.decoder.load_state_dict(torch.load(args.decoder_weight))
|
75 |
+
model.eval()
|
76 |
+
|
77 |
+
t = transform(512)
|
78 |
+
|
79 |
+
style_tensor = t(style_image).unsqueeze(0).to(device)
|
80 |
+
|
81 |
+
|
82 |
+
while content_video.isOpened():
|
83 |
+
ret, content_image = content_video.read()
|
84 |
+
if not ret: # Failed to read a frame
|
85 |
+
break
|
86 |
+
|
87 |
+
content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device)
|
88 |
+
|
89 |
+
with torch.no_grad():
|
90 |
+
out_tensor = style_transfer(content_tensor, style_tensor, model.encoder
|
91 |
+
, model.decoder, args.alpha).cpu().detach().numpy()
|
92 |
+
|
93 |
+
# Convert output frame to original size and rgb range (0,255)
|
94 |
+
out_tensor = np.squeeze(out_tensor, axis=0)
|
95 |
+
out_tensor = np.transpose(out_tensor, (1, 2, 0))
|
96 |
+
out_tensor = cv2.normalize(src=out_tensor, dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
|
97 |
+
out_tensor = cv2.resize(out_tensor, (video_width, video_height), interpolation=cv2.INTER_CUBIC)
|
98 |
+
|
99 |
+
writer.append_data(np.array(out_tensor))
|
100 |
+
video_tqdm.update(1)
|
101 |
+
|
102 |
+
content_video.release()
|
103 |
+
|
104 |
+
print("\nContent: " + content_video_pth.stem + ". Style: " + style_image_pth.stem +'\n')
|
105 |
+
|
106 |
+
if __name__ == '__main__':
|
107 |
+
main()
|
train.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from utils import TrainSet
|
6 |
+
from AdaIN import AdaINNet
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
def main():
|
10 |
+
parser = argparse.ArgumentParser()
|
11 |
+
parser.add_argument('--content_dir', type=str, required=True, help='content images folder path')
|
12 |
+
parser.add_argument('--style_dir', type=str, required=True, help='style images folder path')
|
13 |
+
parser.add_argument('--epochs', type=int, default=1, help='Number of epoch')
|
14 |
+
parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
|
15 |
+
parser.add_argument('--resume', type=int, default=0, help='Continue training from epoch')
|
16 |
+
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
17 |
+
args = parser.parse_args()
|
18 |
+
|
19 |
+
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
|
20 |
+
|
21 |
+
check_point_dir = './check_point/'
|
22 |
+
weights_dir = './weights/'
|
23 |
+
train_set = TrainSet(args.content_dir, args.style_dir)
|
24 |
+
train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True)
|
25 |
+
|
26 |
+
vgg_model = torch.load('vgg_normalized.pth')
|
27 |
+
model = AdaINNet(vgg_model).to(device)
|
28 |
+
|
29 |
+
decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=1e-6)
|
30 |
+
total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
|
31 |
+
losses = []
|
32 |
+
iteration = 0
|
33 |
+
|
34 |
+
if args.resume > 0:
|
35 |
+
states = torch.load(check_point_dir + "epoch_" + str(args.resume)+'.pth')
|
36 |
+
model.decoder.load_state_dict(states['decoder'])
|
37 |
+
decoder_optimizer.load_state_dict(states['decoder_optimizer'])
|
38 |
+
losses = states['losses']
|
39 |
+
iteration = states['iteration']
|
40 |
+
|
41 |
+
|
42 |
+
for epoch in range(args.resume + 1, args.epochs + 1):
|
43 |
+
print("Begin epoch: %i/%i" % (epoch, int(args.epochs)))
|
44 |
+
train_tqdm = tqdm(train_loader)
|
45 |
+
train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
|
46 |
+
losses.append((iteration, total_loss, content_loss, style_loss))
|
47 |
+
total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
|
48 |
+
|
49 |
+
for content_batch, style_batch in train_tqdm:
|
50 |
+
|
51 |
+
content_batch = content_batch.to(device)
|
52 |
+
style_batch = style_batch.to(device)
|
53 |
+
|
54 |
+
loss_content, loss_style = model(content_batch, style_batch)
|
55 |
+
loss_scaled = loss_content + 10 * loss_style
|
56 |
+
loss_scaled.backward()
|
57 |
+
decoder_optimizer.step()
|
58 |
+
total_loss += loss_scaled.item() * style_batch.size(0)
|
59 |
+
decoder_optimizer.zero_grad()
|
60 |
+
|
61 |
+
total_num += style_batch.size(0)
|
62 |
+
|
63 |
+
if iteration % 100 == 0 and iteration > 0:
|
64 |
+
|
65 |
+
total_loss /= total_num
|
66 |
+
content_loss /= total_num
|
67 |
+
style_loss /= total_num
|
68 |
+
print('')
|
69 |
+
train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
|
70 |
+
|
71 |
+
losses.append((iteration, total_loss, content_loss, style_loss))
|
72 |
+
|
73 |
+
total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
|
74 |
+
total_num = 0
|
75 |
+
|
76 |
+
if iteration % np.ceil(len(train_loader.dataset)/args.batch_size) == 0 and iteration > 0:
|
77 |
+
total_loss /= total_num
|
78 |
+
content_loss /= total_num
|
79 |
+
style_loss /= total_num
|
80 |
+
total_num = 0
|
81 |
+
|
82 |
+
iteration += 1
|
83 |
+
|
84 |
+
print('Finished epoch: %i/%i' % (epoch, int(args.epochs)))
|
85 |
+
|
86 |
+
states = {'decoder': model.decoder.state_dict(), 'decoder_optimizer': decoder_optimizer.state_dict(),
|
87 |
+
'losses': losses, 'iteration': iteration}
|
88 |
+
torch.save(states, check_point_dir + 'epoch_%i.pth' % (epoch))
|
89 |
+
torch.save(model.decoder.state_dict(), weights_dir + 'decoder_epoch_%i.pth' % (epoch))
|
90 |
+
np.savetxt("losses", losses, fmt='%i,%.4f,%.4f,%.4f')
|
91 |
+
|
92 |
+
if __name__ == '__main__':
|
93 |
+
main()
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
|
utils.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image, ImageFile
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from pathlib import Path
|
8 |
+
from glob import glob
|
9 |
+
|
10 |
+
def adaptive_instance_normalization(x, y, eps=1e-5):
|
11 |
+
"""
|
12 |
+
Adaptive Instance Normalization. Perform neural style transfer given content image x
|
13 |
+
and style image y.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
x (torch.FloatTensor): Content image tensor
|
17 |
+
y (torch.FloatTensor): Style image tensor
|
18 |
+
eps (float, default=1e-5): Small value to avoid zero division
|
19 |
+
|
20 |
+
Return:
|
21 |
+
output (torch.FloatTensor): AdaIN style transferred output
|
22 |
+
"""
|
23 |
+
|
24 |
+
mu_x = torch.mean(x, dim=[2, 3])
|
25 |
+
mu_y = torch.mean(y, dim=[2, 3])
|
26 |
+
mu_x = mu_x.unsqueeze(-1).unsqueeze(-1)
|
27 |
+
mu_y = mu_y.unsqueeze(-1).unsqueeze(-1)
|
28 |
+
|
29 |
+
sigma_x = torch.std(x, dim=[2, 3])
|
30 |
+
sigma_y = torch.std(y, dim=[2, 3])
|
31 |
+
sigma_x = sigma_x.unsqueeze(-1).unsqueeze(-1) + eps
|
32 |
+
sigma_y = sigma_y.unsqueeze(-1).unsqueeze(-1) + eps
|
33 |
+
|
34 |
+
return (x - mu_x) / sigma_x * sigma_y + mu_y
|
35 |
+
|
36 |
+
def transform(size):
|
37 |
+
"""
|
38 |
+
Image preprocess transformation. Resize image and convert to tensor.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
size (int): Resize image size
|
42 |
+
|
43 |
+
Return:
|
44 |
+
output (torchvision.transforms): Composition of torchvision.transforms steps
|
45 |
+
"""
|
46 |
+
|
47 |
+
t = []
|
48 |
+
t.append(transforms.Resize(size))
|
49 |
+
t.append(transforms.ToTensor())
|
50 |
+
t = transforms.Compose(t)
|
51 |
+
return t
|
52 |
+
|
53 |
+
def grid_image(row, col, images, height=6, width=6, save_pth='grid.png'):
|
54 |
+
"""
|
55 |
+
Generate and save an image that contains row x col grids of images.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
row (int): number of rows
|
59 |
+
col (int): number of columns
|
60 |
+
images (list of PIL image): list of images.
|
61 |
+
height (int) : height of each image (inch)
|
62 |
+
width (int) : width of eac image (inch)
|
63 |
+
save_pth (str): save file path
|
64 |
+
"""
|
65 |
+
|
66 |
+
width = col * width
|
67 |
+
height = row * height
|
68 |
+
plt.figure(figsize=(width, height))
|
69 |
+
for i, image in enumerate(images):
|
70 |
+
plt.subplot(row, col, i+1)
|
71 |
+
plt.imshow(image)
|
72 |
+
plt.axis('off')
|
73 |
+
plt.subplots_adjust(wspace=0.01, hspace=0.01)
|
74 |
+
plt.savefig(save_pth)
|
75 |
+
|
76 |
+
|
77 |
+
class TrainSet(Dataset):
|
78 |
+
"""
|
79 |
+
Build Training dataset
|
80 |
+
"""
|
81 |
+
def __init__(self, content_dir, style_dir, crop_size = 256):
|
82 |
+
super().__init__()
|
83 |
+
|
84 |
+
self.content_files = [Path(f) for f in glob(content_dir+'/*')]
|
85 |
+
self.style_files = [Path(f) for f in glob(style_dir+'/*')]
|
86 |
+
|
87 |
+
self.transform = transforms.Compose([
|
88 |
+
transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC),
|
89 |
+
transforms.RandomCrop(crop_size),
|
90 |
+
transforms.ToTensor(),
|
91 |
+
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
|
92 |
+
])
|
93 |
+
|
94 |
+
Image.MAX_IMAGE_PIXELS = None
|
95 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
96 |
+
|
97 |
+
def __len__(self):
|
98 |
+
return min(len(self.style_files), len(self.content_files))
|
99 |
+
|
100 |
+
def __getitem__(self, index):
|
101 |
+
content_img = Image.open(self.content_files[index]).convert('RGB')
|
102 |
+
style_img = Image.open(self.style_files[index]).convert('RGB')
|
103 |
+
|
104 |
+
content_sample = self.transform(content_img)
|
105 |
+
style_sample = self.transform(style_img)
|
106 |
+
|
107 |
+
return content_sample, style_sample
|
108 |
+
|
109 |
+
class Range(object):
|
110 |
+
"""
|
111 |
+
Helper class for input argument range restriction
|
112 |
+
"""
|
113 |
+
def __init__(self, start, end):
|
114 |
+
self.start = start
|
115 |
+
self.end = end
|
116 |
+
def __eq__(self, other):
|
117 |
+
return self.start <= other <= self.end
|