Commit
·
6b396d0
1
Parent(s):
b2c93a6
Added Streamlit App
Browse filesCreated a Streamlit app for style transfer. Designed the app to be self contained, so that it can be directly uploaded to Huggingface space. Hence all necessary code for the app has been copied to this directory.
- Currently only single image style transfer supported (1 content and 1 style image at a time)
- Supports preserving color of content image (checkbox)
- Support changing alpha (slider)
- downloads the models (vgg and decoder) from the google drive link
- Updated README
- Uploaded the contents of streamlit_app directory to the Huggingface space [link](https://huggingface.co/spaces/subatomicseer/2022-AdaIN-pytorch-Demo)
- .gitignore +160 -4
- README.md +20 -0
- streamlit_app/AdaIN.py +73 -0
- streamlit_app/Network.py +73 -0
- streamlit_app/app.py +95 -0
- streamlit_app/infer_func.py +60 -0
- streamlit_app/requirements.txt +12 -0
- streamlit_app/utils.py +144 -0
.gitignore
CHANGED
@@ -1,7 +1,163 @@
|
|
1 |
-
#Ignore __pycache__
|
2 |
-
/__pycache__/
|
3 |
-
|
4 |
#Ignore results
|
5 |
/results*/
|
6 |
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
#Ignore results
|
2 |
/results*/
|
3 |
|
4 |
+
# Byte-compiled / optimized / DLL files
|
5 |
+
__pycache__/
|
6 |
+
*.py[cod]
|
7 |
+
*$py.class
|
8 |
+
|
9 |
+
# C extensions
|
10 |
+
*.so
|
11 |
+
|
12 |
+
# Distribution / packaging
|
13 |
+
.Python
|
14 |
+
build/
|
15 |
+
develop-eggs/
|
16 |
+
dist/
|
17 |
+
downloads/
|
18 |
+
eggs/
|
19 |
+
.eggs/
|
20 |
+
lib/
|
21 |
+
lib64/
|
22 |
+
parts/
|
23 |
+
sdist/
|
24 |
+
var/
|
25 |
+
wheels/
|
26 |
+
share/python-wheels/
|
27 |
+
*.egg-info/
|
28 |
+
.installed.cfg
|
29 |
+
*.egg
|
30 |
+
MANIFEST
|
31 |
+
|
32 |
+
# PyInstaller
|
33 |
+
# Usually these files are written by a python script from a template
|
34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
35 |
+
*.manifest
|
36 |
+
*.spec
|
37 |
+
|
38 |
+
# Installer logs
|
39 |
+
pip-log.txt
|
40 |
+
pip-delete-this-directory.txt
|
41 |
+
|
42 |
+
# Unit test / coverage reports
|
43 |
+
htmlcov/
|
44 |
+
.tox/
|
45 |
+
.nox/
|
46 |
+
.coverage
|
47 |
+
.coverage.*
|
48 |
+
.cache
|
49 |
+
nosetests.xml
|
50 |
+
coverage.xml
|
51 |
+
*.cover
|
52 |
+
*.py,cover
|
53 |
+
.hypothesis/
|
54 |
+
.pytest_cache/
|
55 |
+
cover/
|
56 |
+
|
57 |
+
# Translations
|
58 |
+
*.mo
|
59 |
+
*.pot
|
60 |
+
|
61 |
+
# Django stuff:
|
62 |
+
*.log
|
63 |
+
local_settings.py
|
64 |
+
db.sqlite3
|
65 |
+
db.sqlite3-journal
|
66 |
+
|
67 |
+
# Flask stuff:
|
68 |
+
instance/
|
69 |
+
.webassets-cache
|
70 |
+
|
71 |
+
# Scrapy stuff:
|
72 |
+
.scrapy
|
73 |
+
|
74 |
+
# Sphinx documentation
|
75 |
+
docs/_build/
|
76 |
+
|
77 |
+
# PyBuilder
|
78 |
+
.pybuilder/
|
79 |
+
target/
|
80 |
+
|
81 |
+
# Jupyter Notebook
|
82 |
+
.ipynb_checkpoints
|
83 |
+
|
84 |
+
# IPython
|
85 |
+
profile_default/
|
86 |
+
ipython_config.py
|
87 |
+
|
88 |
+
# pyenv
|
89 |
+
# For a library or package, you might want to ignore these files since the code is
|
90 |
+
# intended to run in multiple environments; otherwise, check them in:
|
91 |
+
# .python-version
|
92 |
+
|
93 |
+
# pipenv
|
94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
97 |
+
# install all needed dependencies.
|
98 |
+
#Pipfile.lock
|
99 |
+
|
100 |
+
# poetry
|
101 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
102 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
103 |
+
# commonly ignored for libraries.
|
104 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
105 |
+
#poetry.lock
|
106 |
+
|
107 |
+
# pdm
|
108 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
109 |
+
#pdm.lock
|
110 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
111 |
+
# in version control.
|
112 |
+
# https://pdm.fming.dev/#use-with-ide
|
113 |
+
.pdm.toml
|
114 |
+
|
115 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
116 |
+
__pypackages__/
|
117 |
+
|
118 |
+
# Celery stuff
|
119 |
+
celerybeat-schedule
|
120 |
+
celerybeat.pid
|
121 |
+
|
122 |
+
# SageMath parsed files
|
123 |
+
*.sage.py
|
124 |
+
|
125 |
+
# Environments
|
126 |
+
.env
|
127 |
+
.venv
|
128 |
+
env/
|
129 |
+
venv/
|
130 |
+
ENV/
|
131 |
+
env.bak/
|
132 |
+
venv.bak/
|
133 |
+
|
134 |
+
# Spyder project settings
|
135 |
+
.spyderproject
|
136 |
+
.spyproject
|
137 |
+
|
138 |
+
# Rope project settings
|
139 |
+
.ropeproject
|
140 |
+
|
141 |
+
# mkdocs documentation
|
142 |
+
/site
|
143 |
+
|
144 |
+
# mypy
|
145 |
+
.mypy_cache/
|
146 |
+
.dmypy.json
|
147 |
+
dmypy.json
|
148 |
+
|
149 |
+
# Pyre type checker
|
150 |
+
.pyre/
|
151 |
+
|
152 |
+
# pytype static type analyzer
|
153 |
+
.pytype/
|
154 |
+
|
155 |
+
# Cython debug symbols
|
156 |
+
cython_debug/
|
157 |
+
|
158 |
+
# PyCharm
|
159 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
160 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
161 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
162 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
163 |
+
.idea/
|
README.md
CHANGED
@@ -15,6 +15,26 @@ Install requirements by `$ pip install -r requirements.txt`
|
|
15 |
- tqdm
|
16 |
|
17 |
## Usage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
### Training
|
20 |
|
|
|
15 |
- tqdm
|
16 |
|
17 |
## Usage
|
18 |
+
### Demo Website
|
19 |
+
You can access a demo and perform style transfer at [2022-AdaIN-pytorch-Demo](https://huggingface.co/spaces/subatomicseer/2022-AdaIN-pytorch-Demo) Huggingface Space.
|
20 |
+
|
21 |
+
### Local Web App
|
22 |
+
If you would like to run the Streamlit app on your local system do the following steps:
|
23 |
+
|
24 |
+
Install requirements by:
|
25 |
+
|
26 |
+
`$ pip install -r streamlit_app/requirements.txt`
|
27 |
+
|
28 |
+
Following additional packages are for the web app:
|
29 |
+
- streamlit
|
30 |
+
- gdown
|
31 |
+
- packaging
|
32 |
+
|
33 |
+
Run the webapp by:
|
34 |
+
|
35 |
+
`$ streamlit run streamlit_app/app.py`
|
36 |
+
|
37 |
+
The above command will open a window in your default browser (if available), and will also display the local url, which you can navigate to, to use the app.
|
38 |
|
39 |
### Training
|
40 |
|
streamlit_app/AdaIN.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from Network import vgg19, decoder
|
4 |
+
from utils import adaptive_instance_normalization
|
5 |
+
|
6 |
+
class AdaINNet(nn.Module):
|
7 |
+
"""
|
8 |
+
AdaIN Style Transfer Network
|
9 |
+
|
10 |
+
Args:
|
11 |
+
vgg_weight: pretrained vgg19 weight
|
12 |
+
"""
|
13 |
+
def __init__(self, vgg_weight):
|
14 |
+
super().__init__()
|
15 |
+
self.encoder = vgg19(vgg_weight)
|
16 |
+
|
17 |
+
# drop layers after 4_1
|
18 |
+
self.encoder = nn.Sequential(*list(self.encoder.children())[:22])
|
19 |
+
|
20 |
+
# No optimization for encoder
|
21 |
+
for parameter in self.encoder.parameters():
|
22 |
+
parameter.requires_grad = False
|
23 |
+
|
24 |
+
self.decoder = decoder()
|
25 |
+
|
26 |
+
self.mseloss = nn.MSELoss()
|
27 |
+
|
28 |
+
"""
|
29 |
+
Computes style loss of two images
|
30 |
+
|
31 |
+
Args:
|
32 |
+
x (torch.FloatTensor): content image tensor
|
33 |
+
y (torch.FloatTensor): style image tensor
|
34 |
+
|
35 |
+
Return:
|
36 |
+
Mean Squared Error between x.mean, y.mean and MSE between x.std, y.std
|
37 |
+
"""
|
38 |
+
def _style_loss(self, x, y):
|
39 |
+
return self.mseloss(torch.mean(x, dim=[2, 3]), torch.mean(y, dim=[2, 3])) + \
|
40 |
+
self.mseloss(torch.std(x, dim=[2, 3]), torch.std(y, dim=[2, 3]))
|
41 |
+
|
42 |
+
def forward(self, content, style, alpha=1.0):
|
43 |
+
# Generate image features
|
44 |
+
content_enc = self.encoder(content)
|
45 |
+
style_enc = self.encoder(style)
|
46 |
+
|
47 |
+
# Perform style transfer on feature space
|
48 |
+
transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
|
49 |
+
|
50 |
+
# Generate outptu image
|
51 |
+
out = self.decoder(transfer_enc)
|
52 |
+
|
53 |
+
# vgg19 layer relu1_1
|
54 |
+
style_relu11 = self.encoder[:3](style)
|
55 |
+
out_relu11 = self.encoder[:3](out)
|
56 |
+
|
57 |
+
# vgg19 layer relu2_1
|
58 |
+
style_relu21 = self.encoder[3:8](style_relu11)
|
59 |
+
out_relu21 = self.encoder[3:8](out_relu11)
|
60 |
+
|
61 |
+
# vgg19 layer relu3_1
|
62 |
+
style_relu31 = self.encoder[8:13](style_relu21)
|
63 |
+
out_relu31 = self.encoder[8:13](out_relu21)
|
64 |
+
|
65 |
+
# vgg19 layer relu4_1
|
66 |
+
out_enc = self.encoder[13:](out_relu31)
|
67 |
+
|
68 |
+
# Calculate loss
|
69 |
+
content_loss = self.mseloss(out_enc, transfer_enc)
|
70 |
+
style_loss = self._style_loss(out_relu11, style_relu11) + self._style_loss(out_relu21, style_relu21) + \
|
71 |
+
self._style_loss(out_relu31, style_relu31) + self._style_loss(out_enc, style_enc)
|
72 |
+
|
73 |
+
return content_loss, style_loss
|
streamlit_app/Network.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
vgg19_cfg = [3, 64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"]
|
4 |
+
decoder_cfg = [512, 256, "U", 256, 256, 256, 128, "U", 128, 64, 'U', 64, 3]
|
5 |
+
|
6 |
+
def vgg19(weights=None):
|
7 |
+
"""
|
8 |
+
Build vgg19 network. Load weights if weights are given.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
weights (dict): vgg19 pretrained weights
|
12 |
+
|
13 |
+
Return:
|
14 |
+
layers (nn.Sequential): vgg19 layers
|
15 |
+
"""
|
16 |
+
|
17 |
+
modules = make_block(vgg19_cfg)
|
18 |
+
modules = [nn.Conv2d(3, 3, kernel_size=1)] + list(modules.children())
|
19 |
+
layers = nn.Sequential(*modules)
|
20 |
+
|
21 |
+
if weights:
|
22 |
+
layers.load_state_dict(weights)
|
23 |
+
|
24 |
+
return layers
|
25 |
+
|
26 |
+
|
27 |
+
def decoder(weights=None):
|
28 |
+
"""
|
29 |
+
Build decoder network. Load weights if weights are given.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
weights (dict): decoder pretrained weights
|
33 |
+
|
34 |
+
Return:
|
35 |
+
layers (nn.Sequential): decoder layers
|
36 |
+
"""
|
37 |
+
|
38 |
+
modules = make_block(decoder_cfg)
|
39 |
+
layers = nn.Sequential(*list(modules.children())[:-1]) # no relu at the last layer
|
40 |
+
|
41 |
+
if weights:
|
42 |
+
layers.load_state_dict(weights)
|
43 |
+
|
44 |
+
return layers
|
45 |
+
|
46 |
+
|
47 |
+
def make_block(config):
|
48 |
+
"""
|
49 |
+
Helper function for building blocks of convolutional layers.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
config (list): List of layer configs. "M"
|
53 |
+
"M" - Max pooling layer.
|
54 |
+
"U" - Upsampling layer.
|
55 |
+
i (int) - Convolutional layer (i filters) plus ReLU activation.
|
56 |
+
Return:
|
57 |
+
layers (nn.Sequential): block layers
|
58 |
+
"""
|
59 |
+
layers = []
|
60 |
+
in_channels = config[0]
|
61 |
+
|
62 |
+
for c in config[1:]:
|
63 |
+
if c == "M":
|
64 |
+
layers.append(nn.MaxPool2d(kernel_size=2, stride=2, padding=0))
|
65 |
+
elif c == "U":
|
66 |
+
layers.append(nn.Upsample(scale_factor=2, mode='nearest'))
|
67 |
+
else:
|
68 |
+
assert(isinstance(c, int))
|
69 |
+
layers.append(nn.Conv2d(in_channels, c, kernel_size=3, padding=1))
|
70 |
+
layers.append(nn.ReLU(inplace=True))
|
71 |
+
in_channels = c
|
72 |
+
|
73 |
+
return nn.Sequential(*layers)
|
streamlit_app/app.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
import gdown
|
4 |
+
from packaging.version import Version
|
5 |
+
|
6 |
+
from infer_func import convert
|
7 |
+
|
8 |
+
ROOT = os.path.dirname(os.path.abspath(__file__))
|
9 |
+
|
10 |
+
EXAMPLES = {
|
11 |
+
'content': {
|
12 |
+
'Brad Pitt': ROOT + '/examples/content/brad_pitt.jpg'
|
13 |
+
},
|
14 |
+
'style': {
|
15 |
+
'Flower of Life': ROOT + '/examples/style/flower_of_life.jpg'
|
16 |
+
}
|
17 |
+
}
|
18 |
+
|
19 |
+
VGG_WEIGHT_URL = 'https://drive.google.com/uc?id=1UcSl-Zn3byEmn15NIPXMf9zaGCKc2gfx'
|
20 |
+
DECODER_WEIGHT_URL = 'https://drive.google.com/uc?id=18JpLtMOapA-vwBz-LRomyTl24A9GwhTF'
|
21 |
+
|
22 |
+
VGG_WEIGHT_FILENAME = ROOT + '/vgg.pth'
|
23 |
+
DECODER_WEIGHT_FILENAME = ROOT + '/decoder.pth'
|
24 |
+
|
25 |
+
|
26 |
+
@st.cache
|
27 |
+
def download_models():
|
28 |
+
with st.spinner(text="Downloading VGG weights..."):
|
29 |
+
gdown.download(VGG_WEIGHT_URL, output=VGG_WEIGHT_FILENAME)
|
30 |
+
with st.spinner(text="Downloading Decoder weights..."):
|
31 |
+
gdown.download(DECODER_WEIGHT_URL, output=DECODER_WEIGHT_FILENAME)
|
32 |
+
|
33 |
+
|
34 |
+
def image_getter(image_kind):
|
35 |
+
|
36 |
+
image = None
|
37 |
+
|
38 |
+
options = ['Use Example Image', 'Upload Image']
|
39 |
+
|
40 |
+
if Version(st.__version__) >= Version('1.4.0'):
|
41 |
+
options.append('Open Camera')
|
42 |
+
|
43 |
+
option = st.selectbox(
|
44 |
+
'Choose Image',
|
45 |
+
options, key=image_kind)
|
46 |
+
|
47 |
+
if option == 'Use Example Image':
|
48 |
+
image_key = st.selectbox(
|
49 |
+
'Choose from examples',
|
50 |
+
EXAMPLES[image_kind], key=image_kind)
|
51 |
+
image = EXAMPLES[image_kind][image_key]
|
52 |
+
|
53 |
+
elif option == 'Upload Image':
|
54 |
+
image = st.file_uploader("Upload an image", type=['png', 'jpg', 'PNG', 'JPG', 'JPEG'], key=image_kind)
|
55 |
+
elif option == 'Open Camera':
|
56 |
+
image = st.camera_input('', key=image_kind)
|
57 |
+
|
58 |
+
return image
|
59 |
+
|
60 |
+
|
61 |
+
if __name__ == '__main__':
|
62 |
+
|
63 |
+
st.set_page_config(layout="wide")
|
64 |
+
st.header('Adaptive Instance Normalization demo based on '
|
65 |
+
'[2022-AdaIN-pytorch](https://github.com/media-comp/2022-AdaIN-pytorch)')
|
66 |
+
|
67 |
+
download_models()
|
68 |
+
# col1, col2, col3, col4 = st.columns((2, 2, 1, 3))
|
69 |
+
col1, col2, col3 = st.columns((3, 4, 4))
|
70 |
+
with col1:
|
71 |
+
st.subheader('Content Image')
|
72 |
+
content = image_getter('content')
|
73 |
+
st.subheader('Style Image')
|
74 |
+
style = image_getter('style')
|
75 |
+
with col2:
|
76 |
+
img1 = content if content is not None else 'examples/img.png'
|
77 |
+
img2 = style if style is not None else 'examples/img.png'
|
78 |
+
if img1 is not None:
|
79 |
+
st.image(img1, width=None, caption='Content Image')
|
80 |
+
if img2 is not None:
|
81 |
+
st.image(img2, width=None, caption='Style Image')
|
82 |
+
|
83 |
+
with col3:
|
84 |
+
color_control = st.checkbox('Preserve content image color')
|
85 |
+
alpha = st.slider('Strength of style transfer', 0.0, 1.0, 1.0, 0.01)
|
86 |
+
process = st.button('Stylize')
|
87 |
+
|
88 |
+
if content is not None and style is not None and process:
|
89 |
+
print(content, style)
|
90 |
+
with col3:
|
91 |
+
with st.spinner('Processing...'):
|
92 |
+
output_image = convert(content, style, VGG_WEIGHT_FILENAME, DECODER_WEIGHT_FILENAME, alpha, color_control)
|
93 |
+
|
94 |
+
st.image(output_image, width=None, caption='Stylized Image')
|
95 |
+
|
streamlit_app/infer_func.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
from AdaIN import AdaINNet
|
6 |
+
from utils import adaptive_instance_normalization, transform, linear_histogram_matching
|
7 |
+
|
8 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
9 |
+
|
10 |
+
|
11 |
+
def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
|
12 |
+
"""
|
13 |
+
Given content image and style image, generate feature maps with encoder, apply
|
14 |
+
neural style transfer with adaptive instance normalization, generate output image
|
15 |
+
with decoder
|
16 |
+
|
17 |
+
Args:
|
18 |
+
content_tensor (torch.FloatTensor): Content image
|
19 |
+
style_tensor (torch.FloatTensor): Style Image
|
20 |
+
encoder: Encoder (vgg19) network
|
21 |
+
decoder: Decoder network
|
22 |
+
alpha (float, default=1.0): Weight of style image feature
|
23 |
+
|
24 |
+
Return:
|
25 |
+
output_tensor (torch.FloatTensor): Style Transfer output image
|
26 |
+
"""
|
27 |
+
|
28 |
+
content_enc = encoder(content_tensor)
|
29 |
+
style_enc = encoder(style_tensor)
|
30 |
+
|
31 |
+
transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
|
32 |
+
|
33 |
+
mix_enc = alpha * transfer_enc + (1 - alpha) * content_enc
|
34 |
+
return decoder(mix_enc)
|
35 |
+
|
36 |
+
|
37 |
+
def convert(content_path, style_path, vgg_weights_path, decoder_weights_path, alpha, color_control):
|
38 |
+
|
39 |
+
vgg = torch.load(vgg_weights_path)
|
40 |
+
model = AdaINNet(vgg).to(device)
|
41 |
+
model.decoder.load_state_dict(torch.load(decoder_weights_path))
|
42 |
+
model.eval()
|
43 |
+
|
44 |
+
# Prepare image transform
|
45 |
+
t = transform(512)
|
46 |
+
|
47 |
+
# load images
|
48 |
+
content_img = Image.open(content_path)
|
49 |
+
content_tensor = t(content_img).unsqueeze(0).to(device)
|
50 |
+
style_tensor = t(Image.open(style_path)).unsqueeze(0).to(device)
|
51 |
+
|
52 |
+
if color_control:
|
53 |
+
style_tensor = linear_histogram_matching(content_tensor, style_tensor)
|
54 |
+
|
55 |
+
with torch.no_grad():
|
56 |
+
out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, alpha).cpu()
|
57 |
+
|
58 |
+
output_image = torchvision.transforms.ToPILImage()(out_tensor.squeeze(0))
|
59 |
+
|
60 |
+
return output_image
|
streamlit_app/requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.10.1
|
2 |
+
torchvision==0.11.2
|
3 |
+
opencv-python==4.5.1.48
|
4 |
+
numpy == 1.18.4
|
5 |
+
Pillow==8.4.0
|
6 |
+
tqdm==4.62.3
|
7 |
+
imageio==2.9.0
|
8 |
+
imageio-ffmpeg==0.4.6
|
9 |
+
matplotlib==3.3.2
|
10 |
+
gdown
|
11 |
+
packaging
|
12 |
+
streamlit
|
streamlit_app/utils.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image, ImageFile
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from pathlib import Path
|
8 |
+
from glob import glob
|
9 |
+
|
10 |
+
def adaptive_instance_normalization(x, y, eps=1e-5):
|
11 |
+
"""
|
12 |
+
Adaptive Instance Normalization. Perform neural style transfer given content image x
|
13 |
+
and style image y.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
x (torch.FloatTensor): Content image tensor
|
17 |
+
y (torch.FloatTensor): Style image tensor
|
18 |
+
eps (float, default=1e-5): Small value to avoid zero division
|
19 |
+
|
20 |
+
Return:
|
21 |
+
output (torch.FloatTensor): AdaIN style transferred output
|
22 |
+
"""
|
23 |
+
|
24 |
+
mu_x = torch.mean(x, dim=[2, 3])
|
25 |
+
mu_y = torch.mean(y, dim=[2, 3])
|
26 |
+
mu_x = mu_x.unsqueeze(-1).unsqueeze(-1)
|
27 |
+
mu_y = mu_y.unsqueeze(-1).unsqueeze(-1)
|
28 |
+
|
29 |
+
sigma_x = torch.std(x, dim=[2, 3])
|
30 |
+
sigma_y = torch.std(y, dim=[2, 3])
|
31 |
+
sigma_x = sigma_x.unsqueeze(-1).unsqueeze(-1) + eps
|
32 |
+
sigma_y = sigma_y.unsqueeze(-1).unsqueeze(-1) + eps
|
33 |
+
|
34 |
+
return (x - mu_x) / sigma_x * sigma_y + mu_y
|
35 |
+
|
36 |
+
def transform(size):
|
37 |
+
"""
|
38 |
+
Image preprocess transformation. Resize image and convert to tensor.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
size (int): Resize image size
|
42 |
+
|
43 |
+
Return:
|
44 |
+
output (torchvision.transforms): Composition of torchvision.transforms steps
|
45 |
+
"""
|
46 |
+
|
47 |
+
t = []
|
48 |
+
t.append(transforms.Resize(size))
|
49 |
+
t.append(transforms.ToTensor())
|
50 |
+
t = transforms.Compose(t)
|
51 |
+
return t
|
52 |
+
|
53 |
+
def grid_image(row, col, images, height=6, width=6, save_pth='grid.png'):
|
54 |
+
"""
|
55 |
+
Generate and save an image that contains row x col grids of images.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
row (int): number of rows
|
59 |
+
col (int): number of columns
|
60 |
+
images (list of PIL image): list of images.
|
61 |
+
height (int) : height of each image (inch)
|
62 |
+
width (int) : width of eac image (inch)
|
63 |
+
save_pth (str): save file path
|
64 |
+
"""
|
65 |
+
|
66 |
+
width = col * width
|
67 |
+
height = row * height
|
68 |
+
plt.figure(figsize=(width, height))
|
69 |
+
for i, image in enumerate(images):
|
70 |
+
plt.subplot(row, col, i+1)
|
71 |
+
plt.imshow(image)
|
72 |
+
plt.axis('off')
|
73 |
+
plt.subplots_adjust(wspace=0.01, hspace=0.01)
|
74 |
+
plt.savefig(save_pth)
|
75 |
+
|
76 |
+
|
77 |
+
def linear_histogram_matching(content_tensor, style_tensor):
|
78 |
+
"""
|
79 |
+
Given content_tensor and style_tensor, transform style_tensor histogram to that of content_tensor.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
content_tensor (torch.FloatTensor): Content image
|
83 |
+
style_tensor (torch.FloatTensor): Style Image
|
84 |
+
|
85 |
+
Return:
|
86 |
+
style_tensor (torch.FloatTensor): histogram matched Style Image
|
87 |
+
"""
|
88 |
+
#for batch
|
89 |
+
for b in range(len(content_tensor)):
|
90 |
+
std_ct = []
|
91 |
+
std_st = []
|
92 |
+
mean_ct = []
|
93 |
+
mean_st = []
|
94 |
+
#for channel
|
95 |
+
for c in range(len(content_tensor[b])):
|
96 |
+
std_ct.append(torch.var(content_tensor[b][c],unbiased = False))
|
97 |
+
mean_ct.append(torch.mean(content_tensor[b][c]))
|
98 |
+
std_st.append(torch.var(style_tensor[b][c],unbiased = False))
|
99 |
+
mean_st.append(torch.mean(style_tensor[b][c]))
|
100 |
+
style_tensor[b][c] = (style_tensor[b][c] - mean_st[c]) * std_ct[c] / std_st[c] + mean_ct[c]
|
101 |
+
return style_tensor
|
102 |
+
|
103 |
+
|
104 |
+
class TrainSet(Dataset):
|
105 |
+
"""
|
106 |
+
Build Training dataset
|
107 |
+
"""
|
108 |
+
def __init__(self, content_dir, style_dir, crop_size = 256):
|
109 |
+
super().__init__()
|
110 |
+
|
111 |
+
self.content_files = [Path(f) for f in glob(content_dir+'/*')]
|
112 |
+
self.style_files = [Path(f) for f in glob(style_dir+'/*')]
|
113 |
+
|
114 |
+
self.transform = transforms.Compose([
|
115 |
+
transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC),
|
116 |
+
transforms.RandomCrop(crop_size),
|
117 |
+
transforms.ToTensor(),
|
118 |
+
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
|
119 |
+
])
|
120 |
+
|
121 |
+
Image.MAX_IMAGE_PIXELS = None
|
122 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
123 |
+
|
124 |
+
def __len__(self):
|
125 |
+
return min(len(self.style_files), len(self.content_files))
|
126 |
+
|
127 |
+
def __getitem__(self, index):
|
128 |
+
content_img = Image.open(self.content_files[index]).convert('RGB')
|
129 |
+
style_img = Image.open(self.style_files[index]).convert('RGB')
|
130 |
+
|
131 |
+
content_sample = self.transform(content_img)
|
132 |
+
style_sample = self.transform(style_img)
|
133 |
+
|
134 |
+
return content_sample, style_sample
|
135 |
+
|
136 |
+
class Range(object):
|
137 |
+
"""
|
138 |
+
Helper class for input argument range restriction
|
139 |
+
"""
|
140 |
+
def __init__(self, start, end):
|
141 |
+
self.start = start
|
142 |
+
self.end = end
|
143 |
+
def __eq__(self, other):
|
144 |
+
return self.start <= other <= self.end
|