subatomicseer commited on
Commit
7999e5a
·
1 Parent(s): fb09963

Initial Commit

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .idea/
2
+ __pycache__/
3
+ *.pth
AdaIN.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from Network import vgg19, decoder
4
+ from utils import adaptive_instance_normalization
5
+
6
+ class AdaINNet(nn.Module):
7
+ """
8
+ AdaIN Style Transfer Network
9
+
10
+ Args:
11
+ vgg_weight: pretrained vgg19 weight
12
+ """
13
+ def __init__(self, vgg_weight):
14
+ super().__init__()
15
+ self.encoder = vgg19(vgg_weight)
16
+
17
+ # drop layers after 4_1
18
+ self.encoder = nn.Sequential(*list(self.encoder.children())[:22])
19
+
20
+ # No optimization for encoder
21
+ for parameter in self.encoder.parameters():
22
+ parameter.requires_grad = False
23
+
24
+ self.decoder = decoder()
25
+
26
+ self.mseloss = nn.MSELoss()
27
+
28
+ """
29
+ Computes style loss of two images
30
+
31
+ Args:
32
+ x (torch.FloatTensor): content image tensor
33
+ y (torch.FloatTensor): style image tensor
34
+
35
+ Return:
36
+ Mean Squared Error between x.mean, y.mean and MSE between x.std, y.std
37
+ """
38
+ def _style_loss(self, x, y):
39
+ return self.mseloss(torch.mean(x, dim=[2, 3]), torch.mean(y, dim=[2, 3])) + \
40
+ self.mseloss(torch.std(x, dim=[2, 3]), torch.std(y, dim=[2, 3]))
41
+
42
+ def forward(self, content, style, alpha=1.0):
43
+ # Generate image features
44
+ content_enc = self.encoder(content)
45
+ style_enc = self.encoder(style)
46
+
47
+ # Perform style transfer on feature space
48
+ transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
49
+
50
+ # Generate outptu image
51
+ out = self.decoder(transfer_enc)
52
+
53
+ # vgg19 layer relu1_1
54
+ style_relu11 = self.encoder[:3](style)
55
+ out_relu11 = self.encoder[:3](out)
56
+
57
+ # vgg19 layer relu2_1
58
+ style_relu21 = self.encoder[3:8](style_relu11)
59
+ out_relu21 = self.encoder[3:8](out_relu11)
60
+
61
+ # vgg19 layer relu3_1
62
+ style_relu31 = self.encoder[8:13](style_relu21)
63
+ out_relu31 = self.encoder[8:13](out_relu21)
64
+
65
+ # vgg19 layer relu4_1
66
+ out_enc = self.encoder[13:](out_relu31)
67
+
68
+ # Calculate loss
69
+ content_loss = self.mseloss(out_enc, transfer_enc)
70
+ style_loss = self._style_loss(out_relu11, style_relu11) + self._style_loss(out_relu21, style_relu21) + \
71
+ self._style_loss(out_relu31, style_relu31) + self._style_loss(out_enc, style_enc)
72
+
73
+ return content_loss, style_loss
Network.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ vgg19_cfg = [3, 64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"]
4
+ decoder_cfg = [512, 256, "U", 256, 256, 256, 128, "U", 128, 64, 'U', 64, 3]
5
+
6
+ def vgg19(weights=None):
7
+ """
8
+ Build vgg19 network. Load weights if weights are given.
9
+
10
+ Args:
11
+ weights (dict): vgg19 pretrained weights
12
+
13
+ Return:
14
+ layers (nn.Sequential): vgg19 layers
15
+ """
16
+
17
+ modules = make_block(vgg19_cfg)
18
+ modules = [nn.Conv2d(3, 3, kernel_size=1)] + list(modules.children())
19
+ layers = nn.Sequential(*modules)
20
+
21
+ if weights:
22
+ layers.load_state_dict(weights)
23
+
24
+ return layers
25
+
26
+
27
+ def decoder(weights=None):
28
+ """
29
+ Build decoder network. Load weights if weights are given.
30
+
31
+ Args:
32
+ weights (dict): decoder pretrained weights
33
+
34
+ Return:
35
+ layers (nn.Sequential): decoder layers
36
+ """
37
+
38
+ modules = make_block(decoder_cfg)
39
+ layers = nn.Sequential(*list(modules.children())[:-1]) # no relu at the last layer
40
+
41
+ if weights:
42
+ layers.load_state_dict(weights)
43
+
44
+ return layers
45
+
46
+
47
+ def make_block(config):
48
+ """
49
+ Helper function for building blocks of convolutional layers.
50
+
51
+ Args:
52
+ config (list): List of layer configs. "M"
53
+ "M" - Max pooling layer.
54
+ "U" - Upsampling layer.
55
+ i (int) - Convolutional layer (i filters) plus ReLU activation.
56
+ Return:
57
+ layers (nn.Sequential): block layers
58
+ """
59
+ layers = []
60
+ in_channels = config[0]
61
+
62
+ for c in config[1:]:
63
+ if c == "M":
64
+ layers.append(nn.MaxPool2d(kernel_size=2, stride=2, padding=0))
65
+ elif c == "U":
66
+ layers.append(nn.Upsample(scale_factor=2, mode='nearest'))
67
+ else:
68
+ assert(isinstance(c, int))
69
+ layers.append(nn.Conv2d(in_channels, c, kernel_size=3, padding=1))
70
+ layers.append(nn.ReLU(inplace=True))
71
+ in_channels = c
72
+
73
+ return nn.Sequential(*layers)
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: 2022 AdaIN Pytorch Demo
3
- emoji:
4
- colorFrom: blue
5
- colorTo: gray
6
  sdk: streamlit
7
  sdk_version: 1.2.0
8
  app_file: app.py
 
1
  ---
2
+ title: AdaIN
3
+ emoji: 📚
4
+ colorFrom: red
5
+ colorTo: indigo
6
  sdk: streamlit
7
  sdk_version: 1.2.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import gdown
4
+ from packaging.version import Version
5
+
6
+ from infer_func import convert
7
+
8
+ ROOT = os.path.dirname(os.path.abspath(__file__))
9
+
10
+ EXAMPLES = {
11
+ 'content': {
12
+ 'Brad Pitt': ROOT + '/examples/content/brad_pitt.jpg'
13
+ },
14
+ 'style': {
15
+ 'Flower of Life': ROOT + '/examples/style/flower_of_life.jpg'
16
+ }
17
+ }
18
+
19
+ VGG_WEIGHT_URL = 'https://drive.google.com/uc?id=1UcSl-Zn3byEmn15NIPXMf9zaGCKc2gfx'
20
+ DECODER_WEIGHT_URL = 'https://drive.google.com/uc?id=18JpLtMOapA-vwBz-LRomyTl24A9GwhTF'
21
+
22
+ VGG_WEIGHT_FILENAME = ROOT + '/vgg.pth'
23
+ DECODER_WEIGHT_FILENAME = ROOT + '/decoder.pth'
24
+
25
+
26
+ @st.cache
27
+ def download_models():
28
+ with st.spinner(text="Downloading VGG weights..."):
29
+ gdown.download(VGG_WEIGHT_URL, output=VGG_WEIGHT_FILENAME)
30
+ with st.spinner(text="Downloading Decoder weights..."):
31
+ gdown.download(DECODER_WEIGHT_URL, output=DECODER_WEIGHT_FILENAME)
32
+
33
+
34
+ def image_getter(image_kind):
35
+
36
+ image = None
37
+
38
+ options = ['Use Example Image', 'Upload Image']
39
+
40
+ if Version(st.__version__) >= Version('1.4.0'):
41
+ options.append('Open Camera')
42
+
43
+ option = st.selectbox(
44
+ 'Choose Image',
45
+ options, key=image_kind)
46
+
47
+ if option == 'Use Example Image':
48
+ image_key = st.selectbox(
49
+ 'Choose from examples',
50
+ EXAMPLES[image_kind], key=image_kind)
51
+ image = EXAMPLES[image_kind][image_key]
52
+
53
+ elif option == 'Upload Image':
54
+ image = st.file_uploader("Upload an image", type=['png', 'jpg', 'PNG', 'JPG', 'JPEG'], key=image_kind)
55
+ elif option == 'Open Camera':
56
+ image = st.camera_input('', key=image_kind)
57
+
58
+ return image
59
+
60
+
61
+ if __name__ == '__main__':
62
+
63
+ st.set_page_config(layout="wide")
64
+ st.header('Adaptive Instance Normalization demo based on '
65
+ '[2022-AdaIN-pytorch](https://github.com/media-comp/2022-AdaIN-pytorch)')
66
+
67
+ download_models()
68
+ # col1, col2, col3, col4 = st.columns((2, 2, 1, 3))
69
+ col1, col2, col3 = st.columns((3, 4, 4))
70
+ with col1:
71
+ st.subheader('Content Image')
72
+ content = image_getter('content')
73
+ st.subheader('Style Image')
74
+ style = image_getter('style')
75
+ with col2:
76
+ img1 = content if content is not None else 'examples/img.png'
77
+ img2 = style if style is not None else 'examples/img.png'
78
+ if img1 is not None:
79
+ st.image(img1, width=None, caption='Content Image')
80
+ if img2 is not None:
81
+ st.image(img2, width=None, caption='Style Image')
82
+
83
+ with col3:
84
+ color_control = st.checkbox('Preserve content image color')
85
+ alpha = st.slider('Strength of style transfer', 0.0, 1.0, 1.0, 0.01)
86
+ process = st.button('Stylize')
87
+
88
+ if content is not None and style is not None and process:
89
+ print(content, style)
90
+ with col3:
91
+ with st.spinner('Processing...'):
92
+ output_image = convert(content, style, VGG_WEIGHT_FILENAME, DECODER_WEIGHT_FILENAME, alpha, color_control)
93
+
94
+ st.image(output_image, width=None, caption='Stylized Image')
95
+
examples/content/brad_pitt.jpg ADDED
examples/img.png ADDED
examples/style/flower_of_life.jpg ADDED
examples/style/sketch.jpg ADDED
infer_func.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms
3
+ from PIL import Image
4
+
5
+ from AdaIN import AdaINNet
6
+ from utils import adaptive_instance_normalization, transform, linear_histogram_matching
7
+
8
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+
10
+
11
+ def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
12
+ """
13
+ Given content image and style image, generate feature maps with encoder, apply
14
+ neural style transfer with adaptive instance normalization, generate output image
15
+ with decoder
16
+
17
+ Args:
18
+ content_tensor (torch.FloatTensor): Content image
19
+ style_tensor (torch.FloatTensor): Style Image
20
+ encoder: Encoder (vgg19) network
21
+ decoder: Decoder network
22
+ alpha (float, default=1.0): Weight of style image feature
23
+
24
+ Return:
25
+ output_tensor (torch.FloatTensor): Style Transfer output image
26
+ """
27
+
28
+ content_enc = encoder(content_tensor)
29
+ style_enc = encoder(style_tensor)
30
+
31
+ transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
32
+
33
+ mix_enc = alpha * transfer_enc + (1 - alpha) * content_enc
34
+ return decoder(mix_enc)
35
+
36
+
37
+ def convert(content_path, style_path, vgg_weights_path, decoder_weights_path, alpha, color_control):
38
+
39
+ vgg = torch.load(vgg_weights_path)
40
+ model = AdaINNet(vgg).to(device)
41
+ model.decoder.load_state_dict(torch.load(decoder_weights_path))
42
+ model.eval()
43
+
44
+ # Prepare image transform
45
+ t = transform(512)
46
+
47
+ # load images
48
+ content_img = Image.open(content_path)
49
+ content_tensor = t(content_img).unsqueeze(0).to(device)
50
+ style_tensor = t(Image.open(style_path)).unsqueeze(0).to(device)
51
+
52
+ if color_control:
53
+ style_tensor = linear_histogram_matching(content_tensor, style_tensor)
54
+
55
+ with torch.no_grad():
56
+ out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, alpha).cpu()
57
+
58
+ output_image = torchvision.transforms.ToPILImage()(out_tensor.squeeze(0))
59
+
60
+ return output_image
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.10.1
2
+ torchvision==0.11.2
3
+ opencv-python==4.5.1.48
4
+ numpy == 1.18.4
5
+ Pillow==8.4.0
6
+ tqdm==4.62.3
7
+ imageio==2.9.0
8
+ imageio-ffmpeg==0.4.6
9
+ matplotlib==3.3.2
10
+ gdown
11
+ packaging
utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image, ImageFile
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ import torchvision.transforms as transforms
6
+ import matplotlib.pyplot as plt
7
+ from pathlib import Path
8
+ from glob import glob
9
+
10
+ def adaptive_instance_normalization(x, y, eps=1e-5):
11
+ """
12
+ Adaptive Instance Normalization. Perform neural style transfer given content image x
13
+ and style image y.
14
+
15
+ Args:
16
+ x (torch.FloatTensor): Content image tensor
17
+ y (torch.FloatTensor): Style image tensor
18
+ eps (float, default=1e-5): Small value to avoid zero division
19
+
20
+ Return:
21
+ output (torch.FloatTensor): AdaIN style transferred output
22
+ """
23
+
24
+ mu_x = torch.mean(x, dim=[2, 3])
25
+ mu_y = torch.mean(y, dim=[2, 3])
26
+ mu_x = mu_x.unsqueeze(-1).unsqueeze(-1)
27
+ mu_y = mu_y.unsqueeze(-1).unsqueeze(-1)
28
+
29
+ sigma_x = torch.std(x, dim=[2, 3])
30
+ sigma_y = torch.std(y, dim=[2, 3])
31
+ sigma_x = sigma_x.unsqueeze(-1).unsqueeze(-1) + eps
32
+ sigma_y = sigma_y.unsqueeze(-1).unsqueeze(-1) + eps
33
+
34
+ return (x - mu_x) / sigma_x * sigma_y + mu_y
35
+
36
+ def transform(size):
37
+ """
38
+ Image preprocess transformation. Resize image and convert to tensor.
39
+
40
+ Args:
41
+ size (int): Resize image size
42
+
43
+ Return:
44
+ output (torchvision.transforms): Composition of torchvision.transforms steps
45
+ """
46
+
47
+ t = []
48
+ t.append(transforms.Resize(size))
49
+ t.append(transforms.ToTensor())
50
+ t = transforms.Compose(t)
51
+ return t
52
+
53
+ def grid_image(row, col, images, height=6, width=6, save_pth='grid.png'):
54
+ """
55
+ Generate and save an image that contains row x col grids of images.
56
+
57
+ Args:
58
+ row (int): number of rows
59
+ col (int): number of columns
60
+ images (list of PIL image): list of images.
61
+ height (int) : height of each image (inch)
62
+ width (int) : width of eac image (inch)
63
+ save_pth (str): save file path
64
+ """
65
+
66
+ width = col * width
67
+ height = row * height
68
+ plt.figure(figsize=(width, height))
69
+ for i, image in enumerate(images):
70
+ plt.subplot(row, col, i+1)
71
+ plt.imshow(image)
72
+ plt.axis('off')
73
+ plt.subplots_adjust(wspace=0.01, hspace=0.01)
74
+ plt.savefig(save_pth)
75
+
76
+
77
+ def linear_histogram_matching(content_tensor, style_tensor):
78
+ """
79
+ Given content_tensor and style_tensor, transform style_tensor histogram to that of content_tensor.
80
+
81
+ Args:
82
+ content_tensor (torch.FloatTensor): Content image
83
+ style_tensor (torch.FloatTensor): Style Image
84
+
85
+ Return:
86
+ style_tensor (torch.FloatTensor): histogram matched Style Image
87
+ """
88
+ #for batch
89
+ for b in range(len(content_tensor)):
90
+ std_ct = []
91
+ std_st = []
92
+ mean_ct = []
93
+ mean_st = []
94
+ #for channel
95
+ for c in range(len(content_tensor[b])):
96
+ std_ct.append(torch.var(content_tensor[b][c],unbiased = False))
97
+ mean_ct.append(torch.mean(content_tensor[b][c]))
98
+ std_st.append(torch.var(style_tensor[b][c],unbiased = False))
99
+ mean_st.append(torch.mean(style_tensor[b][c]))
100
+ style_tensor[b][c] = (style_tensor[b][c] - mean_st[c]) * std_ct[c] / std_st[c] + mean_ct[c]
101
+ return style_tensor
102
+
103
+
104
+ class TrainSet(Dataset):
105
+ """
106
+ Build Training dataset
107
+ """
108
+ def __init__(self, content_dir, style_dir, crop_size = 256):
109
+ super().__init__()
110
+
111
+ self.content_files = [Path(f) for f in glob(content_dir+'/*')]
112
+ self.style_files = [Path(f) for f in glob(style_dir+'/*')]
113
+
114
+ self.transform = transforms.Compose([
115
+ transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC),
116
+ transforms.RandomCrop(crop_size),
117
+ transforms.ToTensor(),
118
+ transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
119
+ ])
120
+
121
+ Image.MAX_IMAGE_PIXELS = None
122
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
123
+
124
+ def __len__(self):
125
+ return min(len(self.style_files), len(self.content_files))
126
+
127
+ def __getitem__(self, index):
128
+ content_img = Image.open(self.content_files[index]).convert('RGB')
129
+ style_img = Image.open(self.style_files[index]).convert('RGB')
130
+
131
+ content_sample = self.transform(content_img)
132
+ style_sample = self.transform(style_img)
133
+
134
+ return content_sample, style_sample
135
+
136
+ class Range(object):
137
+ """
138
+ Helper class for input argument range restriction
139
+ """
140
+ def __init__(self, start, end):
141
+ self.start = start
142
+ self.end = end
143
+ def __eq__(self, other):
144
+ return self.start <= other <= self.end