RobotChao commited on
Commit
64ba4c2
·
verified ·
1 Parent(s): 118a39f
Files changed (1) hide show
  1. app.py +1114 -0
app.py ADDED
@@ -0,0 +1,1114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################################################
2
+ # Copyright (C) 2023 Xingqian Xu - All Rights Reserved #
3
+ # #
4
+ # Please visit Versatile Diffusion's arXiv paper for more details, link at #
5
+ # arxiv.org/abs/2211.08332 #
6
+ # #
7
+ # Besides, this work is also inspired by many established techniques including:#
8
+ # Denoising Diffusion Probablistic Model; Denoising Diffusion Implicit Model; #
9
+ # Latent Diffusion Model; Stable Diffusion; Stable Diffusion - Img2Img; Stable #
10
+ # Diffusion - Variation; ImageMixer; DreamBooth; Stable Diffusion - Lora; More #
11
+ # Control for Free; Prompt-to-Prompt; #
12
+ # #
13
+ ################################################################################
14
+
15
+ import gradio as gr
16
+ import os
17
+ import PIL
18
+ from PIL import Image
19
+ from pathlib import Path
20
+ import numpy as np
21
+ import numpy.random as npr
22
+ from contextlib import nullcontext
23
+ import types
24
+
25
+ import torch
26
+ import torchvision.transforms as tvtrans
27
+ from lib.cfg_helper import model_cfg_bank
28
+ from lib.model_zoo import get_model
29
+ from cusomized_gradio_blocks import create_myexamples, customized_as_example, customized_postprocess
30
+
31
+ n_sample_image = 2
32
+ n_sample_text = 4
33
+ cache_examples = True
34
+
35
+ from lib.model_zoo.ddim import DDIMSampler
36
+
37
+ ##########
38
+ # helper #
39
+ ##########
40
+
41
+ def highlight_print(info):
42
+ print('')
43
+ print(''.join(['#']*(len(info)+4)))
44
+ print('# '+info+' #')
45
+ print(''.join(['#']*(len(info)+4)))
46
+ print('')
47
+
48
+ def decompose(x, q=20, niter=100):
49
+ x_mean = x.mean(-1, keepdim=True)
50
+ x_input = x - x_mean
51
+ u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter)
52
+ ss = torch.stack([torch.diag(si) for si in s])
53
+ x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1]))
54
+ x_remain = x_input - x_lowrank
55
+ return u, s, v, x_mean, x_remain
56
+
57
+ class adjust_rank(object):
58
+ def __init__(self, max_drop_rank=[1, 5], q=20):
59
+ self.max_semantic_drop_rank = max_drop_rank[0]
60
+ self.max_style_drop_rank = max_drop_rank[1]
61
+ self.q = q
62
+
63
+ def t2y0_semf_wrapper(t0, y00, t1, y01):
64
+ return lambda t: (np.exp((t-0.5)*2)-t0)/(t1-t0)*(y01-y00)+y00
65
+ t0, y00 = np.exp((0 -0.5)*2), -self.max_semantic_drop_rank
66
+ t1, y01 = np.exp((0.5-0.5)*2), 1
67
+ self.t2y0_semf = t2y0_semf_wrapper(t0, y00, t1, y01)
68
+
69
+ def x2y_semf_wrapper(x0, x1, y1):
70
+ return lambda x, y0: (x-x0)/(x1-x0)*(y1-y0)+y0
71
+ x0 = 0
72
+ x1, y1 = self.max_semantic_drop_rank+1, 1
73
+ self.x2y_semf = x2y_semf_wrapper(x0, x1, y1)
74
+
75
+ def t2y0_styf_wrapper(t0, y00, t1, y01):
76
+ return lambda t: (np.exp((t-0.5)*2)-t0)/(t1-t0)*(y01-y00)+y00
77
+ t0, y00 = np.exp((1 -0.5)*2), -(q-self.max_style_drop_rank)
78
+ t1, y01 = np.exp((0.5-0.5)*2), 1
79
+ self.t2y0_styf = t2y0_styf_wrapper(t0, y00, t1, y01)
80
+
81
+ def x2y_styf_wrapper(x0, x1, y1):
82
+ return lambda x, y0: (x-x0)/(x1-x0)*(y1-y0)+y0
83
+ x0 = q-1
84
+ x1, y1 = self.max_style_drop_rank-1, 1
85
+ self.x2y_styf = x2y_styf_wrapper(x0, x1, y1)
86
+
87
+ def __call__(self, x, lvl):
88
+ if lvl == 0.5:
89
+ return x
90
+
91
+ if x.dtype == torch.float16:
92
+ fp16 = True
93
+ x = x.float()
94
+ else:
95
+ fp16 = False
96
+ std_save = x.std(axis=[-2, -1])
97
+
98
+ u, s, v, x_mean, x_remain = decompose(x, q=self.q)
99
+
100
+ if lvl < 0.5:
101
+ assert lvl>=0
102
+ for xi in range(0, self.max_semantic_drop_rank+1):
103
+ y0 = self.t2y0_semf(lvl)
104
+ yi = self.x2y_semf(xi, y0)
105
+ yi = 0 if yi<0 else yi
106
+ s[:, xi] *= yi
107
+
108
+ elif lvl > 0.5:
109
+ assert lvl <= 1
110
+ for xi in range(self.max_style_drop_rank, self.q):
111
+ y0 = self.t2y0_styf(lvl)
112
+ yi = self.x2y_styf(xi, y0)
113
+ yi = 0 if yi<0 else yi
114
+ s[:, xi] *= yi
115
+ x_remain = 0
116
+
117
+ ss = torch.stack([torch.diag(si) for si in s])
118
+ x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1]))
119
+ x_new = x_lowrank + x_mean + x_remain
120
+
121
+ std_new = x_new.std(axis=[-2, -1])
122
+ x_new = x_new / std_new * std_save
123
+
124
+ if fp16:
125
+ x_new = x_new.half()
126
+
127
+ return x_new
128
+
129
+ def remove_duplicate_word(tx):
130
+ def combine_words(input, length):
131
+ combined_inputs = []
132
+ if len(splitted_input)>1:
133
+ for i in range(len(input)-1):
134
+ combined_inputs.append(input[i]+" "+last_word_of(splitted_input[i+1],length)) #add the last word of the right-neighbour (overlapping) sequence (before it has expanded), which is the next word in the original sentence
135
+ return combined_inputs, length+1
136
+
137
+ def remove_duplicates(input, length):
138
+ bool_broke=False #this means we didn't find any duplicates here
139
+ for i in range(len(input) - length):
140
+ if input[i]==input[i + length]: #found a duplicate piece of sentence!
141
+ for j in range(0, length): #remove the overlapping sequences in reverse order
142
+ del input[i + length - j]
143
+ bool_broke = True
144
+ break #break the for loop as the loop length does not matches the length of splitted_input anymore as we removed elements
145
+ if bool_broke:
146
+ return remove_duplicates(input, length) #if we found a duplicate, look for another duplicate of the same length
147
+ return input
148
+
149
+ def last_word_of(input, length):
150
+ splitted = input.split(" ")
151
+ if len(splitted)==0:
152
+ return input
153
+ else:
154
+ return splitted[length-1]
155
+
156
+ def split_and_puncsplit(text):
157
+ tx = text.split(" ")
158
+ txnew = []
159
+ for txi in tx:
160
+ txqueue=[]
161
+ while True:
162
+ if txi[0] in '([{':
163
+ txqueue.extend([txi[:1], '<puncnext>'])
164
+ txi = txi[1:]
165
+ if len(txi) == 0:
166
+ break
167
+ else:
168
+ break
169
+ txnew += txqueue
170
+ txstack=[]
171
+ if len(txi) == 0:
172
+ continue
173
+ while True:
174
+ if txi[-1] in '?!.,:;}])':
175
+ txstack = ['<puncnext>', txi[-1:]] + txstack
176
+ txi = txi[:-1]
177
+ if len(txi) == 0:
178
+ break
179
+ else:
180
+ break
181
+ if len(txi) != 0:
182
+ txnew += [txi]
183
+ txnew += txstack
184
+ return txnew
185
+
186
+ if tx == '':
187
+ return tx
188
+
189
+ splitted_input = split_and_puncsplit(tx)
190
+ word_length = 1
191
+ intermediate_output = False
192
+ while len(splitted_input)>1:
193
+ splitted_input = remove_duplicates(splitted_input, word_length)
194
+ if len(splitted_input)>1:
195
+ splitted_input, word_length = combine_words(splitted_input, word_length)
196
+ if intermediate_output:
197
+ print(splitted_input)
198
+ print(word_length)
199
+ output = splitted_input[0]
200
+ output = output.replace(' <puncnext> ', '')
201
+ return output
202
+
203
+ def get_instruction(mode):
204
+ t2i_instruction = ["Generate image from text prompt."]
205
+ i2i_instruction = ["Generate image conditioned on reference image.",]
206
+ i2t_instruction = ["Generate text from reference image. "]
207
+ t2t_instruction = ["Generate text from reference text prompt. "]
208
+ dcg_instruction = ["Generate image conditioned on both text and image."]
209
+ tcg_instruction = ["Generate image conditioned on text and up to two images."]
210
+ mcg_instruction = ["Generate image from multiple contexts."]
211
+
212
+ if mode == "Text-to-Image":
213
+ return '\n'.join(t2i_instruction)
214
+ elif mode == "Image-Variation":
215
+ return '\n'.join(i2i_instruction)
216
+ elif mode == "Image-to-Text":
217
+ return '\n'.join(i2t_instruction)
218
+ elif mode == "Text-Variation":
219
+ return '\n'.join(t2t_instruction)
220
+ elif mode == "Dual-Context":
221
+ return '\n'.join(dcg_instruction)
222
+ elif mode == "Triple-Context":
223
+ return '\n'.join(tcg_instruction)
224
+ elif mode == "Multi-Context":
225
+ return '\n'.join(mcg_instruction)
226
+ else:
227
+ assert False
228
+
229
+ ########
230
+ # main #
231
+ ########
232
+ class vd_dummy(object):
233
+ def __init__(self, *args, **kwarg):
234
+ self.which = 'Vdummy'
235
+ def inference_t2i(self, *args, **kwarg): pass
236
+ def inference_i2i(self, *args, **kwarg): pass
237
+ def inference_i2t(self, *args, **kwarg): pass
238
+ def inference_t2t(self, *args, **kwarg): pass
239
+ def inference_dcg(self, *args, **kwarg): pass
240
+ def inference_tcg(self, *args, **kwarg): pass
241
+ def inference_mcg(self, *args, **kwarg):
242
+ return None, None
243
+
244
+ class vd_inference(object):
245
+ def __init__(self, fp16=False, which='v2.0'):
246
+ highlight_print(which)
247
+ self.which = which
248
+
249
+ if self.which == 'v1.0':
250
+ cfgm = model_cfg_bank()('vd_four_flow_v1-0')
251
+ else:
252
+ assert False, 'Model type not supported'
253
+ net = get_model()(cfgm)
254
+
255
+ if fp16:
256
+ highlight_print('Running in FP16')
257
+ if self.which == 'v1.0':
258
+ net.ctx['text'].fp16 = True
259
+ net.ctx['image'].fp16 = True
260
+ net = net.half()
261
+ self.dtype = torch.float16
262
+ else:
263
+ self.dtype = torch.float32
264
+
265
+ if self.which == 'v1.0':
266
+ # if fp16:
267
+ # sd = torch.load('pretrained/vd-four-flow-v1-0-fp16.pth', map_location='cpu')
268
+ # else:
269
+ # sd = torch.load('pretrained/vd-four-flow-v1-0.pth', map_location='cpu')
270
+ from huggingface_hub import hf_hub_download
271
+ if fp16:
272
+ temppath = hf_hub_download('shi-labs/versatile-diffusion-model', 'pretrained_pth/vd-four-flow-v1-0-fp16.pth')
273
+ else:
274
+ temppath = hf_hub_download('shi-labs/versatile-diffusion-model', 'pretrained_pth/vd-four-flow-v1-0.pth')
275
+ sd = torch.load(temppath, map_location='cpu')
276
+
277
+ net.load_state_dict(sd, strict=False)
278
+
279
+ self.use_cuda = torch.cuda.is_available()
280
+ if self.use_cuda:
281
+ net.to('cuda')
282
+ self.net = net
283
+ self.sampler = DDIMSampler(net)
284
+
285
+ self.output_dim = [512, 512]
286
+ self.n_sample_image = n_sample_image
287
+ self.n_sample_text = n_sample_text
288
+ self.ddim_steps = 50
289
+ self.ddim_eta = 0.0
290
+ self.scale_textto = 7.5
291
+ self.image_latent_dim = 4
292
+ self.text_latent_dim = 768
293
+ self.text_temperature = 1
294
+
295
+ if which == 'v1.0':
296
+ self.adjust_rank_f = adjust_rank(max_drop_rank=[1, 5], q=20)
297
+ self.scale_imgto = 7.5
298
+ self.disentanglement_noglobal = True
299
+
300
+ def inference_t2i(self, text, seed):
301
+ n_samples = self.n_sample_image
302
+ scale = self.scale_textto
303
+ sampler = self.sampler
304
+ h, w = self.output_dim
305
+ u = self.net.ctx_encode([""], which='text').repeat(n_samples, 1, 1)
306
+ c = self.net.ctx_encode([text], which='text').repeat(n_samples, 1, 1)
307
+ shape = [n_samples, self.image_latent_dim, h//8, w//8]
308
+ np.random.seed(seed)
309
+ torch.manual_seed(seed + 100)
310
+ x, _ = sampler.sample(
311
+ steps=self.ddim_steps,
312
+ x_info={'type':'image'},
313
+ c_info={'type':'text', 'conditioning':c, 'unconditional_conditioning':u,
314
+ 'unconditional_guidance_scale':scale},
315
+ shape=shape,
316
+ verbose=False,
317
+ eta=self.ddim_eta)
318
+ im = self.net.vae_decode(x, which='image')
319
+ im = [tvtrans.ToPILImage()(i) for i in im]
320
+ return im
321
+
322
+ def inference_i2i(self, im, fid_lvl, fcs_lvl, clr_adj, seed):
323
+ n_samples = self.n_sample_image
324
+ scale = self.scale_imgto
325
+ sampler = self.sampler
326
+ h, w = self.output_dim
327
+ device = self.net.device
328
+
329
+ BICUBIC = PIL.Image.Resampling.BICUBIC
330
+ im = im.resize([w, h], resample=BICUBIC)
331
+
332
+ if fid_lvl == 1:
333
+ return [im]*n_samples
334
+
335
+ cx = tvtrans.ToTensor()(im)[None].to(device).to(self.dtype)
336
+
337
+ c = self.net.ctx_encode(cx, which='image')
338
+ if self.disentanglement_noglobal:
339
+ c_glb = c[:, 0:1]
340
+ c_loc = c[:, 1: ]
341
+ c_loc = self.adjust_rank_f(c_loc, fcs_lvl)
342
+ c = torch.cat([c_glb, c_loc], dim=1).repeat(n_samples, 1, 1)
343
+ else:
344
+ c = self.adjust_rank_f(c, fcs_lvl).repeat(n_samples, 1, 1)
345
+ u = torch.zeros_like(c)
346
+
347
+ shape = [n_samples, self.image_latent_dim, h//8, w//8]
348
+ np.random.seed(seed)
349
+ torch.manual_seed(seed + 100)
350
+ if fid_lvl!=0:
351
+ x0 = self.net.vae_encode(cx, which='image').repeat(n_samples, 1, 1, 1)
352
+ step = int(self.ddim_steps * (1-fid_lvl))
353
+ x, _ = sampler.sample(
354
+ steps=self.ddim_steps,
355
+ x_info={'type':'image', 'x0':x0, 'x0_forward_timesteps':step},
356
+ c_info={'type':'image', 'conditioning':c, 'unconditional_conditioning':u,
357
+ 'unconditional_guidance_scale':scale},
358
+ shape=shape,
359
+ verbose=False,
360
+ eta=self.ddim_eta)
361
+ else:
362
+ x, _ = sampler.sample(
363
+ steps=self.ddim_steps,
364
+ x_info={'type':'image',},
365
+ c_info={'type':'image', 'conditioning':c, 'unconditional_conditioning':u,
366
+ 'unconditional_guidance_scale':scale},
367
+ shape=shape,
368
+ verbose=False,
369
+ eta=self.ddim_eta)
370
+
371
+ imout = self.net.vae_decode(x, which='image')
372
+
373
+ if clr_adj == 'Simple':
374
+ cx_mean = cx.view(3, -1).mean(-1)[:, None, None]
375
+ cx_std = cx.view(3, -1).std(-1)[:, None, None]
376
+ imout_mean = [imouti.view(3, -1).mean(-1)[:, None, None] for imouti in imout]
377
+ imout_std = [imouti.view(3, -1).std(-1)[:, None, None] for imouti in imout]
378
+ imout = [(ii-mi)/si*cx_std+cx_mean for ii, mi, si in zip(imout, imout_mean, imout_std)]
379
+ imout = [torch.clamp(ii, 0, 1) for ii in imout]
380
+
381
+ imout = [tvtrans.ToPILImage()(i) for i in imout]
382
+ return imout
383
+
384
+ def inference_i2t(self, im, seed):
385
+ n_samples = self.n_sample_text
386
+ scale = self.scale_imgto
387
+ sampler = self.sampler
388
+ h, w = self.output_dim
389
+ device = self.net.device
390
+
391
+ BICUBIC = PIL.Image.Resampling.BICUBIC
392
+ im = im.resize([w, h], resample=BICUBIC)
393
+
394
+ cx = tvtrans.ToTensor()(im)[None].to(device)
395
+ c = self.net.ctx_encode(cx, which='image').repeat(n_samples, 1, 1)
396
+ u = self.net.ctx_encode(torch.zeros_like(cx), which='image').repeat(n_samples, 1, 1)
397
+
398
+ shape = [n_samples, self.text_latent_dim]
399
+ np.random.seed(seed)
400
+ torch.manual_seed(seed + 100)
401
+ x, _ = sampler.sample(
402
+ steps=self.ddim_steps,
403
+ x_info={'type':'text',},
404
+ c_info={'type':'image', 'conditioning':c, 'unconditional_conditioning':u,
405
+ 'unconditional_guidance_scale':scale},
406
+ shape=shape,
407
+ verbose=False,
408
+ eta=self.ddim_eta)
409
+ tx = self.net.vae_decode(x, which='text', temperature=self.text_temperature)
410
+ tx = [remove_duplicate_word(txi) for txi in tx]
411
+ tx_combined = '\n'.join(tx)
412
+ return tx_combined
413
+
414
+ def inference_t2t(self, text, seed):
415
+ n_samples = self.n_sample_text
416
+ scale = self.scale_textto
417
+ sampler = self.sampler
418
+ u = self.net.ctx_encode([""], which='text').repeat(n_samples, 1, 1)
419
+ c = self.net.ctx_encode([text], which='text').repeat(n_samples, 1, 1)
420
+ shape = [n_samples, self.text_latent_dim]
421
+ np.random.seed(seed)
422
+ torch.manual_seed(seed + 100)
423
+ x, _ = sampler.sample(
424
+ steps=self.ddim_steps,
425
+ x_info={'type':'text',},
426
+ c_info={'type':'text', 'conditioning':c, 'unconditional_conditioning':u,
427
+ 'unconditional_guidance_scale':scale},
428
+ shape=shape,
429
+ verbose=False,
430
+ eta=self.ddim_eta)
431
+ tx = self.net.vae_decode(x, which='text', temperature=self.text_temperature)
432
+ tx = [remove_duplicate_word(txi) for txi in tx]
433
+ tx_combined = '\n'.join(tx)
434
+ return tx_combined
435
+
436
+ def inference_dcg(self, imctx, fcs_lvl, textctx, textstrength, seed):
437
+ n_samples = self.n_sample_image
438
+ sampler = self.sampler
439
+ h, w = self.output_dim
440
+ device = self.net.device
441
+
442
+ c_info_list = []
443
+
444
+ if (textctx is not None) and (textctx != "") and (textstrength != 0):
445
+ ut = self.net.ctx_encode([""], which='text').repeat(n_samples, 1, 1)
446
+ ct = self.net.ctx_encode([textctx], which='text').repeat(n_samples, 1, 1)
447
+ scale = self.scale_imgto*(1-textstrength) + self.scale_textto*textstrength
448
+
449
+ c_info_list.append({
450
+ 'type':'text',
451
+ 'conditioning':ct,
452
+ 'unconditional_conditioning':ut,
453
+ 'unconditional_guidance_scale':scale,
454
+ 'ratio': textstrength, })
455
+ else:
456
+ scale = self.scale_imgto
457
+ textstrength = 0
458
+
459
+ BICUBIC = PIL.Image.Resampling.BICUBIC
460
+ cx = imctx.resize([w, h], resample=BICUBIC)
461
+ cx = tvtrans.ToTensor()(cx)[None].to(device).to(self.dtype)
462
+ ci = self.net.ctx_encode(cx, which='image')
463
+
464
+ if self.disentanglement_noglobal:
465
+ ci_glb = ci[:, 0:1]
466
+ ci_loc = ci[:, 1: ]
467
+ ci_loc = self.adjust_rank_f(ci_loc, fcs_lvl)
468
+ ci = torch.cat([ci_glb, ci_loc], dim=1).repeat(n_samples, 1, 1)
469
+ else:
470
+ ci = self.adjust_rank_f(ci, fcs_lvl).repeat(n_samples, 1, 1)
471
+
472
+ c_info_list.append({
473
+ 'type':'image',
474
+ 'conditioning':ci,
475
+ 'unconditional_conditioning':torch.zeros_like(ci),
476
+ 'unconditional_guidance_scale':scale,
477
+ 'ratio': (1-textstrength), })
478
+
479
+ shape = [n_samples, self.image_latent_dim, h//8, w//8]
480
+ np.random.seed(seed)
481
+ torch.manual_seed(seed + 100)
482
+ x, _ = sampler.sample_multicontext(
483
+ steps=self.ddim_steps,
484
+ x_info={'type':'image',},
485
+ c_info_list=c_info_list,
486
+ shape=shape,
487
+ verbose=False,
488
+ eta=self.ddim_eta)
489
+
490
+ imout = self.net.vae_decode(x, which='image')
491
+ imout = [tvtrans.ToPILImage()(i) for i in imout]
492
+ return imout
493
+
494
+ def inference_tcg(self, *args):
495
+ args_imag = list(args[0:10]) + [None, None, None, None, None]*2
496
+ args_rest = args[10:]
497
+ imin, imout = self.inference_mcg(*args_imag, *args_rest)
498
+ return imin, imout
499
+
500
+ def inference_mcg(self, *args):
501
+ imctx = [args[0:5], args[5:10], args[10:15], args[15:20]]
502
+ textctx, textstrength, seed = args[20:]
503
+
504
+ n_samples = self.n_sample_image
505
+ sampler = self.sampler
506
+ h, w = self.output_dim
507
+ device = self.net.device
508
+
509
+ c_info_list = []
510
+
511
+ if (textctx is not None) and (textctx != "") and (textstrength != 0):
512
+ ut = self.net.ctx_encode([""], which='text').repeat(n_samples, 1, 1)
513
+ ct = self.net.ctx_encode([textctx], which='text').repeat(n_samples, 1, 1)
514
+ scale = self.scale_imgto*(1-textstrength) + self.scale_textto*textstrength
515
+
516
+ c_info_list.append({
517
+ 'type':'text',
518
+ 'conditioning':ct,
519
+ 'unconditional_conditioning':ut,
520
+ 'unconditional_guidance_scale':scale,
521
+ 'ratio': textstrength, })
522
+ else:
523
+ scale = self.scale_imgto
524
+ textstrength = 0
525
+
526
+ input_save = []
527
+ imc = []
528
+ for im, imm, strength, fcs_lvl, use_mask in imctx:
529
+ if (im is None) and (imm is None):
530
+ continue
531
+ BILINEAR = PIL.Image.Resampling.BILINEAR
532
+ BICUBIC = PIL.Image.Resampling.BICUBIC
533
+ if use_mask:
534
+ cx = imm['image'].resize([w, h], resample=BICUBIC)
535
+ cx = tvtrans.ToTensor()(cx)[None].to(self.dtype).to(device)
536
+ m = imm['mask'].resize([w, h], resample=BILINEAR)
537
+ m = tvtrans.ToTensor()(m)[None, 0:1].to(self.dtype).to(device)
538
+ m = (1-m)
539
+ cx_show = cx*m
540
+ ci = self.net.ctx_encode(cx, which='image', masks=m)
541
+ else:
542
+ cx = im.resize([w, h], resample=BICUBIC)
543
+ cx = tvtrans.ToTensor()(cx)[None].to(self.dtype).to(device)
544
+ ci = self.net.ctx_encode(cx, which='image')
545
+ cx_show = cx
546
+
547
+ input_save.append(tvtrans.ToPILImage()(cx_show[0]))
548
+
549
+ if self.disentanglement_noglobal:
550
+ ci_glb = ci[:, 0:1]
551
+ ci_loc = ci[:, 1: ]
552
+ ci_loc = self.adjust_rank_f(ci_loc, fcs_lvl)
553
+ ci = torch.cat([ci_glb, ci_loc], dim=1).repeat(n_samples, 1, 1)
554
+ else:
555
+ ci = self.adjust_rank_f(ci, fcs_lvl).repeat(n_samples, 1, 1)
556
+ imc.append(ci * strength)
557
+
558
+ cis = torch.cat(imc, dim=1)
559
+ c_info_list.append({
560
+ 'type':'image',
561
+ 'conditioning':cis,
562
+ 'unconditional_conditioning':torch.zeros_like(cis),
563
+ 'unconditional_guidance_scale':scale,
564
+ 'ratio': (1-textstrength), })
565
+
566
+ shape = [n_samples, self.image_latent_dim, h//8, w//8]
567
+ np.random.seed(seed)
568
+ torch.manual_seed(seed + 100)
569
+ x, _ = sampler.sample_multicontext(
570
+ steps=self.ddim_steps,
571
+ x_info={'type':'image',},
572
+ c_info_list=c_info_list,
573
+ shape=shape,
574
+ verbose=False,
575
+ eta=self.ddim_eta)
576
+
577
+ imout = self.net.vae_decode(x, which='image')
578
+ imout = [tvtrans.ToPILImage()(i) for i in imout]
579
+ return input_save, imout
580
+
581
+ # vd_inference = vd_dummy()
582
+ vd_inference = vd_inference(which='v1.0', fp16=True)
583
+
584
+ #################
585
+ # sub interface #
586
+ #################
587
+
588
+ def t2i_interface(with_example=False):
589
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Text-to-Image") + '</p>')
590
+ with gr.Row():
591
+ with gr.Column():
592
+ text = gr.Textbox(lines=4, placeholder="Input prompt...", label='Text Input')
593
+ seed = gr.Number(20, label="Seed", precision=0)
594
+ button = gr.Button("Run")
595
+ with gr.Column():
596
+ img_output = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image)
597
+
598
+ button.click(
599
+ vd_inference.inference_t2i,
600
+ inputs=[text, seed],
601
+ outputs=[img_output])
602
+
603
+ if with_example:
604
+ gr.Examples(
605
+ label='Examples',
606
+ examples=get_example('Text-to-Image'),
607
+ fn=vd_inference.inference_t2i,
608
+ inputs=[text, seed],
609
+ outputs=[img_output],
610
+ cache_examples=cache_examples),
611
+
612
+ def i2i_interface(with_example=False):
613
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Image-Variation") + '</p>')
614
+ with gr.Row():
615
+ with gr.Column():
616
+ img_input = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
617
+ sim_flag = gr.Checkbox(label='Show Detail Controls')
618
+ with gr.Row():
619
+ fid_lvl = gr.Slider(label="Fidelity (Dislike -- Same)", minimum=0, maximum=1, value=0, step=0.02, visible=False)
620
+ fcs_lvl = gr.Slider(label="Focus (Semantic -- Style)", minimum=0, maximum=1, value=0.5, step=0.02, visible=False)
621
+ clr_adj = gr.Radio(label="Color Adjustment", choices=["None", "Simple"], value='Simple', visible=False)
622
+ explain = gr.HTML('<p id=myinst>&nbsp Fidelity: How likely the output image looks like the referece image (0-dislike (default), 1-same).</p>'+
623
+ '<p id=myinst>&nbsp Focus: What the output image should focused on (0-semantic, 0.5-balanced (default), 1-style).</p>',
624
+ visible=False)
625
+ seed = gr.Number(20, label="Seed", precision=0)
626
+ button = gr.Button("Run")
627
+ with gr.Column():
628
+ img_output = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image)
629
+
630
+ sim_flag.change(
631
+ fn=lambda x: {
632
+ explain : gr.update(visible=x),
633
+ fid_lvl : gr.update(visible=x),
634
+ fcs_lvl : gr.update(visible=x),
635
+ clr_adj : gr.update(visible=x), },
636
+ inputs=sim_flag,
637
+ outputs=[explain, fid_lvl, fcs_lvl, clr_adj, seed],)
638
+
639
+ button.click(
640
+ vd_inference.inference_i2i,
641
+ inputs=[img_input, fid_lvl, fcs_lvl, clr_adj, seed],
642
+ outputs=[img_output])
643
+
644
+ if with_example:
645
+ gr.Examples(
646
+ label='Examples',
647
+ examples=get_example('Image-Variation'),
648
+ fn=vd_inference.inference_i2i,
649
+ inputs=[img_input, fid_lvl, fcs_lvl, clr_adj, seed],
650
+ outputs=[img_output],
651
+ cache_examples=cache_examples),
652
+
653
+ def i2t_interface(with_example=False):
654
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Image-to-Text") + '</p>')
655
+ with gr.Row():
656
+ with gr.Column():
657
+ img_input = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
658
+ seed = gr.Number(20, label="Seed", precision=0)
659
+ button = gr.Button("Run")
660
+ with gr.Column():
661
+ txt_output = gr.Textbox(lines=4, label='Text Result')
662
+
663
+ button.click(
664
+ vd_inference.inference_i2t,
665
+ inputs=[img_input, seed],
666
+ outputs=[txt_output])
667
+
668
+ if with_example:
669
+ gr.Examples(
670
+ label='Examples',
671
+ examples=get_example('Image-to-Text'),
672
+ fn=vd_inference.inference_i2t,
673
+ inputs=[img_input, seed],
674
+ outputs=[txt_output],
675
+ cache_examples=cache_examples),
676
+
677
+ def t2t_interface(with_example=False):
678
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Text-Variation") + '</p>')
679
+ with gr.Row():
680
+ with gr.Column():
681
+ text = gr.Textbox(lines=4, placeholder="Input prompt...", label='Text Input')
682
+ seed = gr.Number(20, label="Seed", precision=0)
683
+ button = gr.Button("Run")
684
+ with gr.Column():
685
+ txt_output = gr.Textbox(lines=4, label='Text Result')
686
+
687
+ button.click(
688
+ vd_inference.inference_t2t,
689
+ inputs=[text, seed],
690
+ outputs=[txt_output])
691
+
692
+ if with_example:
693
+ gr.Examples(
694
+ label='Examples',
695
+ examples=get_example('Text-Variation'),
696
+ fn=vd_inference.inference_t2t,
697
+ inputs=[text, seed],
698
+ outputs=[txt_output],
699
+ cache_examples=cache_examples, )
700
+
701
+ class image_mimage_swap(object):
702
+ def __init__(self, block0, block1):
703
+ self.block0 = block0
704
+ self.block1 = block1
705
+ self.which_update = 'both'
706
+
707
+ def __call__(self, x0, x1, flag):
708
+ if self.which_update == 'both':
709
+ return self.update_both(x0, x1, flag)
710
+ elif self.which_update == 'visible':
711
+ return self.update_visible(x0, x1, flag)
712
+ elif self.which_update == 'visible_oneoff':
713
+ return self.update_visible_oneoff(x0, x1, flag)
714
+ else:
715
+ assert False
716
+
717
+ def update_both(self, x0, x1, flag):
718
+ if flag:
719
+ ug0 = gr.update(visible=False)
720
+ if x0 is None:
721
+ ug1 = gr.update(value=None, visible=True)
722
+ else:
723
+ if (x1 is not None) and ('mask' in x1):
724
+ value1 = {'image':x0, 'mask':x1['mask']}
725
+ else:
726
+ value1 = {'image':x0, 'mask':None}
727
+ ug1 = gr.update(value=value1, visible=True)
728
+ else:
729
+ if (x1 is not None) and ('image' in x1):
730
+ value0 = x1['image']
731
+ else:
732
+ value0 = None
733
+ ug0 = gr.update(value=value0, visible=True)
734
+ ug1 = gr.update(visible=False)
735
+ return {
736
+ self.block0 : ug0,
737
+ self.block1 : ug1,}
738
+
739
+ def update_visible(self, x0, x1, flag):
740
+ return {
741
+ self.block0 : gr.update(visible=not flag),
742
+ self.block1 : gr.update(visible=flag), }
743
+
744
+ def update_visible_oneoff(self, x0, x1, flag):
745
+ self.which_update = 'both'
746
+ return {
747
+ self.block0 : gr.update(visible=not flag),
748
+ self.block1 : gr.update(visible=flag), }
749
+
750
+ class example_visible_only_hack(object):
751
+ def __init__(self, checkbox_list, functor_list):
752
+ self.checkbox_list = checkbox_list
753
+ self.functor_list = functor_list
754
+
755
+ def __call__(self, *args):
756
+ for bi, fi, vi in zip(self.checkbox_list, self.functor_list, args):
757
+ if bi.value != vi:
758
+ fi.which_update = 'visible_oneoff'
759
+
760
+ def dcg_interface(with_example=False):
761
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Dual-Context") + '</p>')
762
+ with gr.Row():
763
+ input_session = []
764
+ with gr.Column():
765
+ img = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
766
+ fcs = gr.Slider(label="Focus (Semantic -- Style)", minimum=0, maximum=1, value=0.5, step=0.02)
767
+ gr.HTML('<p id=myinst>&nbsp Focus: Focus on what aspect of the image? (0-semantic, 0.5-balanced (default), 1-style).</p>')
768
+
769
+ text = gr.Textbox(lines=2, placeholder="Input prompt...", label='Text Input')
770
+ tstrength = gr.Slider(label="Text Domination (NoEffect -- TextOnly)", minimum=0, maximum=1, value=0, step=0.02)
771
+
772
+ seed = gr.Number(20, label="Seed", precision=0)
773
+ button = gr.Button("Run")
774
+
775
+ with gr.Column():
776
+ output_gallary = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image)
777
+
778
+ input_list = []
779
+ for i in input_session:
780
+ input_list += i
781
+ button.click(
782
+ vd_inference.inference_dcg,
783
+ inputs=[img, fcs, text, tstrength, seed],
784
+ outputs=[output_gallary])
785
+
786
+ if with_example:
787
+ gr.Examples(
788
+ label='Examples',
789
+ examples=get_example('Dual-Context'),
790
+ fn=vd_inference.inference_dcg,
791
+ inputs=[img, fcs, text, tstrength, seed],
792
+ outputs=[output_gallary],
793
+ cache_examples=cache_examples)
794
+
795
+ def tcg_interface(with_example=False):
796
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Triple-Context") + '</p>')
797
+ with gr.Row():
798
+ input_session = []
799
+ with gr.Column(min_width=940):
800
+ with gr.Row():
801
+ with gr.Column():
802
+ img0 = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
803
+ img0.as_example = types.MethodType(customized_as_example, img0)
804
+ imgm0 = gr.Image(label='Image Input with Mask', type='pil', elem_id='customized_imbox', tool='sketch', source="upload", visible=False)
805
+ imgm0.postprocess = types.MethodType(customized_postprocess, imgm0)
806
+ imgm0.as_example = types.MethodType(customized_as_example, imgm0)
807
+ istrength0 = gr.Slider(label="Weight", minimum=0, maximum=1, value=1, step=0.02)
808
+ fcs0 = gr.Slider(label="Focus (Semantic -- Style)", minimum=0, maximum=1, value=0.5, step=0.02)
809
+ msk0 = gr.Checkbox(label='Use mask?')
810
+ swapf0 = image_mimage_swap(img0, imgm0)
811
+
812
+ msk0.change(
813
+ fn=swapf0,
814
+ inputs=[img0, imgm0, msk0],
815
+ outputs=[img0, imgm0],)
816
+ input_session.append([img0, imgm0, istrength0, fcs0, msk0])
817
+
818
+ with gr.Column():
819
+ img1 = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
820
+ img1.as_example = types.MethodType(customized_as_example, img1)
821
+ imgm1 = gr.Image(label='Image Input with Mask', type='pil', elem_id='customized_imbox', tool='sketch', source="upload", visible=False)
822
+ imgm1.postprocess = types.MethodType(customized_postprocess, imgm1)
823
+ imgm1.as_example = types.MethodType(customized_as_example, imgm1)
824
+ istrength1 = gr.Slider(label="Weight", minimum=0, maximum=1, value=1, step=0.02)
825
+ fcs1 = gr.Slider(label="Focus (Semantic -- Style)", minimum=0, maximum=1, value=0.5, step=0.02)
826
+ msk1 = gr.Checkbox(label='Use mask?')
827
+ swapf1 = image_mimage_swap(img1, imgm1)
828
+
829
+ msk1.change(
830
+ fn=swapf1,
831
+ inputs=[img1, imgm1, msk1],
832
+ outputs=[img1, imgm1],)
833
+ input_session.append([img1, imgm1, istrength1, fcs1, msk1])
834
+
835
+ gr.HTML('<p id=myinst>&nbsp Weight: The strength of the reference image. This weight is subject to <u>Text Domination</u>).</p>'+
836
+ '<p id=myinst>&nbsp Focus: Focus on what aspect of the image? (0-semantic, 0.5-balanced (default), 1-style).</p>'+
837
+ '<p id=myinst>&nbsp Mask: Remove regions on reference image so they will not influence the output.</p>',)
838
+
839
+ text = gr.Textbox(lines=2, placeholder="Input prompt...", label='Text Input')
840
+ tstrength = gr.Slider(label="Text Domination (NoEffect -- TextOnly)", minimum=0, maximum=1, value=0, step=0.02)
841
+
842
+ seed = gr.Number(20, label="Seed", precision=0)
843
+ button = gr.Button("Run")
844
+
845
+ with gr.Column(min_width=470):
846
+ input_gallary = gr.Gallery(label="Input Display", elem_id="customized_imbox").style(grid=2)
847
+ output_gallary = gr.Gallery(label="Image Result", elem_id="customized_imbox").style(grid=n_sample_image)
848
+
849
+ input_list = []
850
+ for i in input_session:
851
+ input_list += i
852
+ input_list += [text, tstrength, seed]
853
+ button.click(
854
+ vd_inference.inference_tcg,
855
+ inputs=input_list,
856
+ outputs=[input_gallary, output_gallary])
857
+
858
+ if with_example:
859
+ create_myexamples(
860
+ label='Examples',
861
+ examples=get_example('Triple-Context'),
862
+ fn=vd_inference.inference_tcg,
863
+ inputs=input_list,
864
+ outputs=[input_gallary, output_gallary, ],
865
+ cache_examples=cache_examples, )
866
+
867
+ gr.HTML('<br><p id=myinst>&nbsp How to add mask: Please see the following instructions.</p><br>'+
868
+ '<div id="maskinst">'+
869
+ '<img src="file/assets/demo/misc/mask_inst1.gif">'+
870
+ '<img src="file/assets/demo/misc/mask_inst2.gif">'+
871
+ '<img src="file/assets/demo/misc/mask_inst3.gif">'+
872
+ '</div>')
873
+
874
+ def mcg_interface(with_example=False):
875
+ num_img_input = 4
876
+ gr.HTML('<p id=myinst>&nbsp Description: ' + get_instruction("Multi-Context") + '</p>')
877
+ with gr.Row():
878
+ input_session = []
879
+ with gr.Column():
880
+ for idx in range(num_img_input):
881
+ with gr.Tab('Image{}'.format(idx+1)):
882
+ img = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
883
+ img.as_example = types.MethodType(customized_as_example, img)
884
+ imgm = gr.Image(label='Image Input with Mask', type='pil', elem_id='customized_imbox', tool='sketch', source="upload", visible=False)
885
+ imgm.postprocess = types.MethodType(customized_postprocess, imgm)
886
+ imgm.as_example = types.MethodType(customized_as_example, imgm)
887
+
888
+ with gr.Row():
889
+ istrength = gr.Slider(label="Weight", minimum=0, maximum=1, value=1, step=0.02)
890
+ fcs = gr.Slider(label="Focus (Semantic -- Style)", minimum=0, maximum=1, value=0.5, step=0.02)
891
+ msk = gr.Checkbox(label='Use mask?')
892
+ gr.HTML('<p id=myinst>&nbsp Weight: The strength of the reference image. This weight is subject to <u>Text Domination</u>).</p>'+
893
+ '<p id=myinst>&nbsp Focus: Focus on what aspect of the image? (0-semantic, 0.5-balanced (default), 1-style).</p>'+
894
+ '<p id=myinst>&nbsp Mask: Remove regions on reference image so they will not influence the output.</p>',)
895
+
896
+ msk.change(
897
+ fn=image_mimage_swap(img, imgm),
898
+ inputs=[img, imgm, msk],
899
+ outputs=[img, imgm],)
900
+ input_session.append([img, imgm, istrength, fcs, msk])
901
+
902
+ text = gr.Textbox(lines=2, placeholder="Input prompt...", label='Text Input')
903
+ tstrength = gr.Slider(label="Text Domination (NoEffect -- TextOnly)", minimum=0, maximum=1, value=0, step=0.02)
904
+
905
+ seed = gr.Number(20, label="Seed", precision=0)
906
+ button = gr.Button("Run")
907
+
908
+
909
+ with gr.Column():
910
+ input_gallary = gr.Gallery(label="Input Display", elem_id='customized_imbox').style(grid=4)
911
+ output_gallary = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image)
912
+
913
+ input_list = []
914
+ for i in input_session:
915
+ input_list += i
916
+ input_list += [text, tstrength, seed]
917
+ button.click(
918
+ vd_inference.inference_mcg,
919
+ inputs=input_list,
920
+ outputs=[input_gallary, output_gallary], )
921
+
922
+ if with_example:
923
+ create_myexamples(
924
+ label='Examples',
925
+ examples=get_example('Multi-Context'),
926
+ fn=vd_inference.inference_mcg,
927
+ inputs=input_list,
928
+ outputs=[input_gallary, output_gallary],
929
+ cache_examples=cache_examples, )
930
+
931
+ gr.HTML('<br><p id=myinst>&nbsp How to add mask: Please see the following instructions.</p><br>'+
932
+ '<div id="maskinst">'+
933
+ '<img src="file/assets/demo/misc/mask_inst1.gif">'+
934
+ '<img src="file/assets/demo/misc/mask_inst2.gif">'+
935
+ '<img src="file/assets/demo/misc/mask_inst3.gif">'+
936
+ '</div>')
937
+
938
+ ###########
939
+ # Example #
940
+ ###########
941
+
942
+ def get_example(mode):
943
+ if mode == 'Text-to-Image':
944
+ case = [
945
+ ['a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ', 23],
946
+ ['a beautiful landscape with mountains and rivers', 20],
947
+ ]
948
+ elif mode == "Image-Variation":
949
+ case = [
950
+ ['assets/demo/reg_example/ghibli.jpg', 0, 0.5, 'None', 20],
951
+ ['assets/demo/reg_example/ghibli.jpg', 0.5, 0.5, 'None', 20],
952
+ ['assets/demo/reg_example/matisse.jpg', 0, 0, 'None', 20],
953
+ ['assets/demo/reg_example/matisse.jpg', 0, 1, 'Simple', 20],
954
+ ['assets/demo/reg_example/vermeer.jpg', 0.2, 0.3, 'None', 30],
955
+ ]
956
+ elif mode == "Image-to-Text":
957
+ case = [
958
+ ['assets/demo/reg_example/house_by_lake.jpg', 20],
959
+ ]
960
+ elif mode == "Text-Variation":
961
+ case = [
962
+ ['heavy arms gundam penguin mech', 20],
963
+ ]
964
+ elif mode == "Dual-Context":
965
+ case = [
966
+ ['assets/demo/reg_example/benz.jpg', 0.5, 'cyberpunk 2077', 0.7, 22],
967
+ ['assets/demo/reg_example/ghibli.jpg', 1, 'Red maple on a hill in golden Autumn.', 0.5, 21],
968
+ ]
969
+ elif mode == "Triple-Context":
970
+ case = [
971
+ [
972
+ 'assets/demo/reg_example/night_light.jpg', None, 1 , 0.5, False,
973
+ 'assets/demo/reg_example/paris.jpg' , None, 0.94, 0.5, False,
974
+ "snow on the street", 0.4, 28],
975
+ [
976
+ 'assets/demo/tcg_example/e1i0.jpg', None, 1 , 0.5, False,
977
+ 'assets/demo/tcg_example/e1i1.jpg', None, 0.94, 0.5, False,
978
+ "a painting of an elegant woman in front of the moon", 0.2, 217],
979
+ [
980
+ 'assets/demo/tcg_example/e2i0.jpg', None, 1, 0.5, False,
981
+ 'assets/demo/reg_example/paris.jpg', None, 1, 0.5, False,
982
+ "", 0, 29],
983
+ [
984
+ 'assets/demo/tcg_example/e0i0.jpg', None, 1 , 0.5, False,
985
+ 'assets/demo/tcg_example/e0i1.jpg', None, 0.9, 0.5, False,
986
+ "rose blooms on the tree", 0.2, 20],
987
+ [
988
+ 'assets/demo/reg_example/ghibli.jpg', None, 1 , 1 , False,
989
+ 'assets/demo/reg_example/space.jpg' , None, 0.88, 0.5, False,
990
+ "", 0, 20],
991
+ [
992
+ 'assets/demo/reg_example/train.jpg' , None, 0.8, 0.5, False,
993
+ 'assets/demo/reg_example/matisse.jpg', None, 1 , 1 , False,
994
+ "", 0, 20],
995
+ ]
996
+ elif mode == "Multi-Context":
997
+ case = [
998
+ [
999
+ 'assets/demo/mcg_example/e0i0.jpg', None, 1, 0.5, False,
1000
+ 'assets/demo/mcg_example/e0i1.jpg', None, 1, 0.5, False,
1001
+ 'assets/demo/mcg_example/e0i2.jpg', None, 0.86, 0.5, False,
1002
+ None, None, 1, 0.5, False,
1003
+ "", 0, 20],
1004
+ ]
1005
+ else:
1006
+ raise ValueError
1007
+ return case
1008
+
1009
+ #############
1010
+ # Interface #
1011
+ #############
1012
+
1013
+ css = """
1014
+ #customized_imbox {
1015
+ min-height: 450px;
1016
+ }
1017
+ #customized_imbox>div[data-testid="image"] {
1018
+ min-height: 450px;
1019
+ }
1020
+ #customized_imbox>div[data-testid="image"]>div {
1021
+ min-height: 450px;
1022
+ }
1023
+ #customized_imbox>div[data-testid="image"]>iframe {
1024
+ min-height: 450px;
1025
+ }
1026
+ #customized_imbox>div.unpadded_box {
1027
+ min-height: 450px;
1028
+ }
1029
+ #myinst {
1030
+ font-size: 0.8rem;
1031
+ margin: 0rem;
1032
+ color: #6B7280;
1033
+ }
1034
+ #maskinst {
1035
+ text-align: justify;
1036
+ min-width: 1200px;
1037
+ }
1038
+ #maskinst>img {
1039
+ min-width:399px;
1040
+ max-width:450px;
1041
+ vertical-align: top;
1042
+ display: inline-block;
1043
+ }
1044
+ #maskinst:after {
1045
+ content: "";
1046
+ width: 100%;
1047
+ display: inline-block;
1048
+ }
1049
+ """
1050
+
1051
+ if True:
1052
+ with gr.Blocks(css=css) as demo:
1053
+ gr.HTML(
1054
+ """
1055
+ <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
1056
+ <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
1057
+ Versatile Diffusion
1058
+ </h1>
1059
+ <h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
1060
+ We built <b>Versatile Diffusion (VD), the first unified multi-flow multimodal diffusion framework</b>, as a step towards <b>Universal Generative AI</b>.
1061
+ VD can natively support image-to-text, image-variation, text-to-image, and text-variation,
1062
+ and can be further extended to other applications such as
1063
+ semantic-style disentanglement, image-text dual-guided generation, latent image-to-text-to-image editing, and more.
1064
+ Future versions will support more modalities such as speech, music, video and 3D.
1065
+ </h2>
1066
+ <h3 style="font-weight: 450; font-size: 1rem; margin: 0rem">
1067
+ Xingqian Xu, Atlas Wang, Eric Zhang, Kai Wang,
1068
+ and <a href="https://www.humphreyshi.com/home">Humphrey Shi</a>
1069
+ [<a href="https://arxiv.org/abs/2211.08332" style="color:blue;">arXiv</a>]
1070
+ [<a href="https://github.com/SHI-Labs/Versatile-Diffusion" style="color:blue;">GitHub</a>]
1071
+ </h3>
1072
+ </div>
1073
+ """)
1074
+
1075
+ with gr.Tab('Text-to-Image'):
1076
+ t2i_interface(with_example=True)
1077
+ with gr.Tab('Image-Variation'):
1078
+ i2i_interface(with_example=True)
1079
+ with gr.Tab('Image-to-Text'):
1080
+ i2t_interface(with_example=True)
1081
+ with gr.Tab('Text-Variation'):
1082
+ t2t_interface(with_example=True)
1083
+ with gr.Tab('Dual-Context Image-Generation'):
1084
+ dcg_interface(with_example=True)
1085
+ with gr.Tab('Triple-Context Image-Blender'):
1086
+ tcg_interface(with_example=True)
1087
+ with gr.Tab('Multi-Context Image-Blender'):
1088
+ mcg_interface(with_example=True)
1089
+
1090
+ gr.HTML(
1091
+ """
1092
+ <div style="text-align: justify; max-width: 1200px; margin: 20px auto;">
1093
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
1094
+ <b>Version</b>: {}
1095
+ </h3>
1096
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
1097
+ <b>Caution</b>:
1098
+ We would like the raise the awareness of users of this demo of its potential issues and concerns.
1099
+ Like previous large foundation models, Versatile Diffusion could be problematic in some cases, partially due to the imperfect training data and pretrained network (VAEs / context encoders) with limited scope.
1100
+ In its future research phase, VD may do better on tasks such as text-to-image, image-to-text, etc., with the help of more powerful VAEs, more sophisticated network designs, and more cleaned data.
1101
+ So far, we keep all features available for research testing both to show the great potential of the VD framework and to collect important feedback to improve the model in the future.
1102
+ We welcome researchers and users to report issues with the HuggingFace community discussion feature or email the authors.
1103
+ </h3>
1104
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
1105
+ <b>Biases and content acknowledgement</b>:
1106
+ Beware that VD may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography, and violence.
1107
+ VD was trained on the LAION-2B dataset, which scraped non-curated online images and text, and may contained unintended exceptions as we removed illegal content.
1108
+ VD in this demo is meant only for research purposes.
1109
+ </h3>
1110
+ </div>
1111
+ """.format(' '+vd_inference.which))
1112
+
1113
+ # demo.launch(share=True)
1114
+ demo.launch(debug=True)