subatomicseer commited on
Commit
6b396d0
·
1 Parent(s): b2c93a6

Added Streamlit App

Browse files

Created a Streamlit app for style transfer. Designed the app to be self contained, so that it can be directly uploaded to Huggingface space. Hence all necessary code for the app has been copied to this directory.

- Currently only single image style transfer supported (1 content and 1 style image at a time)
- Supports preserving color of content image (checkbox)
- Support changing alpha (slider)
- downloads the models (vgg and decoder) from the google drive link
- Updated README

- Uploaded the contents of streamlit_app directory to the Huggingface space [link](https://huggingface.co/spaces/subatomicseer/2022-AdaIN-pytorch-Demo)

.gitignore CHANGED
@@ -1,7 +1,163 @@
1
- #Ignore __pycache__
2
- /__pycache__/
3
-
4
  #Ignore results
5
  /results*/
6
 
7
- .idea/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #Ignore results
2
  /results*/
3
 
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/#use-with-ide
113
+ .pdm.toml
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ .idea/
README.md CHANGED
@@ -15,6 +15,26 @@ Install requirements by `$ pip install -r requirements.txt`
15
  - tqdm
16
 
17
  ## Usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  ### Training
20
 
 
15
  - tqdm
16
 
17
  ## Usage
18
+ ### Demo Website
19
+ You can access a demo and perform style transfer at [2022-AdaIN-pytorch-Demo](https://huggingface.co/spaces/subatomicseer/2022-AdaIN-pytorch-Demo) Huggingface Space.
20
+
21
+ ### Local Web App
22
+ If you would like to run the Streamlit app on your local system do the following steps:
23
+
24
+ Install requirements by:
25
+
26
+ `$ pip install -r streamlit_app/requirements.txt`
27
+
28
+ Following additional packages are for the web app:
29
+ - streamlit
30
+ - gdown
31
+ - packaging
32
+
33
+ Run the webapp by:
34
+
35
+ `$ streamlit run streamlit_app/app.py`
36
+
37
+ The above command will open a window in your default browser (if available), and will also display the local url, which you can navigate to, to use the app.
38
 
39
  ### Training
40
 
streamlit_app/AdaIN.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from Network import vgg19, decoder
4
+ from utils import adaptive_instance_normalization
5
+
6
+ class AdaINNet(nn.Module):
7
+ """
8
+ AdaIN Style Transfer Network
9
+
10
+ Args:
11
+ vgg_weight: pretrained vgg19 weight
12
+ """
13
+ def __init__(self, vgg_weight):
14
+ super().__init__()
15
+ self.encoder = vgg19(vgg_weight)
16
+
17
+ # drop layers after 4_1
18
+ self.encoder = nn.Sequential(*list(self.encoder.children())[:22])
19
+
20
+ # No optimization for encoder
21
+ for parameter in self.encoder.parameters():
22
+ parameter.requires_grad = False
23
+
24
+ self.decoder = decoder()
25
+
26
+ self.mseloss = nn.MSELoss()
27
+
28
+ """
29
+ Computes style loss of two images
30
+
31
+ Args:
32
+ x (torch.FloatTensor): content image tensor
33
+ y (torch.FloatTensor): style image tensor
34
+
35
+ Return:
36
+ Mean Squared Error between x.mean, y.mean and MSE between x.std, y.std
37
+ """
38
+ def _style_loss(self, x, y):
39
+ return self.mseloss(torch.mean(x, dim=[2, 3]), torch.mean(y, dim=[2, 3])) + \
40
+ self.mseloss(torch.std(x, dim=[2, 3]), torch.std(y, dim=[2, 3]))
41
+
42
+ def forward(self, content, style, alpha=1.0):
43
+ # Generate image features
44
+ content_enc = self.encoder(content)
45
+ style_enc = self.encoder(style)
46
+
47
+ # Perform style transfer on feature space
48
+ transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
49
+
50
+ # Generate outptu image
51
+ out = self.decoder(transfer_enc)
52
+
53
+ # vgg19 layer relu1_1
54
+ style_relu11 = self.encoder[:3](style)
55
+ out_relu11 = self.encoder[:3](out)
56
+
57
+ # vgg19 layer relu2_1
58
+ style_relu21 = self.encoder[3:8](style_relu11)
59
+ out_relu21 = self.encoder[3:8](out_relu11)
60
+
61
+ # vgg19 layer relu3_1
62
+ style_relu31 = self.encoder[8:13](style_relu21)
63
+ out_relu31 = self.encoder[8:13](out_relu21)
64
+
65
+ # vgg19 layer relu4_1
66
+ out_enc = self.encoder[13:](out_relu31)
67
+
68
+ # Calculate loss
69
+ content_loss = self.mseloss(out_enc, transfer_enc)
70
+ style_loss = self._style_loss(out_relu11, style_relu11) + self._style_loss(out_relu21, style_relu21) + \
71
+ self._style_loss(out_relu31, style_relu31) + self._style_loss(out_enc, style_enc)
72
+
73
+ return content_loss, style_loss
streamlit_app/Network.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ vgg19_cfg = [3, 64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"]
4
+ decoder_cfg = [512, 256, "U", 256, 256, 256, 128, "U", 128, 64, 'U', 64, 3]
5
+
6
+ def vgg19(weights=None):
7
+ """
8
+ Build vgg19 network. Load weights if weights are given.
9
+
10
+ Args:
11
+ weights (dict): vgg19 pretrained weights
12
+
13
+ Return:
14
+ layers (nn.Sequential): vgg19 layers
15
+ """
16
+
17
+ modules = make_block(vgg19_cfg)
18
+ modules = [nn.Conv2d(3, 3, kernel_size=1)] + list(modules.children())
19
+ layers = nn.Sequential(*modules)
20
+
21
+ if weights:
22
+ layers.load_state_dict(weights)
23
+
24
+ return layers
25
+
26
+
27
+ def decoder(weights=None):
28
+ """
29
+ Build decoder network. Load weights if weights are given.
30
+
31
+ Args:
32
+ weights (dict): decoder pretrained weights
33
+
34
+ Return:
35
+ layers (nn.Sequential): decoder layers
36
+ """
37
+
38
+ modules = make_block(decoder_cfg)
39
+ layers = nn.Sequential(*list(modules.children())[:-1]) # no relu at the last layer
40
+
41
+ if weights:
42
+ layers.load_state_dict(weights)
43
+
44
+ return layers
45
+
46
+
47
+ def make_block(config):
48
+ """
49
+ Helper function for building blocks of convolutional layers.
50
+
51
+ Args:
52
+ config (list): List of layer configs. "M"
53
+ "M" - Max pooling layer.
54
+ "U" - Upsampling layer.
55
+ i (int) - Convolutional layer (i filters) plus ReLU activation.
56
+ Return:
57
+ layers (nn.Sequential): block layers
58
+ """
59
+ layers = []
60
+ in_channels = config[0]
61
+
62
+ for c in config[1:]:
63
+ if c == "M":
64
+ layers.append(nn.MaxPool2d(kernel_size=2, stride=2, padding=0))
65
+ elif c == "U":
66
+ layers.append(nn.Upsample(scale_factor=2, mode='nearest'))
67
+ else:
68
+ assert(isinstance(c, int))
69
+ layers.append(nn.Conv2d(in_channels, c, kernel_size=3, padding=1))
70
+ layers.append(nn.ReLU(inplace=True))
71
+ in_channels = c
72
+
73
+ return nn.Sequential(*layers)
streamlit_app/app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import gdown
4
+ from packaging.version import Version
5
+
6
+ from infer_func import convert
7
+
8
+ ROOT = os.path.dirname(os.path.abspath(__file__))
9
+
10
+ EXAMPLES = {
11
+ 'content': {
12
+ 'Brad Pitt': ROOT + '/examples/content/brad_pitt.jpg'
13
+ },
14
+ 'style': {
15
+ 'Flower of Life': ROOT + '/examples/style/flower_of_life.jpg'
16
+ }
17
+ }
18
+
19
+ VGG_WEIGHT_URL = 'https://drive.google.com/uc?id=1UcSl-Zn3byEmn15NIPXMf9zaGCKc2gfx'
20
+ DECODER_WEIGHT_URL = 'https://drive.google.com/uc?id=18JpLtMOapA-vwBz-LRomyTl24A9GwhTF'
21
+
22
+ VGG_WEIGHT_FILENAME = ROOT + '/vgg.pth'
23
+ DECODER_WEIGHT_FILENAME = ROOT + '/decoder.pth'
24
+
25
+
26
+ @st.cache
27
+ def download_models():
28
+ with st.spinner(text="Downloading VGG weights..."):
29
+ gdown.download(VGG_WEIGHT_URL, output=VGG_WEIGHT_FILENAME)
30
+ with st.spinner(text="Downloading Decoder weights..."):
31
+ gdown.download(DECODER_WEIGHT_URL, output=DECODER_WEIGHT_FILENAME)
32
+
33
+
34
+ def image_getter(image_kind):
35
+
36
+ image = None
37
+
38
+ options = ['Use Example Image', 'Upload Image']
39
+
40
+ if Version(st.__version__) >= Version('1.4.0'):
41
+ options.append('Open Camera')
42
+
43
+ option = st.selectbox(
44
+ 'Choose Image',
45
+ options, key=image_kind)
46
+
47
+ if option == 'Use Example Image':
48
+ image_key = st.selectbox(
49
+ 'Choose from examples',
50
+ EXAMPLES[image_kind], key=image_kind)
51
+ image = EXAMPLES[image_kind][image_key]
52
+
53
+ elif option == 'Upload Image':
54
+ image = st.file_uploader("Upload an image", type=['png', 'jpg', 'PNG', 'JPG', 'JPEG'], key=image_kind)
55
+ elif option == 'Open Camera':
56
+ image = st.camera_input('', key=image_kind)
57
+
58
+ return image
59
+
60
+
61
+ if __name__ == '__main__':
62
+
63
+ st.set_page_config(layout="wide")
64
+ st.header('Adaptive Instance Normalization demo based on '
65
+ '[2022-AdaIN-pytorch](https://github.com/media-comp/2022-AdaIN-pytorch)')
66
+
67
+ download_models()
68
+ # col1, col2, col3, col4 = st.columns((2, 2, 1, 3))
69
+ col1, col2, col3 = st.columns((3, 4, 4))
70
+ with col1:
71
+ st.subheader('Content Image')
72
+ content = image_getter('content')
73
+ st.subheader('Style Image')
74
+ style = image_getter('style')
75
+ with col2:
76
+ img1 = content if content is not None else 'examples/img.png'
77
+ img2 = style if style is not None else 'examples/img.png'
78
+ if img1 is not None:
79
+ st.image(img1, width=None, caption='Content Image')
80
+ if img2 is not None:
81
+ st.image(img2, width=None, caption='Style Image')
82
+
83
+ with col3:
84
+ color_control = st.checkbox('Preserve content image color')
85
+ alpha = st.slider('Strength of style transfer', 0.0, 1.0, 1.0, 0.01)
86
+ process = st.button('Stylize')
87
+
88
+ if content is not None and style is not None and process:
89
+ print(content, style)
90
+ with col3:
91
+ with st.spinner('Processing...'):
92
+ output_image = convert(content, style, VGG_WEIGHT_FILENAME, DECODER_WEIGHT_FILENAME, alpha, color_control)
93
+
94
+ st.image(output_image, width=None, caption='Stylized Image')
95
+
streamlit_app/infer_func.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms
3
+ from PIL import Image
4
+
5
+ from AdaIN import AdaINNet
6
+ from utils import adaptive_instance_normalization, transform, linear_histogram_matching
7
+
8
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+
10
+
11
+ def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
12
+ """
13
+ Given content image and style image, generate feature maps with encoder, apply
14
+ neural style transfer with adaptive instance normalization, generate output image
15
+ with decoder
16
+
17
+ Args:
18
+ content_tensor (torch.FloatTensor): Content image
19
+ style_tensor (torch.FloatTensor): Style Image
20
+ encoder: Encoder (vgg19) network
21
+ decoder: Decoder network
22
+ alpha (float, default=1.0): Weight of style image feature
23
+
24
+ Return:
25
+ output_tensor (torch.FloatTensor): Style Transfer output image
26
+ """
27
+
28
+ content_enc = encoder(content_tensor)
29
+ style_enc = encoder(style_tensor)
30
+
31
+ transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
32
+
33
+ mix_enc = alpha * transfer_enc + (1 - alpha) * content_enc
34
+ return decoder(mix_enc)
35
+
36
+
37
+ def convert(content_path, style_path, vgg_weights_path, decoder_weights_path, alpha, color_control):
38
+
39
+ vgg = torch.load(vgg_weights_path)
40
+ model = AdaINNet(vgg).to(device)
41
+ model.decoder.load_state_dict(torch.load(decoder_weights_path))
42
+ model.eval()
43
+
44
+ # Prepare image transform
45
+ t = transform(512)
46
+
47
+ # load images
48
+ content_img = Image.open(content_path)
49
+ content_tensor = t(content_img).unsqueeze(0).to(device)
50
+ style_tensor = t(Image.open(style_path)).unsqueeze(0).to(device)
51
+
52
+ if color_control:
53
+ style_tensor = linear_histogram_matching(content_tensor, style_tensor)
54
+
55
+ with torch.no_grad():
56
+ out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, alpha).cpu()
57
+
58
+ output_image = torchvision.transforms.ToPILImage()(out_tensor.squeeze(0))
59
+
60
+ return output_image
streamlit_app/requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.10.1
2
+ torchvision==0.11.2
3
+ opencv-python==4.5.1.48
4
+ numpy == 1.18.4
5
+ Pillow==8.4.0
6
+ tqdm==4.62.3
7
+ imageio==2.9.0
8
+ imageio-ffmpeg==0.4.6
9
+ matplotlib==3.3.2
10
+ gdown
11
+ packaging
12
+ streamlit
streamlit_app/utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image, ImageFile
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ import torchvision.transforms as transforms
6
+ import matplotlib.pyplot as plt
7
+ from pathlib import Path
8
+ from glob import glob
9
+
10
+ def adaptive_instance_normalization(x, y, eps=1e-5):
11
+ """
12
+ Adaptive Instance Normalization. Perform neural style transfer given content image x
13
+ and style image y.
14
+
15
+ Args:
16
+ x (torch.FloatTensor): Content image tensor
17
+ y (torch.FloatTensor): Style image tensor
18
+ eps (float, default=1e-5): Small value to avoid zero division
19
+
20
+ Return:
21
+ output (torch.FloatTensor): AdaIN style transferred output
22
+ """
23
+
24
+ mu_x = torch.mean(x, dim=[2, 3])
25
+ mu_y = torch.mean(y, dim=[2, 3])
26
+ mu_x = mu_x.unsqueeze(-1).unsqueeze(-1)
27
+ mu_y = mu_y.unsqueeze(-1).unsqueeze(-1)
28
+
29
+ sigma_x = torch.std(x, dim=[2, 3])
30
+ sigma_y = torch.std(y, dim=[2, 3])
31
+ sigma_x = sigma_x.unsqueeze(-1).unsqueeze(-1) + eps
32
+ sigma_y = sigma_y.unsqueeze(-1).unsqueeze(-1) + eps
33
+
34
+ return (x - mu_x) / sigma_x * sigma_y + mu_y
35
+
36
+ def transform(size):
37
+ """
38
+ Image preprocess transformation. Resize image and convert to tensor.
39
+
40
+ Args:
41
+ size (int): Resize image size
42
+
43
+ Return:
44
+ output (torchvision.transforms): Composition of torchvision.transforms steps
45
+ """
46
+
47
+ t = []
48
+ t.append(transforms.Resize(size))
49
+ t.append(transforms.ToTensor())
50
+ t = transforms.Compose(t)
51
+ return t
52
+
53
+ def grid_image(row, col, images, height=6, width=6, save_pth='grid.png'):
54
+ """
55
+ Generate and save an image that contains row x col grids of images.
56
+
57
+ Args:
58
+ row (int): number of rows
59
+ col (int): number of columns
60
+ images (list of PIL image): list of images.
61
+ height (int) : height of each image (inch)
62
+ width (int) : width of eac image (inch)
63
+ save_pth (str): save file path
64
+ """
65
+
66
+ width = col * width
67
+ height = row * height
68
+ plt.figure(figsize=(width, height))
69
+ for i, image in enumerate(images):
70
+ plt.subplot(row, col, i+1)
71
+ plt.imshow(image)
72
+ plt.axis('off')
73
+ plt.subplots_adjust(wspace=0.01, hspace=0.01)
74
+ plt.savefig(save_pth)
75
+
76
+
77
+ def linear_histogram_matching(content_tensor, style_tensor):
78
+ """
79
+ Given content_tensor and style_tensor, transform style_tensor histogram to that of content_tensor.
80
+
81
+ Args:
82
+ content_tensor (torch.FloatTensor): Content image
83
+ style_tensor (torch.FloatTensor): Style Image
84
+
85
+ Return:
86
+ style_tensor (torch.FloatTensor): histogram matched Style Image
87
+ """
88
+ #for batch
89
+ for b in range(len(content_tensor)):
90
+ std_ct = []
91
+ std_st = []
92
+ mean_ct = []
93
+ mean_st = []
94
+ #for channel
95
+ for c in range(len(content_tensor[b])):
96
+ std_ct.append(torch.var(content_tensor[b][c],unbiased = False))
97
+ mean_ct.append(torch.mean(content_tensor[b][c]))
98
+ std_st.append(torch.var(style_tensor[b][c],unbiased = False))
99
+ mean_st.append(torch.mean(style_tensor[b][c]))
100
+ style_tensor[b][c] = (style_tensor[b][c] - mean_st[c]) * std_ct[c] / std_st[c] + mean_ct[c]
101
+ return style_tensor
102
+
103
+
104
+ class TrainSet(Dataset):
105
+ """
106
+ Build Training dataset
107
+ """
108
+ def __init__(self, content_dir, style_dir, crop_size = 256):
109
+ super().__init__()
110
+
111
+ self.content_files = [Path(f) for f in glob(content_dir+'/*')]
112
+ self.style_files = [Path(f) for f in glob(style_dir+'/*')]
113
+
114
+ self.transform = transforms.Compose([
115
+ transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC),
116
+ transforms.RandomCrop(crop_size),
117
+ transforms.ToTensor(),
118
+ transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
119
+ ])
120
+
121
+ Image.MAX_IMAGE_PIXELS = None
122
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
123
+
124
+ def __len__(self):
125
+ return min(len(self.style_files), len(self.content_files))
126
+
127
+ def __getitem__(self, index):
128
+ content_img = Image.open(self.content_files[index]).convert('RGB')
129
+ style_img = Image.open(self.style_files[index]).convert('RGB')
130
+
131
+ content_sample = self.transform(content_img)
132
+ style_sample = self.transform(style_img)
133
+
134
+ return content_sample, style_sample
135
+
136
+ class Range(object):
137
+ """
138
+ Helper class for input argument range restriction
139
+ """
140
+ def __init__(self, start, end):
141
+ self.start = start
142
+ self.end = end
143
+ def __eq__(self, other):
144
+ return self.start <= other <= self.end