Ahsen Khaliq commited on
Commit
9a34aef
Β·
1 Parent(s): 5f90927

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -104
app.py CHANGED
@@ -1,11 +1,13 @@
1
  import torch
2
  torch.hub.download_url_to_file('http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.yaml', 'vqgan_imagenet_f16_16384.yaml')
3
  torch.hub.download_url_to_file('http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.ckpt', 'vqgan_imagenet_f16_16384.ckpt')
 
 
 
4
  import argparse
5
  import math
6
  from pathlib import Path
7
  import sys
8
-
9
  sys.path.insert(1, './taming-transformers')
10
  #from IPython import display
11
  from base64 import b64encode
@@ -18,7 +20,6 @@ from torch.nn import functional as F
18
  from torchvision import transforms
19
  from torchvision.transforms import functional as TF
20
  from tqdm.notebook import tqdm
21
-
22
  from CLIP import clip
23
  import kornia.augmentation as K
24
  import numpy as np
@@ -26,17 +27,12 @@ import imageio
26
  from PIL import ImageFile, Image
27
  ImageFile.LOAD_TRUNCATED_IMAGES = True
28
  import gradio as gr
29
-
30
  def sinc(x):
31
  return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
32
-
33
-
34
  def lanczos(x, a):
35
  cond = torch.logical_and(-a < x, x < a)
36
  out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
37
  return out / out.sum()
38
-
39
-
40
  def ramp(ratio, width):
41
  n = math.ceil(width / ratio + 1)
42
  out = torch.empty([n])
@@ -45,44 +41,31 @@ def ramp(ratio, width):
45
  out[i] = cur
46
  cur += ratio
47
  return torch.cat([-out[1:].flip([0]), out])[1:-1]
48
-
49
-
50
  def resample(input, size, align_corners=True):
51
  n, c, h, w = input.shape
52
  dh, dw = size
53
-
54
  input = input.view([n * c, 1, h, w])
55
-
56
  if dh < h:
57
  kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
58
  pad_h = (kernel_h.shape[0] - 1) // 2
59
  input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
60
  input = F.conv2d(input, kernel_h[None, None, :, None])
61
-
62
  if dw < w:
63
  kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
64
  pad_w = (kernel_w.shape[0] - 1) // 2
65
  input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
66
  input = F.conv2d(input, kernel_w[None, None, None, :])
67
-
68
  input = input.view([n, c, h, w])
69
  return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
70
-
71
-
72
  class ReplaceGrad(torch.autograd.Function):
73
  @staticmethod
74
  def forward(ctx, x_forward, x_backward):
75
  ctx.shape = x_backward.shape
76
  return x_forward
77
-
78
  @staticmethod
79
  def backward(ctx, grad_in):
80
  return None, grad_in.sum_to_size(ctx.shape)
81
-
82
-
83
  replace_grad = ReplaceGrad.apply
84
-
85
-
86
  class ClampWithGrad(torch.autograd.Function):
87
  @staticmethod
88
  def forward(ctx, input, min, max):
@@ -90,51 +73,38 @@ class ClampWithGrad(torch.autograd.Function):
90
  ctx.max = max
91
  ctx.save_for_backward(input)
92
  return input.clamp(min, max)
93
-
94
  @staticmethod
95
  def backward(ctx, grad_in):
96
  input, = ctx.saved_tensors
97
  return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None
98
-
99
-
100
  clamp_with_grad = ClampWithGrad.apply
101
-
102
-
103
  def vector_quantize(x, codebook):
104
  d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
105
  indices = d.argmin(-1)
106
  x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
107
  return replace_grad(x_q, x)
108
-
109
-
110
  class Prompt(nn.Module):
111
  def __init__(self, embed, weight=1., stop=float('-inf')):
112
  super().__init__()
113
  self.register_buffer('embed', embed)
114
  self.register_buffer('weight', torch.as_tensor(weight))
115
  self.register_buffer('stop', torch.as_tensor(stop))
116
-
117
  def forward(self, input):
118
  input_normed = F.normalize(input.unsqueeze(1), dim=2)
119
  embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
120
  dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
121
  dists = dists * self.weight.sign()
122
  return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
123
-
124
-
125
  def parse_prompt(prompt):
126
  vals = prompt.rsplit(':', 2)
127
  vals = vals + ['', '1', '-inf'][len(vals):]
128
  return vals[0], float(vals[1]), float(vals[2])
129
-
130
-
131
  class MakeCutouts(nn.Module):
132
  def __init__(self, cut_size, cutn, cut_pow=1.):
133
  super().__init__()
134
  self.cut_size = cut_size
135
  self.cutn = cutn
136
  self.cut_pow = cut_pow
137
-
138
  self.augs = nn.Sequential(
139
  # K.RandomHorizontalFlip(p=0.5),
140
  # K.RandomVerticalFlip(p=0.5),
@@ -151,7 +121,6 @@ class MakeCutouts(nn.Module):
151
  self.noise_fac = 0.1
152
  self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
153
  self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
154
-
155
  def forward(self, input):
156
  sideY, sideX = input.shape[2:4]
157
  max_size = min(sideX, sideY)
@@ -159,13 +128,11 @@ class MakeCutouts(nn.Module):
159
  cutouts = []
160
 
161
  for _ in range(self.cutn):
162
-
163
  # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
164
  # offsetx = torch.randint(0, sideX - size + 1, ())
165
  # offsety = torch.randint(0, sideY - size + 1, ())
166
  # cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
167
  # cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
168
-
169
  # cutout = transforms.Resize(size=(self.cut_size, self.cut_size))(input)
170
 
171
  cutout = (self.av_pool(input) + self.max_pool(input))/2
@@ -175,8 +142,6 @@ class MakeCutouts(nn.Module):
175
  facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
176
  batch = batch + facs * torch.randn_like(batch)
177
  return batch
178
-
179
-
180
  def load_vqgan_model(config_path, checkpoint_path):
181
  config = OmegaConf.load(config_path)
182
  if config.model.target == 'taming.models.vqgan.VQModel':
@@ -196,8 +161,6 @@ def load_vqgan_model(config_path, checkpoint_path):
196
  raise ValueError(f'unknown model type: {config.model.target}')
197
  del model.loss
198
  return model
199
-
200
-
201
  def resize_image(image, out_size):
202
  ratio = image.size[0] / image.size[1]
203
  area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
@@ -205,82 +168,66 @@ def resize_image(image, out_size):
205
  return image.resize(size, Image.LANCZOS)
206
 
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  def inference(text):
209
  texts = text
210
- width = 128
211
- height = 128
212
- model = "vqgan_imagenet_f16_16384"
213
- images_interval = 50
214
- init_image = ""
215
  target_images = ""
216
- seed = 42
217
  max_iterations = 100
218
-
219
  model_names={"vqgan_imagenet_f16_16384": 'ImageNet 16384',"vqgan_imagenet_f16_1024":"ImageNet 1024", 'vqgan_openimages_f16_8192':'OpenImages 8912',
220
  "wikiart_1024":"WikiArt 1024", "wikiart_16384":"WikiArt 16384", "coco":"COCO-Stuff", "faceshq":"FacesHQ", "sflckr":"S-FLCKR"}
221
- name_model = model_names[model]
222
-
223
- if seed == -1:
224
- seed = None
225
- if init_image == "None":
226
- init_image = None
227
  if target_images == "None" or not target_images:
228
  target_images = []
229
  else:
230
  target_images = target_images.split("|")
231
  target_images = [image.strip() for image in target_images]
232
-
233
  texts = [phrase.strip() for phrase in texts.split("|")]
234
  if texts == ['']:
235
  texts = []
236
-
237
-
238
- args = argparse.Namespace(
239
- prompts=texts,
240
- image_prompts=target_images,
241
- noise_prompt_seeds=[],
242
- noise_prompt_weights=[],
243
- size=[width, height],
244
- init_image=init_image,
245
- init_weight=0.,
246
- clip_model='ViT-B/32',
247
- vqgan_config=f'{model}.yaml',
248
- vqgan_checkpoint=f'{model}.ckpt',
249
- step_size=0.1,
250
- cutn=1,
251
- cut_pow=1.,
252
- display_freq=images_interval,
253
- seed=seed,
254
- )
255
  from urllib.request import urlopen
256
-
257
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
258
- print('Using device:', device)
259
  if texts:
260
  print('Using texts:', texts)
261
  if target_images:
262
  print('Using image prompts:', target_images)
263
- if args.seed is None:
264
  seed = torch.seed()
265
  else:
266
  seed = args.seed
267
  torch.manual_seed(seed)
268
  print('Using seed:', seed)
269
-
270
- model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
271
- perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)
272
  # clock=deepcopy(perceptor.visual.positional_embedding.data)
273
  # perceptor.visual.positional_embedding.data = clock/clock.max()
274
  # perceptor.visual.positional_embedding.data=clamp_with_grad(clock,0,1)
275
-
276
  cut_size = perceptor.visual.input_resolution
277
-
278
  f = 2**(model.decoder.num_resolutions - 1)
279
  make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)
280
-
281
  toksX, toksY = args.size[0] // f, args.size[1] // f
282
  sideX, sideY = toksX * f, toksY * f
283
-
284
  if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':
285
  e_dim = 256
286
  n_toks = model.quantize.n_embed
@@ -293,10 +240,8 @@ def inference(text):
293
  z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
294
  # z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
295
  # z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
296
-
297
  # normalize_imagenet = transforms.Normalize(mean=[0.485, 0.456, 0.406],
298
  # std=[0.229, 0.224, 0.225])
299
-
300
  if args.init_image:
301
  if 'http' in args.init_image:
302
  img = Image.open(urlopen(args.init_image))
@@ -318,20 +263,14 @@ def inference(text):
318
  z_orig = z.clone()
319
  z.requires_grad_(True)
320
  opt = optim.Adam([z], lr=args.step_size)
321
-
322
  normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
323
  std=[0.26862954, 0.26130258, 0.27577711])
324
-
325
-
326
-
327
  pMs = []
328
-
329
- for prompt in args.prompts:
330
  txt, weight, stop = parse_prompt(prompt)
331
  embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
332
  pMs.append(Prompt(embed, weight, stop).to(device))
333
-
334
- for prompt in args.image_prompts:
335
  path, weight, stop = parse_prompt(prompt)
336
  img = Image.open(path)
337
  pil_image = img.convert('RGB')
@@ -339,19 +278,16 @@ def inference(text):
339
  batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
340
  embed = perceptor.encode_image(normalize(batch)).float()
341
  pMs.append(Prompt(embed, weight, stop).to(device))
342
-
343
  for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):
344
  gen = torch.Generator().manual_seed(seed)
345
  embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)
346
  pMs.append(Prompt(embed, weight).to(device))
347
-
348
  def synth(z):
349
  if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':
350
  z_q = vector_quantize(z.movedim(1, 3), model.quantize.embed.weight).movedim(3, 1)
351
  else:
352
  z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)
353
  return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)
354
-
355
  @torch.no_grad()
356
  def checkin(i, losses):
357
  losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
@@ -359,14 +295,12 @@ def inference(text):
359
  out = synth(z)
360
  #TF.to_pil_image(out[0].cpu()).save('progress.png')
361
  #display.display(display.Image('progress.png'))
362
-
363
  def ascend_txt():
364
  # global i
365
  out = synth(z)
366
  iii = perceptor.encode_image(normalize(make_cutouts(out))).float()
367
 
368
  result = []
369
-
370
  if args.init_weight:
371
  # result.append(F.mse_loss(z, z_orig) * args.init_weight / 2)
372
  result.append(F.mse_loss(z, torch.zeros_like(z_orig)) * ((1/torch.tensor(i*2 + 1))*args.init_weight) / 2)
@@ -375,9 +309,7 @@ def inference(text):
375
  img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]
376
  img = np.transpose(img, (1, 2, 0))
377
  #imageio.imwrite('./steps/' + str(i) + '.png', np.array(img))
378
-
379
  return result, np.array(img)
380
-
381
  def train(i):
382
  opt.zero_grad()
383
  lossAll, image = ascend_txt()
@@ -390,7 +322,6 @@ def inference(text):
390
  with torch.no_grad():
391
  z.copy_(z.maximum(z_min).minimum(z_max))
392
  return image
393
-
394
  i = 0
395
  try:
396
  with tqdm() as pbar:
@@ -404,12 +335,29 @@ def inference(text):
404
  pass
405
  return image
406
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  title = "VQGAN + CLIP"
408
  description = "Gradio demo for VQGAN + CLIP. To use it, simply add your text, or click one of the examples to load them. Read more at the links below. Please click submit only once"
409
  article = "<p style='text-align: center'>Originally made by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). The original BigGAN+CLIP method was by https://twitter.com/advadnoun. Added some explanations and modifications by Eleiber#8347, pooling trick by Crimeacs#8222 (https://twitter.com/EarthML1) and the GUI was made with the help of Abulafia#3734. | <a href='https://colab.research.google.com/drive/1ZAus_gn2RhTZWzOWUpPERNC0Q8OhZRTZ'>Colab</a> | <a href='https://github.com/CompVis/taming-transformers'>Taming Transformers Github Repo</a> | <a href='https://github.com/openai/CLIP'>CLIP Github Repo</a></p>"
410
-
411
  gr.Interface(
412
- inference,
413
  gr.inputs.Textbox(label="Input"),
414
  gr.outputs.Image(type="numpy", label="Output"),
415
  title=title,
 
1
  import torch
2
  torch.hub.download_url_to_file('http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.yaml', 'vqgan_imagenet_f16_16384.yaml')
3
  torch.hub.download_url_to_file('http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.ckpt', 'vqgan_imagenet_f16_16384.ckpt')
4
+ # import torch
5
+ # torch.hub.download_url_to_file('http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.yaml', 'vqgan_imagenet_f16_16384.yaml')
6
+ # torch.hub.download_url_to_file('http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.ckpt', 'vqgan_imagenet_f16_16384.ckpt')
7
  import argparse
8
  import math
9
  from pathlib import Path
10
  import sys
 
11
  sys.path.insert(1, './taming-transformers')
12
  #from IPython import display
13
  from base64 import b64encode
 
20
  from torchvision import transforms
21
  from torchvision.transforms import functional as TF
22
  from tqdm.notebook import tqdm
 
23
  from CLIP import clip
24
  import kornia.augmentation as K
25
  import numpy as np
 
27
  from PIL import ImageFile, Image
28
  ImageFile.LOAD_TRUNCATED_IMAGES = True
29
  import gradio as gr
 
30
  def sinc(x):
31
  return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
 
 
32
  def lanczos(x, a):
33
  cond = torch.logical_and(-a < x, x < a)
34
  out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
35
  return out / out.sum()
 
 
36
  def ramp(ratio, width):
37
  n = math.ceil(width / ratio + 1)
38
  out = torch.empty([n])
 
41
  out[i] = cur
42
  cur += ratio
43
  return torch.cat([-out[1:].flip([0]), out])[1:-1]
 
 
44
  def resample(input, size, align_corners=True):
45
  n, c, h, w = input.shape
46
  dh, dw = size
 
47
  input = input.view([n * c, 1, h, w])
 
48
  if dh < h:
49
  kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
50
  pad_h = (kernel_h.shape[0] - 1) // 2
51
  input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
52
  input = F.conv2d(input, kernel_h[None, None, :, None])
 
53
  if dw < w:
54
  kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
55
  pad_w = (kernel_w.shape[0] - 1) // 2
56
  input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
57
  input = F.conv2d(input, kernel_w[None, None, None, :])
 
58
  input = input.view([n, c, h, w])
59
  return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
 
 
60
  class ReplaceGrad(torch.autograd.Function):
61
  @staticmethod
62
  def forward(ctx, x_forward, x_backward):
63
  ctx.shape = x_backward.shape
64
  return x_forward
 
65
  @staticmethod
66
  def backward(ctx, grad_in):
67
  return None, grad_in.sum_to_size(ctx.shape)
 
 
68
  replace_grad = ReplaceGrad.apply
 
 
69
  class ClampWithGrad(torch.autograd.Function):
70
  @staticmethod
71
  def forward(ctx, input, min, max):
 
73
  ctx.max = max
74
  ctx.save_for_backward(input)
75
  return input.clamp(min, max)
 
76
  @staticmethod
77
  def backward(ctx, grad_in):
78
  input, = ctx.saved_tensors
79
  return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None
 
 
80
  clamp_with_grad = ClampWithGrad.apply
 
 
81
  def vector_quantize(x, codebook):
82
  d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
83
  indices = d.argmin(-1)
84
  x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
85
  return replace_grad(x_q, x)
 
 
86
  class Prompt(nn.Module):
87
  def __init__(self, embed, weight=1., stop=float('-inf')):
88
  super().__init__()
89
  self.register_buffer('embed', embed)
90
  self.register_buffer('weight', torch.as_tensor(weight))
91
  self.register_buffer('stop', torch.as_tensor(stop))
 
92
  def forward(self, input):
93
  input_normed = F.normalize(input.unsqueeze(1), dim=2)
94
  embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
95
  dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
96
  dists = dists * self.weight.sign()
97
  return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
 
 
98
  def parse_prompt(prompt):
99
  vals = prompt.rsplit(':', 2)
100
  vals = vals + ['', '1', '-inf'][len(vals):]
101
  return vals[0], float(vals[1]), float(vals[2])
 
 
102
  class MakeCutouts(nn.Module):
103
  def __init__(self, cut_size, cutn, cut_pow=1.):
104
  super().__init__()
105
  self.cut_size = cut_size
106
  self.cutn = cutn
107
  self.cut_pow = cut_pow
 
108
  self.augs = nn.Sequential(
109
  # K.RandomHorizontalFlip(p=0.5),
110
  # K.RandomVerticalFlip(p=0.5),
 
121
  self.noise_fac = 0.1
122
  self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
123
  self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
 
124
  def forward(self, input):
125
  sideY, sideX = input.shape[2:4]
126
  max_size = min(sideX, sideY)
 
128
  cutouts = []
129
 
130
  for _ in range(self.cutn):
 
131
  # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
132
  # offsetx = torch.randint(0, sideX - size + 1, ())
133
  # offsety = torch.randint(0, sideY - size + 1, ())
134
  # cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
135
  # cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
 
136
  # cutout = transforms.Resize(size=(self.cut_size, self.cut_size))(input)
137
 
138
  cutout = (self.av_pool(input) + self.max_pool(input))/2
 
142
  facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
143
  batch = batch + facs * torch.randn_like(batch)
144
  return batch
 
 
145
  def load_vqgan_model(config_path, checkpoint_path):
146
  config = OmegaConf.load(config_path)
147
  if config.model.target == 'taming.models.vqgan.VQModel':
 
161
  raise ValueError(f'unknown model type: {config.model.target}')
162
  del model.loss
163
  return model
 
 
164
  def resize_image(image, out_size):
165
  ratio = image.size[0] / image.size[1]
166
  area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
 
168
  return image.resize(size, Image.LANCZOS)
169
 
170
 
171
+ model_name = "vqgan_imagenet_f16_16384"
172
+ images_interval = 50
173
+ width = 256
174
+ height = 256
175
+ init_image = ""
176
+ seed = 42
177
+ args = argparse.Namespace(
178
+ noise_prompt_seeds=[],
179
+ noise_prompt_weights=[],
180
+ size=[width, height],
181
+ init_image=init_image,
182
+ init_weight=0.,
183
+ clip_model='ViT-B/32',
184
+ vqgan_config=f'{model_name}.yaml',
185
+ vqgan_checkpoint=f'{model_name}.ckpt',
186
+ step_size=0.1,
187
+ cutn=1,
188
+ cut_pow=1.,
189
+ display_freq=images_interval,
190
+ seed=seed,
191
+ )
192
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
193
+ print('Using device:', device)
194
+ model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
195
+ perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)
196
+
197
  def inference(text):
198
  texts = text
 
 
 
 
 
199
  target_images = ""
 
200
  max_iterations = 100
 
201
  model_names={"vqgan_imagenet_f16_16384": 'ImageNet 16384',"vqgan_imagenet_f16_1024":"ImageNet 1024", 'vqgan_openimages_f16_8192':'OpenImages 8912',
202
  "wikiart_1024":"WikiArt 1024", "wikiart_16384":"WikiArt 16384", "coco":"COCO-Stuff", "faceshq":"FacesHQ", "sflckr":"S-FLCKR"}
203
+ name_model = model_names[model_name]
 
 
 
 
 
204
  if target_images == "None" or not target_images:
205
  target_images = []
206
  else:
207
  target_images = target_images.split("|")
208
  target_images = [image.strip() for image in target_images]
 
209
  texts = [phrase.strip() for phrase in texts.split("|")]
210
  if texts == ['']:
211
  texts = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  from urllib.request import urlopen
 
 
 
213
  if texts:
214
  print('Using texts:', texts)
215
  if target_images:
216
  print('Using image prompts:', target_images)
217
+ if args.seed is None or args.seed == -1:
218
  seed = torch.seed()
219
  else:
220
  seed = args.seed
221
  torch.manual_seed(seed)
222
  print('Using seed:', seed)
 
 
 
223
  # clock=deepcopy(perceptor.visual.positional_embedding.data)
224
  # perceptor.visual.positional_embedding.data = clock/clock.max()
225
  # perceptor.visual.positional_embedding.data=clamp_with_grad(clock,0,1)
 
226
  cut_size = perceptor.visual.input_resolution
 
227
  f = 2**(model.decoder.num_resolutions - 1)
228
  make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)
 
229
  toksX, toksY = args.size[0] // f, args.size[1] // f
230
  sideX, sideY = toksX * f, toksY * f
 
231
  if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':
232
  e_dim = 256
233
  n_toks = model.quantize.n_embed
 
240
  z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
241
  # z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
242
  # z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
 
243
  # normalize_imagenet = transforms.Normalize(mean=[0.485, 0.456, 0.406],
244
  # std=[0.229, 0.224, 0.225])
 
245
  if args.init_image:
246
  if 'http' in args.init_image:
247
  img = Image.open(urlopen(args.init_image))
 
263
  z_orig = z.clone()
264
  z.requires_grad_(True)
265
  opt = optim.Adam([z], lr=args.step_size)
 
266
  normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
267
  std=[0.26862954, 0.26130258, 0.27577711])
 
 
 
268
  pMs = []
269
+ for prompt in texts:
 
270
  txt, weight, stop = parse_prompt(prompt)
271
  embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
272
  pMs.append(Prompt(embed, weight, stop).to(device))
273
+ for prompt in target_images:
 
274
  path, weight, stop = parse_prompt(prompt)
275
  img = Image.open(path)
276
  pil_image = img.convert('RGB')
 
278
  batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
279
  embed = perceptor.encode_image(normalize(batch)).float()
280
  pMs.append(Prompt(embed, weight, stop).to(device))
 
281
  for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):
282
  gen = torch.Generator().manual_seed(seed)
283
  embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)
284
  pMs.append(Prompt(embed, weight).to(device))
 
285
  def synth(z):
286
  if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':
287
  z_q = vector_quantize(z.movedim(1, 3), model.quantize.embed.weight).movedim(3, 1)
288
  else:
289
  z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)
290
  return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)
 
291
  @torch.no_grad()
292
  def checkin(i, losses):
293
  losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
 
295
  out = synth(z)
296
  #TF.to_pil_image(out[0].cpu()).save('progress.png')
297
  #display.display(display.Image('progress.png'))
 
298
  def ascend_txt():
299
  # global i
300
  out = synth(z)
301
  iii = perceptor.encode_image(normalize(make_cutouts(out))).float()
302
 
303
  result = []
 
304
  if args.init_weight:
305
  # result.append(F.mse_loss(z, z_orig) * args.init_weight / 2)
306
  result.append(F.mse_loss(z, torch.zeros_like(z_orig)) * ((1/torch.tensor(i*2 + 1))*args.init_weight) / 2)
 
309
  img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]
310
  img = np.transpose(img, (1, 2, 0))
311
  #imageio.imwrite('./steps/' + str(i) + '.png', np.array(img))
 
312
  return result, np.array(img)
 
313
  def train(i):
314
  opt.zero_grad()
315
  lossAll, image = ascend_txt()
 
322
  with torch.no_grad():
323
  z.copy_(z.maximum(z_min).minimum(z_max))
324
  return image
 
325
  i = 0
326
  try:
327
  with tqdm() as pbar:
 
335
  pass
336
  return image
337
 
338
+ inferences_running = 0
339
+
340
+ def throttled_inference(text):
341
+ global inferences_running
342
+ current = inferences_running
343
+ if current >= 2:
344
+ print(f"Rejected inference when we already had {current} running")
345
+ return None
346
+
347
+ print(f"Inference starting when we already had {current} running")
348
+ inferences_running += 1
349
+ try:
350
+ return inference(text)
351
+ finally:
352
+ print("Inference finished")
353
+ inferences_running -= 1
354
+
355
+
356
  title = "VQGAN + CLIP"
357
  description = "Gradio demo for VQGAN + CLIP. To use it, simply add your text, or click one of the examples to load them. Read more at the links below. Please click submit only once"
358
  article = "<p style='text-align: center'>Originally made by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). The original BigGAN+CLIP method was by https://twitter.com/advadnoun. Added some explanations and modifications by Eleiber#8347, pooling trick by Crimeacs#8222 (https://twitter.com/EarthML1) and the GUI was made with the help of Abulafia#3734. | <a href='https://colab.research.google.com/drive/1ZAus_gn2RhTZWzOWUpPERNC0Q8OhZRTZ'>Colab</a> | <a href='https://github.com/CompVis/taming-transformers'>Taming Transformers Github Repo</a> | <a href='https://github.com/openai/CLIP'>CLIP Github Repo</a></p>"
 
359
  gr.Interface(
360
+ throttled_inference,
361
  gr.inputs.Textbox(label="Input"),
362
  gr.outputs.Image(type="numpy", label="Output"),
363
  title=title,