yuaiyu JamesXu commited on
Commit
004fe4b
·
0 Parent(s):

Duplicate from shi-labs/Versatile-Diffusion

Browse files

Co-authored-by: Xingqian Xu <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +36 -0
  2. .gitignore +11 -0
  3. README.md +15 -0
  4. app.py +1114 -0
  5. assets/demo/mcg_example/e0i0.jpg +0 -0
  6. assets/demo/mcg_example/e0i1.jpg +0 -0
  7. assets/demo/mcg_example/e0i2.jpg +0 -0
  8. assets/demo/misc/mask_inst1.gif +3 -0
  9. assets/demo/misc/mask_inst2.gif +3 -0
  10. assets/demo/misc/mask_inst3.gif +3 -0
  11. assets/demo/misc/noimage.jpg +0 -0
  12. assets/demo/reg_example/benz.jpg +0 -0
  13. assets/demo/reg_example/boy_and_girl.jpg +0 -0
  14. assets/demo/reg_example/church.jpg +0 -0
  15. assets/demo/reg_example/firework.jpg +0 -0
  16. assets/demo/reg_example/ghibli.jpg +0 -0
  17. assets/demo/reg_example/horse.jpg +0 -0
  18. assets/demo/reg_example/house_by_lake.jpg +0 -0
  19. assets/demo/reg_example/matisse.jpg +0 -0
  20. assets/demo/reg_example/night_light.jpg +0 -0
  21. assets/demo/reg_example/noimage.jpg +0 -0
  22. assets/demo/reg_example/paris.jpg +0 -0
  23. assets/demo/reg_example/penguin.jpg +0 -0
  24. assets/demo/reg_example/san_diego.jpg +0 -0
  25. assets/demo/reg_example/scream.jpg +0 -0
  26. assets/demo/reg_example/space.jpg +0 -0
  27. assets/demo/reg_example/tiger.jpg +0 -0
  28. assets/demo/reg_example/train.jpg +0 -0
  29. assets/demo/reg_example/vermeer.jpg +0 -0
  30. assets/demo/tcg_example/e0i0.jpg +0 -0
  31. assets/demo/tcg_example/e0i1.jpg +0 -0
  32. assets/demo/tcg_example/e1i0.jpg +0 -0
  33. assets/demo/tcg_example/e1i1.jpg +0 -0
  34. assets/demo/tcg_example/e2i0.jpg +0 -0
  35. assets/figures/share_instruction.png +0 -0
  36. configs/model/autokl.yaml +23 -0
  37. configs/model/clip.yaml +13 -0
  38. configs/model/openai_unet.yaml +96 -0
  39. configs/model/optimus.yaml +103 -0
  40. configs/model/vd.yaml +29 -0
  41. cusomized_gradio_blocks.py +271 -0
  42. lib/__init__.py +0 -0
  43. lib/cfg_helper.py +612 -0
  44. lib/cfg_holder.py +28 -0
  45. lib/log_service.py +166 -0
  46. lib/model_zoo/__init__.py +4 -0
  47. lib/model_zoo/attention.py +435 -0
  48. lib/model_zoo/autokl.py +140 -0
  49. lib/model_zoo/autokl_modules.py +835 -0
  50. lib/model_zoo/autokl_utils.py +400 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.pth filter=lfs diff=lfs merge=lfs -text
36
+ *.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ .vscode/
3
+ src/
4
+ data/
5
+ data
6
+ log/
7
+ log
8
+ pretrained/
9
+ pretrained
10
+ gradio_cached_examples/
11
+ gradio_cached_examples
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Versatile Diffusion
3
+ emoji: null
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.17.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ python_version: 3.8.5
12
+ duplicated_from: shi-labs/Versatile-Diffusion
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,1114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################################################
2
+ # Copyright (C) 2023 Xingqian Xu - All Rights Reserved #
3
+ # #
4
+ # Please visit Versatile Diffusion's arXiv paper for more details, link at #
5
+ # arxiv.org/abs/2211.08332 #
6
+ # #
7
+ # Besides, this work is also inspired by many established techniques including:#
8
+ # Denoising Diffusion Probablistic Model; Denoising Diffusion Implicit Model; #
9
+ # Latent Diffusion Model; Stable Diffusion; Stable Diffusion - Img2Img; Stable #
10
+ # Diffusion - Variation; ImageMixer; DreamBooth; Stable Diffusion - Lora; More #
11
+ # Control for Free; Prompt-to-Prompt; #
12
+ # #
13
+ ################################################################################
14
+
15
+ import gradio as gr
16
+ import os
17
+ import PIL
18
+ from PIL import Image
19
+ from pathlib import Path
20
+ import numpy as np
21
+ import numpy.random as npr
22
+ from contextlib import nullcontext
23
+ import types
24
+
25
+ import torch
26
+ import torchvision.transforms as tvtrans
27
+ from lib.cfg_helper import model_cfg_bank
28
+ from lib.model_zoo import get_model
29
+ from cusomized_gradio_blocks import create_myexamples, customized_as_example, customized_postprocess
30
+
31
+ n_sample_image = 2
32
+ n_sample_text = 4
33
+ cache_examples = True
34
+
35
+ from lib.model_zoo.ddim import DDIMSampler
36
+
37
+ ##########
38
+ # helper #
39
+ ##########
40
+
41
+ def highlight_print(info):
42
+ print('')
43
+ print(''.join(['#']*(len(info)+4)))
44
+ print('# '+info+' #')
45
+ print(''.join(['#']*(len(info)+4)))
46
+ print('')
47
+
48
+ def decompose(x, q=20, niter=100):
49
+ x_mean = x.mean(-1, keepdim=True)
50
+ x_input = x - x_mean
51
+ u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter)
52
+ ss = torch.stack([torch.diag(si) for si in s])
53
+ x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1]))
54
+ x_remain = x_input - x_lowrank
55
+ return u, s, v, x_mean, x_remain
56
+
57
+ class adjust_rank(object):
58
+ def __init__(self, max_drop_rank=[1, 5], q=20):
59
+ self.max_semantic_drop_rank = max_drop_rank[0]
60
+ self.max_style_drop_rank = max_drop_rank[1]
61
+ self.q = q
62
+
63
+ def t2y0_semf_wrapper(t0, y00, t1, y01):
64
+ return lambda t: (np.exp((t-0.5)*2)-t0)/(t1-t0)*(y01-y00)+y00
65
+ t0, y00 = np.exp((0 -0.5)*2), -self.max_semantic_drop_rank
66
+ t1, y01 = np.exp((0.5-0.5)*2), 1
67
+ self.t2y0_semf = t2y0_semf_wrapper(t0, y00, t1, y01)
68
+
69
+ def x2y_semf_wrapper(x0, x1, y1):
70
+ return lambda x, y0: (x-x0)/(x1-x0)*(y1-y0)+y0
71
+ x0 = 0
72
+ x1, y1 = self.max_semantic_drop_rank+1, 1
73
+ self.x2y_semf = x2y_semf_wrapper(x0, x1, y1)
74
+
75
+ def t2y0_styf_wrapper(t0, y00, t1, y01):
76
+ return lambda t: (np.exp((t-0.5)*2)-t0)/(t1-t0)*(y01-y00)+y00
77
+ t0, y00 = np.exp((1 -0.5)*2), -(q-self.max_style_drop_rank)
78
+ t1, y01 = np.exp((0.5-0.5)*2), 1
79
+ self.t2y0_styf = t2y0_styf_wrapper(t0, y00, t1, y01)
80
+
81
+ def x2y_styf_wrapper(x0, x1, y1):
82
+ return lambda x, y0: (x-x0)/(x1-x0)*(y1-y0)+y0
83
+ x0 = q-1
84
+ x1, y1 = self.max_style_drop_rank-1, 1
85
+ self.x2y_styf = x2y_styf_wrapper(x0, x1, y1)
86
+
87
+ def __call__(self, x, lvl):
88
+ if lvl == 0.5:
89
+ return x
90
+
91
+ if x.dtype == torch.float16:
92
+ fp16 = True
93
+ x = x.float()
94
+ else:
95
+ fp16 = False
96
+ std_save = x.std(axis=[-2, -1])
97
+
98
+ u, s, v, x_mean, x_remain = decompose(x, q=self.q)
99
+
100
+ if lvl < 0.5:
101
+ assert lvl>=0
102
+ for xi in range(0, self.max_semantic_drop_rank+1):
103
+ y0 = self.t2y0_semf(lvl)
104
+ yi = self.x2y_semf(xi, y0)
105
+ yi = 0 if yi<0 else yi
106
+ s[:, xi] *= yi
107
+
108
+ elif lvl > 0.5:
109
+ assert lvl <= 1
110
+ for xi in range(self.max_style_drop_rank, self.q):
111
+ y0 = self.t2y0_styf(lvl)
112
+ yi = self.x2y_styf(xi, y0)
113
+ yi = 0 if yi<0 else yi
114
+ s[:, xi] *= yi
115
+ x_remain = 0
116
+
117
+ ss = torch.stack([torch.diag(si) for si in s])
118
+ x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1]))
119
+ x_new = x_lowrank + x_mean + x_remain
120
+
121
+ std_new = x_new.std(axis=[-2, -1])
122
+ x_new = x_new / std_new * std_save
123
+
124
+ if fp16:
125
+ x_new = x_new.half()
126
+
127
+ return x_new
128
+
129
+ def remove_duplicate_word(tx):
130
+ def combine_words(input, length):
131
+ combined_inputs = []
132
+ if len(splitted_input)>1:
133
+ for i in range(len(input)-1):
134
+ combined_inputs.append(input[i]+" "+last_word_of(splitted_input[i+1],length)) #add the last word of the right-neighbour (overlapping) sequence (before it has expanded), which is the next word in the original sentence
135
+ return combined_inputs, length+1
136
+
137
+ def remove_duplicates(input, length):
138
+ bool_broke=False #this means we didn't find any duplicates here
139
+ for i in range(len(input) - length):
140
+ if input[i]==input[i + length]: #found a duplicate piece of sentence!
141
+ for j in range(0, length): #remove the overlapping sequences in reverse order
142
+ del input[i + length - j]
143
+ bool_broke = True
144
+ break #break the for loop as the loop length does not matches the length of splitted_input anymore as we removed elements
145
+ if bool_broke:
146
+ return remove_duplicates(input, length) #if we found a duplicate, look for another duplicate of the same length
147
+ return input
148
+
149
+ def last_word_of(input, length):
150
+ splitted = input.split(" ")
151
+ if len(splitted)==0:
152
+ return input
153
+ else:
154
+ return splitted[length-1]
155
+
156
+ def split_and_puncsplit(text):
157
+ tx = text.split(" ")
158
+ txnew = []
159
+ for txi in tx:
160
+ txqueue=[]
161
+ while True:
162
+ if txi[0] in '([{':
163
+ txqueue.extend([txi[:1], '<puncnext>'])
164
+ txi = txi[1:]
165
+ if len(txi) == 0:
166
+ break
167
+ else:
168
+ break
169
+ txnew += txqueue
170
+ txstack=[]
171
+ if len(txi) == 0:
172
+ continue
173
+ while True:
174
+ if txi[-1] in '?!.,:;}])':
175
+ txstack = ['<puncnext>', txi[-1:]] + txstack
176
+ txi = txi[:-1]
177
+ if len(txi) == 0:
178
+ break
179
+ else:
180
+ break
181
+ if len(txi) != 0:
182
+ txnew += [txi]
183
+ txnew += txstack
184
+ return txnew
185
+
186
+ if tx == '':
187
+ return tx
188
+
189
+ splitted_input = split_and_puncsplit(tx)
190
+ word_length = 1
191
+ intermediate_output = False
192
+ while len(splitted_input)>1:
193
+ splitted_input = remove_duplicates(splitted_input, word_length)
194
+ if len(splitted_input)>1:
195
+ splitted_input, word_length = combine_words(splitted_input, word_length)
196
+ if intermediate_output:
197
+ print(splitted_input)
198
+ print(word_length)
199
+ output = splitted_input[0]
200
+ output = output.replace(' <puncnext> ', '')
201
+ return output
202
+
203
+ def get_instruction(mode):
204
+ t2i_instruction = ["Generate image from text prompt."]
205
+ i2i_instruction = ["Generate image conditioned on reference image.",]
206
+ i2t_instruction = ["Generate text from reference image. "]
207
+ t2t_instruction = ["Generate text from reference text prompt. "]
208
+ dcg_instruction = ["Generate image conditioned on both text and image."]
209
+ tcg_instruction = ["Generate image conditioned on text and up to two images."]
210
+ mcg_instruction = ["Generate image from multiple contexts."]
211
+
212
+ if mode == "Text-to-Image":
213
+ return '\n'.join(t2i_instruction)
214
+ elif mode == "Image-Variation":
215
+ return '\n'.join(i2i_instruction)
216
+ elif mode == "Image-to-Text":
217
+ return '\n'.join(i2t_instruction)
218
+ elif mode == "Text-Variation":
219
+ return '\n'.join(t2t_instruction)
220
+ elif mode == "Dual-Context":
221
+ return '\n'.join(dcg_instruction)
222
+ elif mode == "Triple-Context":
223
+ return '\n'.join(tcg_instruction)
224
+ elif mode == "Multi-Context":
225
+ return '\n'.join(mcg_instruction)
226
+ else:
227
+ assert False
228
+
229
+ ########
230
+ # main #
231
+ ########
232
+ class vd_dummy(object):
233
+ def __init__(self, *args, **kwarg):
234
+ self.which = 'Vdummy'
235
+ def inference_t2i(self, *args, **kwarg): pass
236
+ def inference_i2i(self, *args, **kwarg): pass
237
+ def inference_i2t(self, *args, **kwarg): pass
238
+ def inference_t2t(self, *args, **kwarg): pass
239
+ def inference_dcg(self, *args, **kwarg): pass
240
+ def inference_tcg(self, *args, **kwarg): pass
241
+ def inference_mcg(self, *args, **kwarg):
242
+ return None, None
243
+
244
+ class vd_inference(object):
245
+ def __init__(self, fp16=False, which='v2.0'):
246
+ highlight_print(which)
247
+ self.which = which
248
+
249
+ if self.which == 'v1.0':
250
+ cfgm = model_cfg_bank()('vd_four_flow_v1-0')
251
+ else:
252
+ assert False, 'Model type not supported'
253
+ net = get_model()(cfgm)
254
+
255
+ if fp16:
256
+ highlight_print('Running in FP16')
257
+ if self.which == 'v1.0':
258
+ net.ctx['text'].fp16 = True
259
+ net.ctx['image'].fp16 = True
260
+ net = net.half()
261
+ self.dtype = torch.float16
262
+ else:
263
+ self.dtype = torch.float32
264
+
265
+ if self.which == 'v1.0':
266
+ # if fp16:
267
+ # sd = torch.load('pretrained/vd-four-flow-v1-0-fp16.pth', map_location='cpu')
268
+ # else:
269
+ # sd = torch.load('pretrained/vd-four-flow-v1-0.pth', map_location='cpu')
270
+ from huggingface_hub import hf_hub_download
271
+ if fp16:
272
+ temppath = hf_hub_download('shi-labs/versatile-diffusion-model', 'pretrained_pth/vd-four-flow-v1-0-fp16.pth')
273
+ else:
274
+ temppath = hf_hub_download('shi-labs/versatile-diffusion-model', 'pretrained_pth/vd-four-flow-v1-0.pth')
275
+ sd = torch.load(temppath, map_location='cpu')
276
+
277
+ net.load_state_dict(sd, strict=False)
278
+
279
+ self.use_cuda = torch.cuda.is_available()
280
+ if self.use_cuda:
281
+ net.to('cuda')
282
+ self.net = net
283
+ self.sampler = DDIMSampler(net)
284
+
285
+ self.output_dim = [512, 512]
286
+ self.n_sample_image = n_sample_image
287
+ self.n_sample_text = n_sample_text
288
+ self.ddim_steps = 50
289
+ self.ddim_eta = 0.0
290
+ self.scale_textto = 7.5
291
+ self.image_latent_dim = 4
292
+ self.text_latent_dim = 768
293
+ self.text_temperature = 1
294
+
295
+ if which == 'v1.0':
296
+ self.adjust_rank_f = adjust_rank(max_drop_rank=[1, 5], q=20)
297
+ self.scale_imgto = 7.5
298
+ self.disentanglement_noglobal = True
299
+
300
+ def inference_t2i(self, text, seed):
301
+ n_samples = self.n_sample_image
302
+ scale = self.scale_textto
303
+ sampler = self.sampler
304
+ h, w = self.output_dim
305
+ u = self.net.ctx_encode([""], which='text').repeat(n_samples, 1, 1)
306
+ c = self.net.ctx_encode([text], which='text').repeat(n_samples, 1, 1)
307
+ shape = [n_samples, self.image_latent_dim, h//8, w//8]
308
+ np.random.seed(seed)
309
+ torch.manual_seed(seed + 100)
310
+ x, _ = sampler.sample(
311
+ steps=self.ddim_steps,
312
+ x_info={'type':'image'},
313
+ c_info={'type':'text', 'conditioning':c, 'unconditional_conditioning':u,
314
+ 'unconditional_guidance_scale':scale},
315
+ shape=shape,
316
+ verbose=False,
317
+ eta=self.ddim_eta)
318
+ im = self.net.vae_decode(x, which='image')
319
+ im = [tvtrans.ToPILImage()(i) for i in im]
320
+ return im
321
+
322
+ def inference_i2i(self, im, fid_lvl, fcs_lvl, clr_adj, seed):
323
+ n_samples = self.n_sample_image
324
+ scale = self.scale_imgto
325
+ sampler = self.sampler
326
+ h, w = self.output_dim
327
+ device = self.net.device
328
+
329
+ BICUBIC = PIL.Image.Resampling.BICUBIC
330
+ im = im.resize([w, h], resample=BICUBIC)
331
+
332
+ if fid_lvl == 1:
333
+ return [im]*n_samples
334
+
335
+ cx = tvtrans.ToTensor()(im)[None].to(device).to(self.dtype)
336
+
337
+ c = self.net.ctx_encode(cx, which='image')
338
+ if self.disentanglement_noglobal:
339
+ c_glb = c[:, 0:1]
340
+ c_loc = c[:, 1: ]
341
+ c_loc = self.adjust_rank_f(c_loc, fcs_lvl)
342
+ c = torch.cat([c_glb, c_loc], dim=1).repeat(n_samples, 1, 1)
343
+ else:
344
+ c = self.adjust_rank_f(c, fcs_lvl).repeat(n_samples, 1, 1)
345
+ u = torch.zeros_like(c)
346
+
347
+ shape = [n_samples, self.image_latent_dim, h//8, w//8]
348
+ np.random.seed(seed)
349
+ torch.manual_seed(seed + 100)
350
+ if fid_lvl!=0:
351
+ x0 = self.net.vae_encode(cx, which='image').repeat(n_samples, 1, 1, 1)
352
+ step = int(self.ddim_steps * (1-fid_lvl))
353
+ x, _ = sampler.sample(
354
+ steps=self.ddim_steps,
355
+ x_info={'type':'image', 'x0':x0, 'x0_forward_timesteps':step},
356
+ c_info={'type':'image', 'conditioning':c, 'unconditional_conditioning':u,
357
+ 'unconditional_guidance_scale':scale},
358
+ shape=shape,
359
+ verbose=False,
360
+ eta=self.ddim_eta)
361
+ else:
362
+ x, _ = sampler.sample(
363
+ steps=self.ddim_steps,
364
+ x_info={'type':'image',},
365
+ c_info={'type':'image', 'conditioning':c, 'unconditional_conditioning':u,
366
+ 'unconditional_guidance_scale':scale},
367
+ shape=shape,
368
+ verbose=False,
369
+ eta=self.ddim_eta)
370
+
371
+ imout = self.net.vae_decode(x, which='image')
372
+
373
+ if clr_adj == 'Simple':
374
+ cx_mean = cx.view(3, -1).mean(-1)[:, None, None]
375
+ cx_std = cx.view(3, -1).std(-1)[:, None, None]
376
+ imout_mean = [imouti.view(3, -1).mean(-1)[:, None, None] for imouti in imout]
377
+ imout_std = [imouti.view(3, -1).std(-1)[:, None, None] for imouti in imout]
378
+ imout = [(ii-mi)/si*cx_std+cx_mean for ii, mi, si in zip(imout, imout_mean, imout_std)]
379
+ imout = [torch.clamp(ii, 0, 1) for ii in imout]
380
+
381
+ imout = [tvtrans.ToPILImage()(i) for i in imout]
382
+ return imout
383
+
384
+ def inference_i2t(self, im, seed):
385
+ n_samples = self.n_sample_text
386
+ scale = self.scale_imgto
387
+ sampler = self.sampler
388
+ h, w = self.output_dim
389
+ device = self.net.device
390
+
391
+ BICUBIC = PIL.Image.Resampling.BICUBIC
392
+ im = im.resize([w, h], resample=BICUBIC)
393
+
394
+ cx = tvtrans.ToTensor()(im)[None].to(device)
395
+ c = self.net.ctx_encode(cx, which='image').repeat(n_samples, 1, 1)
396
+ u = self.net.ctx_encode(torch.zeros_like(cx), which='image').repeat(n_samples, 1, 1)
397
+
398
+ shape = [n_samples, self.text_latent_dim]
399
+ np.random.seed(seed)
400
+ torch.manual_seed(seed + 100)
401
+ x, _ = sampler.sample(
402
+ steps=self.ddim_steps,
403
+ x_info={'type':'text',},
404
+ c_info={'type':'image', 'conditioning':c, 'unconditional_conditioning':u,
405
+ 'unconditional_guidance_scale':scale},
406
+ shape=shape,
407
+ verbose=False,
408
+ eta=self.ddim_eta)
409
+ tx = self.net.vae_decode(x, which='text', temperature=self.text_temperature)
410
+ tx = [remove_duplicate_word(txi) for txi in tx]
411
+ tx_combined = '\n'.join(tx)
412
+ return tx_combined
413
+
414
+ def inference_t2t(self, text, seed):
415
+ n_samples = self.n_sample_text
416
+ scale = self.scale_textto
417
+ sampler = self.sampler
418
+ u = self.net.ctx_encode([""], which='text').repeat(n_samples, 1, 1)
419
+ c = self.net.ctx_encode([text], which='text').repeat(n_samples, 1, 1)
420
+ shape = [n_samples, self.text_latent_dim]
421
+ np.random.seed(seed)
422
+ torch.manual_seed(seed + 100)
423
+ x, _ = sampler.sample(
424
+ steps=self.ddim_steps,
425
+ x_info={'type':'text',},
426
+ c_info={'type':'text', 'conditioning':c, 'unconditional_conditioning':u,
427
+ 'unconditional_guidance_scale':scale},
428
+ shape=shape,
429
+ verbose=False,
430
+ eta=self.ddim_eta)
431
+ tx = self.net.vae_decode(x, which='text', temperature=self.text_temperature)
432
+ tx = [remove_duplicate_word(txi) for txi in tx]
433
+ tx_combined = '\n'.join(tx)
434
+ return tx_combined
435
+
436
+ def inference_dcg(self, imctx, fcs_lvl, textctx, textstrength, seed):
437
+ n_samples = self.n_sample_image
438
+ sampler = self.sampler
439
+ h, w = self.output_dim
440
+ device = self.net.device
441
+
442
+ c_info_list = []
443
+
444
+ if (textctx is not None) and (textctx != "") and (textstrength != 0):
445
+ ut = self.net.ctx_encode([""], which='text').repeat(n_samples, 1, 1)
446
+ ct = self.net.ctx_encode([textctx], which='text').repeat(n_samples, 1, 1)
447
+ scale = self.scale_imgto*(1-textstrength) + self.scale_textto*textstrength
448
+
449
+ c_info_list.append({
450
+ 'type':'text',
451
+ 'conditioning':ct,
452
+ 'unconditional_conditioning':ut,
453
+ 'unconditional_guidance_scale':scale,
454
+ 'ratio': textstrength, })
455
+ else:
456
+ scale = self.scale_imgto
457
+ textstrength = 0
458
+
459
+ BICUBIC = PIL.Image.Resampling.BICUBIC
460
+ cx = imctx.resize([w, h], resample=BICUBIC)
461
+ cx = tvtrans.ToTensor()(cx)[None].to(device).to(self.dtype)
462
+ ci = self.net.ctx_encode(cx, which='image')
463
+
464
+ if self.disentanglement_noglobal:
465
+ ci_glb = ci[:, 0:1]
466
+ ci_loc = ci[:, 1: ]
467
+ ci_loc = self.adjust_rank_f(ci_loc, fcs_lvl)
468
+ ci = torch.cat([ci_glb, ci_loc], dim=1).repeat(n_samples, 1, 1)
469
+ else:
470
+ ci = self.adjust_rank_f(ci, fcs_lvl).repeat(n_samples, 1, 1)
471
+
472
+ c_info_list.append({
473
+ 'type':'image',
474
+ 'conditioning':ci,
475
+ 'unconditional_conditioning':torch.zeros_like(ci),
476
+ 'unconditional_guidance_scale':scale,
477
+ 'ratio': (1-textstrength), })
478
+
479
+ shape = [n_samples, self.image_latent_dim, h//8, w//8]
480
+ np.random.seed(seed)
481
+ torch.manual_seed(seed + 100)
482
+ x, _ = sampler.sample_multicontext(
483
+ steps=self.ddim_steps,
484
+ x_info={'type':'image',},
485
+ c_info_list=c_info_list,
486
+ shape=shape,
487
+ verbose=False,
488
+ eta=self.ddim_eta)
489
+
490
+ imout = self.net.vae_decode(x, which='image')
491
+ imout = [tvtrans.ToPILImage()(i) for i in imout]
492
+ return imout
493
+
494
+ def inference_tcg(self, *args):
495
+ args_imag = list(args[0:10]) + [None, None, None, None, None]*2
496
+ args_rest = args[10:]
497
+ imin, imout = self.inference_mcg(*args_imag, *args_rest)
498
+ return imin, imout
499
+
500
+ def inference_mcg(self, *args):
501
+ imctx = [args[0:5], args[5:10], args[10:15], args[15:20]]
502
+ textctx, textstrength, seed = args[20:]
503
+
504
+ n_samples = self.n_sample_image
505
+ sampler = self.sampler
506
+ h, w = self.output_dim
507
+ device = self.net.device
508
+
509
+ c_info_list = []
510
+
511
+ if (textctx is not None) and (textctx != "") and (textstrength != 0):
512
+ ut = self.net.ctx_encode([""], which='text').repeat(n_samples, 1, 1)
513
+ ct = self.net.ctx_encode([textctx], which='text').repeat(n_samples, 1, 1)
514
+ scale = self.scale_imgto*(1-textstrength) + self.scale_textto*textstrength
515
+
516
+ c_info_list.append({
517
+ 'type':'text',
518
+ 'conditioning':ct,
519
+ 'unconditional_conditioning':ut,
520
+ 'unconditional_guidance_scale':scale,
521
+ 'ratio': textstrength, })
522
+ else:
523
+ scale = self.scale_imgto
524
+ textstrength = 0
525
+
526
+ input_save = []
527
+ imc = []
528
+ for im, imm, strength, fcs_lvl, use_mask in imctx:
529
+ if (im is None) and (imm is None):
530
+ continue
531
+ BILINEAR = PIL.Image.Resampling.BILINEAR
532
+ BICUBIC = PIL.Image.Resampling.BICUBIC
533
+ if use_mask:
534
+ cx = imm['image'].resize([w, h], resample=BICUBIC)
535
+ cx = tvtrans.ToTensor()(cx)[None].to(self.dtype).to(device)
536
+ m = imm['mask'].resize([w, h], resample=BILINEAR)
537
+ m = tvtrans.ToTensor()(m)[None, 0:1].to(self.dtype).to(device)
538
+ m = (1-m)
539
+ cx_show = cx*m
540
+ ci = self.net.ctx_encode(cx, which='image', masks=m)
541
+ else:
542
+ cx = im.resize([w, h], resample=BICUBIC)
543
+ cx = tvtrans.ToTensor()(cx)[None].to(self.dtype).to(device)
544
+ ci = self.net.ctx_encode(cx, which='image')
545
+ cx_show = cx
546
+
547
+ input_save.append(tvtrans.ToPILImage()(cx_show[0]))
548
+
549
+ if self.disentanglement_noglobal:
550
+ ci_glb = ci[:, 0:1]
551
+ ci_loc = ci[:, 1: ]
552
+ ci_loc = self.adjust_rank_f(ci_loc, fcs_lvl)
553
+ ci = torch.cat([ci_glb, ci_loc], dim=1).repeat(n_samples, 1, 1)
554
+ else:
555
+ ci = self.adjust_rank_f(ci, fcs_lvl).repeat(n_samples, 1, 1)
556
+ imc.append(ci * strength)
557
+
558
+ cis = torch.cat(imc, dim=1)
559
+ c_info_list.append({
560
+ 'type':'image',
561
+ 'conditioning':cis,
562
+ 'unconditional_conditioning':torch.zeros_like(cis),
563
+ 'unconditional_guidance_scale':scale,
564
+ 'ratio': (1-textstrength), })
565
+
566
+ shape = [n_samples, self.image_latent_dim, h//8, w//8]
567
+ np.random.seed(seed)
568
+ torch.manual_seed(seed + 100)
569
+ x, _ = sampler.sample_multicontext(
570
+ steps=self.ddim_steps,
571
+ x_info={'type':'image',},
572
+ c_info_list=c_info_list,
573
+ shape=shape,
574
+ verbose=False,
575
+ eta=self.ddim_eta)
576
+
577
+ imout = self.net.vae_decode(x, which='image')
578
+ imout = [tvtrans.ToPILImage()(i) for i in imout]
579
+ return input_save, imout
580
+
581
+ # vd_inference = vd_dummy()
582
+ vd_inference = vd_inference(which='v1.0', fp16=True)
583
+
584
+ #################
585
+ # sub interface #
586
+ #################
587
+
588
+ def t2i_interface(with_example=False):
589
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Text-to-Image") + '</p>')
590
+ with gr.Row():
591
+ with gr.Column():
592
+ text = gr.Textbox(lines=4, placeholder="Input prompt...", label='Text Input')
593
+ seed = gr.Number(20, label="Seed", precision=0)
594
+ button = gr.Button("Run")
595
+ with gr.Column():
596
+ img_output = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image)
597
+
598
+ button.click(
599
+ vd_inference.inference_t2i,
600
+ inputs=[text, seed],
601
+ outputs=[img_output])
602
+
603
+ if with_example:
604
+ gr.Examples(
605
+ label='Examples',
606
+ examples=get_example('Text-to-Image'),
607
+ fn=vd_inference.inference_t2i,
608
+ inputs=[text, seed],
609
+ outputs=[img_output],
610
+ cache_examples=cache_examples),
611
+
612
+ def i2i_interface(with_example=False):
613
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Image-Variation") + '</p>')
614
+ with gr.Row():
615
+ with gr.Column():
616
+ img_input = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
617
+ sim_flag = gr.Checkbox(label='Show Detail Controls')
618
+ with gr.Row():
619
+ fid_lvl = gr.Slider(label="Fidelity (Dislike -- Same)", minimum=0, maximum=1, value=0, step=0.02, visible=False)
620
+ fcs_lvl = gr.Slider(label="Focus (Semantic -- Style)", minimum=0, maximum=1, value=0.5, step=0.02, visible=False)
621
+ clr_adj = gr.Radio(label="Color Adjustment", choices=["None", "Simple"], value='Simple', visible=False)
622
+ explain = gr.HTML('<p id=myinst>&nbsp Fidelity: How likely the output image looks like the referece image (0-dislike (default), 1-same).</p>'+
623
+ '<p id=myinst>&nbsp Focus: What the output image should focused on (0-semantic, 0.5-balanced (default), 1-style).</p>',
624
+ visible=False)
625
+ seed = gr.Number(20, label="Seed", precision=0)
626
+ button = gr.Button("Run")
627
+ with gr.Column():
628
+ img_output = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image)
629
+
630
+ sim_flag.change(
631
+ fn=lambda x: {
632
+ explain : gr.update(visible=x),
633
+ fid_lvl : gr.update(visible=x),
634
+ fcs_lvl : gr.update(visible=x),
635
+ clr_adj : gr.update(visible=x), },
636
+ inputs=sim_flag,
637
+ outputs=[explain, fid_lvl, fcs_lvl, clr_adj, seed],)
638
+
639
+ button.click(
640
+ vd_inference.inference_i2i,
641
+ inputs=[img_input, fid_lvl, fcs_lvl, clr_adj, seed],
642
+ outputs=[img_output])
643
+
644
+ if with_example:
645
+ gr.Examples(
646
+ label='Examples',
647
+ examples=get_example('Image-Variation'),
648
+ fn=vd_inference.inference_i2i,
649
+ inputs=[img_input, fid_lvl, fcs_lvl, clr_adj, seed],
650
+ outputs=[img_output],
651
+ cache_examples=cache_examples),
652
+
653
+ def i2t_interface(with_example=False):
654
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Image-to-Text") + '</p>')
655
+ with gr.Row():
656
+ with gr.Column():
657
+ img_input = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
658
+ seed = gr.Number(20, label="Seed", precision=0)
659
+ button = gr.Button("Run")
660
+ with gr.Column():
661
+ txt_output = gr.Textbox(lines=4, label='Text Result')
662
+
663
+ button.click(
664
+ vd_inference.inference_i2t,
665
+ inputs=[img_input, seed],
666
+ outputs=[txt_output])
667
+
668
+ if with_example:
669
+ gr.Examples(
670
+ label='Examples',
671
+ examples=get_example('Image-to-Text'),
672
+ fn=vd_inference.inference_i2t,
673
+ inputs=[img_input, seed],
674
+ outputs=[txt_output],
675
+ cache_examples=cache_examples),
676
+
677
+ def t2t_interface(with_example=False):
678
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Text-Variation") + '</p>')
679
+ with gr.Row():
680
+ with gr.Column():
681
+ text = gr.Textbox(lines=4, placeholder="Input prompt...", label='Text Input')
682
+ seed = gr.Number(20, label="Seed", precision=0)
683
+ button = gr.Button("Run")
684
+ with gr.Column():
685
+ txt_output = gr.Textbox(lines=4, label='Text Result')
686
+
687
+ button.click(
688
+ vd_inference.inference_t2t,
689
+ inputs=[text, seed],
690
+ outputs=[txt_output])
691
+
692
+ if with_example:
693
+ gr.Examples(
694
+ label='Examples',
695
+ examples=get_example('Text-Variation'),
696
+ fn=vd_inference.inference_t2t,
697
+ inputs=[text, seed],
698
+ outputs=[txt_output],
699
+ cache_examples=cache_examples, )
700
+
701
+ class image_mimage_swap(object):
702
+ def __init__(self, block0, block1):
703
+ self.block0 = block0
704
+ self.block1 = block1
705
+ self.which_update = 'both'
706
+
707
+ def __call__(self, x0, x1, flag):
708
+ if self.which_update == 'both':
709
+ return self.update_both(x0, x1, flag)
710
+ elif self.which_update == 'visible':
711
+ return self.update_visible(x0, x1, flag)
712
+ elif self.which_update == 'visible_oneoff':
713
+ return self.update_visible_oneoff(x0, x1, flag)
714
+ else:
715
+ assert False
716
+
717
+ def update_both(self, x0, x1, flag):
718
+ if flag:
719
+ ug0 = gr.update(visible=False)
720
+ if x0 is None:
721
+ ug1 = gr.update(value=None, visible=True)
722
+ else:
723
+ if (x1 is not None) and ('mask' in x1):
724
+ value1 = {'image':x0, 'mask':x1['mask']}
725
+ else:
726
+ value1 = {'image':x0, 'mask':None}
727
+ ug1 = gr.update(value=value1, visible=True)
728
+ else:
729
+ if (x1 is not None) and ('image' in x1):
730
+ value0 = x1['image']
731
+ else:
732
+ value0 = None
733
+ ug0 = gr.update(value=value0, visible=True)
734
+ ug1 = gr.update(visible=False)
735
+ return {
736
+ self.block0 : ug0,
737
+ self.block1 : ug1,}
738
+
739
+ def update_visible(self, x0, x1, flag):
740
+ return {
741
+ self.block0 : gr.update(visible=not flag),
742
+ self.block1 : gr.update(visible=flag), }
743
+
744
+ def update_visible_oneoff(self, x0, x1, flag):
745
+ self.which_update = 'both'
746
+ return {
747
+ self.block0 : gr.update(visible=not flag),
748
+ self.block1 : gr.update(visible=flag), }
749
+
750
+ class example_visible_only_hack(object):
751
+ def __init__(self, checkbox_list, functor_list):
752
+ self.checkbox_list = checkbox_list
753
+ self.functor_list = functor_list
754
+
755
+ def __call__(self, *args):
756
+ for bi, fi, vi in zip(self.checkbox_list, self.functor_list, args):
757
+ if bi.value != vi:
758
+ fi.which_update = 'visible_oneoff'
759
+
760
+ def dcg_interface(with_example=False):
761
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Dual-Context") + '</p>')
762
+ with gr.Row():
763
+ input_session = []
764
+ with gr.Column():
765
+ img = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
766
+ fcs = gr.Slider(label="Focus (Semantic -- Style)", minimum=0, maximum=1, value=0.5, step=0.02)
767
+ gr.HTML('<p id=myinst>&nbsp Focus: Focus on what aspect of the image? (0-semantic, 0.5-balanced (default), 1-style).</p>')
768
+
769
+ text = gr.Textbox(lines=2, placeholder="Input prompt...", label='Text Input')
770
+ tstrength = gr.Slider(label="Text Domination (NoEffect -- TextOnly)", minimum=0, maximum=1, value=0, step=0.02)
771
+
772
+ seed = gr.Number(20, label="Seed", precision=0)
773
+ button = gr.Button("Run")
774
+
775
+ with gr.Column():
776
+ output_gallary = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image)
777
+
778
+ input_list = []
779
+ for i in input_session:
780
+ input_list += i
781
+ button.click(
782
+ vd_inference.inference_dcg,
783
+ inputs=[img, fcs, text, tstrength, seed],
784
+ outputs=[output_gallary])
785
+
786
+ if with_example:
787
+ gr.Examples(
788
+ label='Examples',
789
+ examples=get_example('Dual-Context'),
790
+ fn=vd_inference.inference_dcg,
791
+ inputs=[img, fcs, text, tstrength, seed],
792
+ outputs=[output_gallary],
793
+ cache_examples=cache_examples)
794
+
795
+ def tcg_interface(with_example=False):
796
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Triple-Context") + '</p>')
797
+ with gr.Row():
798
+ input_session = []
799
+ with gr.Column(min_width=940):
800
+ with gr.Row():
801
+ with gr.Column():
802
+ img0 = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
803
+ img0.as_example = types.MethodType(customized_as_example, img0)
804
+ imgm0 = gr.Image(label='Image Input with Mask', type='pil', elem_id='customized_imbox', tool='sketch', source="upload", visible=False)
805
+ imgm0.postprocess = types.MethodType(customized_postprocess, imgm0)
806
+ imgm0.as_example = types.MethodType(customized_as_example, imgm0)
807
+ istrength0 = gr.Slider(label="Weight", minimum=0, maximum=1, value=1, step=0.02)
808
+ fcs0 = gr.Slider(label="Focus (Semantic -- Style)", minimum=0, maximum=1, value=0.5, step=0.02)
809
+ msk0 = gr.Checkbox(label='Use mask?')
810
+ swapf0 = image_mimage_swap(img0, imgm0)
811
+
812
+ msk0.change(
813
+ fn=swapf0,
814
+ inputs=[img0, imgm0, msk0],
815
+ outputs=[img0, imgm0],)
816
+ input_session.append([img0, imgm0, istrength0, fcs0, msk0])
817
+
818
+ with gr.Column():
819
+ img1 = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
820
+ img1.as_example = types.MethodType(customized_as_example, img1)
821
+ imgm1 = gr.Image(label='Image Input with Mask', type='pil', elem_id='customized_imbox', tool='sketch', source="upload", visible=False)
822
+ imgm1.postprocess = types.MethodType(customized_postprocess, imgm1)
823
+ imgm1.as_example = types.MethodType(customized_as_example, imgm1)
824
+ istrength1 = gr.Slider(label="Weight", minimum=0, maximum=1, value=1, step=0.02)
825
+ fcs1 = gr.Slider(label="Focus (Semantic -- Style)", minimum=0, maximum=1, value=0.5, step=0.02)
826
+ msk1 = gr.Checkbox(label='Use mask?')
827
+ swapf1 = image_mimage_swap(img1, imgm1)
828
+
829
+ msk1.change(
830
+ fn=swapf1,
831
+ inputs=[img1, imgm1, msk1],
832
+ outputs=[img1, imgm1],)
833
+ input_session.append([img1, imgm1, istrength1, fcs1, msk1])
834
+
835
+ gr.HTML('<p id=myinst>&nbsp Weight: The strength of the reference image. This weight is subject to <u>Text Domination</u>).</p>'+
836
+ '<p id=myinst>&nbsp Focus: Focus on what aspect of the image? (0-semantic, 0.5-balanced (default), 1-style).</p>'+
837
+ '<p id=myinst>&nbsp Mask: Remove regions on reference image so they will not influence the output.</p>',)
838
+
839
+ text = gr.Textbox(lines=2, placeholder="Input prompt...", label='Text Input')
840
+ tstrength = gr.Slider(label="Text Domination (NoEffect -- TextOnly)", minimum=0, maximum=1, value=0, step=0.02)
841
+
842
+ seed = gr.Number(20, label="Seed", precision=0)
843
+ button = gr.Button("Run")
844
+
845
+ with gr.Column(min_width=470):
846
+ input_gallary = gr.Gallery(label="Input Display", elem_id="customized_imbox").style(grid=2)
847
+ output_gallary = gr.Gallery(label="Image Result", elem_id="customized_imbox").style(grid=n_sample_image)
848
+
849
+ input_list = []
850
+ for i in input_session:
851
+ input_list += i
852
+ input_list += [text, tstrength, seed]
853
+ button.click(
854
+ vd_inference.inference_tcg,
855
+ inputs=input_list,
856
+ outputs=[input_gallary, output_gallary])
857
+
858
+ if with_example:
859
+ create_myexamples(
860
+ label='Examples',
861
+ examples=get_example('Triple-Context'),
862
+ fn=vd_inference.inference_tcg,
863
+ inputs=input_list,
864
+ outputs=[input_gallary, output_gallary, ],
865
+ cache_examples=cache_examples, )
866
+
867
+ gr.HTML('<br><p id=myinst>&nbsp How to add mask: Please see the following instructions.</p><br>'+
868
+ '<div id="maskinst">'+
869
+ '<img src="file/assets/demo/misc/mask_inst1.gif">'+
870
+ '<img src="file/assets/demo/misc/mask_inst2.gif">'+
871
+ '<img src="file/assets/demo/misc/mask_inst3.gif">'+
872
+ '</div>')
873
+
874
+ def mcg_interface(with_example=False):
875
+ num_img_input = 4
876
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Multi-Context") + '</p>')
877
+ with gr.Row():
878
+ input_session = []
879
+ with gr.Column():
880
+ for idx in range(num_img_input):
881
+ with gr.Tab('Image{}'.format(idx+1)):
882
+ img = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
883
+ img.as_example = types.MethodType(customized_as_example, img)
884
+ imgm = gr.Image(label='Image Input with Mask', type='pil', elem_id='customized_imbox', tool='sketch', source="upload", visible=False)
885
+ imgm.postprocess = types.MethodType(customized_postprocess, imgm)
886
+ imgm.as_example = types.MethodType(customized_as_example, imgm)
887
+
888
+ with gr.Row():
889
+ istrength = gr.Slider(label="Weight", minimum=0, maximum=1, value=1, step=0.02)
890
+ fcs = gr.Slider(label="Focus (Semantic -- Style)", minimum=0, maximum=1, value=0.5, step=0.02)
891
+ msk = gr.Checkbox(label='Use mask?')
892
+ gr.HTML('<p id=myinst>&nbsp Weight: The strength of the reference image. This weight is subject to <u>Text Domination</u>).</p>'+
893
+ '<p id=myinst>&nbsp Focus: Focus on what aspect of the image? (0-semantic, 0.5-balanced (default), 1-style).</p>'+
894
+ '<p id=myinst>&nbsp Mask: Remove regions on reference image so they will not influence the output.</p>',)
895
+
896
+ msk.change(
897
+ fn=image_mimage_swap(img, imgm),
898
+ inputs=[img, imgm, msk],
899
+ outputs=[img, imgm],)
900
+ input_session.append([img, imgm, istrength, fcs, msk])
901
+
902
+ text = gr.Textbox(lines=2, placeholder="Input prompt...", label='Text Input')
903
+ tstrength = gr.Slider(label="Text Domination (NoEffect -- TextOnly)", minimum=0, maximum=1, value=0, step=0.02)
904
+
905
+ seed = gr.Number(20, label="Seed", precision=0)
906
+ button = gr.Button("Run")
907
+
908
+
909
+ with gr.Column():
910
+ input_gallary = gr.Gallery(label="Input Display", elem_id='customized_imbox').style(grid=4)
911
+ output_gallary = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image)
912
+
913
+ input_list = []
914
+ for i in input_session:
915
+ input_list += i
916
+ input_list += [text, tstrength, seed]
917
+ button.click(
918
+ vd_inference.inference_mcg,
919
+ inputs=input_list,
920
+ outputs=[input_gallary, output_gallary], )
921
+
922
+ if with_example:
923
+ create_myexamples(
924
+ label='Examples',
925
+ examples=get_example('Multi-Context'),
926
+ fn=vd_inference.inference_mcg,
927
+ inputs=input_list,
928
+ outputs=[input_gallary, output_gallary],
929
+ cache_examples=cache_examples, )
930
+
931
+ gr.HTML('<br><p id=myinst>&nbsp How to add mask: Please see the following instructions.</p><br>'+
932
+ '<div id="maskinst">'+
933
+ '<img src="file/assets/demo/misc/mask_inst1.gif">'+
934
+ '<img src="file/assets/demo/misc/mask_inst2.gif">'+
935
+ '<img src="file/assets/demo/misc/mask_inst3.gif">'+
936
+ '</div>')
937
+
938
+ ###########
939
+ # Example #
940
+ ###########
941
+
942
+ def get_example(mode):
943
+ if mode == 'Text-to-Image':
944
+ case = [
945
+ ['a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ', 23],
946
+ ['a beautiful landscape with mountains and rivers', 20],
947
+ ]
948
+ elif mode == "Image-Variation":
949
+ case = [
950
+ ['assets/demo/reg_example/ghibli.jpg', 0, 0.5, 'None', 20],
951
+ ['assets/demo/reg_example/ghibli.jpg', 0.5, 0.5, 'None', 20],
952
+ ['assets/demo/reg_example/matisse.jpg', 0, 0, 'None', 20],
953
+ ['assets/demo/reg_example/matisse.jpg', 0, 1, 'Simple', 20],
954
+ ['assets/demo/reg_example/vermeer.jpg', 0.2, 0.3, 'None', 30],
955
+ ]
956
+ elif mode == "Image-to-Text":
957
+ case = [
958
+ ['assets/demo/reg_example/house_by_lake.jpg', 20],
959
+ ]
960
+ elif mode == "Text-Variation":
961
+ case = [
962
+ ['heavy arms gundam penguin mech', 20],
963
+ ]
964
+ elif mode == "Dual-Context":
965
+ case = [
966
+ ['assets/demo/reg_example/benz.jpg', 0.5, 'cyberpunk 2077', 0.7, 22],
967
+ ['assets/demo/reg_example/ghibli.jpg', 1, 'Red maple on a hill in golden Autumn.', 0.5, 21],
968
+ ]
969
+ elif mode == "Triple-Context":
970
+ case = [
971
+ [
972
+ 'assets/demo/reg_example/night_light.jpg', None, 1 , 0.5, False,
973
+ 'assets/demo/reg_example/paris.jpg' , None, 0.94, 0.5, False,
974
+ "snow on the street", 0.4, 28],
975
+ [
976
+ 'assets/demo/tcg_example/e1i0.jpg', None, 1 , 0.5, False,
977
+ 'assets/demo/tcg_example/e1i1.jpg', None, 0.94, 0.5, False,
978
+ "a painting of an elegant woman in front of the moon", 0.2, 217],
979
+ [
980
+ 'assets/demo/tcg_example/e2i0.jpg', None, 1, 0.5, False,
981
+ 'assets/demo/reg_example/paris.jpg', None, 1, 0.5, False,
982
+ "", 0, 29],
983
+ [
984
+ 'assets/demo/tcg_example/e0i0.jpg', None, 1 , 0.5, False,
985
+ 'assets/demo/tcg_example/e0i1.jpg', None, 0.9, 0.5, False,
986
+ "rose blooms on the tree", 0.2, 20],
987
+ [
988
+ 'assets/demo/reg_example/ghibli.jpg', None, 1 , 1 , False,
989
+ 'assets/demo/reg_example/space.jpg' , None, 0.88, 0.5, False,
990
+ "", 0, 20],
991
+ [
992
+ 'assets/demo/reg_example/train.jpg' , None, 0.8, 0.5, False,
993
+ 'assets/demo/reg_example/matisse.jpg', None, 1 , 1 , False,
994
+ "", 0, 20],
995
+ ]
996
+ elif mode == "Multi-Context":
997
+ case = [
998
+ [
999
+ 'assets/demo/mcg_example/e0i0.jpg', None, 1, 0.5, False,
1000
+ 'assets/demo/mcg_example/e0i1.jpg', None, 1, 0.5, False,
1001
+ 'assets/demo/mcg_example/e0i2.jpg', None, 0.86, 0.5, False,
1002
+ None, None, 1, 0.5, False,
1003
+ "", 0, 20],
1004
+ ]
1005
+ else:
1006
+ raise ValueError
1007
+ return case
1008
+
1009
+ #############
1010
+ # Interface #
1011
+ #############
1012
+
1013
+ css = """
1014
+ #customized_imbox {
1015
+ min-height: 450px;
1016
+ }
1017
+ #customized_imbox>div[data-testid="image"] {
1018
+ min-height: 450px;
1019
+ }
1020
+ #customized_imbox>div[data-testid="image"]>div {
1021
+ min-height: 450px;
1022
+ }
1023
+ #customized_imbox>div[data-testid="image"]>iframe {
1024
+ min-height: 450px;
1025
+ }
1026
+ #customized_imbox>div.unpadded_box {
1027
+ min-height: 450px;
1028
+ }
1029
+ #myinst {
1030
+ font-size: 0.8rem;
1031
+ margin: 0rem;
1032
+ color: #6B7280;
1033
+ }
1034
+ #maskinst {
1035
+ text-align: justify;
1036
+ min-width: 1200px;
1037
+ }
1038
+ #maskinst>img {
1039
+ min-width:399px;
1040
+ max-width:450px;
1041
+ vertical-align: top;
1042
+ display: inline-block;
1043
+ }
1044
+ #maskinst:after {
1045
+ content: "";
1046
+ width: 100%;
1047
+ display: inline-block;
1048
+ }
1049
+ """
1050
+
1051
+ if True:
1052
+ with gr.Blocks(css=css) as demo:
1053
+ gr.HTML(
1054
+ """
1055
+ <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
1056
+ <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
1057
+ Versatile Diffusion
1058
+ </h1>
1059
+ <h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
1060
+ We built <b>Versatile Diffusion (VD), the first unified multi-flow multimodal diffusion framework</b>, as a step towards <b>Universal Generative AI</b>.
1061
+ VD can natively support image-to-text, image-variation, text-to-image, and text-variation,
1062
+ and can be further extended to other applications such as
1063
+ semantic-style disentanglement, image-text dual-guided generation, latent image-to-text-to-image editing, and more.
1064
+ Future versions will support more modalities such as speech, music, video and 3D.
1065
+ </h2>
1066
+ <h3 style="font-weight: 450; font-size: 1rem; margin: 0rem">
1067
+ Xingqian Xu, Atlas Wang, Eric Zhang, Kai Wang,
1068
+ and <a href="https://www.humphreyshi.com/home">Humphrey Shi</a>
1069
+ [<a href="https://arxiv.org/abs/2211.08332" style="color:blue;">arXiv</a>]
1070
+ [<a href="https://github.com/SHI-Labs/Versatile-Diffusion" style="color:blue;">GitHub</a>]
1071
+ </h3>
1072
+ </div>
1073
+ """)
1074
+
1075
+ with gr.Tab('Text-to-Image'):
1076
+ t2i_interface(with_example=True)
1077
+ with gr.Tab('Image-Variation'):
1078
+ i2i_interface(with_example=True)
1079
+ with gr.Tab('Image-to-Text'):
1080
+ i2t_interface(with_example=True)
1081
+ with gr.Tab('Text-Variation'):
1082
+ t2t_interface(with_example=True)
1083
+ with gr.Tab('Dual-Context Image-Generation'):
1084
+ dcg_interface(with_example=True)
1085
+ with gr.Tab('Triple-Context Image-Blender'):
1086
+ tcg_interface(with_example=True)
1087
+ with gr.Tab('Multi-Context Image-Blender'):
1088
+ mcg_interface(with_example=True)
1089
+
1090
+ gr.HTML(
1091
+ """
1092
+ <div style="text-align: justify; max-width: 1200px; margin: 20px auto;">
1093
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
1094
+ <b>Version</b>: {}
1095
+ </h3>
1096
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
1097
+ <b>Caution</b>:
1098
+ We would like the raise the awareness of users of this demo of its potential issues and concerns.
1099
+ Like previous large foundation models, Versatile Diffusion could be problematic in some cases, partially due to the imperfect training data and pretrained network (VAEs / context encoders) with limited scope.
1100
+ In its future research phase, VD may do better on tasks such as text-to-image, image-to-text, etc., with the help of more powerful VAEs, more sophisticated network designs, and more cleaned data.
1101
+ So far, we keep all features available for research testing both to show the great potential of the VD framework and to collect important feedback to improve the model in the future.
1102
+ We welcome researchers and users to report issues with the HuggingFace community discussion feature or email the authors.
1103
+ </h3>
1104
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
1105
+ <b>Biases and content acknowledgement</b>:
1106
+ Beware that VD may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography, and violence.
1107
+ VD was trained on the LAION-2B dataset, which scraped non-curated online images and text, and may contained unintended exceptions as we removed illegal content.
1108
+ VD in this demo is meant only for research purposes.
1109
+ </h3>
1110
+ </div>
1111
+ """.format(' '+vd_inference.which))
1112
+
1113
+ # demo.launch(share=True)
1114
+ demo.launch(debug=True)
assets/demo/mcg_example/e0i0.jpg ADDED
assets/demo/mcg_example/e0i1.jpg ADDED
assets/demo/mcg_example/e0i2.jpg ADDED
assets/demo/misc/mask_inst1.gif ADDED

Git LFS Details

  • SHA256: 90732a23a9a275649068654ae0c29418ea28ffb45eef6605da6d42e77390e808
  • Pointer size: 132 Bytes
  • Size of remote file: 5.23 MB
assets/demo/misc/mask_inst2.gif ADDED

Git LFS Details

  • SHA256: 183544affa3f5c76cf347e25d991a87e0eeb426b042f70ea33ef9acc6217d53f
  • Pointer size: 132 Bytes
  • Size of remote file: 5.82 MB
assets/demo/misc/mask_inst3.gif ADDED

Git LFS Details

  • SHA256: 6136887307c45b86ce451eff1102c7e996d46a107795439b3f35e4391d348b30
  • Pointer size: 132 Bytes
  • Size of remote file: 5.53 MB
assets/demo/misc/noimage.jpg ADDED
assets/demo/reg_example/benz.jpg ADDED
assets/demo/reg_example/boy_and_girl.jpg ADDED
assets/demo/reg_example/church.jpg ADDED
assets/demo/reg_example/firework.jpg ADDED
assets/demo/reg_example/ghibli.jpg ADDED
assets/demo/reg_example/horse.jpg ADDED
assets/demo/reg_example/house_by_lake.jpg ADDED
assets/demo/reg_example/matisse.jpg ADDED
assets/demo/reg_example/night_light.jpg ADDED
assets/demo/reg_example/noimage.jpg ADDED
assets/demo/reg_example/paris.jpg ADDED
assets/demo/reg_example/penguin.jpg ADDED
assets/demo/reg_example/san_diego.jpg ADDED
assets/demo/reg_example/scream.jpg ADDED
assets/demo/reg_example/space.jpg ADDED
assets/demo/reg_example/tiger.jpg ADDED
assets/demo/reg_example/train.jpg ADDED
assets/demo/reg_example/vermeer.jpg ADDED
assets/demo/tcg_example/e0i0.jpg ADDED
assets/demo/tcg_example/e0i1.jpg ADDED
assets/demo/tcg_example/e1i0.jpg ADDED
assets/demo/tcg_example/e1i1.jpg ADDED
assets/demo/tcg_example/e2i0.jpg ADDED
assets/figures/share_instruction.png ADDED
configs/model/autokl.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ autokl:
2
+ symbol: autokl
3
+ find_unused_parameters: false
4
+
5
+ autokl_v1:
6
+ super_cfg: autokl
7
+ type: autoencoderkl
8
+ args:
9
+ embed_dim: 4
10
+ ddconfig:
11
+ double_z: true
12
+ z_channels: 4
13
+ resolution: 256
14
+ in_channels: 3
15
+ out_ch: 3
16
+ ch: 128
17
+ ch_mult: [1, 2, 4, 4]
18
+ num_res_blocks: 2
19
+ attn_resolutions: []
20
+ dropout: 0.0
21
+ lossconfig: null
22
+ # pth: pretrained/kl-f8.pth
23
+ hfm: ['shi-labs/versatile-diffusion-model', 'pretrained_pth/kl-f8.pth']
configs/model/clip.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip:
2
+ symbol: clip
3
+ args: {}
4
+
5
+ clip_text_context_encoder:
6
+ super_cfg: clip
7
+ type: clip_text_context_encoder
8
+ args: {}
9
+
10
+ clip_image_context_encoder:
11
+ super_cfg: clip
12
+ type: clip_image_context_encoder
13
+ args: {}
configs/model/openai_unet.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #########
2
+ # v1 2d #
3
+ #########
4
+
5
+ openai_unet_2d_v1:
6
+ type: openai_unet_2d_next
7
+ args:
8
+ in_channels: 4
9
+ out_channels: 4
10
+ model_channels: 320
11
+ attention_resolutions: [ 4, 2, 1 ]
12
+ num_res_blocks: [ 2, 2, 2, 2 ]
13
+ channel_mult: [ 1, 2, 4, 4 ]
14
+ num_heads: 8
15
+ context_dim: 768
16
+ use_checkpoint: True
17
+ parts: [global, data, context]
18
+
19
+ openai_unet_2d_v1_g:
20
+ super_cfg: openai_unet_2d_v1
21
+ args:
22
+ parts: [global]
23
+
24
+ openai_unet_2d_v1_d:
25
+ super_cfg: openai_unet_2d_v1
26
+ args:
27
+ parts: [data]
28
+
29
+ openai_unet_2d_v1_c:
30
+ super_cfg: openai_unet_2d_v1
31
+ args:
32
+ parts: [context]
33
+
34
+ openai_unet_2d_v1_gd:
35
+ super_cfg: openai_unet_2d_v1
36
+ args:
37
+ parts: [global, data]
38
+
39
+ openai_unet_2d_v1_gc:
40
+ super_cfg: openai_unet_2d_v1
41
+ args:
42
+ parts: [global, context]
43
+
44
+ openai_unet_2d_v1_dc:
45
+ super_cfg: openai_unet_2d_v1
46
+ args:
47
+ parts: [data, context]
48
+
49
+ #########
50
+ # v1 0d #
51
+ #########
52
+
53
+ openai_unet_0d_v1:
54
+ type: openai_unet_0d_next
55
+ args:
56
+ input_channels: 768
57
+ model_channels: 320
58
+ output_channels: 768
59
+ num_noattn_blocks: [ 2, 2, 2, 2 ]
60
+ channel_mult: [ 1, 2, 4, 4 ]
61
+ second_dim: [ 4, 4, 4, 4 ]
62
+ with_attn: [true, true, true, false]
63
+ num_heads: 8
64
+ context_dim: 768
65
+ use_checkpoint: True
66
+ parts: [global, data, context]
67
+
68
+ openai_unet_0d_v1_g:
69
+ super_cfg: openai_unet_0d_v1
70
+ args:
71
+ parts: [global]
72
+
73
+ openai_unet_0d_v1_d:
74
+ super_cfg: openai_unet_0d_v1
75
+ args:
76
+ parts: [data]
77
+
78
+ openai_unet_0d_v1_c:
79
+ super_cfg: openai_unet_0d_v1
80
+ args:
81
+ parts: [context]
82
+
83
+ openai_unet_0d_v1_gd:
84
+ super_cfg: openai_unet_0d_v1
85
+ args:
86
+ parts: [global, data]
87
+
88
+ openai_unet_0d_v1_gc:
89
+ super_cfg: openai_unet_0d_v1
90
+ args:
91
+ parts: [global, context]
92
+
93
+ openai_unet_0d_v1_dc:
94
+ super_cfg: openai_unet_0d_v1
95
+ args:
96
+ parts: [data, context]
configs/model/optimus.yaml ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ optimus:
3
+ symbol: optimus
4
+ find_unused_parameters: false
5
+ args: {}
6
+
7
+ optimus_bert_encoder:
8
+ super_cfg: optimus
9
+ type: optimus_bert_connector
10
+ # pth: pretrained/optimus_bert_encoder.pth
11
+ args:
12
+ config:
13
+ architectures:
14
+ - BertForMaskedLM
15
+ attention_probs_dropout_prob: 0.1
16
+ finetuning_task: null
17
+ hidden_act: gelu
18
+ hidden_dropout_prob: 0.1
19
+ hidden_size: 768
20
+ initializer_range: 0.02
21
+ intermediate_size: 3072
22
+ layer_norm_eps: 1.e-12
23
+ max_position_embeddings: 512
24
+ num_attention_heads: 12
25
+ num_hidden_layers: 12
26
+ num_labels: 2
27
+ output_attentions: false
28
+ output_hidden_states: false
29
+ pruned_heads: {}
30
+ torchscript: false
31
+ type_vocab_size: 2
32
+ vocab_size: 28996
33
+ latent_size: 768
34
+
35
+ optimus_bert_tokenizer:
36
+ super_cfg: optimus
37
+ type: optimus_bert_tokenizer
38
+ args:
39
+ do_lower_case: false
40
+ max_len: 512
41
+ vocab_file: lib/model_zoo/optimus_models/vocab/bert-base-cased-vocab.txt
42
+
43
+ optimus_gpt2_decoder:
44
+ super_cfg: optimus
45
+ type: optimus_gpt2_connector
46
+ # pth: pretrained/optimus_gpt2_decoder.pth
47
+ args:
48
+ config:
49
+ architectures:
50
+ - GPT2LMHeadModel
51
+ attn_pdrop: 0.1
52
+ embd_pdrop: 0.1
53
+ finetuning_task: null
54
+ hidden_size: 768
55
+ initializer_range: 0.02
56
+ latent_size: 768
57
+ layer_norm_epsilon: 1.e-05
58
+ max_position_embeddings: 1024
59
+ n_ctx: 1024
60
+ n_embd: 768
61
+ n_head: 12
62
+ n_layer: 12
63
+ n_positions: 1024
64
+ num_attention_heads: 12
65
+ num_hidden_layers: 12
66
+ num_labels: 1
67
+ output_attentions: false
68
+ output_hidden_states: false
69
+ pretrained_config_archive_map:
70
+ gpt2 : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json
71
+ gpt2-medium : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json
72
+ gpt2-large : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json
73
+ pruned_heads: {}
74
+ resid_pdrop: 0.1
75
+ summary_activation: null
76
+ summary_first_dropout: 0.1
77
+ summary_proj_to_labels: true
78
+ summary_type: cls_index
79
+ summary_use_proj: true
80
+ torchscript: false
81
+ vocab_size: 50260
82
+
83
+ optimus_gpt2_tokenizer:
84
+ super_cfg: optimus
85
+ type: optimus_gpt2_tokenizer
86
+ args:
87
+ do_lower_case: false
88
+ max_len: 1024
89
+ vocab_file: lib/model_zoo/optimus_models/vocab/gpt2-vocab.json
90
+ merges_file: lib/model_zoo/optimus_models/vocab/gpt2-merges.txt
91
+
92
+ optimus_v1:
93
+ super_cfg: optimus
94
+ type: optimus_vae_next
95
+ args:
96
+ encoder: MODEL(optimus_bert_encoder)
97
+ decoder: MODEL(optimus_gpt2_decoder)
98
+ tokenizer_encoder: MODEL(optimus_bert_tokenizer)
99
+ tokenizer_decoder: MODEL(optimus_gpt2_tokenizer)
100
+ args:
101
+ latent_size: 768
102
+ # pth: pretrained/optimus-vae.pth
103
+ hfm: ['shi-labs/versatile-diffusion-model', 'pretrained_pth/optimus-vae.pth']
configs/model/vd.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ vd_base:
2
+ symbol: vd
3
+ find_unused_parameters: true
4
+ type: vd_v2_0
5
+ args:
6
+ beta_linear_start: 0.00085
7
+ beta_linear_end: 0.012
8
+ timesteps: 1000
9
+ use_ema: false
10
+
11
+ ###########
12
+ # vd v1.0 #
13
+ ###########
14
+
15
+ vd_four_flow_v1-0:
16
+ super_cfg: vd_base
17
+ args:
18
+ vae_cfg_list:
19
+ - [image, MODEL(autokl_v1)]
20
+ - [text, MODEL(optimus_v1)]
21
+ ctx_cfg_list:
22
+ - [image, MODEL(clip_image_context_encoder)]
23
+ - [text, MODEL(clip_text_context_encoder)]
24
+ diffuser_cfg_list:
25
+ - [image, MODEL(openai_unet_2d_v1)]
26
+ - [text, MODEL(openai_unet_0d_v1_dc)]
27
+ global_layer_ptr: image
28
+ latent_scale_factor:
29
+ image: 0.18215
cusomized_gradio_blocks.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import csv
5
+ import inspect
6
+ import os
7
+ import subprocess
8
+ import tempfile
9
+ import threading
10
+ import warnings
11
+ from pathlib import Path
12
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Tuple
13
+
14
+ import matplotlib
15
+ import matplotlib.pyplot as plt
16
+ import numpy as np
17
+ import PIL
18
+ import PIL.Image
19
+
20
+ import gradio
21
+ from gradio import components, processing_utils, routes, utils
22
+ from gradio.context import Context
23
+ from gradio.documentation import document, set_documentation_group
24
+ from gradio.flagging import CSVLogger
25
+
26
+ if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
27
+ from gradio.components import IOComponent
28
+
29
+ CACHED_FOLDER = "gradio_cached_examples"
30
+ LOG_FILE = "log.csv"
31
+
32
+ def create_myexamples(
33
+ examples: List[Any] | List[List[Any]] | str,
34
+ inputs: IOComponent | List[IOComponent],
35
+ outputs: IOComponent | List[IOComponent] | None = None,
36
+ fn: Callable | None = None,
37
+ cache_examples: bool = False,
38
+ examples_per_page: int = 10,
39
+ _api_mode: bool = False,
40
+ label: str | None = None,
41
+ elem_id: str | None = None,
42
+ run_on_click: bool = False,
43
+ preprocess: bool = True,
44
+ postprocess: bool = True,
45
+ batch: bool = False,):
46
+ """Top-level synchronous function that creates Examples. Provided for backwards compatibility, i.e. so that gr.Examples(...) can be used to create the Examples component."""
47
+ examples_obj = MyExamples(
48
+ examples=examples,
49
+ inputs=inputs,
50
+ outputs=outputs,
51
+ fn=fn,
52
+ cache_examples=cache_examples,
53
+ examples_per_page=examples_per_page,
54
+ _api_mode=_api_mode,
55
+ label=label,
56
+ elem_id=elem_id,
57
+ run_on_click=run_on_click,
58
+ preprocess=preprocess,
59
+ postprocess=postprocess,
60
+ batch=batch,
61
+ _initiated_directly=False,
62
+ )
63
+ utils.synchronize_async(examples_obj.create)
64
+ return examples_obj
65
+
66
+ class MyExamples(gradio.helpers.Examples):
67
+ def __init__(
68
+ self,
69
+ examples: List[Any] | List[List[Any]] | str,
70
+ inputs: IOComponent | List[IOComponent],
71
+ outputs: IOComponent | List[IOComponent] | None = None,
72
+ fn: Callable | None = None,
73
+ cache_examples: bool = False,
74
+ examples_per_page: int = 10,
75
+ _api_mode: bool = False,
76
+ label: str | None = "Examples",
77
+ elem_id: str | None = None,
78
+ run_on_click: bool = False,
79
+ preprocess: bool = True,
80
+ postprocess: bool = True,
81
+ batch: bool = False,
82
+ _initiated_directly: bool = True,):
83
+
84
+ if _initiated_directly:
85
+ warnings.warn(
86
+ "Please use gr.Examples(...) instead of gr.examples.Examples(...) to create the Examples.",
87
+ )
88
+
89
+ if cache_examples and (fn is None or outputs is None):
90
+ raise ValueError("If caching examples, `fn` and `outputs` must be provided")
91
+
92
+ if not isinstance(inputs, list):
93
+ inputs = [inputs]
94
+ if outputs and not isinstance(outputs, list):
95
+ outputs = [outputs]
96
+
97
+ working_directory = Path().absolute()
98
+
99
+ if examples is None:
100
+ raise ValueError("The parameter `examples` cannot be None")
101
+ elif isinstance(examples, list) and (
102
+ len(examples) == 0 or isinstance(examples[0], list)
103
+ ):
104
+ pass
105
+ elif (
106
+ isinstance(examples, list) and len(inputs) == 1
107
+ ): # If there is only one input component, examples can be provided as a regular list instead of a list of lists
108
+ examples = [[e] for e in examples]
109
+ elif isinstance(examples, str):
110
+ if not Path(examples).exists():
111
+ raise FileNotFoundError(
112
+ "Could not find examples directory: " + examples
113
+ )
114
+ working_directory = examples
115
+ if not (Path(examples) / LOG_FILE).exists():
116
+ if len(inputs) == 1:
117
+ examples = [[e] for e in os.listdir(examples)]
118
+ else:
119
+ raise FileNotFoundError(
120
+ "Could not find log file (required for multiple inputs): "
121
+ + LOG_FILE
122
+ )
123
+ else:
124
+ with open(Path(examples) / LOG_FILE) as logs:
125
+ examples = list(csv.reader(logs))
126
+ examples = [
127
+ examples[i][: len(inputs)] for i in range(1, len(examples))
128
+ ] # remove header and unnecessary columns
129
+
130
+ else:
131
+ raise ValueError(
132
+ "The parameter `examples` must either be a string directory or a list"
133
+ "(if there is only 1 input component) or (more generally), a nested "
134
+ "list, where each sublist represents a set of inputs."
135
+ )
136
+
137
+ input_has_examples = [False] * len(inputs)
138
+ for example in examples:
139
+ for idx, example_for_input in enumerate(example):
140
+ # if not (example_for_input is None):
141
+ if True:
142
+ try:
143
+ input_has_examples[idx] = True
144
+ except IndexError:
145
+ pass # If there are more example components than inputs, ignore. This can sometimes be intentional (e.g. loading from a log file where outputs and timestamps are also logged)
146
+
147
+ inputs_with_examples = [
148
+ inp for (inp, keep) in zip(inputs, input_has_examples) if keep
149
+ ]
150
+ non_none_examples = [
151
+ [ex for (ex, keep) in zip(example, input_has_examples) if keep]
152
+ for example in examples
153
+ ]
154
+
155
+ self.examples = examples
156
+ self.non_none_examples = non_none_examples
157
+ self.inputs = inputs
158
+ self.inputs_with_examples = inputs_with_examples
159
+ self.outputs = outputs
160
+ self.fn = fn
161
+ self.cache_examples = cache_examples
162
+ self._api_mode = _api_mode
163
+ self.preprocess = preprocess
164
+ self.postprocess = postprocess
165
+ self.batch = batch
166
+
167
+ with utils.set_directory(working_directory):
168
+ self.processed_examples = [
169
+ [
170
+ component.postprocess(sample)
171
+ for component, sample in zip(inputs, example)
172
+ ]
173
+ for example in examples
174
+ ]
175
+ self.non_none_processed_examples = [
176
+ [ex for (ex, keep) in zip(example, input_has_examples) if keep]
177
+ for example in self.processed_examples
178
+ ]
179
+ if cache_examples:
180
+ for example in self.examples:
181
+ if len([ex for ex in example if ex is not None]) != len(self.inputs):
182
+ warnings.warn(
183
+ "Examples are being cached but not all input components have "
184
+ "example values. This may result in an exception being thrown by "
185
+ "your function. If you do get an error while caching examples, make "
186
+ "sure all of your inputs have example values for all of your examples "
187
+ "or you provide default values for those particular parameters in your function."
188
+ )
189
+ break
190
+
191
+ with utils.set_directory(working_directory):
192
+ self.dataset = components.Dataset(
193
+ components=inputs_with_examples,
194
+ samples=non_none_examples,
195
+ type="index",
196
+ label=label,
197
+ samples_per_page=examples_per_page,
198
+ elem_id=elem_id,
199
+ )
200
+
201
+ self.cached_folder = Path(CACHED_FOLDER) / str(self.dataset._id)
202
+ self.cached_file = Path(self.cached_folder) / "log.csv"
203
+ self.cache_examples = cache_examples
204
+ self.run_on_click = run_on_click
205
+
206
+ from gradio import utils, processing_utils
207
+ from PIL import Image as _Image
208
+ from pathlib import Path
209
+ import numpy as np
210
+
211
+ def customized_postprocess(self, y):
212
+ if y is None:
213
+ return None
214
+
215
+ if isinstance(y, dict):
216
+ if self.tool == "sketch" and self.source in ["upload", "webcam"]:
217
+ y, mask = y["image"], y["mask"]
218
+ if y is None:
219
+ return None
220
+ elif isinstance(y, np.ndarray):
221
+ im = processing_utils.encode_array_to_base64(y)
222
+ elif isinstance(y, _Image.Image):
223
+ im = processing_utils.encode_pil_to_base64(y)
224
+ elif isinstance(y, (str, Path)):
225
+ im = processing_utils.encode_url_or_file_to_base64(y)
226
+ else:
227
+ raise ValueError("Cannot process this value as an Image")
228
+ im = self._format_image(im)
229
+
230
+ if mask is None:
231
+ return im
232
+ elif isinstance(y, np.ndarray):
233
+ mask_im = processing_utils.encode_array_to_base64(mask)
234
+ elif isinstance(y, _Image.Image):
235
+ mask_im = processing_utils.encode_pil_to_base64(mask)
236
+ elif isinstance(y, (str, Path)):
237
+ mask_im = processing_utils.encode_url_or_file_to_base64(mask)
238
+ else:
239
+ raise ValueError("Cannot process this value as an Image")
240
+
241
+ return {"image": im, "mask" : mask_im,}
242
+
243
+ elif isinstance(y, np.ndarray):
244
+ return processing_utils.encode_array_to_base64(y)
245
+ elif isinstance(y, _Image.Image):
246
+ return processing_utils.encode_pil_to_base64(y)
247
+ elif isinstance(y, (str, Path)):
248
+ return processing_utils.encode_url_or_file_to_base64(y)
249
+ else:
250
+ raise ValueError("Cannot process this value as an Image")
251
+
252
+ # def customized_as_example(self, input_data=None):
253
+ # if input_data is None:
254
+ # return str('assets/demo/misc/noimage.jpg')
255
+ # elif isinstance(input_data, dict):
256
+ # im = np.array(PIL.Image.open(input_data["image"])).astype(float)
257
+ # mask = np.array(PIL.Image.open(input_data["mask"])).astype(float)/255
258
+ # imm = (im * (1-mask)).astype(np.uint8)
259
+ # import time
260
+ # ctime = int(time.time()*100)
261
+ # impath = 'assets/demo/temp/temp_{}.png'.format(ctime)
262
+ # PIL.Image.fromarray(imm).save(impath)
263
+ # return str(utils.abspath(impath))
264
+ # else:
265
+ # return str(utils.abspath(input_data))
266
+
267
+ def customized_as_example(self, input_data=None):
268
+ if input_data is None:
269
+ return str('assets/demo/misc/noimage.jpg')
270
+ else:
271
+ return str(utils.abspath(input_data))
lib/__init__.py ADDED
File without changes
lib/cfg_helper.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import shutil
4
+ import copy
5
+ import time
6
+ import pprint
7
+ import numpy as np
8
+ import torch
9
+ import matplotlib
10
+ import argparse
11
+ import json
12
+ import yaml
13
+ from easydict import EasyDict as edict
14
+
15
+ from .model_zoo import get_model
16
+
17
+ ############
18
+ # cfg_bank #
19
+ ############
20
+
21
+ def cfg_solvef(cmd, root):
22
+ if not isinstance(cmd, str):
23
+ return cmd
24
+
25
+ if cmd.find('SAME')==0:
26
+ zoom = root
27
+ p = cmd[len('SAME'):].strip('()').split('.')
28
+ p = [pi.strip() for pi in p]
29
+ for pi in p:
30
+ try:
31
+ pi = int(pi)
32
+ except:
33
+ pass
34
+
35
+ try:
36
+ zoom = zoom[pi]
37
+ except:
38
+ return cmd
39
+ return cfg_solvef(zoom, root)
40
+
41
+ if cmd.find('SEARCH')==0:
42
+ zoom = root
43
+ p = cmd[len('SEARCH'):].strip('()').split('.')
44
+ p = [pi.strip() for pi in p]
45
+ find = True
46
+ # Depth first search
47
+ for pi in p:
48
+ try:
49
+ pi = int(pi)
50
+ except:
51
+ pass
52
+
53
+ try:
54
+ zoom = zoom[pi]
55
+ except:
56
+ find = False
57
+ break
58
+
59
+ if find:
60
+ return cfg_solvef(zoom, root)
61
+ else:
62
+ if isinstance(root, dict):
63
+ for ri in root:
64
+ rv = cfg_solvef(cmd, root[ri])
65
+ if rv != cmd:
66
+ return rv
67
+ if isinstance(root, list):
68
+ for ri in root:
69
+ rv = cfg_solvef(cmd, ri)
70
+ if rv != cmd:
71
+ return rv
72
+ return cmd
73
+
74
+ if cmd.find('MODEL')==0:
75
+ goto = cmd[len('MODEL'):].strip('()')
76
+ return model_cfg_bank()(goto)
77
+
78
+ if cmd.find('DATASET')==0:
79
+ goto = cmd[len('DATASET'):].strip('()')
80
+ return dataset_cfg_bank()(goto)
81
+
82
+ return cmd
83
+
84
+ def cfg_solve(cfg, cfg_root):
85
+ # The function solve cfg element such that
86
+ # all sorrogate input are settled.
87
+ # (i.e. SAME(***) )
88
+ if isinstance(cfg, list):
89
+ for i in range(len(cfg)):
90
+ if isinstance(cfg[i], (list, dict)):
91
+ cfg[i] = cfg_solve(cfg[i], cfg_root)
92
+ else:
93
+ cfg[i] = cfg_solvef(cfg[i], cfg_root)
94
+ if isinstance(cfg, dict):
95
+ for k in cfg:
96
+ if isinstance(cfg[k], (list, dict)):
97
+ cfg[k] = cfg_solve(cfg[k], cfg_root)
98
+ else:
99
+ cfg[k] = cfg_solvef(cfg[k], cfg_root)
100
+ return cfg
101
+
102
+ class model_cfg_bank(object):
103
+ def __init__(self):
104
+ self.cfg_dir = osp.join('configs', 'model')
105
+ self.cfg_bank = edict()
106
+
107
+ def __call__(self, name):
108
+ if name not in self.cfg_bank:
109
+ cfg_path = self.get_yaml_path(name)
110
+ with open(cfg_path, 'r') as f:
111
+ cfg_new = yaml.load(
112
+ f, Loader=yaml.FullLoader)
113
+ cfg_new = edict(cfg_new)
114
+ self.cfg_bank.update(cfg_new)
115
+
116
+ cfg = self.cfg_bank[name]
117
+ cfg.name = name
118
+ if 'super_cfg' not in cfg:
119
+ cfg = cfg_solve(cfg, cfg)
120
+ self.cfg_bank[name] = cfg
121
+ return copy.deepcopy(cfg)
122
+
123
+ super_cfg = self.__call__(cfg.super_cfg)
124
+ # unlike other field,
125
+ # args will not be replaced but update.
126
+ if 'args' in cfg:
127
+ if 'args' in super_cfg:
128
+ super_cfg.args.update(cfg.args)
129
+ else:
130
+ super_cfg.args = cfg.args
131
+ cfg.pop('args')
132
+
133
+ super_cfg.update(cfg)
134
+ super_cfg.pop('super_cfg')
135
+ cfg = super_cfg
136
+ try:
137
+ delete_args = cfg.pop('delete_args')
138
+ except:
139
+ delete_args = []
140
+
141
+ for dargs in delete_args:
142
+ cfg.args.pop(dargs)
143
+
144
+ cfg = cfg_solve(cfg, cfg)
145
+ self.cfg_bank[name] = cfg
146
+ return copy.deepcopy(cfg)
147
+
148
+ def get_yaml_path(self, name):
149
+ if name.find('openai_unet')==0:
150
+ return osp.join(
151
+ self.cfg_dir, 'openai_unet.yaml')
152
+ elif (name.find('clip')==0) or (name.find('openclip')==0):
153
+ return osp.join(
154
+ self.cfg_dir, 'clip.yaml')
155
+ elif name.find('vd')==0:
156
+ return osp.join(
157
+ self.cfg_dir, 'vd.yaml')
158
+ elif name.find('optimus')==0:
159
+ return osp.join(
160
+ self.cfg_dir, 'optimus.yaml')
161
+ elif name.find('autokl')==0:
162
+ return osp.join(
163
+ self.cfg_dir, 'autokl.yaml')
164
+ else:
165
+ raise ValueError
166
+
167
+ class dataset_cfg_bank(object):
168
+ def __init__(self):
169
+ self.cfg_dir = osp.join('configs', 'dataset')
170
+ self.cfg_bank = edict()
171
+
172
+ def __call__(self, name):
173
+ if name not in self.cfg_bank:
174
+ cfg_path = self.get_yaml_path(name)
175
+ with open(cfg_path, 'r') as f:
176
+ cfg_new = yaml.load(
177
+ f, Loader=yaml.FullLoader)
178
+ cfg_new = edict(cfg_new)
179
+ self.cfg_bank.update(cfg_new)
180
+
181
+ cfg = self.cfg_bank[name]
182
+ cfg.name = name
183
+ if cfg.get('super_cfg', None) is None:
184
+ cfg = cfg_solve(cfg, cfg)
185
+ self.cfg_bank[name] = cfg
186
+ return copy.deepcopy(cfg)
187
+
188
+ super_cfg = self.__call__(cfg.super_cfg)
189
+ super_cfg.update(cfg)
190
+ cfg = super_cfg
191
+ cfg.super_cfg = None
192
+ try:
193
+ delete = cfg.pop('delete')
194
+ except:
195
+ delete = []
196
+
197
+ for dargs in delete:
198
+ cfg.pop(dargs)
199
+
200
+ cfg = cfg_solve(cfg, cfg)
201
+ self.cfg_bank[name] = cfg
202
+ return copy.deepcopy(cfg)
203
+
204
+ def get_yaml_path(self, name):
205
+ if name.find('laion2b')==0:
206
+ return osp.join(
207
+ self.cfg_dir, 'laion2b.yaml')
208
+ else:
209
+ raise ValueError
210
+
211
+ class experiment_cfg_bank(object):
212
+ def __init__(self):
213
+ self.cfg_dir = osp.join('configs', 'experiment')
214
+ self.cfg_bank = edict()
215
+
216
+ def __call__(self, name):
217
+ if name not in self.cfg_bank:
218
+ cfg_path = self.get_yaml_path(name)
219
+ with open(cfg_path, 'r') as f:
220
+ cfg = yaml.load(
221
+ f, Loader=yaml.FullLoader)
222
+ cfg = edict(cfg)
223
+
224
+ cfg = cfg_solve(cfg, cfg)
225
+ cfg = cfg_solve(cfg, cfg)
226
+ # twice for SEARCH
227
+ self.cfg_bank[name] = cfg
228
+ return copy.deepcopy(cfg)
229
+
230
+ def get_yaml_path(self, name):
231
+ return osp.join(
232
+ self.cfg_dir, name+'.yaml')
233
+
234
+ def load_cfg_yaml(path):
235
+ if osp.isfile(path):
236
+ cfg_path = path
237
+ elif osp.isfile(osp.join('configs', 'experiment', path)):
238
+ cfg_path = osp.join('configs', 'experiment', path)
239
+ elif osp.isfile(osp.join('configs', 'experiment', path+'.yaml')):
240
+ cfg_path = osp.join('configs', 'experiment', path+'.yaml')
241
+ else:
242
+ assert False, 'No such config!'
243
+
244
+ with open(cfg_path, 'r') as f:
245
+ cfg = yaml.load(f, Loader=yaml.FullLoader)
246
+ cfg = edict(cfg)
247
+ cfg = cfg_solve(cfg, cfg)
248
+ cfg = cfg_solve(cfg, cfg)
249
+ return cfg
250
+
251
+ ##############
252
+ # cfg_helper #
253
+ ##############
254
+
255
+ def get_experiment_id(ref=None):
256
+ if ref is None:
257
+ time.sleep(0.5)
258
+ return int(time.time()*100)
259
+ else:
260
+ try:
261
+ return int(ref)
262
+ except:
263
+ pass
264
+
265
+ _, ref = osp.split(ref)
266
+ ref = ref.split('_')[0]
267
+ try:
268
+ return int(ref)
269
+ except:
270
+ assert False, 'Invalid experiment ID!'
271
+
272
+ def record_resume_cfg(path):
273
+ cnt = 0
274
+ while True:
275
+ if osp.exists(path+'.{:04d}'.format(cnt)):
276
+ cnt += 1
277
+ continue
278
+ shutil.copyfile(path, path+'.{:04d}'.format(cnt))
279
+ break
280
+
281
+ def get_command_line_args():
282
+ parser = argparse.ArgumentParser()
283
+ parser.add_argument('--debug', action='store_true', default=False)
284
+ parser.add_argument('--config', type=str)
285
+ parser.add_argument('--gpu', nargs='+', type=int)
286
+
287
+ parser.add_argument('--node_rank', type=int)
288
+ parser.add_argument('--node_list', nargs='+', type=str)
289
+ parser.add_argument('--nodes', type=int)
290
+ parser.add_argument('--addr', type=str, default='127.0.0.1')
291
+ parser.add_argument('--port', type=int, default=11233)
292
+
293
+ parser.add_argument('--signature', nargs='+', type=str)
294
+ parser.add_argument('--seed', type=int)
295
+
296
+ parser.add_argument('--eval', type=str)
297
+ parser.add_argument('--eval_subdir', type=str)
298
+ parser.add_argument('--pretrained', type=str)
299
+
300
+ parser.add_argument('--resume_dir', type=str)
301
+ parser.add_argument('--resume_step', type=int)
302
+ parser.add_argument('--resume_weight', type=str)
303
+
304
+ args = parser.parse_args()
305
+
306
+ # Special handling the resume
307
+ if args.resume_dir is not None:
308
+ cfg = edict()
309
+ cfg.env = edict()
310
+ cfg.env.debug = args.debug
311
+ cfg.env.resume = edict()
312
+ cfg.env.resume.dir = args.resume_dir
313
+ cfg.env.resume.step = args.resume_step
314
+ cfg.env.resume.weight = args.resume_weight
315
+ return cfg
316
+
317
+ cfg = load_cfg_yaml(args.config)
318
+ cfg.env.debug = args.debug
319
+ cfg.env.gpu_device = [0] if args.gpu is None else list(args.gpu)
320
+ cfg.env.master_addr = args.addr
321
+ cfg.env.master_port = args.port
322
+ cfg.env.dist_url = 'tcp://{}:{}'.format(args.addr, args.port)
323
+
324
+ if args.node_list is None:
325
+ cfg.env.node_rank = 0 if args.node_rank is None else args.node_rank
326
+ cfg.env.nodes = 1 if args.nodes is None else args.nodes
327
+ else:
328
+ import socket
329
+ hostname = socket.gethostname()
330
+ assert cfg.env.master_addr == args.node_list[0]
331
+ cfg.env.node_rank = args.node_list.index(hostname)
332
+ cfg.env.nodes = len(args.node_list)
333
+ cfg.env.node_list = args.node_list
334
+
335
+ istrain = False if args.eval is not None else True
336
+ isdebug = cfg.env.debug
337
+
338
+ if istrain:
339
+ if isdebug:
340
+ cfg.env.experiment_id = 999999999999
341
+ cfg.train.signature = ['debug']
342
+ else:
343
+ cfg.env.experiment_id = get_experiment_id()
344
+ if args.signature is not None:
345
+ cfg.train.signature = args.signature
346
+ else:
347
+ if 'train' in cfg:
348
+ cfg.pop('train')
349
+ cfg.env.experiment_id = get_experiment_id(args.eval)
350
+ if args.signature is not None:
351
+ cfg.eval.signature = args.signature
352
+
353
+ if isdebug and (args.eval is None):
354
+ cfg.env.experiment_id = 999999999999
355
+ cfg.eval.signature = ['debug']
356
+
357
+ if args.eval_subdir is not None:
358
+ if isdebug:
359
+ cfg.eval.eval_subdir = 'debug'
360
+ else:
361
+ cfg.eval.eval_subdir = args.eval_subdir
362
+ if args.pretrained is not None:
363
+ cfg.eval.pretrained = args.pretrained
364
+ # The override pretrained over the setting in cfg.model
365
+
366
+ if args.seed is not None:
367
+ cfg.env.rnd_seed = args.seed
368
+
369
+ return cfg
370
+
371
+ def cfg_initiates(cfg):
372
+ cfge = cfg.env
373
+ isdebug = cfge.debug
374
+ isresume = 'resume' in cfge
375
+ istrain = 'train' in cfg
376
+ haseval = 'eval' in cfg
377
+ cfgt = cfg.train if istrain else None
378
+ cfgv = cfg.eval if haseval else None
379
+
380
+ ###############################
381
+ # get some environment params #
382
+ ###############################
383
+
384
+ cfge.computer = os.uname()
385
+ cfge.torch_version = str(torch.__version__)
386
+
387
+ ##########
388
+ # resume #
389
+ ##########
390
+
391
+ if isresume:
392
+ resume_cfg_path = osp.join(cfge.resume.dir, 'config.yaml')
393
+ record_resume_cfg(resume_cfg_path)
394
+ with open(resume_cfg_path, 'r') as f:
395
+ cfg_resume = yaml.load(f, Loader=yaml.FullLoader)
396
+ cfg_resume = edict(cfg_resume)
397
+ cfg_resume.env.update(cfge)
398
+ cfg = cfg_resume
399
+ cfge = cfg.env
400
+ log_file = cfg.train.log_file
401
+
402
+ print('')
403
+ print('##########')
404
+ print('# resume #')
405
+ print('##########')
406
+ print('')
407
+ with open(log_file, 'a') as f:
408
+ print('', file=f)
409
+ print('##########', file=f)
410
+ print('# resume #', file=f)
411
+ print('##########', file=f)
412
+ print('', file=f)
413
+
414
+ pprint.pprint(cfg)
415
+ with open(log_file, 'a') as f:
416
+ pprint.pprint(cfg, f)
417
+
418
+ ####################
419
+ # node distributed #
420
+ ####################
421
+
422
+ if cfg.env.master_addr!='127.0.0.1':
423
+ os.environ['MASTER_ADDR'] = cfge.master_addr
424
+ os.environ['MASTER_PORT'] = '{}'.format(cfge.master_port)
425
+ if cfg.env.dist_backend=='nccl':
426
+ os.environ['NCCL_SOCKET_FAMILY'] = 'AF_INET'
427
+ if cfg.env.dist_backend=='gloo':
428
+ os.environ['GLOO_SOCKET_FAMILY'] = 'AF_INET'
429
+
430
+ #######################
431
+ # cuda visible device #
432
+ #######################
433
+
434
+ os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
435
+ [str(gid) for gid in cfge.gpu_device])
436
+
437
+ #####################
438
+ # return resume cfg #
439
+ #####################
440
+
441
+ if isresume:
442
+ return cfg
443
+
444
+ #############################################
445
+ # some misc setting that not need in resume #
446
+ #############################################
447
+
448
+ cfgm = cfg.model
449
+ cfge.gpu_count = len(cfge.gpu_device)
450
+
451
+ ##########################################
452
+ # align batch size and num worker config #
453
+ ##########################################
454
+
455
+ gpu_n = cfge.gpu_count * cfge.nodes
456
+ def align_batch_size(bs, bs_per_gpu):
457
+ assert (bs is not None) or (bs_per_gpu is not None)
458
+ bs = bs_per_gpu * gpu_n if bs is None else bs
459
+ bs_per_gpu = bs // gpu_n if bs_per_gpu is None else bs_per_gpu
460
+ assert (bs == bs_per_gpu * gpu_n)
461
+ return bs, bs_per_gpu
462
+
463
+ if istrain:
464
+ cfgt.batch_size, cfgt.batch_size_per_gpu = \
465
+ align_batch_size(cfgt.batch_size, cfgt.batch_size_per_gpu)
466
+ cfgt.dataset_num_workers, cfgt.dataset_num_workers_per_gpu = \
467
+ align_batch_size(cfgt.dataset_num_workers, cfgt.dataset_num_workers_per_gpu)
468
+ if haseval:
469
+ cfgv.batch_size, cfgv.batch_size_per_gpu = \
470
+ align_batch_size(cfgv.batch_size, cfgv.batch_size_per_gpu)
471
+ cfgv.dataset_num_workers, cfgv.dataset_num_workers_per_gpu = \
472
+ align_batch_size(cfgv.dataset_num_workers, cfgv.dataset_num_workers_per_gpu)
473
+
474
+ ##################
475
+ # create log dir #
476
+ ##################
477
+
478
+ if istrain:
479
+ if not isdebug:
480
+ sig = cfgt.get('signature', [])
481
+ sig = sig + ['s{}'.format(cfge.rnd_seed)]
482
+ else:
483
+ sig = ['debug']
484
+
485
+ log_dir = [
486
+ cfge.log_root_dir,
487
+ '{}_{}'.format(cfgm.symbol, cfgt.dataset.symbol),
488
+ '_'.join([str(cfge.experiment_id)] + sig)
489
+ ]
490
+ log_dir = osp.join(*log_dir)
491
+ log_file = osp.join(log_dir, 'train.log')
492
+ if not osp.exists(log_file):
493
+ os.makedirs(osp.dirname(log_file))
494
+ cfgt.log_dir = log_dir
495
+ cfgt.log_file = log_file
496
+
497
+ if haseval:
498
+ cfgv.log_dir = log_dir
499
+ cfgv.log_file = log_file
500
+ else:
501
+ model_symbol = cfgm.symbol
502
+ if cfgv.get('dataset', None) is None:
503
+ dataset_symbol = 'nodataset'
504
+ else:
505
+ dataset_symbol = cfgv.dataset.symbol
506
+
507
+ log_dir = osp.join(cfge.log_root_dir, '{}_{}'.format(model_symbol, dataset_symbol))
508
+ exp_dir = search_experiment_folder(log_dir, cfge.experiment_id)
509
+ if exp_dir is None:
510
+ if not isdebug:
511
+ sig = cfgv.get('signature', []) + ['evalonly']
512
+ else:
513
+ sig = ['debug']
514
+ exp_dir = '_'.join([str(cfge.experiment_id)] + sig)
515
+
516
+ eval_subdir = cfgv.get('eval_subdir', None)
517
+ # override subdir in debug mode (if eval_subdir is set)
518
+ eval_subdir = 'debug' if (eval_subdir is not None) and isdebug else eval_subdir
519
+
520
+ if eval_subdir is not None:
521
+ log_dir = osp.join(log_dir, exp_dir, eval_subdir)
522
+ else:
523
+ log_dir = osp.join(log_dir, exp_dir)
524
+
525
+ disable_log_override = cfgv.get('disable_log_override', False)
526
+ if osp.isdir(log_dir):
527
+ if disable_log_override:
528
+ assert False, 'Override an exsited log_dir is disabled at [{}]'.format(log_dir)
529
+ else:
530
+ os.makedirs(log_dir)
531
+
532
+ log_file = osp.join(log_dir, 'eval.log')
533
+ cfgv.log_dir = log_dir
534
+ cfgv.log_file = log_file
535
+
536
+ ######################
537
+ # print and save cfg #
538
+ ######################
539
+
540
+ pprint.pprint(cfg)
541
+ if cfge.node_rank==0:
542
+ with open(log_file, 'w') as f:
543
+ pprint.pprint(cfg, f)
544
+ with open(osp.join(log_dir, 'config.yaml'), 'w') as f:
545
+ yaml.dump(edict_2_dict(cfg), f)
546
+ else:
547
+ with open(osp.join(log_dir, 'config.yaml.{}'.format(cfge.node_rank)), 'w') as f:
548
+ yaml.dump(edict_2_dict(cfg), f)
549
+
550
+ #############
551
+ # save code #
552
+ #############
553
+
554
+ save_code = False
555
+ if istrain:
556
+ save_code = cfgt.get('save_code', False)
557
+ elif haseval:
558
+ save_code = cfgv.get('save_code', False)
559
+ save_code = save_code and (cfge.node_rank==0)
560
+
561
+ if save_code:
562
+ codedir = osp.join(log_dir, 'code')
563
+ if osp.exists(codedir):
564
+ shutil.rmtree(codedir)
565
+ for d in ['configs', 'lib']:
566
+ fromcodedir = d
567
+ tocodedir = osp.join(codedir, d)
568
+ shutil.copytree(
569
+ fromcodedir, tocodedir,
570
+ ignore=shutil.ignore_patterns(
571
+ '*__pycache__*', '*build*'))
572
+ for codei in os.listdir('.'):
573
+ if osp.splitext(codei)[1] == 'py':
574
+ shutil.copy(codei, codedir)
575
+
576
+ #######################
577
+ # set matplotlib mode #
578
+ #######################
579
+
580
+ if 'matplotlib_mode' in cfge:
581
+ try:
582
+ matplotlib.use(cfge.matplotlib_mode)
583
+ except:
584
+ print('Warning: matplotlib mode [{}] failed to be set!'.format(cfge.matplotlib_mode))
585
+
586
+ return cfg
587
+
588
+ def edict_2_dict(x):
589
+ if isinstance(x, dict):
590
+ xnew = {}
591
+ for k in x:
592
+ xnew[k] = edict_2_dict(x[k])
593
+ return xnew
594
+ elif isinstance(x, list):
595
+ xnew = []
596
+ for i in range(len(x)):
597
+ xnew.append( edict_2_dict(x[i]) )
598
+ return xnew
599
+ else:
600
+ return x
601
+
602
+ def search_experiment_folder(root, exid):
603
+ target = None
604
+ for fi in os.listdir(root):
605
+ if not osp.isdir(osp.join(root, fi)):
606
+ continue
607
+ if int(fi.split('_')[0]) == exid:
608
+ if target is not None:
609
+ return None # duplicated
610
+ elif target is None:
611
+ target = fi
612
+ return target
lib/cfg_holder.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ def singleton(class_):
4
+ instances = {}
5
+ def getinstance(*args, **kwargs):
6
+ if class_ not in instances:
7
+ instances[class_] = class_(*args, **kwargs)
8
+ return instances[class_]
9
+ return getinstance
10
+
11
+ ##############
12
+ # cfg_holder #
13
+ ##############
14
+
15
+ @singleton
16
+ class cfg_unique_holder(object):
17
+ def __init__(self):
18
+ self.cfg = None
19
+ # this is use to track the main codes.
20
+ self.code = set()
21
+ def save_cfg(self, cfg):
22
+ self.cfg = copy.deepcopy(cfg)
23
+ def add_code(self, code):
24
+ """
25
+ A new main code is reached and
26
+ its name is added.
27
+ """
28
+ self.code.add(code)
lib/log_service.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timeit
2
+ import numpy as np
3
+ import os
4
+ import os.path as osp
5
+ import shutil
6
+ import copy
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.distributed as dist
10
+ from .cfg_holder import cfg_unique_holder as cfguh
11
+ from . import sync
12
+
13
+ print_console_local_rank0_only = True
14
+
15
+ def print_log(*console_info):
16
+ local_rank = sync.get_rank('local')
17
+ if print_console_local_rank0_only and (local_rank!=0):
18
+ return
19
+ console_info = [str(i) for i in console_info]
20
+ console_info = ' '.join(console_info)
21
+ print(console_info)
22
+
23
+ if local_rank!=0:
24
+ return
25
+
26
+ log_file = None
27
+ try:
28
+ log_file = cfguh().cfg.train.log_file
29
+ except:
30
+ try:
31
+ log_file = cfguh().cfg.eval.log_file
32
+ except:
33
+ return
34
+ if log_file is not None:
35
+ with open(log_file, 'a') as f:
36
+ f.write(console_info + '\n')
37
+
38
+ class distributed_log_manager(object):
39
+ def __init__(self):
40
+ self.sum = {}
41
+ self.cnt = {}
42
+ self.time_check = timeit.default_timer()
43
+
44
+ cfgt = cfguh().cfg.train
45
+ use_tensorboard = getattr(cfgt, 'log_tensorboard', False)
46
+
47
+ self.ddp = sync.is_ddp()
48
+ self.rank = sync.get_rank('local')
49
+ self.world_size = sync.get_world_size('local')
50
+
51
+ self.tb = None
52
+ if use_tensorboard and (self.rank==0):
53
+ import tensorboardX
54
+ monitoring_dir = osp.join(cfguh().cfg.train.log_dir, 'tensorboard')
55
+ self.tb = tensorboardX.SummaryWriter(osp.join(monitoring_dir))
56
+
57
+ def accumulate(self, n, **data):
58
+ if n < 0:
59
+ raise ValueError
60
+
61
+ for itemn, di in data.items():
62
+ if itemn in self.sum:
63
+ self.sum[itemn] += di * n
64
+ self.cnt[itemn] += n
65
+ else:
66
+ self.sum[itemn] = di * n
67
+ self.cnt[itemn] = n
68
+
69
+ def get_mean_value_dict(self):
70
+ value_gather = [
71
+ self.sum[itemn]/self.cnt[itemn] \
72
+ for itemn in sorted(self.sum.keys()) ]
73
+
74
+ value_gather_tensor = torch.FloatTensor(value_gather).to(self.rank)
75
+ if self.ddp:
76
+ dist.all_reduce(value_gather_tensor, op=dist.ReduceOp.SUM)
77
+ value_gather_tensor /= self.world_size
78
+
79
+ mean = {}
80
+ for idx, itemn in enumerate(sorted(self.sum.keys())):
81
+ mean[itemn] = value_gather_tensor[idx].item()
82
+ return mean
83
+
84
+ def tensorboard_log(self, step, data, mode='train', **extra):
85
+ if self.tb is None:
86
+ return
87
+ if mode == 'train':
88
+ self.tb.add_scalar('other/epochn', extra['epochn'], step)
89
+ if 'lr' in extra:
90
+ self.tb.add_scalar('other/lr', extra['lr'], step)
91
+ for itemn, di in data.items():
92
+ if itemn.find('loss') == 0:
93
+ self.tb.add_scalar('loss/'+itemn, di, step)
94
+ elif itemn == 'Loss':
95
+ self.tb.add_scalar('Loss', di, step)
96
+ else:
97
+ self.tb.add_scalar('other/'+itemn, di, step)
98
+ elif mode == 'eval':
99
+ if isinstance(data, dict):
100
+ for itemn, di in data.items():
101
+ self.tb.add_scalar('eval/'+itemn, di, step)
102
+ else:
103
+ self.tb.add_scalar('eval', data, step)
104
+ return
105
+
106
+ def train_summary(self, itern, epochn, samplen, lr, tbstep=None):
107
+ console_info = [
108
+ 'Iter:{}'.format(itern),
109
+ 'Epoch:{}'.format(epochn),
110
+ 'Sample:{}'.format(samplen),]
111
+
112
+ if lr is not None:
113
+ console_info += ['LR:{:.4E}'.format(lr)]
114
+
115
+ mean = self.get_mean_value_dict()
116
+
117
+ tbstep = itern if tbstep is None else tbstep
118
+ self.tensorboard_log(
119
+ tbstep, mean, mode='train',
120
+ itern=itern, epochn=epochn, lr=lr)
121
+
122
+ loss = mean.pop('Loss')
123
+ mean_info = ['Loss:{:.4f}'.format(loss)] + [
124
+ '{}:{:.4f}'.format(itemn, mean[itemn]) \
125
+ for itemn in sorted(mean.keys()) \
126
+ if itemn.find('loss') == 0
127
+ ]
128
+ console_info += mean_info
129
+ console_info.append('Time:{:.2f}s'.format(
130
+ timeit.default_timer() - self.time_check))
131
+ return ' , '.join(console_info)
132
+
133
+ def clear(self):
134
+ self.sum = {}
135
+ self.cnt = {}
136
+ self.time_check = timeit.default_timer()
137
+
138
+ def tensorboard_close(self):
139
+ if self.tb is not None:
140
+ self.tb.close()
141
+
142
+ # ----- also include some small utils -----
143
+
144
+ def torch_to_numpy(*argv):
145
+ if len(argv) > 1:
146
+ data = list(argv)
147
+ else:
148
+ data = argv[0]
149
+
150
+ if isinstance(data, torch.Tensor):
151
+ return data.to('cpu').detach().numpy()
152
+
153
+ elif isinstance(data, (list, tuple)):
154
+ out = []
155
+ for di in data:
156
+ out.append(torch_to_numpy(di))
157
+ return out
158
+
159
+ elif isinstance(data, dict):
160
+ out = {}
161
+ for ni, di in data.items():
162
+ out[ni] = torch_to_numpy(di)
163
+ return out
164
+
165
+ else:
166
+ return data
lib/model_zoo/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .common.get_model import get_model
2
+ from .common.get_optimizer import get_optimizer
3
+ from .common.get_scheduler import get_scheduler
4
+ from .common.utils import get_unit
lib/model_zoo/attention.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+
8
+ from .diffusion_utils import checkpoint
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def uniq(arr):
16
+ return{el: True for el in arr}.keys()
17
+
18
+
19
+ def default(val, d):
20
+ if exists(val):
21
+ return val
22
+ return d() if isfunction(d) else d
23
+
24
+
25
+ def max_neg_value(t):
26
+ return -torch.finfo(t.dtype).max
27
+
28
+
29
+ def init_(tensor):
30
+ dim = tensor.shape[-1]
31
+ std = 1 / math.sqrt(dim)
32
+ tensor.uniform_(-std, std)
33
+ return tensor
34
+
35
+
36
+ # feedforward
37
+ class GEGLU(nn.Module):
38
+ def __init__(self, dim_in, dim_out):
39
+ super().__init__()
40
+ self.proj = nn.Linear(dim_in, dim_out * 2)
41
+
42
+ def forward(self, x):
43
+ x, gate = self.proj(x).chunk(2, dim=-1)
44
+ return x * F.gelu(gate)
45
+
46
+
47
+ class FeedForward(nn.Module):
48
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49
+ super().__init__()
50
+ inner_dim = int(dim * mult)
51
+ dim_out = default(dim_out, dim)
52
+ project_in = nn.Sequential(
53
+ nn.Linear(dim, inner_dim),
54
+ nn.GELU()
55
+ ) if not glu else GEGLU(dim, inner_dim)
56
+
57
+ self.net = nn.Sequential(
58
+ project_in,
59
+ nn.Dropout(dropout),
60
+ nn.Linear(inner_dim, dim_out)
61
+ )
62
+
63
+ def forward(self, x):
64
+ return self.net(x)
65
+
66
+
67
+ def zero_module(module):
68
+ """
69
+ Zero out the parameters of a module and return it.
70
+ """
71
+ for p in module.parameters():
72
+ p.detach().zero_()
73
+ return module
74
+
75
+
76
+ def Normalize(in_channels):
77
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78
+
79
+
80
+ class LinearAttention(nn.Module):
81
+ def __init__(self, dim, heads=4, dim_head=32):
82
+ super().__init__()
83
+ self.heads = heads
84
+ hidden_dim = dim_head * heads
85
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87
+
88
+ def forward(self, x):
89
+ b, c, h, w = x.shape
90
+ qkv = self.to_qkv(x)
91
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92
+ k = k.softmax(dim=-1)
93
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
94
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
95
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96
+ return self.to_out(out)
97
+
98
+
99
+ class SpatialSelfAttention(nn.Module):
100
+ def __init__(self, in_channels):
101
+ super().__init__()
102
+ self.in_channels = in_channels
103
+
104
+ self.norm = Normalize(in_channels)
105
+ self.q = torch.nn.Conv2d(in_channels,
106
+ in_channels,
107
+ kernel_size=1,
108
+ stride=1,
109
+ padding=0)
110
+ self.k = torch.nn.Conv2d(in_channels,
111
+ in_channels,
112
+ kernel_size=1,
113
+ stride=1,
114
+ padding=0)
115
+ self.v = torch.nn.Conv2d(in_channels,
116
+ in_channels,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0)
120
+ self.proj_out = torch.nn.Conv2d(in_channels,
121
+ in_channels,
122
+ kernel_size=1,
123
+ stride=1,
124
+ padding=0)
125
+
126
+ def forward(self, x):
127
+ h_ = x
128
+ h_ = self.norm(h_)
129
+ q = self.q(h_)
130
+ k = self.k(h_)
131
+ v = self.v(h_)
132
+
133
+ # compute attention
134
+ b,c,h,w = q.shape
135
+ q = rearrange(q, 'b c h w -> b (h w) c')
136
+ k = rearrange(k, 'b c h w -> b c (h w)')
137
+ w_ = torch.einsum('bij,bjk->bik', q, k)
138
+
139
+ w_ = w_ * (int(c)**(-0.5))
140
+ w_ = torch.nn.functional.softmax(w_, dim=2)
141
+
142
+ # attend to values
143
+ v = rearrange(v, 'b c h w -> b c (h w)')
144
+ w_ = rearrange(w_, 'b i j -> b j i')
145
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
146
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147
+ h_ = self.proj_out(h_)
148
+
149
+ return x+h_
150
+
151
+
152
+ class CrossAttention(nn.Module):
153
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
154
+ super().__init__()
155
+ inner_dim = dim_head * heads
156
+ context_dim = default(context_dim, query_dim)
157
+
158
+ self.scale = dim_head ** -0.5
159
+ self.heads = heads
160
+
161
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
163
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
164
+
165
+ self.to_out = nn.Sequential(
166
+ nn.Linear(inner_dim, query_dim),
167
+ nn.Dropout(dropout)
168
+ )
169
+
170
+ def forward(self, x, context=None, mask=None):
171
+ h = self.heads
172
+
173
+ q = self.to_q(x)
174
+ context = default(context, x)
175
+ k = self.to_k(context)
176
+ v = self.to_v(context)
177
+
178
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
179
+
180
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
181
+
182
+ if exists(mask):
183
+ mask = rearrange(mask, 'b ... -> b (...)')
184
+ max_neg_value = -torch.finfo(sim.dtype).max
185
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
186
+ sim.masked_fill_(~mask, max_neg_value)
187
+
188
+ # attention, what we cannot get enough of
189
+ attn = sim.softmax(dim=-1)
190
+
191
+ out = einsum('b i j, b j d -> b i d', attn, v)
192
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
193
+ return self.to_out(out)
194
+
195
+
196
+ class BasicTransformerBlock(nn.Module):
197
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
198
+ disable_self_attn=False):
199
+ super().__init__()
200
+ self.disable_self_attn = disable_self_attn
201
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
202
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
203
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
204
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
205
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
206
+ self.norm1 = nn.LayerNorm(dim)
207
+ self.norm2 = nn.LayerNorm(dim)
208
+ self.norm3 = nn.LayerNorm(dim)
209
+ self.checkpoint = checkpoint
210
+
211
+ def forward(self, x, context=None):
212
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
213
+
214
+ def _forward(self, x, context=None):
215
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
216
+ x = self.attn2(self.norm2(x), context=context) + x
217
+ x = self.ff(self.norm3(x)) + x
218
+ return x
219
+
220
+
221
+ class SpatialTransformer(nn.Module):
222
+ """
223
+ Transformer block for image-like data.
224
+ First, project the input (aka embedding)
225
+ and reshape to b, t, d.
226
+ Then apply standard transformer action.
227
+ Finally, reshape to image
228
+ """
229
+ def __init__(self, in_channels, n_heads, d_head,
230
+ depth=1, dropout=0., context_dim=None,
231
+ disable_self_attn=False):
232
+ super().__init__()
233
+ self.in_channels = in_channels
234
+ inner_dim = n_heads * d_head
235
+ self.norm = Normalize(in_channels)
236
+
237
+ self.proj_in = nn.Conv2d(in_channels,
238
+ inner_dim,
239
+ kernel_size=1,
240
+ stride=1,
241
+ padding=0)
242
+
243
+ self.transformer_blocks = nn.ModuleList(
244
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
245
+ disable_self_attn=disable_self_attn)
246
+ for d in range(depth)]
247
+ )
248
+
249
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
250
+ in_channels,
251
+ kernel_size=1,
252
+ stride=1,
253
+ padding=0))
254
+
255
+ def forward(self, x, context=None):
256
+ # note: if no context is given, cross-attention defaults to self-attention
257
+ b, c, h, w = x.shape
258
+ x_in = x
259
+ x = self.norm(x)
260
+ x = self.proj_in(x)
261
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
262
+ for block in self.transformer_blocks:
263
+ x = block(x, context=context)
264
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
265
+ x = self.proj_out(x)
266
+ return x + x_in
267
+
268
+
269
+ ##########################
270
+ # transformer no context #
271
+ ##########################
272
+
273
+ class BasicTransformerBlockNoContext(nn.Module):
274
+ def __init__(self, dim, n_heads, d_head, dropout=0., gated_ff=True, checkpoint=True):
275
+ super().__init__()
276
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head,
277
+ dropout=dropout, context_dim=None)
278
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
279
+ self.attn2 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head,
280
+ dropout=dropout, context_dim=None)
281
+ self.norm1 = nn.LayerNorm(dim)
282
+ self.norm2 = nn.LayerNorm(dim)
283
+ self.norm3 = nn.LayerNorm(dim)
284
+ self.checkpoint = checkpoint
285
+
286
+ def forward(self, x):
287
+ return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
288
+
289
+ def _forward(self, x):
290
+ x = self.attn1(self.norm1(x)) + x
291
+ x = self.attn2(self.norm2(x)) + x
292
+ x = self.ff(self.norm3(x)) + x
293
+ return x
294
+
295
+ class SpatialTransformerNoContext(nn.Module):
296
+ """
297
+ Transformer block for image-like data.
298
+ First, project the input (aka embedding)
299
+ and reshape to b, t, d.
300
+ Then apply standard transformer action.
301
+ Finally, reshape to image
302
+ """
303
+ def __init__(self, in_channels, n_heads, d_head,
304
+ depth=1, dropout=0.,):
305
+ super().__init__()
306
+ self.in_channels = in_channels
307
+ inner_dim = n_heads * d_head
308
+ self.norm = Normalize(in_channels)
309
+
310
+ self.proj_in = nn.Conv2d(in_channels,
311
+ inner_dim,
312
+ kernel_size=1,
313
+ stride=1,
314
+ padding=0)
315
+
316
+ self.transformer_blocks = nn.ModuleList(
317
+ [BasicTransformerBlockNoContext(inner_dim, n_heads, d_head, dropout=dropout)
318
+ for d in range(depth)]
319
+ )
320
+
321
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
322
+ in_channels,
323
+ kernel_size=1,
324
+ stride=1,
325
+ padding=0))
326
+
327
+ def forward(self, x):
328
+ # note: if no context is given, cross-attention defaults to self-attention
329
+ b, c, h, w = x.shape
330
+ x_in = x
331
+ x = self.norm(x)
332
+ x = self.proj_in(x)
333
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
334
+ for block in self.transformer_blocks:
335
+ x = block(x)
336
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
337
+ x = self.proj_out(x)
338
+ return x + x_in
339
+
340
+
341
+ #######################################
342
+ # Spatial Transformer with Two Branch #
343
+ #######################################
344
+
345
+ class DualSpatialTransformer(nn.Module):
346
+ def __init__(self, in_channels, n_heads, d_head,
347
+ depth=1, dropout=0., context_dim=None,
348
+ disable_self_attn=False):
349
+ super().__init__()
350
+ self.in_channels = in_channels
351
+ inner_dim = n_heads * d_head
352
+
353
+ # First crossattn
354
+ self.norm_0 = Normalize(in_channels)
355
+ self.proj_in_0 = nn.Conv2d(
356
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
357
+ self.transformer_blocks_0 = nn.ModuleList(
358
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
359
+ disable_self_attn=disable_self_attn)
360
+ for d in range(depth)]
361
+ )
362
+ self.proj_out_0 = zero_module(nn.Conv2d(
363
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
364
+
365
+ # Second crossattn
366
+ self.norm_1 = Normalize(in_channels)
367
+ self.proj_in_1 = nn.Conv2d(
368
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
369
+ self.transformer_blocks_1 = nn.ModuleList(
370
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
371
+ disable_self_attn=disable_self_attn)
372
+ for d in range(depth)]
373
+ )
374
+ self.proj_out_1 = zero_module(nn.Conv2d(
375
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
376
+
377
+ def forward(self, x, context=None, which=None):
378
+ # note: if no context is given, cross-attention defaults to self-attention
379
+ b, c, h, w = x.shape
380
+ x_in = x
381
+ if which==0:
382
+ norm, proj_in, blocks, proj_out = \
383
+ self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0
384
+ elif which==1:
385
+ norm, proj_in, blocks, proj_out = \
386
+ self.norm_1, self.proj_in_1, self.transformer_blocks_1, self.proj_out_1
387
+ else:
388
+ # assert False, 'DualSpatialTransformer forward with a invalid which branch!'
389
+ # import numpy.random as npr
390
+ # rwhich = 0 if npr.rand() < which else 1
391
+ # context = context[rwhich]
392
+ # if rwhich==0:
393
+ # norm, proj_in, blocks, proj_out = \
394
+ # self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0
395
+ # elif rwhich==1:
396
+ # norm, proj_in, blocks, proj_out = \
397
+ # self.norm_1, self.proj_in_1, self.transformer_blocks_1, self.proj_out_1
398
+
399
+ # import numpy.random as npr
400
+ # rwhich = 0 if npr.rand() < 0.33 else 1
401
+ # if rwhich==0:
402
+ # context = context[rwhich]
403
+ # norm, proj_in, blocks, proj_out = \
404
+ # self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0
405
+ # else:
406
+
407
+ norm, proj_in, blocks, proj_out = \
408
+ self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0
409
+ x0 = norm(x)
410
+ x0 = proj_in(x0)
411
+ x0 = rearrange(x0, 'b c h w -> b (h w) c').contiguous()
412
+ for block in blocks:
413
+ x0 = block(x0, context=context[0])
414
+ x0 = rearrange(x0, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
415
+ x0 = proj_out(x0)
416
+
417
+ norm, proj_in, blocks, proj_out = \
418
+ self.norm_1, self.proj_in_1, self.transformer_blocks_1, self.proj_out_1
419
+ x1 = norm(x)
420
+ x1 = proj_in(x1)
421
+ x1 = rearrange(x1, 'b c h w -> b (h w) c').contiguous()
422
+ for block in blocks:
423
+ x1 = block(x1, context=context[1])
424
+ x1 = rearrange(x1, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
425
+ x1 = proj_out(x1)
426
+ return x0*which + x1*(1-which) + x_in
427
+
428
+ x = norm(x)
429
+ x = proj_in(x)
430
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
431
+ for block in blocks:
432
+ x = block(x, context=context)
433
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
434
+ x = proj_out(x)
435
+ return x + x_in
lib/model_zoo/autokl.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from contextlib import contextmanager
5
+ from lib.model_zoo.common.get_model import get_model, register
6
+
7
+ # from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
8
+
9
+ from .autokl_modules import Encoder, Decoder
10
+ from .distributions import DiagonalGaussianDistribution
11
+
12
+ from .autokl_utils import LPIPSWithDiscriminator
13
+
14
+ @register('autoencoderkl')
15
+ class AutoencoderKL(nn.Module):
16
+ def __init__(self,
17
+ ddconfig,
18
+ lossconfig,
19
+ embed_dim,):
20
+ super().__init__()
21
+ self.encoder = Encoder(**ddconfig)
22
+ self.decoder = Decoder(**ddconfig)
23
+ if lossconfig is not None:
24
+ self.loss = LPIPSWithDiscriminator(**lossconfig)
25
+ assert ddconfig["double_z"]
26
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
27
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
28
+ self.embed_dim = embed_dim
29
+
30
+ @torch.no_grad()
31
+ def encode(self, x, out_posterior=False):
32
+ return self.encode_trainable(x, out_posterior)
33
+
34
+ def encode_trainable(self, x, out_posterior=False):
35
+ x = x*2-1
36
+ h = self.encoder(x)
37
+ moments = self.quant_conv(h)
38
+ posterior = DiagonalGaussianDistribution(moments)
39
+ if out_posterior:
40
+ return posterior
41
+ else:
42
+ return posterior.sample()
43
+
44
+ @torch.no_grad()
45
+ def decode(self, z):
46
+ z = self.post_quant_conv(z)
47
+ dec = self.decoder(z)
48
+ dec = torch.clamp((dec+1)/2, 0, 1)
49
+ return dec
50
+
51
+ def decode_trainable(self, z):
52
+ z = self.post_quant_conv(z)
53
+ dec = self.decoder(z)
54
+ dec = (dec+1)/2
55
+ return dec
56
+
57
+ def apply_model(self, input, sample_posterior=True):
58
+ posterior = self.encode_trainable(input, out_posterior=True)
59
+ if sample_posterior:
60
+ z = posterior.sample()
61
+ else:
62
+ z = posterior.mode()
63
+ dec = self.decode_trainable(z)
64
+ return dec, posterior
65
+
66
+ def get_input(self, batch, k):
67
+ x = batch[k]
68
+ if len(x.shape) == 3:
69
+ x = x[..., None]
70
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
71
+ return x
72
+
73
+ def forward(self, x, optimizer_idx, global_step):
74
+ reconstructions, posterior = self.apply_model(x)
75
+
76
+ if optimizer_idx == 0:
77
+ # train encoder+decoder+logvar
78
+ aeloss, log_dict_ae = self.loss(x, reconstructions, posterior, optimizer_idx, global_step=global_step,
79
+ last_layer=self.get_last_layer(), split="train")
80
+ return aeloss, log_dict_ae
81
+
82
+ if optimizer_idx == 1:
83
+ # train the discriminator
84
+ discloss, log_dict_disc = self.loss(x, reconstructions, posterior, optimizer_idx, global_step=global_step,
85
+ last_layer=self.get_last_layer(), split="train")
86
+
87
+ return discloss, log_dict_disc
88
+
89
+ def validation_step(self, batch, batch_idx):
90
+ inputs = self.get_input(batch, self.image_key)
91
+ reconstructions, posterior = self(inputs)
92
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
93
+ last_layer=self.get_last_layer(), split="val")
94
+
95
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
96
+ last_layer=self.get_last_layer(), split="val")
97
+
98
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
99
+ self.log_dict(log_dict_ae)
100
+ self.log_dict(log_dict_disc)
101
+ return self.log_dict
102
+
103
+ def configure_optimizers(self):
104
+ lr = self.learning_rate
105
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
106
+ list(self.decoder.parameters())+
107
+ list(self.quant_conv.parameters())+
108
+ list(self.post_quant_conv.parameters()),
109
+ lr=lr, betas=(0.5, 0.9))
110
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
111
+ lr=lr, betas=(0.5, 0.9))
112
+ return [opt_ae, opt_disc], []
113
+
114
+ def get_last_layer(self):
115
+ return self.decoder.conv_out.weight
116
+
117
+ @torch.no_grad()
118
+ def log_images(self, batch, only_inputs=False, **kwargs):
119
+ log = dict()
120
+ x = self.get_input(batch, self.image_key)
121
+ x = x.to(self.device)
122
+ if not only_inputs:
123
+ xrec, posterior = self(x)
124
+ if x.shape[1] > 3:
125
+ # colorize with random projection
126
+ assert xrec.shape[1] > 3
127
+ x = self.to_rgb(x)
128
+ xrec = self.to_rgb(xrec)
129
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
130
+ log["reconstructions"] = xrec
131
+ log["inputs"] = x
132
+ return log
133
+
134
+ def to_rgb(self, x):
135
+ assert self.image_key == "segmentation"
136
+ if not hasattr(self, "colorize"):
137
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
138
+ x = F.conv2d(x, weight=self.colorize)
139
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
140
+ return x
lib/model_zoo/autokl_modules.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+
8
+ # from .diffusion_utils import instantiate_from_config
9
+ from .attention import LinearAttention
10
+
11
+
12
+ def get_timestep_embedding(timesteps, embedding_dim):
13
+ """
14
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
15
+ From Fairseq.
16
+ Build sinusoidal embeddings.
17
+ This matches the implementation in tensor2tensor, but differs slightly
18
+ from the description in Section 3.5 of "Attention Is All You Need".
19
+ """
20
+ assert len(timesteps.shape) == 1
21
+
22
+ half_dim = embedding_dim // 2
23
+ emb = math.log(10000) / (half_dim - 1)
24
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
25
+ emb = emb.to(device=timesteps.device)
26
+ emb = timesteps.float()[:, None] * emb[None, :]
27
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
28
+ if embedding_dim % 2 == 1: # zero pad
29
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
30
+ return emb
31
+
32
+
33
+ def nonlinearity(x):
34
+ # swish
35
+ return x*torch.sigmoid(x)
36
+
37
+
38
+ def Normalize(in_channels, num_groups=32):
39
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
40
+
41
+
42
+ class Upsample(nn.Module):
43
+ def __init__(self, in_channels, with_conv):
44
+ super().__init__()
45
+ self.with_conv = with_conv
46
+ if self.with_conv:
47
+ self.conv = torch.nn.Conv2d(in_channels,
48
+ in_channels,
49
+ kernel_size=3,
50
+ stride=1,
51
+ padding=1)
52
+
53
+ def forward(self, x):
54
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
55
+ if self.with_conv:
56
+ x = self.conv(x)
57
+ return x
58
+
59
+
60
+ class Downsample(nn.Module):
61
+ def __init__(self, in_channels, with_conv):
62
+ super().__init__()
63
+ self.with_conv = with_conv
64
+ if self.with_conv:
65
+ # no asymmetric padding in torch conv, must do it ourselves
66
+ self.conv = torch.nn.Conv2d(in_channels,
67
+ in_channels,
68
+ kernel_size=3,
69
+ stride=2,
70
+ padding=0)
71
+
72
+ def forward(self, x):
73
+ if self.with_conv:
74
+ pad = (0,1,0,1)
75
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
76
+ x = self.conv(x)
77
+ else:
78
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
79
+ return x
80
+
81
+
82
+ class ResnetBlock(nn.Module):
83
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
84
+ dropout, temb_channels=512):
85
+ super().__init__()
86
+ self.in_channels = in_channels
87
+ out_channels = in_channels if out_channels is None else out_channels
88
+ self.out_channels = out_channels
89
+ self.use_conv_shortcut = conv_shortcut
90
+
91
+ self.norm1 = Normalize(in_channels)
92
+ self.conv1 = torch.nn.Conv2d(in_channels,
93
+ out_channels,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1)
97
+ if temb_channels > 0:
98
+ self.temb_proj = torch.nn.Linear(temb_channels,
99
+ out_channels)
100
+ self.norm2 = Normalize(out_channels)
101
+ self.dropout = torch.nn.Dropout(dropout)
102
+ self.conv2 = torch.nn.Conv2d(out_channels,
103
+ out_channels,
104
+ kernel_size=3,
105
+ stride=1,
106
+ padding=1)
107
+ if self.in_channels != self.out_channels:
108
+ if self.use_conv_shortcut:
109
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
110
+ out_channels,
111
+ kernel_size=3,
112
+ stride=1,
113
+ padding=1)
114
+ else:
115
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
116
+ out_channels,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0)
120
+
121
+ def forward(self, x, temb):
122
+ h = x
123
+ h = self.norm1(h)
124
+ h = nonlinearity(h)
125
+ h = self.conv1(h)
126
+
127
+ if temb is not None:
128
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
129
+
130
+ h = self.norm2(h)
131
+ h = nonlinearity(h)
132
+ h = self.dropout(h)
133
+ h = self.conv2(h)
134
+
135
+ if self.in_channels != self.out_channels:
136
+ if self.use_conv_shortcut:
137
+ x = self.conv_shortcut(x)
138
+ else:
139
+ x = self.nin_shortcut(x)
140
+
141
+ return x+h
142
+
143
+
144
+ class LinAttnBlock(LinearAttention):
145
+ """to match AttnBlock usage"""
146
+ def __init__(self, in_channels):
147
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
148
+
149
+
150
+ class AttnBlock(nn.Module):
151
+ def __init__(self, in_channels):
152
+ super().__init__()
153
+ self.in_channels = in_channels
154
+
155
+ self.norm = Normalize(in_channels)
156
+ self.q = torch.nn.Conv2d(in_channels,
157
+ in_channels,
158
+ kernel_size=1,
159
+ stride=1,
160
+ padding=0)
161
+ self.k = torch.nn.Conv2d(in_channels,
162
+ in_channels,
163
+ kernel_size=1,
164
+ stride=1,
165
+ padding=0)
166
+ self.v = torch.nn.Conv2d(in_channels,
167
+ in_channels,
168
+ kernel_size=1,
169
+ stride=1,
170
+ padding=0)
171
+ self.proj_out = torch.nn.Conv2d(in_channels,
172
+ in_channels,
173
+ kernel_size=1,
174
+ stride=1,
175
+ padding=0)
176
+
177
+
178
+ def forward(self, x):
179
+ h_ = x
180
+ h_ = self.norm(h_)
181
+ q = self.q(h_)
182
+ k = self.k(h_)
183
+ v = self.v(h_)
184
+
185
+ # compute attention
186
+ b,c,h,w = q.shape
187
+ q = q.reshape(b,c,h*w)
188
+ q = q.permute(0,2,1) # b,hw,c
189
+ k = k.reshape(b,c,h*w) # b,c,hw
190
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
191
+ w_ = w_ * (int(c)**(-0.5))
192
+ w_ = torch.nn.functional.softmax(w_, dim=2)
193
+
194
+ # attend to values
195
+ v = v.reshape(b,c,h*w)
196
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
197
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
198
+ h_ = h_.reshape(b,c,h,w)
199
+
200
+ h_ = self.proj_out(h_)
201
+
202
+ return x+h_
203
+
204
+
205
+ def make_attn(in_channels, attn_type="vanilla"):
206
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
207
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
208
+ if attn_type == "vanilla":
209
+ return AttnBlock(in_channels)
210
+ elif attn_type == "none":
211
+ return nn.Identity(in_channels)
212
+ else:
213
+ return LinAttnBlock(in_channels)
214
+
215
+
216
+ class Model(nn.Module):
217
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
218
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
219
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
220
+ super().__init__()
221
+ if use_linear_attn: attn_type = "linear"
222
+ self.ch = ch
223
+ self.temb_ch = self.ch*4
224
+ self.num_resolutions = len(ch_mult)
225
+ self.num_res_blocks = num_res_blocks
226
+ self.resolution = resolution
227
+ self.in_channels = in_channels
228
+
229
+ self.use_timestep = use_timestep
230
+ if self.use_timestep:
231
+ # timestep embedding
232
+ self.temb = nn.Module()
233
+ self.temb.dense = nn.ModuleList([
234
+ torch.nn.Linear(self.ch,
235
+ self.temb_ch),
236
+ torch.nn.Linear(self.temb_ch,
237
+ self.temb_ch),
238
+ ])
239
+
240
+ # downsampling
241
+ self.conv_in = torch.nn.Conv2d(in_channels,
242
+ self.ch,
243
+ kernel_size=3,
244
+ stride=1,
245
+ padding=1)
246
+
247
+ curr_res = resolution
248
+ in_ch_mult = (1,)+tuple(ch_mult)
249
+ self.down = nn.ModuleList()
250
+ for i_level in range(self.num_resolutions):
251
+ block = nn.ModuleList()
252
+ attn = nn.ModuleList()
253
+ block_in = ch*in_ch_mult[i_level]
254
+ block_out = ch*ch_mult[i_level]
255
+ for i_block in range(self.num_res_blocks):
256
+ block.append(ResnetBlock(in_channels=block_in,
257
+ out_channels=block_out,
258
+ temb_channels=self.temb_ch,
259
+ dropout=dropout))
260
+ block_in = block_out
261
+ if curr_res in attn_resolutions:
262
+ attn.append(make_attn(block_in, attn_type=attn_type))
263
+ down = nn.Module()
264
+ down.block = block
265
+ down.attn = attn
266
+ if i_level != self.num_resolutions-1:
267
+ down.downsample = Downsample(block_in, resamp_with_conv)
268
+ curr_res = curr_res // 2
269
+ self.down.append(down)
270
+
271
+ # middle
272
+ self.mid = nn.Module()
273
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
274
+ out_channels=block_in,
275
+ temb_channels=self.temb_ch,
276
+ dropout=dropout)
277
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
278
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
279
+ out_channels=block_in,
280
+ temb_channels=self.temb_ch,
281
+ dropout=dropout)
282
+
283
+ # upsampling
284
+ self.up = nn.ModuleList()
285
+ for i_level in reversed(range(self.num_resolutions)):
286
+ block = nn.ModuleList()
287
+ attn = nn.ModuleList()
288
+ block_out = ch*ch_mult[i_level]
289
+ skip_in = ch*ch_mult[i_level]
290
+ for i_block in range(self.num_res_blocks+1):
291
+ if i_block == self.num_res_blocks:
292
+ skip_in = ch*in_ch_mult[i_level]
293
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
294
+ out_channels=block_out,
295
+ temb_channels=self.temb_ch,
296
+ dropout=dropout))
297
+ block_in = block_out
298
+ if curr_res in attn_resolutions:
299
+ attn.append(make_attn(block_in, attn_type=attn_type))
300
+ up = nn.Module()
301
+ up.block = block
302
+ up.attn = attn
303
+ if i_level != 0:
304
+ up.upsample = Upsample(block_in, resamp_with_conv)
305
+ curr_res = curr_res * 2
306
+ self.up.insert(0, up) # prepend to get consistent order
307
+
308
+ # end
309
+ self.norm_out = Normalize(block_in)
310
+ self.conv_out = torch.nn.Conv2d(block_in,
311
+ out_ch,
312
+ kernel_size=3,
313
+ stride=1,
314
+ padding=1)
315
+
316
+ def forward(self, x, t=None, context=None):
317
+ #assert x.shape[2] == x.shape[3] == self.resolution
318
+ if context is not None:
319
+ # assume aligned context, cat along channel axis
320
+ x = torch.cat((x, context), dim=1)
321
+ if self.use_timestep:
322
+ # timestep embedding
323
+ assert t is not None
324
+ temb = get_timestep_embedding(t, self.ch)
325
+ temb = self.temb.dense[0](temb)
326
+ temb = nonlinearity(temb)
327
+ temb = self.temb.dense[1](temb)
328
+ else:
329
+ temb = None
330
+
331
+ # downsampling
332
+ hs = [self.conv_in(x)]
333
+ for i_level in range(self.num_resolutions):
334
+ for i_block in range(self.num_res_blocks):
335
+ h = self.down[i_level].block[i_block](hs[-1], temb)
336
+ if len(self.down[i_level].attn) > 0:
337
+ h = self.down[i_level].attn[i_block](h)
338
+ hs.append(h)
339
+ if i_level != self.num_resolutions-1:
340
+ hs.append(self.down[i_level].downsample(hs[-1]))
341
+
342
+ # middle
343
+ h = hs[-1]
344
+ h = self.mid.block_1(h, temb)
345
+ h = self.mid.attn_1(h)
346
+ h = self.mid.block_2(h, temb)
347
+
348
+ # upsampling
349
+ for i_level in reversed(range(self.num_resolutions)):
350
+ for i_block in range(self.num_res_blocks+1):
351
+ h = self.up[i_level].block[i_block](
352
+ torch.cat([h, hs.pop()], dim=1), temb)
353
+ if len(self.up[i_level].attn) > 0:
354
+ h = self.up[i_level].attn[i_block](h)
355
+ if i_level != 0:
356
+ h = self.up[i_level].upsample(h)
357
+
358
+ # end
359
+ h = self.norm_out(h)
360
+ h = nonlinearity(h)
361
+ h = self.conv_out(h)
362
+ return h
363
+
364
+ def get_last_layer(self):
365
+ return self.conv_out.weight
366
+
367
+
368
+ class Encoder(nn.Module):
369
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
370
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
371
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
372
+ **ignore_kwargs):
373
+ super().__init__()
374
+ if use_linear_attn: attn_type = "linear"
375
+ self.ch = ch
376
+ self.temb_ch = 0
377
+ self.num_resolutions = len(ch_mult)
378
+ self.num_res_blocks = num_res_blocks
379
+ self.resolution = resolution
380
+ self.in_channels = in_channels
381
+
382
+ # downsampling
383
+ self.conv_in = torch.nn.Conv2d(in_channels,
384
+ self.ch,
385
+ kernel_size=3,
386
+ stride=1,
387
+ padding=1)
388
+
389
+ curr_res = resolution
390
+ in_ch_mult = (1,)+tuple(ch_mult)
391
+ self.in_ch_mult = in_ch_mult
392
+ self.down = nn.ModuleList()
393
+ for i_level in range(self.num_resolutions):
394
+ block = nn.ModuleList()
395
+ attn = nn.ModuleList()
396
+ block_in = ch*in_ch_mult[i_level]
397
+ block_out = ch*ch_mult[i_level]
398
+ for i_block in range(self.num_res_blocks):
399
+ block.append(ResnetBlock(in_channels=block_in,
400
+ out_channels=block_out,
401
+ temb_channels=self.temb_ch,
402
+ dropout=dropout))
403
+ block_in = block_out
404
+ if curr_res in attn_resolutions:
405
+ attn.append(make_attn(block_in, attn_type=attn_type))
406
+ down = nn.Module()
407
+ down.block = block
408
+ down.attn = attn
409
+ if i_level != self.num_resolutions-1:
410
+ down.downsample = Downsample(block_in, resamp_with_conv)
411
+ curr_res = curr_res // 2
412
+ self.down.append(down)
413
+
414
+ # middle
415
+ self.mid = nn.Module()
416
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
417
+ out_channels=block_in,
418
+ temb_channels=self.temb_ch,
419
+ dropout=dropout)
420
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
421
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
422
+ out_channels=block_in,
423
+ temb_channels=self.temb_ch,
424
+ dropout=dropout)
425
+
426
+ # end
427
+ self.norm_out = Normalize(block_in)
428
+ self.conv_out = torch.nn.Conv2d(block_in,
429
+ 2*z_channels if double_z else z_channels,
430
+ kernel_size=3,
431
+ stride=1,
432
+ padding=1)
433
+
434
+ def forward(self, x):
435
+ # timestep embedding
436
+ temb = None
437
+
438
+ # downsampling
439
+ hs = [self.conv_in(x)]
440
+ for i_level in range(self.num_resolutions):
441
+ for i_block in range(self.num_res_blocks):
442
+ h = self.down[i_level].block[i_block](hs[-1], temb)
443
+ if len(self.down[i_level].attn) > 0:
444
+ h = self.down[i_level].attn[i_block](h)
445
+ hs.append(h)
446
+ if i_level != self.num_resolutions-1:
447
+ hs.append(self.down[i_level].downsample(hs[-1]))
448
+
449
+ # middle
450
+ h = hs[-1]
451
+ h = self.mid.block_1(h, temb)
452
+ h = self.mid.attn_1(h)
453
+ h = self.mid.block_2(h, temb)
454
+
455
+ # end
456
+ h = self.norm_out(h)
457
+ h = nonlinearity(h)
458
+ h = self.conv_out(h)
459
+ return h
460
+
461
+
462
+ class Decoder(nn.Module):
463
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
464
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
465
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
466
+ attn_type="vanilla", **ignorekwargs):
467
+ super().__init__()
468
+ if use_linear_attn: attn_type = "linear"
469
+ self.ch = ch
470
+ self.temb_ch = 0
471
+ self.num_resolutions = len(ch_mult)
472
+ self.num_res_blocks = num_res_blocks
473
+ self.resolution = resolution
474
+ self.in_channels = in_channels
475
+ self.give_pre_end = give_pre_end
476
+ self.tanh_out = tanh_out
477
+
478
+ # compute in_ch_mult, block_in and curr_res at lowest res
479
+ in_ch_mult = (1,)+tuple(ch_mult)
480
+ block_in = ch*ch_mult[self.num_resolutions-1]
481
+ curr_res = resolution // 2**(self.num_resolutions-1)
482
+ self.z_shape = (1,z_channels,curr_res,curr_res)
483
+ print("Working with z of shape {} = {} dimensions.".format(
484
+ self.z_shape, np.prod(self.z_shape)))
485
+
486
+ # z to block_in
487
+ self.conv_in = torch.nn.Conv2d(z_channels,
488
+ block_in,
489
+ kernel_size=3,
490
+ stride=1,
491
+ padding=1)
492
+
493
+ # middle
494
+ self.mid = nn.Module()
495
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
496
+ out_channels=block_in,
497
+ temb_channels=self.temb_ch,
498
+ dropout=dropout)
499
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
500
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
501
+ out_channels=block_in,
502
+ temb_channels=self.temb_ch,
503
+ dropout=dropout)
504
+
505
+ # upsampling
506
+ self.up = nn.ModuleList()
507
+ for i_level in reversed(range(self.num_resolutions)):
508
+ block = nn.ModuleList()
509
+ attn = nn.ModuleList()
510
+ block_out = ch*ch_mult[i_level]
511
+ for i_block in range(self.num_res_blocks+1):
512
+ block.append(ResnetBlock(in_channels=block_in,
513
+ out_channels=block_out,
514
+ temb_channels=self.temb_ch,
515
+ dropout=dropout))
516
+ block_in = block_out
517
+ if curr_res in attn_resolutions:
518
+ attn.append(make_attn(block_in, attn_type=attn_type))
519
+ up = nn.Module()
520
+ up.block = block
521
+ up.attn = attn
522
+ if i_level != 0:
523
+ up.upsample = Upsample(block_in, resamp_with_conv)
524
+ curr_res = curr_res * 2
525
+ self.up.insert(0, up) # prepend to get consistent order
526
+
527
+ # end
528
+ self.norm_out = Normalize(block_in)
529
+ self.conv_out = torch.nn.Conv2d(block_in,
530
+ out_ch,
531
+ kernel_size=3,
532
+ stride=1,
533
+ padding=1)
534
+
535
+ def forward(self, z):
536
+ #assert z.shape[1:] == self.z_shape[1:]
537
+ self.last_z_shape = z.shape
538
+
539
+ # timestep embedding
540
+ temb = None
541
+
542
+ # z to block_in
543
+ h = self.conv_in(z)
544
+
545
+ # middle
546
+ h = self.mid.block_1(h, temb)
547
+ h = self.mid.attn_1(h)
548
+ h = self.mid.block_2(h, temb)
549
+
550
+ # upsampling
551
+ for i_level in reversed(range(self.num_resolutions)):
552
+ for i_block in range(self.num_res_blocks+1):
553
+ h = self.up[i_level].block[i_block](h, temb)
554
+ if len(self.up[i_level].attn) > 0:
555
+ h = self.up[i_level].attn[i_block](h)
556
+ if i_level != 0:
557
+ h = self.up[i_level].upsample(h)
558
+
559
+ # end
560
+ if self.give_pre_end:
561
+ return h
562
+
563
+ h = self.norm_out(h)
564
+ h = nonlinearity(h)
565
+ h = self.conv_out(h)
566
+ if self.tanh_out:
567
+ h = torch.tanh(h)
568
+ return h
569
+
570
+
571
+ class SimpleDecoder(nn.Module):
572
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
573
+ super().__init__()
574
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
575
+ ResnetBlock(in_channels=in_channels,
576
+ out_channels=2 * in_channels,
577
+ temb_channels=0, dropout=0.0),
578
+ ResnetBlock(in_channels=2 * in_channels,
579
+ out_channels=4 * in_channels,
580
+ temb_channels=0, dropout=0.0),
581
+ ResnetBlock(in_channels=4 * in_channels,
582
+ out_channels=2 * in_channels,
583
+ temb_channels=0, dropout=0.0),
584
+ nn.Conv2d(2*in_channels, in_channels, 1),
585
+ Upsample(in_channels, with_conv=True)])
586
+ # end
587
+ self.norm_out = Normalize(in_channels)
588
+ self.conv_out = torch.nn.Conv2d(in_channels,
589
+ out_channels,
590
+ kernel_size=3,
591
+ stride=1,
592
+ padding=1)
593
+
594
+ def forward(self, x):
595
+ for i, layer in enumerate(self.model):
596
+ if i in [1,2,3]:
597
+ x = layer(x, None)
598
+ else:
599
+ x = layer(x)
600
+
601
+ h = self.norm_out(x)
602
+ h = nonlinearity(h)
603
+ x = self.conv_out(h)
604
+ return x
605
+
606
+
607
+ class UpsampleDecoder(nn.Module):
608
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
609
+ ch_mult=(2,2), dropout=0.0):
610
+ super().__init__()
611
+ # upsampling
612
+ self.temb_ch = 0
613
+ self.num_resolutions = len(ch_mult)
614
+ self.num_res_blocks = num_res_blocks
615
+ block_in = in_channels
616
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
617
+ self.res_blocks = nn.ModuleList()
618
+ self.upsample_blocks = nn.ModuleList()
619
+ for i_level in range(self.num_resolutions):
620
+ res_block = []
621
+ block_out = ch * ch_mult[i_level]
622
+ for i_block in range(self.num_res_blocks + 1):
623
+ res_block.append(ResnetBlock(in_channels=block_in,
624
+ out_channels=block_out,
625
+ temb_channels=self.temb_ch,
626
+ dropout=dropout))
627
+ block_in = block_out
628
+ self.res_blocks.append(nn.ModuleList(res_block))
629
+ if i_level != self.num_resolutions - 1:
630
+ self.upsample_blocks.append(Upsample(block_in, True))
631
+ curr_res = curr_res * 2
632
+
633
+ # end
634
+ self.norm_out = Normalize(block_in)
635
+ self.conv_out = torch.nn.Conv2d(block_in,
636
+ out_channels,
637
+ kernel_size=3,
638
+ stride=1,
639
+ padding=1)
640
+
641
+ def forward(self, x):
642
+ # upsampling
643
+ h = x
644
+ for k, i_level in enumerate(range(self.num_resolutions)):
645
+ for i_block in range(self.num_res_blocks + 1):
646
+ h = self.res_blocks[i_level][i_block](h, None)
647
+ if i_level != self.num_resolutions - 1:
648
+ h = self.upsample_blocks[k](h)
649
+ h = self.norm_out(h)
650
+ h = nonlinearity(h)
651
+ h = self.conv_out(h)
652
+ return h
653
+
654
+
655
+ class LatentRescaler(nn.Module):
656
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
657
+ super().__init__()
658
+ # residual block, interpolate, residual block
659
+ self.factor = factor
660
+ self.conv_in = nn.Conv2d(in_channels,
661
+ mid_channels,
662
+ kernel_size=3,
663
+ stride=1,
664
+ padding=1)
665
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
666
+ out_channels=mid_channels,
667
+ temb_channels=0,
668
+ dropout=0.0) for _ in range(depth)])
669
+ self.attn = AttnBlock(mid_channels)
670
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
671
+ out_channels=mid_channels,
672
+ temb_channels=0,
673
+ dropout=0.0) for _ in range(depth)])
674
+
675
+ self.conv_out = nn.Conv2d(mid_channels,
676
+ out_channels,
677
+ kernel_size=1,
678
+ )
679
+
680
+ def forward(self, x):
681
+ x = self.conv_in(x)
682
+ for block in self.res_block1:
683
+ x = block(x, None)
684
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
685
+ x = self.attn(x)
686
+ for block in self.res_block2:
687
+ x = block(x, None)
688
+ x = self.conv_out(x)
689
+ return x
690
+
691
+
692
+ class MergedRescaleEncoder(nn.Module):
693
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
694
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
695
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
696
+ super().__init__()
697
+ intermediate_chn = ch * ch_mult[-1]
698
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
699
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
700
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
701
+ out_ch=None)
702
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
703
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
704
+
705
+ def forward(self, x):
706
+ x = self.encoder(x)
707
+ x = self.rescaler(x)
708
+ return x
709
+
710
+
711
+ class MergedRescaleDecoder(nn.Module):
712
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
713
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
714
+ super().__init__()
715
+ tmp_chn = z_channels*ch_mult[-1]
716
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
717
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
718
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
719
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
720
+ out_channels=tmp_chn, depth=rescale_module_depth)
721
+
722
+ def forward(self, x):
723
+ x = self.rescaler(x)
724
+ x = self.decoder(x)
725
+ return x
726
+
727
+
728
+ class Upsampler(nn.Module):
729
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
730
+ super().__init__()
731
+ assert out_size >= in_size
732
+ num_blocks = int(np.log2(out_size//in_size))+1
733
+ factor_up = 1.+ (out_size % in_size)
734
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
735
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
736
+ out_channels=in_channels)
737
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
738
+ attn_resolutions=[], in_channels=None, ch=in_channels,
739
+ ch_mult=[ch_mult for _ in range(num_blocks)])
740
+
741
+ def forward(self, x):
742
+ x = self.rescaler(x)
743
+ x = self.decoder(x)
744
+ return x
745
+
746
+
747
+ class Resize(nn.Module):
748
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
749
+ super().__init__()
750
+ self.with_conv = learned
751
+ self.mode = mode
752
+ if self.with_conv:
753
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
754
+ raise NotImplementedError()
755
+ assert in_channels is not None
756
+ # no asymmetric padding in torch conv, must do it ourselves
757
+ self.conv = torch.nn.Conv2d(in_channels,
758
+ in_channels,
759
+ kernel_size=4,
760
+ stride=2,
761
+ padding=1)
762
+
763
+ def forward(self, x, scale_factor=1.0):
764
+ if scale_factor==1.0:
765
+ return x
766
+ else:
767
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
768
+ return x
769
+
770
+ class FirstStagePostProcessor(nn.Module):
771
+
772
+ def __init__(self, ch_mult:list, in_channels,
773
+ pretrained_model:nn.Module=None,
774
+ reshape=False,
775
+ n_channels=None,
776
+ dropout=0.,
777
+ pretrained_config=None):
778
+ super().__init__()
779
+ if pretrained_config is None:
780
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
781
+ self.pretrained_model = pretrained_model
782
+ else:
783
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
784
+ self.instantiate_pretrained(pretrained_config)
785
+
786
+ self.do_reshape = reshape
787
+
788
+ if n_channels is None:
789
+ n_channels = self.pretrained_model.encoder.ch
790
+
791
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
792
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
793
+ stride=1,padding=1)
794
+
795
+ blocks = []
796
+ downs = []
797
+ ch_in = n_channels
798
+ for m in ch_mult:
799
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
800
+ ch_in = m * n_channels
801
+ downs.append(Downsample(ch_in, with_conv=False))
802
+
803
+ self.model = nn.ModuleList(blocks)
804
+ self.downsampler = nn.ModuleList(downs)
805
+
806
+
807
+ def instantiate_pretrained(self, config):
808
+ model = instantiate_from_config(config)
809
+ self.pretrained_model = model.eval()
810
+ # self.pretrained_model.train = False
811
+ for param in self.pretrained_model.parameters():
812
+ param.requires_grad = False
813
+
814
+
815
+ @torch.no_grad()
816
+ def encode_with_pretrained(self,x):
817
+ c = self.pretrained_model.encode(x)
818
+ if isinstance(c, DiagonalGaussianDistribution):
819
+ c = c.mode()
820
+ return c
821
+
822
+ def forward(self,x):
823
+ z_fs = self.encode_with_pretrained(x)
824
+ z = self.proj_norm(z_fs)
825
+ z = self.proj(z)
826
+ z = nonlinearity(z)
827
+
828
+ for submodel, downmodel in zip(self.model,self.downsampler):
829
+ z = submodel(z,temb=None)
830
+ z = downmodel(z)
831
+
832
+ if self.do_reshape:
833
+ z = rearrange(z,'b c h w -> b (h w) c')
834
+ return z
835
+
lib/model_zoo/autokl_utils.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import functools
4
+
5
+ class ActNorm(nn.Module):
6
+ def __init__(self, num_features, logdet=False, affine=True,
7
+ allow_reverse_init=False):
8
+ assert affine
9
+ super().__init__()
10
+ self.logdet = logdet
11
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
12
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
13
+ self.allow_reverse_init = allow_reverse_init
14
+
15
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
16
+
17
+ def initialize(self, input):
18
+ with torch.no_grad():
19
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
20
+ mean = (
21
+ flatten.mean(1)
22
+ .unsqueeze(1)
23
+ .unsqueeze(2)
24
+ .unsqueeze(3)
25
+ .permute(1, 0, 2, 3)
26
+ )
27
+ std = (
28
+ flatten.std(1)
29
+ .unsqueeze(1)
30
+ .unsqueeze(2)
31
+ .unsqueeze(3)
32
+ .permute(1, 0, 2, 3)
33
+ )
34
+
35
+ self.loc.data.copy_(-mean)
36
+ self.scale.data.copy_(1 / (std + 1e-6))
37
+
38
+ def forward(self, input, reverse=False):
39
+ if reverse:
40
+ return self.reverse(input)
41
+ if len(input.shape) == 2:
42
+ input = input[:,:,None,None]
43
+ squeeze = True
44
+ else:
45
+ squeeze = False
46
+
47
+ _, _, height, width = input.shape
48
+
49
+ if self.training and self.initialized.item() == 0:
50
+ self.initialize(input)
51
+ self.initialized.fill_(1)
52
+
53
+ h = self.scale * (input + self.loc)
54
+
55
+ if squeeze:
56
+ h = h.squeeze(-1).squeeze(-1)
57
+
58
+ if self.logdet:
59
+ log_abs = torch.log(torch.abs(self.scale))
60
+ logdet = height*width*torch.sum(log_abs)
61
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
62
+ return h, logdet
63
+
64
+ return h
65
+
66
+ def reverse(self, output):
67
+ if self.training and self.initialized.item() == 0:
68
+ if not self.allow_reverse_init:
69
+ raise RuntimeError(
70
+ "Initializing ActNorm in reverse direction is "
71
+ "disabled by default. Use allow_reverse_init=True to enable."
72
+ )
73
+ else:
74
+ self.initialize(output)
75
+ self.initialized.fill_(1)
76
+
77
+ if len(output.shape) == 2:
78
+ output = output[:,:,None,None]
79
+ squeeze = True
80
+ else:
81
+ squeeze = False
82
+
83
+ h = output / self.scale - self.loc
84
+
85
+ if squeeze:
86
+ h = h.squeeze(-1).squeeze(-1)
87
+ return h
88
+
89
+ #################
90
+ # Discriminator #
91
+ #################
92
+
93
+ def weights_init(m):
94
+ classname = m.__class__.__name__
95
+ if classname.find('Conv') != -1:
96
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
97
+ elif classname.find('BatchNorm') != -1:
98
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
99
+ nn.init.constant_(m.bias.data, 0)
100
+
101
+ class NLayerDiscriminator(nn.Module):
102
+ """Defines a PatchGAN discriminator as in Pix2Pix
103
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
104
+ """
105
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
106
+ """Construct a PatchGAN discriminator
107
+ Parameters:
108
+ input_nc (int) -- the number of channels in input images
109
+ ndf (int) -- the number of filters in the last conv layer
110
+ n_layers (int) -- the number of conv layers in the discriminator
111
+ norm_layer -- normalization layer
112
+ """
113
+ super(NLayerDiscriminator, self).__init__()
114
+ if not use_actnorm:
115
+ norm_layer = nn.BatchNorm2d
116
+ else:
117
+ norm_layer = ActNorm
118
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
119
+ use_bias = norm_layer.func != nn.BatchNorm2d
120
+ else:
121
+ use_bias = norm_layer != nn.BatchNorm2d
122
+
123
+ kw = 4
124
+ padw = 1
125
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
126
+ nf_mult = 1
127
+ nf_mult_prev = 1
128
+ for n in range(1, n_layers): # gradually increase the number of filters
129
+ nf_mult_prev = nf_mult
130
+ nf_mult = min(2 ** n, 8)
131
+ sequence += [
132
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
133
+ norm_layer(ndf * nf_mult),
134
+ nn.LeakyReLU(0.2, True)
135
+ ]
136
+
137
+ nf_mult_prev = nf_mult
138
+ nf_mult = min(2 ** n_layers, 8)
139
+ sequence += [
140
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
141
+ norm_layer(ndf * nf_mult),
142
+ nn.LeakyReLU(0.2, True)
143
+ ]
144
+
145
+ sequence += [
146
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
147
+ self.main = nn.Sequential(*sequence)
148
+
149
+ def forward(self, input):
150
+ """Standard forward."""
151
+ return self.main(input)
152
+
153
+ #########
154
+ # LPIPS #
155
+ #########
156
+
157
+ class ScalingLayer(nn.Module):
158
+ def __init__(self):
159
+ super(ScalingLayer, self).__init__()
160
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
161
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
162
+
163
+ def forward(self, inp):
164
+ return (inp - self.shift) / self.scale
165
+
166
+ class NetLinLayer(nn.Module):
167
+ """ A single linear layer which does a 1x1 conv """
168
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
169
+ super(NetLinLayer, self).__init__()
170
+ layers = [nn.Dropout(), ] if (use_dropout) else []
171
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
172
+ self.model = nn.Sequential(*layers)
173
+
174
+ from collections import namedtuple
175
+ from torchvision import models
176
+ from torchvision.models import VGG16_Weights
177
+
178
+ class vgg16(torch.nn.Module):
179
+ def __init__(self, requires_grad=False, pretrained=True):
180
+ super(vgg16, self).__init__()
181
+ if pretrained:
182
+ vgg_pretrained_features = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features
183
+ self.slice1 = torch.nn.Sequential()
184
+ self.slice2 = torch.nn.Sequential()
185
+ self.slice3 = torch.nn.Sequential()
186
+ self.slice4 = torch.nn.Sequential()
187
+ self.slice5 = torch.nn.Sequential()
188
+ self.N_slices = 5
189
+ for x in range(4):
190
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
191
+ for x in range(4, 9):
192
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
193
+ for x in range(9, 16):
194
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
195
+ for x in range(16, 23):
196
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
197
+ for x in range(23, 30):
198
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
199
+ if not requires_grad:
200
+ for param in self.parameters():
201
+ param.requires_grad = False
202
+
203
+ def forward(self, X):
204
+ h = self.slice1(X)
205
+ h_relu1_2 = h
206
+ h = self.slice2(h)
207
+ h_relu2_2 = h
208
+ h = self.slice3(h)
209
+ h_relu3_3 = h
210
+ h = self.slice4(h)
211
+ h_relu4_3 = h
212
+ h = self.slice5(h)
213
+ h_relu5_3 = h
214
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
215
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
216
+ return out
217
+
218
+ def normalize_tensor(x,eps=1e-10):
219
+ norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
220
+ return x/(norm_factor+eps)
221
+
222
+ def spatial_average(x, keepdim=True):
223
+ return x.mean([2,3],keepdim=keepdim)
224
+
225
+ def get_ckpt_path(*args, **kwargs):
226
+ return 'pretrained/lpips.pth'
227
+
228
+ class LPIPS(nn.Module):
229
+ # Learned perceptual metric
230
+ def __init__(self, use_dropout=True):
231
+ super().__init__()
232
+ self.scaling_layer = ScalingLayer()
233
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
234
+ self.net = vgg16(pretrained=True, requires_grad=False)
235
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
236
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
237
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
238
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
239
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
240
+ self.load_from_pretrained()
241
+ for param in self.parameters():
242
+ param.requires_grad = False
243
+
244
+ def load_from_pretrained(self, name="vgg_lpips"):
245
+ ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
246
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
247
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
248
+
249
+ @classmethod
250
+ def from_pretrained(cls, name="vgg_lpips"):
251
+ if name != "vgg_lpips":
252
+ raise NotImplementedError
253
+ model = cls()
254
+ ckpt = get_ckpt_path(name)
255
+ model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
256
+ return model
257
+
258
+ def forward(self, input, target):
259
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
260
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
261
+ feats0, feats1, diffs = {}, {}, {}
262
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
263
+ for kk in range(len(self.chns)):
264
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
265
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
266
+
267
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
268
+ val = res[0]
269
+ for l in range(1, len(self.chns)):
270
+ val += res[l]
271
+ return val
272
+
273
+ ############
274
+ # The loss #
275
+ ############
276
+
277
+ def adopt_weight(weight, global_step, threshold=0, value=0.):
278
+ if global_step < threshold:
279
+ weight = value
280
+ return weight
281
+
282
+ def hinge_d_loss(logits_real, logits_fake):
283
+ loss_real = torch.mean(F.relu(1. - logits_real))
284
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
285
+ d_loss = 0.5 * (loss_real + loss_fake)
286
+ return d_loss
287
+
288
+ def vanilla_d_loss(logits_real, logits_fake):
289
+ d_loss = 0.5 * (
290
+ torch.mean(torch.nn.functional.softplus(-logits_real)) +
291
+ torch.mean(torch.nn.functional.softplus(logits_fake)))
292
+ return d_loss
293
+
294
+ class LPIPSWithDiscriminator(nn.Module):
295
+ def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
296
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
297
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
298
+ disc_loss="hinge"):
299
+
300
+ super().__init__()
301
+ assert disc_loss in ["hinge", "vanilla"]
302
+ self.kl_weight = kl_weight
303
+ self.pixel_weight = pixelloss_weight
304
+ self.perceptual_loss = LPIPS().eval()
305
+ self.perceptual_weight = perceptual_weight
306
+ # output log variance
307
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
308
+
309
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
310
+ n_layers=disc_num_layers,
311
+ use_actnorm=use_actnorm
312
+ ).apply(weights_init)
313
+ self.discriminator_iter_start = disc_start
314
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
315
+ self.disc_factor = disc_factor
316
+ self.discriminator_weight = disc_weight
317
+ self.disc_conditional = disc_conditional
318
+
319
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
320
+ if last_layer is not None:
321
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
322
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
323
+ else:
324
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
325
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
326
+
327
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
328
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
329
+ d_weight = d_weight * self.discriminator_weight
330
+ return d_weight
331
+
332
+ def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
333
+ global_step, last_layer=None, cond=None, split="train",
334
+ weights=None):
335
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
336
+ if self.perceptual_weight > 0:
337
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
338
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
339
+
340
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
341
+ weighted_nll_loss = nll_loss
342
+ if weights is not None:
343
+ weighted_nll_loss = weights*nll_loss
344
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
345
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
346
+ kl_loss = posteriors.kl()
347
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
348
+
349
+ # now the GAN part
350
+ if optimizer_idx == 0:
351
+ # generator update
352
+ if cond is None:
353
+ assert not self.disc_conditional
354
+ logits_fake = self.discriminator(reconstructions.contiguous())
355
+ else:
356
+ assert self.disc_conditional
357
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
358
+ g_loss = -torch.mean(logits_fake)
359
+
360
+ if self.disc_factor > 0.0:
361
+ try:
362
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
363
+ except RuntimeError:
364
+ assert not self.training
365
+ d_weight = torch.tensor(0.0)
366
+ else:
367
+ d_weight = torch.tensor(0.0)
368
+
369
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
370
+ loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
371
+
372
+ log = {"Loss": loss.clone().detach().mean(),
373
+ "logvar": self.logvar.detach(),
374
+ "loss_kl": kl_loss.detach().mean(),
375
+ "loss_nll": nll_loss.detach().mean(),
376
+ "loss_rec": rec_loss.detach().mean(),
377
+ "d_weight": d_weight.detach(),
378
+ "disc_factor": torch.tensor(disc_factor),
379
+ "loss_g": g_loss.detach().mean(),
380
+ }
381
+ return loss, log
382
+
383
+ if optimizer_idx == 1:
384
+ # second pass for discriminator update
385
+ if cond is None:
386
+ logits_real = self.discriminator(inputs.contiguous().detach())
387
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
388
+ else:
389
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
390
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
391
+
392
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
393
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
394
+
395
+ log = {"Loss": d_loss.clone().detach().mean(),
396
+ "loss_disc": d_loss.clone().detach().mean(),
397
+ "logits_real": logits_real.detach().mean(),
398
+ "logits_fake": logits_fake.detach().mean()
399
+ }
400
+ return d_loss, log