Chengkai Yang commited on
Commit
7930ce0
·
0 Parent(s):
Files changed (9) hide show
  1. AdaIN.py +54 -0
  2. Network.py +73 -0
  3. README.md +132 -0
  4. test.py +125 -0
  5. test_interpolate.py +126 -0
  6. test_style_transfer.py +58 -0
  7. test_video.py +107 -0
  8. train.py +100 -0
  9. utils.py +117 -0
AdaIN.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from Network import vgg19, decoder
4
+ from utils import adaptive_instance_normalization
5
+
6
+ class AdaINNet(nn.Module):
7
+ """
8
+ AdaIN Style Transfer Network
9
+
10
+ Args:
11
+ vgg_weight: pretrained vgg19 weight
12
+ """
13
+ def __init__(self, vgg_weight):
14
+ super().__init__()
15
+ self.encoder = vgg19(vgg_weight)
16
+ self.encoder = nn.Sequential(*list(self.encoder.children())[:22]) # drop layers after 4_1
17
+ for parameter in self.encoder.parameters():
18
+ parameter.requires_grad = False
19
+
20
+ self.decoder = decoder()
21
+
22
+ self.mseloss = nn.MSELoss()
23
+
24
+ def _style_loss(self, x, y):
25
+ return self.mseloss(torch.mean(x, dim=[2, 3]), torch.mean(y, dim=[2, 3])) + \
26
+ self.mseloss(torch.std(x, dim=[2, 3]), torch.std(y, dim=[2, 3]))
27
+
28
+ def forward(self, content, style, alpha=1.0):
29
+ content_enc = self.encoder(content)
30
+ style_enc = self.encoder(style)
31
+ transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
32
+
33
+ out = self.decoder(transfer_enc)
34
+
35
+ # vgg19 layer relu1_1
36
+ style_relu11 = self.encoder[:3](style)
37
+ out_relu11 = self.encoder[:3](out)
38
+
39
+ # vgg19 layer relu2_1
40
+ style_relu21 = self.encoder[3:8](style_relu11)
41
+ out_relu21 = self.encoder[3:8](out_relu11)
42
+
43
+ # vgg19 layer relu3_1
44
+ style_relu31 = self.encoder[8:13](style_relu21)
45
+ out_relu31 = self.encoder[8:13](out_relu21)
46
+
47
+ # vgg19 layer relu4_1
48
+ out_enc = self.encoder[13:](out_relu31)
49
+
50
+ content_loss = self.mseloss(out_enc, transfer_enc)
51
+ style_loss = self._style_loss(out_relu11, style_relu11) + self._style_loss(out_relu21, style_relu21) + \
52
+ self._style_loss(out_relu31, style_relu31) + self._style_loss(out_enc, style_enc)
53
+
54
+ return content_loss, style_loss
Network.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ vgg19_cfg = [3, 64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"]
4
+ decoder_cfg = [512, 256, "U", 256, 256, 256, 128, "U", 128, 64, 'U', 64, 3]
5
+
6
+ def vgg19(weights=None):
7
+ """
8
+ Build vgg19 network. Load weights if weights are given.
9
+
10
+ Args:
11
+ weights (dict): vgg19 pretrained weights
12
+
13
+ Return:
14
+ layers (nn.Sequential): vgg19 layers
15
+ """
16
+
17
+ modules = make_block(vgg19_cfg)
18
+ modules = [nn.Conv2d(3, 3, kernel_size=1)] + list(modules.children())
19
+ layers = nn.Sequential(*modules)
20
+
21
+ if weights:
22
+ layers.load_state_dict(weights)
23
+
24
+ return layers
25
+
26
+
27
+ def decoder(weights=None):
28
+ """
29
+ Build decoder network. Load weights if weights are given.
30
+
31
+ Args:
32
+ weights (dict): decoder pretrained weights
33
+
34
+ Return:
35
+ layers (nn.Sequential): decoder layers
36
+ """
37
+
38
+ modules = make_block(decoder_cfg)
39
+ layers = nn.Sequential(*list(modules.children())[:-1]) # no relu at the last layer
40
+
41
+ if weights:
42
+ layers.load_state_dict(weights)
43
+
44
+ return layers
45
+
46
+
47
+ def make_block(config):
48
+ """
49
+ Helper function for building blocks of convolutional layers.
50
+
51
+ Args:
52
+ config (list): List of layer configs. "M"
53
+ "M" - Max pooling layer.
54
+ "U" - Upsampling layer.
55
+ i (int) - Convolutional layer (i filters) plus ReLU activation.
56
+ Return:
57
+ layers (nn.Sequential): block layers
58
+ """
59
+ layers = []
60
+ in_channels = config[0]
61
+
62
+ for c in config[1:]:
63
+ if c == "M":
64
+ layers.append(nn.MaxPool2d(kernel_size=2, stride=2, padding=0))
65
+ elif c == "U":
66
+ layers.append(nn.Upsample(scale_factor=2, mode='nearest'))
67
+ else:
68
+ assert(isinstance(c, int))
69
+ layers.append(nn.Conv2d(in_channels, c, kernel_size=3, padding=1))
70
+ layers.append(nn.ReLU(inplace=True))
71
+ in_channels = c
72
+
73
+ return nn.Sequential(*layers)
README.md ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2022-AdaIN-pytorch
2
+ ============================
3
+ This is an unofficial Pytorch implementation of the paper, Style Transfer with Adaptive Instance Normalization [arxiv](https://arxiv.org/abs/1703.06868). I referred to the [official implementation](https://github.com/xunhuang1995/AdaIN-style) in Torch. I used pretrained weights for vgg19 and decoder from [naoto0804](https://github.com/naoto0804/pytorch-AdaIN).
4
+
5
+ Requirement
6
+ ----------------------------
7
+ * Python 3.7
8
+ * PyTorch 1.10
9
+ * Pillow
10
+ * TorchVision
11
+ * Numpy
12
+ * tqdm
13
+
14
+
15
+ Usage
16
+ ----------------------------
17
+
18
+ ### Training
19
+
20
+ The encoder uses pretrained vgg19 network. Download the [weight of vgg19](https://drive.google.com/file/d/1UcSl-Zn3byEmn15NIPXMf9zaGCKc2gfx/view?usp=sharing). The decoder is trained on MSCOCO and wikiart dataset.
21
+ Run the script train.py
22
+ ```
23
+ $ python train.py --content_dir $CONTENT_DIR --style_dir STYLE_DIR --cuda
24
+
25
+ usage: train.py [-h] [--content_dir CONTENT_DIR] [--style_dir STYLE_DIR]
26
+ [--epochs EPOCHS] [--batch_size BATCH_SIZE] [--resume RESUME] [--cuda]
27
+
28
+ optional arguments:
29
+ -h, --help show this help message and exit
30
+ --content_dir CONTENT_DIR
31
+ content images folder path
32
+ --style_dir STYLE_DIR
33
+ style images folder path
34
+ --epochs EPOCHS Number of epoch
35
+ --batch_size BATCH_SIZE
36
+ Batch size
37
+ --resume RESUME Continue training from epoch
38
+ --cuda Use CUDA
39
+ ```
40
+
41
+ ### Test Image Style Transfer
42
+
43
+ Download the [decoder weight](https://drive.google.com/file/d/18JpLtMOapA-vwBz-LRomyTl24A9GwhTF/view?usp=sharing).
44
+
45
+ To test basic style transfer, run the script test_image.py.
46
+
47
+ ```
48
+ $ python test.py --content_image $IMG --style_image $STYLE --decoder_weight $WEIGHT --cuda
49
+
50
+ usage: test_style_transfer.py [-h] [--content_image CONTENT_IMAGE] [--content_dir CONTENT_DIR]
51
+ [--style_image STYLE_IMAGE] [--style_dir STYLE_DIR]
52
+ [--decoder_weight DECODER_WEIGHT] [--alpha {Alpha Range}]
53
+ [--cuda] [--grid_pth GRID_PTH]
54
+
55
+ optional arguments:
56
+ -h, --help show this help message and exit
57
+ --content_image CONTENT_IMAGE
58
+ single content image file
59
+ --content_dir CONTENT_DIR
60
+ content image directory, iterate all images under this directory
61
+ --style_image STYLE_IMAGE
62
+ single style image
63
+ --style_dir STYLE_DIR
64
+ style image directory, iterate all images under this directory
65
+ --decoder_weight DECODER_WEIGHT decoder weight file
66
+ --alpha {Alpha Range}
67
+ Alpha [0.0, 1.0] controls style transfer level
68
+ --cuda Use CUDA
69
+ --grid_pth GRID_PTH
70
+ Specify a grid image path (default=None) if generate a grid image that contains all style transferred images
71
+ ```
72
+
73
+ ### Test Image Interpolation Style Transfer
74
+
75
+ To test style transfer interpolation, run the script test_interpolate.py.
76
+
77
+ ```
78
+ $ python test_interpolation.py --content_image $IMG --style_image $STYLE --decoder_weight $WEIGHT --cuda
79
+
80
+ usage: test_style_transfer.py [-h] [--content_image CONTENT_IMAGE] [--style_image STYLE_IMAGE]
81
+ [--decoder_weight DECODER_WEIGHT] [--alpha {Alpha Range}]
82
+ [--interpolation_weights INTERPOLATION_WEIGHTS]
83
+ [--cuda] [--grid_pth GRID_PTH]
84
+
85
+ optional arguments:
86
+ -h, --help show this help message and exit
87
+ --content_image CONTENT_IMAGE
88
+ single content image file
89
+ --style_image STYLE_IMAGE
90
+ multiple style images file separated by comma
91
+ --decoder_weight DECODER_WEIGHT decoder weight file
92
+ --alpha {Alpha Range}
93
+ Alpha [0.0, 1.0] controls style transfer level
94
+ --interpolation_weights INTERPOLATION_WEIGHTS
95
+ Interpolation weight of each style image, separated by comma. Perform interpolation style transfer once. Do not specify if input grid_pth.
96
+ --cuda Use CUDA
97
+ --grid_pth GRID_PTH
98
+ Specify a grid image path (default=None) to perform interpolation style transfer multiple times with different built-in weights and generate a grid image that contains all style transferred images. Provide 4 style images. Do not specify if input interpolation_weights.
99
+ ```
100
+
101
+ ### Test Video Style Transfer
102
+ ----------------------------
103
+
104
+ To test video style transfer, run the script test_video.py.
105
+
106
+
107
+ ```
108
+ $ python test_video.py --content_video $VID --style_image $STYLE --decoder_weight $WEIGHT --cuda
109
+
110
+ usage: test_style_transfer.py [-h] [--content_video CONTENT_VID] [--style_image STYLE_IMAGE]
111
+ [--decoder_weight DECODER_WEIGHT] [--alpha {Alpha Range}]
112
+ [--cuda]
113
+
114
+ optional arguments:
115
+ -h, --help show this help message and exit
116
+ --content_image CONTENT_IMAGE
117
+ single content video file
118
+ --style_image STYLE_IMAGE
119
+ single style image
120
+ --decoder_weight DECODER_WEIGHT decoder weight file
121
+ --alpha {Alpha Range}
122
+ Alpha [0.0, 1.0] controls style transfer level
123
+ --cuda Use CUDA
124
+ ```
125
+
126
+
127
+ ### References
128
+ ----------------------------
129
+
130
+ -[1]: X. Huang and S. Belongie. "Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization.", in ICCV, 2017. [arxiv](https://arxiv.org/abs/1703.06868)
131
+ -[2]: [Original implementation in Torch](https://github.com/xunhuang1995/AdaIN-style)
132
+ -[3]: [Pretrained weights](https://github.com/naoto0804/pytorch-AdaIN)
test.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ import time
5
+ import numpy as np
6
+ from pathlib import Path
7
+ from AdaIN import AdaINNet
8
+ from PIL import Image
9
+ from torchvision.utils import save_image
10
+ from torchvision.transforms import ToPILImage
11
+ from utils import adaptive_instance_normalization, grid_image, transform, Range
12
+ from glob import glob
13
+
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--content_image', type=str, help='Content image file path')
16
+ parser.add_argument('--content_dir', type=str, help='Content image folder path')
17
+ parser.add_argument('--style_image', type=str, help='Style image file path')
18
+ parser.add_argument('--style_dir', type=str, help='Content image folder path')
19
+ parser.add_argument('--decoder_weight', type=str, required=True, help='Decoder weight file path')
20
+ parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
21
+ parser.add_argument('--cuda', action='store_true', help='Use CUDA')
22
+ parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images')
23
+ args = parser.parse_args()
24
+ assert args.content_image or args.content_dir
25
+ assert args.style_image or args.style_dir
26
+ assert args.decoder_weight
27
+
28
+ device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
29
+
30
+
31
+ def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
32
+ """
33
+ Given content image and style image, generate feature maps with encoder, apply
34
+ neural style transfer with adaptive instance normalization, generate output image
35
+ with decoder
36
+
37
+ Args:
38
+ content_tensor (torch.FloatTensor): Content image
39
+ style_tensor (torch.FloatTensor): Style Image
40
+ encoder: Encoder (vgg19) network
41
+ decoder: Decoder network
42
+ alpha (float, default=1.0): Weight of style image feature
43
+
44
+ Return:
45
+ output_tensor (torch.FloatTensor): Style Transfer output image
46
+ """
47
+
48
+ content_enc = encoder(content_tensor)
49
+ style_enc = encoder(style_tensor)
50
+
51
+ transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
52
+
53
+ mix_enc = alpha * transfer_enc + (1-alpha) * content_enc
54
+ return decoder(mix_enc)
55
+
56
+
57
+ def main():
58
+ # Read content images and style images
59
+ if args.content_image:
60
+ content_pths = [Path(args.content_image)]
61
+ else:
62
+ content_pths = [Path(f) for f in glob(args.content_dir+'/*')]
63
+
64
+ if args.style_image:
65
+ style_pths = [Path(args.style_image)]
66
+ else:
67
+ style_pths = [Path(f) for f in glob(args.style_dir+'/*')]
68
+
69
+ out_dir = './results/'
70
+ os.makedirs(out_dir, exist_ok=True)
71
+
72
+ # Load AdaIN model
73
+ vgg = torch.load('vgg_normalized.pth')
74
+ model = AdaINNet(vgg).to(device)
75
+ model.decoder.load_state_dict(torch.load(args.decoder_weight))
76
+ model.eval()
77
+
78
+ # Prepare image transform
79
+ t = transform(512)
80
+
81
+ # Prepare grid image
82
+ if args.grid_pth:
83
+ imgs = [np.zeros((1,1))]
84
+ for style_pth in style_pths:
85
+ imgs.append(Image.open(style_pth))
86
+
87
+ # Timer
88
+ times = []
89
+
90
+ for content_pth in content_pths:
91
+ content_img = Image.open(content_pth)
92
+ content_tensor = t(content_img).unsqueeze(0).to(device)
93
+
94
+ if args.grid_pth:
95
+ imgs.append(content_img)
96
+
97
+ for style_pth in style_pths:
98
+
99
+ style_tensor = t(Image.open(style_pth)).unsqueeze(0).to(device)
100
+
101
+ tic = time.perf_counter() # Start time
102
+ with torch.no_grad():
103
+ out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, args.alpha).cpu()
104
+
105
+ toc = time.perf_counter() # End time
106
+ print("Content: " + content_pth.stem + ". Style: " \
107
+ + style_pth.stem + '. Style Transfer time: %.4f seconds' % (toc-tic))
108
+ times.append(toc-tic)
109
+
110
+ out_pth = out_dir + content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(args.alpha) + content_pth.suffix
111
+ save_image(out_tensor, out_pth)
112
+
113
+ if args.grid_pth:
114
+ imgs.append(Image.open(out_pth))
115
+
116
+ avg = sum(times)/len(times)
117
+ print("Average style transfer time: %.4f seconds" % (avg))
118
+
119
+ if args.grid_pth:
120
+ print("Generating grid image")
121
+ grid_image(len(content_pths) + 1, len(style_pths) + 1, imgs, save_pth=args.grid_pth)
122
+ print("Finished")
123
+
124
+ if __name__ == '__main__':
125
+ main()
test_interpolate.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ import time
5
+ import numpy as np
6
+ from pathlib import Path
7
+ from AdaIN import AdaINNet
8
+ from PIL import Image
9
+ from torchvision.utils import save_image
10
+ from utils import adaptive_instance_normalization, transform, Range, grid_image
11
+ from glob import glob
12
+
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument('--content_image', type=str, help='Test image file path')
15
+ parser.add_argument('--style_image', type=str, required=True, help='Multiple Style image file path, separated by comma')
16
+ parser.add_argument('--decoder_weight', type=str, required=True, help='Decoder weight file path')
17
+ parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
18
+ parser.add_argument('--interpolation_weights', type=str, help='Weights of interpolate multiple style images')
19
+ parser.add_argument('--cuda', action='store_true', help='Use CUDA')
20
+ parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images. \
21
+ if use grid mode, provide 4 style images')
22
+ args = parser.parse_args()
23
+ assert args.content_image
24
+ assert args.style_image
25
+ assert args.decoder_weight
26
+ assert args.interpolation_weights or args.grid_pth
27
+
28
+ device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
29
+
30
+
31
+ def interpolate_style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0, interpolation_weights=None):
32
+ """
33
+ Given content image and multiple style images, generate feature maps with encoder, apply
34
+ neural style transfer with adaptive instance normalization, interpolate style image features
35
+ with interpolation weights, generate output image with decoder
36
+
37
+ Args:
38
+ content_tensor (torch.FloatTensor): Content image
39
+ style_tensor (torch.FloatTensor): Multiple Style Images
40
+ encoder: Encoder (vgg19) network
41
+ decoder: Decoder network
42
+ alpha (float, default=1.0): Weight of style image feature
43
+ interpolation_weights (list): Weight of each style image
44
+
45
+ Return:
46
+ output_tensor (torch.FloatTensor): Interpolate Style Transfer output image
47
+ """
48
+
49
+ content_enc = encoder(content_tensor)
50
+ style_enc = encoder(style_tensor)
51
+
52
+ transfer_enc = torch.zeros_like(content_enc).to(device)
53
+ full_enc = adaptive_instance_normalization(content_enc, style_enc)
54
+ for i, w in enumerate(interpolation_weights):
55
+ transfer_enc += w * full_enc[i]
56
+
57
+ mix_enc = alpha * transfer_enc + (1-alpha) * content_enc
58
+ return decoder(mix_enc)
59
+
60
+
61
+ def main():
62
+ # Read content and style image
63
+ if args.content_image:
64
+ content_pths = [Path(args.content_image)]
65
+ else:
66
+ content_pths = [Path(f) for f in glob(args.content_dir+'/*')]
67
+
68
+ style_pths_list = args.style_image.split(',')
69
+ style_pths = [Path(pth) for pth in style_pths_list]
70
+
71
+ inter_weights = []
72
+ # If grid mode, use 4 style images, 5x5 interpolation weights
73
+ if args.grid_pth:
74
+ assert len(style_pths) == 4
75
+ inter_weights = [ [ min(4-a, 4-b) / 4, min(4-a, b) / 4, min(a, 4-b) / 4, min(a, b) / 4] \
76
+ for a in range(5) for b in range(5) ]
77
+
78
+ # Use user input interpolation weights
79
+ else:
80
+ inter_weight = [float(i) for i in args.interpolation_weights.split(',')]
81
+ inter_weight = [i / sum(inter_weight) for i in inter_weight]
82
+ inter_weights.append(inter_weight)
83
+
84
+
85
+ out_dir = './results_interpolate/'
86
+ os.makedirs(out_dir, exist_ok=True)
87
+
88
+ # Load AdaIN model
89
+ vgg = torch.load('vgg_normalized.pth')
90
+ model = AdaINNet(vgg).to(device)
91
+ model.decoder.load_state_dict(torch.load(args.decoder_weight))
92
+ model.eval()
93
+
94
+ # Prepare image transform
95
+ t = transform(512)
96
+
97
+ imgs = []
98
+
99
+ for content_pth in content_pths:
100
+ content_tensor = t(Image.open(content_pth)).unsqueeze(0).to(device)
101
+
102
+ style_tensor = []
103
+ for style_pth in style_pths:
104
+ img = Image.open(style_pth)
105
+ style_tensor.append(transform([512, 512])(img)) # Convert style images to same size
106
+ style_tensor = torch.stack(style_tensor, dim=0).to(device)
107
+
108
+ for inter_weight in inter_weights:
109
+ with torch.no_grad():
110
+ out_tensor = out_tensor = interpolate_style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, args.alpha, inter_weight).cpu()
111
+
112
+ print("Content: " + content_pth.stem + ". Style: " + str([style_pth.stem for style_pth in style_pths]) + ". Interpolation weight: ", str(inter_weight))
113
+
114
+ out_pth = out_dir + content_pth.stem + '_interpolate_' + str(inter_weight) + content_pth.suffix
115
+ save_image(out_tensor, out_pth)
116
+
117
+ if args.grid_pth:
118
+ imgs.append(Image.open(out_pth))
119
+
120
+ if args.grid_pth:
121
+ print("Generating grid image")
122
+ grid_image(5, 5, imgs, save_pth=args.grid_pth)
123
+ print("Finished")
124
+
125
+ if __name__ == '__main__':
126
+ main()
test_style_transfer.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import numpy as np
4
+ from AdaIN import StyleTransferNet
5
+ from PIL import Image
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+ from torchvision.utils import save_image
9
+
10
+ class AlphaRange(object):
11
+ def __init__(self, start, end):
12
+ self.start = start
13
+ self.end = end
14
+ def __eq__(self, other):
15
+ return self.start <= other <= self.end
16
+ def __str__(self):
17
+ return 'Alpha Range'
18
+
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument('--input_image', type=str, help='test image')
21
+ parser.add_argument('--style_image', type=str, help='style image')
22
+ parser.add_argument('--weight', type=str, help='decoder weight file')
23
+ parser.add_argument('--alpha', type=float, default=1.0, choices=[AlphaRange(0.0, 1.0)], help='Level of style transfer, value between 0 and 1')
24
+ parser.add_argument('--cuda', action='store_true', help='Using GPU to train')
25
+
26
+
27
+ if __name__ == '__main__':
28
+ opt =parser.parse_args()
29
+ input_image = Image.open(opt.input_image)
30
+ style_image = Image.open(opt.style_image)
31
+ output_format = opt.input_image[opt.input_image.find('.'):]
32
+ out_dir = './results/'
33
+ os.makedirs(out_dir, exist_ok=True)
34
+ with torch.no_grad():
35
+ vgg_model = torch.load('vgg_normalized.pth')
36
+
37
+ net = StyleTransferNet(vgg_model)
38
+ net.decoder.load_state_dict(torch.load(opt.weight))
39
+
40
+ net.eval()
41
+
42
+ input_image = transforms.Resize(512)(input_image)
43
+ style_image = transforms.Resize(512)(style_image)
44
+
45
+ input_tensor = transforms.ToTensor()(input_image).unsqueeze(0)
46
+ style_tensor = transforms.ToTensor()(style_image).unsqueeze(0)
47
+
48
+
49
+ if torch.cuda.is_available() and opt.cuda:
50
+ net.cuda()
51
+ input_tensor = input_tensor.cuda()
52
+ style_tensor = style_tensor.cuda()
53
+ out_tensor = net([input_tensor, style_tensor], alpha = opt.alpha)
54
+
55
+
56
+ save_image(out_tensor, out_dir + opt.input_image[opt.input_image.rfind('/')+1: opt.input_image.find('.')]
57
+ +"_style_"+ opt.style_image[opt.style_image.rfind('/')+1: opt.style_image.find('.')]
58
+ + output_format)
test_video.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ from pathlib import Path
5
+ from AdaIN import AdaINNet
6
+ from PIL import Image
7
+ from utils import transform, adaptive_instance_normalization, Range
8
+ import cv2
9
+ import imageio
10
+ import numpy as np
11
+ from tqdm import tqdm
12
+
13
+
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--content_video', type=str, required=True, help='Content video file path')
16
+ parser.add_argument('--style_image', type=str, required=True, help='Style image file path')
17
+ parser.add_argument('--decoder_weight', type=str, required=True, help='Decoder weight file path')
18
+ parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
19
+ parser.add_argument('--cuda', action='store_true', help='Use CUDA')
20
+ args = parser.parse_args()
21
+
22
+ device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
23
+
24
+
25
+ def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
26
+ """
27
+ Given content image and style image, generate feature maps with encoder, apply
28
+ neural style transfer with adaptive instance normalization, generate output image
29
+ with decoder
30
+
31
+ Args:
32
+ content_tensor (torch.FloatTensor): Content image
33
+ style_tensor (torch.FloatTensor): Style Image
34
+ encoder: Encoder (vgg19) network
35
+ decoder: Decoder network
36
+ alpha (float, default=1.0): Weight of style image feature
37
+
38
+ Return:
39
+ output_tensor (torch.FloatTensor): Style Transfer output image
40
+ """
41
+
42
+ content_enc = encoder(content_tensor)
43
+ style_enc = encoder(style_tensor)
44
+
45
+ transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
46
+
47
+ mix_enc = alpha * transfer_enc + (1-alpha) * content_enc
48
+ return decoder(mix_enc)
49
+
50
+
51
+ def main():
52
+ # Read video file
53
+ content_video_pth = Path(args.content_video)
54
+ content_video = cv2.VideoCapture(str(content_video_pth))
55
+ style_image_pth = Path(args.style_image)
56
+ style_image = Image.open(style_image_pth)
57
+
58
+ fps = int(content_video.get(cv2.CAP_PROP_FPS))
59
+ frame_count = int(content_video.get(cv2.CAP_PROP_FRAME_COUNT))
60
+ video_height = int(content_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
61
+ video_width = int(content_video.get(cv2.CAP_PROP_FRAME_WIDTH))
62
+
63
+ video_tqdm = tqdm(frame_count)
64
+
65
+ out_dir = './results_video/'
66
+ os.makedirs(out_dir, exist_ok=True)
67
+ out_pth = Path(out_dir + content_video_pth.stem + '_style_' \
68
+ + style_image_pth.stem + content_video_pth.suffix)
69
+ writer = imageio.get_writer(out_pth, mode='I', fps=fps)
70
+
71
+ # Load AdaIN model
72
+ vgg = torch.load('vgg_normalized.pth')
73
+ model = AdaINNet(vgg).to(device)
74
+ model.decoder.load_state_dict(torch.load(args.decoder_weight))
75
+ model.eval()
76
+
77
+ t = transform(512)
78
+
79
+ style_tensor = t(style_image).unsqueeze(0).to(device)
80
+
81
+
82
+ while content_video.isOpened():
83
+ ret, content_image = content_video.read()
84
+ if not ret: # Failed to read a frame
85
+ break
86
+
87
+ content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device)
88
+
89
+ with torch.no_grad():
90
+ out_tensor = style_transfer(content_tensor, style_tensor, model.encoder
91
+ , model.decoder, args.alpha).cpu().detach().numpy()
92
+
93
+ # Convert output frame to original size and rgb range (0,255)
94
+ out_tensor = np.squeeze(out_tensor, axis=0)
95
+ out_tensor = np.transpose(out_tensor, (1, 2, 0))
96
+ out_tensor = cv2.normalize(src=out_tensor, dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
97
+ out_tensor = cv2.resize(out_tensor, (video_width, video_height), interpolation=cv2.INTER_CUBIC)
98
+
99
+ writer.append_data(np.array(out_tensor))
100
+ video_tqdm.update(1)
101
+
102
+ content_video.release()
103
+
104
+ print("\nContent: " + content_video_pth.stem + ". Style: " + style_image_pth.stem +'\n')
105
+
106
+ if __name__ == '__main__':
107
+ main()
train.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ from utils import TrainSet
6
+ from AdaIN import AdaINNet
7
+ from tqdm import tqdm
8
+
9
+ def main():
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument('--content_dir', type=str, required=True, help='content images folder path')
12
+ parser.add_argument('--style_dir', type=str, required=True, help='style images folder path')
13
+ parser.add_argument('--epochs', type=int, default=1, help='Number of epoch')
14
+ parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
15
+ parser.add_argument('--resume', type=int, default=0, help='Continue training from epoch')
16
+ parser.add_argument('--cuda', action='store_true', help='Use CUDA')
17
+ args = parser.parse_args()
18
+
19
+ device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
20
+
21
+ check_point_dir = './check_point/'
22
+ weights_dir = './weights/'
23
+ train_set = TrainSet(args.content_dir, args.style_dir)
24
+ train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True)
25
+
26
+ vgg_model = torch.load('vgg_normalized.pth')
27
+ model = AdaINNet(vgg_model).to(device)
28
+
29
+ decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=1e-6)
30
+ total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
31
+ losses = []
32
+ iteration = 0
33
+
34
+ if args.resume > 0:
35
+ states = torch.load(check_point_dir + "epoch_" + str(args.resume)+'.pth')
36
+ model.decoder.load_state_dict(states['decoder'])
37
+ decoder_optimizer.load_state_dict(states['decoder_optimizer'])
38
+ losses = states['losses']
39
+ iteration = states['iteration']
40
+
41
+
42
+ for epoch in range(args.resume + 1, args.epochs + 1):
43
+ print("Begin epoch: %i/%i" % (epoch, int(args.epochs)))
44
+ train_tqdm = tqdm(train_loader)
45
+ train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
46
+ losses.append((iteration, total_loss, content_loss, style_loss))
47
+ total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
48
+
49
+ for content_batch, style_batch in train_tqdm:
50
+
51
+ content_batch = content_batch.to(device)
52
+ style_batch = style_batch.to(device)
53
+
54
+ loss_content, loss_style = model(content_batch, style_batch)
55
+ loss_scaled = loss_content + 10 * loss_style
56
+ loss_scaled.backward()
57
+ decoder_optimizer.step()
58
+ total_loss += loss_scaled.item() * style_batch.size(0)
59
+ decoder_optimizer.zero_grad()
60
+
61
+ total_num += style_batch.size(0)
62
+
63
+ if iteration % 100 == 0 and iteration > 0:
64
+
65
+ total_loss /= total_num
66
+ content_loss /= total_num
67
+ style_loss /= total_num
68
+ print('')
69
+ train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
70
+
71
+ losses.append((iteration, total_loss, content_loss, style_loss))
72
+
73
+ total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
74
+ total_num = 0
75
+
76
+ if iteration % np.ceil(len(train_loader.dataset)/args.batch_size) == 0 and iteration > 0:
77
+ total_loss /= total_num
78
+ content_loss /= total_num
79
+ style_loss /= total_num
80
+ total_num = 0
81
+
82
+ iteration += 1
83
+
84
+ print('Finished epoch: %i/%i' % (epoch, int(args.epochs)))
85
+
86
+ states = {'decoder': model.decoder.state_dict(), 'decoder_optimizer': decoder_optimizer.state_dict(),
87
+ 'losses': losses, 'iteration': iteration}
88
+ torch.save(states, check_point_dir + 'epoch_%i.pth' % (epoch))
89
+ torch.save(model.decoder.state_dict(), weights_dir + 'decoder_epoch_%i.pth' % (epoch))
90
+ np.savetxt("losses", losses, fmt='%i,%.4f,%.4f,%.4f')
91
+
92
+ if __name__ == '__main__':
93
+ main()
94
+
95
+
96
+
97
+
98
+
99
+
100
+
utils.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image, ImageFile
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ import torchvision.transforms as transforms
6
+ import matplotlib.pyplot as plt
7
+ from pathlib import Path
8
+ from glob import glob
9
+
10
+ def adaptive_instance_normalization(x, y, eps=1e-5):
11
+ """
12
+ Adaptive Instance Normalization. Perform neural style transfer given content image x
13
+ and style image y.
14
+
15
+ Args:
16
+ x (torch.FloatTensor): Content image tensor
17
+ y (torch.FloatTensor): Style image tensor
18
+ eps (float, default=1e-5): Small value to avoid zero division
19
+
20
+ Return:
21
+ output (torch.FloatTensor): AdaIN style transferred output
22
+ """
23
+
24
+ mu_x = torch.mean(x, dim=[2, 3])
25
+ mu_y = torch.mean(y, dim=[2, 3])
26
+ mu_x = mu_x.unsqueeze(-1).unsqueeze(-1)
27
+ mu_y = mu_y.unsqueeze(-1).unsqueeze(-1)
28
+
29
+ sigma_x = torch.std(x, dim=[2, 3])
30
+ sigma_y = torch.std(y, dim=[2, 3])
31
+ sigma_x = sigma_x.unsqueeze(-1).unsqueeze(-1) + eps
32
+ sigma_y = sigma_y.unsqueeze(-1).unsqueeze(-1) + eps
33
+
34
+ return (x - mu_x) / sigma_x * sigma_y + mu_y
35
+
36
+ def transform(size):
37
+ """
38
+ Image preprocess transformation. Resize image and convert to tensor.
39
+
40
+ Args:
41
+ size (int): Resize image size
42
+
43
+ Return:
44
+ output (torchvision.transforms): Composition of torchvision.transforms steps
45
+ """
46
+
47
+ t = []
48
+ t.append(transforms.Resize(size))
49
+ t.append(transforms.ToTensor())
50
+ t = transforms.Compose(t)
51
+ return t
52
+
53
+ def grid_image(row, col, images, height=6, width=6, save_pth='grid.png'):
54
+ """
55
+ Generate and save an image that contains row x col grids of images.
56
+
57
+ Args:
58
+ row (int): number of rows
59
+ col (int): number of columns
60
+ images (list of PIL image): list of images.
61
+ height (int) : height of each image (inch)
62
+ width (int) : width of eac image (inch)
63
+ save_pth (str): save file path
64
+ """
65
+
66
+ width = col * width
67
+ height = row * height
68
+ plt.figure(figsize=(width, height))
69
+ for i, image in enumerate(images):
70
+ plt.subplot(row, col, i+1)
71
+ plt.imshow(image)
72
+ plt.axis('off')
73
+ plt.subplots_adjust(wspace=0.01, hspace=0.01)
74
+ plt.savefig(save_pth)
75
+
76
+
77
+ class TrainSet(Dataset):
78
+ """
79
+ Build Training dataset
80
+ """
81
+ def __init__(self, content_dir, style_dir, crop_size = 256):
82
+ super().__init__()
83
+
84
+ self.content_files = [Path(f) for f in glob(content_dir+'/*')]
85
+ self.style_files = [Path(f) for f in glob(style_dir+'/*')]
86
+
87
+ self.transform = transforms.Compose([
88
+ transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC),
89
+ transforms.RandomCrop(crop_size),
90
+ transforms.ToTensor(),
91
+ transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
92
+ ])
93
+
94
+ Image.MAX_IMAGE_PIXELS = None
95
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
96
+
97
+ def __len__(self):
98
+ return min(len(self.style_files), len(self.content_files))
99
+
100
+ def __getitem__(self, index):
101
+ content_img = Image.open(self.content_files[index]).convert('RGB')
102
+ style_img = Image.open(self.style_files[index]).convert('RGB')
103
+
104
+ content_sample = self.transform(content_img)
105
+ style_sample = self.transform(style_img)
106
+
107
+ return content_sample, style_sample
108
+
109
+ class Range(object):
110
+ """
111
+ Helper class for input argument range restriction
112
+ """
113
+ def __init__(self, start, end):
114
+ self.start = start
115
+ self.end = end
116
+ def __eq__(self, other):
117
+ return self.start <= other <= self.end