Ahsen Khaliq commited on
Commit
463d215
Β·
1 Parent(s): 1bfda02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -36
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import torch
2
  torch.hub.download_url_to_file('http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.yaml', 'vqgan_imagenet_f16_16384.yaml')
3
  torch.hub.download_url_to_file('http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.ckpt', 'vqgan_imagenet_f16_16384.ckpt')
 
 
4
  import argparse
5
  import math
6
  from pathlib import Path
@@ -25,12 +27,9 @@ from PIL import ImageFile, Image
25
  ImageFile.LOAD_TRUNCATED_IMAGES = True
26
  import gradio as gr
27
  import nvidia_smi
28
-
29
  nvidia_smi.nvmlInit()
30
  handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
31
  # card id 0 hardcoded here, there is also a call to get all available card ids, so we could iterate
32
-
33
-
34
  torch.hub.download_url_to_file('https://i.imgur.com/WEHmKef.jpg', 'gpu.jpg')
35
  def sinc(x):
36
  return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
@@ -171,39 +170,40 @@ def resize_image(image, out_size):
171
  area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
172
  size = round((area * ratio)**0.5), round((area / ratio)**0.5)
173
  return image.resize(size, Image.LANCZOS)
174
- model_name = "vqgan_imagenet_f16_16384"
175
- images_interval = 50
176
- width = 280
177
- height = 280
178
- init_image = ""
179
- seed = 42
180
- args = argparse.Namespace(
181
- noise_prompt_seeds=[],
182
- noise_prompt_weights=[],
183
- size=[width, height],
184
- init_image=init_image,
185
- init_weight=0.,
186
- clip_model='ViT-B/32',
187
- vqgan_config=f'{model_name}.yaml',
188
- vqgan_checkpoint=f'{model_name}.ckpt',
189
- step_size=0.15,
190
- cutn=4,
191
- cut_pow=1.,
192
- display_freq=images_interval,
193
- seed=seed,
194
- )
195
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
196
- print('Using device:', device)
197
- model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
198
- perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)
199
- def inference(text, seed, step_size, max_iterations, width, height):
200
- size=[width, height]
201
  texts = text
202
  target_images = ""
203
  max_iterations = max_iterations
 
204
  model_names={"vqgan_imagenet_f16_16384": 'ImageNet 16384',"vqgan_imagenet_f16_1024":"ImageNet 1024", 'vqgan_openimages_f16_8192':'OpenImages 8912',
205
- "wikiart_1024":"WikiArt 1024", "wikiart_16384":"WikiArt 16384", "coco":"COCO-Stuff", "faceshq":"FacesHQ", "sflckr":"S-FLCKR"}
206
  name_model = model_names[model_name]
 
 
 
 
 
 
 
 
 
207
  if target_images == "None" or not target_images:
208
  target_images = []
209
  else:
@@ -345,7 +345,7 @@ def load_image( infilename ) :
345
  img.load()
346
  data = np.asarray( img, dtype="int32" )
347
  return data
348
- def throttled_inference(text, seed, step_size, max_iterations, width, height):
349
  global inferences_running
350
  current = inferences_running
351
  if current >= 3:
@@ -354,7 +354,7 @@ def throttled_inference(text, seed, step_size, max_iterations, width, height):
354
  print(f"Inference starting when we already had {current} running")
355
  inferences_running += 1
356
  try:
357
- return inference(text, seed, step_size, max_iterations, width, height)
358
  finally:
359
  print("Inference finished")
360
  inferences_running -= 1
@@ -369,14 +369,15 @@ gr.Interface(
369
  gr.inputs.Slider(minimum=25, maximum=150, default=80, label='max iterations', step=1),
370
  gr.inputs.Slider(minimum=200, maximum=280, default=256, label='width', step=1),
371
  gr.inputs.Slider(minimum=200, maximum=280, default=256, label='height', step=1),
 
372
  ],
373
  gr.outputs.Image(type="numpy", label="Output"),
374
  title=title,
375
  description=description,
376
  article=article,
377
  examples=[
378
- ['a garden by james gurney',42,0.16, 100, 256, 256],
379
- ['coral reef city artstationHQ',1000,0.6, 110, 200, 200],
380
- ['a cabin in the mountains unreal engine',98,0.3, 120, 280, 280]
381
  ]
382
  ).launch(debug=True)
 
1
  import torch
2
  torch.hub.download_url_to_file('http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.yaml', 'vqgan_imagenet_f16_16384.yaml')
3
  torch.hub.download_url_to_file('http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.ckpt', 'vqgan_imagenet_f16_16384.ckpt')
4
+ torch.hub.download_url_to_file('http://batbot.tv/misc/coco_first_stage.yaml', 'coco_first_stage.yaml')
5
+ torch.hub.download_url_to_file('http://batbot.tv/misc/coco_first_stage.ckpt', 'coco_first_stage.ckpt')
6
  import argparse
7
  import math
8
  from pathlib import Path
 
27
  ImageFile.LOAD_TRUNCATED_IMAGES = True
28
  import gradio as gr
29
  import nvidia_smi
 
30
  nvidia_smi.nvmlInit()
31
  handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
32
  # card id 0 hardcoded here, there is also a call to get all available card ids, so we could iterate
 
 
33
  torch.hub.download_url_to_file('https://i.imgur.com/WEHmKef.jpg', 'gpu.jpg')
34
  def sinc(x):
35
  return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
 
170
  area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
171
  size = round((area * ratio)**0.5), round((area / ratio)**0.5)
172
  return image.resize(size, Image.LANCZOS)
173
+
174
+
175
+ def inference(text, seed, step_size, max_iterations, width, height, model_name):
176
+ args = argparse.Namespace(
177
+ noise_prompt_seeds=[],
178
+ noise_prompt_weights=[],
179
+ size=[width, height],
180
+ init_image="",
181
+ init_weight=0.,
182
+ clip_model='ViT-B/32',
183
+ vqgan_config=f'{model_name}.yaml',
184
+ vqgan_checkpoint=f'{model_name}.ckpt',
185
+ step_size=step_size,
186
+ cutn=4,
187
+ cut_pow=1.,
188
+ display_freq=50,
189
+ seed=seed,
190
+ )
 
 
 
 
 
 
 
 
 
191
  texts = text
192
  target_images = ""
193
  max_iterations = max_iterations
194
+ model_name = model_name
195
  model_names={"vqgan_imagenet_f16_16384": 'ImageNet 16384',"vqgan_imagenet_f16_1024":"ImageNet 1024", 'vqgan_openimages_f16_8192':'OpenImages 8912',
196
+ "wikiart_1024":"WikiArt 1024", "wikiart_16384":"WikiArt 16384", "coco_first_stage":"COCO-Stuff", "faceshq":"FacesHQ", "sflckr":"S-FLCKR"}
197
  name_model = model_names[model_name]
198
+ init_image = ""
199
+ size=[width, height]
200
+ seed=seed
201
+ step_size=step_size
202
+
203
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
204
+ print('Using device:', device)
205
+ model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
206
+ perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)
207
  if target_images == "None" or not target_images:
208
  target_images = []
209
  else:
 
345
  img.load()
346
  data = np.asarray( img, dtype="int32" )
347
  return data
348
+ def throttled_inference(text, seed, step_size, max_iterations, width, height, model_name):
349
  global inferences_running
350
  current = inferences_running
351
  if current >= 3:
 
354
  print(f"Inference starting when we already had {current} running")
355
  inferences_running += 1
356
  try:
357
+ return inference(text, seed, step_size, max_iterations, width, height, model_name)
358
  finally:
359
  print("Inference finished")
360
  inferences_running -= 1
 
369
  gr.inputs.Slider(minimum=25, maximum=150, default=80, label='max iterations', step=1),
370
  gr.inputs.Slider(minimum=200, maximum=280, default=256, label='width', step=1),
371
  gr.inputs.Slider(minimum=200, maximum=280, default=256, label='height', step=1),
372
+ gr.inputs.Dropdown(choices=["vqgan_imagenet_f16_16384", "coco_first_stage"], type="value", default="vqgan_imagenet_f16_16384", label="Model Name")
373
  ],
374
  gr.outputs.Image(type="numpy", label="Output"),
375
  title=title,
376
  description=description,
377
  article=article,
378
  examples=[
379
+ ['a garden by james gurney',42,0.16, 100, 256, 256, "vqgan_imagenet_f16_16384"],
380
+ ['coral reef city artstationHQ',1000,0.6, 110, 200, 200, "vqgan_imagenet_f16_16384"],
381
+ ['a cabin in the mountains unreal engine',98,0.3, 120, 280, 280, "vqgan_imagenet_f16_16384"]
382
  ]
383
  ).launch(debug=True)