JingyeChen22 commited on
Commit
971203d
·
1 Parent(s): 2758991

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +397 -238
app.py CHANGED
@@ -3,9 +3,13 @@ import re
3
  import zipfile
4
  import torch
5
  import gradio as gr
 
 
 
 
6
  import time
7
  from transformers import CLIPTextModel, CLIPTokenizer
8
- from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, DiffusionPipeline, LCMScheduler
9
  from tqdm import tqdm
10
  from PIL import Image
11
  from PIL import Image, ImageDraw, ImageFont
@@ -26,29 +30,9 @@ if not os.path.exists('images2'):
26
  # os.system('nvidia-smi')
27
  os.system('ls')
28
 
29
- #### import m1
30
- from fastchat.model import load_model, get_conversation_template
31
- from transformers import AutoTokenizer, AutoModelForCausalLM
32
- m1_model_path = 'JingyeChen22/textdiffuser2_layout_planner'
33
- # m1_model, m1_tokenizer = load_model(
34
- # m1_model_path,
35
- # 'cuda',
36
- # 1,
37
- # None,
38
- # False,
39
- # False,
40
- # revision="main",
41
- # debug=False,
42
- # )
43
-
44
- m1_tokenizer = AutoTokenizer.from_pretrained(m1_model_path, use_fast=False)
45
- m1_model = AutoModelForCausalLM.from_pretrained(
46
- m1_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
47
- ).cuda()
48
-
49
  #### import diffusion models
50
  text_encoder = CLIPTextModel.from_pretrained(
51
- 'JingyeChen22/textdiffuser2-full-ft', subfolder="text_encoder"
52
  ).cuda().half()
53
  tokenizer = CLIPTokenizer.from_pretrained(
54
  'runwayml/stable-diffusion-v1-5', subfolder="tokenizer"
@@ -69,50 +53,49 @@ print('***************')
69
 
70
  vae = AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="vae").half().cuda()
71
  unet = UNet2DConditionModel.from_pretrained(
72
- 'JingyeChen22/textdiffuser2-full-ft', subfolder="unet"
73
  ).half().cuda()
74
  text_encoder.resize_token_embeddings(len(tokenizer))
75
 
 
 
 
 
 
76
 
77
- #### load lcm components
78
- model_id = "lambdalabs/sd-pokemon-diffusers"
79
- lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
80
- pipe = DiffusionPipeline.from_pretrained(model_id, unet=copy.deepcopy(unet), tokenizer=tokenizer, text_encoder=copy.deepcopy(text_encoder), torch_dtype=torch.float16)
81
- pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
82
- pipe.load_lora_weights(lcm_lora_id)
83
- pipe.to(device="cuda")
84
 
85
 
86
- #### for interactive
87
- stack = []
88
- state = 0
89
- font = ImageFont.truetype("./Arial.ttf", 32)
90
-
91
- def skip_fun(i, t):
92
- global state
93
- state = 0
94
-
95
-
96
- def exe_undo(i, t):
97
- global stack
98
- global state
99
- state = 0
100
- stack = []
101
- image = Image.open(f'./gray256.jpg')
102
- print('stack', stack)
103
- return image
104
 
 
105
 
106
- def exe_redo(i, t):
107
- global state
108
- state = 0
109
 
110
- if len(stack) > 0:
111
- stack.pop()
112
- image = Image.open(f'./gray256.jpg')
 
 
 
 
 
 
113
  draw = ImageDraw.Draw(image)
114
 
115
- for items in stack:
116
  # print('now', items)
117
  text_position, t = items
118
  if len(text_position) == 2:
@@ -133,57 +116,194 @@ def exe_redo(i, t):
133
  draw.ellipse((leftUpPoint,rightDownPoint), fill='red')
134
  draw.rectangle((x0,y0,x1,y1), outline=(255, 0, 0) )
135
 
136
- print('stack', stack)
137
  return image
138
 
139
- def get_pixels(i, t, evt: gr.SelectData):
140
- global state
141
 
142
- text_position = evt.index
 
143
 
144
- if state == 0:
145
- stack.append(
146
- (text_position, t)
147
- )
148
- print(text_position, stack)
149
- state = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  else:
151
-
152
- (_, t) = stack.pop()
153
- x, y = _
154
- stack.append(
155
- ((x,y,text_position[0],text_position[1]), t)
156
- )
157
- state = 0
158
 
 
 
 
 
 
 
 
159
 
160
- image = Image.open(f'./gray256.jpg')
161
- draw = ImageDraw.Draw(image)
162
 
163
- for items in stack:
164
- # print('now', items)
165
- text_position, t = items
166
- if len(text_position) == 2:
167
- x, y = text_position
168
- text_color = (255, 0, 0)
169
- draw.text((x+2, y), t, font=font, fill=text_color)
170
- r = 4
171
- leftUpPoint = (x-r, y-r)
172
- rightDownPoint = (x+r, y+r)
173
- draw.ellipse((leftUpPoint,rightDownPoint), fill='red')
174
- elif len(text_position) == 4:
175
- x0, y0, x1, y1 = text_position
176
- text_color = (255, 0, 0)
177
- draw.text((x0+2, y0), t, font=font, fill=text_color)
178
- r = 4
179
- leftUpPoint = (x0-r, y0-r)
180
- rightDownPoint = (x0+r, y0+r)
181
- draw.ellipse((leftUpPoint,rightDownPoint), fill='red')
182
- draw.rectangle((x0,y0,x1,y1), outline=(255, 0, 0) )
183
 
184
- print('stack', stack)
185
 
186
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
 
189
  font_layout = ImageFont.truetype('./Arial.ttf', 16)
@@ -208,11 +328,30 @@ def get_layout_image(ocrs):
208
  return blank
209
 
210
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
- def text_to_image(prompt,keywords,positive_prompt,radio,slider_step,slider_guidance,slider_batch,slider_temperature,slider_natural):
 
213
 
214
- global stack
215
- global state
 
 
 
 
 
 
 
216
 
217
  if len(positive_prompt.strip()) != 0:
218
  prompt += positive_prompt
@@ -227,7 +366,7 @@ def text_to_image(prompt,keywords,positive_prompt,radio,slider_step,slider_guida
227
  prompt = tokenizer.encode(user_prompt)
228
  layout_image = None
229
  else:
230
- if len(stack) == 0:
231
 
232
  if len(keywords.strip()) == 0:
233
  template = f'Given a prompt that will be used to generate an image, plan the layout of visual text for the image. The size of the image is 128x128. Therefore, all properties of the positions should not exceed 128, including the coordinates of top, left, right, and bottom. All keywords are included in the caption. You dont need to specify the details of font styles. At each line, the format should be keyword left, top, right, bottom. So let us begin. Prompt: {user_prompt}'
@@ -308,19 +447,25 @@ def text_to_image(prompt,keywords,positive_prompt,radio,slider_step,slider_guida
308
  composed_prompt = tokenizer.decode(prompt)
309
 
310
  else:
311
- user_prompt += ' <|endoftext|>'
312
  layout_image = None
313
-
314
- for items in stack:
 
 
 
315
  position, text = items
316
 
 
 
317
 
318
  if len(position) == 2:
319
  x, y = position
320
  x = x // 4
321
  y = y // 4
322
  text_str = ' '.join([f'[{c}]' for c in list(text)])
323
- user_prompt += f'<|startoftext|> l{x} t{y} {text_str} <|endoftext|>'
 
324
  elif len(position) == 4:
325
  x0, y0, x1, y1 = position
326
  x0 = x0 // 4
@@ -328,9 +473,32 @@ def text_to_image(prompt,keywords,positive_prompt,radio,slider_step,slider_guida
328
  x1 = x1 // 4
329
  y1 = y1 // 4
330
  text_str = ' '.join([f'[{c}]' for c in list(text)])
331
- user_prompt += f'<|startoftext|> l{x0} t{y0} r{x1} b{y1} {text_str} <|endoftext|>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
- # composed_prompt = user_prompt
334
  prompt = tokenizer.encode(user_prompt)
335
  composed_prompt = tokenizer.decode(prompt)
336
 
@@ -338,70 +506,67 @@ def text_to_image(prompt,keywords,positive_prompt,radio,slider_step,slider_guida
338
  while len(prompt) < 77:
339
  prompt.append(tokenizer.pad_token_id)
340
 
341
- if radio == 'TextDiffuser-2':
342
-
343
- prompts_cond = prompt
344
- prompts_nocond = [tokenizer.pad_token_id]*77
345
-
346
- prompts_cond = [prompts_cond] * slider_batch
347
- prompts_nocond = [prompts_nocond] * slider_batch
348
-
349
- prompts_cond = torch.Tensor(prompts_cond).long().cuda()
350
- prompts_nocond = torch.Tensor(prompts_nocond).long().cuda()
351
-
352
- scheduler = DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="scheduler")
353
- scheduler.set_timesteps(slider_step)
354
- noise = torch.randn((slider_batch, 4, 64, 64)).to("cuda").half()
355
- input = noise
356
-
357
- encoder_hidden_states_cond = text_encoder(prompts_cond)[0].half()
358
- encoder_hidden_states_nocond = text_encoder(prompts_nocond)[0].half()
359
-
360
-
361
- for t in tqdm(scheduler.timesteps):
362
- with torch.no_grad(): # classifier free guidance
363
- noise_pred_cond = unet(sample=input, timestep=t, encoder_hidden_states=encoder_hidden_states_cond[:slider_batch]).sample # b, 4, 64, 64
364
- noise_pred_uncond = unet(sample=input, timestep=t, encoder_hidden_states=encoder_hidden_states_nocond[:slider_batch]).sample # b, 4, 64, 64
365
- noisy_residual = noise_pred_uncond + slider_guidance * (noise_pred_cond - noise_pred_uncond) # b, 4, 64, 64
366
- input = scheduler.step(noisy_residual, t, input).prev_sample
367
- del noise_pred_cond
368
- del noise_pred_uncond
369
-
370
- torch.cuda.empty_cache()
371
-
372
- # decode
373
- input = 1 / vae.config.scaling_factor * input
374
- images = vae.decode(input, return_dict=False)[0]
375
- width, height = 512, 512
376
- results = []
377
- new_image = Image.new('RGB', (2*width, 2*height))
378
- for index, image in enumerate(images.cpu().float()):
379
- image = (image / 2 + 0.5).clamp(0, 1).unsqueeze(0)
380
- image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
381
- image = Image.fromarray((image * 255).round().astype("uint8")).convert('RGB')
382
- results.append(image)
383
- row = index // 2
384
- col = index % 2
385
- new_image.paste(image, (col*width, row*height))
386
- # os.system('nvidia-smi')
387
- torch.cuda.empty_cache()
388
- # os.system('nvidia-smi')
389
- return tuple(results), composed_prompt, layout_image
390
-
391
- elif radio == 'TextDiffuser-2-LCM':
392
- generator = torch.Generator(device=pipe.device).manual_seed(random.randint(0,1000))
393
- image = pipe(
394
- prompt=user_prompt,
395
- generator=generator,
396
- # negative_prompt=negative_prompt,
397
- num_inference_steps=slider_step,
398
- guidance_scale=1,
399
- # num_images_per_prompt=slider_batch,
400
- ).images
401
- # os.system('nvidia-smi')
402
- torch.cuda.empty_cache()
403
- # os.system('nvidia-smi')
404
- return tuple(image), composed_prompt, layout_image
405
 
406
  with gr.Blocks() as demo:
407
 
@@ -411,6 +576,9 @@ with gr.Blocks() as demo:
411
  <h2 style="font-weight: 900; font-size: 2.3rem; margin: 0rem">
412
  TextDiffuser-2: Unleashing the Power of Language Models for Text Rendering
413
  </h2>
 
 
 
414
  <h2 style="font-weight: 460; font-size: 1.1rem; margin: 0rem">
415
  <a href="https://jingyechen.github.io/">Jingye Chen</a>, <a href="https://hypjudy.github.io/website/">Yupan Huang</a>, <a href="https://scholar.google.com/citations?user=0LTZGhUAAAAJ&hl=en">Tengchao Lv</a>, <a href="https://www.microsoft.com/en-us/research/people/lecu/">Lei Cui</a>, <a href="https://cqf.io/">Qifeng Chen</a>, <a href="https://thegenerality.com/">Furu Wei</a>
416
  </h2>
@@ -420,96 +588,87 @@ with gr.Blocks() as demo:
420
  <h3 style="font-weight: 450; font-size: 1rem; margin: 0rem">
421
  [<a href="https://arxiv.org/abs/2311.16465" style="color:blue;">arXiv</a>]
422
  [<a href="https://github.com/microsoft/unilm/tree/master/textdiffuser-2" style="color:blue;">Code</a>]
 
 
423
  </h3>
424
  <h2 style="text-align: left; font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
425
- We propose <b>TextDiffuser-2</b>, aiming at unleashing the power of language models for text rendering. Specifically, we <b>tame a language model into a layout planner</b> to transform user prompt into a layout using the caption-OCR pairs. The language model demonstrates flexibility and automation by inferring keywords from user prompts or incorporating user-specified keywords to determine their positions. Secondly, we <b>leverage the language model in the diffusion model as the layout encoder</b> to represent the position and content of text at the line level. This approach enables diffusion models to generate text images with broader diversity.
426
  </h2>
427
  <h2 style="text-align: left; font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
428
- 👀 <b>Tips for using this demo</b>: <b>(1)</b> Please carefully read the disclaimer in the below. Current verison can only support English. <b>(2)</b> The specification of keywords is optional. If provided, the language model will do its best to plan layouts using the given keywords. <b>(3)</b> If a template is given, the layout planner (M1) is not used. <b>(4)</b> Three operations, including redo, undo, and skip are provided. When using skip, only the left-top point of a keyword will be recorded, resulting in more diversity but sometimes decreasing the accuracy. <b>(5)</b> The layout planner can produce different layouts. You can increase the temperature to enhance the diversity. <b>(6)</b> We also provide the experimental demo combining <b>TextDiffuser-2</b> and <b>LCM</b>. The inference is fast using less sampling steps, although the precision in text rendering might decrease.
429
  </h2>
430
- <style>
431
- .scaled-image {
432
- transform: scale(1);
433
- }
434
- </style>
435
-
436
- <img src="https://i.ibb.co/56JVg5j/architecture.jpg" alt="textdiffuser-2" class="scaled-image">
437
  </div>
438
  """)
439
 
440
- with gr.Tab("Text-to-Image"):
441
  with gr.Row():
442
- with gr.Column(scale=1):
443
- prompt = gr.Textbox(label="Prompt. You can let language model automatically identify keywords, or provide them below", placeholder="A beautiful city skyline stamp of Shanghai")
444
- keywords = gr.Textbox(label="(Optional) Keywords. Should be seperated by / (e.g., keyword1/keyword2/...)", placeholder="keyword1/keyword2")
445
- positive_prompt = gr.Textbox(label="(Optional) Positive prompt", value=", digital art, very detailed, fantasy, high definition, cinematic light, dnd, trending on artstation")
446
-
447
- with gr.Accordion("(Optional) Template - Click to paint", open=False):
448
- with gr.Row():
449
- with gr.Column(scale=1):
450
- i = gr.Image(label="Canvas", type='filepath', value=f'./gray256.jpg', height=256, width=256)
451
- with gr.Column(scale=1):
452
- t = gr.Textbox(label="Keyword", value='input_keyword')
453
- redo = gr.Button(value='Redo - Cancel the last keyword')
454
- undo = gr.Button(value='Undo - Clear the canvas')
455
- skip_button = gr.Button(value='Skip - Operate the next keyword')
456
-
457
- i.select(get_pixels,[i,t],[i])
458
- redo.click(exe_redo, [i,t],[i])
459
- undo.click(exe_undo, [i,t],[i])
460
- skip_button.click(skip_fun, [i,t])
461
-
462
- radio = gr.Radio(["TextDiffuser-2", "TextDiffuser-2-LCM"], label="Choice of models", value="TextDiffuser-2")
463
- slider_natural = gr.Checkbox(label="Natural image generation", value=False, info="The text position and content info will not be incorporated.")
464
- slider_step = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Sampling step", info="The sampling step for TextDiffuser-2. You may decease the step to 4 when using LCM.")
465
- slider_guidance = gr.Slider(minimum=1, maximum=13, value=7.5, step=0.5, label="Scale of classifier-free guidance", info="The scale of cfg and is set to 7.5 in default. When using LCM, cfg is set to 1.")
466
- slider_batch = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="Batch size", info="The number of images to be sampled.")
467
- slider_temperature = gr.Slider(minimum=0.1, maximum=2, value=1.4, step=0.1, label="Temperature", info="Control the diversity of layout planner. Higher value indicates more diversity.")
468
  # slider_seed = gr.Slider(minimum=1, maximum=10000, label="Seed", randomize=True)
469
  button = gr.Button("Generate")
 
 
 
 
 
 
 
470
 
471
- with gr.Column(scale=1):
472
- output = gr.Gallery(label='Generated image')
473
 
474
- with gr.Accordion("Intermediate results", open=False):
475
  gr.Markdown("Composed prompt")
476
  composed_prompt = gr.Textbox(label='')
477
- gr.Markdown("Layout visualization")
478
- layout = gr.Image(height=256, width=256)
479
 
480
 
481
- button.click(text_to_image, inputs=[prompt,keywords,positive_prompt, radio,slider_step,slider_guidance,slider_batch,slider_temperature,slider_natural], outputs=[output, composed_prompt, layout])
482
 
483
- gr.Markdown("## Prompt Examples")
484
- gr.Examples(
 
485
  [
486
- ["A beautiful city skyline stamp of Shanghai", "", False],
487
- ["A logo of superman", "", False],
488
- ["A pencil sketch of a tree with the title nothing to tree here", "", False],
489
- ["handwritten signature of peter", "", False],
490
- ["Delicate greeting card of happy birthday to xyz", "", False],
491
- ["Book cover of good morning baby ", "", False],
492
- ["The handwritten words Hello World displayed on a wall in a neon light effect", "", False],
493
- ["Logo of winter in artistic font, made by snowflake", "", False],
494
- ["A book cover named summer vibe", "", False],
495
- ["Newspaper with the title Love Story", "", False],
496
- ["A logo for the company EcoGrow, where the letters look like plants", "EcoGrow", False],
497
- ["A poster titled 'Quails of North America', showing different kinds of quails.", "Quails/of/North/America", False],
498
- ["A detailed portrait of a fox guardian with a shield with Kung Fu written on it, by victo ngai and justin gerard, digital art, realistic painting", "kung/fu", False],
499
- ["A stamp of breath of the wild", "breath/of/the/wild", False],
500
- ["Poster of the incoming movie Transformers", "Transformers", False],
501
- ["Some apples are on a table", "", True],
502
- ["a hotdog with mustard and other toppings on it", "", True],
503
- ["a bathroom that has a slanted ceiling and a large bath tub", "", True],
504
- ["a man holding a tennis racquet on a tennis court", "", True],
505
- ["hamburger with bacon, lettuce, tomato and cheese| promotional image| hyperquality| products shot| full - color| extreme render| mouthwatering", "", True],
506
  ],
507
  [
508
- prompt,
509
- keywords,
510
- slider_natural
511
  ],
512
- examples_per_page=20
513
  )
514
 
515
  gr.HTML(
 
3
  import zipfile
4
  import torch
5
  import gradio as gr
6
+
7
+ print('hello', gr.__version__)
8
+
9
+ import numpy as np
10
  import time
11
  from transformers import CLIPTextModel, CLIPTokenizer
12
+ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, DiffusionPipeline
13
  from tqdm import tqdm
14
  from PIL import Image
15
  from PIL import Image, ImageDraw, ImageFont
 
30
  # os.system('nvidia-smi')
31
  os.system('ls')
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  #### import diffusion models
34
  text_encoder = CLIPTextModel.from_pretrained(
35
+ 'JingyeChen22/textdiffuser2-full-ft-inpainting', subfolder="text_encoder"
36
  ).cuda().half()
37
  tokenizer = CLIPTokenizer.from_pretrained(
38
  'runwayml/stable-diffusion-v1-5', subfolder="tokenizer"
 
53
 
54
  vae = AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="vae").half().cuda()
55
  unet = UNet2DConditionModel.from_pretrained(
56
+ 'JingyeChen22/textdiffuser2-full-ft-inpainting', subfolder="unet"
57
  ).half().cuda()
58
  text_encoder.resize_token_embeddings(len(tokenizer))
59
 
60
+ global_dict = {}
61
+ #### for interactive
62
+ # stack = []
63
+ # state = 0
64
+ font = ImageFont.truetype("./Arial.ttf", 20)
65
 
66
+ def skip_fun(i, t, guest_id):
67
+ global_dict[guest_id]['state'] = 0
68
+ # global state
69
+ # state = 0
 
 
 
70
 
71
 
72
+ def exe_undo(i, orig_i, t, guest_id):
73
+
74
+ global_dict[guest_id]['stack'] = []
75
+ global_dict[guest_id]['state'] = 0
76
+
77
+ return copy.deepcopy(orig_i)
78
+
79
+
80
+ def exe_redo(i, orig_i, t, guest_id):
 
 
 
 
 
 
 
 
 
81
 
82
+ print('redo ',orig_i)
83
 
84
+ if type(orig_i) == str:
85
+ orig_i = Image.open(orig_i)
 
86
 
87
+ # global state
88
+ # state = 0
89
+ global_dict[guest_id]['state'] = 0
90
+
91
+ if len(global_dict[guest_id]['stack']) > 0:
92
+ global_dict[guest_id]['stack'].pop()
93
+
94
+ image = copy.deepcopy(orig_i)
95
+
96
  draw = ImageDraw.Draw(image)
97
 
98
+ for items in global_dict[guest_id]['stack']:
99
  # print('now', items)
100
  text_position, t = items
101
  if len(text_position) == 2:
 
116
  draw.ellipse((leftUpPoint,rightDownPoint), fill='red')
117
  draw.rectangle((x0,y0,x1,y1), outline=(255, 0, 0) )
118
 
119
+ print('stack', global_dict[guest_id]['stack'])
120
  return image
121
 
122
+ def get_pixels(i, orig_i, radio, t, guest_id, evt: gr.SelectData):
 
123
 
124
+ print('hi1 ', i)
125
+ print('hi2 ', orig_i)
126
 
127
+ width, height = Image.open(i).size
128
+
129
+ # register
130
+ if guest_id == '-1': # register for the first time
131
+ seed = str(int(time.time()))
132
+ global_dict[str(seed)] = {
133
+ 'state': 0,
134
+ 'stack': [],
135
+ 'image_id': [list(Image.open(i).resize((512,512)).getdata())] # an image has been recorded
136
+ }
137
+ guest_id = str(seed)
138
+ else:
139
+ seed = guest_id
140
+
141
+ if type(i) == str:
142
+ i = Image.open(i)
143
+ i = i.resize((512,512))
144
+
145
+ images = global_dict[str(seed)]['image_id']
146
+ flag = False
147
+ for image in images:
148
+ if image == list(i.getdata()):
149
+ print('find it')
150
+ flag = True
151
+ break
152
+
153
+ if not flag:
154
+ global_dict[str(seed)]['image_id'] = [list(i.getdata())]
155
+ global_dict[str(seed)]['stack'] = []
156
+ global_dict[str(seed)]['state'] = 0
157
+ orig_i = i
158
  else:
 
 
 
 
 
 
 
159
 
160
+ if orig_i is not None:
161
+ orig_i = Image.open(orig_i)
162
+ orig_i = orig_i.resize((512,512))
163
+ else:
164
+ orig_i = i
165
+ global_dict[guest_id]['stack'] = []
166
+ global_dict[guest_id]['state'] = 0
167
 
168
+ text_position = evt.index
 
169
 
170
+ print('hello ', text_position)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
+ if radio == 'Two Points':
173
 
174
+ if global_dict[guest_id]['state'] == 0:
175
+ global_dict[guest_id]['stack'].append(
176
+ (text_position, t)
177
+ )
178
+ print(text_position, global_dict[guest_id]['stack'])
179
+ global_dict[guest_id]['state'] = 1
180
+ else:
181
+
182
+ (_, t) = global_dict[guest_id]['stack'].pop()
183
+ x, y = _
184
+ global_dict[guest_id]['stack'].append(
185
+ ((x,y,text_position[0],text_position[1]), t)
186
+ )
187
+ global_dict[guest_id]['state'] = 0
188
+
189
+ image = copy.deepcopy(orig_i)
190
+ draw = ImageDraw.Draw(image)
191
+
192
+ for items in global_dict[guest_id]['stack']:
193
+ text_position, t = items
194
+ if len(text_position) == 2:
195
+ x, y = text_position
196
+
197
+ x = int(512 * x / width)
198
+ y = int(512 * y / height)
199
+
200
+ text_color = (255, 0, 0)
201
+ draw.text((x+2, y), t, font=font, fill=text_color)
202
+ r = 4
203
+ leftUpPoint = (x-r, y-r)
204
+ rightDownPoint = (x+r, y+r)
205
+ draw.ellipse((leftUpPoint,rightDownPoint), fill='red')
206
+ elif len(text_position) == 4:
207
+ x0, y0, x1, y1 = text_position
208
+
209
+ x0 = int(512 * x0 / width)
210
+ x1 = int(512 * x1 / width)
211
+ y0 = int(512 * y0 / height)
212
+ y1 = int(512 * y1 / height)
213
+
214
+ text_color = (255, 0, 0)
215
+ draw.text((x0+2, y0), t, font=font, fill=text_color)
216
+ r = 4
217
+ leftUpPoint = (x0-r, y0-r)
218
+ rightDownPoint = (x0+r, y0+r)
219
+ draw.ellipse((leftUpPoint,rightDownPoint), fill='red')
220
+ draw.rectangle((x0,y0,x1,y1), outline=(255, 0, 0) )
221
+
222
+ elif radio == 'Four Points':
223
+
224
+ if global_dict[guest_id]['state'] == 0:
225
+ global_dict[guest_id]['stack'].append(
226
+ (text_position, t)
227
+ )
228
+ print(text_position, global_dict[guest_id]['stack'])
229
+ global_dict[guest_id]['state'] = 1
230
+ elif global_dict[guest_id]['state'] == 1:
231
+ (_, t) = global_dict[guest_id]['stack'].pop()
232
+ x, y = _
233
+ global_dict[guest_id]['stack'].append(
234
+ ((x,y,text_position[0],text_position[1]), t)
235
+ )
236
+ global_dict[guest_id]['state'] = 2
237
+ elif global_dict[guest_id]['state'] == 2:
238
+ (_, t) = global_dict[guest_id]['stack'].pop()
239
+ x0, y0, x1, y1 = _
240
+ global_dict[guest_id]['stack'].append(
241
+ ((x0, y0, x1, y1,text_position[0],text_position[1]), t)
242
+ )
243
+ global_dict[guest_id]['state'] = 3
244
+ elif global_dict[guest_id]['state'] == 3:
245
+ (_, t) = global_dict[guest_id]['stack'].pop()
246
+ x0, y0, x1, y1, x2, y2 = _
247
+ global_dict[guest_id]['stack'].append(
248
+ ((x0, y0, x1, y1, x2, y2,text_position[0],text_position[1]), t)
249
+ )
250
+ global_dict[guest_id]['state'] = 0
251
+
252
+ image = copy.deepcopy(orig_i)
253
+ draw = ImageDraw.Draw(image)
254
+
255
+ for items in global_dict[guest_id]['stack']:
256
+ text_position, t = items
257
+ if len(text_position) == 2:
258
+ x, y = text_position
259
+
260
+ x = int(512 * x / width)
261
+ y = int(512 * y / height)
262
+
263
+ text_color = (255, 0, 0)
264
+ draw.text((x+2, y), t, font=font, fill=text_color)
265
+ r = 4
266
+ leftUpPoint = (x-r, y-r)
267
+ rightDownPoint = (x+r, y+r)
268
+ draw.ellipse((leftUpPoint,rightDownPoint), fill='red')
269
+ elif len(text_position) == 4:
270
+ x0, y0, x1, y1 = text_position
271
+ text_color = (255, 0, 0)
272
+ draw.text((x0+2, y0), t, font=font, fill=text_color)
273
+ r = 4
274
+ leftUpPoint = (x0-r, y0-r)
275
+ rightDownPoint = (x0+r, y0+r)
276
+ draw.ellipse((leftUpPoint,rightDownPoint), fill='red')
277
+ draw.line(((x0,y0),(x1,y1)), fill=(255, 0, 0) )
278
+ elif len(text_position) == 6:
279
+ x0, y0, x1, y1, x2, y2 = text_position
280
+ text_color = (255, 0, 0)
281
+ draw.text((x0+2, y0), t, font=font, fill=text_color)
282
+ r = 4
283
+ leftUpPoint = (x0-r, y0-r)
284
+ rightDownPoint = (x0+r, y0+r)
285
+ draw.ellipse((leftUpPoint,rightDownPoint), fill='red')
286
+ draw.line(((x0,y0),(x1,y1)), fill=(255, 0, 0) )
287
+ draw.line(((x1,y1),(x2,y2)), fill=(255, 0, 0) )
288
+ elif len(text_position) == 8:
289
+ x0, y0, x1, y1, x2, y2, x3, y3 = text_position
290
+ text_color = (255, 0, 0)
291
+ draw.text((x0+2, y0), t, font=font, fill=text_color)
292
+ r = 4
293
+ leftUpPoint = (x0-r, y0-r)
294
+ rightDownPoint = (x0+r, y0+r)
295
+ draw.ellipse((leftUpPoint,rightDownPoint), fill='red')
296
+ draw.line(((x0,y0),(x1,y1)), fill=(255, 0, 0) )
297
+ draw.line(((x1,y1),(x2,y2)), fill=(255, 0, 0) )
298
+ draw.line(((x2,y2),(x3,y3)), fill=(255, 0, 0) )
299
+ draw.line(((x3,y3),(x0,y0)), fill=(255, 0, 0) )
300
+
301
+
302
+ print('stack', global_dict[guest_id]['stack'])
303
+
304
+ global_dict[str(seed)]['image_id'].append(list(image.getdata()))
305
+
306
+ return image, orig_i, seed
307
 
308
 
309
  font_layout = ImageFont.truetype('./Arial.ttf', 16)
 
328
  return blank
329
 
330
 
331
+ def to_tensor(image):
332
+ if isinstance(image, Image.Image):
333
+ image = np.array(image)
334
+ elif not isinstance(image, np.ndarray):
335
+ raise TypeError("Error")
336
+
337
+ image = image.astype(np.float32) / 255.0
338
+ image = np.transpose(image, (2, 0, 1))
339
+ tensor = torch.from_numpy(image)
340
+
341
+ return tensor
342
 
343
+ def test_fn(x,y):
344
+ print('hello')
345
 
346
+ def text_to_image(guest_id, i, orig_i, prompt,keywords,positive_prompt,radio,slider_step,slider_guidance,slider_batch,slider_temperature,slider_natural):
347
+
348
+ # print(type(i))
349
+ # exit(0)
350
+
351
+ print(f'[info] Prompt: {prompt} | Keywords: {keywords} | Radio: {radio} | Steps: {slider_step} | Guidance: {slider_guidance} | Natural: {slider_natural}')
352
+
353
+ # global stack
354
+ # global state
355
 
356
  if len(positive_prompt.strip()) != 0:
357
  prompt += positive_prompt
 
366
  prompt = tokenizer.encode(user_prompt)
367
  layout_image = None
368
  else:
369
+ if guest_id not in global_dict or len(global_dict[guest_id]['stack']) == 0:
370
 
371
  if len(keywords.strip()) == 0:
372
  template = f'Given a prompt that will be used to generate an image, plan the layout of visual text for the image. The size of the image is 128x128. Therefore, all properties of the positions should not exceed 128, including the coordinates of top, left, right, and bottom. All keywords are included in the caption. You dont need to specify the details of font styles. At each line, the format should be keyword left, top, right, bottom. So let us begin. Prompt: {user_prompt}'
 
447
  composed_prompt = tokenizer.decode(prompt)
448
 
449
  else:
450
+ user_prompt += ' <|endoftext|><|startoftext|>'
451
  layout_image = None
452
+
453
+ image_mask = Image.new('L', (512,512), 0)
454
+ draw = ImageDraw.Draw(image_mask)
455
+
456
+ for items in global_dict[guest_id]['stack']:
457
  position, text = items
458
 
459
+ # feature_mask
460
+ # masked_feature
461
 
462
  if len(position) == 2:
463
  x, y = position
464
  x = x // 4
465
  y = y // 4
466
  text_str = ' '.join([f'[{c}]' for c in list(text)])
467
+ user_prompt += f' l{x} t{y} {text_str} <|endoftext|>'
468
+
469
  elif len(position) == 4:
470
  x0, y0, x1, y1 = position
471
  x0 = x0 // 4
 
473
  x1 = x1 // 4
474
  y1 = y1 // 4
475
  text_str = ' '.join([f'[{c}]' for c in list(text)])
476
+ user_prompt += f' l{x0} t{y0} r{x1} b{y1} {text_str} <|endoftext|>'
477
+
478
+ draw.rectangle((x0*4, y0*4, x1*4, y1*4), fill=1)
479
+ print('prompt ', user_prompt)
480
+
481
+ elif len(position) == 8: # four points
482
+ x0, y0, x1, y1, x2, y2, x3, y3 = position
483
+ draw.polygon([(x0, y0), (x1, y1), (x2, y2), (x3, y3)], fill=1)
484
+ x0 = x0 // 4
485
+ y0 = y0 // 4
486
+ x1 = x1 // 4
487
+ y1 = y1 // 4
488
+ x2 = x2 // 4
489
+ y2 = y2 // 4
490
+ x3 = x3 // 4
491
+ y3 = y3 // 4
492
+ xmin = min(x0, x1, x2, x3)
493
+ ymin = min(y0, y1, y2, y3)
494
+ xmax = max(x0, x1, x2, x3)
495
+ ymax = max(y0, y1, y2, y3)
496
+ text_str = ' '.join([f'[{c}]' for c in list(text)])
497
+ user_prompt += f' l{xmin} t{ymin} r{xmax} b{ymax} {text_str} <|endoftext|>'
498
+
499
+ print('prompt ', user_prompt)
500
+
501
 
 
502
  prompt = tokenizer.encode(user_prompt)
503
  composed_prompt = tokenizer.decode(prompt)
504
 
 
506
  while len(prompt) < 77:
507
  prompt.append(tokenizer.pad_token_id)
508
 
509
+ prompts_cond = prompt
510
+ prompts_nocond = [tokenizer.pad_token_id]*77
511
+
512
+ prompts_cond = [prompts_cond] * slider_batch
513
+ prompts_nocond = [prompts_nocond] * slider_batch
514
+
515
+ prompts_cond = torch.Tensor(prompts_cond).long().cuda()
516
+ prompts_nocond = torch.Tensor(prompts_nocond).long().cuda()
517
+
518
+ scheduler = DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="scheduler")
519
+ scheduler.set_timesteps(slider_step)
520
+ noise = torch.randn((slider_batch, 4, 64, 64)).to("cuda").half()
521
+ input = noise
522
+
523
+ encoder_hidden_states_cond = text_encoder(prompts_cond)[0].half()
524
+ encoder_hidden_states_nocond = text_encoder(prompts_nocond)[0].half()
525
+
526
+ image_mask = torch.Tensor(np.array(image_mask)).float().half().cuda()
527
+ image_mask = image_mask.unsqueeze(0).unsqueeze(0).repeat(slider_batch, 1, 1, 1)
528
+
529
+ image = Image.open(orig_i).resize((512,512))
530
+ image_tensor = to_tensor(image).unsqueeze(0).cuda().sub_(0.5).div_(0.5)
531
+ print(f'image_tensor.shape {image_tensor.shape}')
532
+ masked_image = image_tensor * (1-image_mask)
533
+ masked_feature = vae.encode(masked_image.half()).latent_dist.sample()
534
+ masked_feature = masked_feature * vae.config.scaling_factor
535
+ masked_feature = masked_feature.half()
536
+ print(f'masked_feature.shape {masked_feature.shape}')
537
+
538
+ feature_mask = torch.nn.functional.interpolate(image_mask, size=(64,64), mode='nearest').cuda()
539
+
540
+ for t in tqdm(scheduler.timesteps):
541
+ with torch.no_grad(): # classifier free guidance
542
+
543
+ noise_pred_cond = unet(sample=input, timestep=t, encoder_hidden_states=encoder_hidden_states_cond[:slider_batch],feature_mask=feature_mask, masked_feature=masked_feature).sample # b, 4, 64, 64
544
+ noise_pred_uncond = unet(sample=input, timestep=t, encoder_hidden_states=encoder_hidden_states_nocond[:slider_batch],feature_mask=feature_mask, masked_feature=masked_feature).sample # b, 4, 64, 64
545
+ noisy_residual = noise_pred_uncond + slider_guidance * (noise_pred_cond - noise_pred_uncond) # b, 4, 64, 64
546
+ input = scheduler.step(noisy_residual, t, input).prev_sample
547
+ del noise_pred_cond
548
+ del noise_pred_uncond
549
+
550
+ torch.cuda.empty_cache()
551
+
552
+ # decode
553
+ input = 1 / vae.config.scaling_factor * input
554
+ images = vae.decode(input, return_dict=False)[0]
555
+ width, height = 512, 512
556
+ results = []
557
+ new_image = Image.new('RGB', (2*width, 2*height))
558
+ for index, image in enumerate(images.cpu().float()):
559
+ image = (image / 2 + 0.5).clamp(0, 1).unsqueeze(0)
560
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
561
+ image = Image.fromarray((image * 255).round().astype("uint8")).convert('RGB')
562
+ results.append(image)
563
+ row = index // 2
564
+ col = index % 2
565
+ new_image.paste(image, (col*width, row*height))
566
+ # os.system('nvidia-smi')
567
+ torch.cuda.empty_cache()
568
+ # os.system('nvidia-smi')
569
+ return tuple(results), composed_prompt
 
 
 
570
 
571
  with gr.Blocks() as demo:
572
 
 
576
  <h2 style="font-weight: 900; font-size: 2.3rem; margin: 0rem">
577
  TextDiffuser-2: Unleashing the Power of Language Models for Text Rendering
578
  </h2>
579
+ <h2 style="font-weight: 900; font-size: 1.3rem; margin: 0rem">
580
+ (Demo for <b>Text Inpainting</b> 🖼️🖌️)
581
+ </h2>
582
  <h2 style="font-weight: 460; font-size: 1.1rem; margin: 0rem">
583
  <a href="https://jingyechen.github.io/">Jingye Chen</a>, <a href="https://hypjudy.github.io/website/">Yupan Huang</a>, <a href="https://scholar.google.com/citations?user=0LTZGhUAAAAJ&hl=en">Tengchao Lv</a>, <a href="https://www.microsoft.com/en-us/research/people/lecu/">Lei Cui</a>, <a href="https://cqf.io/">Qifeng Chen</a>, <a href="https://thegenerality.com/">Furu Wei</a>
584
  </h2>
 
588
  <h3 style="font-weight: 450; font-size: 1rem; margin: 0rem">
589
  [<a href="https://arxiv.org/abs/2311.16465" style="color:blue;">arXiv</a>]
590
  [<a href="https://github.com/microsoft/unilm/tree/master/textdiffuser-2" style="color:blue;">Code</a>]
591
+ [<a href="https://jingyechen.github.io/textdiffuser2/" style="color:blue;">Project Page</a>]
592
+ [<a href="https://discord.gg/q7eHPupu" style="color:purple;">Discord</a>]
593
  </h3>
594
  <h2 style="text-align: left; font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
595
+ TextDiffuser-2 leverages language models to enhance text rendering, achieving greater flexibility. Different from text editing, the text inpainting task aims to add or modify text guided by users, ensuring that the inpainted text has a reasonable style (i.e., no need to match the style of the original text during modification exactly) and is coherent with backgrounds. TextDiffuser-2 offers an <b>improved user experience</b>. Specifically, users only need to type the text they wish to inpaint into the provided input box and then select key points on the Canvas.
596
  </h2>
597
  <h2 style="text-align: left; font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
598
+ 👀 <b>Tips for using this demo</b>: <b>(1)</b> Please carefully read the disclaimer in the below. Current verison can only support English. <b>(2)</b> The <b>prompt is optional</b>. If provided, the generated image may be more accurate. <b>(3)</b> Redo is used to cancel the last keyword, and undo is used to clear all keywords. <b>(4)</b> Current version only supports input image with resolution 512x512. <b>(5)</b> You can use either two points or four points to specify the text box. Using four points can better represent the perspective boxes. <b>(6)</b> Leave "Text to be inpaintd" empty can function as the text removal task. <b>(7)</b> Classifier-free guidance is set to a small value (e.g. 1) in default. It is noticed that a larger cfg may result in chromatic aberration against the background. <b>(8)</b> You can inpaint many text regions at one time. <b>(9)</b> Thanks for reading these tips, shall we start now?
599
  </h2>
600
+ <img src="https://raw.githubusercontent.com/JingyeChen/jingyechen.github.io/master/textdiffuser2/static/images/inpainting_blank.jpg" alt="textdiffuser-2">
 
 
 
 
 
 
601
  </div>
602
  """)
603
 
604
+ with gr.Tab("Text Inpainting"):
605
  with gr.Row():
606
+ with gr.Column():
607
+
608
+ keywords = gr.Textbox(label="(Optional) Keywords. Should be seperated by / (e.g., keyword1/keyword2/...)", placeholder="keyword1/keyword2", visible=False)
609
+ positive_prompt = gr.Textbox(label="(Optional) Positive prompt", value="", visible=False)
610
+
611
+ i = gr.Image(label="Image", type='filepath', value='https://raw.githubusercontent.com/JingyeChen/jingyechen.github.io/master/textdiffuser2/static/images/example11.jpg')
612
+ orig_i = gr.Image(label="Placeholder", type='filepath', height=512, width=512, visible=False)
613
+
614
+ radio = gr.Radio(["Two Points", "Four Points"], label="Number of points to represent the text box.", value="Two Points", visible=True)
615
+
616
+ with gr.Row():
617
+ t = gr.Textbox(label="Text to be inpainted", value='Test')
618
+ prompt = gr.Textbox(label="(Optional) Prompt.")
619
+ with gr.Row():
620
+ redo = gr.Button(value='Redo - Cancel the last keyword')
621
+ undo = gr.Button(value='Undo - Clear the canvas')
622
+ # skip_button = gr.Button(value='Skip - Operate the next keyword')
623
+
624
+ slider_natural = gr.Checkbox(label="Natural image generation", value=False, info="The text position and content info will not be incorporated.", visible=False)
625
+ slider_step = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Sampling step", info="The sampling step for TextDiffuser-2.")
626
+ slider_guidance = gr.Slider(minimum=1, maximum=13, value=1, step=0.5, label="Scale of classifier-free guidance", info="The scale of cfg and is set to 1 in default. Smaller cfg produce stable results.")
627
+ slider_batch = gr.Slider(minimum=1, maximum=6, value=4, step=1, label="Batch size", info="The number of images to be sampled.")
628
+ slider_temperature = gr.Slider(minimum=0.1, maximum=2, value=1.4, step=0.1, label="Temperature", info="Control the diversity of layout planner. Higher value indicates more diversity.", visible=False)
 
 
 
629
  # slider_seed = gr.Slider(minimum=1, maximum=10000, label="Seed", randomize=True)
630
  button = gr.Button("Generate")
631
+
632
+ guest_id_box = gr.Textbox(label="guest_id", value=f"-1", visible=False)
633
+ i.select(get_pixels,[i,orig_i,radio,t,guest_id_box],[i,orig_i,guest_id_box])
634
+ redo.click(exe_redo, [i,orig_i,t,guest_id_box],[i])
635
+ undo.click(exe_undo, [i,orig_i,t,guest_id_box],[i])
636
+ # skip_button.click(skip_fun, [i,t,guest_id_box])
637
+
638
 
639
+ with gr.Column():
640
+ output = gr.Gallery(label='Generated image', rows=2, height=768)
641
 
642
+ with gr.Accordion("Intermediate results", open=False, visible=False):
643
  gr.Markdown("Composed prompt")
644
  composed_prompt = gr.Textbox(label='')
645
+ # gr.Markdown("Layout visualization")
646
+ # layout = gr.Image(height=256, width=256)
647
 
648
 
649
+ button.click(text_to_image, inputs=[guest_id_box, i, orig_i, prompt,keywords,positive_prompt, radio,slider_step,slider_guidance,slider_batch,slider_temperature,slider_natural], outputs=[output, composed_prompt])
650
 
651
+ gr.Markdown("## Image Examples")
652
+ template = None
653
+ gr.Examples(
654
  [
655
+ ["https://raw.githubusercontent.com/JingyeChen/jingyechen.github.io/master/textdiffuser2/static/images/example1.jpg"],
656
+ ["https://raw.githubusercontent.com/JingyeChen/jingyechen.github.io/master/textdiffuser2/static/images/example2.jpg"],
657
+ ["https://raw.githubusercontent.com/JingyeChen/jingyechen.github.io/master/textdiffuser2/static/images/example3.jpg"],
658
+ ["https://raw.githubusercontent.com/JingyeChen/jingyechen.github.io/master/textdiffuser2/static/images/example4.jpg"],
659
+ ["https://raw.githubusercontent.com/JingyeChen/jingyechen.github.io/master/textdiffuser2/static/images/example5.jpg"],
660
+ ["https://raw.githubusercontent.com/JingyeChen/jingyechen.github.io/master/textdiffuser2/static/images/example7.jpg"],
661
+ ["https://raw.githubusercontent.com/JingyeChen/jingyechen.github.io/master/textdiffuser2/static/images/example8.jpg"],
662
+ ["https://raw.githubusercontent.com/JingyeChen/jingyechen.github.io/master/textdiffuser2/static/images/example11.jpg"],
663
+ ["https://raw.githubusercontent.com/JingyeChen/jingyechen.github.io/master/textdiffuser2/static/images/example12.jpg"],
664
+ ["https://raw.githubusercontent.com/JingyeChen/jingyechen.github.io/master/textdiffuser2/static/images/example13.jpg"],
665
+ ["https://raw.githubusercontent.com/JingyeChen/jingyechen.github.io/master/textdiffuser2/static/images/example14.jpg"],
666
+ ["https://raw.githubusercontent.com/JingyeChen/jingyechen.github.io/master/textdiffuser2/static/images/example15.jpg"],
 
 
 
 
 
 
 
 
667
  ],
668
  [
669
+ i
 
 
670
  ],
671
+ examples_per_page=25,
672
  )
673
 
674
  gr.HTML(