init_commit
Browse files- .gitignore +130 -0
- Dockerfile +32 -0
- app.py +130 -0
- requirements.txt +6 -0
- seg2art/checkpoints/multimodal_artworks/latest_net_G-fp16.pth +3 -0
- seg2art/inference_util.py +77 -0
- seg2art/model_util.py +158 -0
- seg2art/options/__init__.py +4 -0
- seg2art/options/base_options.py +184 -0
- seg2art/options/test_options.py +22 -0
- seg2art/sstan_models/__init__.py +44 -0
- seg2art/sstan_models/networks/__init__.py +63 -0
- seg2art/sstan_models/networks/architecture.py +231 -0
- seg2art/sstan_models/networks/base_network.py +59 -0
- seg2art/sstan_models/networks/dual_attention_module.py +51 -0
- seg2art/sstan_models/networks/generator.py +184 -0
- seg2art/sstan_models/networks/normalization.py +222 -0
- seg2art/sstan_models/networks/sync_batchnorm/__init__.py +13 -0
- seg2art/sstan_models/networks/sync_batchnorm/batchnorm.py +361 -0
- seg2art/sstan_models/networks/sync_batchnorm/batchnorm_reimpl.py +74 -0
- seg2art/sstan_models/networks/sync_batchnorm/comm.py +137 -0
- seg2art/sstan_models/networks/sync_batchnorm/replicate.py +94 -0
- seg2art/sstan_models/networks/sync_batchnorm/unittest.py +29 -0
- seg2art/sstan_models/pix2pix_model.py +285 -0
- static/index.js +256 -0
- static/init_code +0 -0
- static/style.css +36 -0
- templates/index.html +124 -0
- utils/boundaries_amp_52/artwork_ink_boundary/boundary.npy +3 -0
- utils/boundaries_amp_52/artwork_ink_boundary/log.txt +12 -0
- utils/boundaries_amp_52/artwork_monet_boundary/boundary.npy +3 -0
- utils/boundaries_amp_52/artwork_monet_boundary/log.txt +12 -0
- utils/boundaries_amp_52/artwork_vangogh_boundary/boundary.npy +3 -0
- utils/boundaries_amp_52/artwork_vangogh_boundary/log.txt +12 -0
- utils/boundaries_amp_52/artwork_water_boundary/boundary.npy +3 -0
- utils/boundaries_amp_52/artwork_water_boundary/log.txt +12 -0
- utils/umap_utils.py +99 -0
.gitignore
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
.vercel
|
Dockerfile
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
2 |
+
# you will also find guides on how best to write your Dockerfile
|
3 |
+
|
4 |
+
# Include base image
|
5 |
+
FROM docker.io/pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
|
6 |
+
|
7 |
+
# Define working directory
|
8 |
+
WORKDIR /workspace/
|
9 |
+
|
10 |
+
# Set timezone
|
11 |
+
ENV TZ=Asia/Tokyo
|
12 |
+
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
|
13 |
+
|
14 |
+
# Install dependencies
|
15 |
+
RUN apt-get update && apt-get -y install libgl1 libglib2.0-0 vim
|
16 |
+
RUN apt-get autoremove -y && apt-get clean -y
|
17 |
+
|
18 |
+
# Add pretrained model
|
19 |
+
ADD seg2art ./seg2art
|
20 |
+
ADD static ./static
|
21 |
+
ADD templates ./templates
|
22 |
+
ADD utils ./utils
|
23 |
+
|
24 |
+
# Add necessary files
|
25 |
+
ADD app.py ./
|
26 |
+
|
27 |
+
# pip install
|
28 |
+
ADD requirements.txt ./
|
29 |
+
RUN pip install -r requirements.txt
|
30 |
+
|
31 |
+
# Run server
|
32 |
+
CMD [ "python", "-u", "./app.py" ]
|
app.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
import base64
|
7 |
+
from PIL import Image
|
8 |
+
from io import BytesIO
|
9 |
+
|
10 |
+
# set CUDA_MODULE_LOADING=LAZY to speed up the serverless function
|
11 |
+
os.environ["CUDA_MODULE_LOADING"] = "LAZY"
|
12 |
+
# set SAFETENSORS_FAST_GPU=1 to speed up the serverless function
|
13 |
+
os.environ["SAFETENSORS_FAST_GPU"] = "1"
|
14 |
+
|
15 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "seg2art"))
|
16 |
+
from seg2art.sstan_models.pix2pix_model import Pix2PixModel
|
17 |
+
from seg2art.options.test_options import TestOptions
|
18 |
+
from seg2art.inference_util import get_artwork
|
19 |
+
|
20 |
+
import uvicorn
|
21 |
+
from fastapi import FastAPI, Form
|
22 |
+
from fastapi.templating import Jinja2Templates
|
23 |
+
from fastapi.responses import PlainTextResponse, HTMLResponse
|
24 |
+
from fastapi.requests import Request
|
25 |
+
from fastapi.staticfiles import StaticFiles
|
26 |
+
|
27 |
+
|
28 |
+
# declare constants
|
29 |
+
HOST = "0.0.0.0"
|
30 |
+
PORT = 7860
|
31 |
+
# FastAPI
|
32 |
+
app = FastAPI(root_path=os.path.abspath(os.path.dirname(__file__)))
|
33 |
+
app.mount("/static", StaticFiles(directory="static"), name="static")
|
34 |
+
templates = Jinja2Templates(directory="templates")
|
35 |
+
|
36 |
+
|
37 |
+
# initialize SEAN model.
|
38 |
+
opt = TestOptions().parse()
|
39 |
+
opt.status = "test"
|
40 |
+
model = Pix2PixModel(opt)
|
41 |
+
model = model.half() if torch.cuda.is_available() else model
|
42 |
+
model.eval()
|
43 |
+
|
44 |
+
|
45 |
+
from utils.umap_utils import get_code, load_boundries, modify_code
|
46 |
+
|
47 |
+
boundaries = load_boundries()
|
48 |
+
global current_codes
|
49 |
+
current_codes = {}
|
50 |
+
max_user_num = 5
|
51 |
+
|
52 |
+
initial_code_path = os.path.join(os.path.dirname(__file__), "static/init_code")
|
53 |
+
initial_code = torch.load(initial_code_path) if torch.cuda.is_available() else torch.load(initial_code_path, map_location=torch.device("cpu"))
|
54 |
+
|
55 |
+
|
56 |
+
def EncodeImage(img_pil):
|
57 |
+
with BytesIO() as buffer:
|
58 |
+
img_pil.save(buffer, "jpeg")
|
59 |
+
image_data = base64.b64encode(buffer.getvalue())
|
60 |
+
return image_data
|
61 |
+
|
62 |
+
|
63 |
+
def DecodeImage(img_pil):
|
64 |
+
img_pil = BytesIO(base64.urlsafe_b64decode(img_pil))
|
65 |
+
img_pil = Image.open(img_pil).convert("RGB")
|
66 |
+
return img_pil
|
67 |
+
|
68 |
+
|
69 |
+
def process_input(body, random=False):
|
70 |
+
global current_codes
|
71 |
+
json_body = json.loads(body.decode("utf-8"))
|
72 |
+
user_id = json_body["user_id"]
|
73 |
+
start_time = time.time()
|
74 |
+
|
75 |
+
# save current code for different users
|
76 |
+
if user_id not in current_codes:
|
77 |
+
current_codes[user_id] = initial_code.clone()
|
78 |
+
if len(current_codes[user_id]) > max_user_num:
|
79 |
+
current_codes[user_id] = current_codes[user_id][-max_user_num:]
|
80 |
+
|
81 |
+
if random:
|
82 |
+
# randomize code
|
83 |
+
domain = json_body["model"]
|
84 |
+
current_codes[user_id] = get_code(domain, boundaries)
|
85 |
+
|
86 |
+
# get input
|
87 |
+
input_img = DecodeImage(json_body["img"])
|
88 |
+
|
89 |
+
try:
|
90 |
+
move_range = float(json_body["move_range"])
|
91 |
+
except:
|
92 |
+
move_range = 0
|
93 |
+
|
94 |
+
# set move range to 3 if random is True
|
95 |
+
move_range = 3 if random else move_range
|
96 |
+
# print("Input image was received")
|
97 |
+
# get selected style
|
98 |
+
domain = json_body["model"]
|
99 |
+
if move_range != 0:
|
100 |
+
modified_code = modify_code(current_codes[user_id], boundaries, domain, move_range)
|
101 |
+
else:
|
102 |
+
modified_code = current_code.clone()
|
103 |
+
|
104 |
+
# inference
|
105 |
+
result = get_artwork(model, input_img, modified_code)
|
106 |
+
print("Time Cost: ", time.time() - start_time)
|
107 |
+
return EncodeImage(result)
|
108 |
+
|
109 |
+
|
110 |
+
@app.get("/", response_class=HTMLResponse)
|
111 |
+
def root(request: Request):
|
112 |
+
return templates.TemplateResponse("index.html", {"request": request})
|
113 |
+
|
114 |
+
|
115 |
+
@app.post("/predict")
|
116 |
+
async def predict(request: Request):
|
117 |
+
body = await request.body()
|
118 |
+
result = process_input(body, random=False)
|
119 |
+
return result
|
120 |
+
|
121 |
+
|
122 |
+
@app.post("/predict_random")
|
123 |
+
async def predict_random(request: Request):
|
124 |
+
body = await request.body()
|
125 |
+
result = process_input(body, random=True)
|
126 |
+
return result
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
uvicorn.run(app, host=HOST, port=PORT, log_level="info")
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
scikit-learn
|
2 |
+
scikit-image
|
3 |
+
torchvision>=0.7.0
|
4 |
+
torch>=1.6.0
|
5 |
+
fastapi
|
6 |
+
uvicorn
|
seg2art/checkpoints/multimodal_artworks/latest_net_G-fp16.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:11e5d324b59dce20e81cea0eed77b25c7b9f6b56ccb44970b67593b4287ddb4c
|
3 |
+
size 418576205
|
seg2art/inference_util.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision import transforms
|
8 |
+
|
9 |
+
# define constants
|
10 |
+
image_size = 256
|
11 |
+
|
12 |
+
# to label
|
13 |
+
values = [12, 2, 6, 8, 1, 10, 3, 14, 11, 4, 5, 13, 9]
|
14 |
+
values = np.array(values)
|
15 |
+
|
16 |
+
# from color
|
17 |
+
colors = [
|
18 |
+
(135, 206, 235),
|
19 |
+
(155, 118, 83),
|
20 |
+
(176, 212, 155),
|
21 |
+
(90, 188, 216),
|
22 |
+
(193, 190, 186),
|
23 |
+
(90, 77, 65),
|
24 |
+
(86, 125, 70),
|
25 |
+
(66, 105, 47),
|
26 |
+
(21, 119, 190),
|
27 |
+
(58, 46, 39),
|
28 |
+
(77, 65, 90),
|
29 |
+
(253, 218, 22),
|
30 |
+
(208, 204, 204),
|
31 |
+
]
|
32 |
+
colors = np.array(colors)
|
33 |
+
|
34 |
+
|
35 |
+
def remap_label(arr):
|
36 |
+
# compare only first 1 channel to speed up
|
37 |
+
arr_r = arr[:, :, 0]
|
38 |
+
|
39 |
+
# remap color to label
|
40 |
+
for i in range(len(colors)):
|
41 |
+
arr_r[arr_r == colors[i][0]] = values[i]
|
42 |
+
# others to 15
|
43 |
+
arr_r[arr_r > 15] = 15
|
44 |
+
return arr_r
|
45 |
+
|
46 |
+
|
47 |
+
preprocess = transforms.Compose(
|
48 |
+
[
|
49 |
+
transforms.Resize([image_size, image_size]),
|
50 |
+
transforms.ToTensor(),
|
51 |
+
]
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
def image_loader(loader, label_inp):
|
56 |
+
image = Image.fromarray(label_inp).convert("RGB")
|
57 |
+
image = image.resize((image_size, image_size))
|
58 |
+
image = loader(image).float() * 255
|
59 |
+
image = image.clone().detach().requires_grad_(True)
|
60 |
+
image = image.unsqueeze(0)
|
61 |
+
return image
|
62 |
+
|
63 |
+
|
64 |
+
def tensor2im(image_tensor):
|
65 |
+
image_numpy = image_tensor[0].detach().cpu().float().numpy()
|
66 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
|
67 |
+
image_numpy = np.clip(image_numpy, 0, 255)
|
68 |
+
return Image.fromarray(image_numpy.astype(np.uint8))
|
69 |
+
|
70 |
+
|
71 |
+
def get_artwork(model, data, code):
|
72 |
+
label_inp = remap_label(np.array(data))
|
73 |
+
label_inp = (image_loader(preprocess, label_inp)).detach().half()
|
74 |
+
|
75 |
+
image_out = model(label_inp, mode="inference", style_codes=code)
|
76 |
+
image_out = tensor2im(image_out)
|
77 |
+
return image_out
|
seg2art/model_util.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
|
6 |
+
import re
|
7 |
+
import importlib
|
8 |
+
import torch
|
9 |
+
from argparse import Namespace
|
10 |
+
import numpy as np
|
11 |
+
from PIL import Image
|
12 |
+
import os
|
13 |
+
|
14 |
+
|
15 |
+
# Converts a Tensor into a Numpy array
|
16 |
+
# |imtype|: the desired type of the converted numpy array
|
17 |
+
def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False):
|
18 |
+
if isinstance(image_tensor, list):
|
19 |
+
image_numpy = []
|
20 |
+
for i in range(len(image_tensor)):
|
21 |
+
image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
|
22 |
+
return image_numpy
|
23 |
+
|
24 |
+
if image_tensor.dim() == 4:
|
25 |
+
# transform each image in the batch
|
26 |
+
images_np = []
|
27 |
+
for b in range(image_tensor.size(0)):
|
28 |
+
one_image = image_tensor[b]
|
29 |
+
one_image_np = tensor2im(one_image)
|
30 |
+
images_np.append(one_image_np.reshape(1, *one_image_np.shape))
|
31 |
+
images_np = np.concatenate(images_np, axis=0)
|
32 |
+
return images_np
|
33 |
+
|
34 |
+
if image_tensor.dim() == 2:
|
35 |
+
image_tensor = image_tensor.unsqueeze(0)
|
36 |
+
image_numpy = image_tensor.detach().cpu().float().numpy()
|
37 |
+
if normalize:
|
38 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
|
39 |
+
else:
|
40 |
+
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
|
41 |
+
image_numpy = np.clip(image_numpy, 0, 255)
|
42 |
+
if image_numpy.shape[2] == 1:
|
43 |
+
image_numpy = image_numpy[:, :, 0]
|
44 |
+
return image_numpy.astype(imtype)
|
45 |
+
|
46 |
+
|
47 |
+
# Converts a one-hot tensor into a colorful label map
|
48 |
+
def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False):
|
49 |
+
if label_tensor.dim() == 4:
|
50 |
+
# transform each image in the batch
|
51 |
+
images_np = []
|
52 |
+
for b in range(label_tensor.size(0)):
|
53 |
+
one_image = label_tensor[b]
|
54 |
+
one_image_np = tensor2label(one_image, n_label, imtype)
|
55 |
+
images_np.append(one_image_np.reshape(1, *one_image_np.shape))
|
56 |
+
images_np = np.concatenate(images_np, axis=0)
|
57 |
+
if tile:
|
58 |
+
images_tiled = tile_images(images_np)
|
59 |
+
return images_tiled
|
60 |
+
else:
|
61 |
+
images_np = images_np[0]
|
62 |
+
return images_np
|
63 |
+
|
64 |
+
if label_tensor.dim() == 1:
|
65 |
+
return np.zeros((64, 64, 3), dtype=np.uint8)
|
66 |
+
if n_label == 0:
|
67 |
+
return tensor2im(label_tensor, imtype)
|
68 |
+
label_tensor = label_tensor.cpu().float()
|
69 |
+
if label_tensor.size()[0] > 1:
|
70 |
+
label_tensor = label_tensor.max(0, keepdim=True)[1]
|
71 |
+
label_tensor = Colorize(n_label)(label_tensor)
|
72 |
+
label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
|
73 |
+
result = label_numpy.astype(imtype)
|
74 |
+
return result
|
75 |
+
|
76 |
+
|
77 |
+
def save_image(image_numpy, image_path, create_dir=False):
|
78 |
+
if create_dir:
|
79 |
+
os.makedirs(os.path.dirname(image_path), exist_ok=True)
|
80 |
+
if len(image_numpy.shape) == 2:
|
81 |
+
image_numpy = np.expand_dims(image_numpy, axis=2)
|
82 |
+
if image_numpy.shape[2] == 1:
|
83 |
+
image_numpy = np.repeat(image_numpy, 3, 2)
|
84 |
+
image_pil = Image.fromarray(image_numpy)
|
85 |
+
|
86 |
+
# save to png
|
87 |
+
image_pil.save(image_path.replace('.jpg', '.png'))
|
88 |
+
|
89 |
+
|
90 |
+
def mkdirs(paths):
|
91 |
+
if isinstance(paths, list) and not isinstance(paths, str):
|
92 |
+
for path in paths:
|
93 |
+
mkdir(path)
|
94 |
+
else:
|
95 |
+
mkdir(paths)
|
96 |
+
|
97 |
+
|
98 |
+
def mkdir(path):
|
99 |
+
if not os.path.exists(path):
|
100 |
+
os.makedirs(path)
|
101 |
+
|
102 |
+
|
103 |
+
def atoi(text):
|
104 |
+
return int(text) if text.isdigit() else text
|
105 |
+
|
106 |
+
|
107 |
+
def natural_keys(text):
|
108 |
+
'''
|
109 |
+
alist.sort(key=natural_keys) sorts in human order
|
110 |
+
http://nedbatchelder.com/blog/200712/human_sorting.html
|
111 |
+
(See Toothy's implementation in the comments)
|
112 |
+
'''
|
113 |
+
return [atoi(c) for c in re.split('(\d+)', text)]
|
114 |
+
|
115 |
+
|
116 |
+
def natural_sort(items):
|
117 |
+
items.sort(key=natural_keys)
|
118 |
+
|
119 |
+
|
120 |
+
def str2bool(v):
|
121 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
122 |
+
return True
|
123 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
124 |
+
return False
|
125 |
+
else:
|
126 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
127 |
+
|
128 |
+
|
129 |
+
def find_class_in_module(target_cls_name, module):
|
130 |
+
target_cls_name = target_cls_name.replace('_', '').lower()
|
131 |
+
clslib = importlib.import_module(module)
|
132 |
+
cls = None
|
133 |
+
for name, clsobj in clslib.__dict__.items():
|
134 |
+
if name.lower() == target_cls_name:
|
135 |
+
cls = clsobj
|
136 |
+
|
137 |
+
if cls is None:
|
138 |
+
print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name))
|
139 |
+
exit(0)
|
140 |
+
|
141 |
+
return cls
|
142 |
+
|
143 |
+
|
144 |
+
def save_network(net, label, epoch, opt):
|
145 |
+
save_filename = '%s_net_%s.pth' % (epoch, label)
|
146 |
+
save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
|
147 |
+
torch.save(net.cpu().state_dict(), save_path)
|
148 |
+
if len(opt.gpu_ids) and torch.cuda.is_available():
|
149 |
+
net.cuda()
|
150 |
+
|
151 |
+
|
152 |
+
def load_network(net, label, epoch, opt):
|
153 |
+
save_filename = '%s_net_%s.pth' % (epoch, label)
|
154 |
+
save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
155 |
+
save_path = os.path.join(save_dir, save_filename)
|
156 |
+
weights = torch.load(save_path)
|
157 |
+
net.load_state_dict(weights, strict=False)#
|
158 |
+
return net
|
seg2art/options/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
seg2art/options/base_options.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import torch
|
9 |
+
import pickle
|
10 |
+
import argparse
|
11 |
+
import sstan_models
|
12 |
+
import utils as util
|
13 |
+
|
14 |
+
|
15 |
+
class BaseOptions():
|
16 |
+
def __init__(self):
|
17 |
+
self.initialized = False
|
18 |
+
|
19 |
+
def initialize(self, parser):
|
20 |
+
# experiment specifics
|
21 |
+
parser.add_argument('--name', type=str, default='multimodal_artworks', help='name of the experiment. It decides where to store samples and sstan_models')
|
22 |
+
|
23 |
+
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
|
24 |
+
parser.add_argument('--checkpoints_dir', type=str, default='./seg2art/checkpoints', help='sstan_models are saved here')
|
25 |
+
parser.add_argument('--model', type=str, default='pix2pix', help='which model to use')
|
26 |
+
parser.add_argument('--norm_G', type=str, default='spectralinstance', help='instance normalization or batch normalization')
|
27 |
+
parser.add_argument('--norm_D', type=str, default='spectralinstance', help='instance normalization or batch normalization')
|
28 |
+
parser.add_argument('--norm_E', type=str, default='spectralinstance', help='instance normalization or batch normalization')
|
29 |
+
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
|
30 |
+
|
31 |
+
# input/output sizes
|
32 |
+
parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
|
33 |
+
parser.add_argument('--preprocess_mode', type=str, default='scale_width_and_crop', help='scaling and cropping of images at load time.', choices=("resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "scale_shortside", "scale_shortside_and_crop", "fixed", "none"))
|
34 |
+
parser.add_argument('--load_size', type=int, default=512, help='Scale images to this size. The final image will be cropped to --crop_size.')
|
35 |
+
parser.add_argument('--crop_size', type=int, default=512, help='Crop to the width of crop_size (after initially scaling the images to load_size.)')
|
36 |
+
parser.add_argument('--aspect_ratio', type=float, default=1.0, help='The ratio width/height. The final height of the load image will be crop_size/aspect_ratio')
|
37 |
+
parser.add_argument('--label_nc', type=int, default=16, help='# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.')
|
38 |
+
parser.add_argument('--contain_dontcare_label', action='store_true', help='if the label map contains dontcare label (dontcare=255)')
|
39 |
+
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
|
40 |
+
|
41 |
+
# for setting inputs
|
42 |
+
parser.add_argument('--dataroot', type=str, default='./datasets/cityscapes/')
|
43 |
+
parser.add_argument('--dataset_mode', type=str, default='custom')
|
44 |
+
parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
|
45 |
+
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
|
46 |
+
parser.add_argument('--nThreads', default=0, type=int, help='# threads for loading data')
|
47 |
+
parser.add_argument('--max_dataset_size', type=int, default=sys.maxsize, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
|
48 |
+
parser.add_argument('--load_from_opt_file', action='store_true', help='load the options from checkpoints and use that as default')
|
49 |
+
parser.add_argument('--cache_filelist_write', action='store_true', help='saves the current filelist into a text file, so that it loads faster')
|
50 |
+
parser.add_argument('--cache_filelist_read', action='store_true', help='reads from the file list cache')
|
51 |
+
|
52 |
+
# for displays
|
53 |
+
parser.add_argument('--display_winsize', type=int, default=400, help='display window size')
|
54 |
+
|
55 |
+
# for generator
|
56 |
+
parser.add_argument('--netG', type=str, default='spade', help='selects model to use for netG (pix2pixhd | spade)')
|
57 |
+
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
|
58 |
+
parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]')
|
59 |
+
parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution')
|
60 |
+
parser.add_argument('--z_dim', type=int, default=256,
|
61 |
+
help="dimension of the latent z vector")
|
62 |
+
|
63 |
+
# for instance-wise features
|
64 |
+
parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input')
|
65 |
+
parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer')
|
66 |
+
parser.add_argument('--use_vae', action='store_true', help='enable training with an image encoder.')
|
67 |
+
|
68 |
+
self.initialized = True
|
69 |
+
return parser
|
70 |
+
|
71 |
+
def gather_options(self):
|
72 |
+
# initialize parser with basic options
|
73 |
+
if not self.initialized:
|
74 |
+
parser = argparse.ArgumentParser(
|
75 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
76 |
+
parser = self.initialize(parser)
|
77 |
+
|
78 |
+
# get the basic options
|
79 |
+
opt, unknown = parser.parse_known_args()
|
80 |
+
|
81 |
+
# modify model-related parser options
|
82 |
+
model_name = opt.model
|
83 |
+
model_option_setter = sstan_models.get_option_setter(model_name)
|
84 |
+
parser = model_option_setter(parser, self.isTrain)
|
85 |
+
|
86 |
+
|
87 |
+
# # modify dataset-related parser options
|
88 |
+
# dataset_mode = opt.dataset_mode
|
89 |
+
# dataset_option_setter = data.get_option_setter(dataset_mode)
|
90 |
+
# parser = dataset_option_setter(parser, self.isTrain)
|
91 |
+
|
92 |
+
# opt, unknown = parser.parse_known_args()
|
93 |
+
|
94 |
+
# # if there is opt_file, load it.
|
95 |
+
# # The previous default options will be overwritten
|
96 |
+
# if opt.load_from_opt_file:
|
97 |
+
# parser = self.update_options_from_file(parser, opt)
|
98 |
+
|
99 |
+
opt = parser.parse_args()
|
100 |
+
|
101 |
+
opt.contain_dontcare_label = False
|
102 |
+
opt.no_instance = True
|
103 |
+
opt.use_vae = False
|
104 |
+
|
105 |
+
self.parser = parser
|
106 |
+
return opt
|
107 |
+
|
108 |
+
def print_options(self, opt):
|
109 |
+
message = ''
|
110 |
+
message += '----------------- Options ---------------\n'
|
111 |
+
for k, v in sorted(vars(opt).items()):
|
112 |
+
comment = ''
|
113 |
+
default = self.parser.get_default(k)
|
114 |
+
if v != default:
|
115 |
+
comment = '\t[default: %s]' % str(default)
|
116 |
+
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
117 |
+
message += '----------------- End -------------------'
|
118 |
+
print(message)
|
119 |
+
|
120 |
+
def option_file_path(self, opt, makedir=False):
|
121 |
+
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
122 |
+
if makedir:
|
123 |
+
util.mkdirs(expr_dir)
|
124 |
+
file_name = os.path.join(expr_dir, 'opt')
|
125 |
+
return file_name
|
126 |
+
|
127 |
+
def save_options(self, opt):
|
128 |
+
file_name = self.option_file_path(opt, makedir=True)
|
129 |
+
with open(file_name + '.txt', 'wt') as opt_file:
|
130 |
+
for k, v in sorted(vars(opt).items()):
|
131 |
+
comment = ''
|
132 |
+
default = self.parser.get_default(k)
|
133 |
+
if v != default:
|
134 |
+
comment = '\t[default: %s]' % str(default)
|
135 |
+
opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment))
|
136 |
+
|
137 |
+
with open(file_name + '.pkl', 'wb') as opt_file:
|
138 |
+
pickle.dump(opt, opt_file)
|
139 |
+
|
140 |
+
def update_options_from_file(self, parser, opt):
|
141 |
+
new_opt = self.load_options(opt)
|
142 |
+
for k, v in sorted(vars(opt).items()):
|
143 |
+
if hasattr(new_opt, k) and v != getattr(new_opt, k):
|
144 |
+
new_val = getattr(new_opt, k)
|
145 |
+
parser.set_defaults(**{k: new_val})
|
146 |
+
return parser
|
147 |
+
|
148 |
+
def load_options(self, opt):
|
149 |
+
file_name = self.option_file_path(opt, makedir=False)
|
150 |
+
new_opt = pickle.load(open(file_name + '.pkl', 'rb'))
|
151 |
+
return new_opt
|
152 |
+
|
153 |
+
def parse(self, save=False):
|
154 |
+
|
155 |
+
opt = self.gather_options()
|
156 |
+
opt.isTrain = self.isTrain # train or test
|
157 |
+
|
158 |
+
#self.print_options(opt)
|
159 |
+
if opt.isTrain:
|
160 |
+
self.save_options(opt)
|
161 |
+
|
162 |
+
# Set semantic_nc based on the option.
|
163 |
+
# This will be convenient in many places
|
164 |
+
opt.semantic_nc = opt.label_nc + \
|
165 |
+
(1 if opt.contain_dontcare_label else 0) + \
|
166 |
+
(0 if opt.no_instance else 1)
|
167 |
+
|
168 |
+
# set gpu ids
|
169 |
+
str_ids = opt.gpu_ids.split(',')
|
170 |
+
opt.gpu_ids = []
|
171 |
+
for str_id in str_ids:
|
172 |
+
id = int(str_id)
|
173 |
+
if id >= 0:
|
174 |
+
opt.gpu_ids.append(id)
|
175 |
+
opt.gpu_ids = [] if torch.cuda.device_count() == 0 else opt.gpu_ids
|
176 |
+
if len(opt.gpu_ids) > 0:
|
177 |
+
torch.cuda.set_device(opt.gpu_ids[0])
|
178 |
+
|
179 |
+
assert len(opt.gpu_ids) == 0 or opt.batchSize % len(opt.gpu_ids) == 0, \
|
180 |
+
"Batch size %d is wrong. It must be a multiple of # GPUs %d." \
|
181 |
+
% (opt.batchSize, len(opt.gpu_ids))
|
182 |
+
|
183 |
+
self.opt = opt
|
184 |
+
return self.opt
|
seg2art/options/test_options.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
|
6 |
+
from .base_options import BaseOptions
|
7 |
+
|
8 |
+
|
9 |
+
class TestOptions(BaseOptions):
|
10 |
+
def initialize(self, parser):
|
11 |
+
BaseOptions.initialize(self, parser)
|
12 |
+
parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
|
13 |
+
parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
|
14 |
+
parser.add_argument('--checkpoint_path', type=str, default='./checkpoints/multimodal_artworks/latest_net_G-fp16.pth', help='load model from a checkpoint')
|
15 |
+
parser.add_argument('--how_many', type=int, default=float("inf"), help='how many test images to run')
|
16 |
+
|
17 |
+
parser.set_defaults(preprocess_mode='scale_width_and_crop', crop_size=512, load_size=512, display_winsize=256)
|
18 |
+
parser.set_defaults(serial_batches=True)
|
19 |
+
parser.set_defaults(no_flip=True)
|
20 |
+
parser.set_defaults(phase='test')
|
21 |
+
self.isTrain = False
|
22 |
+
return parser
|
seg2art/sstan_models/__init__.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
|
6 |
+
import importlib
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
def find_model_using_name(model_name):
|
11 |
+
# Given the option --model [modelname],
|
12 |
+
# the file "sstan_models/modelname_model.py"
|
13 |
+
# will be imported.
|
14 |
+
model_filename = "sstan_models." + model_name + "_model"
|
15 |
+
modellib = importlib.import_module(model_filename)
|
16 |
+
|
17 |
+
# In the file, the class called ModelNameModel() will
|
18 |
+
# be instantiated. It has to be a subclass of torch.nn.Module,
|
19 |
+
# and it is case-insensitive.
|
20 |
+
model = None
|
21 |
+
target_model_name = model_name.replace('_', '') + 'model'
|
22 |
+
for name, cls in modellib.__dict__.items():
|
23 |
+
if name.lower() == target_model_name.lower() \
|
24 |
+
and issubclass(cls, torch.nn.Module):
|
25 |
+
model = cls
|
26 |
+
|
27 |
+
if model is None:
|
28 |
+
print("In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
29 |
+
exit(0)
|
30 |
+
|
31 |
+
return model
|
32 |
+
|
33 |
+
|
34 |
+
def get_option_setter(model_name):
|
35 |
+
model_class = find_model_using_name(model_name)
|
36 |
+
return model_class.modify_commandline_options
|
37 |
+
|
38 |
+
|
39 |
+
def create_model(opt):
|
40 |
+
model = find_model_using_name(opt.model)
|
41 |
+
instance = model(opt)
|
42 |
+
print("model [%s] was created" % (type(instance).__name__))
|
43 |
+
|
44 |
+
return instance
|
seg2art/sstan_models/networks/__init__.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from sstan_models.networks.base_network import BaseNetwork
|
8 |
+
# from sstan_models.networks.loss import *
|
9 |
+
# from sstan_models.networks.discriminator import *
|
10 |
+
from sstan_models.networks.generator import *
|
11 |
+
# from sstan_models.networks.encoder import *
|
12 |
+
import model_util as util
|
13 |
+
|
14 |
+
|
15 |
+
def find_network_using_name(target_network_name, filename):
|
16 |
+
target_class_name = target_network_name + filename
|
17 |
+
module_name = 'sstan_models.networks.' + filename
|
18 |
+
network = util.find_class_in_module(target_class_name, module_name)
|
19 |
+
|
20 |
+
assert issubclass(network, BaseNetwork), \
|
21 |
+
"Class %s should be a subclass of BaseNetwork" % network
|
22 |
+
|
23 |
+
return network
|
24 |
+
|
25 |
+
|
26 |
+
def modify_commandline_options(parser, is_train):
|
27 |
+
opt, _ = parser.parse_known_args()
|
28 |
+
|
29 |
+
netG_cls = find_network_using_name(opt.netG, 'generator')
|
30 |
+
parser = netG_cls.modify_commandline_options(parser, is_train)
|
31 |
+
if is_train:
|
32 |
+
netD_cls = find_network_using_name(opt.netD, 'discriminator')
|
33 |
+
parser = netD_cls.modify_commandline_options(parser, is_train)
|
34 |
+
# netE_cls = find_network_using_name('conv', 'encoder')
|
35 |
+
# parser = netE_cls.modify_commandline_options(parser, is_train)
|
36 |
+
|
37 |
+
return parser
|
38 |
+
|
39 |
+
|
40 |
+
def create_network(cls, opt):
|
41 |
+
net = cls(opt)
|
42 |
+
net.print_network()
|
43 |
+
if len(opt.gpu_ids) > 0:
|
44 |
+
assert(torch.cuda.is_available())
|
45 |
+
net.cuda()
|
46 |
+
net.init_weights(opt.init_type, opt.init_variance)
|
47 |
+
return net
|
48 |
+
|
49 |
+
|
50 |
+
def define_G(opt):
|
51 |
+
netG_cls = find_network_using_name(opt.netG, 'generator')
|
52 |
+
return create_network(netG_cls, opt)
|
53 |
+
|
54 |
+
|
55 |
+
def define_D(opt):
|
56 |
+
netD_cls = find_network_using_name(opt.netD, 'discriminator')
|
57 |
+
return create_network(netD_cls, opt)
|
58 |
+
|
59 |
+
|
60 |
+
def define_E(opt):
|
61 |
+
# there exists only one encoder type
|
62 |
+
netE_cls = find_network_using_name('conv', 'encoder')
|
63 |
+
return create_network(netE_cls, opt)
|
seg2art/sstan_models/networks/architecture.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision
|
10 |
+
import torch.nn.utils.spectral_norm as spectral_norm
|
11 |
+
from sstan_models.networks.normalization import SPADE
|
12 |
+
|
13 |
+
|
14 |
+
# ResNet block that uses SPADE.
|
15 |
+
# It differs from the ResNet block of pix2pixHD in that
|
16 |
+
# it takes in the segmentation map as input, learns the skip connection if necessary,
|
17 |
+
# and applies normalization first and then convolution.
|
18 |
+
# This architecture seemed like a standard architecture for unconditional or
|
19 |
+
# class-conditional GAN architecture using residual block.
|
20 |
+
# The code was inspired from https://github.com/LMescheder/GAN_stability.
|
21 |
+
class SPADEResnetBlock(nn.Module):
|
22 |
+
def __init__(self, fin, fout, opt, feed_code=False):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
self.status = 'train'
|
26 |
+
# Attributes
|
27 |
+
self.learned_shortcut = (fin != fout)
|
28 |
+
fmiddle = min(fin, fout)
|
29 |
+
|
30 |
+
# create conv layers
|
31 |
+
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
|
32 |
+
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
|
33 |
+
if self.learned_shortcut:
|
34 |
+
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
|
35 |
+
|
36 |
+
# apply spectral norm if specified
|
37 |
+
if 'spectral' in opt.norm_G:
|
38 |
+
self.conv_0 = spectral_norm(self.conv_0)
|
39 |
+
self.conv_1 = spectral_norm(self.conv_1)
|
40 |
+
if self.learned_shortcut:
|
41 |
+
self.conv_s = spectral_norm(self.conv_s)
|
42 |
+
|
43 |
+
# define normalization layers
|
44 |
+
spade_config_str = opt.norm_G.replace('spectral', '')
|
45 |
+
|
46 |
+
# Attention related
|
47 |
+
self.channelAtt = ChannelAttention(fout)
|
48 |
+
self.spatialAtt = SpatialAttention()
|
49 |
+
|
50 |
+
self.norm_0 = SPADE(spade_config_str, fin, feed_code, status=self.status, spade_params=[
|
51 |
+
spade_config_str, fin, opt.semantic_nc])
|
52 |
+
self.norm_1 = SPADE(spade_config_str, fmiddle, feed_code, status=self.status, spade_params=[
|
53 |
+
spade_config_str, fmiddle, opt.semantic_nc])
|
54 |
+
if self.learned_shortcut:
|
55 |
+
self.norm_s = SPADE(spade_config_str, fin, feed_code, status=self.status, spade_params=[
|
56 |
+
spade_config_str, fin, opt.semantic_nc])
|
57 |
+
|
58 |
+
# note the resnet block with SPADE also takes in |seg|,
|
59 |
+
# the semantic segmentation map as input
|
60 |
+
def forward(self, x, seg, style_codes=None, self_attention=False):
|
61 |
+
x_s = self.shortcut(x, seg, style_codes=style_codes)
|
62 |
+
|
63 |
+
dx = self.conv_0(self.actvn(
|
64 |
+
self.norm_0(x, seg, style_codes=style_codes)))
|
65 |
+
dx = self.conv_1(self.actvn(self.norm_1(
|
66 |
+
dx, seg, style_codes=style_codes)))
|
67 |
+
|
68 |
+
dx = self.channelAtt(dx) * dx
|
69 |
+
dx = self.spatialAtt(dx) * dx
|
70 |
+
out = x_s + dx
|
71 |
+
|
72 |
+
return out
|
73 |
+
|
74 |
+
def shortcut(self, x, seg, style_codes):
|
75 |
+
if self.learned_shortcut:
|
76 |
+
x_s = self.conv_s(self.norm_s(x, seg, style_codes=style_codes))
|
77 |
+
else:
|
78 |
+
x_s = x
|
79 |
+
return x_s
|
80 |
+
|
81 |
+
def actvn(self, x):
|
82 |
+
return F.leaky_relu(x, 2e-1)
|
83 |
+
|
84 |
+
|
85 |
+
# ResNet block used in pix2pixHD
|
86 |
+
# We keep the same architecture as pix2pixHD.
|
87 |
+
class ResnetBlock(nn.Module):
|
88 |
+
def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3):
|
89 |
+
super().__init__()
|
90 |
+
|
91 |
+
pw = (kernel_size - 1) // 2
|
92 |
+
self.conv_block = nn.Sequential(
|
93 |
+
nn.ReflectionPad2d(pw),
|
94 |
+
norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
|
95 |
+
activation,
|
96 |
+
nn.ReflectionPad2d(pw),
|
97 |
+
norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size))
|
98 |
+
)
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
y = self.conv_block(x)
|
102 |
+
out = x + y
|
103 |
+
return out
|
104 |
+
|
105 |
+
|
106 |
+
# VGG architecter, used for the perceptual loss using a pretrained VGG network
|
107 |
+
class VGG19(torch.nn.Module):
|
108 |
+
def __init__(self, requires_grad=False):
|
109 |
+
super().__init__()
|
110 |
+
vgg_pretrained_features = torchvision.models.vgg19(
|
111 |
+
pretrained=True).features
|
112 |
+
self.slice1 = torch.nn.Sequential()
|
113 |
+
self.slice2 = torch.nn.Sequential()
|
114 |
+
self.slice3 = torch.nn.Sequential()
|
115 |
+
self.slice4 = torch.nn.Sequential()
|
116 |
+
self.slice5 = torch.nn.Sequential()
|
117 |
+
for x in range(2):
|
118 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
119 |
+
for x in range(2, 7):
|
120 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
121 |
+
for x in range(7, 12):
|
122 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
123 |
+
for x in range(12, 21):
|
124 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
125 |
+
for x in range(21, 30):
|
126 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
127 |
+
if not requires_grad:
|
128 |
+
for param in self.parameters():
|
129 |
+
param.requires_grad = False
|
130 |
+
# torch.cuda.empty_cache()
|
131 |
+
def forward(self, X):
|
132 |
+
with torch.cuda.amp.autocast():
|
133 |
+
# with torch.no_grad():
|
134 |
+
h_relu1 = self.slice1(X)
|
135 |
+
h_relu2 = self.slice2(h_relu1)
|
136 |
+
h_relu3 = self.slice3(h_relu2)
|
137 |
+
h_relu4 = self.slice4(h_relu3)
|
138 |
+
h_relu5 = self.slice5(h_relu4)
|
139 |
+
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
140 |
+
return out
|
141 |
+
|
142 |
+
'''
|
143 |
+
class SourceReferenceAttention(nn.Module):
|
144 |
+
"""
|
145 |
+
Source-Reference Attention Layer
|
146 |
+
"""
|
147 |
+
|
148 |
+
def __init__(self, in_planes_s, in_planes_r):
|
149 |
+
"""
|
150 |
+
Parameters
|
151 |
+
----------
|
152 |
+
in_planes_s: int
|
153 |
+
Number of input source feature vector channels.
|
154 |
+
in_planes_r: int
|
155 |
+
Number of input reference feature vector channels.
|
156 |
+
"""
|
157 |
+
super(SourceReferenceAttention, self).__init__()
|
158 |
+
self.query_conv = nn.Conv2d(in_channels=in_planes_s,
|
159 |
+
out_channels=in_planes_s//8, kernel_size=1)
|
160 |
+
self.key_conv = nn.Conv2d(in_channels=in_planes_r,
|
161 |
+
out_channels=in_planes_r//8, kernel_size=1)
|
162 |
+
self.value_conv = nn.Conv2d(in_channels=in_planes_r,
|
163 |
+
out_channels=in_planes_r, kernel_size=1)
|
164 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
165 |
+
self.softmax = nn.Softmax(dim=-1)
|
166 |
+
|
167 |
+
def forward(self, source, reference):
|
168 |
+
"""
|
169 |
+
Parameters
|
170 |
+
----------
|
171 |
+
source : torch.Tensor
|
172 |
+
Source feature maps (B x Cs x Ts x Hs x Ws)
|
173 |
+
reference : torch.Tensor
|
174 |
+
Reference feature maps (B x Cr x Tr x Hr x Wr )
|
175 |
+
Returns :
|
176 |
+
torch.Tensor
|
177 |
+
Source-reference attention value added to the input source features
|
178 |
+
torch.Tensor
|
179 |
+
Attention map (B x Ns x Nt) (Ns=Ts*Hs*Ws, Nr=Tr*Hr*Wr)
|
180 |
+
"""
|
181 |
+
s_batchsize, sC, sH, sW = source.size()
|
182 |
+
r_batchsize, rC, rH, rW = reference.size()
|
183 |
+
proj_query = self.query_conv(source).view(
|
184 |
+
s_batchsize, -1, sH*sW).permute(0, 2, 1)
|
185 |
+
proj_key = self.key_conv(reference).view(r_batchsize, -1, rW*rH)
|
186 |
+
energy = torch.bmm(proj_query, proj_key)
|
187 |
+
attention = self.softmax(energy)
|
188 |
+
proj_value = self.value_conv(reference).view(r_batchsize, -1, rH*rW)
|
189 |
+
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
|
190 |
+
out = out.view(s_batchsize, sC, sH, sW)
|
191 |
+
out = self.gamma*out + source
|
192 |
+
return out, attention
|
193 |
+
'''
|
194 |
+
|
195 |
+
class ChannelAttention(nn.Module):
|
196 |
+
def __init__(self, in_planes, ratio=16):
|
197 |
+
super(ChannelAttention, self).__init__()
|
198 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
199 |
+
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
200 |
+
|
201 |
+
self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
|
202 |
+
self.relu1 = nn.ReLU()
|
203 |
+
self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
|
204 |
+
|
205 |
+
self.sigmoid = nn.Sigmoid()
|
206 |
+
|
207 |
+
def forward(self, x):
|
208 |
+
|
209 |
+
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
|
210 |
+
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
|
211 |
+
out = avg_out + max_out
|
212 |
+
return self.sigmoid(out)
|
213 |
+
|
214 |
+
|
215 |
+
class SpatialAttention(nn.Module):
|
216 |
+
def __init__(self, kernel_size=7):
|
217 |
+
super(SpatialAttention, self).__init__()
|
218 |
+
|
219 |
+
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
|
220 |
+
padding = 3 if kernel_size == 7 else 1
|
221 |
+
|
222 |
+
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
223 |
+
self.sigmoid = nn.Sigmoid()
|
224 |
+
|
225 |
+
def forward(self, x):
|
226 |
+
|
227 |
+
avg_out = torch.mean(x, dim=1, keepdim=True)
|
228 |
+
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
229 |
+
x = torch.cat([avg_out, max_out], dim=1)
|
230 |
+
x = self.conv1(x)
|
231 |
+
return self.sigmoid(x)
|
seg2art/sstan_models/networks/base_network.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.nn import init
|
8 |
+
|
9 |
+
|
10 |
+
class BaseNetwork(nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super(BaseNetwork, self).__init__()
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
def modify_commandline_options(parser, is_train):
|
16 |
+
return parser
|
17 |
+
|
18 |
+
def print_network(self):
|
19 |
+
if isinstance(self, list):
|
20 |
+
self = self[0]
|
21 |
+
num_params = 0
|
22 |
+
for param in self.parameters():
|
23 |
+
num_params += param.numel()
|
24 |
+
print('Network [%s] was created. Total number of parameters: %.1f million. '
|
25 |
+
'To see the architecture, do print(network).'
|
26 |
+
% (type(self).__name__, num_params / 1000000))
|
27 |
+
|
28 |
+
def init_weights(self, init_type='normal', gain=0.02):
|
29 |
+
def init_func(m):
|
30 |
+
classname = m.__class__.__name__
|
31 |
+
if classname.find('BatchNorm2d') != -1:
|
32 |
+
if hasattr(m, 'weight') and m.weight is not None:
|
33 |
+
init.normal_(m.weight.data, 1.0, gain)
|
34 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
35 |
+
init.constant_(m.bias.data, 0.0)
|
36 |
+
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
37 |
+
if init_type == 'normal':
|
38 |
+
init.normal_(m.weight.data, 0.0, gain)
|
39 |
+
elif init_type == 'xavier':
|
40 |
+
init.xavier_normal_(m.weight.data, gain=gain)
|
41 |
+
elif init_type == 'xavier_uniform':
|
42 |
+
init.xavier_uniform_(m.weight.data, gain=1.0)
|
43 |
+
elif init_type == 'kaiming':
|
44 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
45 |
+
elif init_type == 'orthogonal':
|
46 |
+
init.orthogonal_(m.weight.data, gain=gain)
|
47 |
+
elif init_type == 'none': # uses pytorch's default init method
|
48 |
+
m.reset_parameters()
|
49 |
+
else:
|
50 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
51 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
52 |
+
init.constant_(m.bias.data, 0.0)
|
53 |
+
|
54 |
+
self.apply(init_func)
|
55 |
+
|
56 |
+
# propagate to children
|
57 |
+
for m in self.children():
|
58 |
+
if hasattr(m, 'init_weights'):
|
59 |
+
m.init_weights(init_type, gain)
|
seg2art/sstan_models/networks/dual_attention_module.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class ChannelAttention(nn.Module):
|
3 |
+
def __init__(self, in_planes, ratio=16):
|
4 |
+
super(ChannelAttention, self).__init__()
|
5 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
6 |
+
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
7 |
+
|
8 |
+
self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
|
9 |
+
self.relu1 = nn.ReLU()
|
10 |
+
self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
|
11 |
+
|
12 |
+
self.sigmoid = nn.Sigmoid()
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
|
16 |
+
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
|
17 |
+
out = avg_out + max_out
|
18 |
+
return self.sigmoid(out)
|
19 |
+
|
20 |
+
|
21 |
+
class SpatialAttention(nn.Module):
|
22 |
+
def __init__(self, kernel_size=7):
|
23 |
+
super(SpatialAttention, self).__init__()
|
24 |
+
|
25 |
+
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
|
26 |
+
padding = 3 if kernel_size == 7 else 1
|
27 |
+
|
28 |
+
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
29 |
+
self.sigmoid = nn.Sigmoid()
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
avg_out = torch.mean(x, dim=1, keepdim=True)
|
33 |
+
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
34 |
+
x = torch.cat([avg_out, max_out], dim=1)
|
35 |
+
x = self.conv1(x)
|
36 |
+
return self.sigmoid(x)
|
37 |
+
|
38 |
+
|
39 |
+
'''
|
40 |
+
def forward(self, x, seg):
|
41 |
+
x_s = self.shortcut(x, seg)
|
42 |
+
|
43 |
+
dx = self.conv_0(self.actvn(self.norm_0(x, seg)))
|
44 |
+
dx = self.conv_1(self.actvn(self.norm_1(dx, seg)))
|
45 |
+
|
46 |
+
dx = self.channelAtt(dx) * dx
|
47 |
+
dx = self.spatialAtt(dx) * dx
|
48 |
+
out = x_s + dx
|
49 |
+
|
50 |
+
return out
|
51 |
+
'''
|
seg2art/sstan_models/networks/generator.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from sstan_models.networks.base_network import BaseNetwork
|
10 |
+
from sstan_models.networks.normalization import get_nonspade_norm_layer
|
11 |
+
from sstan_models.networks.architecture import ResnetBlock as ResnetBlock
|
12 |
+
from sstan_models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock
|
13 |
+
import numpy as np
|
14 |
+
torch.manual_seed(1234)
|
15 |
+
|
16 |
+
|
17 |
+
class SPADEGenerator(BaseNetwork):
|
18 |
+
@staticmethod
|
19 |
+
def modify_commandline_options(parser, is_train):
|
20 |
+
parser.set_defaults(norm_G='spectralspadesyncbatch3x3')
|
21 |
+
parser.add_argument('--num_upsampling_layers',
|
22 |
+
choices=('normal', 'more', 'most'), default='normal',
|
23 |
+
help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator")
|
24 |
+
|
25 |
+
return parser
|
26 |
+
|
27 |
+
def __init__(self, opt):
|
28 |
+
super().__init__()
|
29 |
+
self.opt = opt
|
30 |
+
nf = opt.ngf
|
31 |
+
|
32 |
+
self.sw, self.sh = self.compute_latent_vector_size(opt)
|
33 |
+
|
34 |
+
# if opt.use_vae:
|
35 |
+
# # In case of VAE, we will sample from random z vector
|
36 |
+
# self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
|
37 |
+
# else:
|
38 |
+
# # Otherwise, we make the network deterministic by starting with
|
39 |
+
# # downsampled segmentation map instead of random z
|
40 |
+
# self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)
|
41 |
+
self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
|
42 |
+
self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, feed_code = True)
|
43 |
+
|
44 |
+
self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, feed_code = True)
|
45 |
+
self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt, feed_code = True)
|
46 |
+
|
47 |
+
self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt, feed_code = True)
|
48 |
+
self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt, feed_code = True)
|
49 |
+
self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt, feed_code = True)
|
50 |
+
self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt, feed_code = False)
|
51 |
+
|
52 |
+
final_nc = nf
|
53 |
+
|
54 |
+
# if opt.num_upsampling_layers == 'most':
|
55 |
+
# print('used?')
|
56 |
+
# self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt)
|
57 |
+
# final_nc = nf // 2
|
58 |
+
|
59 |
+
self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
|
60 |
+
|
61 |
+
self.up = nn.Upsample(scale_factor=2)
|
62 |
+
|
63 |
+
def compute_latent_vector_size(self, opt):
|
64 |
+
if opt.num_upsampling_layers == 'normal':
|
65 |
+
num_up_layers = 5
|
66 |
+
elif opt.num_upsampling_layers == 'more':
|
67 |
+
num_up_layers = 6
|
68 |
+
elif opt.num_upsampling_layers == 'most':
|
69 |
+
num_up_layers = 7
|
70 |
+
else:
|
71 |
+
raise ValueError('opt.num_upsampling_layers [%s] not recognized' %
|
72 |
+
opt.num_upsampling_layers)
|
73 |
+
|
74 |
+
sw = opt.crop_size // (2**num_up_layers)
|
75 |
+
sh = round(sw / opt.aspect_ratio)
|
76 |
+
|
77 |
+
return sw, sh
|
78 |
+
|
79 |
+
def forward(self, input, rgb_img, style_codes=None):
|
80 |
+
with torch.cuda.amp.autocast():
|
81 |
+
seg = input
|
82 |
+
# if self.opt.use_vae:
|
83 |
+
# # we sample z from unit normal and reshape the tensor
|
84 |
+
# if z is None:
|
85 |
+
# z = torch.randn(input.size(0), self.opt.z_dim,
|
86 |
+
# dtype=torch.float32, device=input.get_device())
|
87 |
+
x = self.fc(style_codes)
|
88 |
+
x = x.view(-1, 16 * self.opt.ngf, self.sh, self.sw)
|
89 |
+
# else:
|
90 |
+
# # we downsample segmap and run convolution
|
91 |
+
# x = F.interpolate(seg, size=(self.sh, self.sw))
|
92 |
+
# x = self.fc(x)
|
93 |
+
x = self.head_0(x, seg, style_codes=style_codes)
|
94 |
+
|
95 |
+
x = self.up(x)
|
96 |
+
x = self.G_middle_0(x, seg, style_codes=style_codes)
|
97 |
+
|
98 |
+
if self.opt.num_upsampling_layers == 'more' or \
|
99 |
+
self.opt.num_upsampling_layers == 'most':
|
100 |
+
x = self.up(x)
|
101 |
+
|
102 |
+
x = self.G_middle_1(x, seg, style_codes=style_codes)
|
103 |
+
|
104 |
+
x = self.up(x)
|
105 |
+
x = self.up_0(x, seg, style_codes=style_codes)
|
106 |
+
x = self.up(x)
|
107 |
+
x = self.up_1(x, seg, style_codes=style_codes)
|
108 |
+
x = self.up(x)
|
109 |
+
x = self.up_2(x, seg, style_codes=style_codes)
|
110 |
+
x = self.up(x)
|
111 |
+
x = self.up_3(x, seg)
|
112 |
+
|
113 |
+
# if self.opt.num_upsampling_layers == 'most':
|
114 |
+
# x = self.up(x)
|
115 |
+
# x = self.up_4(x, seg)
|
116 |
+
|
117 |
+
x = self.conv_img(F.leaky_relu(x, 2e-1))
|
118 |
+
x = torch.tanh(x)#F.tanh(x)
|
119 |
+
|
120 |
+
return x#, style_codes
|
121 |
+
|
122 |
+
|
123 |
+
class Pix2PixHDGenerator(BaseNetwork):
|
124 |
+
@staticmethod
|
125 |
+
def modify_commandline_options(parser, is_train):
|
126 |
+
parser.add_argument('--resnet_n_downsample', type=int, default=4, help='number of downsampling layers in netG')
|
127 |
+
parser.add_argument('--resnet_n_blocks', type=int, default=9, help='number of residual blocks in the global generator network')
|
128 |
+
parser.add_argument('--resnet_kernel_size', type=int, default=3,
|
129 |
+
help='kernel size of the resnet block')
|
130 |
+
parser.add_argument('--resnet_initial_kernel_size', type=int, default=7,
|
131 |
+
help='kernel size of the first convolution')
|
132 |
+
parser.set_defaults(norm_G='instance')
|
133 |
+
return parser
|
134 |
+
|
135 |
+
def __init__(self, opt):
|
136 |
+
super().__init__()
|
137 |
+
input_nc = opt.label_nc + (1 if opt.contain_dontcare_label else 0) + (0 if opt.no_instance else 1)
|
138 |
+
|
139 |
+
norm_layer = get_nonspade_norm_layer(opt, opt.norm_G)
|
140 |
+
activation = nn.ReLU(False)
|
141 |
+
|
142 |
+
model = []
|
143 |
+
|
144 |
+
# initial conv
|
145 |
+
model += [nn.ReflectionPad2d(opt.resnet_initial_kernel_size // 2),
|
146 |
+
norm_layer(nn.Conv2d(input_nc, opt.ngf,
|
147 |
+
kernel_size=opt.resnet_initial_kernel_size,
|
148 |
+
padding=0)),
|
149 |
+
activation]
|
150 |
+
|
151 |
+
# downsample
|
152 |
+
mult = 1
|
153 |
+
for i in range(opt.resnet_n_downsample):
|
154 |
+
model += [norm_layer(nn.Conv2d(opt.ngf * mult, opt.ngf * mult * 2,
|
155 |
+
kernel_size=3, stride=2, padding=1)),
|
156 |
+
activation]
|
157 |
+
mult *= 2
|
158 |
+
|
159 |
+
# resnet blocks
|
160 |
+
for i in range(opt.resnet_n_blocks):
|
161 |
+
model += [ResnetBlock(opt.ngf * mult,
|
162 |
+
norm_layer=norm_layer,
|
163 |
+
activation=activation,
|
164 |
+
kernel_size=opt.resnet_kernel_size)]
|
165 |
+
|
166 |
+
# upsample
|
167 |
+
for i in range(opt.resnet_n_downsample):
|
168 |
+
nc_in = int(opt.ngf * mult)
|
169 |
+
nc_out = int((opt.ngf * mult) / 2)
|
170 |
+
model += [norm_layer(nn.ConvTranspose2d(nc_in, nc_out,
|
171 |
+
kernel_size=3, stride=2,
|
172 |
+
padding=1, output_padding=1)),
|
173 |
+
activation]
|
174 |
+
mult = mult // 2
|
175 |
+
|
176 |
+
# final output conv
|
177 |
+
model += [nn.ReflectionPad2d(3),
|
178 |
+
nn.Conv2d(nc_out, opt.output_nc, kernel_size=7, padding=0),
|
179 |
+
nn.Tanh()]
|
180 |
+
|
181 |
+
self.model = nn.Sequential(*model)
|
182 |
+
|
183 |
+
def forward(self, input, z=None):
|
184 |
+
return self.model(input)
|
seg2art/sstan_models/networks/normalization.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
|
6 |
+
import re
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from sstan_models.networks.sync_batchnorm import SynchronizedBatchNorm2d
|
11 |
+
import torch.nn.utils.spectral_norm as spectral_norm
|
12 |
+
|
13 |
+
|
14 |
+
# Returns a function that creates a normalization function
|
15 |
+
# that does not condition on semantic map
|
16 |
+
def get_nonspade_norm_layer(opt, norm_type='instance'):
|
17 |
+
# helper function to get # output channels of the previous layer
|
18 |
+
def get_out_channel(layer):
|
19 |
+
if hasattr(layer, 'out_channels'):
|
20 |
+
return getattr(layer, 'out_channels')
|
21 |
+
return layer.weight.size(0)
|
22 |
+
|
23 |
+
# this function will be returned
|
24 |
+
def add_norm_layer(layer):
|
25 |
+
nonlocal norm_type
|
26 |
+
if norm_type.startswith('spectral'):
|
27 |
+
layer = spectral_norm(layer)
|
28 |
+
subnorm_type = norm_type[len('spectral'):]
|
29 |
+
|
30 |
+
if subnorm_type == 'none' or len(subnorm_type) == 0:
|
31 |
+
return layer
|
32 |
+
|
33 |
+
# remove bias in the previous layer, which is meaningless
|
34 |
+
# since it has no effect after normalization
|
35 |
+
if getattr(layer, 'bias', None) is not None:
|
36 |
+
delattr(layer, 'bias')
|
37 |
+
layer.register_parameter('bias', None)
|
38 |
+
|
39 |
+
if subnorm_type == 'batch':
|
40 |
+
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
|
41 |
+
elif subnorm_type == 'sync_batch':
|
42 |
+
norm_layer = SynchronizedBatchNorm2d(
|
43 |
+
get_out_channel(layer), affine=True)
|
44 |
+
elif subnorm_type == 'instance':
|
45 |
+
norm_layer = nn.InstanceNorm2d(
|
46 |
+
get_out_channel(layer), affine=False)
|
47 |
+
else:
|
48 |
+
raise ValueError(
|
49 |
+
'normalization layer %s is not recognized' % subnorm_type)
|
50 |
+
|
51 |
+
return nn.Sequential(layer, norm_layer)
|
52 |
+
|
53 |
+
return add_norm_layer
|
54 |
+
|
55 |
+
|
56 |
+
# Creates SPADE normalization layer based on the given configuration
|
57 |
+
# SPADE consists of two steps. First, it normalizes the activations using
|
58 |
+
# your favorite normalization method, such as Batch Norm or Instance Norm.
|
59 |
+
# Second, it applies scale and bias to the normalized output, conditioned on
|
60 |
+
# the segmentation map.
|
61 |
+
# The format of |config_text| is spade(norm)(ks), where
|
62 |
+
# (norm) specifies the type of parameter-free normalization.
|
63 |
+
# (e.g. syncbatch, batch, instance)
|
64 |
+
# (ks) specifies the size of kernel in the SPADE module (e.g. 3x3)
|
65 |
+
# Example |config_text| will be spadesyncbatch3x3, or spadeinstance5x5.
|
66 |
+
# Also, the other arguments are
|
67 |
+
# |norm_nc|: the #channels of the normalized activations, hence the output dim of SPADE
|
68 |
+
# |label_nc|: the #channels of the input semantic map, hence the input dim of SPADE
|
69 |
+
class SPADE(nn.Module):
|
70 |
+
def __init__(self, config_text, norm_nc, feed_code, status='train', spade_params=None):
|
71 |
+
super().__init__()
|
72 |
+
|
73 |
+
self.style_length = 256
|
74 |
+
# self.noise_var = nn.Parameter(torch.zeros(norm_nc), requires_grad=True)
|
75 |
+
self.Spade = SPADE_ori(*spade_params)
|
76 |
+
|
77 |
+
|
78 |
+
assert config_text.startswith('spade')
|
79 |
+
parsed = re.search('spade(\D+)(\d)x\d', config_text)
|
80 |
+
param_free_norm_type = str(parsed.group(1))
|
81 |
+
ks = int(parsed.group(2))
|
82 |
+
pw = ks // 2
|
83 |
+
|
84 |
+
if param_free_norm_type == 'instance':
|
85 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
|
86 |
+
elif param_free_norm_type == 'syncbatch':
|
87 |
+
self.param_free_norm = SynchronizedBatchNorm2d(
|
88 |
+
norm_nc, affine=False)
|
89 |
+
elif param_free_norm_type == 'batch':
|
90 |
+
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
|
91 |
+
else:
|
92 |
+
raise ValueError('%s is not a recognized param-free norm type in SPADE'
|
93 |
+
% param_free_norm_type)
|
94 |
+
|
95 |
+
# self.create_gamma_beta_fc_layers()
|
96 |
+
if feed_code:
|
97 |
+
self.blending_gamma = nn.Parameter(torch.zeros(1), requires_grad=True)
|
98 |
+
self.blending_beta = nn.Parameter(torch.zeros(1), requires_grad=True)
|
99 |
+
self.conv_gamma = nn.Conv2d(
|
100 |
+
self.style_length, norm_nc, kernel_size=ks, padding=pw)
|
101 |
+
self.conv_beta = nn.Conv2d(
|
102 |
+
self.style_length, norm_nc, kernel_size=ks, padding=pw)
|
103 |
+
|
104 |
+
def forward(self, x, segmap, style_codes=None):
|
105 |
+
if style_codes is None:
|
106 |
+
input_code = False
|
107 |
+
else:
|
108 |
+
input_code = True
|
109 |
+
|
110 |
+
# Part 1. generate parameter-free normalized activations
|
111 |
+
# added_noise = (torch.randn(
|
112 |
+
# x.shape[0], x.shape[3], x.shape[2], 1).cuda() * self.noise_var).transpose(1, 3)
|
113 |
+
normalized = self.param_free_norm(x)
|
114 |
+
|
115 |
+
# Part 2. produce scaling and bias conditioned on semantic map
|
116 |
+
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
|
117 |
+
|
118 |
+
if input_code:
|
119 |
+
[b_size, f_size, h_size, w_size] = normalized.shape
|
120 |
+
middle_avg = torch.zeros(
|
121 |
+
(b_size, self.style_length, h_size, w_size), device=normalized.device)
|
122 |
+
|
123 |
+
for i in range(b_size):
|
124 |
+
|
125 |
+
middle_mu = F.relu((style_codes[i]))
|
126 |
+
|
127 |
+
middle_mu = middle_mu.reshape(self.style_length, 1).expand(
|
128 |
+
self.style_length, h_size*w_size)
|
129 |
+
middle_mu = middle_mu.reshape(
|
130 |
+
self.style_length, h_size, w_size)
|
131 |
+
middle_avg[i] = middle_mu
|
132 |
+
|
133 |
+
gamma_avg = self.conv_gamma(middle_avg)
|
134 |
+
beta_avg = self.conv_beta(middle_avg)
|
135 |
+
|
136 |
+
gamma_spade, beta_spade = self.Spade(segmap)
|
137 |
+
|
138 |
+
gamma_alpha = torch.sigmoid(self.blending_gamma)#F.sigmoid(self.blending_gamma)
|
139 |
+
beta_alpha = torch.sigmoid(self.blending_gamma)#F.sigmoid(self.blending_beta)
|
140 |
+
|
141 |
+
gamma_final = gamma_alpha * gamma_avg + \
|
142 |
+
(1 - gamma_alpha) * gamma_spade
|
143 |
+
|
144 |
+
beta_final = beta_alpha * beta_avg + (1 - beta_alpha) * beta_spade
|
145 |
+
|
146 |
+
out = normalized * (1 + gamma_final) + beta_final
|
147 |
+
else:
|
148 |
+
gamma_spade, beta_spade = self.Spade(segmap)
|
149 |
+
gamma_final = gamma_spade
|
150 |
+
beta_final = beta_spade
|
151 |
+
out = normalized * (1 + gamma_final) + beta_final
|
152 |
+
return out
|
153 |
+
|
154 |
+
# def create_gamma_beta_fc_layers(self):
|
155 |
+
|
156 |
+
# # These codes should be replaced with torch.nn.ModuleList
|
157 |
+
|
158 |
+
# style_length = self.style_length
|
159 |
+
|
160 |
+
# self.fc_mu0 = nn.Linear(style_length, style_length)
|
161 |
+
# self.fc_mu1 = nn.Linear(style_length, style_length)
|
162 |
+
# self.fc_mu2 = nn.Linear(style_length, style_length)
|
163 |
+
# self.fc_mu3 = nn.Linear(style_length, style_length)
|
164 |
+
# self.fc_mu4 = nn.Linear(style_length, style_length)
|
165 |
+
# self.fc_mu5 = nn.Linear(style_length, style_length)
|
166 |
+
# self.fc_mu6 = nn.Linear(style_length, style_length)
|
167 |
+
# self.fc_mu7 = nn.Linear(style_length, style_length)
|
168 |
+
# self.fc_mu8 = nn.Linear(style_length, style_length)
|
169 |
+
# self.fc_mu9 = nn.Linear(style_length, style_length)
|
170 |
+
# self.fc_mu10 = nn.Linear(style_length, style_length)
|
171 |
+
# self.fc_mu11 = nn.Linear(style_length, style_length)
|
172 |
+
# self.fc_mu12 = nn.Linear(style_length, style_length)
|
173 |
+
# self.fc_mu13 = nn.Linear(style_length, style_length)
|
174 |
+
# self.fc_mu14 = nn.Linear(style_length, style_length)
|
175 |
+
# self.fc_mu15 = nn.Linear(style_length, style_length)
|
176 |
+
# self.fc_mu16 = nn.Linear(style_length, style_length)
|
177 |
+
# self.fc_mu17 = nn.Linear(style_length, style_length)
|
178 |
+
# self.fc_mu18 = nn.Linear(style_length, style_length)
|
179 |
+
|
180 |
+
|
181 |
+
class SPADE_ori(nn.Module):
|
182 |
+
def __init__(self, config_text, norm_nc, label_nc):
|
183 |
+
super().__init__()
|
184 |
+
|
185 |
+
assert config_text.startswith('spade')
|
186 |
+
parsed = re.search('spade(\D+)(\d)x\d', config_text)
|
187 |
+
param_free_norm_type = str(parsed.group(1))
|
188 |
+
ks = int(parsed.group(2))
|
189 |
+
|
190 |
+
if param_free_norm_type == 'instance':
|
191 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
|
192 |
+
elif param_free_norm_type == 'syncbatch':
|
193 |
+
self.param_free_norm = SynchronizedBatchNorm2d(
|
194 |
+
norm_nc, affine=False)
|
195 |
+
elif param_free_norm_type == 'batch':
|
196 |
+
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
|
197 |
+
else:
|
198 |
+
raise ValueError('%s is not a recognized param-free norm type in SPADE'
|
199 |
+
% param_free_norm_type)
|
200 |
+
|
201 |
+
# The dimension of the intermediate embedding space. Yes, hardcoded.
|
202 |
+
nhidden = 128
|
203 |
+
|
204 |
+
pw = ks // 2
|
205 |
+
self.mlp_shared = nn.Sequential(
|
206 |
+
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
|
207 |
+
nn.ReLU()
|
208 |
+
)
|
209 |
+
|
210 |
+
self.mlp_gamma = nn.Conv2d(
|
211 |
+
nhidden, norm_nc, kernel_size=ks, padding=pw)
|
212 |
+
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
|
213 |
+
|
214 |
+
def forward(self, segmap):
|
215 |
+
|
216 |
+
inputmap = segmap
|
217 |
+
|
218 |
+
actv = self.mlp_shared(inputmap)
|
219 |
+
gamma = self.mlp_gamma(actv)
|
220 |
+
beta = self.mlp_beta(actv)
|
221 |
+
|
222 |
+
return gamma, beta
|
seg2art/sstan_models/networks/sync_batchnorm/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : __init__.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
12 |
+
from .batchnorm import convert_model
|
13 |
+
from .replicate import DataParallelWithCallback, patch_replication_callback
|
seg2art/sstan_models/networks/sync_batchnorm/batchnorm.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import collections
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
17 |
+
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
18 |
+
|
19 |
+
from .comm import SyncMaster
|
20 |
+
from .replicate import DataParallelWithCallback
|
21 |
+
|
22 |
+
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d',
|
23 |
+
'SynchronizedBatchNorm3d', 'convert_model']
|
24 |
+
|
25 |
+
|
26 |
+
def _sum_ft(tensor):
|
27 |
+
"""sum over the first and last dimention"""
|
28 |
+
return tensor.sum(dim=0).sum(dim=-1)
|
29 |
+
|
30 |
+
|
31 |
+
def _unsqueeze_ft(tensor):
|
32 |
+
"""add new dementions at the front and the tail"""
|
33 |
+
return tensor.unsqueeze(0).unsqueeze(-1)
|
34 |
+
|
35 |
+
|
36 |
+
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
37 |
+
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
38 |
+
|
39 |
+
|
40 |
+
class _SynchronizedBatchNorm(_BatchNorm):
|
41 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
|
42 |
+
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
43 |
+
|
44 |
+
self._sync_master = SyncMaster(self._data_parallel_master)
|
45 |
+
|
46 |
+
self._is_parallel = False
|
47 |
+
self._parallel_id = None
|
48 |
+
self._slave_pipe = None
|
49 |
+
|
50 |
+
def forward(self, input):
|
51 |
+
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
52 |
+
if not (self._is_parallel and self.training):
|
53 |
+
return F.batch_norm(
|
54 |
+
input, self.running_mean, self.running_var, self.weight, self.bias,
|
55 |
+
self.training, self.momentum, self.eps)
|
56 |
+
|
57 |
+
# Resize the input to (B, C, -1).
|
58 |
+
input_shape = input.size()
|
59 |
+
input = input.view(input.size(0), self.num_features, -1)
|
60 |
+
|
61 |
+
# Compute the sum and square-sum.
|
62 |
+
sum_size = input.size(0) * input.size(2)
|
63 |
+
input_sum = _sum_ft(input)
|
64 |
+
input_ssum = _sum_ft(input ** 2)
|
65 |
+
|
66 |
+
# Reduce-and-broadcast the statistics.
|
67 |
+
if self._parallel_id == 0:
|
68 |
+
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
69 |
+
else:
|
70 |
+
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
71 |
+
|
72 |
+
# Compute the output.
|
73 |
+
if self.affine:
|
74 |
+
# MJY:: Fuse the multiplication for speed.
|
75 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
76 |
+
else:
|
77 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
78 |
+
|
79 |
+
# Reshape it.
|
80 |
+
return output.view(input_shape)
|
81 |
+
|
82 |
+
def __data_parallel_replicate__(self, ctx, copy_id):
|
83 |
+
self._is_parallel = True
|
84 |
+
self._parallel_id = copy_id
|
85 |
+
|
86 |
+
# parallel_id == 0 means master device.
|
87 |
+
if self._parallel_id == 0:
|
88 |
+
ctx.sync_master = self._sync_master
|
89 |
+
else:
|
90 |
+
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
91 |
+
|
92 |
+
def _data_parallel_master(self, intermediates):
|
93 |
+
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
94 |
+
|
95 |
+
# Always using same "device order" makes the ReduceAdd operation faster.
|
96 |
+
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
97 |
+
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
98 |
+
|
99 |
+
to_reduce = [i[1][:2] for i in intermediates]
|
100 |
+
to_reduce = [j for i in to_reduce for j in i] # flatten
|
101 |
+
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
102 |
+
|
103 |
+
sum_size = sum([i[1].sum_size for i in intermediates])
|
104 |
+
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
105 |
+
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
106 |
+
|
107 |
+
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
108 |
+
|
109 |
+
outputs = []
|
110 |
+
for i, rec in enumerate(intermediates):
|
111 |
+
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
112 |
+
|
113 |
+
return outputs
|
114 |
+
|
115 |
+
def _compute_mean_std(self, sum_, ssum, size):
|
116 |
+
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
117 |
+
also maintains the moving average on the master device."""
|
118 |
+
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
119 |
+
mean = sum_ / size
|
120 |
+
sumvar = ssum - sum_ * mean
|
121 |
+
unbias_var = sumvar / (size - 1)
|
122 |
+
bias_var = sumvar / size
|
123 |
+
|
124 |
+
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
125 |
+
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
126 |
+
|
127 |
+
return mean, bias_var.clamp(self.eps) ** -0.5
|
128 |
+
|
129 |
+
|
130 |
+
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
131 |
+
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
132 |
+
mini-batch.
|
133 |
+
|
134 |
+
.. math::
|
135 |
+
|
136 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
137 |
+
|
138 |
+
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
139 |
+
standard-deviation are reduced across all devices during training.
|
140 |
+
|
141 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
142 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
143 |
+
the statistics only on that device, which accelerated the computation and
|
144 |
+
is also easy to implement, but the statistics might be inaccurate.
|
145 |
+
Instead, in this synchronized version, the statistics will be computed
|
146 |
+
over all training samples distributed on multiple devices.
|
147 |
+
|
148 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
149 |
+
as the built-in PyTorch implementation.
|
150 |
+
|
151 |
+
The mean and standard-deviation are calculated per-dimension over
|
152 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
153 |
+
of size C (where C is the input size).
|
154 |
+
|
155 |
+
During training, this layer keeps a running estimate of its computed mean
|
156 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
157 |
+
|
158 |
+
During evaluation, this running mean/variance is used for normalization.
|
159 |
+
|
160 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
161 |
+
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
162 |
+
|
163 |
+
Args:
|
164 |
+
num_features: num_features from an expected input of size
|
165 |
+
`batch_size x num_features [x width]`
|
166 |
+
eps: a value added to the denominator for numerical stability.
|
167 |
+
Default: 1e-5
|
168 |
+
momentum: the value used for the running_mean and running_var
|
169 |
+
computation. Default: 0.1
|
170 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
171 |
+
affine parameters. Default: ``True``
|
172 |
+
|
173 |
+
Shape:
|
174 |
+
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
175 |
+
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
176 |
+
|
177 |
+
Examples:
|
178 |
+
>>> # With Learnable Parameters
|
179 |
+
>>> m = SynchronizedBatchNorm1d(100)
|
180 |
+
>>> # Without Learnable Parameters
|
181 |
+
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
182 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
183 |
+
>>> output = m(input)
|
184 |
+
"""
|
185 |
+
|
186 |
+
def _check_input_dim(self, input):
|
187 |
+
if input.dim() != 2 and input.dim() != 3:
|
188 |
+
raise ValueError('expected 2D or 3D input (got {}D input)'
|
189 |
+
.format(input.dim()))
|
190 |
+
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
191 |
+
|
192 |
+
|
193 |
+
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
194 |
+
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
195 |
+
of 3d inputs
|
196 |
+
|
197 |
+
.. math::
|
198 |
+
|
199 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
200 |
+
|
201 |
+
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
202 |
+
standard-deviation are reduced across all devices during training.
|
203 |
+
|
204 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
205 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
206 |
+
the statistics only on that device, which accelerated the computation and
|
207 |
+
is also easy to implement, but the statistics might be inaccurate.
|
208 |
+
Instead, in this synchronized version, the statistics will be computed
|
209 |
+
over all training samples distributed on multiple devices.
|
210 |
+
|
211 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
212 |
+
as the built-in PyTorch implementation.
|
213 |
+
|
214 |
+
The mean and standard-deviation are calculated per-dimension over
|
215 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
216 |
+
of size C (where C is the input size).
|
217 |
+
|
218 |
+
During training, this layer keeps a running estimate of its computed mean
|
219 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
220 |
+
|
221 |
+
During evaluation, this running mean/variance is used for normalization.
|
222 |
+
|
223 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
224 |
+
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
225 |
+
|
226 |
+
Args:
|
227 |
+
num_features: num_features from an expected input of
|
228 |
+
size batch_size x num_features x height x width
|
229 |
+
eps: a value added to the denominator for numerical stability.
|
230 |
+
Default: 1e-5
|
231 |
+
momentum: the value used for the running_mean and running_var
|
232 |
+
computation. Default: 0.1
|
233 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
234 |
+
affine parameters. Default: ``True``
|
235 |
+
|
236 |
+
Shape:
|
237 |
+
- Input: :math:`(N, C, H, W)`
|
238 |
+
- Output: :math:`(N, C, H, W)` (same shape as input)
|
239 |
+
|
240 |
+
Examples:
|
241 |
+
>>> # With Learnable Parameters
|
242 |
+
>>> m = SynchronizedBatchNorm2d(100)
|
243 |
+
>>> # Without Learnable Parameters
|
244 |
+
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
245 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
246 |
+
>>> output = m(input)
|
247 |
+
"""
|
248 |
+
|
249 |
+
def _check_input_dim(self, input):
|
250 |
+
if input.dim() != 4:
|
251 |
+
raise ValueError('expected 4D input (got {}D input)'
|
252 |
+
.format(input.dim()))
|
253 |
+
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
254 |
+
|
255 |
+
|
256 |
+
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
257 |
+
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
258 |
+
of 4d inputs
|
259 |
+
|
260 |
+
.. math::
|
261 |
+
|
262 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
263 |
+
|
264 |
+
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
265 |
+
standard-deviation are reduced across all devices during training.
|
266 |
+
|
267 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
268 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
269 |
+
the statistics only on that device, which accelerated the computation and
|
270 |
+
is also easy to implement, but the statistics might be inaccurate.
|
271 |
+
Instead, in this synchronized version, the statistics will be computed
|
272 |
+
over all training samples distributed on multiple devices.
|
273 |
+
|
274 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
275 |
+
as the built-in PyTorch implementation.
|
276 |
+
|
277 |
+
The mean and standard-deviation are calculated per-dimension over
|
278 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
279 |
+
of size C (where C is the input size).
|
280 |
+
|
281 |
+
During training, this layer keeps a running estimate of its computed mean
|
282 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
283 |
+
|
284 |
+
During evaluation, this running mean/variance is used for normalization.
|
285 |
+
|
286 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
287 |
+
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
288 |
+
or Spatio-temporal BatchNorm
|
289 |
+
|
290 |
+
Args:
|
291 |
+
num_features: num_features from an expected input of
|
292 |
+
size batch_size x num_features x depth x height x width
|
293 |
+
eps: a value added to the denominator for numerical stability.
|
294 |
+
Default: 1e-5
|
295 |
+
momentum: the value used for the running_mean and running_var
|
296 |
+
computation. Default: 0.1
|
297 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
298 |
+
affine parameters. Default: ``True``
|
299 |
+
|
300 |
+
Shape:
|
301 |
+
- Input: :math:`(N, C, D, H, W)`
|
302 |
+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
303 |
+
|
304 |
+
Examples:
|
305 |
+
>>> # With Learnable Parameters
|
306 |
+
>>> m = SynchronizedBatchNorm3d(100)
|
307 |
+
>>> # Without Learnable Parameters
|
308 |
+
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
309 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
310 |
+
>>> output = m(input)
|
311 |
+
"""
|
312 |
+
|
313 |
+
def _check_input_dim(self, input):
|
314 |
+
if input.dim() != 5:
|
315 |
+
raise ValueError('expected 5D input (got {}D input)'
|
316 |
+
.format(input.dim()))
|
317 |
+
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
318 |
+
|
319 |
+
|
320 |
+
def convert_model(module):
|
321 |
+
"""Traverse the input module and its child recursively
|
322 |
+
and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
|
323 |
+
to SynchronizedBatchNorm*N*d
|
324 |
+
|
325 |
+
Args:
|
326 |
+
module: the input module needs to be convert to SyncBN model
|
327 |
+
|
328 |
+
Examples:
|
329 |
+
>>> import torch.nn as nn
|
330 |
+
>>> import torchvision
|
331 |
+
>>> # m is a standard pytorch model
|
332 |
+
>>> m = torchvision.models.resnet18(True)
|
333 |
+
>>> m = nn.DataParallel(m)
|
334 |
+
>>> # after convert, m is using SyncBN
|
335 |
+
>>> m = convert_model(m)
|
336 |
+
"""
|
337 |
+
if isinstance(module, torch.nn.DataParallel):
|
338 |
+
mod = module.module
|
339 |
+
mod = convert_model(mod)
|
340 |
+
mod = DataParallelWithCallback(mod)
|
341 |
+
return mod
|
342 |
+
|
343 |
+
mod = module
|
344 |
+
for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
|
345 |
+
torch.nn.modules.batchnorm.BatchNorm2d,
|
346 |
+
torch.nn.modules.batchnorm.BatchNorm3d],
|
347 |
+
[SynchronizedBatchNorm1d,
|
348 |
+
SynchronizedBatchNorm2d,
|
349 |
+
SynchronizedBatchNorm3d]):
|
350 |
+
if isinstance(module, pth_module):
|
351 |
+
mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
|
352 |
+
mod.running_mean = module.running_mean
|
353 |
+
mod.running_var = module.running_var
|
354 |
+
if module.affine:
|
355 |
+
mod.weight.data = module.weight.data.clone().detach()
|
356 |
+
mod.bias.data = module.bias.data.clone().detach()
|
357 |
+
|
358 |
+
for name, child in module.named_children():
|
359 |
+
mod.add_module(name, convert_model(child))
|
360 |
+
|
361 |
+
return mod
|
seg2art/sstan_models/networks/sync_batchnorm/batchnorm_reimpl.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# File : batchnorm_reimpl.py
|
4 |
+
# Author : acgtyrant
|
5 |
+
# Date : 11/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.init as init
|
14 |
+
|
15 |
+
__all__ = ['BatchNormReimpl']
|
16 |
+
|
17 |
+
|
18 |
+
class BatchNorm2dReimpl(nn.Module):
|
19 |
+
"""
|
20 |
+
A re-implementation of batch normalization, used for testing the numerical
|
21 |
+
stability.
|
22 |
+
|
23 |
+
Author: acgtyrant
|
24 |
+
See also:
|
25 |
+
https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
|
26 |
+
"""
|
27 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.num_features = num_features
|
31 |
+
self.eps = eps
|
32 |
+
self.momentum = momentum
|
33 |
+
self.weight = nn.Parameter(torch.empty(num_features))
|
34 |
+
self.bias = nn.Parameter(torch.empty(num_features))
|
35 |
+
self.register_buffer('running_mean', torch.zeros(num_features))
|
36 |
+
self.register_buffer('running_var', torch.ones(num_features))
|
37 |
+
self.reset_parameters()
|
38 |
+
|
39 |
+
def reset_running_stats(self):
|
40 |
+
self.running_mean.zero_()
|
41 |
+
self.running_var.fill_(1)
|
42 |
+
|
43 |
+
def reset_parameters(self):
|
44 |
+
self.reset_running_stats()
|
45 |
+
init.uniform_(self.weight)
|
46 |
+
init.zeros_(self.bias)
|
47 |
+
|
48 |
+
def forward(self, input_):
|
49 |
+
batchsize, channels, height, width = input_.size()
|
50 |
+
numel = batchsize * height * width
|
51 |
+
input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
|
52 |
+
sum_ = input_.sum(1)
|
53 |
+
sum_of_square = input_.pow(2).sum(1)
|
54 |
+
mean = sum_ / numel
|
55 |
+
sumvar = sum_of_square - sum_ * mean
|
56 |
+
|
57 |
+
self.running_mean = (
|
58 |
+
(1 - self.momentum) * self.running_mean
|
59 |
+
+ self.momentum * mean.detach()
|
60 |
+
)
|
61 |
+
unbias_var = sumvar / (numel - 1)
|
62 |
+
self.running_var = (
|
63 |
+
(1 - self.momentum) * self.running_var
|
64 |
+
+ self.momentum * unbias_var.detach()
|
65 |
+
)
|
66 |
+
|
67 |
+
bias_var = sumvar / numel
|
68 |
+
inv_std = 1 / (bias_var + self.eps).pow(0.5)
|
69 |
+
output = (
|
70 |
+
(input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
|
71 |
+
self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
|
72 |
+
|
73 |
+
return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
|
74 |
+
|
seg2art/sstan_models/networks/sync_batchnorm/comm.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : comm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import queue
|
12 |
+
import collections
|
13 |
+
import threading
|
14 |
+
|
15 |
+
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
16 |
+
|
17 |
+
|
18 |
+
class FutureResult(object):
|
19 |
+
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
self._result = None
|
23 |
+
self._lock = threading.Lock()
|
24 |
+
self._cond = threading.Condition(self._lock)
|
25 |
+
|
26 |
+
def put(self, result):
|
27 |
+
with self._lock:
|
28 |
+
assert self._result is None, 'Previous result has\'t been fetched.'
|
29 |
+
self._result = result
|
30 |
+
self._cond.notify()
|
31 |
+
|
32 |
+
def get(self):
|
33 |
+
with self._lock:
|
34 |
+
if self._result is None:
|
35 |
+
self._cond.wait()
|
36 |
+
|
37 |
+
res = self._result
|
38 |
+
self._result = None
|
39 |
+
return res
|
40 |
+
|
41 |
+
|
42 |
+
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
43 |
+
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
44 |
+
|
45 |
+
|
46 |
+
class SlavePipe(_SlavePipeBase):
|
47 |
+
"""Pipe for master-slave communication."""
|
48 |
+
|
49 |
+
def run_slave(self, msg):
|
50 |
+
self.queue.put((self.identifier, msg))
|
51 |
+
ret = self.result.get()
|
52 |
+
self.queue.put(True)
|
53 |
+
return ret
|
54 |
+
|
55 |
+
|
56 |
+
class SyncMaster(object):
|
57 |
+
"""An abstract `SyncMaster` object.
|
58 |
+
|
59 |
+
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
60 |
+
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
61 |
+
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
62 |
+
and passed to a registered callback.
|
63 |
+
- After receiving the messages, the master device should gather the information and determine to message passed
|
64 |
+
back to each slave devices.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, master_callback):
|
68 |
+
"""
|
69 |
+
|
70 |
+
Args:
|
71 |
+
master_callback: a callback to be invoked after having collected messages from slave devices.
|
72 |
+
"""
|
73 |
+
self._master_callback = master_callback
|
74 |
+
self._queue = queue.Queue()
|
75 |
+
self._registry = collections.OrderedDict()
|
76 |
+
self._activated = False
|
77 |
+
|
78 |
+
def __getstate__(self):
|
79 |
+
return {'master_callback': self._master_callback}
|
80 |
+
|
81 |
+
def __setstate__(self, state):
|
82 |
+
self.__init__(state['master_callback'])
|
83 |
+
|
84 |
+
def register_slave(self, identifier):
|
85 |
+
"""
|
86 |
+
Register an slave device.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
identifier: an identifier, usually is the device id.
|
90 |
+
|
91 |
+
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
92 |
+
|
93 |
+
"""
|
94 |
+
if self._activated:
|
95 |
+
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
96 |
+
self._activated = False
|
97 |
+
self._registry.clear()
|
98 |
+
future = FutureResult()
|
99 |
+
self._registry[identifier] = _MasterRegistry(future)
|
100 |
+
return SlavePipe(identifier, self._queue, future)
|
101 |
+
|
102 |
+
def run_master(self, master_msg):
|
103 |
+
"""
|
104 |
+
Main entry for the master device in each forward pass.
|
105 |
+
The messages were first collected from each devices (including the master device), and then
|
106 |
+
an callback will be invoked to compute the message to be sent back to each devices
|
107 |
+
(including the master device).
|
108 |
+
|
109 |
+
Args:
|
110 |
+
master_msg: the message that the master want to send to itself. This will be placed as the first
|
111 |
+
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
112 |
+
|
113 |
+
Returns: the message to be sent back to the master device.
|
114 |
+
|
115 |
+
"""
|
116 |
+
self._activated = True
|
117 |
+
|
118 |
+
intermediates = [(0, master_msg)]
|
119 |
+
for i in range(self.nr_slaves):
|
120 |
+
intermediates.append(self._queue.get())
|
121 |
+
|
122 |
+
results = self._master_callback(intermediates)
|
123 |
+
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
124 |
+
|
125 |
+
for i, res in results:
|
126 |
+
if i == 0:
|
127 |
+
continue
|
128 |
+
self._registry[i].result.put(res)
|
129 |
+
|
130 |
+
for i in range(self.nr_slaves):
|
131 |
+
assert self._queue.get() is True
|
132 |
+
|
133 |
+
return results[0][1]
|
134 |
+
|
135 |
+
@property
|
136 |
+
def nr_slaves(self):
|
137 |
+
return len(self._registry)
|
seg2art/sstan_models/networks/sync_batchnorm/replicate.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : replicate.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import functools
|
12 |
+
|
13 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
'CallbackContext',
|
17 |
+
'execute_replication_callbacks',
|
18 |
+
'DataParallelWithCallback',
|
19 |
+
'patch_replication_callback'
|
20 |
+
]
|
21 |
+
|
22 |
+
|
23 |
+
class CallbackContext(object):
|
24 |
+
pass
|
25 |
+
|
26 |
+
|
27 |
+
def execute_replication_callbacks(modules):
|
28 |
+
"""
|
29 |
+
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
30 |
+
|
31 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
32 |
+
|
33 |
+
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
34 |
+
(shared among multiple copies of this module on different devices).
|
35 |
+
Through this context, different copies can share some information.
|
36 |
+
|
37 |
+
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
38 |
+
of any slave copies.
|
39 |
+
"""
|
40 |
+
master_copy = modules[0]
|
41 |
+
nr_modules = len(list(master_copy.modules()))
|
42 |
+
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
43 |
+
|
44 |
+
for i, module in enumerate(modules):
|
45 |
+
for j, m in enumerate(module.modules()):
|
46 |
+
if hasattr(m, '__data_parallel_replicate__'):
|
47 |
+
m.__data_parallel_replicate__(ctxs[j], i)
|
48 |
+
|
49 |
+
|
50 |
+
class DataParallelWithCallback(DataParallel):
|
51 |
+
"""
|
52 |
+
Data Parallel with a replication callback.
|
53 |
+
|
54 |
+
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
55 |
+
original `replicate` function.
|
56 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
57 |
+
|
58 |
+
Examples:
|
59 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
60 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
61 |
+
# sync_bn.__data_parallel_replicate__ will be invoked.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def replicate(self, module, device_ids):
|
65 |
+
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
66 |
+
execute_replication_callbacks(modules)
|
67 |
+
return modules
|
68 |
+
|
69 |
+
|
70 |
+
def patch_replication_callback(data_parallel):
|
71 |
+
"""
|
72 |
+
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
73 |
+
Useful when you have customized `DataParallel` implementation.
|
74 |
+
|
75 |
+
Examples:
|
76 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
77 |
+
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
78 |
+
> patch_replication_callback(sync_bn)
|
79 |
+
# this is equivalent to
|
80 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
81 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
82 |
+
"""
|
83 |
+
|
84 |
+
assert isinstance(data_parallel, DataParallel)
|
85 |
+
|
86 |
+
old_replicate = data_parallel.replicate
|
87 |
+
|
88 |
+
@functools.wraps(old_replicate)
|
89 |
+
def new_replicate(module, device_ids):
|
90 |
+
modules = old_replicate(module, device_ids)
|
91 |
+
execute_replication_callbacks(modules)
|
92 |
+
return modules
|
93 |
+
|
94 |
+
data_parallel.replicate = new_replicate
|
seg2art/sstan_models/networks/sync_batchnorm/unittest.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : unittest.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import unittest
|
12 |
+
import torch
|
13 |
+
|
14 |
+
|
15 |
+
class TorchTestCase(unittest.TestCase):
|
16 |
+
def assertTensorClose(self, x, y):
|
17 |
+
adiff = float((x - y).abs().max())
|
18 |
+
if (y == 0).all():
|
19 |
+
rdiff = 'NaN'
|
20 |
+
else:
|
21 |
+
rdiff = float((adiff / y).abs().max())
|
22 |
+
|
23 |
+
message = (
|
24 |
+
'Tensor close check failed\n'
|
25 |
+
'adiff={}\n'
|
26 |
+
'rdiff={}\n'
|
27 |
+
).format(adiff, rdiff)
|
28 |
+
self.assertTrue(torch.allclose(x, y), message)
|
29 |
+
|
seg2art/sstan_models/pix2pix_model.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
import sstan_models.networks as networks
|
9 |
+
import model_util as util
|
10 |
+
|
11 |
+
|
12 |
+
class Pix2PixModel(torch.nn.Module):
|
13 |
+
@staticmethod
|
14 |
+
def modify_commandline_options(parser, is_train):
|
15 |
+
networks.modify_commandline_options(parser, is_train)
|
16 |
+
return parser
|
17 |
+
|
18 |
+
def __init__(self, opt):
|
19 |
+
super().__init__()
|
20 |
+
self.opt = opt
|
21 |
+
self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() else torch.FloatTensor
|
22 |
+
self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() else torch.ByteTensor
|
23 |
+
|
24 |
+
self.netG, self.netD, self.netE = self.initialize_networks(opt)
|
25 |
+
|
26 |
+
# set loss functions
|
27 |
+
if opt.isTrain:
|
28 |
+
self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)
|
29 |
+
self.criterionFeat = torch.nn.L1Loss()
|
30 |
+
if not opt.no_vgg_loss:
|
31 |
+
self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
|
32 |
+
if opt.use_vae:
|
33 |
+
self.KLDLoss = networks.KLDLoss()
|
34 |
+
|
35 |
+
# Entry point for all calls involving forward pass
|
36 |
+
# of deep networks. We used this approach since DataParallel module
|
37 |
+
# can't parallelize custom functions, we branch to different
|
38 |
+
# routines based on |mode|.
|
39 |
+
def forward(self, data, mode, style_codes=None):
|
40 |
+
input_semantics, real_image = self.preprocess_input(data)
|
41 |
+
domain = None
|
42 |
+
|
43 |
+
# print(torch.cuda.memory_cached(0))
|
44 |
+
if mode == "generator":
|
45 |
+
g_loss, generated = self.compute_generator_loss(input_semantics, real_image, domain)
|
46 |
+
return g_loss, generated
|
47 |
+
elif mode == "discriminator":
|
48 |
+
d_loss = self.compute_discriminator_loss(input_semantics, real_image, domain)
|
49 |
+
return d_loss
|
50 |
+
elif mode == "encode_only":
|
51 |
+
_, mu, logvar = self.encode_z(real_image, domain)
|
52 |
+
return mu, logvar
|
53 |
+
elif mode == "inference":
|
54 |
+
with torch.no_grad():
|
55 |
+
fake_image, _, _ = self.generate_fake(input_semantics, real_image, domain, style_codes=style_codes, compute_kld_loss=False)
|
56 |
+
return fake_image
|
57 |
+
elif mode == "generate_img_npy":
|
58 |
+
with torch.no_grad():
|
59 |
+
fake_image, encoded_style_code = self.generate_img_npy(input_semantics, real_image, domain)
|
60 |
+
return fake_image, encoded_style_code
|
61 |
+
else:
|
62 |
+
raise ValueError("|mode| is invalid")
|
63 |
+
|
64 |
+
def create_optimizers(self, opt):
|
65 |
+
G_params = list(self.netG.parameters())
|
66 |
+
if opt.use_vae:
|
67 |
+
G_params += list(self.netE.parameters())
|
68 |
+
if opt.isTrain:
|
69 |
+
D_params = list(self.netD.parameters())
|
70 |
+
|
71 |
+
beta1, beta2 = opt.beta1, opt.beta2
|
72 |
+
if opt.no_TTUR:
|
73 |
+
G_lr, D_lr = opt.lr, opt.lr
|
74 |
+
else:
|
75 |
+
G_lr, D_lr = opt.lr / 2, opt.lr * 2
|
76 |
+
|
77 |
+
optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2))
|
78 |
+
optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2))
|
79 |
+
|
80 |
+
return optimizer_G, optimizer_D
|
81 |
+
|
82 |
+
def save(self, epoch):
|
83 |
+
util.save_network(self.netG, "G", epoch, self.opt)
|
84 |
+
util.save_network(self.netD, "D", epoch, self.opt)
|
85 |
+
if self.opt.use_vae:
|
86 |
+
util.save_network(self.netE, "E", epoch, self.opt)
|
87 |
+
|
88 |
+
############################################################################
|
89 |
+
# Private helper methods
|
90 |
+
############################################################################
|
91 |
+
|
92 |
+
def initialize_networks(self, opt):
|
93 |
+
netG = networks.define_G(opt)
|
94 |
+
netD = networks.define_D(opt) if opt.isTrain else None
|
95 |
+
netE = networks.define_E(opt) if opt.use_vae else None
|
96 |
+
|
97 |
+
if not opt.isTrain or opt.continue_train:
|
98 |
+
# netG = util.load_network(netG, 'G', opt.which_epoch, opt)
|
99 |
+
checkpoint_path = os.path.join(os.path.dirname(__file__), "..", opt.checkpoint_path)
|
100 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
101 |
+
if device == "cuda":
|
102 |
+
checkpoint = torch.load(checkpoint_path)
|
103 |
+
else:
|
104 |
+
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
|
105 |
+
s = checkpoint
|
106 |
+
netG.load_state_dict(s)
|
107 |
+
|
108 |
+
if opt.isTrain:
|
109 |
+
netD = util.load_network(netD, "D", opt.which_epoch, opt)
|
110 |
+
if opt.use_vae:
|
111 |
+
netE = util.load_network(netE, "E", opt.which_epoch, opt)
|
112 |
+
|
113 |
+
return netG, netD, netE
|
114 |
+
|
115 |
+
# preprocess the input, such as moving the tensors to GPUs and
|
116 |
+
# transforming the label map to one-hot encoding
|
117 |
+
# |data|: dictionary of the input data
|
118 |
+
|
119 |
+
def preprocess_input(self, data):
|
120 |
+
"""
|
121 |
+
# move to GPU and change data types
|
122 |
+
data['label'] = data['label'].long()
|
123 |
+
if self.use_gpu():
|
124 |
+
data['label'] = data['label'].cuda(non_blocking=True)
|
125 |
+
data['instance'] = data['instance'].cuda(non_blocking=True)
|
126 |
+
data['image'] = data['image'].cuda(non_blocking=True)
|
127 |
+
data['domain'] = data['domain'].cuda(non_blocking=True)
|
128 |
+
|
129 |
+
# create one-hot label map
|
130 |
+
label_map = data['label']
|
131 |
+
bs, _, h, w = label_map.size()
|
132 |
+
nc = self.opt.label_nc + 1 if self.opt.contain_dontcare_label \
|
133 |
+
else self.opt.label_nc
|
134 |
+
input_label = self.FloatTensor(bs, nc, h, w).zero_()
|
135 |
+
input_semantics = input_label.scatter_(1, label_map, 1.0)
|
136 |
+
|
137 |
+
# concatenate instance map if it exists
|
138 |
+
if not self.opt.no_instance:
|
139 |
+
inst_map = data['instance']
|
140 |
+
instance_edge_map = self.get_edges(inst_map)
|
141 |
+
input_semantics = torch.cat((input_semantics, instance_edge_map), dim=1)
|
142 |
+
|
143 |
+
return input_semantics, data['image'], data['domain']
|
144 |
+
"""
|
145 |
+
|
146 |
+
data = data.long()
|
147 |
+
image = (data - 128).float() / 128.0
|
148 |
+
if self.use_gpu():
|
149 |
+
data = data.cuda()
|
150 |
+
image = image.float().cuda()
|
151 |
+
label_map = data
|
152 |
+
bs, _, h, w = label_map.size()
|
153 |
+
nc = self.opt.label_nc + 1 if self.opt.contain_dontcare_label else self.opt.label_nc
|
154 |
+
input_label = self.FloatTensor(bs, nc, h, w).zero_()
|
155 |
+
input_semantics = input_label.scatter_(1, label_map, 1.0)
|
156 |
+
|
157 |
+
return input_semantics, image # data['image'],
|
158 |
+
|
159 |
+
def compute_generator_loss(self, input_semantics, real_image, domain):
|
160 |
+
G_losses = {}
|
161 |
+
|
162 |
+
fake_image, KLD_loss, _ = self.generate_fake(input_semantics, real_image, domain, compute_kld_loss=self.opt.use_vae)
|
163 |
+
|
164 |
+
if self.opt.use_vae:
|
165 |
+
if KLD_loss.data.item() > 2.5:
|
166 |
+
print("ng")
|
167 |
+
print(KLD_loss.data.item())
|
168 |
+
KLD_loss.data = torch.Tensor([min(999.9999, KLD_loss.data.item())]).cuda()
|
169 |
+
G_losses["KLD"] = KLD_loss
|
170 |
+
|
171 |
+
pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image, domain)
|
172 |
+
|
173 |
+
G_losses["GAN"] = self.criterionGAN(pred_fake, True, for_discriminator=False)
|
174 |
+
|
175 |
+
if not self.opt.no_ganFeat_loss:
|
176 |
+
num_D = len(pred_fake)
|
177 |
+
GAN_Feat_loss = self.FloatTensor(1).fill_(0)
|
178 |
+
for i in range(num_D): # for each discriminator
|
179 |
+
# last output is the final prediction, so we exclude it
|
180 |
+
num_intermediate_outputs = len(pred_fake[i]) - 1
|
181 |
+
for j in range(num_intermediate_outputs): # for each layer output
|
182 |
+
unweighted_loss = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach())
|
183 |
+
GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D
|
184 |
+
G_losses["GAN_Feat"] = GAN_Feat_loss
|
185 |
+
|
186 |
+
if not self.opt.no_vgg_loss:
|
187 |
+
G_losses["VGG"] = self.criterionVGG(fake_image, real_image) * self.opt.lambda_vgg
|
188 |
+
|
189 |
+
return G_losses, fake_image
|
190 |
+
|
191 |
+
def compute_discriminator_loss(self, input_semantics, real_image, domain):
|
192 |
+
D_losses = {}
|
193 |
+
with torch.no_grad():
|
194 |
+
fake_image, _, _ = self.generate_fake(input_semantics, real_image, domain)
|
195 |
+
fake_image = fake_image.detach()
|
196 |
+
fake_image.requires_grad_()
|
197 |
+
|
198 |
+
pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image, domain)
|
199 |
+
|
200 |
+
D_losses["D_Fake"] = self.criterionGAN(pred_fake, False, for_discriminator=True)
|
201 |
+
D_losses["D_real"] = self.criterionGAN(pred_real, True, for_discriminator=True)
|
202 |
+
|
203 |
+
return D_losses
|
204 |
+
|
205 |
+
def encode_z(self, real_image, domain):
|
206 |
+
mu, logvar = self.netE(real_image, domain)
|
207 |
+
z = self.reparameterize(mu, logvar)
|
208 |
+
return z, mu, logvar
|
209 |
+
|
210 |
+
def generate_fake(self, input_semantics, real_image, domain, style_codes=None, compute_kld_loss=True):
|
211 |
+
KLD_loss = None
|
212 |
+
if self.opt.use_vae and style_codes is None:
|
213 |
+
# print('yes')
|
214 |
+
style_codes, mu, logvar = self.encode_z(real_image, domain)
|
215 |
+
if compute_kld_loss:
|
216 |
+
KLD_loss = self.KLDLoss(mu, logvar) * self.opt.lambda_kld
|
217 |
+
|
218 |
+
fake_image = self.netG(input_semantics, real_image, style_codes=style_codes)
|
219 |
+
|
220 |
+
assert (not compute_kld_loss) or self.opt.use_vae, "You cannot compute KLD loss if opt.use_vae == False"
|
221 |
+
|
222 |
+
return fake_image, KLD_loss, style_codes
|
223 |
+
|
224 |
+
def generate_img_npy(self, input_semantics, real_image, domain, compute_kld_loss=False):
|
225 |
+
KLD_loss = None
|
226 |
+
style_codes, mu, logvar = self.encode_z(real_image, domain)
|
227 |
+
if compute_kld_loss:
|
228 |
+
KLD_loss = self.KLDLoss(mu, logvar) * self.opt.lambda_kld
|
229 |
+
|
230 |
+
fake_image = self.netG(input_semantics, real_image, style_codes=style_codes)
|
231 |
+
# print(real_image, fake_image.shape)
|
232 |
+
assert (not compute_kld_loss) or self.opt.use_vae, "You cannot compute KLD loss if opt.use_vae == False"
|
233 |
+
|
234 |
+
return fake_image, style_codes
|
235 |
+
|
236 |
+
# Given fake and real image, return the prediction of discriminator
|
237 |
+
# for each fake and real image.
|
238 |
+
|
239 |
+
def discriminate(self, input_semantics, fake_image, real_image, domain):
|
240 |
+
fake_concat = torch.cat([input_semantics, fake_image], dim=1)
|
241 |
+
real_concat = torch.cat([input_semantics, real_image], dim=1)
|
242 |
+
|
243 |
+
# In Batch Normalization, the fake and real images are
|
244 |
+
# recommended to be in the same batch to avoid disparate
|
245 |
+
# statistics in fake and real images.
|
246 |
+
# So both fake and real images are fed to D all at once.
|
247 |
+
fake_and_real = torch.cat([fake_concat, real_concat], dim=0)
|
248 |
+
|
249 |
+
discriminator_out = self.netD(fake_and_real, domain)
|
250 |
+
|
251 |
+
pred_fake, pred_real = self.divide_pred(discriminator_out)
|
252 |
+
|
253 |
+
return pred_fake, pred_real
|
254 |
+
|
255 |
+
# Take the prediction of fake and real images from the combined batch
|
256 |
+
def divide_pred(self, pred):
|
257 |
+
# the prediction contains the intermediate outputs of multiscale GAN,
|
258 |
+
# so it's usually a list
|
259 |
+
if type(pred) == list:
|
260 |
+
fake = []
|
261 |
+
real = []
|
262 |
+
for p in pred:
|
263 |
+
fake.append([tensor[: tensor.size(0) // 2] for tensor in p])
|
264 |
+
real.append([tensor[tensor.size(0) // 2 :] for tensor in p])
|
265 |
+
else:
|
266 |
+
fake = pred[: pred.size(0) // 2]
|
267 |
+
real = pred[pred.size(0) // 2 :]
|
268 |
+
|
269 |
+
return fake, real
|
270 |
+
|
271 |
+
def get_edges(self, t):
|
272 |
+
edge = self.ByteTensor(t.size()).zero_()
|
273 |
+
edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
|
274 |
+
edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])
|
275 |
+
edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
|
276 |
+
edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
|
277 |
+
return edge.float()
|
278 |
+
|
279 |
+
def reparameterize(self, mu, logvar):
|
280 |
+
std = torch.exp(0.5 * logvar)
|
281 |
+
eps = torch.randn_like(std)
|
282 |
+
return eps.mul(std) + mu
|
283 |
+
|
284 |
+
def use_gpu(self):
|
285 |
+
return len(self.opt.gpu_ids) > 0
|
static/index.js
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
let cvsIn = document.getElementById("inputimg");
|
4 |
+
let ctxIn = cvsIn.getContext('2d');
|
5 |
+
let style = document.getElementById("style");
|
6 |
+
let svgGraph = null;
|
7 |
+
let mouselbtn = false;
|
8 |
+
var current_time = (new Date()).getTime();
|
9 |
+
var user_id = Math.floor(Math.random() * 1000000000);
|
10 |
+
|
11 |
+
|
12 |
+
// initilize
|
13 |
+
window.onload = function () {
|
14 |
+
ctxIn.fillStyle = "#87ceeb";
|
15 |
+
ctxIn.fillRect(0, 0, cvsIn.width, 300);
|
16 |
+
ctxIn.fillStyle = "#567d46";
|
17 |
+
ctxIn.fillRect(0, 300, cvsIn.width, 512);
|
18 |
+
|
19 |
+
ctxIn.color = "#b0d49b";
|
20 |
+
ctxIn.lineWidth = 30;
|
21 |
+
ctxIn.lineJoin = ctxIn.lineCap = 'round';
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
// add cavas events
|
26 |
+
cvsIn.addEventListener("mousedown", function (e) {
|
27 |
+
if (e.button == 0) {
|
28 |
+
let rect = e.target.getBoundingClientRect();
|
29 |
+
let x = e.clientX - rect.left;
|
30 |
+
let y = e.clientY - rect.top;
|
31 |
+
mouselbtn = true;
|
32 |
+
ctxIn.beginPath();
|
33 |
+
ctxIn.moveTo(x, y);
|
34 |
+
}
|
35 |
+
else if (e.button == 2) {
|
36 |
+
onClear(); // right click for clear input
|
37 |
+
}
|
38 |
+
});
|
39 |
+
|
40 |
+
cvsIn.addEventListener("mouseup", function (e) {
|
41 |
+
if (e.button == 0) {
|
42 |
+
mouselbtn = false;
|
43 |
+
move_range = domainSlider.value;
|
44 |
+
onRecognition(move_range);
|
45 |
+
}
|
46 |
+
});
|
47 |
+
cvsIn.addEventListener("mousemove", function (e) {
|
48 |
+
let rect = e.target.getBoundingClientRect();
|
49 |
+
let x = e.clientX - rect.left;
|
50 |
+
let y = e.clientY - rect.top;
|
51 |
+
if (mouselbtn) {
|
52 |
+
ctxIn.lineTo(x, y);
|
53 |
+
ctxIn.strokeStyle = ctxIn.color;
|
54 |
+
ctxIn.stroke();
|
55 |
+
if (((new Date).getTime() - current_time) >= 400) {
|
56 |
+
move_range = domainSlider.value;
|
57 |
+
onRecognition(move_range);
|
58 |
+
current_time = (new Date).getTime();
|
59 |
+
}
|
60 |
+
|
61 |
+
}
|
62 |
+
});
|
63 |
+
|
64 |
+
cvsIn.addEventListener("touchstart", function (e) {
|
65 |
+
// for touch device
|
66 |
+
if (e.targetTouches.length == 1) {
|
67 |
+
let rect = e.target.getBoundingClientRect();
|
68 |
+
let touch = e.targetTouches[0];
|
69 |
+
let x = touch.clientX - rect.left;
|
70 |
+
let y = touch.clientY - rect.top;
|
71 |
+
ctxIn.beginPath();
|
72 |
+
ctxIn.moveTo(x, y);
|
73 |
+
}
|
74 |
+
});
|
75 |
+
|
76 |
+
cvsIn.addEventListener("touchmove", function (e) {
|
77 |
+
// for touch device
|
78 |
+
if (e.targetTouches.length == 1) {
|
79 |
+
let rect = e.target.getBoundingClientRect();
|
80 |
+
let touch = e.targetTouches[0];
|
81 |
+
let x = touch.clientX - rect.left;
|
82 |
+
let y = touch.clientY - rect.top;
|
83 |
+
ctxIn.lineTo(x, y);
|
84 |
+
ctxIn.strokeStyle = ctxIn.color;
|
85 |
+
ctxIn.stroke();
|
86 |
+
e.preventDefault();
|
87 |
+
}
|
88 |
+
});
|
89 |
+
|
90 |
+
cvsIn.addEventListener("touchend", function (e) {
|
91 |
+
// for touch device
|
92 |
+
move_range = domainSlider.value;
|
93 |
+
onRecognition(move_range);
|
94 |
+
});
|
95 |
+
|
96 |
+
// prevent display the contextmenu
|
97 |
+
cvsIn.addEventListener('contextmenu', function (e) {
|
98 |
+
e.preventDefault();
|
99 |
+
});
|
100 |
+
|
101 |
+
document.getElementById("clearbtn").onclick = onClear;
|
102 |
+
function onClear() {
|
103 |
+
mouselbtn = false;
|
104 |
+
ctxIn.clearRect(0, 0, 512, 512);
|
105 |
+
ctxIn.fillStyle = "#87ceeb";
|
106 |
+
ctxIn.fillRect(0, 0, cvsIn.width, 300);
|
107 |
+
ctxIn.fillStyle = "#567d46";
|
108 |
+
ctxIn.fillRect(0, 300, cvsIn.width, 512);
|
109 |
+
}
|
110 |
+
|
111 |
+
|
112 |
+
document.getElementById("random_pick").addEventListener("click", function () {
|
113 |
+
//ctxIn.color = "#F5F5F5";
|
114 |
+
onRecognition_random();
|
115 |
+
});
|
116 |
+
|
117 |
+
|
118 |
+
document.getElementById("color1").addEventListener("click", function () {
|
119 |
+
//ctxIn.color = "#D5D5D5";
|
120 |
+
ctxIn.color = "#87ceeb";
|
121 |
+
});
|
122 |
+
document.getElementById("color2").addEventListener("click", function () {
|
123 |
+
//ctxIn.color = "#696969";
|
124 |
+
ctxIn.color = "#9b7653"
|
125 |
+
});
|
126 |
+
document.getElementById("color3").addEventListener("click", function () {
|
127 |
+
//ctxIn.color = "#676767";
|
128 |
+
ctxIn.color = "#b0d49b"
|
129 |
+
});
|
130 |
+
document.getElementById("color4").addEventListener("click", function () {
|
131 |
+
//ctxIn.color = "#F5F5F5";
|
132 |
+
ctxIn.color = "#5abcd8"
|
133 |
+
});
|
134 |
+
document.getElementById("color5").addEventListener("click", function () {
|
135 |
+
//ctxIn.color = "#F5F5F5";
|
136 |
+
ctxIn.color = "#C1BEBA"
|
137 |
+
});
|
138 |
+
document.getElementById("color6").addEventListener("click", function () {
|
139 |
+
ctxIn.color = "#5A4D41"
|
140 |
+
});
|
141 |
+
document.getElementById("color7").addEventListener("click", function () {
|
142 |
+
ctxIn.color = "#567d46"
|
143 |
+
});
|
144 |
+
document.getElementById("color8").addEventListener("click", function () {
|
145 |
+
ctxIn.color = "#42692f"
|
146 |
+
});
|
147 |
+
|
148 |
+
document.getElementById("color9").addEventListener("click", function () {
|
149 |
+
ctxIn.color = "#1577be"
|
150 |
+
});
|
151 |
+
//document.getElementById("color10").addEventListener("click", function(){
|
152 |
+
//ctxIn.color = "#676767";
|
153 |
+
// ctxIn.color = "#808080"
|
154 |
+
//});
|
155 |
+
document.getElementById("color11").addEventListener("click", function () {
|
156 |
+
//ctxIn.color = "#F5F5F5";
|
157 |
+
ctxIn.color = "#3a2e27"
|
158 |
+
});
|
159 |
+
document.getElementById("color12").addEventListener("click", function () {
|
160 |
+
//ctxIn.color = "#F5F5F5";
|
161 |
+
ctxIn.color = "#4D415A"
|
162 |
+
});
|
163 |
+
//document.getElementById("color13").addEventListener("click", function(){
|
164 |
+
// ctxIn.color = "#74cc8c"
|
165 |
+
//});
|
166 |
+
document.getElementById("color14").addEventListener("click", function () {
|
167 |
+
ctxIn.color = "#FDDA16"
|
168 |
+
});
|
169 |
+
document.getElementById("color15").addEventListener("click", function () {
|
170 |
+
ctxIn.color = "#d0cccc"
|
171 |
+
});
|
172 |
+
|
173 |
+
var brushSlider = document.getElementById("brushSlider");
|
174 |
+
ctxIn.lineWidth = brushSlider.value;
|
175 |
+
|
176 |
+
brushSlider.addEventListener("change", function () {
|
177 |
+
ctxIn.lineWidth = brushSlider.value;
|
178 |
+
});
|
179 |
+
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
var move_range = 3;//domainSlider_1.value;
|
184 |
+
|
185 |
+
|
186 |
+
document.getElementById('style').addEventListener('change', function (event) {
|
187 |
+
domainSlider.value = 3;
|
188 |
+
onRecognition(domainSlider.value);
|
189 |
+
})
|
190 |
+
|
191 |
+
|
192 |
+
domainSlider.addEventListener("change", function () {
|
193 |
+
// style.value = "ink";
|
194 |
+
move_range = domainSlider.value;
|
195 |
+
onRecognition(move_range);
|
196 |
+
});
|
197 |
+
|
198 |
+
|
199 |
+
// post data to server for recognition
|
200 |
+
function onRecognition(range) {
|
201 |
+
console.time("predict");
|
202 |
+
|
203 |
+
$.ajax({
|
204 |
+
url: './predict',
|
205 |
+
type: 'POST',
|
206 |
+
data: JSON.stringify({
|
207 |
+
img: cvsIn.toDataURL("image/png").replace('data:image/png;base64,', ''),
|
208 |
+
model: style.value,
|
209 |
+
move_range: range,
|
210 |
+
user_id: user_id
|
211 |
+
}),
|
212 |
+
contentType: 'application/json',
|
213 |
+
}).done(function (data) {
|
214 |
+
drawImgToCanvas("outputimg", data)
|
215 |
+
|
216 |
+
}).fail(function (XMLHttpRequest, textStatus, errorThrown) {
|
217 |
+
console.log(XMLHttpRequest);
|
218 |
+
alert("error");
|
219 |
+
})
|
220 |
+
|
221 |
+
console.timeEnd("time");
|
222 |
+
}
|
223 |
+
|
224 |
+
function onRecognition_random(range) {
|
225 |
+
console.time("predict");
|
226 |
+
|
227 |
+
$.ajax({
|
228 |
+
url: './predict_random',
|
229 |
+
type: 'POST',
|
230 |
+
data: JSON.stringify({
|
231 |
+
img: cvsIn.toDataURL("image/png").replace('data:image/png;base64,', ''),
|
232 |
+
model: style.value,
|
233 |
+
move_range: range,
|
234 |
+
user_id: user_id
|
235 |
+
}),
|
236 |
+
contentType: 'application/json',
|
237 |
+
}).done(function (data) {
|
238 |
+
drawImgToCanvas("outputimg", data)
|
239 |
+
|
240 |
+
}).fail(function (XMLHttpRequest, textStatus, errorThrown) {
|
241 |
+
console.log(XMLHttpRequest);
|
242 |
+
alert("error");
|
243 |
+
})
|
244 |
+
|
245 |
+
console.timeEnd("time");
|
246 |
+
}
|
247 |
+
|
248 |
+
function drawImgToCanvas(canvasId, b64Img) {
|
249 |
+
let canvas = document.getElementById(canvasId);
|
250 |
+
let ctx = canvas.getContext('2d');
|
251 |
+
let img = new Image();
|
252 |
+
img.src = "data:image/png;base64," + b64Img;
|
253 |
+
img.onload = function () {
|
254 |
+
ctx.drawImage(img, 0, 0, img.width, img.height, 0, 0, canvas.width, canvas.height);
|
255 |
+
}
|
256 |
+
}
|
static/init_code
ADDED
Binary file (1.79 kB). View file
|
|
static/style.css
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.common {
|
2 |
+
text-align: center;
|
3 |
+
}
|
4 |
+
|
5 |
+
.boxitem1 {
|
6 |
+
display: inline-block;
|
7 |
+
vertical-align: left;
|
8 |
+
}
|
9 |
+
|
10 |
+
select {
|
11 |
+
font-size: 1.7em;
|
12 |
+
border: 1px;
|
13 |
+
}
|
14 |
+
|
15 |
+
button {
|
16 |
+
border-radius: 10px;
|
17 |
+
}
|
18 |
+
|
19 |
+
.boxitem2 {
|
20 |
+
display: inline-block;
|
21 |
+
vertical-align: right;
|
22 |
+
}
|
23 |
+
.boxitem {
|
24 |
+
display: inline-block;
|
25 |
+
vertical-align: top;
|
26 |
+
}
|
27 |
+
|
28 |
+
#inputimg {
|
29 |
+
vertical-align: left;
|
30 |
+
border: solid 1px black;
|
31 |
+
}
|
32 |
+
|
33 |
+
#outputimg {
|
34 |
+
vertical-align: left;
|
35 |
+
border: solid 1px black;
|
36 |
+
}
|
templates/index.html
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html>
|
3 |
+
|
4 |
+
<head>
|
5 |
+
<link rel="stylesheet" href="https://cdn.rawgit.com/Chalarangelo/mini.css/v2.3.7/dist/mini-default.min.css">
|
6 |
+
<link rel="stylesheet" href="./static/style.css">
|
7 |
+
<meta charset="UTF-8">
|
8 |
+
<title>Label to Art Demo_V_0.7 </title>
|
9 |
+
</head>
|
10 |
+
|
11 |
+
<div class="common">
|
12 |
+
|
13 |
+
<body>
|
14 |
+
<div class="row">
|
15 |
+
<div class="col-sm-12 col-md-10 col-md-offset-1">
|
16 |
+
|
17 |
+
<div class="boxitem1">
|
18 |
+
<label for='brushSlider' style='font-size: 1.7em; font-weight: bold'>Stroke Width </label>
|
19 |
+
<input type="range" name="brushsize" min="0" max="100" id="brushSlider" step="1" value="30"
|
20 |
+
onchange="this.setAttribute('value',this.value);">
|
21 |
+
</div>
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
<div class="boxitem2">
|
28 |
+
|
29 |
+
<label for="style" style='font-size: 1.7em; font-weight: bold'>Domain:</label>
|
30 |
+
<select id="style">
|
31 |
+
<option style='font-size: 1.7em; font-weight: bold' value="ink">Ink Wash</option>
|
32 |
+
<option style='font-size: 1.7em; font-weight: bold' value="monet">Monet</option>
|
33 |
+
<option style='font-size: 1.7em; font-weight: bold' value="vangogh">Van Gogh</option>
|
34 |
+
<option style='font-size: 1.7em; font-weight: bold' value="water">WaterColor</option>
|
35 |
+
</select>
|
36 |
+
</div>
|
37 |
+
|
38 |
+
<div class="row">
|
39 |
+
<div class="col-sm-12 col-md-10 col-md-offset-1">
|
40 |
+
<div class="boxitem1">
|
41 |
+
<label for='domainSlider' style='font-size: 1.7em; font-weight: bold'> Style Strength </label>
|
42 |
+
<input type="range" name="style range" min="1" max="5" id="domainSlider" step="0.2" value="3"
|
43 |
+
onchange="this.setAttribute('value',this.value);">
|
44 |
+
</div>
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
<div class="boxitem2">
|
54 |
+
<button id="random_pick" class="primary"
|
55 |
+
style="background-color:#003449; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
|
56 |
+
Random </button>
|
57 |
+
</div>
|
58 |
+
</div>
|
59 |
+
<br>
|
60 |
+
|
61 |
+
</div>
|
62 |
+
<div class="divider"></div>
|
63 |
+
<div class="boxitem1">
|
64 |
+
<canvas id="inputimg" width="512" height="512" style="border:5px solid #ffffff;"></canvas>
|
65 |
+
</div>
|
66 |
+
<div class="boxitem2">
|
67 |
+
<canvas id="outputimg" width="512" height="512" style="border:5px solid #ffffff;"></canvas>
|
68 |
+
</div>
|
69 |
+
</div>
|
70 |
+
</div>
|
71 |
+
|
72 |
+
|
73 |
+
<div class="boxitem">
|
74 |
+
<button id="color1" class="primary"
|
75 |
+
style="background-color:#87ceeb; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
|
76 |
+
Sky </button>
|
77 |
+
<button id="color2" class="primary"
|
78 |
+
style="background-color:#9b7653; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
|
79 |
+
Dirt </button>
|
80 |
+
<button id="color3" class="primary"
|
81 |
+
style="background-color:#b0d49b; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
|
82 |
+
Mountain </button>
|
83 |
+
<button id="color4" class="primary"
|
84 |
+
style="background-color:#5abcd8; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
|
85 |
+
River </button>
|
86 |
+
<button id="color5" class="primary"
|
87 |
+
style="background-color:#C1BEBA; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
|
88 |
+
Clouds </button>
|
89 |
+
<button id="color6" class="primary"
|
90 |
+
style="background-color:#5A4D41; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
|
91 |
+
Rock </button>
|
92 |
+
<button id="color7" class="primary"
|
93 |
+
style="background-color:#567d46; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
|
94 |
+
Grass </button>
|
95 |
+
<button id="color8" class="primary"
|
96 |
+
style="background-color:#42692f; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
|
97 |
+
Tree </button>
|
98 |
+
</div>
|
99 |
+
<br>
|
100 |
+
<div class="boxitem">
|
101 |
+
<button id="color9" class="primary"
|
102 |
+
style="background-color:#1577be; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
|
103 |
+
Sea </button>
|
104 |
+
<button id="color11" class="primary"
|
105 |
+
style="background-color:#3a2e27; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
|
106 |
+
Ground </button>
|
107 |
+
<button id="color12" class="primary"
|
108 |
+
style="background-color:#4D415A; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
|
109 |
+
Hill </button>
|
110 |
+
<button id="color14" class="primary"
|
111 |
+
style="background-color:#FDDA16; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
|
112 |
+
Road </button>
|
113 |
+
<button id="color15" class="primary"
|
114 |
+
style="background-color:#d0cccc; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
|
115 |
+
Snow </button>
|
116 |
+
<button id="clearbtn" class="primary"
|
117 |
+
style="margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;"> Clear </button>
|
118 |
+
|
119 |
+
</div>
|
120 |
+
|
121 |
+
<script src="./static/index.js"></script>
|
122 |
+
<script src="//ajax.googleapis.com/ajax/libs/jquery/3.3.1/jquery.min.js"></script>
|
123 |
+
<script src="//d3js.org/d3.v5.min.js"></script>
|
124 |
+
</body>
|
utils/boundaries_amp_52/artwork_ink_boundary/boundary.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:63f02c41f68b7fcd603e158fbd4542ebbe3b42d5b25b64a5e1e1e0699e1286c0
|
3 |
+
size 1152
|
utils/boundaries_amp_52/artwork_ink_boundary/log.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[2021-09-15 00:57:58,220][INFO] Loading latent codes.
|
2 |
+
[2021-09-15 00:57:58,233][INFO] Loading attribute scores.
|
3 |
+
[2021-09-15 00:57:58,234][INFO] Filtering training data.
|
4 |
+
[2021-09-15 00:57:58,234][INFO] Sorting scores to get positive and negative samples.
|
5 |
+
[2021-09-15 00:57:58,246][INFO] Spliting training and validation sets:
|
6 |
+
[2021-09-15 00:57:58,337][INFO] Training: 4200 positive, 4200 negative.
|
7 |
+
[2021-09-15 00:57:58,342][INFO] Validation: 1800 positive, 1800 negative.
|
8 |
+
[2021-09-15 00:57:58,353][INFO] Remaining: 4255 positive, 23745 negative.
|
9 |
+
[2021-09-15 00:57:58,356][INFO] Training boundary.
|
10 |
+
[2021-09-15 00:57:59,433][INFO] Finish training.
|
11 |
+
[2021-09-15 00:57:59,711][INFO] Accuracy for validation set: 3596 / 3600 = 0.998889
|
12 |
+
[2021-09-15 00:58:01,331][INFO] Accuracy for remaining set: 27037 / 28000 = 0.965607
|
utils/boundaries_amp_52/artwork_monet_boundary/boundary.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:477c0e8af08ccb05d8ddb79e38b5c71cbc856bb4ebe0adf80194de48cf1510ad
|
3 |
+
size 1152
|
utils/boundaries_amp_52/artwork_monet_boundary/log.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[2021-09-15 00:57:58,250][INFO] Loading latent codes.
|
2 |
+
[2021-09-15 00:57:58,265][INFO] Loading attribute scores.
|
3 |
+
[2021-09-15 00:57:58,266][INFO] Filtering training data.
|
4 |
+
[2021-09-15 00:57:58,266][INFO] Sorting scores to get positive and negative samples.
|
5 |
+
[2021-09-15 00:57:58,281][INFO] Spliting training and validation sets:
|
6 |
+
[2021-09-15 00:57:58,393][INFO] Training: 4200 positive, 4200 negative.
|
7 |
+
[2021-09-15 00:57:58,398][INFO] Validation: 1800 positive, 1800 negative.
|
8 |
+
[2021-09-15 00:57:58,400][INFO] Remaining: 5854 positive, 22146 negative.
|
9 |
+
[2021-09-15 00:57:58,407][INFO] Training boundary.
|
10 |
+
[2021-09-15 00:57:59,699][INFO] Finish training.
|
11 |
+
[2021-09-15 00:57:59,912][INFO] Accuracy for validation set: 3549 / 3600 = 0.985833
|
12 |
+
[2021-09-15 00:58:01,556][INFO] Accuracy for remaining set: 23849 / 28000 = 0.851750
|
utils/boundaries_amp_52/artwork_vangogh_boundary/boundary.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b616bf764d3ea1c85b9f852b5320b90fcb6ff5d90f80d22ef8a7e5c3a47a5959
|
3 |
+
size 1152
|
utils/boundaries_amp_52/artwork_vangogh_boundary/log.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[2021-09-15 00:57:58,217][INFO] Loading latent codes.
|
2 |
+
[2021-09-15 00:57:58,230][INFO] Loading attribute scores.
|
3 |
+
[2021-09-15 00:57:58,231][INFO] Filtering training data.
|
4 |
+
[2021-09-15 00:57:58,231][INFO] Sorting scores to get positive and negative samples.
|
5 |
+
[2021-09-15 00:57:58,246][INFO] Spliting training and validation sets:
|
6 |
+
[2021-09-15 00:57:58,279][INFO] Training: 4200 positive, 4200 negative.
|
7 |
+
[2021-09-15 00:57:58,281][INFO] Validation: 1800 positive, 1800 negative.
|
8 |
+
[2021-09-15 00:57:58,281][INFO] Remaining: 3401 positive, 24599 negative.
|
9 |
+
[2021-09-15 00:57:58,281][INFO] Training boundary.
|
10 |
+
[2021-09-15 00:57:59,347][INFO] Finish training.
|
11 |
+
[2021-09-15 00:57:59,551][INFO] Accuracy for validation set: 3596 / 3600 = 0.998889
|
12 |
+
[2021-09-15 00:58:01,098][INFO] Accuracy for remaining set: 23785 / 28000 = 0.849464
|
utils/boundaries_amp_52/artwork_water_boundary/boundary.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f0827fe3ea8f6361bc298c968623c2ff2e33e5183f4d9df18c118f751c10c713
|
3 |
+
size 1152
|
utils/boundaries_amp_52/artwork_water_boundary/log.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[2021-09-15 00:57:58,226][INFO] Loading latent codes.
|
2 |
+
[2021-09-15 00:57:58,244][INFO] Loading attribute scores.
|
3 |
+
[2021-09-15 00:57:58,245][INFO] Filtering training data.
|
4 |
+
[2021-09-15 00:57:58,245][INFO] Sorting scores to get positive and negative samples.
|
5 |
+
[2021-09-15 00:57:58,263][INFO] Spliting training and validation sets:
|
6 |
+
[2021-09-15 00:57:58,390][INFO] Training: 4200 positive, 4200 negative.
|
7 |
+
[2021-09-15 00:57:58,393][INFO] Validation: 1800 positive, 1800 negative.
|
8 |
+
[2021-09-15 00:57:58,398][INFO] Remaining: 4465 positive, 23535 negative.
|
9 |
+
[2021-09-15 00:57:58,401][INFO] Training boundary.
|
10 |
+
[2021-09-15 00:57:59,812][INFO] Finish training.
|
11 |
+
[2021-09-15 00:58:00,021][INFO] Accuracy for validation set: 3584 / 3600 = 0.995556
|
12 |
+
[2021-09-15 00:58:01,830][INFO] Accuracy for remaining set: 24271 / 28000 = 0.866821
|
utils/umap_utils.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import copy
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
from sklearn import svm
|
7 |
+
|
8 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
9 |
+
|
10 |
+
|
11 |
+
def linear_interpolate(latent_code, boundary, start_distance=-3, end_distance=3, steps=10):
|
12 |
+
"""Manipulates the given latent code with respect to a particular boundary.
|
13 |
+
|
14 |
+
Basically, this function takes a latent code and a boundary as inputs, and
|
15 |
+
outputs a collection of manipulated latent codes. For example, let `steps` to
|
16 |
+
be 10, then the input `latent_code` is with shape [1, latent_space_dim], input
|
17 |
+
`boundary` is with shape [1, latent_space_dim] and unit norm, the output is
|
18 |
+
with shape [10, latent_space_dim]. The first output latent code is
|
19 |
+
`start_distance` away from the given `boundary`, while the last output latent
|
20 |
+
code is `end_distance` away from the given `boundary`. Remaining latent codes
|
21 |
+
are linearly interpolated.
|
22 |
+
|
23 |
+
Input `latent_code` can also be with shape [1, num_layers, latent_space_dim]
|
24 |
+
to support W+ space in Style GAN. In this case, all features in W+ space will
|
25 |
+
be manipulated same as each other. Accordingly, the output will be with shape
|
26 |
+
[10, num_layers, latent_space_dim].
|
27 |
+
|
28 |
+
NOTE: Distance is sign sensitive.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
latent_code: The input latent code for manipulation.
|
32 |
+
boundary: The semantic boundary as reference.
|
33 |
+
start_distance: The distance to the boundary where the manipulation starts.
|
34 |
+
(default: -3.0)
|
35 |
+
end_distance: The distance to the boundary where the manipulation ends.
|
36 |
+
(default: 3.0)
|
37 |
+
steps: Number of steps to move the latent code from start position to end
|
38 |
+
position. (default: 10)
|
39 |
+
"""
|
40 |
+
assert latent_code.shape[0] == 1 and boundary.shape[0] == 1 and len(boundary.shape) == 2 and boundary.shape[1] == latent_code.shape[-1]
|
41 |
+
|
42 |
+
linspace = np.linspace(start_distance, end_distance, steps)
|
43 |
+
if len(latent_code.shape) == 2:
|
44 |
+
linspace = linspace - latent_code.dot(boundary.T)
|
45 |
+
linspace = linspace.reshape(-1, 1).astype(np.float32)
|
46 |
+
return latent_code + linspace * boundary
|
47 |
+
if len(latent_code.shape) == 3:
|
48 |
+
linspace = linspace.reshape(-1, 1, 1).astype(np.float32)
|
49 |
+
return latent_code + linspace * boundary.reshape(1, 1, -1)
|
50 |
+
raise ValueError(
|
51 |
+
f"Input `latent_code` should be with shape "
|
52 |
+
f"[1, latent_space_dim] or [1, N, latent_space_dim] for "
|
53 |
+
f"W+ space in Style GAN!\n"
|
54 |
+
f"But {latent_code.shape} is received."
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
def get_code(domain, boundaries):
|
59 |
+
if domain == "ink":
|
60 |
+
domain = 0
|
61 |
+
elif domain == "monet":
|
62 |
+
domain = 1
|
63 |
+
elif domain == "vangogh":
|
64 |
+
domain = 2
|
65 |
+
elif domain == "water":
|
66 |
+
domain = 3
|
67 |
+
|
68 |
+
res = np.array(torch.randn(1, 256, dtype=torch.float32))
|
69 |
+
# res = linear_interpolate(res, boundaries[domain], end_distance=3, steps=3)[-1:]
|
70 |
+
res = torch.Tensor(res).cuda() if torch.cuda.is_available() else torch.Tensor(res)
|
71 |
+
return res
|
72 |
+
|
73 |
+
|
74 |
+
def modify_code(code, boundaries, domain, range):
|
75 |
+
if domain == "ink":
|
76 |
+
domain = 0
|
77 |
+
elif domain == "monet":
|
78 |
+
domain = 1
|
79 |
+
elif domain == "vangogh":
|
80 |
+
domain = 2
|
81 |
+
elif domain == "water":
|
82 |
+
domain = 3
|
83 |
+
# print(domain, range)
|
84 |
+
if range == 0:
|
85 |
+
return code
|
86 |
+
else:
|
87 |
+
res = np.array(code.cpu().detach().numpy())
|
88 |
+
res = linear_interpolate(res, boundaries[domain], end_distance=range, steps=3)[-1:]
|
89 |
+
res = torch.Tensor(res).cuda() if torch.cuda.is_available() else torch.Tensor(res)
|
90 |
+
return res
|
91 |
+
|
92 |
+
|
93 |
+
def load_boundries():
|
94 |
+
domains = ["ink", "monet", "vangogh", "water"]
|
95 |
+
domains.sort()
|
96 |
+
boundaries = [
|
97 |
+
np.load(os.path.join(os.path.dirname(__file__), "boundaries_amp_52/artwork_" + domain + "_boundary/boundary.npy")) for domain in domains
|
98 |
+
]
|
99 |
+
return boundaries
|