silentchen commited on
Commit
67ae45a
·
1 Parent(s): 4e5c108

Upload app_fast_api.py

Browse files
Files changed (1) hide show
  1. app_fast_api.py +562 -0
app_fast_api.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import CLIPTextModel, CLIPTokenizer
4
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler
5
+ from my_model import unet_2d_condition
6
+ import json
7
+ import numpy as np
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ from functools import partial
10
+ import math
11
+ from utils import compute_ca_loss
12
+ from gradio import processing_utils
13
+ from typing import Optional
14
+ from fastapi import FastAPI
15
+
16
+ import warnings
17
+
18
+ import sys
19
+
20
+ sys.tracebacklimit = 0
21
+
22
+ class Blocks(gr.Blocks):
23
+
24
+ def __init__(
25
+ self,
26
+ theme: str = "default",
27
+ analytics_enabled: Optional[bool] = None,
28
+ mode: str = "blocks",
29
+ title: str = "Gradio",
30
+ css: Optional[str] = None,
31
+ **kwargs,
32
+ ):
33
+ self.extra_configs = {
34
+ 'thumbnail': kwargs.pop('thumbnail', ''),
35
+ 'url': kwargs.pop('url', 'https://gradio.app/'),
36
+ 'creator': kwargs.pop('creator', '@teamGradio'),
37
+ }
38
+
39
+ super(Blocks, self).__init__(theme, analytics_enabled, mode, title, css, **kwargs)
40
+ warnings.filterwarnings("ignore")
41
+
42
+ def get_config_file(self):
43
+ config = super(Blocks, self).get_config_file()
44
+
45
+ for k, v in self.extra_configs.items():
46
+ config[k] = v
47
+
48
+ return config
49
+
50
+
51
+ def draw_box(boxes=[], texts=[], img=None):
52
+ if len(boxes) == 0 and img is None:
53
+ return None
54
+
55
+ if img is None:
56
+ img = Image.new('RGB', (512, 512), (255, 255, 255))
57
+ colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
58
+ draw = ImageDraw.Draw(img)
59
+ font = ImageFont.truetype("DejaVuSansMono.ttf", size=18)
60
+ print(boxes)
61
+ for bid, box in enumerate(boxes):
62
+ draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4)
63
+ anno_text = texts[bid]
64
+ draw.rectangle(
65
+ [box[0], box[3] - int(font.size * 1.2), box[0] + int((len(anno_text) + 0.8) * font.size * 0.6), box[3]],
66
+ outline=colors[bid % len(colors)], fill=colors[bid % len(colors)], width=4)
67
+ draw.text([box[0] + int(font.size * 0.2), box[3] - int(font.size * 1.2)], anno_text, font=font,
68
+ fill=(255, 255, 255))
69
+ return img
70
+
71
+ '''
72
+ inference model
73
+ '''
74
+
75
+ def inference(device, unet, vae, tokenizer, text_encoder, prompt, bboxes, object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_index_step, rand_seed, guidance_scale):
76
+ uncond_input = tokenizer(
77
+ [""] * 1, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
78
+ )
79
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
80
+
81
+ input_ids = tokenizer(
82
+ prompt,
83
+ padding="max_length",
84
+ truncation=True,
85
+ max_length=tokenizer.model_max_length,
86
+ return_tensors="pt",
87
+ ).input_ids[0].unsqueeze(0).to(device)
88
+ # text_embeddings = text_encoder(input_ids)[0]
89
+ text_embeddings = torch.cat([uncond_embeddings, text_encoder(input_ids)[0]])
90
+ # text_embeddings[1, 1, :] = text_embeddings[1, 2, :]
91
+ generator = torch.manual_seed(rand_seed) # Seed generator to create the inital latent noise
92
+
93
+ latents = torch.randn(
94
+ (batch_size, 4, 64, 64),
95
+ generator=generator,
96
+ ).to(device)
97
+
98
+ noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
99
+
100
+ # generator = torch.Generator("cuda").manual_seed(1024)
101
+ noise_scheduler.set_timesteps(51)
102
+
103
+ latents = latents * noise_scheduler.init_noise_sigma
104
+
105
+ loss = torch.tensor(10000)
106
+
107
+ for index, t in enumerate(noise_scheduler.timesteps):
108
+ iteration = 0
109
+
110
+ while loss.item() / loss_scale > loss_threshold and iteration < max_iter and index < max_index_step:
111
+ latents = latents.requires_grad_(True)
112
+
113
+ # latent_model_input = torch.cat([latents] * 2)
114
+ latent_model_input = latents
115
+
116
+ latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
117
+ noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \
118
+ unet(latent_model_input, t, encoder_hidden_states=text_encoder(input_ids)[0])
119
+
120
+ # update latents with guidence from gaussian blob
121
+
122
+ loss = compute_ca_loss(attn_map_integrated_mid, attn_map_integrated_up, bboxes=bboxes,
123
+ object_positions=object_positions) * loss_scale
124
+
125
+ print(loss.item() / loss_scale)
126
+
127
+ grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0]
128
+
129
+ latents = latents - grad_cond * noise_scheduler.sigmas[index] ** 2
130
+ iteration += 1
131
+ torch.cuda.empty_cache()
132
+ torch.cuda.empty_cache()
133
+
134
+
135
+ with torch.no_grad():
136
+
137
+ latent_model_input = torch.cat([latents] * 2)
138
+
139
+ latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
140
+ noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \
141
+ unet(latent_model_input, t, encoder_hidden_states=text_embeddings)
142
+
143
+ noise_pred = noise_pred.sample
144
+
145
+ # perform classifier-free guidance
146
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
147
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
148
+
149
+ latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
150
+ torch.cuda.empty_cache()
151
+ # Decode image
152
+ with torch.no_grad():
153
+ # print("decode image")
154
+ latents = 1 / 0.18215 * latents
155
+ image = vae.decode(latents).sample
156
+ image = (image / 2 + 0.5).clamp(0, 1)
157
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
158
+ images = (image * 255).round().astype("uint8")
159
+ pil_images = [Image.fromarray(image) for image in images]
160
+ return pil_images
161
+
162
+ def get_concat(ims):
163
+ if len(ims) == 1:
164
+ n_col = 1
165
+ else:
166
+ n_col = 2
167
+ n_row = math.ceil(len(ims) / 2)
168
+ dst = Image.new('RGB', (ims[0].width * n_col, ims[0].height * n_row), color="white")
169
+ for i, im in enumerate(ims):
170
+ row_id = i // n_col
171
+ col_id = i % n_col
172
+ dst.paste(im, (im.width * col_id, im.height * row_id))
173
+ return dst
174
+
175
+
176
+ def generate(unet, vae, tokenizer, text_encoder, language_instruction, grounding_texts, sketch_pad,
177
+ loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter,
178
+ state):
179
+ if 'boxes' not in state:
180
+ state['boxes'] = []
181
+ boxes = state['boxes']
182
+ grounding_texts = [x.strip() for x in grounding_texts.split(';')]
183
+ # assert len(boxes) == len(grounding_texts)
184
+ if len(boxes) != len(grounding_texts):
185
+ if len(boxes) < len(grounding_texts):
186
+ raise ValueError("""The number of boxes should be equal to the number of grounding objects.
187
+ Number of boxes drawn: {}, number of grounding tokens: {}.
188
+ Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts)))
189
+ grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
190
+
191
+ boxes = (np.asarray(boxes) / 512).tolist()
192
+ boxes = [[box] for box in boxes]
193
+ grounding_instruction = json.dumps({obj: box for obj, box in zip(grounding_texts, boxes)})
194
+ language_instruction_list = language_instruction.strip('.').split(' ')
195
+ object_positions = []
196
+ for obj in grounding_texts:
197
+ obj_position = []
198
+ for word in obj.split(' '):
199
+ obj_first_index = language_instruction_list.index(word) + 1
200
+ obj_position.append(obj_first_index)
201
+ object_positions.append(obj_position)
202
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
203
+
204
+ gen_images = inference(device, unet, vae, tokenizer, text_encoder, language_instruction, boxes, object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_step, rand_seed, guidance_scale)
205
+
206
+ blank_samples = batch_size % 2 if batch_size > 1 else 0
207
+ gen_images = [gr.Image.update(value=x, visible=True) for i, x in enumerate(gen_images)] \
208
+ + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
209
+ + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
210
+
211
+ return gen_images + [state]
212
+
213
+
214
+ def binarize(x):
215
+ return (x != 0).astype('uint8') * 255
216
+
217
+
218
+ def sized_center_crop(img, cropx, cropy):
219
+ y, x = img.shape[:2]
220
+ startx = x // 2 - (cropx // 2)
221
+ starty = y // 2 - (cropy // 2)
222
+ return img[starty:starty + cropy, startx:startx + cropx]
223
+
224
+
225
+ def sized_center_fill(img, fill, cropx, cropy):
226
+ y, x = img.shape[:2]
227
+ startx = x // 2 - (cropx // 2)
228
+ starty = y // 2 - (cropy // 2)
229
+ img[starty:starty + cropy, startx:startx + cropx] = fill
230
+ return img
231
+
232
+
233
+ def sized_center_mask(img, cropx, cropy):
234
+ y, x = img.shape[:2]
235
+ startx = x // 2 - (cropx // 2)
236
+ starty = y // 2 - (cropy // 2)
237
+ center_region = img[starty:starty + cropy, startx:startx + cropx].copy()
238
+ img = (img * 0.2).astype('uint8')
239
+ img[starty:starty + cropy, startx:startx + cropx] = center_region
240
+ return img
241
+
242
+
243
+ def center_crop(img, HW=None, tgt_size=(512, 512)):
244
+ if HW is None:
245
+ H, W = img.shape[:2]
246
+ HW = min(H, W)
247
+ img = sized_center_crop(img, HW, HW)
248
+ img = Image.fromarray(img)
249
+ img = img.resize(tgt_size)
250
+ return np.array(img)
251
+
252
+
253
+ def draw(input, grounding_texts, new_image_trigger, state):
254
+ if type(input) == dict:
255
+ image = input['image']
256
+ mask = input['mask']
257
+ else:
258
+ mask = input
259
+ if mask.ndim == 3:
260
+ mask = 255 - mask[..., 0]
261
+
262
+ image_scale = 1.0
263
+
264
+ mask = binarize(mask)
265
+
266
+ if type(mask) != np.ndarray:
267
+ mask = np.array(mask)
268
+
269
+ if mask.sum() == 0:
270
+ state = {}
271
+
272
+ image = None
273
+
274
+ if 'boxes' not in state:
275
+ state['boxes'] = []
276
+
277
+ if 'masks' not in state or len(state['masks']) == 0:
278
+ state['masks'] = []
279
+ last_mask = np.zeros_like(mask)
280
+ else:
281
+ last_mask = state['masks'][-1]
282
+
283
+ if type(mask) == np.ndarray and mask.size > 1:
284
+ diff_mask = mask - last_mask
285
+ else:
286
+ diff_mask = np.zeros([])
287
+
288
+ if diff_mask.sum() > 0:
289
+ x1x2 = np.where(diff_mask.max(0) != 0)[0]
290
+ y1y2 = np.where(diff_mask.max(1) != 0)[0]
291
+ y1, y2 = y1y2.min(), y1y2.max()
292
+ x1, x2 = x1x2.min(), x1x2.max()
293
+
294
+ if (x2 - x1 > 5) and (y2 - y1 > 5):
295
+ state['masks'].append(mask.copy())
296
+ state['boxes'].append((x1, y1, x2, y2))
297
+
298
+ grounding_texts = [x.strip() for x in grounding_texts.split(';')]
299
+ grounding_texts = [x for x in grounding_texts if len(x) > 0]
300
+ if len(grounding_texts) < len(state['boxes']):
301
+ grounding_texts += [f'Obj. {bid + 1}' for bid in range(len(grounding_texts), len(state['boxes']))]
302
+ box_image = draw_box(state['boxes'], grounding_texts, image)
303
+
304
+ return [box_image, new_image_trigger, image_scale, state]
305
+
306
+
307
+ def clear(task, sketch_pad_trigger, batch_size, state, switch_task=False):
308
+ if task != 'Grounded Inpainting':
309
+ sketch_pad_trigger = sketch_pad_trigger + 1
310
+ blank_samples = batch_size % 2 if batch_size > 1 else 0
311
+ out_images = [gr.Image.update(value=None, visible=True) for i in range(batch_size)]
312
+ # state = {}
313
+ return [None, sketch_pad_trigger, None, 1.0] + out_images + [{}]
314
+
315
+
316
+ app = FastAPI()
317
+
318
+ @app.get("/")
319
+ async def root():
320
+ return {"message": "Hello World"}
321
+
322
+
323
+
324
+ # def main():
325
+ css = """
326
+ #img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img
327
+ {
328
+ height: var(--height) !important;
329
+ max-height: var(--height) !important;
330
+ min-height: var(--height) !important;
331
+ }
332
+ #paper-info a {
333
+ color:#008AD7;
334
+ text-decoration: none;
335
+ }
336
+ #paper-info a:hover {
337
+ cursor: pointer;
338
+ text-decoration: none;
339
+ }
340
+
341
+ .tooltip {
342
+ color: #555;
343
+ position: relative;
344
+ display: inline-block;
345
+ cursor: pointer;
346
+ }
347
+
348
+ .tooltip .tooltiptext {
349
+ visibility: hidden;
350
+ width: 400px;
351
+ background-color: #555;
352
+ color: #fff;
353
+ text-align: center;
354
+ padding: 5px;
355
+ border-radius: 5px;
356
+ position: absolute;
357
+ z-index: 1; /* Set z-index to 1 */
358
+ left: 10px;
359
+ top: 100%;
360
+ opacity: 0;
361
+ transition: opacity 0.3s;
362
+ }
363
+
364
+ .tooltip:hover .tooltiptext {
365
+ visibility: visible;
366
+ opacity: 1;
367
+ z-index: 9999; /* Set a high z-index value when hovering */
368
+ }
369
+
370
+
371
+ """
372
+
373
+ rescale_js = """
374
+ function(x) {
375
+ const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app');
376
+ let image_scale = parseFloat(root.querySelector('#image_scale input').value) || 1.0;
377
+ const image_width = root.querySelector('#img2img_image').clientWidth;
378
+ const target_height = parseInt(image_width * image_scale);
379
+ document.body.style.setProperty('--height', `${target_height}px`);
380
+ root.querySelectorAll('button.justify-center.rounded')[0].style.display='none';
381
+ root.querySelectorAll('button.justify-center.rounded')[1].style.display='none';
382
+ return x;
383
+ }
384
+ """
385
+ with open('./conf/unet/config.json') as f:
386
+ unet_config = json.load(f)
387
+
388
+ unet = unet_2d_condition.UNet2DConditionModel(**unet_config).from_pretrained('runwayml/stable-diffusion-v1-5',
389
+ subfolder="unet")
390
+ tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
391
+ text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
392
+ vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
393
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
394
+ unet.to(device)
395
+ text_encoder.to(device)
396
+ vae.to(device)
397
+
398
+
399
+ with Blocks(
400
+ css=css,
401
+ analytics_enabled=False,
402
+ title="Layout-Guidance demo",
403
+ root_url='/Users/shil5883/Desktop/test'
404
+ ) as demo:
405
+ description = """<p style="text-align: center; font-weight: bold;">
406
+ <span style="font-size: 28px">Layout Guidance</span>
407
+ <br>
408
+ <span style="font-size: 18px" id="paper-info">
409
+ [<a href=" " target="_blank">Project Page</a>]
410
+ [<a href=" " target="_blank">Paper</a>]
411
+ [<a href=" " target="_blank">GitHub</a>]
412
+ </span>
413
+ </p>
414
+ """
415
+ gr.HTML(description)
416
+ with gr.Column():
417
+ language_instruction = gr.Textbox(
418
+ label="Text Prompt",
419
+ )
420
+ grounding_instruction = gr.Textbox(
421
+ label="Grounding instruction (Separated by semicolon)",
422
+ )
423
+ sketch_pad_trigger = gr.Number(value=0, visible=False)
424
+ sketch_pad_resize_trigger = gr.Number(value=0, visible=False)
425
+ init_white_trigger = gr.Number(value=0, visible=False)
426
+ image_scale = gr.Number(value=0, elem_id="image_scale", visible=False)
427
+ new_image_trigger = gr.Number(value=0, visible=False)
428
+
429
+
430
+ with gr.Row():
431
+ sketch_pad = gr.Paint(label="Sketch Pad", elem_id="img2img_image", source='canvas', shape=(512, 512))
432
+ out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad")
433
+ out_gen_1 = gr.Image(type="pil", visible=True, label="Generated Image")
434
+
435
+ with gr.Row():
436
+ clear_btn = gr.Button(value='Clear')
437
+ gen_btn = gr.Button(value='Generate')
438
+
439
+ with gr.Accordion("Advanced Options", open=False):
440
+ with gr.Column():
441
+ description = """<div class="tooltip">Loss Scale Factor &#9432
442
+ <span class="tooltiptext">The scale factor of the backward guidance loss. The larger it is, the better control we get while it sometimes losses fidelity. </span>
443
+ </div>
444
+ <div class="tooltip">Guidance Scale &#9432
445
+ <span class="tooltiptext">The scale factor of classifier-free guidance. </span>
446
+ </div>
447
+ <div class="tooltip" >Max Iteration per Step &#9432
448
+ <span class="tooltiptext">The max iterations of backward guidance in each diffusion inference process.</span>
449
+ </div>
450
+ <div class="tooltip" >Loss Threshold &#9432
451
+ <span class="tooltiptext">The threshold of loss. If the loss computed by cross-attention map is smaller then the threshold, the backward guidance is stopped. </span>
452
+ </div>
453
+ <div class="tooltip" >Max Step of Backward Guidance &#9432
454
+ <span class="tooltiptext">The max steps of backward guidance in diffusion inference process.</span>
455
+ </div>
456
+ """
457
+ gr.HTML(description)
458
+ Loss_scale = gr.Slider(minimum=0, maximum=500, step=5, value=30,label="Loss Scale Factor")
459
+ guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Guidance Scale")
460
+ batch_size = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="Number of Samples", visible=False)
461
+ max_iter = gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Max Iteration per Step")
462
+ loss_threshold = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Loss Threshold")
463
+ max_step = gr.Slider(minimum=0, maximum=50, step=1, value=10, label="Max Step of Backward Guidance")
464
+ rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=445, label="Random Seed")
465
+
466
+ state = gr.State({})
467
+
468
+
469
+ class Controller:
470
+ def __init__(self):
471
+ self.calls = 0
472
+ self.tracks = 0
473
+ self.resizes = 0
474
+ self.scales = 0
475
+
476
+ def init_white(self, init_white_trigger):
477
+ self.calls += 1
478
+ return np.ones((512, 512), dtype='uint8') * 255, 1.0, init_white_trigger + 1
479
+
480
+ def change_n_samples(self, n_samples):
481
+ blank_samples = n_samples % 2 if n_samples > 1 else 0
482
+ return [gr.Image.update(visible=True) for _ in range(n_samples + blank_samples)] \
483
+ + [gr.Image.update(visible=False) for _ in range(4 - n_samples - blank_samples)]
484
+
485
+
486
+ controller = Controller()
487
+ demo.load(
488
+ lambda x: x + 1,
489
+ inputs=sketch_pad_trigger,
490
+ outputs=sketch_pad_trigger,
491
+ queue=False)
492
+ sketch_pad.edit(
493
+ draw,
494
+ inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
495
+ outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
496
+ queue=False,
497
+ )
498
+ grounding_instruction.change(
499
+ draw,
500
+ inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
501
+ outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
502
+ queue=False,
503
+ )
504
+ clear_btn.click(
505
+ clear,
506
+ inputs=[sketch_pad_trigger, sketch_pad_trigger, batch_size, state],
507
+ outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, state],
508
+ queue=False)
509
+
510
+ sketch_pad_trigger.change(
511
+ controller.init_white,
512
+ inputs=[init_white_trigger],
513
+ outputs=[sketch_pad, image_scale, init_white_trigger],
514
+ queue=False)
515
+
516
+ gen_btn.click(
517
+ fn=partial(generate, unet, vae, tokenizer, text_encoder),
518
+ inputs=[
519
+ language_instruction, grounding_instruction, sketch_pad,
520
+ loss_threshold, guidance_scale, batch_size, rand_seed,
521
+ max_step,
522
+ Loss_scale, max_iter,
523
+ state,
524
+ ],
525
+ outputs=[out_gen_1, state],
526
+ queue=True
527
+ )
528
+ sketch_pad_resize_trigger.change(
529
+ None,
530
+ None,
531
+ sketch_pad_resize_trigger,
532
+ _js=rescale_js,
533
+ queue=False)
534
+ init_white_trigger.change(
535
+ None,
536
+ None,
537
+ init_white_trigger,
538
+ _js=rescale_js,
539
+ queue=False)
540
+
541
+ with gr.Column():
542
+ gr.Examples(
543
+ examples=[
544
+ [
545
+ # "images/input.png",
546
+ "A hello kitty toy is playing with a purple ball.",
547
+ "hello kitty;ball",
548
+ "./images/hello_kitty_results.png"
549
+ ],
550
+ ],
551
+ inputs=[language_instruction, grounding_instruction, out_gen_1],
552
+ outputs=None,
553
+ fn=None,
554
+ cache_examples=False,
555
+ )
556
+ description = """<p> The source codes of the demo are modified based on the <a href="https://huggingface.co/spaces/gligen/demo/tree/main">GlIGen</a>. Thanks! </p>"""
557
+ gr.HTML(description)
558
+
559
+
560
+
561
+ demo.queue(concurrency_count=1, api_open=False)
562
+ app = gr.mount_gradio_app(app, demo, path="/layout-guidance")