sky24h commited on
Commit
f3daba8
·
1 Parent(s): 1880a57

init_commit

Browse files
Files changed (37) hide show
  1. .gitignore +130 -0
  2. Dockerfile +32 -0
  3. app.py +130 -0
  4. requirements.txt +6 -0
  5. seg2art/checkpoints/multimodal_artworks/latest_net_G-fp16.pth +3 -0
  6. seg2art/inference_util.py +77 -0
  7. seg2art/model_util.py +158 -0
  8. seg2art/options/__init__.py +4 -0
  9. seg2art/options/base_options.py +184 -0
  10. seg2art/options/test_options.py +22 -0
  11. seg2art/sstan_models/__init__.py +44 -0
  12. seg2art/sstan_models/networks/__init__.py +63 -0
  13. seg2art/sstan_models/networks/architecture.py +231 -0
  14. seg2art/sstan_models/networks/base_network.py +59 -0
  15. seg2art/sstan_models/networks/dual_attention_module.py +51 -0
  16. seg2art/sstan_models/networks/generator.py +184 -0
  17. seg2art/sstan_models/networks/normalization.py +222 -0
  18. seg2art/sstan_models/networks/sync_batchnorm/__init__.py +13 -0
  19. seg2art/sstan_models/networks/sync_batchnorm/batchnorm.py +361 -0
  20. seg2art/sstan_models/networks/sync_batchnorm/batchnorm_reimpl.py +74 -0
  21. seg2art/sstan_models/networks/sync_batchnorm/comm.py +137 -0
  22. seg2art/sstan_models/networks/sync_batchnorm/replicate.py +94 -0
  23. seg2art/sstan_models/networks/sync_batchnorm/unittest.py +29 -0
  24. seg2art/sstan_models/pix2pix_model.py +285 -0
  25. static/index.js +256 -0
  26. static/init_code +0 -0
  27. static/style.css +36 -0
  28. templates/index.html +124 -0
  29. utils/boundaries_amp_52/artwork_ink_boundary/boundary.npy +3 -0
  30. utils/boundaries_amp_52/artwork_ink_boundary/log.txt +12 -0
  31. utils/boundaries_amp_52/artwork_monet_boundary/boundary.npy +3 -0
  32. utils/boundaries_amp_52/artwork_monet_boundary/log.txt +12 -0
  33. utils/boundaries_amp_52/artwork_vangogh_boundary/boundary.npy +3 -0
  34. utils/boundaries_amp_52/artwork_vangogh_boundary/log.txt +12 -0
  35. utils/boundaries_amp_52/artwork_water_boundary/boundary.npy +3 -0
  36. utils/boundaries_amp_52/artwork_water_boundary/log.txt +12 -0
  37. utils/umap_utils.py +99 -0
.gitignore ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+ .vercel
Dockerfile ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ # Include base image
5
+ FROM docker.io/pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
6
+
7
+ # Define working directory
8
+ WORKDIR /workspace/
9
+
10
+ # Set timezone
11
+ ENV TZ=Asia/Tokyo
12
+ RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
13
+
14
+ # Install dependencies
15
+ RUN apt-get update && apt-get -y install libgl1 libglib2.0-0 vim
16
+ RUN apt-get autoremove -y && apt-get clean -y
17
+
18
+ # Add pretrained model
19
+ ADD seg2art ./seg2art
20
+ ADD static ./static
21
+ ADD templates ./templates
22
+ ADD utils ./utils
23
+
24
+ # Add necessary files
25
+ ADD app.py ./
26
+
27
+ # pip install
28
+ ADD requirements.txt ./
29
+ RUN pip install -r requirements.txt
30
+
31
+ # Run server
32
+ CMD [ "python", "-u", "./app.py" ]
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import json
5
+ import torch
6
+ import base64
7
+ from PIL import Image
8
+ from io import BytesIO
9
+
10
+ # set CUDA_MODULE_LOADING=LAZY to speed up the serverless function
11
+ os.environ["CUDA_MODULE_LOADING"] = "LAZY"
12
+ # set SAFETENSORS_FAST_GPU=1 to speed up the serverless function
13
+ os.environ["SAFETENSORS_FAST_GPU"] = "1"
14
+
15
+ sys.path.append(os.path.join(os.path.dirname(__file__), "seg2art"))
16
+ from seg2art.sstan_models.pix2pix_model import Pix2PixModel
17
+ from seg2art.options.test_options import TestOptions
18
+ from seg2art.inference_util import get_artwork
19
+
20
+ import uvicorn
21
+ from fastapi import FastAPI, Form
22
+ from fastapi.templating import Jinja2Templates
23
+ from fastapi.responses import PlainTextResponse, HTMLResponse
24
+ from fastapi.requests import Request
25
+ from fastapi.staticfiles import StaticFiles
26
+
27
+
28
+ # declare constants
29
+ HOST = "0.0.0.0"
30
+ PORT = 7860
31
+ # FastAPI
32
+ app = FastAPI(root_path=os.path.abspath(os.path.dirname(__file__)))
33
+ app.mount("/static", StaticFiles(directory="static"), name="static")
34
+ templates = Jinja2Templates(directory="templates")
35
+
36
+
37
+ # initialize SEAN model.
38
+ opt = TestOptions().parse()
39
+ opt.status = "test"
40
+ model = Pix2PixModel(opt)
41
+ model = model.half() if torch.cuda.is_available() else model
42
+ model.eval()
43
+
44
+
45
+ from utils.umap_utils import get_code, load_boundries, modify_code
46
+
47
+ boundaries = load_boundries()
48
+ global current_codes
49
+ current_codes = {}
50
+ max_user_num = 5
51
+
52
+ initial_code_path = os.path.join(os.path.dirname(__file__), "static/init_code")
53
+ initial_code = torch.load(initial_code_path) if torch.cuda.is_available() else torch.load(initial_code_path, map_location=torch.device("cpu"))
54
+
55
+
56
+ def EncodeImage(img_pil):
57
+ with BytesIO() as buffer:
58
+ img_pil.save(buffer, "jpeg")
59
+ image_data = base64.b64encode(buffer.getvalue())
60
+ return image_data
61
+
62
+
63
+ def DecodeImage(img_pil):
64
+ img_pil = BytesIO(base64.urlsafe_b64decode(img_pil))
65
+ img_pil = Image.open(img_pil).convert("RGB")
66
+ return img_pil
67
+
68
+
69
+ def process_input(body, random=False):
70
+ global current_codes
71
+ json_body = json.loads(body.decode("utf-8"))
72
+ user_id = json_body["user_id"]
73
+ start_time = time.time()
74
+
75
+ # save current code for different users
76
+ if user_id not in current_codes:
77
+ current_codes[user_id] = initial_code.clone()
78
+ if len(current_codes[user_id]) > max_user_num:
79
+ current_codes[user_id] = current_codes[user_id][-max_user_num:]
80
+
81
+ if random:
82
+ # randomize code
83
+ domain = json_body["model"]
84
+ current_codes[user_id] = get_code(domain, boundaries)
85
+
86
+ # get input
87
+ input_img = DecodeImage(json_body["img"])
88
+
89
+ try:
90
+ move_range = float(json_body["move_range"])
91
+ except:
92
+ move_range = 0
93
+
94
+ # set move range to 3 if random is True
95
+ move_range = 3 if random else move_range
96
+ # print("Input image was received")
97
+ # get selected style
98
+ domain = json_body["model"]
99
+ if move_range != 0:
100
+ modified_code = modify_code(current_codes[user_id], boundaries, domain, move_range)
101
+ else:
102
+ modified_code = current_code.clone()
103
+
104
+ # inference
105
+ result = get_artwork(model, input_img, modified_code)
106
+ print("Time Cost: ", time.time() - start_time)
107
+ return EncodeImage(result)
108
+
109
+
110
+ @app.get("/", response_class=HTMLResponse)
111
+ def root(request: Request):
112
+ return templates.TemplateResponse("index.html", {"request": request})
113
+
114
+
115
+ @app.post("/predict")
116
+ async def predict(request: Request):
117
+ body = await request.body()
118
+ result = process_input(body, random=False)
119
+ return result
120
+
121
+
122
+ @app.post("/predict_random")
123
+ async def predict_random(request: Request):
124
+ body = await request.body()
125
+ result = process_input(body, random=True)
126
+ return result
127
+
128
+
129
+ if __name__ == "__main__":
130
+ uvicorn.run(app, host=HOST, port=PORT, log_level="info")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ scikit-learn
2
+ scikit-image
3
+ torchvision>=0.7.0
4
+ torch>=1.6.0
5
+ fastapi
6
+ uvicorn
seg2art/checkpoints/multimodal_artworks/latest_net_G-fp16.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11e5d324b59dce20e81cea0eed77b25c7b9f6b56ccb44970b67593b4287ddb4c
3
+ size 418576205
seg2art/inference_util.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+ import numpy as np
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+
9
+ # define constants
10
+ image_size = 256
11
+
12
+ # to label
13
+ values = [12, 2, 6, 8, 1, 10, 3, 14, 11, 4, 5, 13, 9]
14
+ values = np.array(values)
15
+
16
+ # from color
17
+ colors = [
18
+ (135, 206, 235),
19
+ (155, 118, 83),
20
+ (176, 212, 155),
21
+ (90, 188, 216),
22
+ (193, 190, 186),
23
+ (90, 77, 65),
24
+ (86, 125, 70),
25
+ (66, 105, 47),
26
+ (21, 119, 190),
27
+ (58, 46, 39),
28
+ (77, 65, 90),
29
+ (253, 218, 22),
30
+ (208, 204, 204),
31
+ ]
32
+ colors = np.array(colors)
33
+
34
+
35
+ def remap_label(arr):
36
+ # compare only first 1 channel to speed up
37
+ arr_r = arr[:, :, 0]
38
+
39
+ # remap color to label
40
+ for i in range(len(colors)):
41
+ arr_r[arr_r == colors[i][0]] = values[i]
42
+ # others to 15
43
+ arr_r[arr_r > 15] = 15
44
+ return arr_r
45
+
46
+
47
+ preprocess = transforms.Compose(
48
+ [
49
+ transforms.Resize([image_size, image_size]),
50
+ transforms.ToTensor(),
51
+ ]
52
+ )
53
+
54
+
55
+ def image_loader(loader, label_inp):
56
+ image = Image.fromarray(label_inp).convert("RGB")
57
+ image = image.resize((image_size, image_size))
58
+ image = loader(image).float() * 255
59
+ image = image.clone().detach().requires_grad_(True)
60
+ image = image.unsqueeze(0)
61
+ return image
62
+
63
+
64
+ def tensor2im(image_tensor):
65
+ image_numpy = image_tensor[0].detach().cpu().float().numpy()
66
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
67
+ image_numpy = np.clip(image_numpy, 0, 255)
68
+ return Image.fromarray(image_numpy.astype(np.uint8))
69
+
70
+
71
+ def get_artwork(model, data, code):
72
+ label_inp = remap_label(np.array(data))
73
+ label_inp = (image_loader(preprocess, label_inp)).detach().half()
74
+
75
+ image_out = model(label_inp, mode="inference", style_codes=code)
76
+ image_out = tensor2im(image_out)
77
+ return image_out
seg2art/model_util.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+
6
+ import re
7
+ import importlib
8
+ import torch
9
+ from argparse import Namespace
10
+ import numpy as np
11
+ from PIL import Image
12
+ import os
13
+
14
+
15
+ # Converts a Tensor into a Numpy array
16
+ # |imtype|: the desired type of the converted numpy array
17
+ def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False):
18
+ if isinstance(image_tensor, list):
19
+ image_numpy = []
20
+ for i in range(len(image_tensor)):
21
+ image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
22
+ return image_numpy
23
+
24
+ if image_tensor.dim() == 4:
25
+ # transform each image in the batch
26
+ images_np = []
27
+ for b in range(image_tensor.size(0)):
28
+ one_image = image_tensor[b]
29
+ one_image_np = tensor2im(one_image)
30
+ images_np.append(one_image_np.reshape(1, *one_image_np.shape))
31
+ images_np = np.concatenate(images_np, axis=0)
32
+ return images_np
33
+
34
+ if image_tensor.dim() == 2:
35
+ image_tensor = image_tensor.unsqueeze(0)
36
+ image_numpy = image_tensor.detach().cpu().float().numpy()
37
+ if normalize:
38
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
39
+ else:
40
+ image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
41
+ image_numpy = np.clip(image_numpy, 0, 255)
42
+ if image_numpy.shape[2] == 1:
43
+ image_numpy = image_numpy[:, :, 0]
44
+ return image_numpy.astype(imtype)
45
+
46
+
47
+ # Converts a one-hot tensor into a colorful label map
48
+ def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False):
49
+ if label_tensor.dim() == 4:
50
+ # transform each image in the batch
51
+ images_np = []
52
+ for b in range(label_tensor.size(0)):
53
+ one_image = label_tensor[b]
54
+ one_image_np = tensor2label(one_image, n_label, imtype)
55
+ images_np.append(one_image_np.reshape(1, *one_image_np.shape))
56
+ images_np = np.concatenate(images_np, axis=0)
57
+ if tile:
58
+ images_tiled = tile_images(images_np)
59
+ return images_tiled
60
+ else:
61
+ images_np = images_np[0]
62
+ return images_np
63
+
64
+ if label_tensor.dim() == 1:
65
+ return np.zeros((64, 64, 3), dtype=np.uint8)
66
+ if n_label == 0:
67
+ return tensor2im(label_tensor, imtype)
68
+ label_tensor = label_tensor.cpu().float()
69
+ if label_tensor.size()[0] > 1:
70
+ label_tensor = label_tensor.max(0, keepdim=True)[1]
71
+ label_tensor = Colorize(n_label)(label_tensor)
72
+ label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
73
+ result = label_numpy.astype(imtype)
74
+ return result
75
+
76
+
77
+ def save_image(image_numpy, image_path, create_dir=False):
78
+ if create_dir:
79
+ os.makedirs(os.path.dirname(image_path), exist_ok=True)
80
+ if len(image_numpy.shape) == 2:
81
+ image_numpy = np.expand_dims(image_numpy, axis=2)
82
+ if image_numpy.shape[2] == 1:
83
+ image_numpy = np.repeat(image_numpy, 3, 2)
84
+ image_pil = Image.fromarray(image_numpy)
85
+
86
+ # save to png
87
+ image_pil.save(image_path.replace('.jpg', '.png'))
88
+
89
+
90
+ def mkdirs(paths):
91
+ if isinstance(paths, list) and not isinstance(paths, str):
92
+ for path in paths:
93
+ mkdir(path)
94
+ else:
95
+ mkdir(paths)
96
+
97
+
98
+ def mkdir(path):
99
+ if not os.path.exists(path):
100
+ os.makedirs(path)
101
+
102
+
103
+ def atoi(text):
104
+ return int(text) if text.isdigit() else text
105
+
106
+
107
+ def natural_keys(text):
108
+ '''
109
+ alist.sort(key=natural_keys) sorts in human order
110
+ http://nedbatchelder.com/blog/200712/human_sorting.html
111
+ (See Toothy's implementation in the comments)
112
+ '''
113
+ return [atoi(c) for c in re.split('(\d+)', text)]
114
+
115
+
116
+ def natural_sort(items):
117
+ items.sort(key=natural_keys)
118
+
119
+
120
+ def str2bool(v):
121
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
122
+ return True
123
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
124
+ return False
125
+ else:
126
+ raise argparse.ArgumentTypeError('Boolean value expected.')
127
+
128
+
129
+ def find_class_in_module(target_cls_name, module):
130
+ target_cls_name = target_cls_name.replace('_', '').lower()
131
+ clslib = importlib.import_module(module)
132
+ cls = None
133
+ for name, clsobj in clslib.__dict__.items():
134
+ if name.lower() == target_cls_name:
135
+ cls = clsobj
136
+
137
+ if cls is None:
138
+ print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name))
139
+ exit(0)
140
+
141
+ return cls
142
+
143
+
144
+ def save_network(net, label, epoch, opt):
145
+ save_filename = '%s_net_%s.pth' % (epoch, label)
146
+ save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
147
+ torch.save(net.cpu().state_dict(), save_path)
148
+ if len(opt.gpu_ids) and torch.cuda.is_available():
149
+ net.cuda()
150
+
151
+
152
+ def load_network(net, label, epoch, opt):
153
+ save_filename = '%s_net_%s.pth' % (epoch, label)
154
+ save_dir = os.path.join(opt.checkpoints_dir, opt.name)
155
+ save_path = os.path.join(save_dir, save_filename)
156
+ weights = torch.load(save_path)
157
+ net.load_state_dict(weights, strict=False)#
158
+ return net
seg2art/options/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
seg2art/options/base_options.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import torch
9
+ import pickle
10
+ import argparse
11
+ import sstan_models
12
+ import utils as util
13
+
14
+
15
+ class BaseOptions():
16
+ def __init__(self):
17
+ self.initialized = False
18
+
19
+ def initialize(self, parser):
20
+ # experiment specifics
21
+ parser.add_argument('--name', type=str, default='multimodal_artworks', help='name of the experiment. It decides where to store samples and sstan_models')
22
+
23
+ parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
24
+ parser.add_argument('--checkpoints_dir', type=str, default='./seg2art/checkpoints', help='sstan_models are saved here')
25
+ parser.add_argument('--model', type=str, default='pix2pix', help='which model to use')
26
+ parser.add_argument('--norm_G', type=str, default='spectralinstance', help='instance normalization or batch normalization')
27
+ parser.add_argument('--norm_D', type=str, default='spectralinstance', help='instance normalization or batch normalization')
28
+ parser.add_argument('--norm_E', type=str, default='spectralinstance', help='instance normalization or batch normalization')
29
+ parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
30
+
31
+ # input/output sizes
32
+ parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
33
+ parser.add_argument('--preprocess_mode', type=str, default='scale_width_and_crop', help='scaling and cropping of images at load time.', choices=("resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "scale_shortside", "scale_shortside_and_crop", "fixed", "none"))
34
+ parser.add_argument('--load_size', type=int, default=512, help='Scale images to this size. The final image will be cropped to --crop_size.')
35
+ parser.add_argument('--crop_size', type=int, default=512, help='Crop to the width of crop_size (after initially scaling the images to load_size.)')
36
+ parser.add_argument('--aspect_ratio', type=float, default=1.0, help='The ratio width/height. The final height of the load image will be crop_size/aspect_ratio')
37
+ parser.add_argument('--label_nc', type=int, default=16, help='# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.')
38
+ parser.add_argument('--contain_dontcare_label', action='store_true', help='if the label map contains dontcare label (dontcare=255)')
39
+ parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
40
+
41
+ # for setting inputs
42
+ parser.add_argument('--dataroot', type=str, default='./datasets/cityscapes/')
43
+ parser.add_argument('--dataset_mode', type=str, default='custom')
44
+ parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
45
+ parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
46
+ parser.add_argument('--nThreads', default=0, type=int, help='# threads for loading data')
47
+ parser.add_argument('--max_dataset_size', type=int, default=sys.maxsize, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
48
+ parser.add_argument('--load_from_opt_file', action='store_true', help='load the options from checkpoints and use that as default')
49
+ parser.add_argument('--cache_filelist_write', action='store_true', help='saves the current filelist into a text file, so that it loads faster')
50
+ parser.add_argument('--cache_filelist_read', action='store_true', help='reads from the file list cache')
51
+
52
+ # for displays
53
+ parser.add_argument('--display_winsize', type=int, default=400, help='display window size')
54
+
55
+ # for generator
56
+ parser.add_argument('--netG', type=str, default='spade', help='selects model to use for netG (pix2pixhd | spade)')
57
+ parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
58
+ parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]')
59
+ parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution')
60
+ parser.add_argument('--z_dim', type=int, default=256,
61
+ help="dimension of the latent z vector")
62
+
63
+ # for instance-wise features
64
+ parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input')
65
+ parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer')
66
+ parser.add_argument('--use_vae', action='store_true', help='enable training with an image encoder.')
67
+
68
+ self.initialized = True
69
+ return parser
70
+
71
+ def gather_options(self):
72
+ # initialize parser with basic options
73
+ if not self.initialized:
74
+ parser = argparse.ArgumentParser(
75
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
76
+ parser = self.initialize(parser)
77
+
78
+ # get the basic options
79
+ opt, unknown = parser.parse_known_args()
80
+
81
+ # modify model-related parser options
82
+ model_name = opt.model
83
+ model_option_setter = sstan_models.get_option_setter(model_name)
84
+ parser = model_option_setter(parser, self.isTrain)
85
+
86
+
87
+ # # modify dataset-related parser options
88
+ # dataset_mode = opt.dataset_mode
89
+ # dataset_option_setter = data.get_option_setter(dataset_mode)
90
+ # parser = dataset_option_setter(parser, self.isTrain)
91
+
92
+ # opt, unknown = parser.parse_known_args()
93
+
94
+ # # if there is opt_file, load it.
95
+ # # The previous default options will be overwritten
96
+ # if opt.load_from_opt_file:
97
+ # parser = self.update_options_from_file(parser, opt)
98
+
99
+ opt = parser.parse_args()
100
+
101
+ opt.contain_dontcare_label = False
102
+ opt.no_instance = True
103
+ opt.use_vae = False
104
+
105
+ self.parser = parser
106
+ return opt
107
+
108
+ def print_options(self, opt):
109
+ message = ''
110
+ message += '----------------- Options ---------------\n'
111
+ for k, v in sorted(vars(opt).items()):
112
+ comment = ''
113
+ default = self.parser.get_default(k)
114
+ if v != default:
115
+ comment = '\t[default: %s]' % str(default)
116
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
117
+ message += '----------------- End -------------------'
118
+ print(message)
119
+
120
+ def option_file_path(self, opt, makedir=False):
121
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
122
+ if makedir:
123
+ util.mkdirs(expr_dir)
124
+ file_name = os.path.join(expr_dir, 'opt')
125
+ return file_name
126
+
127
+ def save_options(self, opt):
128
+ file_name = self.option_file_path(opt, makedir=True)
129
+ with open(file_name + '.txt', 'wt') as opt_file:
130
+ for k, v in sorted(vars(opt).items()):
131
+ comment = ''
132
+ default = self.parser.get_default(k)
133
+ if v != default:
134
+ comment = '\t[default: %s]' % str(default)
135
+ opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment))
136
+
137
+ with open(file_name + '.pkl', 'wb') as opt_file:
138
+ pickle.dump(opt, opt_file)
139
+
140
+ def update_options_from_file(self, parser, opt):
141
+ new_opt = self.load_options(opt)
142
+ for k, v in sorted(vars(opt).items()):
143
+ if hasattr(new_opt, k) and v != getattr(new_opt, k):
144
+ new_val = getattr(new_opt, k)
145
+ parser.set_defaults(**{k: new_val})
146
+ return parser
147
+
148
+ def load_options(self, opt):
149
+ file_name = self.option_file_path(opt, makedir=False)
150
+ new_opt = pickle.load(open(file_name + '.pkl', 'rb'))
151
+ return new_opt
152
+
153
+ def parse(self, save=False):
154
+
155
+ opt = self.gather_options()
156
+ opt.isTrain = self.isTrain # train or test
157
+
158
+ #self.print_options(opt)
159
+ if opt.isTrain:
160
+ self.save_options(opt)
161
+
162
+ # Set semantic_nc based on the option.
163
+ # This will be convenient in many places
164
+ opt.semantic_nc = opt.label_nc + \
165
+ (1 if opt.contain_dontcare_label else 0) + \
166
+ (0 if opt.no_instance else 1)
167
+
168
+ # set gpu ids
169
+ str_ids = opt.gpu_ids.split(',')
170
+ opt.gpu_ids = []
171
+ for str_id in str_ids:
172
+ id = int(str_id)
173
+ if id >= 0:
174
+ opt.gpu_ids.append(id)
175
+ opt.gpu_ids = [] if torch.cuda.device_count() == 0 else opt.gpu_ids
176
+ if len(opt.gpu_ids) > 0:
177
+ torch.cuda.set_device(opt.gpu_ids[0])
178
+
179
+ assert len(opt.gpu_ids) == 0 or opt.batchSize % len(opt.gpu_ids) == 0, \
180
+ "Batch size %d is wrong. It must be a multiple of # GPUs %d." \
181
+ % (opt.batchSize, len(opt.gpu_ids))
182
+
183
+ self.opt = opt
184
+ return self.opt
seg2art/options/test_options.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+
6
+ from .base_options import BaseOptions
7
+
8
+
9
+ class TestOptions(BaseOptions):
10
+ def initialize(self, parser):
11
+ BaseOptions.initialize(self, parser)
12
+ parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
13
+ parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
14
+ parser.add_argument('--checkpoint_path', type=str, default='./checkpoints/multimodal_artworks/latest_net_G-fp16.pth', help='load model from a checkpoint')
15
+ parser.add_argument('--how_many', type=int, default=float("inf"), help='how many test images to run')
16
+
17
+ parser.set_defaults(preprocess_mode='scale_width_and_crop', crop_size=512, load_size=512, display_winsize=256)
18
+ parser.set_defaults(serial_batches=True)
19
+ parser.set_defaults(no_flip=True)
20
+ parser.set_defaults(phase='test')
21
+ self.isTrain = False
22
+ return parser
seg2art/sstan_models/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+
6
+ import importlib
7
+ import torch
8
+
9
+
10
+ def find_model_using_name(model_name):
11
+ # Given the option --model [modelname],
12
+ # the file "sstan_models/modelname_model.py"
13
+ # will be imported.
14
+ model_filename = "sstan_models." + model_name + "_model"
15
+ modellib = importlib.import_module(model_filename)
16
+
17
+ # In the file, the class called ModelNameModel() will
18
+ # be instantiated. It has to be a subclass of torch.nn.Module,
19
+ # and it is case-insensitive.
20
+ model = None
21
+ target_model_name = model_name.replace('_', '') + 'model'
22
+ for name, cls in modellib.__dict__.items():
23
+ if name.lower() == target_model_name.lower() \
24
+ and issubclass(cls, torch.nn.Module):
25
+ model = cls
26
+
27
+ if model is None:
28
+ print("In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name))
29
+ exit(0)
30
+
31
+ return model
32
+
33
+
34
+ def get_option_setter(model_name):
35
+ model_class = find_model_using_name(model_name)
36
+ return model_class.modify_commandline_options
37
+
38
+
39
+ def create_model(opt):
40
+ model = find_model_using_name(opt.model)
41
+ instance = model(opt)
42
+ print("model [%s] was created" % (type(instance).__name__))
43
+
44
+ return instance
seg2art/sstan_models/networks/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+
6
+ import torch
7
+ from sstan_models.networks.base_network import BaseNetwork
8
+ # from sstan_models.networks.loss import *
9
+ # from sstan_models.networks.discriminator import *
10
+ from sstan_models.networks.generator import *
11
+ # from sstan_models.networks.encoder import *
12
+ import model_util as util
13
+
14
+
15
+ def find_network_using_name(target_network_name, filename):
16
+ target_class_name = target_network_name + filename
17
+ module_name = 'sstan_models.networks.' + filename
18
+ network = util.find_class_in_module(target_class_name, module_name)
19
+
20
+ assert issubclass(network, BaseNetwork), \
21
+ "Class %s should be a subclass of BaseNetwork" % network
22
+
23
+ return network
24
+
25
+
26
+ def modify_commandline_options(parser, is_train):
27
+ opt, _ = parser.parse_known_args()
28
+
29
+ netG_cls = find_network_using_name(opt.netG, 'generator')
30
+ parser = netG_cls.modify_commandline_options(parser, is_train)
31
+ if is_train:
32
+ netD_cls = find_network_using_name(opt.netD, 'discriminator')
33
+ parser = netD_cls.modify_commandline_options(parser, is_train)
34
+ # netE_cls = find_network_using_name('conv', 'encoder')
35
+ # parser = netE_cls.modify_commandline_options(parser, is_train)
36
+
37
+ return parser
38
+
39
+
40
+ def create_network(cls, opt):
41
+ net = cls(opt)
42
+ net.print_network()
43
+ if len(opt.gpu_ids) > 0:
44
+ assert(torch.cuda.is_available())
45
+ net.cuda()
46
+ net.init_weights(opt.init_type, opt.init_variance)
47
+ return net
48
+
49
+
50
+ def define_G(opt):
51
+ netG_cls = find_network_using_name(opt.netG, 'generator')
52
+ return create_network(netG_cls, opt)
53
+
54
+
55
+ def define_D(opt):
56
+ netD_cls = find_network_using_name(opt.netD, 'discriminator')
57
+ return create_network(netD_cls, opt)
58
+
59
+
60
+ def define_E(opt):
61
+ # there exists only one encoder type
62
+ netE_cls = find_network_using_name('conv', 'encoder')
63
+ return create_network(netE_cls, opt)
seg2art/sstan_models/networks/architecture.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision
10
+ import torch.nn.utils.spectral_norm as spectral_norm
11
+ from sstan_models.networks.normalization import SPADE
12
+
13
+
14
+ # ResNet block that uses SPADE.
15
+ # It differs from the ResNet block of pix2pixHD in that
16
+ # it takes in the segmentation map as input, learns the skip connection if necessary,
17
+ # and applies normalization first and then convolution.
18
+ # This architecture seemed like a standard architecture for unconditional or
19
+ # class-conditional GAN architecture using residual block.
20
+ # The code was inspired from https://github.com/LMescheder/GAN_stability.
21
+ class SPADEResnetBlock(nn.Module):
22
+ def __init__(self, fin, fout, opt, feed_code=False):
23
+ super().__init__()
24
+
25
+ self.status = 'train'
26
+ # Attributes
27
+ self.learned_shortcut = (fin != fout)
28
+ fmiddle = min(fin, fout)
29
+
30
+ # create conv layers
31
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
32
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
33
+ if self.learned_shortcut:
34
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
35
+
36
+ # apply spectral norm if specified
37
+ if 'spectral' in opt.norm_G:
38
+ self.conv_0 = spectral_norm(self.conv_0)
39
+ self.conv_1 = spectral_norm(self.conv_1)
40
+ if self.learned_shortcut:
41
+ self.conv_s = spectral_norm(self.conv_s)
42
+
43
+ # define normalization layers
44
+ spade_config_str = opt.norm_G.replace('spectral', '')
45
+
46
+ # Attention related
47
+ self.channelAtt = ChannelAttention(fout)
48
+ self.spatialAtt = SpatialAttention()
49
+
50
+ self.norm_0 = SPADE(spade_config_str, fin, feed_code, status=self.status, spade_params=[
51
+ spade_config_str, fin, opt.semantic_nc])
52
+ self.norm_1 = SPADE(spade_config_str, fmiddle, feed_code, status=self.status, spade_params=[
53
+ spade_config_str, fmiddle, opt.semantic_nc])
54
+ if self.learned_shortcut:
55
+ self.norm_s = SPADE(spade_config_str, fin, feed_code, status=self.status, spade_params=[
56
+ spade_config_str, fin, opt.semantic_nc])
57
+
58
+ # note the resnet block with SPADE also takes in |seg|,
59
+ # the semantic segmentation map as input
60
+ def forward(self, x, seg, style_codes=None, self_attention=False):
61
+ x_s = self.shortcut(x, seg, style_codes=style_codes)
62
+
63
+ dx = self.conv_0(self.actvn(
64
+ self.norm_0(x, seg, style_codes=style_codes)))
65
+ dx = self.conv_1(self.actvn(self.norm_1(
66
+ dx, seg, style_codes=style_codes)))
67
+
68
+ dx = self.channelAtt(dx) * dx
69
+ dx = self.spatialAtt(dx) * dx
70
+ out = x_s + dx
71
+
72
+ return out
73
+
74
+ def shortcut(self, x, seg, style_codes):
75
+ if self.learned_shortcut:
76
+ x_s = self.conv_s(self.norm_s(x, seg, style_codes=style_codes))
77
+ else:
78
+ x_s = x
79
+ return x_s
80
+
81
+ def actvn(self, x):
82
+ return F.leaky_relu(x, 2e-1)
83
+
84
+
85
+ # ResNet block used in pix2pixHD
86
+ # We keep the same architecture as pix2pixHD.
87
+ class ResnetBlock(nn.Module):
88
+ def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3):
89
+ super().__init__()
90
+
91
+ pw = (kernel_size - 1) // 2
92
+ self.conv_block = nn.Sequential(
93
+ nn.ReflectionPad2d(pw),
94
+ norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
95
+ activation,
96
+ nn.ReflectionPad2d(pw),
97
+ norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size))
98
+ )
99
+
100
+ def forward(self, x):
101
+ y = self.conv_block(x)
102
+ out = x + y
103
+ return out
104
+
105
+
106
+ # VGG architecter, used for the perceptual loss using a pretrained VGG network
107
+ class VGG19(torch.nn.Module):
108
+ def __init__(self, requires_grad=False):
109
+ super().__init__()
110
+ vgg_pretrained_features = torchvision.models.vgg19(
111
+ pretrained=True).features
112
+ self.slice1 = torch.nn.Sequential()
113
+ self.slice2 = torch.nn.Sequential()
114
+ self.slice3 = torch.nn.Sequential()
115
+ self.slice4 = torch.nn.Sequential()
116
+ self.slice5 = torch.nn.Sequential()
117
+ for x in range(2):
118
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
119
+ for x in range(2, 7):
120
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
121
+ for x in range(7, 12):
122
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
123
+ for x in range(12, 21):
124
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
125
+ for x in range(21, 30):
126
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
127
+ if not requires_grad:
128
+ for param in self.parameters():
129
+ param.requires_grad = False
130
+ # torch.cuda.empty_cache()
131
+ def forward(self, X):
132
+ with torch.cuda.amp.autocast():
133
+ # with torch.no_grad():
134
+ h_relu1 = self.slice1(X)
135
+ h_relu2 = self.slice2(h_relu1)
136
+ h_relu3 = self.slice3(h_relu2)
137
+ h_relu4 = self.slice4(h_relu3)
138
+ h_relu5 = self.slice5(h_relu4)
139
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
140
+ return out
141
+
142
+ '''
143
+ class SourceReferenceAttention(nn.Module):
144
+ """
145
+ Source-Reference Attention Layer
146
+ """
147
+
148
+ def __init__(self, in_planes_s, in_planes_r):
149
+ """
150
+ Parameters
151
+ ----------
152
+ in_planes_s: int
153
+ Number of input source feature vector channels.
154
+ in_planes_r: int
155
+ Number of input reference feature vector channels.
156
+ """
157
+ super(SourceReferenceAttention, self).__init__()
158
+ self.query_conv = nn.Conv2d(in_channels=in_planes_s,
159
+ out_channels=in_planes_s//8, kernel_size=1)
160
+ self.key_conv = nn.Conv2d(in_channels=in_planes_r,
161
+ out_channels=in_planes_r//8, kernel_size=1)
162
+ self.value_conv = nn.Conv2d(in_channels=in_planes_r,
163
+ out_channels=in_planes_r, kernel_size=1)
164
+ self.gamma = nn.Parameter(torch.zeros(1))
165
+ self.softmax = nn.Softmax(dim=-1)
166
+
167
+ def forward(self, source, reference):
168
+ """
169
+ Parameters
170
+ ----------
171
+ source : torch.Tensor
172
+ Source feature maps (B x Cs x Ts x Hs x Ws)
173
+ reference : torch.Tensor
174
+ Reference feature maps (B x Cr x Tr x Hr x Wr )
175
+ Returns :
176
+ torch.Tensor
177
+ Source-reference attention value added to the input source features
178
+ torch.Tensor
179
+ Attention map (B x Ns x Nt) (Ns=Ts*Hs*Ws, Nr=Tr*Hr*Wr)
180
+ """
181
+ s_batchsize, sC, sH, sW = source.size()
182
+ r_batchsize, rC, rH, rW = reference.size()
183
+ proj_query = self.query_conv(source).view(
184
+ s_batchsize, -1, sH*sW).permute(0, 2, 1)
185
+ proj_key = self.key_conv(reference).view(r_batchsize, -1, rW*rH)
186
+ energy = torch.bmm(proj_query, proj_key)
187
+ attention = self.softmax(energy)
188
+ proj_value = self.value_conv(reference).view(r_batchsize, -1, rH*rW)
189
+ out = torch.bmm(proj_value, attention.permute(0, 2, 1))
190
+ out = out.view(s_batchsize, sC, sH, sW)
191
+ out = self.gamma*out + source
192
+ return out, attention
193
+ '''
194
+
195
+ class ChannelAttention(nn.Module):
196
+ def __init__(self, in_planes, ratio=16):
197
+ super(ChannelAttention, self).__init__()
198
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
199
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
200
+
201
+ self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
202
+ self.relu1 = nn.ReLU()
203
+ self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
204
+
205
+ self.sigmoid = nn.Sigmoid()
206
+
207
+ def forward(self, x):
208
+
209
+ avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
210
+ max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
211
+ out = avg_out + max_out
212
+ return self.sigmoid(out)
213
+
214
+
215
+ class SpatialAttention(nn.Module):
216
+ def __init__(self, kernel_size=7):
217
+ super(SpatialAttention, self).__init__()
218
+
219
+ assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
220
+ padding = 3 if kernel_size == 7 else 1
221
+
222
+ self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
223
+ self.sigmoid = nn.Sigmoid()
224
+
225
+ def forward(self, x):
226
+
227
+ avg_out = torch.mean(x, dim=1, keepdim=True)
228
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
229
+ x = torch.cat([avg_out, max_out], dim=1)
230
+ x = self.conv1(x)
231
+ return self.sigmoid(x)
seg2art/sstan_models/networks/base_network.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+
6
+ import torch.nn as nn
7
+ from torch.nn import init
8
+
9
+
10
+ class BaseNetwork(nn.Module):
11
+ def __init__(self):
12
+ super(BaseNetwork, self).__init__()
13
+
14
+ @staticmethod
15
+ def modify_commandline_options(parser, is_train):
16
+ return parser
17
+
18
+ def print_network(self):
19
+ if isinstance(self, list):
20
+ self = self[0]
21
+ num_params = 0
22
+ for param in self.parameters():
23
+ num_params += param.numel()
24
+ print('Network [%s] was created. Total number of parameters: %.1f million. '
25
+ 'To see the architecture, do print(network).'
26
+ % (type(self).__name__, num_params / 1000000))
27
+
28
+ def init_weights(self, init_type='normal', gain=0.02):
29
+ def init_func(m):
30
+ classname = m.__class__.__name__
31
+ if classname.find('BatchNorm2d') != -1:
32
+ if hasattr(m, 'weight') and m.weight is not None:
33
+ init.normal_(m.weight.data, 1.0, gain)
34
+ if hasattr(m, 'bias') and m.bias is not None:
35
+ init.constant_(m.bias.data, 0.0)
36
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
37
+ if init_type == 'normal':
38
+ init.normal_(m.weight.data, 0.0, gain)
39
+ elif init_type == 'xavier':
40
+ init.xavier_normal_(m.weight.data, gain=gain)
41
+ elif init_type == 'xavier_uniform':
42
+ init.xavier_uniform_(m.weight.data, gain=1.0)
43
+ elif init_type == 'kaiming':
44
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
45
+ elif init_type == 'orthogonal':
46
+ init.orthogonal_(m.weight.data, gain=gain)
47
+ elif init_type == 'none': # uses pytorch's default init method
48
+ m.reset_parameters()
49
+ else:
50
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
51
+ if hasattr(m, 'bias') and m.bias is not None:
52
+ init.constant_(m.bias.data, 0.0)
53
+
54
+ self.apply(init_func)
55
+
56
+ # propagate to children
57
+ for m in self.children():
58
+ if hasattr(m, 'init_weights'):
59
+ m.init_weights(init_type, gain)
seg2art/sstan_models/networks/dual_attention_module.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class ChannelAttention(nn.Module):
3
+ def __init__(self, in_planes, ratio=16):
4
+ super(ChannelAttention, self).__init__()
5
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
6
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
7
+
8
+ self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
9
+ self.relu1 = nn.ReLU()
10
+ self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
11
+
12
+ self.sigmoid = nn.Sigmoid()
13
+
14
+ def forward(self, x):
15
+ avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
16
+ max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
17
+ out = avg_out + max_out
18
+ return self.sigmoid(out)
19
+
20
+
21
+ class SpatialAttention(nn.Module):
22
+ def __init__(self, kernel_size=7):
23
+ super(SpatialAttention, self).__init__()
24
+
25
+ assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
26
+ padding = 3 if kernel_size == 7 else 1
27
+
28
+ self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
29
+ self.sigmoid = nn.Sigmoid()
30
+
31
+ def forward(self, x):
32
+ avg_out = torch.mean(x, dim=1, keepdim=True)
33
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
34
+ x = torch.cat([avg_out, max_out], dim=1)
35
+ x = self.conv1(x)
36
+ return self.sigmoid(x)
37
+
38
+
39
+ '''
40
+ def forward(self, x, seg):
41
+ x_s = self.shortcut(x, seg)
42
+
43
+ dx = self.conv_0(self.actvn(self.norm_0(x, seg)))
44
+ dx = self.conv_1(self.actvn(self.norm_1(dx, seg)))
45
+
46
+ dx = self.channelAtt(dx) * dx
47
+ dx = self.spatialAtt(dx) * dx
48
+ out = x_s + dx
49
+
50
+ return out
51
+ '''
seg2art/sstan_models/networks/generator.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from sstan_models.networks.base_network import BaseNetwork
10
+ from sstan_models.networks.normalization import get_nonspade_norm_layer
11
+ from sstan_models.networks.architecture import ResnetBlock as ResnetBlock
12
+ from sstan_models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock
13
+ import numpy as np
14
+ torch.manual_seed(1234)
15
+
16
+
17
+ class SPADEGenerator(BaseNetwork):
18
+ @staticmethod
19
+ def modify_commandline_options(parser, is_train):
20
+ parser.set_defaults(norm_G='spectralspadesyncbatch3x3')
21
+ parser.add_argument('--num_upsampling_layers',
22
+ choices=('normal', 'more', 'most'), default='normal',
23
+ help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator")
24
+
25
+ return parser
26
+
27
+ def __init__(self, opt):
28
+ super().__init__()
29
+ self.opt = opt
30
+ nf = opt.ngf
31
+
32
+ self.sw, self.sh = self.compute_latent_vector_size(opt)
33
+
34
+ # if opt.use_vae:
35
+ # # In case of VAE, we will sample from random z vector
36
+ # self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
37
+ # else:
38
+ # # Otherwise, we make the network deterministic by starting with
39
+ # # downsampled segmentation map instead of random z
40
+ # self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)
41
+ self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
42
+ self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, feed_code = True)
43
+
44
+ self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, feed_code = True)
45
+ self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt, feed_code = True)
46
+
47
+ self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt, feed_code = True)
48
+ self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt, feed_code = True)
49
+ self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt, feed_code = True)
50
+ self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt, feed_code = False)
51
+
52
+ final_nc = nf
53
+
54
+ # if opt.num_upsampling_layers == 'most':
55
+ # print('used?')
56
+ # self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt)
57
+ # final_nc = nf // 2
58
+
59
+ self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
60
+
61
+ self.up = nn.Upsample(scale_factor=2)
62
+
63
+ def compute_latent_vector_size(self, opt):
64
+ if opt.num_upsampling_layers == 'normal':
65
+ num_up_layers = 5
66
+ elif opt.num_upsampling_layers == 'more':
67
+ num_up_layers = 6
68
+ elif opt.num_upsampling_layers == 'most':
69
+ num_up_layers = 7
70
+ else:
71
+ raise ValueError('opt.num_upsampling_layers [%s] not recognized' %
72
+ opt.num_upsampling_layers)
73
+
74
+ sw = opt.crop_size // (2**num_up_layers)
75
+ sh = round(sw / opt.aspect_ratio)
76
+
77
+ return sw, sh
78
+
79
+ def forward(self, input, rgb_img, style_codes=None):
80
+ with torch.cuda.amp.autocast():
81
+ seg = input
82
+ # if self.opt.use_vae:
83
+ # # we sample z from unit normal and reshape the tensor
84
+ # if z is None:
85
+ # z = torch.randn(input.size(0), self.opt.z_dim,
86
+ # dtype=torch.float32, device=input.get_device())
87
+ x = self.fc(style_codes)
88
+ x = x.view(-1, 16 * self.opt.ngf, self.sh, self.sw)
89
+ # else:
90
+ # # we downsample segmap and run convolution
91
+ # x = F.interpolate(seg, size=(self.sh, self.sw))
92
+ # x = self.fc(x)
93
+ x = self.head_0(x, seg, style_codes=style_codes)
94
+
95
+ x = self.up(x)
96
+ x = self.G_middle_0(x, seg, style_codes=style_codes)
97
+
98
+ if self.opt.num_upsampling_layers == 'more' or \
99
+ self.opt.num_upsampling_layers == 'most':
100
+ x = self.up(x)
101
+
102
+ x = self.G_middle_1(x, seg, style_codes=style_codes)
103
+
104
+ x = self.up(x)
105
+ x = self.up_0(x, seg, style_codes=style_codes)
106
+ x = self.up(x)
107
+ x = self.up_1(x, seg, style_codes=style_codes)
108
+ x = self.up(x)
109
+ x = self.up_2(x, seg, style_codes=style_codes)
110
+ x = self.up(x)
111
+ x = self.up_3(x, seg)
112
+
113
+ # if self.opt.num_upsampling_layers == 'most':
114
+ # x = self.up(x)
115
+ # x = self.up_4(x, seg)
116
+
117
+ x = self.conv_img(F.leaky_relu(x, 2e-1))
118
+ x = torch.tanh(x)#F.tanh(x)
119
+
120
+ return x#, style_codes
121
+
122
+
123
+ class Pix2PixHDGenerator(BaseNetwork):
124
+ @staticmethod
125
+ def modify_commandline_options(parser, is_train):
126
+ parser.add_argument('--resnet_n_downsample', type=int, default=4, help='number of downsampling layers in netG')
127
+ parser.add_argument('--resnet_n_blocks', type=int, default=9, help='number of residual blocks in the global generator network')
128
+ parser.add_argument('--resnet_kernel_size', type=int, default=3,
129
+ help='kernel size of the resnet block')
130
+ parser.add_argument('--resnet_initial_kernel_size', type=int, default=7,
131
+ help='kernel size of the first convolution')
132
+ parser.set_defaults(norm_G='instance')
133
+ return parser
134
+
135
+ def __init__(self, opt):
136
+ super().__init__()
137
+ input_nc = opt.label_nc + (1 if opt.contain_dontcare_label else 0) + (0 if opt.no_instance else 1)
138
+
139
+ norm_layer = get_nonspade_norm_layer(opt, opt.norm_G)
140
+ activation = nn.ReLU(False)
141
+
142
+ model = []
143
+
144
+ # initial conv
145
+ model += [nn.ReflectionPad2d(opt.resnet_initial_kernel_size // 2),
146
+ norm_layer(nn.Conv2d(input_nc, opt.ngf,
147
+ kernel_size=opt.resnet_initial_kernel_size,
148
+ padding=0)),
149
+ activation]
150
+
151
+ # downsample
152
+ mult = 1
153
+ for i in range(opt.resnet_n_downsample):
154
+ model += [norm_layer(nn.Conv2d(opt.ngf * mult, opt.ngf * mult * 2,
155
+ kernel_size=3, stride=2, padding=1)),
156
+ activation]
157
+ mult *= 2
158
+
159
+ # resnet blocks
160
+ for i in range(opt.resnet_n_blocks):
161
+ model += [ResnetBlock(opt.ngf * mult,
162
+ norm_layer=norm_layer,
163
+ activation=activation,
164
+ kernel_size=opt.resnet_kernel_size)]
165
+
166
+ # upsample
167
+ for i in range(opt.resnet_n_downsample):
168
+ nc_in = int(opt.ngf * mult)
169
+ nc_out = int((opt.ngf * mult) / 2)
170
+ model += [norm_layer(nn.ConvTranspose2d(nc_in, nc_out,
171
+ kernel_size=3, stride=2,
172
+ padding=1, output_padding=1)),
173
+ activation]
174
+ mult = mult // 2
175
+
176
+ # final output conv
177
+ model += [nn.ReflectionPad2d(3),
178
+ nn.Conv2d(nc_out, opt.output_nc, kernel_size=7, padding=0),
179
+ nn.Tanh()]
180
+
181
+ self.model = nn.Sequential(*model)
182
+
183
+ def forward(self, input, z=None):
184
+ return self.model(input)
seg2art/sstan_models/networks/normalization.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+
6
+ import re
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from sstan_models.networks.sync_batchnorm import SynchronizedBatchNorm2d
11
+ import torch.nn.utils.spectral_norm as spectral_norm
12
+
13
+
14
+ # Returns a function that creates a normalization function
15
+ # that does not condition on semantic map
16
+ def get_nonspade_norm_layer(opt, norm_type='instance'):
17
+ # helper function to get # output channels of the previous layer
18
+ def get_out_channel(layer):
19
+ if hasattr(layer, 'out_channels'):
20
+ return getattr(layer, 'out_channels')
21
+ return layer.weight.size(0)
22
+
23
+ # this function will be returned
24
+ def add_norm_layer(layer):
25
+ nonlocal norm_type
26
+ if norm_type.startswith('spectral'):
27
+ layer = spectral_norm(layer)
28
+ subnorm_type = norm_type[len('spectral'):]
29
+
30
+ if subnorm_type == 'none' or len(subnorm_type) == 0:
31
+ return layer
32
+
33
+ # remove bias in the previous layer, which is meaningless
34
+ # since it has no effect after normalization
35
+ if getattr(layer, 'bias', None) is not None:
36
+ delattr(layer, 'bias')
37
+ layer.register_parameter('bias', None)
38
+
39
+ if subnorm_type == 'batch':
40
+ norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
41
+ elif subnorm_type == 'sync_batch':
42
+ norm_layer = SynchronizedBatchNorm2d(
43
+ get_out_channel(layer), affine=True)
44
+ elif subnorm_type == 'instance':
45
+ norm_layer = nn.InstanceNorm2d(
46
+ get_out_channel(layer), affine=False)
47
+ else:
48
+ raise ValueError(
49
+ 'normalization layer %s is not recognized' % subnorm_type)
50
+
51
+ return nn.Sequential(layer, norm_layer)
52
+
53
+ return add_norm_layer
54
+
55
+
56
+ # Creates SPADE normalization layer based on the given configuration
57
+ # SPADE consists of two steps. First, it normalizes the activations using
58
+ # your favorite normalization method, such as Batch Norm or Instance Norm.
59
+ # Second, it applies scale and bias to the normalized output, conditioned on
60
+ # the segmentation map.
61
+ # The format of |config_text| is spade(norm)(ks), where
62
+ # (norm) specifies the type of parameter-free normalization.
63
+ # (e.g. syncbatch, batch, instance)
64
+ # (ks) specifies the size of kernel in the SPADE module (e.g. 3x3)
65
+ # Example |config_text| will be spadesyncbatch3x3, or spadeinstance5x5.
66
+ # Also, the other arguments are
67
+ # |norm_nc|: the #channels of the normalized activations, hence the output dim of SPADE
68
+ # |label_nc|: the #channels of the input semantic map, hence the input dim of SPADE
69
+ class SPADE(nn.Module):
70
+ def __init__(self, config_text, norm_nc, feed_code, status='train', spade_params=None):
71
+ super().__init__()
72
+
73
+ self.style_length = 256
74
+ # self.noise_var = nn.Parameter(torch.zeros(norm_nc), requires_grad=True)
75
+ self.Spade = SPADE_ori(*spade_params)
76
+
77
+
78
+ assert config_text.startswith('spade')
79
+ parsed = re.search('spade(\D+)(\d)x\d', config_text)
80
+ param_free_norm_type = str(parsed.group(1))
81
+ ks = int(parsed.group(2))
82
+ pw = ks // 2
83
+
84
+ if param_free_norm_type == 'instance':
85
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
86
+ elif param_free_norm_type == 'syncbatch':
87
+ self.param_free_norm = SynchronizedBatchNorm2d(
88
+ norm_nc, affine=False)
89
+ elif param_free_norm_type == 'batch':
90
+ self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
91
+ else:
92
+ raise ValueError('%s is not a recognized param-free norm type in SPADE'
93
+ % param_free_norm_type)
94
+
95
+ # self.create_gamma_beta_fc_layers()
96
+ if feed_code:
97
+ self.blending_gamma = nn.Parameter(torch.zeros(1), requires_grad=True)
98
+ self.blending_beta = nn.Parameter(torch.zeros(1), requires_grad=True)
99
+ self.conv_gamma = nn.Conv2d(
100
+ self.style_length, norm_nc, kernel_size=ks, padding=pw)
101
+ self.conv_beta = nn.Conv2d(
102
+ self.style_length, norm_nc, kernel_size=ks, padding=pw)
103
+
104
+ def forward(self, x, segmap, style_codes=None):
105
+ if style_codes is None:
106
+ input_code = False
107
+ else:
108
+ input_code = True
109
+
110
+ # Part 1. generate parameter-free normalized activations
111
+ # added_noise = (torch.randn(
112
+ # x.shape[0], x.shape[3], x.shape[2], 1).cuda() * self.noise_var).transpose(1, 3)
113
+ normalized = self.param_free_norm(x)
114
+
115
+ # Part 2. produce scaling and bias conditioned on semantic map
116
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
117
+
118
+ if input_code:
119
+ [b_size, f_size, h_size, w_size] = normalized.shape
120
+ middle_avg = torch.zeros(
121
+ (b_size, self.style_length, h_size, w_size), device=normalized.device)
122
+
123
+ for i in range(b_size):
124
+
125
+ middle_mu = F.relu((style_codes[i]))
126
+
127
+ middle_mu = middle_mu.reshape(self.style_length, 1).expand(
128
+ self.style_length, h_size*w_size)
129
+ middle_mu = middle_mu.reshape(
130
+ self.style_length, h_size, w_size)
131
+ middle_avg[i] = middle_mu
132
+
133
+ gamma_avg = self.conv_gamma(middle_avg)
134
+ beta_avg = self.conv_beta(middle_avg)
135
+
136
+ gamma_spade, beta_spade = self.Spade(segmap)
137
+
138
+ gamma_alpha = torch.sigmoid(self.blending_gamma)#F.sigmoid(self.blending_gamma)
139
+ beta_alpha = torch.sigmoid(self.blending_gamma)#F.sigmoid(self.blending_beta)
140
+
141
+ gamma_final = gamma_alpha * gamma_avg + \
142
+ (1 - gamma_alpha) * gamma_spade
143
+
144
+ beta_final = beta_alpha * beta_avg + (1 - beta_alpha) * beta_spade
145
+
146
+ out = normalized * (1 + gamma_final) + beta_final
147
+ else:
148
+ gamma_spade, beta_spade = self.Spade(segmap)
149
+ gamma_final = gamma_spade
150
+ beta_final = beta_spade
151
+ out = normalized * (1 + gamma_final) + beta_final
152
+ return out
153
+
154
+ # def create_gamma_beta_fc_layers(self):
155
+
156
+ # # These codes should be replaced with torch.nn.ModuleList
157
+
158
+ # style_length = self.style_length
159
+
160
+ # self.fc_mu0 = nn.Linear(style_length, style_length)
161
+ # self.fc_mu1 = nn.Linear(style_length, style_length)
162
+ # self.fc_mu2 = nn.Linear(style_length, style_length)
163
+ # self.fc_mu3 = nn.Linear(style_length, style_length)
164
+ # self.fc_mu4 = nn.Linear(style_length, style_length)
165
+ # self.fc_mu5 = nn.Linear(style_length, style_length)
166
+ # self.fc_mu6 = nn.Linear(style_length, style_length)
167
+ # self.fc_mu7 = nn.Linear(style_length, style_length)
168
+ # self.fc_mu8 = nn.Linear(style_length, style_length)
169
+ # self.fc_mu9 = nn.Linear(style_length, style_length)
170
+ # self.fc_mu10 = nn.Linear(style_length, style_length)
171
+ # self.fc_mu11 = nn.Linear(style_length, style_length)
172
+ # self.fc_mu12 = nn.Linear(style_length, style_length)
173
+ # self.fc_mu13 = nn.Linear(style_length, style_length)
174
+ # self.fc_mu14 = nn.Linear(style_length, style_length)
175
+ # self.fc_mu15 = nn.Linear(style_length, style_length)
176
+ # self.fc_mu16 = nn.Linear(style_length, style_length)
177
+ # self.fc_mu17 = nn.Linear(style_length, style_length)
178
+ # self.fc_mu18 = nn.Linear(style_length, style_length)
179
+
180
+
181
+ class SPADE_ori(nn.Module):
182
+ def __init__(self, config_text, norm_nc, label_nc):
183
+ super().__init__()
184
+
185
+ assert config_text.startswith('spade')
186
+ parsed = re.search('spade(\D+)(\d)x\d', config_text)
187
+ param_free_norm_type = str(parsed.group(1))
188
+ ks = int(parsed.group(2))
189
+
190
+ if param_free_norm_type == 'instance':
191
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
192
+ elif param_free_norm_type == 'syncbatch':
193
+ self.param_free_norm = SynchronizedBatchNorm2d(
194
+ norm_nc, affine=False)
195
+ elif param_free_norm_type == 'batch':
196
+ self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
197
+ else:
198
+ raise ValueError('%s is not a recognized param-free norm type in SPADE'
199
+ % param_free_norm_type)
200
+
201
+ # The dimension of the intermediate embedding space. Yes, hardcoded.
202
+ nhidden = 128
203
+
204
+ pw = ks // 2
205
+ self.mlp_shared = nn.Sequential(
206
+ nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
207
+ nn.ReLU()
208
+ )
209
+
210
+ self.mlp_gamma = nn.Conv2d(
211
+ nhidden, norm_nc, kernel_size=ks, padding=pw)
212
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
213
+
214
+ def forward(self, segmap):
215
+
216
+ inputmap = segmap
217
+
218
+ actv = self.mlp_shared(inputmap)
219
+ gamma = self.mlp_gamma(actv)
220
+ beta = self.mlp_beta(actv)
221
+
222
+ return gamma, beta
seg2art/sstan_models/networks/sync_batchnorm/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : __init__.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12
+ from .batchnorm import convert_model
13
+ from .replicate import DataParallelWithCallback, patch_replication_callback
seg2art/sstan_models/networks/sync_batchnorm/batchnorm.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import collections
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from torch.nn.modules.batchnorm import _BatchNorm
17
+ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18
+
19
+ from .comm import SyncMaster
20
+ from .replicate import DataParallelWithCallback
21
+
22
+ __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d',
23
+ 'SynchronizedBatchNorm3d', 'convert_model']
24
+
25
+
26
+ def _sum_ft(tensor):
27
+ """sum over the first and last dimention"""
28
+ return tensor.sum(dim=0).sum(dim=-1)
29
+
30
+
31
+ def _unsqueeze_ft(tensor):
32
+ """add new dementions at the front and the tail"""
33
+ return tensor.unsqueeze(0).unsqueeze(-1)
34
+
35
+
36
+ _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
37
+ _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
38
+
39
+
40
+ class _SynchronizedBatchNorm(_BatchNorm):
41
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
42
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
43
+
44
+ self._sync_master = SyncMaster(self._data_parallel_master)
45
+
46
+ self._is_parallel = False
47
+ self._parallel_id = None
48
+ self._slave_pipe = None
49
+
50
+ def forward(self, input):
51
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
52
+ if not (self._is_parallel and self.training):
53
+ return F.batch_norm(
54
+ input, self.running_mean, self.running_var, self.weight, self.bias,
55
+ self.training, self.momentum, self.eps)
56
+
57
+ # Resize the input to (B, C, -1).
58
+ input_shape = input.size()
59
+ input = input.view(input.size(0), self.num_features, -1)
60
+
61
+ # Compute the sum and square-sum.
62
+ sum_size = input.size(0) * input.size(2)
63
+ input_sum = _sum_ft(input)
64
+ input_ssum = _sum_ft(input ** 2)
65
+
66
+ # Reduce-and-broadcast the statistics.
67
+ if self._parallel_id == 0:
68
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
69
+ else:
70
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
71
+
72
+ # Compute the output.
73
+ if self.affine:
74
+ # MJY:: Fuse the multiplication for speed.
75
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
76
+ else:
77
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
78
+
79
+ # Reshape it.
80
+ return output.view(input_shape)
81
+
82
+ def __data_parallel_replicate__(self, ctx, copy_id):
83
+ self._is_parallel = True
84
+ self._parallel_id = copy_id
85
+
86
+ # parallel_id == 0 means master device.
87
+ if self._parallel_id == 0:
88
+ ctx.sync_master = self._sync_master
89
+ else:
90
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
91
+
92
+ def _data_parallel_master(self, intermediates):
93
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
94
+
95
+ # Always using same "device order" makes the ReduceAdd operation faster.
96
+ # Thanks to:: Tete Xiao (http://tetexiao.com/)
97
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
98
+
99
+ to_reduce = [i[1][:2] for i in intermediates]
100
+ to_reduce = [j for i in to_reduce for j in i] # flatten
101
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
102
+
103
+ sum_size = sum([i[1].sum_size for i in intermediates])
104
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
105
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
106
+
107
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
108
+
109
+ outputs = []
110
+ for i, rec in enumerate(intermediates):
111
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
112
+
113
+ return outputs
114
+
115
+ def _compute_mean_std(self, sum_, ssum, size):
116
+ """Compute the mean and standard-deviation with sum and square-sum. This method
117
+ also maintains the moving average on the master device."""
118
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
119
+ mean = sum_ / size
120
+ sumvar = ssum - sum_ * mean
121
+ unbias_var = sumvar / (size - 1)
122
+ bias_var = sumvar / size
123
+
124
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
125
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
126
+
127
+ return mean, bias_var.clamp(self.eps) ** -0.5
128
+
129
+
130
+ class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
131
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
132
+ mini-batch.
133
+
134
+ .. math::
135
+
136
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
137
+
138
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
139
+ standard-deviation are reduced across all devices during training.
140
+
141
+ For example, when one uses `nn.DataParallel` to wrap the network during
142
+ training, PyTorch's implementation normalize the tensor on each device using
143
+ the statistics only on that device, which accelerated the computation and
144
+ is also easy to implement, but the statistics might be inaccurate.
145
+ Instead, in this synchronized version, the statistics will be computed
146
+ over all training samples distributed on multiple devices.
147
+
148
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
149
+ as the built-in PyTorch implementation.
150
+
151
+ The mean and standard-deviation are calculated per-dimension over
152
+ the mini-batches and gamma and beta are learnable parameter vectors
153
+ of size C (where C is the input size).
154
+
155
+ During training, this layer keeps a running estimate of its computed mean
156
+ and variance. The running sum is kept with a default momentum of 0.1.
157
+
158
+ During evaluation, this running mean/variance is used for normalization.
159
+
160
+ Because the BatchNorm is done over the `C` dimension, computing statistics
161
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
162
+
163
+ Args:
164
+ num_features: num_features from an expected input of size
165
+ `batch_size x num_features [x width]`
166
+ eps: a value added to the denominator for numerical stability.
167
+ Default: 1e-5
168
+ momentum: the value used for the running_mean and running_var
169
+ computation. Default: 0.1
170
+ affine: a boolean value that when set to ``True``, gives the layer learnable
171
+ affine parameters. Default: ``True``
172
+
173
+ Shape:
174
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
175
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
176
+
177
+ Examples:
178
+ >>> # With Learnable Parameters
179
+ >>> m = SynchronizedBatchNorm1d(100)
180
+ >>> # Without Learnable Parameters
181
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
182
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
183
+ >>> output = m(input)
184
+ """
185
+
186
+ def _check_input_dim(self, input):
187
+ if input.dim() != 2 and input.dim() != 3:
188
+ raise ValueError('expected 2D or 3D input (got {}D input)'
189
+ .format(input.dim()))
190
+ super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
191
+
192
+
193
+ class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
194
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
195
+ of 3d inputs
196
+
197
+ .. math::
198
+
199
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
200
+
201
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
202
+ standard-deviation are reduced across all devices during training.
203
+
204
+ For example, when one uses `nn.DataParallel` to wrap the network during
205
+ training, PyTorch's implementation normalize the tensor on each device using
206
+ the statistics only on that device, which accelerated the computation and
207
+ is also easy to implement, but the statistics might be inaccurate.
208
+ Instead, in this synchronized version, the statistics will be computed
209
+ over all training samples distributed on multiple devices.
210
+
211
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
212
+ as the built-in PyTorch implementation.
213
+
214
+ The mean and standard-deviation are calculated per-dimension over
215
+ the mini-batches and gamma and beta are learnable parameter vectors
216
+ of size C (where C is the input size).
217
+
218
+ During training, this layer keeps a running estimate of its computed mean
219
+ and variance. The running sum is kept with a default momentum of 0.1.
220
+
221
+ During evaluation, this running mean/variance is used for normalization.
222
+
223
+ Because the BatchNorm is done over the `C` dimension, computing statistics
224
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
225
+
226
+ Args:
227
+ num_features: num_features from an expected input of
228
+ size batch_size x num_features x height x width
229
+ eps: a value added to the denominator for numerical stability.
230
+ Default: 1e-5
231
+ momentum: the value used for the running_mean and running_var
232
+ computation. Default: 0.1
233
+ affine: a boolean value that when set to ``True``, gives the layer learnable
234
+ affine parameters. Default: ``True``
235
+
236
+ Shape:
237
+ - Input: :math:`(N, C, H, W)`
238
+ - Output: :math:`(N, C, H, W)` (same shape as input)
239
+
240
+ Examples:
241
+ >>> # With Learnable Parameters
242
+ >>> m = SynchronizedBatchNorm2d(100)
243
+ >>> # Without Learnable Parameters
244
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
245
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
246
+ >>> output = m(input)
247
+ """
248
+
249
+ def _check_input_dim(self, input):
250
+ if input.dim() != 4:
251
+ raise ValueError('expected 4D input (got {}D input)'
252
+ .format(input.dim()))
253
+ super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
254
+
255
+
256
+ class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
257
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
258
+ of 4d inputs
259
+
260
+ .. math::
261
+
262
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
263
+
264
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
265
+ standard-deviation are reduced across all devices during training.
266
+
267
+ For example, when one uses `nn.DataParallel` to wrap the network during
268
+ training, PyTorch's implementation normalize the tensor on each device using
269
+ the statistics only on that device, which accelerated the computation and
270
+ is also easy to implement, but the statistics might be inaccurate.
271
+ Instead, in this synchronized version, the statistics will be computed
272
+ over all training samples distributed on multiple devices.
273
+
274
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
275
+ as the built-in PyTorch implementation.
276
+
277
+ The mean and standard-deviation are calculated per-dimension over
278
+ the mini-batches and gamma and beta are learnable parameter vectors
279
+ of size C (where C is the input size).
280
+
281
+ During training, this layer keeps a running estimate of its computed mean
282
+ and variance. The running sum is kept with a default momentum of 0.1.
283
+
284
+ During evaluation, this running mean/variance is used for normalization.
285
+
286
+ Because the BatchNorm is done over the `C` dimension, computing statistics
287
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
288
+ or Spatio-temporal BatchNorm
289
+
290
+ Args:
291
+ num_features: num_features from an expected input of
292
+ size batch_size x num_features x depth x height x width
293
+ eps: a value added to the denominator for numerical stability.
294
+ Default: 1e-5
295
+ momentum: the value used for the running_mean and running_var
296
+ computation. Default: 0.1
297
+ affine: a boolean value that when set to ``True``, gives the layer learnable
298
+ affine parameters. Default: ``True``
299
+
300
+ Shape:
301
+ - Input: :math:`(N, C, D, H, W)`
302
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
303
+
304
+ Examples:
305
+ >>> # With Learnable Parameters
306
+ >>> m = SynchronizedBatchNorm3d(100)
307
+ >>> # Without Learnable Parameters
308
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
309
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
310
+ >>> output = m(input)
311
+ """
312
+
313
+ def _check_input_dim(self, input):
314
+ if input.dim() != 5:
315
+ raise ValueError('expected 5D input (got {}D input)'
316
+ .format(input.dim()))
317
+ super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
318
+
319
+
320
+ def convert_model(module):
321
+ """Traverse the input module and its child recursively
322
+ and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
323
+ to SynchronizedBatchNorm*N*d
324
+
325
+ Args:
326
+ module: the input module needs to be convert to SyncBN model
327
+
328
+ Examples:
329
+ >>> import torch.nn as nn
330
+ >>> import torchvision
331
+ >>> # m is a standard pytorch model
332
+ >>> m = torchvision.models.resnet18(True)
333
+ >>> m = nn.DataParallel(m)
334
+ >>> # after convert, m is using SyncBN
335
+ >>> m = convert_model(m)
336
+ """
337
+ if isinstance(module, torch.nn.DataParallel):
338
+ mod = module.module
339
+ mod = convert_model(mod)
340
+ mod = DataParallelWithCallback(mod)
341
+ return mod
342
+
343
+ mod = module
344
+ for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
345
+ torch.nn.modules.batchnorm.BatchNorm2d,
346
+ torch.nn.modules.batchnorm.BatchNorm3d],
347
+ [SynchronizedBatchNorm1d,
348
+ SynchronizedBatchNorm2d,
349
+ SynchronizedBatchNorm3d]):
350
+ if isinstance(module, pth_module):
351
+ mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
352
+ mod.running_mean = module.running_mean
353
+ mod.running_var = module.running_var
354
+ if module.affine:
355
+ mod.weight.data = module.weight.data.clone().detach()
356
+ mod.bias.data = module.bias.data.clone().detach()
357
+
358
+ for name, child in module.named_children():
359
+ mod.add_module(name, convert_model(child))
360
+
361
+ return mod
seg2art/sstan_models/networks/sync_batchnorm/batchnorm_reimpl.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : batchnorm_reimpl.py
4
+ # Author : acgtyrant
5
+ # Date : 11/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.init as init
14
+
15
+ __all__ = ['BatchNormReimpl']
16
+
17
+
18
+ class BatchNorm2dReimpl(nn.Module):
19
+ """
20
+ A re-implementation of batch normalization, used for testing the numerical
21
+ stability.
22
+
23
+ Author: acgtyrant
24
+ See also:
25
+ https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
26
+ """
27
+ def __init__(self, num_features, eps=1e-5, momentum=0.1):
28
+ super().__init__()
29
+
30
+ self.num_features = num_features
31
+ self.eps = eps
32
+ self.momentum = momentum
33
+ self.weight = nn.Parameter(torch.empty(num_features))
34
+ self.bias = nn.Parameter(torch.empty(num_features))
35
+ self.register_buffer('running_mean', torch.zeros(num_features))
36
+ self.register_buffer('running_var', torch.ones(num_features))
37
+ self.reset_parameters()
38
+
39
+ def reset_running_stats(self):
40
+ self.running_mean.zero_()
41
+ self.running_var.fill_(1)
42
+
43
+ def reset_parameters(self):
44
+ self.reset_running_stats()
45
+ init.uniform_(self.weight)
46
+ init.zeros_(self.bias)
47
+
48
+ def forward(self, input_):
49
+ batchsize, channels, height, width = input_.size()
50
+ numel = batchsize * height * width
51
+ input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
52
+ sum_ = input_.sum(1)
53
+ sum_of_square = input_.pow(2).sum(1)
54
+ mean = sum_ / numel
55
+ sumvar = sum_of_square - sum_ * mean
56
+
57
+ self.running_mean = (
58
+ (1 - self.momentum) * self.running_mean
59
+ + self.momentum * mean.detach()
60
+ )
61
+ unbias_var = sumvar / (numel - 1)
62
+ self.running_var = (
63
+ (1 - self.momentum) * self.running_var
64
+ + self.momentum * unbias_var.detach()
65
+ )
66
+
67
+ bias_var = sumvar / numel
68
+ inv_std = 1 / (bias_var + self.eps).pow(0.5)
69
+ output = (
70
+ (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
71
+ self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
72
+
73
+ return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
74
+
seg2art/sstan_models/networks/sync_batchnorm/comm.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : comm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import queue
12
+ import collections
13
+ import threading
14
+
15
+ __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16
+
17
+
18
+ class FutureResult(object):
19
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
20
+
21
+ def __init__(self):
22
+ self._result = None
23
+ self._lock = threading.Lock()
24
+ self._cond = threading.Condition(self._lock)
25
+
26
+ def put(self, result):
27
+ with self._lock:
28
+ assert self._result is None, 'Previous result has\'t been fetched.'
29
+ self._result = result
30
+ self._cond.notify()
31
+
32
+ def get(self):
33
+ with self._lock:
34
+ if self._result is None:
35
+ self._cond.wait()
36
+
37
+ res = self._result
38
+ self._result = None
39
+ return res
40
+
41
+
42
+ _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43
+ _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44
+
45
+
46
+ class SlavePipe(_SlavePipeBase):
47
+ """Pipe for master-slave communication."""
48
+
49
+ def run_slave(self, msg):
50
+ self.queue.put((self.identifier, msg))
51
+ ret = self.result.get()
52
+ self.queue.put(True)
53
+ return ret
54
+
55
+
56
+ class SyncMaster(object):
57
+ """An abstract `SyncMaster` object.
58
+
59
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62
+ and passed to a registered callback.
63
+ - After receiving the messages, the master device should gather the information and determine to message passed
64
+ back to each slave devices.
65
+ """
66
+
67
+ def __init__(self, master_callback):
68
+ """
69
+
70
+ Args:
71
+ master_callback: a callback to be invoked after having collected messages from slave devices.
72
+ """
73
+ self._master_callback = master_callback
74
+ self._queue = queue.Queue()
75
+ self._registry = collections.OrderedDict()
76
+ self._activated = False
77
+
78
+ def __getstate__(self):
79
+ return {'master_callback': self._master_callback}
80
+
81
+ def __setstate__(self, state):
82
+ self.__init__(state['master_callback'])
83
+
84
+ def register_slave(self, identifier):
85
+ """
86
+ Register an slave device.
87
+
88
+ Args:
89
+ identifier: an identifier, usually is the device id.
90
+
91
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
92
+
93
+ """
94
+ if self._activated:
95
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
96
+ self._activated = False
97
+ self._registry.clear()
98
+ future = FutureResult()
99
+ self._registry[identifier] = _MasterRegistry(future)
100
+ return SlavePipe(identifier, self._queue, future)
101
+
102
+ def run_master(self, master_msg):
103
+ """
104
+ Main entry for the master device in each forward pass.
105
+ The messages were first collected from each devices (including the master device), and then
106
+ an callback will be invoked to compute the message to be sent back to each devices
107
+ (including the master device).
108
+
109
+ Args:
110
+ master_msg: the message that the master want to send to itself. This will be placed as the first
111
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112
+
113
+ Returns: the message to be sent back to the master device.
114
+
115
+ """
116
+ self._activated = True
117
+
118
+ intermediates = [(0, master_msg)]
119
+ for i in range(self.nr_slaves):
120
+ intermediates.append(self._queue.get())
121
+
122
+ results = self._master_callback(intermediates)
123
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
124
+
125
+ for i, res in results:
126
+ if i == 0:
127
+ continue
128
+ self._registry[i].result.put(res)
129
+
130
+ for i in range(self.nr_slaves):
131
+ assert self._queue.get() is True
132
+
133
+ return results[0][1]
134
+
135
+ @property
136
+ def nr_slaves(self):
137
+ return len(self._registry)
seg2art/sstan_models/networks/sync_batchnorm/replicate.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : replicate.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import functools
12
+
13
+ from torch.nn.parallel.data_parallel import DataParallel
14
+
15
+ __all__ = [
16
+ 'CallbackContext',
17
+ 'execute_replication_callbacks',
18
+ 'DataParallelWithCallback',
19
+ 'patch_replication_callback'
20
+ ]
21
+
22
+
23
+ class CallbackContext(object):
24
+ pass
25
+
26
+
27
+ def execute_replication_callbacks(modules):
28
+ """
29
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30
+
31
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32
+
33
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
34
+ (shared among multiple copies of this module on different devices).
35
+ Through this context, different copies can share some information.
36
+
37
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38
+ of any slave copies.
39
+ """
40
+ master_copy = modules[0]
41
+ nr_modules = len(list(master_copy.modules()))
42
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
43
+
44
+ for i, module in enumerate(modules):
45
+ for j, m in enumerate(module.modules()):
46
+ if hasattr(m, '__data_parallel_replicate__'):
47
+ m.__data_parallel_replicate__(ctxs[j], i)
48
+
49
+
50
+ class DataParallelWithCallback(DataParallel):
51
+ """
52
+ Data Parallel with a replication callback.
53
+
54
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55
+ original `replicate` function.
56
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57
+
58
+ Examples:
59
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61
+ # sync_bn.__data_parallel_replicate__ will be invoked.
62
+ """
63
+
64
+ def replicate(self, module, device_ids):
65
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66
+ execute_replication_callbacks(modules)
67
+ return modules
68
+
69
+
70
+ def patch_replication_callback(data_parallel):
71
+ """
72
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
73
+ Useful when you have customized `DataParallel` implementation.
74
+
75
+ Examples:
76
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78
+ > patch_replication_callback(sync_bn)
79
+ # this is equivalent to
80
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82
+ """
83
+
84
+ assert isinstance(data_parallel, DataParallel)
85
+
86
+ old_replicate = data_parallel.replicate
87
+
88
+ @functools.wraps(old_replicate)
89
+ def new_replicate(module, device_ids):
90
+ modules = old_replicate(module, device_ids)
91
+ execute_replication_callbacks(modules)
92
+ return modules
93
+
94
+ data_parallel.replicate = new_replicate
seg2art/sstan_models/networks/sync_batchnorm/unittest.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : unittest.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import unittest
12
+ import torch
13
+
14
+
15
+ class TorchTestCase(unittest.TestCase):
16
+ def assertTensorClose(self, x, y):
17
+ adiff = float((x - y).abs().max())
18
+ if (y == 0).all():
19
+ rdiff = 'NaN'
20
+ else:
21
+ rdiff = float((adiff / y).abs().max())
22
+
23
+ message = (
24
+ 'Tensor close check failed\n'
25
+ 'adiff={}\n'
26
+ 'rdiff={}\n'
27
+ ).format(adiff, rdiff)
28
+ self.assertTrue(torch.allclose(x, y), message)
29
+
seg2art/sstan_models/pix2pix_model.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ import sstan_models.networks as networks
9
+ import model_util as util
10
+
11
+
12
+ class Pix2PixModel(torch.nn.Module):
13
+ @staticmethod
14
+ def modify_commandline_options(parser, is_train):
15
+ networks.modify_commandline_options(parser, is_train)
16
+ return parser
17
+
18
+ def __init__(self, opt):
19
+ super().__init__()
20
+ self.opt = opt
21
+ self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() else torch.FloatTensor
22
+ self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() else torch.ByteTensor
23
+
24
+ self.netG, self.netD, self.netE = self.initialize_networks(opt)
25
+
26
+ # set loss functions
27
+ if opt.isTrain:
28
+ self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)
29
+ self.criterionFeat = torch.nn.L1Loss()
30
+ if not opt.no_vgg_loss:
31
+ self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
32
+ if opt.use_vae:
33
+ self.KLDLoss = networks.KLDLoss()
34
+
35
+ # Entry point for all calls involving forward pass
36
+ # of deep networks. We used this approach since DataParallel module
37
+ # can't parallelize custom functions, we branch to different
38
+ # routines based on |mode|.
39
+ def forward(self, data, mode, style_codes=None):
40
+ input_semantics, real_image = self.preprocess_input(data)
41
+ domain = None
42
+
43
+ # print(torch.cuda.memory_cached(0))
44
+ if mode == "generator":
45
+ g_loss, generated = self.compute_generator_loss(input_semantics, real_image, domain)
46
+ return g_loss, generated
47
+ elif mode == "discriminator":
48
+ d_loss = self.compute_discriminator_loss(input_semantics, real_image, domain)
49
+ return d_loss
50
+ elif mode == "encode_only":
51
+ _, mu, logvar = self.encode_z(real_image, domain)
52
+ return mu, logvar
53
+ elif mode == "inference":
54
+ with torch.no_grad():
55
+ fake_image, _, _ = self.generate_fake(input_semantics, real_image, domain, style_codes=style_codes, compute_kld_loss=False)
56
+ return fake_image
57
+ elif mode == "generate_img_npy":
58
+ with torch.no_grad():
59
+ fake_image, encoded_style_code = self.generate_img_npy(input_semantics, real_image, domain)
60
+ return fake_image, encoded_style_code
61
+ else:
62
+ raise ValueError("|mode| is invalid")
63
+
64
+ def create_optimizers(self, opt):
65
+ G_params = list(self.netG.parameters())
66
+ if opt.use_vae:
67
+ G_params += list(self.netE.parameters())
68
+ if opt.isTrain:
69
+ D_params = list(self.netD.parameters())
70
+
71
+ beta1, beta2 = opt.beta1, opt.beta2
72
+ if opt.no_TTUR:
73
+ G_lr, D_lr = opt.lr, opt.lr
74
+ else:
75
+ G_lr, D_lr = opt.lr / 2, opt.lr * 2
76
+
77
+ optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2))
78
+ optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2))
79
+
80
+ return optimizer_G, optimizer_D
81
+
82
+ def save(self, epoch):
83
+ util.save_network(self.netG, "G", epoch, self.opt)
84
+ util.save_network(self.netD, "D", epoch, self.opt)
85
+ if self.opt.use_vae:
86
+ util.save_network(self.netE, "E", epoch, self.opt)
87
+
88
+ ############################################################################
89
+ # Private helper methods
90
+ ############################################################################
91
+
92
+ def initialize_networks(self, opt):
93
+ netG = networks.define_G(opt)
94
+ netD = networks.define_D(opt) if opt.isTrain else None
95
+ netE = networks.define_E(opt) if opt.use_vae else None
96
+
97
+ if not opt.isTrain or opt.continue_train:
98
+ # netG = util.load_network(netG, 'G', opt.which_epoch, opt)
99
+ checkpoint_path = os.path.join(os.path.dirname(__file__), "..", opt.checkpoint_path)
100
+ device = "cuda" if torch.cuda.is_available() else "cpu"
101
+ if device == "cuda":
102
+ checkpoint = torch.load(checkpoint_path)
103
+ else:
104
+ checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
105
+ s = checkpoint
106
+ netG.load_state_dict(s)
107
+
108
+ if opt.isTrain:
109
+ netD = util.load_network(netD, "D", opt.which_epoch, opt)
110
+ if opt.use_vae:
111
+ netE = util.load_network(netE, "E", opt.which_epoch, opt)
112
+
113
+ return netG, netD, netE
114
+
115
+ # preprocess the input, such as moving the tensors to GPUs and
116
+ # transforming the label map to one-hot encoding
117
+ # |data|: dictionary of the input data
118
+
119
+ def preprocess_input(self, data):
120
+ """
121
+ # move to GPU and change data types
122
+ data['label'] = data['label'].long()
123
+ if self.use_gpu():
124
+ data['label'] = data['label'].cuda(non_blocking=True)
125
+ data['instance'] = data['instance'].cuda(non_blocking=True)
126
+ data['image'] = data['image'].cuda(non_blocking=True)
127
+ data['domain'] = data['domain'].cuda(non_blocking=True)
128
+
129
+ # create one-hot label map
130
+ label_map = data['label']
131
+ bs, _, h, w = label_map.size()
132
+ nc = self.opt.label_nc + 1 if self.opt.contain_dontcare_label \
133
+ else self.opt.label_nc
134
+ input_label = self.FloatTensor(bs, nc, h, w).zero_()
135
+ input_semantics = input_label.scatter_(1, label_map, 1.0)
136
+
137
+ # concatenate instance map if it exists
138
+ if not self.opt.no_instance:
139
+ inst_map = data['instance']
140
+ instance_edge_map = self.get_edges(inst_map)
141
+ input_semantics = torch.cat((input_semantics, instance_edge_map), dim=1)
142
+
143
+ return input_semantics, data['image'], data['domain']
144
+ """
145
+
146
+ data = data.long()
147
+ image = (data - 128).float() / 128.0
148
+ if self.use_gpu():
149
+ data = data.cuda()
150
+ image = image.float().cuda()
151
+ label_map = data
152
+ bs, _, h, w = label_map.size()
153
+ nc = self.opt.label_nc + 1 if self.opt.contain_dontcare_label else self.opt.label_nc
154
+ input_label = self.FloatTensor(bs, nc, h, w).zero_()
155
+ input_semantics = input_label.scatter_(1, label_map, 1.0)
156
+
157
+ return input_semantics, image # data['image'],
158
+
159
+ def compute_generator_loss(self, input_semantics, real_image, domain):
160
+ G_losses = {}
161
+
162
+ fake_image, KLD_loss, _ = self.generate_fake(input_semantics, real_image, domain, compute_kld_loss=self.opt.use_vae)
163
+
164
+ if self.opt.use_vae:
165
+ if KLD_loss.data.item() > 2.5:
166
+ print("ng")
167
+ print(KLD_loss.data.item())
168
+ KLD_loss.data = torch.Tensor([min(999.9999, KLD_loss.data.item())]).cuda()
169
+ G_losses["KLD"] = KLD_loss
170
+
171
+ pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image, domain)
172
+
173
+ G_losses["GAN"] = self.criterionGAN(pred_fake, True, for_discriminator=False)
174
+
175
+ if not self.opt.no_ganFeat_loss:
176
+ num_D = len(pred_fake)
177
+ GAN_Feat_loss = self.FloatTensor(1).fill_(0)
178
+ for i in range(num_D): # for each discriminator
179
+ # last output is the final prediction, so we exclude it
180
+ num_intermediate_outputs = len(pred_fake[i]) - 1
181
+ for j in range(num_intermediate_outputs): # for each layer output
182
+ unweighted_loss = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach())
183
+ GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D
184
+ G_losses["GAN_Feat"] = GAN_Feat_loss
185
+
186
+ if not self.opt.no_vgg_loss:
187
+ G_losses["VGG"] = self.criterionVGG(fake_image, real_image) * self.opt.lambda_vgg
188
+
189
+ return G_losses, fake_image
190
+
191
+ def compute_discriminator_loss(self, input_semantics, real_image, domain):
192
+ D_losses = {}
193
+ with torch.no_grad():
194
+ fake_image, _, _ = self.generate_fake(input_semantics, real_image, domain)
195
+ fake_image = fake_image.detach()
196
+ fake_image.requires_grad_()
197
+
198
+ pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image, domain)
199
+
200
+ D_losses["D_Fake"] = self.criterionGAN(pred_fake, False, for_discriminator=True)
201
+ D_losses["D_real"] = self.criterionGAN(pred_real, True, for_discriminator=True)
202
+
203
+ return D_losses
204
+
205
+ def encode_z(self, real_image, domain):
206
+ mu, logvar = self.netE(real_image, domain)
207
+ z = self.reparameterize(mu, logvar)
208
+ return z, mu, logvar
209
+
210
+ def generate_fake(self, input_semantics, real_image, domain, style_codes=None, compute_kld_loss=True):
211
+ KLD_loss = None
212
+ if self.opt.use_vae and style_codes is None:
213
+ # print('yes')
214
+ style_codes, mu, logvar = self.encode_z(real_image, domain)
215
+ if compute_kld_loss:
216
+ KLD_loss = self.KLDLoss(mu, logvar) * self.opt.lambda_kld
217
+
218
+ fake_image = self.netG(input_semantics, real_image, style_codes=style_codes)
219
+
220
+ assert (not compute_kld_loss) or self.opt.use_vae, "You cannot compute KLD loss if opt.use_vae == False"
221
+
222
+ return fake_image, KLD_loss, style_codes
223
+
224
+ def generate_img_npy(self, input_semantics, real_image, domain, compute_kld_loss=False):
225
+ KLD_loss = None
226
+ style_codes, mu, logvar = self.encode_z(real_image, domain)
227
+ if compute_kld_loss:
228
+ KLD_loss = self.KLDLoss(mu, logvar) * self.opt.lambda_kld
229
+
230
+ fake_image = self.netG(input_semantics, real_image, style_codes=style_codes)
231
+ # print(real_image, fake_image.shape)
232
+ assert (not compute_kld_loss) or self.opt.use_vae, "You cannot compute KLD loss if opt.use_vae == False"
233
+
234
+ return fake_image, style_codes
235
+
236
+ # Given fake and real image, return the prediction of discriminator
237
+ # for each fake and real image.
238
+
239
+ def discriminate(self, input_semantics, fake_image, real_image, domain):
240
+ fake_concat = torch.cat([input_semantics, fake_image], dim=1)
241
+ real_concat = torch.cat([input_semantics, real_image], dim=1)
242
+
243
+ # In Batch Normalization, the fake and real images are
244
+ # recommended to be in the same batch to avoid disparate
245
+ # statistics in fake and real images.
246
+ # So both fake and real images are fed to D all at once.
247
+ fake_and_real = torch.cat([fake_concat, real_concat], dim=0)
248
+
249
+ discriminator_out = self.netD(fake_and_real, domain)
250
+
251
+ pred_fake, pred_real = self.divide_pred(discriminator_out)
252
+
253
+ return pred_fake, pred_real
254
+
255
+ # Take the prediction of fake and real images from the combined batch
256
+ def divide_pred(self, pred):
257
+ # the prediction contains the intermediate outputs of multiscale GAN,
258
+ # so it's usually a list
259
+ if type(pred) == list:
260
+ fake = []
261
+ real = []
262
+ for p in pred:
263
+ fake.append([tensor[: tensor.size(0) // 2] for tensor in p])
264
+ real.append([tensor[tensor.size(0) // 2 :] for tensor in p])
265
+ else:
266
+ fake = pred[: pred.size(0) // 2]
267
+ real = pred[pred.size(0) // 2 :]
268
+
269
+ return fake, real
270
+
271
+ def get_edges(self, t):
272
+ edge = self.ByteTensor(t.size()).zero_()
273
+ edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
274
+ edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])
275
+ edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
276
+ edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
277
+ return edge.float()
278
+
279
+ def reparameterize(self, mu, logvar):
280
+ std = torch.exp(0.5 * logvar)
281
+ eps = torch.randn_like(std)
282
+ return eps.mul(std) + mu
283
+
284
+ def use_gpu(self):
285
+ return len(self.opt.gpu_ids) > 0
static/index.js ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ let cvsIn = document.getElementById("inputimg");
4
+ let ctxIn = cvsIn.getContext('2d');
5
+ let style = document.getElementById("style");
6
+ let svgGraph = null;
7
+ let mouselbtn = false;
8
+ var current_time = (new Date()).getTime();
9
+ var user_id = Math.floor(Math.random() * 1000000000);
10
+
11
+
12
+ // initilize
13
+ window.onload = function () {
14
+ ctxIn.fillStyle = "#87ceeb";
15
+ ctxIn.fillRect(0, 0, cvsIn.width, 300);
16
+ ctxIn.fillStyle = "#567d46";
17
+ ctxIn.fillRect(0, 300, cvsIn.width, 512);
18
+
19
+ ctxIn.color = "#b0d49b";
20
+ ctxIn.lineWidth = 30;
21
+ ctxIn.lineJoin = ctxIn.lineCap = 'round';
22
+ }
23
+
24
+
25
+ // add cavas events
26
+ cvsIn.addEventListener("mousedown", function (e) {
27
+ if (e.button == 0) {
28
+ let rect = e.target.getBoundingClientRect();
29
+ let x = e.clientX - rect.left;
30
+ let y = e.clientY - rect.top;
31
+ mouselbtn = true;
32
+ ctxIn.beginPath();
33
+ ctxIn.moveTo(x, y);
34
+ }
35
+ else if (e.button == 2) {
36
+ onClear(); // right click for clear input
37
+ }
38
+ });
39
+
40
+ cvsIn.addEventListener("mouseup", function (e) {
41
+ if (e.button == 0) {
42
+ mouselbtn = false;
43
+ move_range = domainSlider.value;
44
+ onRecognition(move_range);
45
+ }
46
+ });
47
+ cvsIn.addEventListener("mousemove", function (e) {
48
+ let rect = e.target.getBoundingClientRect();
49
+ let x = e.clientX - rect.left;
50
+ let y = e.clientY - rect.top;
51
+ if (mouselbtn) {
52
+ ctxIn.lineTo(x, y);
53
+ ctxIn.strokeStyle = ctxIn.color;
54
+ ctxIn.stroke();
55
+ if (((new Date).getTime() - current_time) >= 400) {
56
+ move_range = domainSlider.value;
57
+ onRecognition(move_range);
58
+ current_time = (new Date).getTime();
59
+ }
60
+
61
+ }
62
+ });
63
+
64
+ cvsIn.addEventListener("touchstart", function (e) {
65
+ // for touch device
66
+ if (e.targetTouches.length == 1) {
67
+ let rect = e.target.getBoundingClientRect();
68
+ let touch = e.targetTouches[0];
69
+ let x = touch.clientX - rect.left;
70
+ let y = touch.clientY - rect.top;
71
+ ctxIn.beginPath();
72
+ ctxIn.moveTo(x, y);
73
+ }
74
+ });
75
+
76
+ cvsIn.addEventListener("touchmove", function (e) {
77
+ // for touch device
78
+ if (e.targetTouches.length == 1) {
79
+ let rect = e.target.getBoundingClientRect();
80
+ let touch = e.targetTouches[0];
81
+ let x = touch.clientX - rect.left;
82
+ let y = touch.clientY - rect.top;
83
+ ctxIn.lineTo(x, y);
84
+ ctxIn.strokeStyle = ctxIn.color;
85
+ ctxIn.stroke();
86
+ e.preventDefault();
87
+ }
88
+ });
89
+
90
+ cvsIn.addEventListener("touchend", function (e) {
91
+ // for touch device
92
+ move_range = domainSlider.value;
93
+ onRecognition(move_range);
94
+ });
95
+
96
+ // prevent display the contextmenu
97
+ cvsIn.addEventListener('contextmenu', function (e) {
98
+ e.preventDefault();
99
+ });
100
+
101
+ document.getElementById("clearbtn").onclick = onClear;
102
+ function onClear() {
103
+ mouselbtn = false;
104
+ ctxIn.clearRect(0, 0, 512, 512);
105
+ ctxIn.fillStyle = "#87ceeb";
106
+ ctxIn.fillRect(0, 0, cvsIn.width, 300);
107
+ ctxIn.fillStyle = "#567d46";
108
+ ctxIn.fillRect(0, 300, cvsIn.width, 512);
109
+ }
110
+
111
+
112
+ document.getElementById("random_pick").addEventListener("click", function () {
113
+ //ctxIn.color = "#F5F5F5";
114
+ onRecognition_random();
115
+ });
116
+
117
+
118
+ document.getElementById("color1").addEventListener("click", function () {
119
+ //ctxIn.color = "#D5D5D5";
120
+ ctxIn.color = "#87ceeb";
121
+ });
122
+ document.getElementById("color2").addEventListener("click", function () {
123
+ //ctxIn.color = "#696969";
124
+ ctxIn.color = "#9b7653"
125
+ });
126
+ document.getElementById("color3").addEventListener("click", function () {
127
+ //ctxIn.color = "#676767";
128
+ ctxIn.color = "#b0d49b"
129
+ });
130
+ document.getElementById("color4").addEventListener("click", function () {
131
+ //ctxIn.color = "#F5F5F5";
132
+ ctxIn.color = "#5abcd8"
133
+ });
134
+ document.getElementById("color5").addEventListener("click", function () {
135
+ //ctxIn.color = "#F5F5F5";
136
+ ctxIn.color = "#C1BEBA"
137
+ });
138
+ document.getElementById("color6").addEventListener("click", function () {
139
+ ctxIn.color = "#5A4D41"
140
+ });
141
+ document.getElementById("color7").addEventListener("click", function () {
142
+ ctxIn.color = "#567d46"
143
+ });
144
+ document.getElementById("color8").addEventListener("click", function () {
145
+ ctxIn.color = "#42692f"
146
+ });
147
+
148
+ document.getElementById("color9").addEventListener("click", function () {
149
+ ctxIn.color = "#1577be"
150
+ });
151
+ //document.getElementById("color10").addEventListener("click", function(){
152
+ //ctxIn.color = "#676767";
153
+ // ctxIn.color = "#808080"
154
+ //});
155
+ document.getElementById("color11").addEventListener("click", function () {
156
+ //ctxIn.color = "#F5F5F5";
157
+ ctxIn.color = "#3a2e27"
158
+ });
159
+ document.getElementById("color12").addEventListener("click", function () {
160
+ //ctxIn.color = "#F5F5F5";
161
+ ctxIn.color = "#4D415A"
162
+ });
163
+ //document.getElementById("color13").addEventListener("click", function(){
164
+ // ctxIn.color = "#74cc8c"
165
+ //});
166
+ document.getElementById("color14").addEventListener("click", function () {
167
+ ctxIn.color = "#FDDA16"
168
+ });
169
+ document.getElementById("color15").addEventListener("click", function () {
170
+ ctxIn.color = "#d0cccc"
171
+ });
172
+
173
+ var brushSlider = document.getElementById("brushSlider");
174
+ ctxIn.lineWidth = brushSlider.value;
175
+
176
+ brushSlider.addEventListener("change", function () {
177
+ ctxIn.lineWidth = brushSlider.value;
178
+ });
179
+
180
+
181
+
182
+
183
+ var move_range = 3;//domainSlider_1.value;
184
+
185
+
186
+ document.getElementById('style').addEventListener('change', function (event) {
187
+ domainSlider.value = 3;
188
+ onRecognition(domainSlider.value);
189
+ })
190
+
191
+
192
+ domainSlider.addEventListener("change", function () {
193
+ // style.value = "ink";
194
+ move_range = domainSlider.value;
195
+ onRecognition(move_range);
196
+ });
197
+
198
+
199
+ // post data to server for recognition
200
+ function onRecognition(range) {
201
+ console.time("predict");
202
+
203
+ $.ajax({
204
+ url: './predict',
205
+ type: 'POST',
206
+ data: JSON.stringify({
207
+ img: cvsIn.toDataURL("image/png").replace('data:image/png;base64,', ''),
208
+ model: style.value,
209
+ move_range: range,
210
+ user_id: user_id
211
+ }),
212
+ contentType: 'application/json',
213
+ }).done(function (data) {
214
+ drawImgToCanvas("outputimg", data)
215
+
216
+ }).fail(function (XMLHttpRequest, textStatus, errorThrown) {
217
+ console.log(XMLHttpRequest);
218
+ alert("error");
219
+ })
220
+
221
+ console.timeEnd("time");
222
+ }
223
+
224
+ function onRecognition_random(range) {
225
+ console.time("predict");
226
+
227
+ $.ajax({
228
+ url: './predict_random',
229
+ type: 'POST',
230
+ data: JSON.stringify({
231
+ img: cvsIn.toDataURL("image/png").replace('data:image/png;base64,', ''),
232
+ model: style.value,
233
+ move_range: range,
234
+ user_id: user_id
235
+ }),
236
+ contentType: 'application/json',
237
+ }).done(function (data) {
238
+ drawImgToCanvas("outputimg", data)
239
+
240
+ }).fail(function (XMLHttpRequest, textStatus, errorThrown) {
241
+ console.log(XMLHttpRequest);
242
+ alert("error");
243
+ })
244
+
245
+ console.timeEnd("time");
246
+ }
247
+
248
+ function drawImgToCanvas(canvasId, b64Img) {
249
+ let canvas = document.getElementById(canvasId);
250
+ let ctx = canvas.getContext('2d');
251
+ let img = new Image();
252
+ img.src = "data:image/png;base64," + b64Img;
253
+ img.onload = function () {
254
+ ctx.drawImage(img, 0, 0, img.width, img.height, 0, 0, canvas.width, canvas.height);
255
+ }
256
+ }
static/init_code ADDED
Binary file (1.79 kB). View file
 
static/style.css ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .common {
2
+ text-align: center;
3
+ }
4
+
5
+ .boxitem1 {
6
+ display: inline-block;
7
+ vertical-align: left;
8
+ }
9
+
10
+ select {
11
+ font-size: 1.7em;
12
+ border: 1px;
13
+ }
14
+
15
+ button {
16
+ border-radius: 10px;
17
+ }
18
+
19
+ .boxitem2 {
20
+ display: inline-block;
21
+ vertical-align: right;
22
+ }
23
+ .boxitem {
24
+ display: inline-block;
25
+ vertical-align: top;
26
+ }
27
+
28
+ #inputimg {
29
+ vertical-align: left;
30
+ border: solid 1px black;
31
+ }
32
+
33
+ #outputimg {
34
+ vertical-align: left;
35
+ border: solid 1px black;
36
+ }
templates/index.html ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+
4
+ <head>
5
+ <link rel="stylesheet" href="https://cdn.rawgit.com/Chalarangelo/mini.css/v2.3.7/dist/mini-default.min.css">
6
+ <link rel="stylesheet" href="./static/style.css">
7
+ <meta charset="UTF-8">
8
+ <title>Label to Art Demo_V_0.7 </title>
9
+ </head>
10
+
11
+ <div class="common">
12
+
13
+ <body>
14
+ <div class="row">
15
+ <div class="col-sm-12 col-md-10 col-md-offset-1">
16
+ &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
17
+ <div class="boxitem1">
18
+ <label for='brushSlider' style='font-size: 1.7em; font-weight: bold'>Stroke Width </label>
19
+ <input type="range" name="brushsize" min="0" max="100" id="brushSlider" step="1" value="30"
20
+ onchange="this.setAttribute('value',this.value);">
21
+ </div>
22
+
23
+ &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
24
+ &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
25
+ &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
26
+ &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
27
+ <div class="boxitem2">
28
+ &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
29
+ <label for="style" style='font-size: 1.7em; font-weight: bold'>Domain:</label>
30
+ <select id="style">
31
+ <option style='font-size: 1.7em; font-weight: bold' value="ink">Ink Wash</option>
32
+ <option style='font-size: 1.7em; font-weight: bold' value="monet">Monet</option>
33
+ <option style='font-size: 1.7em; font-weight: bold' value="vangogh">Van Gogh</option>
34
+ <option style='font-size: 1.7em; font-weight: bold' value="water">WaterColor</option>
35
+ </select>
36
+ </div>
37
+
38
+ <div class="row">
39
+ <div class="col-sm-12 col-md-10 col-md-offset-1">
40
+ <div class="boxitem1">
41
+ <label for='domainSlider' style='font-size: 1.7em; font-weight: bold'> Style Strength </label>
42
+ <input type="range" name="style range" min="1" max="5" id="domainSlider" step="0.2" value="3"
43
+ onchange="this.setAttribute('value',this.value);">
44
+ </div>
45
+ &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
46
+ &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
47
+ &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
48
+ &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
49
+ &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
50
+ &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
51
+ &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
52
+ &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
53
+ <div class="boxitem2">
54
+ <button id="random_pick" class="primary"
55
+ style="background-color:#003449; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
56
+ Random </button>
57
+ </div>
58
+ </div>
59
+ <br>
60
+
61
+ </div>
62
+ <div class="divider"></div>
63
+ <div class="boxitem1">
64
+ <canvas id="inputimg" width="512" height="512" style="border:5px solid #ffffff;"></canvas>
65
+ </div>
66
+ <div class="boxitem2">
67
+ <canvas id="outputimg" width="512" height="512" style="border:5px solid #ffffff;"></canvas>
68
+ </div>
69
+ </div>
70
+ </div>
71
+
72
+
73
+ <div class="boxitem">
74
+ <button id="color1" class="primary"
75
+ style="background-color:#87ceeb; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
76
+ Sky </button>
77
+ <button id="color2" class="primary"
78
+ style="background-color:#9b7653; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
79
+ Dirt </button>
80
+ <button id="color3" class="primary"
81
+ style="background-color:#b0d49b; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
82
+ Mountain </button>
83
+ <button id="color4" class="primary"
84
+ style="background-color:#5abcd8; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
85
+ River </button>
86
+ <button id="color5" class="primary"
87
+ style="background-color:#C1BEBA; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
88
+ Clouds </button>
89
+ <button id="color6" class="primary"
90
+ style="background-color:#5A4D41; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
91
+ Rock </button>
92
+ <button id="color7" class="primary"
93
+ style="background-color:#567d46; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
94
+ Grass </button>
95
+ <button id="color8" class="primary"
96
+ style="background-color:#42692f; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
97
+ Tree </button>
98
+ </div>
99
+ <br>
100
+ <div class="boxitem">
101
+ <button id="color9" class="primary"
102
+ style="background-color:#1577be; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
103
+ Sea </button>
104
+ <button id="color11" class="primary"
105
+ style="background-color:#3a2e27; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
106
+ Ground </button>
107
+ <button id="color12" class="primary"
108
+ style="background-color:#4D415A; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
109
+ Hill </button>
110
+ <button id="color14" class="primary"
111
+ style="background-color:#FDDA16; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
112
+ Road </button>
113
+ <button id="color15" class="primary"
114
+ style="background-color:#d0cccc; margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;">
115
+ Snow </button>
116
+ <button id="clearbtn" class="primary"
117
+ style="margin-right:0px; font-size: 1.9em; font-weight: bold; border-radius:15px;"> Clear </button>
118
+
119
+ </div>
120
+
121
+ <script src="./static/index.js"></script>
122
+ <script src="//ajax.googleapis.com/ajax/libs/jquery/3.3.1/jquery.min.js"></script>
123
+ <script src="//d3js.org/d3.v5.min.js"></script>
124
+ </body>
utils/boundaries_amp_52/artwork_ink_boundary/boundary.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63f02c41f68b7fcd603e158fbd4542ebbe3b42d5b25b64a5e1e1e0699e1286c0
3
+ size 1152
utils/boundaries_amp_52/artwork_ink_boundary/log.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [2021-09-15 00:57:58,220][INFO] Loading latent codes.
2
+ [2021-09-15 00:57:58,233][INFO] Loading attribute scores.
3
+ [2021-09-15 00:57:58,234][INFO] Filtering training data.
4
+ [2021-09-15 00:57:58,234][INFO] Sorting scores to get positive and negative samples.
5
+ [2021-09-15 00:57:58,246][INFO] Spliting training and validation sets:
6
+ [2021-09-15 00:57:58,337][INFO] Training: 4200 positive, 4200 negative.
7
+ [2021-09-15 00:57:58,342][INFO] Validation: 1800 positive, 1800 negative.
8
+ [2021-09-15 00:57:58,353][INFO] Remaining: 4255 positive, 23745 negative.
9
+ [2021-09-15 00:57:58,356][INFO] Training boundary.
10
+ [2021-09-15 00:57:59,433][INFO] Finish training.
11
+ [2021-09-15 00:57:59,711][INFO] Accuracy for validation set: 3596 / 3600 = 0.998889
12
+ [2021-09-15 00:58:01,331][INFO] Accuracy for remaining set: 27037 / 28000 = 0.965607
utils/boundaries_amp_52/artwork_monet_boundary/boundary.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:477c0e8af08ccb05d8ddb79e38b5c71cbc856bb4ebe0adf80194de48cf1510ad
3
+ size 1152
utils/boundaries_amp_52/artwork_monet_boundary/log.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [2021-09-15 00:57:58,250][INFO] Loading latent codes.
2
+ [2021-09-15 00:57:58,265][INFO] Loading attribute scores.
3
+ [2021-09-15 00:57:58,266][INFO] Filtering training data.
4
+ [2021-09-15 00:57:58,266][INFO] Sorting scores to get positive and negative samples.
5
+ [2021-09-15 00:57:58,281][INFO] Spliting training and validation sets:
6
+ [2021-09-15 00:57:58,393][INFO] Training: 4200 positive, 4200 negative.
7
+ [2021-09-15 00:57:58,398][INFO] Validation: 1800 positive, 1800 negative.
8
+ [2021-09-15 00:57:58,400][INFO] Remaining: 5854 positive, 22146 negative.
9
+ [2021-09-15 00:57:58,407][INFO] Training boundary.
10
+ [2021-09-15 00:57:59,699][INFO] Finish training.
11
+ [2021-09-15 00:57:59,912][INFO] Accuracy for validation set: 3549 / 3600 = 0.985833
12
+ [2021-09-15 00:58:01,556][INFO] Accuracy for remaining set: 23849 / 28000 = 0.851750
utils/boundaries_amp_52/artwork_vangogh_boundary/boundary.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b616bf764d3ea1c85b9f852b5320b90fcb6ff5d90f80d22ef8a7e5c3a47a5959
3
+ size 1152
utils/boundaries_amp_52/artwork_vangogh_boundary/log.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [2021-09-15 00:57:58,217][INFO] Loading latent codes.
2
+ [2021-09-15 00:57:58,230][INFO] Loading attribute scores.
3
+ [2021-09-15 00:57:58,231][INFO] Filtering training data.
4
+ [2021-09-15 00:57:58,231][INFO] Sorting scores to get positive and negative samples.
5
+ [2021-09-15 00:57:58,246][INFO] Spliting training and validation sets:
6
+ [2021-09-15 00:57:58,279][INFO] Training: 4200 positive, 4200 negative.
7
+ [2021-09-15 00:57:58,281][INFO] Validation: 1800 positive, 1800 negative.
8
+ [2021-09-15 00:57:58,281][INFO] Remaining: 3401 positive, 24599 negative.
9
+ [2021-09-15 00:57:58,281][INFO] Training boundary.
10
+ [2021-09-15 00:57:59,347][INFO] Finish training.
11
+ [2021-09-15 00:57:59,551][INFO] Accuracy for validation set: 3596 / 3600 = 0.998889
12
+ [2021-09-15 00:58:01,098][INFO] Accuracy for remaining set: 23785 / 28000 = 0.849464
utils/boundaries_amp_52/artwork_water_boundary/boundary.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0827fe3ea8f6361bc298c968623c2ff2e33e5183f4d9df18c118f751c10c713
3
+ size 1152
utils/boundaries_amp_52/artwork_water_boundary/log.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [2021-09-15 00:57:58,226][INFO] Loading latent codes.
2
+ [2021-09-15 00:57:58,244][INFO] Loading attribute scores.
3
+ [2021-09-15 00:57:58,245][INFO] Filtering training data.
4
+ [2021-09-15 00:57:58,245][INFO] Sorting scores to get positive and negative samples.
5
+ [2021-09-15 00:57:58,263][INFO] Spliting training and validation sets:
6
+ [2021-09-15 00:57:58,390][INFO] Training: 4200 positive, 4200 negative.
7
+ [2021-09-15 00:57:58,393][INFO] Validation: 1800 positive, 1800 negative.
8
+ [2021-09-15 00:57:58,398][INFO] Remaining: 4465 positive, 23535 negative.
9
+ [2021-09-15 00:57:58,401][INFO] Training boundary.
10
+ [2021-09-15 00:57:59,812][INFO] Finish training.
11
+ [2021-09-15 00:58:00,021][INFO] Accuracy for validation set: 3584 / 3600 = 0.995556
12
+ [2021-09-15 00:58:01,830][INFO] Accuracy for remaining set: 24271 / 28000 = 0.866821
utils/umap_utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import copy
4
+ import os
5
+ import numpy as np
6
+ from sklearn import svm
7
+
8
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
9
+
10
+
11
+ def linear_interpolate(latent_code, boundary, start_distance=-3, end_distance=3, steps=10):
12
+ """Manipulates the given latent code with respect to a particular boundary.
13
+
14
+ Basically, this function takes a latent code and a boundary as inputs, and
15
+ outputs a collection of manipulated latent codes. For example, let `steps` to
16
+ be 10, then the input `latent_code` is with shape [1, latent_space_dim], input
17
+ `boundary` is with shape [1, latent_space_dim] and unit norm, the output is
18
+ with shape [10, latent_space_dim]. The first output latent code is
19
+ `start_distance` away from the given `boundary`, while the last output latent
20
+ code is `end_distance` away from the given `boundary`. Remaining latent codes
21
+ are linearly interpolated.
22
+
23
+ Input `latent_code` can also be with shape [1, num_layers, latent_space_dim]
24
+ to support W+ space in Style GAN. In this case, all features in W+ space will
25
+ be manipulated same as each other. Accordingly, the output will be with shape
26
+ [10, num_layers, latent_space_dim].
27
+
28
+ NOTE: Distance is sign sensitive.
29
+
30
+ Args:
31
+ latent_code: The input latent code for manipulation.
32
+ boundary: The semantic boundary as reference.
33
+ start_distance: The distance to the boundary where the manipulation starts.
34
+ (default: -3.0)
35
+ end_distance: The distance to the boundary where the manipulation ends.
36
+ (default: 3.0)
37
+ steps: Number of steps to move the latent code from start position to end
38
+ position. (default: 10)
39
+ """
40
+ assert latent_code.shape[0] == 1 and boundary.shape[0] == 1 and len(boundary.shape) == 2 and boundary.shape[1] == latent_code.shape[-1]
41
+
42
+ linspace = np.linspace(start_distance, end_distance, steps)
43
+ if len(latent_code.shape) == 2:
44
+ linspace = linspace - latent_code.dot(boundary.T)
45
+ linspace = linspace.reshape(-1, 1).astype(np.float32)
46
+ return latent_code + linspace * boundary
47
+ if len(latent_code.shape) == 3:
48
+ linspace = linspace.reshape(-1, 1, 1).astype(np.float32)
49
+ return latent_code + linspace * boundary.reshape(1, 1, -1)
50
+ raise ValueError(
51
+ f"Input `latent_code` should be with shape "
52
+ f"[1, latent_space_dim] or [1, N, latent_space_dim] for "
53
+ f"W+ space in Style GAN!\n"
54
+ f"But {latent_code.shape} is received."
55
+ )
56
+
57
+
58
+ def get_code(domain, boundaries):
59
+ if domain == "ink":
60
+ domain = 0
61
+ elif domain == "monet":
62
+ domain = 1
63
+ elif domain == "vangogh":
64
+ domain = 2
65
+ elif domain == "water":
66
+ domain = 3
67
+
68
+ res = np.array(torch.randn(1, 256, dtype=torch.float32))
69
+ # res = linear_interpolate(res, boundaries[domain], end_distance=3, steps=3)[-1:]
70
+ res = torch.Tensor(res).cuda() if torch.cuda.is_available() else torch.Tensor(res)
71
+ return res
72
+
73
+
74
+ def modify_code(code, boundaries, domain, range):
75
+ if domain == "ink":
76
+ domain = 0
77
+ elif domain == "monet":
78
+ domain = 1
79
+ elif domain == "vangogh":
80
+ domain = 2
81
+ elif domain == "water":
82
+ domain = 3
83
+ # print(domain, range)
84
+ if range == 0:
85
+ return code
86
+ else:
87
+ res = np.array(code.cpu().detach().numpy())
88
+ res = linear_interpolate(res, boundaries[domain], end_distance=range, steps=3)[-1:]
89
+ res = torch.Tensor(res).cuda() if torch.cuda.is_available() else torch.Tensor(res)
90
+ return res
91
+
92
+
93
+ def load_boundries():
94
+ domains = ["ink", "monet", "vangogh", "water"]
95
+ domains.sort()
96
+ boundaries = [
97
+ np.load(os.path.join(os.path.dirname(__file__), "boundaries_amp_52/artwork_" + domain + "_boundary/boundary.npy")) for domain in domains
98
+ ]
99
+ return boundaries