bumsika songweig commited on
Commit
2dee308
0 Parent(s):

Duplicate from songweig/rich-text-to-image

Browse files

Co-authored-by: Songwei Ge <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ venv
2
+ __pycache__/
3
+ *.pyc
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Rich Text To Image
3
+ emoji: 🌍
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 3.27.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: songweig/rich-text-to-image
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import os
4
+ import json
5
+ import time
6
+ import argparse
7
+ import torch
8
+ import numpy as np
9
+ from torchvision import transforms
10
+
11
+ from models.region_diffusion import RegionDiffusion
12
+ from utils.attention_utils import get_token_maps
13
+ from utils.richtext_utils import seed_everything, parse_json, get_region_diffusion_input,\
14
+ get_attention_control_input, get_gradient_guidance_input
15
+
16
+
17
+ import gradio as gr
18
+ from PIL import Image, ImageOps
19
+ from share_btn import community_icon_html, loading_icon_html, share_js, css
20
+
21
+
22
+ help_text = """
23
+ If you are encountering an error or not achieving your desired outcome, here are some potential reasons and recommendations to consider:
24
+ 1. If you format only a portion of a word rather than the complete word, an error may occur.
25
+ 2. If you use font color and get completely corrupted results, you may consider decrease the color weight lambda.
26
+ 3. Consider using a different seed.
27
+ """
28
+
29
+
30
+ canvas_html = """<iframe id='rich-text-root' style='width:100%' height='360px' src='file=rich-text-to-json-iframe.html' frameborder='0' scrolling='no'></iframe>"""
31
+ get_js_data = """
32
+ async (text_input, negative_prompt, num_segments, segment_threshold, inject_interval, inject_background, seed, color_guidance_weight, rich_text_input, height, width, steps, guidance_weights) => {
33
+ const richEl = document.getElementById("rich-text-root");
34
+ const data = richEl? richEl.contentDocument.body._data : {};
35
+ return [text_input, negative_prompt, num_segments, segment_threshold, inject_interval, inject_background, seed, color_guidance_weight, JSON.stringify(data), height, width, steps, guidance_weights];
36
+ }
37
+ """
38
+ set_js_data = """
39
+ async (text_input) => {
40
+ const richEl = document.getElementById("rich-text-root");
41
+ const data = text_input ? JSON.parse(text_input) : null;
42
+ if (richEl && data) richEl.contentDocument.body.setQuillContents(data);
43
+ }
44
+ """
45
+
46
+ get_window_url_params = """
47
+ async (url_params) => {
48
+ const params = new URLSearchParams(window.location.search);
49
+ url_params = Object.fromEntries(params);
50
+ return [url_params];
51
+ }
52
+ """
53
+
54
+
55
+ def load_url_params(url_params):
56
+ if 'prompt' in url_params:
57
+ return gr.update(visible=True), url_params
58
+ else:
59
+ return gr.update(visible=False), url_params
60
+
61
+
62
+ def main():
63
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
64
+ model = RegionDiffusion(device)
65
+
66
+ def generate(
67
+ text_input: str,
68
+ negative_text: str,
69
+ num_segments: int,
70
+ segment_threshold: float,
71
+ inject_interval: float,
72
+ inject_background: float,
73
+ seed: int,
74
+ color_guidance_weight: float,
75
+ rich_text_input: str,
76
+ height: int,
77
+ width: int,
78
+ steps: int,
79
+ guidance_weight: float,
80
+ ):
81
+ run_dir = 'results/'
82
+ os.makedirs(run_dir, exist_ok=True)
83
+ # Load region diffusion model.
84
+ height = int(height) if height else 512
85
+ width = int(width) if width else 512
86
+ steps = 41 if not steps else steps
87
+ guidance_weight = 8.5 if not guidance_weight else guidance_weight
88
+ text_input = rich_text_input if rich_text_input != '' and rich_text_input != None else text_input
89
+ print('text_input', text_input, width, height, steps, guidance_weight, num_segments, segment_threshold, inject_interval, inject_background, color_guidance_weight, negative_text)
90
+ if (text_input == '' or rich_text_input == ''):
91
+ raise gr.Error("Please enter some text.")
92
+ # parse json to span attributes
93
+ base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
94
+ color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance = parse_json(
95
+ json.loads(text_input))
96
+
97
+ # create control input for region diffusion
98
+ region_text_prompts, region_target_token_ids, base_tokens = get_region_diffusion_input(
99
+ model, base_text_prompt, style_text_prompts, footnote_text_prompts,
100
+ footnote_target_tokens, color_text_prompts, color_names)
101
+
102
+ # create control input for cross attention
103
+ text_format_dict = get_attention_control_input(
104
+ model, base_tokens, size_text_prompts_and_sizes)
105
+
106
+ # create control input for region guidance
107
+ text_format_dict, color_target_token_ids = get_gradient_guidance_input(
108
+ model, base_tokens, color_text_prompts, color_rgbs, text_format_dict, color_guidance_weight=color_guidance_weight)
109
+
110
+ seed_everything(seed)
111
+
112
+ # get token maps from plain text to image generation.
113
+ begin_time = time.time()
114
+ if model.selfattn_maps is None and model.crossattn_maps is None:
115
+ model.remove_tokenmap_hooks()
116
+ model.register_tokenmap_hooks()
117
+ else:
118
+ model.reset_attention_maps()
119
+ model.remove_tokenmap_hooks()
120
+ plain_img = model.produce_attn_maps([base_text_prompt], [negative_text],
121
+ height=height, width=width, num_inference_steps=steps,
122
+ guidance_scale=guidance_weight)
123
+ print('time lapses to get attention maps: %.4f' %
124
+ (time.time()-begin_time))
125
+ seed_everything(seed)
126
+ color_obj_masks, segments_vis, token_maps = get_token_maps(model.selfattn_maps, model.crossattn_maps, model.n_maps, run_dir,
127
+ 512//8, 512//8, color_target_token_ids[:-1], seed,
128
+ base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
129
+ return_vis=True)
130
+ seed_everything(seed)
131
+ model.masks, segments_vis, token_maps = get_token_maps(model.selfattn_maps, model.crossattn_maps, model.n_maps, run_dir,
132
+ 512//8, 512//8, region_target_token_ids[:-1], seed,
133
+ base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
134
+ return_vis=True)
135
+ color_obj_atten_all = torch.zeros_like(color_obj_masks[-1])
136
+ for obj_mask in color_obj_masks[:-1]:
137
+ color_obj_atten_all += obj_mask
138
+ color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
139
+ interpolation=transforms.InterpolationMode.BICUBIC,
140
+ antialias=True)
141
+ for color_obj_mask in color_obj_masks]
142
+ text_format_dict['color_obj_atten'] = color_obj_masks
143
+ text_format_dict['color_obj_atten_all'] = color_obj_atten_all
144
+ model.remove_tokenmap_hooks()
145
+
146
+ # generate image from rich text
147
+ begin_time = time.time()
148
+ seed_everything(seed)
149
+ rich_img = model.prompt_to_img(region_text_prompts, [negative_text],
150
+ height=height, width=width, num_inference_steps=steps,
151
+ guidance_scale=guidance_weight, use_guidance=use_grad_guidance,
152
+ text_format_dict=text_format_dict, inject_selfattn=inject_interval,
153
+ inject_background=inject_background)
154
+ print('time lapses to generate image from rich text: %.4f' %
155
+ (time.time()-begin_time))
156
+ return [plain_img[0], rich_img[0], segments_vis, token_maps]
157
+
158
+ with gr.Blocks(css=css) as demo:
159
+ url_params = gr.JSON({}, visible=False, label="URL Params")
160
+ gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">Expressive Text-to-Image Generation with Rich Text</h1>
161
+ <p> <a href="https://songweige.github.io/">Songwei Ge</a>, <a href="https://taesung.me/">Taesung Park</a>, <a href="https://www.cs.cmu.edu/~junyanz/">Jun-Yan Zhu</a>, <a href="https://jbhuang0604.github.io/">Jia-Bin Huang</a> <p/>
162
+ <p> UMD, Adobe, CMU <p/>
163
+ <p> <a href="https://huggingface.co/spaces/songweig/rich-text-to-image?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="display:inline;"alt="Duplicate Space"></a> | <a href="https://rich-text-to-image.github.io">[Website]</a> | <a href="https://github.com/SongweiGe/rich-text-to-image">[Code]</a> | <a href="https://arxiv.org/abs/2304.06720">[Paper]</a><p/>
164
+ <p> For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.""")
165
+ with gr.Row():
166
+ with gr.Column():
167
+ rich_text_el = gr.HTML(canvas_html, elem_id="canvas_html")
168
+ rich_text_input = gr.Textbox(value="", visible=False)
169
+ text_input = gr.Textbox(
170
+ label='Rich-text JSON Input',
171
+ visible=False,
172
+ max_lines=1,
173
+ placeholder='Example: \'{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#b26b00"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background.\n"}]}\'',
174
+ elem_id="text_input"
175
+ )
176
+ negative_prompt = gr.Textbox(
177
+ label='Negative Prompt',
178
+ max_lines=1,
179
+ placeholder='Example: poor quality, blurry, dark, low resolution, low quality, worst quality',
180
+ elem_id="negative_prompt"
181
+ )
182
+ segment_threshold = gr.Slider(label='Token map threshold',
183
+ info='(See less area in token maps? Decrease this. See too much area? Increase this.)',
184
+ minimum=0,
185
+ maximum=1,
186
+ step=0.01,
187
+ value=0.25)
188
+ inject_interval = gr.Slider(label='Detail preservation',
189
+ info='(To preserve more structure from plain-text generation, increase this. To see more rich-text attributes, decrease this.)',
190
+ minimum=0,
191
+ maximum=1,
192
+ step=0.01,
193
+ value=0.)
194
+ inject_background = gr.Slider(label='Unformatted token preservation',
195
+ info='(To affect less the tokens without any rich-text attributes, increase this.)',
196
+ minimum=0,
197
+ maximum=1,
198
+ step=0.01,
199
+ value=0.3)
200
+ color_guidance_weight = gr.Slider(label='Color weight',
201
+ info='(To obtain more precise color, increase this, while too large value may cause artifacts.)',
202
+ minimum=0,
203
+ maximum=2,
204
+ step=0.1,
205
+ value=0.5)
206
+ num_segments = gr.Slider(label='Number of segments',
207
+ minimum=2,
208
+ maximum=20,
209
+ step=1,
210
+ value=9)
211
+ seed = gr.Slider(label='Seed',
212
+ minimum=0,
213
+ maximum=100000,
214
+ step=1,
215
+ value=6,
216
+ elem_id="seed"
217
+ )
218
+ with gr.Accordion('Other Parameters', open=False):
219
+ steps = gr.Slider(label='Number of Steps',
220
+ minimum=0,
221
+ maximum=500,
222
+ step=1,
223
+ value=41)
224
+ guidance_weight = gr.Slider(label='CFG weight',
225
+ minimum=0,
226
+ maximum=50,
227
+ step=0.1,
228
+ value=8.5)
229
+ width = gr.Dropdown(choices=[512],
230
+ value=512,
231
+ label='Width',
232
+ visible=True)
233
+ height = gr.Dropdown(choices=[512],
234
+ value=512,
235
+ label='height',
236
+ visible=True)
237
+
238
+ with gr.Row():
239
+ with gr.Column(scale=1, min_width=100):
240
+ generate_button = gr.Button("Generate")
241
+ load_params_button = gr.Button(
242
+ "Load from URL Params", visible=True)
243
+ with gr.Column():
244
+ richtext_result = gr.Image(
245
+ label='Rich-text', elem_id="rich-text-image")
246
+ richtext_result.style(height=512)
247
+ with gr.Row():
248
+ plaintext_result = gr.Image(
249
+ label='Plain-text', elem_id="plain-text-image")
250
+ segments = gr.Image(label='Segmentation')
251
+ with gr.Row():
252
+ token_map = gr.Image(label='Token Maps')
253
+ with gr.Row(visible=False) as share_row:
254
+ with gr.Group(elem_id="share-btn-container"):
255
+ community_icon = gr.HTML(community_icon_html)
256
+ loading_icon = gr.HTML(loading_icon_html)
257
+ share_button = gr.Button(
258
+ "Share to community", elem_id="share-btn")
259
+ share_button.click(None, [], [], _js=share_js)
260
+ with gr.Row():
261
+ gr.Markdown(help_text)
262
+
263
+ with gr.Row():
264
+ footnote_examples = [
265
+ [
266
+ '{"ops":[{"insert":"A close-up 4k dslr photo of a "},{"attributes":{"link":"A cat wearing sunglasses and a bandana around its neck."},"insert":"cat"},{"insert":" riding a scooter. Palm trees in the background."}]}',
267
+ '',
268
+ 5,
269
+ 0.3,
270
+ 0,
271
+ 0.5,
272
+ 6,
273
+ 0,
274
+ None,
275
+ ],
276
+ [
277
+ '{"ops":[{"insert":"A "},{"attributes":{"link":"Thor Kitchen 30 Inch Wide Freestanding Gas Range with Automatic Re-Ignition System"},"insert":"kitchen island"},{"insert":" next to a "},{"attributes":{"link":"an open refrigerator stocked with fresh produce, dairy products, and beverages. "},"insert":"refrigerator"},{"insert":", by James McDonald and Joarc Architects, home, interior, octane render, deviantart, cinematic, key art, hyperrealism, sun light, sunrays, canon eos c 300, ƒ 1.8, 35 mm, 8k, medium - format print"}]}',
278
+ '',
279
+ 7,
280
+ 0.5,
281
+ 0,
282
+ 0.5,
283
+ 6,
284
+ 0,
285
+ None,
286
+ ],
287
+ [
288
+ '{"ops":[{"insert":"A "},{"attributes":{"link":"Happy Kung fu panda art, elder, asian art, volumetric lighting, dramatic scene, ultra detailed, realism, chinese"},"insert":"panda"},{"insert":" standing on a cliff by a waterfall, wildlife photography, photograph, high quality, wildlife, f 1.8, soft focus, 8k, national geographic, award - winning photograph by nick nichols"}]}',
289
+ '',
290
+ 5,
291
+ 0.3,
292
+ 0,
293
+ 0.1,
294
+ 4,
295
+ 0,
296
+ None,
297
+ ],
298
+ ]
299
+
300
+ gr.Examples(examples=footnote_examples,
301
+ label='Footnote examples',
302
+ inputs=[
303
+ text_input,
304
+ negative_prompt,
305
+ num_segments,
306
+ segment_threshold,
307
+ inject_interval,
308
+ inject_background,
309
+ seed,
310
+ color_guidance_weight,
311
+ rich_text_input,
312
+ ],
313
+ outputs=[
314
+ plaintext_result,
315
+ richtext_result,
316
+ segments,
317
+ token_map,
318
+ ],
319
+ fn=generate,
320
+ cache_examples=True,
321
+ examples_per_page=20)
322
+ with gr.Row():
323
+ color_examples = [
324
+ [
325
+ '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#04a704"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
326
+ 'lowres, had anatomy, bad hands, cropped, worst quality',
327
+ 11,
328
+ 0.3,
329
+ 0.3,
330
+ 0.3,
331
+ 6,
332
+ 0.5,
333
+ None,
334
+ ],
335
+ [
336
+ '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#999999"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
337
+ 'lowres, had anatomy, bad hands, cropped, worst quality',
338
+ 11,
339
+ 0.3,
340
+ 0.3,
341
+ 0.3,
342
+ 6,
343
+ 0.5,
344
+ None,
345
+ ],
346
+ [
347
+ '{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#FD6C9E"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background."}]}',
348
+ '',
349
+ 10,
350
+ 0.4,
351
+ 0.5,
352
+ 0.3,
353
+ 6,
354
+ 0.5,
355
+ None,
356
+ ],
357
+ [
358
+ '{"ops":[{"insert":"A mesmerizing sight that captures the beauty of a "},{"attributes":{"color":"#4775fc"},"insert":"rose"},{"insert":" blooming, close up"}]}',
359
+ '',
360
+ 3,
361
+ 0.3,
362
+ 0,
363
+ 0,
364
+ 9,
365
+ 1,
366
+ None,
367
+ ],
368
+ [
369
+ '{"ops":[{"insert":"A "},{"attributes":{"color":"#FFD700"},"insert":"marble statue of a wolf\'s head and shoulder"},{"insert":", surrounded by colorful flowers michelangelo, detailed, intricate, full of color, led lighting, trending on artstation, 4 k, hyperrealistic, 3 5 mm, focused, extreme details, unreal engine 5, masterpiece "}]}',
370
+ '',
371
+ 5,
372
+ 0.4,
373
+ 0.3,
374
+ 0.3,
375
+ 5,
376
+ 0.6,
377
+ None,
378
+ ],
379
+ ]
380
+ gr.Examples(examples=color_examples,
381
+ label='Font color examples',
382
+ inputs=[
383
+ text_input,
384
+ negative_prompt,
385
+ num_segments,
386
+ segment_threshold,
387
+ inject_interval,
388
+ inject_background,
389
+ seed,
390
+ color_guidance_weight,
391
+ rich_text_input,
392
+ ],
393
+ outputs=[
394
+ plaintext_result,
395
+ richtext_result,
396
+ segments,
397
+ token_map,
398
+ ],
399
+ fn=generate,
400
+ cache_examples=True,
401
+ examples_per_page=20)
402
+
403
+ with gr.Row():
404
+ style_examples = [
405
+ [
406
+ '{"ops":[{"insert":"a "},{"attributes":{"font":"mirza"},"insert":"beautiful garden"},{"insert":" with a "},{"attributes":{"font":"roboto"},"insert":"snow mountain in the background"},{"insert":""}]}',
407
+ '',
408
+ 10,
409
+ 0.4,
410
+ 0,
411
+ 0.2,
412
+ 3,
413
+ 0,
414
+ None,
415
+ ],
416
+ [
417
+ '{"ops":[{"attributes":{"link":"the awe-inspiring sky and ocean in the style of J.M.W. Turner"},"insert":"the awe-inspiring sky and sea"},{"insert":" by "},{"attributes":{"font":"mirza"},"insert":"a coast with flowers and grasses in spring"}]}',
418
+ 'worst quality, dark, poor quality',
419
+ 5,
420
+ 0.3,
421
+ 0,
422
+ 0,
423
+ 9,
424
+ 0.5,
425
+ None,
426
+ ],
427
+ [
428
+ '{"ops":[{"insert":"a "},{"attributes":{"font":"slabo"},"insert":"night sky filled with stars"},{"insert":" above a "},{"attributes":{"font":"roboto"},"insert":"turbulent sea with giant waves"}]}',
429
+ '',
430
+ 2,
431
+ 0.35,
432
+ 0,
433
+ 0,
434
+ 6,
435
+ 0.5,
436
+ None,
437
+ ],
438
+ ]
439
+ gr.Examples(examples=style_examples,
440
+ label='Font style examples',
441
+ inputs=[
442
+ text_input,
443
+ negative_prompt,
444
+ num_segments,
445
+ segment_threshold,
446
+ inject_interval,
447
+ inject_background,
448
+ seed,
449
+ color_guidance_weight,
450
+ rich_text_input,
451
+ ],
452
+ outputs=[
453
+ plaintext_result,
454
+ richtext_result,
455
+ segments,
456
+ token_map,
457
+ ],
458
+ fn=generate,
459
+ cache_examples=True,
460
+ examples_per_page=20)
461
+
462
+ with gr.Row():
463
+ size_examples = [
464
+ [
465
+ '{"ops": [{"insert": "A pizza with "}, {"attributes": {"size": "60px"}, "insert": "pineapple"}, {"insert": ", pepperoni, and mushroom on the top, 4k, photorealistic"}]}',
466
+ 'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
467
+ 5,
468
+ 0.3,
469
+ 0,
470
+ 0,
471
+ 13,
472
+ 1,
473
+ None,
474
+ ],
475
+ [
476
+ '{"ops": [{"insert": "A pizza with pineapple, "}, {"attributes": {"size": "20px"}, "insert": "pepperoni"}, {"insert": ", and mushroom on the top, 4k, photorealistic"}]}',
477
+ 'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
478
+ 5,
479
+ 0.3,
480
+ 0,
481
+ 0,
482
+ 13,
483
+ 1,
484
+ None,
485
+ ],
486
+ [
487
+ '{"ops": [{"insert": "A pizza with pineapple, pepperoni, and "}, {"attributes": {"size": "70px"}, "insert": "mushroom"}, {"insert": " on the top, 4k, photorealistic"}]}',
488
+ 'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
489
+ 5,
490
+ 0.3,
491
+ 0,
492
+ 0,
493
+ 13,
494
+ 1,
495
+ None,
496
+ ],
497
+ ]
498
+ gr.Examples(examples=size_examples,
499
+ label='Font size examples',
500
+ inputs=[
501
+ text_input,
502
+ negative_prompt,
503
+ num_segments,
504
+ segment_threshold,
505
+ inject_interval,
506
+ inject_background,
507
+ seed,
508
+ color_guidance_weight,
509
+ rich_text_input,
510
+ ],
511
+ outputs=[
512
+ plaintext_result,
513
+ richtext_result,
514
+ segments,
515
+ token_map,
516
+ ],
517
+ fn=generate,
518
+ cache_examples=True,
519
+ examples_per_page=20)
520
+ generate_button.click(fn=lambda: gr.update(visible=False), inputs=None, outputs=share_row, queue=False).then(
521
+ fn=generate,
522
+ inputs=[
523
+ text_input,
524
+ negative_prompt,
525
+ num_segments,
526
+ segment_threshold,
527
+ inject_interval,
528
+ inject_background,
529
+ seed,
530
+ color_guidance_weight,
531
+ rich_text_input,
532
+ height,
533
+ width,
534
+ steps,
535
+ guidance_weight,
536
+ ],
537
+ outputs=[plaintext_result, richtext_result, segments, token_map],
538
+ _js=get_js_data
539
+ ).then(
540
+ fn=lambda: gr.update(visible=True), inputs=None, outputs=share_row, queue=False)
541
+ text_input.change(
542
+ fn=None, inputs=[text_input], outputs=None, _js=set_js_data, queue=False)
543
+ # load url param prompt to textinput
544
+ load_params_button.click(fn=lambda x: x['prompt'], inputs=[
545
+ url_params], outputs=[text_input], queue=False)
546
+ demo.load(
547
+ fn=load_url_params,
548
+ inputs=[url_params],
549
+ outputs=[load_params_button, url_params],
550
+ _js=get_window_url_params
551
+ )
552
+ demo.queue(concurrency_count=1)
553
+ demo.launch(share=False)
554
+
555
+
556
+ if __name__ == "__main__":
557
+ main()
models/attention.py ADDED
@@ -0,0 +1,904 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ import warnings
16
+ from dataclasses import dataclass
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from torch import nn
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.models.modeling_utils import ModelMixin
25
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
26
+ from diffusers.utils import BaseOutput
27
+ from diffusers.utils.import_utils import is_xformers_available
28
+
29
+
30
+ @dataclass
31
+ class Transformer2DModelOutput(BaseOutput):
32
+ """
33
+ Args:
34
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
35
+ Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
36
+ for the unnoised latent pixels.
37
+ """
38
+
39
+ sample: torch.FloatTensor
40
+
41
+
42
+ if is_xformers_available():
43
+ import xformers
44
+ import xformers.ops
45
+ else:
46
+ xformers = None
47
+
48
+
49
+ class Transformer2DModel(ModelMixin, ConfigMixin):
50
+ """
51
+ Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
52
+ embeddings) inputs.
53
+
54
+ When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
55
+ transformer action. Finally, reshape to image.
56
+
57
+ When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
58
+ embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
59
+ classes of unnoised image.
60
+
61
+ Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
62
+ image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
63
+
64
+ Parameters:
65
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
66
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
67
+ in_channels (`int`, *optional*):
68
+ Pass if the input is continuous. The number of channels in the input and output.
69
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
70
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
71
+ cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
72
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
73
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
74
+ `ImagePositionalEmbeddings`.
75
+ num_vector_embeds (`int`, *optional*):
76
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
77
+ Includes the class for the masked latent pixel.
78
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
79
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
80
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
81
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
82
+ up to but not more than steps than `num_embeds_ada_norm`.
83
+ attention_bias (`bool`, *optional*):
84
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
85
+ """
86
+
87
+ @register_to_config
88
+ def __init__(
89
+ self,
90
+ num_attention_heads: int = 16,
91
+ attention_head_dim: int = 88,
92
+ in_channels: Optional[int] = None,
93
+ num_layers: int = 1,
94
+ dropout: float = 0.0,
95
+ norm_num_groups: int = 32,
96
+ cross_attention_dim: Optional[int] = None,
97
+ attention_bias: bool = False,
98
+ sample_size: Optional[int] = None,
99
+ num_vector_embeds: Optional[int] = None,
100
+ activation_fn: str = "geglu",
101
+ num_embeds_ada_norm: Optional[int] = None,
102
+ use_linear_projection: bool = False,
103
+ only_cross_attention: bool = False,
104
+ ):
105
+ super().__init__()
106
+ self.use_linear_projection = use_linear_projection
107
+ self.num_attention_heads = num_attention_heads
108
+ self.attention_head_dim = attention_head_dim
109
+ inner_dim = num_attention_heads * attention_head_dim
110
+
111
+ # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
112
+ # Define whether input is continuous or discrete depending on configuration
113
+ self.is_input_continuous = in_channels is not None
114
+ self.is_input_vectorized = num_vector_embeds is not None
115
+
116
+ if self.is_input_continuous and self.is_input_vectorized:
117
+ raise ValueError(
118
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
119
+ " sure that either `in_channels` or `num_vector_embeds` is None."
120
+ )
121
+ elif not self.is_input_continuous and not self.is_input_vectorized:
122
+ raise ValueError(
123
+ f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
124
+ " sure that either `in_channels` or `num_vector_embeds` is not None."
125
+ )
126
+
127
+ # 2. Define input layers
128
+ if self.is_input_continuous:
129
+ self.in_channels = in_channels
130
+
131
+ self.norm = torch.nn.GroupNorm(
132
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
133
+ if use_linear_projection:
134
+ self.proj_in = nn.Linear(in_channels, inner_dim)
135
+ else:
136
+ self.proj_in = nn.Conv2d(
137
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
138
+ elif self.is_input_vectorized:
139
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
140
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
141
+
142
+ self.height = sample_size
143
+ self.width = sample_size
144
+ self.num_vector_embeds = num_vector_embeds
145
+ self.num_latent_pixels = self.height * self.width
146
+
147
+ self.latent_image_embedding = ImagePositionalEmbeddings(
148
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
149
+ )
150
+
151
+ # 3. Define transformers blocks
152
+ self.transformer_blocks = nn.ModuleList(
153
+ [
154
+ BasicTransformerBlock(
155
+ inner_dim,
156
+ num_attention_heads,
157
+ attention_head_dim,
158
+ dropout=dropout,
159
+ cross_attention_dim=cross_attention_dim,
160
+ activation_fn=activation_fn,
161
+ num_embeds_ada_norm=num_embeds_ada_norm,
162
+ attention_bias=attention_bias,
163
+ only_cross_attention=only_cross_attention,
164
+ )
165
+ for d in range(num_layers)
166
+ ]
167
+ )
168
+
169
+ # 4. Define output layers
170
+ if self.is_input_continuous:
171
+ if use_linear_projection:
172
+ self.proj_out = nn.Linear(in_channels, inner_dim)
173
+ else:
174
+ self.proj_out = nn.Conv2d(
175
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
176
+ elif self.is_input_vectorized:
177
+ self.norm_out = nn.LayerNorm(inner_dim)
178
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
179
+
180
+ def _set_attention_slice(self, slice_size):
181
+ for block in self.transformer_blocks:
182
+ block._set_attention_slice(slice_size)
183
+
184
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None,
185
+ text_format_dict={}, return_dict: bool = True):
186
+ """
187
+ Args:
188
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
189
+ When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
190
+ hidden_states
191
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
192
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
193
+ self-attention.
194
+ timestep ( `torch.long`, *optional*):
195
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
196
+ return_dict (`bool`, *optional*, defaults to `True`):
197
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
198
+
199
+ Returns:
200
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
201
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
202
+ tensor.
203
+ """
204
+ # 1. Input
205
+ if self.is_input_continuous:
206
+ batch, channel, height, weight = hidden_states.shape
207
+ residual = hidden_states
208
+
209
+ hidden_states = self.norm(hidden_states)
210
+ if not self.use_linear_projection:
211
+ hidden_states = self.proj_in(hidden_states)
212
+ inner_dim = hidden_states.shape[1]
213
+ hidden_states = hidden_states.permute(
214
+ 0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
215
+ else:
216
+ inner_dim = hidden_states.shape[1]
217
+ hidden_states = hidden_states.permute(
218
+ 0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
219
+ hidden_states = self.proj_in(hidden_states)
220
+ elif self.is_input_vectorized:
221
+ hidden_states = self.latent_image_embedding(hidden_states)
222
+
223
+ # 2. Blocks
224
+ for block in self.transformer_blocks:
225
+ hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep,
226
+ text_format_dict=text_format_dict)
227
+
228
+ # 3. Output
229
+ if self.is_input_continuous:
230
+ if not self.use_linear_projection:
231
+ hidden_states = (
232
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(
233
+ 0, 3, 1, 2).contiguous()
234
+ )
235
+ hidden_states = self.proj_out(hidden_states)
236
+ else:
237
+ hidden_states = self.proj_out(hidden_states)
238
+ hidden_states = (
239
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(
240
+ 0, 3, 1, 2).contiguous()
241
+ )
242
+
243
+ output = hidden_states + residual
244
+ elif self.is_input_vectorized:
245
+ hidden_states = self.norm_out(hidden_states)
246
+ logits = self.out(hidden_states)
247
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
248
+ logits = logits.permute(0, 2, 1)
249
+
250
+ # log(p(x_0))
251
+ output = F.log_softmax(logits.double(), dim=1).float()
252
+
253
+ if not return_dict:
254
+ return (output,)
255
+
256
+ return Transformer2DModelOutput(sample=output)
257
+
258
+ def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
259
+ for block in self.transformer_blocks:
260
+ block._set_use_memory_efficient_attention_xformers(
261
+ use_memory_efficient_attention_xformers)
262
+
263
+
264
+ class AttentionBlock(nn.Module):
265
+ """
266
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
267
+ to the N-d case.
268
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
269
+ Uses three q, k, v linear layers to compute attention.
270
+
271
+ Parameters:
272
+ channels (`int`): The number of channels in the input and output.
273
+ num_head_channels (`int`, *optional*):
274
+ The number of channels in each head. If None, then `num_heads` = 1.
275
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
276
+ rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
277
+ eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
278
+ """
279
+
280
+ def __init__(
281
+ self,
282
+ channels: int,
283
+ num_head_channels: Optional[int] = None,
284
+ norm_num_groups: int = 32,
285
+ rescale_output_factor: float = 1.0,
286
+ eps: float = 1e-5,
287
+ ):
288
+ super().__init__()
289
+ self.channels = channels
290
+
291
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
292
+ self.num_head_size = num_head_channels
293
+ self.group_norm = nn.GroupNorm(
294
+ num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
295
+
296
+ # define q,k,v as linear layers
297
+ self.query = nn.Linear(channels, channels)
298
+ self.key = nn.Linear(channels, channels)
299
+ self.value = nn.Linear(channels, channels)
300
+
301
+ self.rescale_output_factor = rescale_output_factor
302
+ self.proj_attn = nn.Linear(channels, channels, 1)
303
+
304
+ def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
305
+ new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
306
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
307
+ new_projection = projection.view(
308
+ new_projection_shape).permute(0, 2, 1, 3)
309
+ return new_projection
310
+
311
+ def forward(self, hidden_states):
312
+ residual = hidden_states
313
+ batch, channel, height, width = hidden_states.shape
314
+
315
+ # norm
316
+ hidden_states = self.group_norm(hidden_states)
317
+
318
+ hidden_states = hidden_states.view(
319
+ batch, channel, height * width).transpose(1, 2)
320
+
321
+ # proj to q, k, v
322
+ query_proj = self.query(hidden_states)
323
+ key_proj = self.key(hidden_states)
324
+ value_proj = self.value(hidden_states)
325
+
326
+ scale = 1 / math.sqrt(self.channels / self.num_heads)
327
+
328
+ # get scores
329
+ if self.num_heads > 1:
330
+ query_states = self.transpose_for_scores(query_proj)
331
+ key_states = self.transpose_for_scores(key_proj)
332
+ value_states = self.transpose_for_scores(value_proj)
333
+
334
+ # TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
335
+ # or reformulate this into a 3D problem?
336
+ # TODO: measure whether on MPS device it would be faster to do this matmul via einsum
337
+ # as some matmuls can be 1.94x slower than an equivalent einsum on MPS
338
+ # https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
339
+ attention_scores = torch.matmul(
340
+ query_states, key_states.transpose(-1, -2)) * scale
341
+ else:
342
+ query_states, key_states, value_states = query_proj, key_proj, value_proj
343
+
344
+ attention_scores = torch.baddbmm(
345
+ torch.empty(
346
+ query_states.shape[0],
347
+ query_states.shape[1],
348
+ key_states.shape[1],
349
+ dtype=query_states.dtype,
350
+ device=query_states.device,
351
+ ),
352
+ query_states,
353
+ key_states.transpose(-1, -2),
354
+ beta=0,
355
+ alpha=scale,
356
+ )
357
+
358
+ attention_probs = torch.softmax(
359
+ attention_scores.float(), dim=-1).type(attention_scores.dtype)
360
+
361
+ # compute attention output
362
+ if self.num_heads > 1:
363
+ # TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
364
+ # or reformulate this into a 3D problem?
365
+ # TODO: measure whether on MPS device it would be faster to do this matmul via einsum
366
+ # as some matmuls can be 1.94x slower than an equivalent einsum on MPS
367
+ # https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
368
+ hidden_states = torch.matmul(attention_probs, value_states)
369
+ hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
370
+ new_hidden_states_shape = hidden_states.size()[
371
+ :-2] + (self.channels,)
372
+ hidden_states = hidden_states.view(new_hidden_states_shape)
373
+ else:
374
+ hidden_states = torch.bmm(attention_probs, value_states)
375
+
376
+ # compute next hidden_states
377
+ hidden_states = self.proj_attn(hidden_states)
378
+ hidden_states = hidden_states.transpose(
379
+ -1, -2).reshape(batch, channel, height, width)
380
+
381
+ # res connect and rescale
382
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
383
+ return hidden_states
384
+
385
+
386
+ class BasicTransformerBlock(nn.Module):
387
+ r"""
388
+ A basic Transformer block.
389
+
390
+ Parameters:
391
+ dim (`int`): The number of channels in the input and output.
392
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
393
+ attention_head_dim (`int`): The number of channels in each head.
394
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
395
+ cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
396
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
397
+ num_embeds_ada_norm (:
398
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
399
+ attention_bias (:
400
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
401
+ """
402
+
403
+ def __init__(
404
+ self,
405
+ dim: int,
406
+ num_attention_heads: int,
407
+ attention_head_dim: int,
408
+ dropout=0.0,
409
+ cross_attention_dim: Optional[int] = None,
410
+ activation_fn: str = "geglu",
411
+ num_embeds_ada_norm: Optional[int] = None,
412
+ attention_bias: bool = False,
413
+ only_cross_attention: bool = False,
414
+ ):
415
+ super().__init__()
416
+ self.only_cross_attention = only_cross_attention
417
+ self.attn1 = CrossAttention(
418
+ query_dim=dim,
419
+ heads=num_attention_heads,
420
+ dim_head=attention_head_dim,
421
+ dropout=dropout,
422
+ bias=attention_bias,
423
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
424
+ ) # is a self-attention
425
+ self.ff = FeedForward(dim, dropout=dropout,
426
+ activation_fn=activation_fn)
427
+ self.attn2 = CrossAttention(
428
+ query_dim=dim,
429
+ cross_attention_dim=cross_attention_dim,
430
+ heads=num_attention_heads,
431
+ dim_head=attention_head_dim,
432
+ dropout=dropout,
433
+ bias=attention_bias,
434
+ ) # is self-attn if context is none
435
+
436
+ # layer norms
437
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
438
+ if self.use_ada_layer_norm:
439
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
440
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
441
+ else:
442
+ self.norm1 = nn.LayerNorm(dim)
443
+ self.norm2 = nn.LayerNorm(dim)
444
+ self.norm3 = nn.LayerNorm(dim)
445
+
446
+ # if xformers is installed try to use memory_efficient_attention by default
447
+ if is_xformers_available():
448
+ try:
449
+ self._set_use_memory_efficient_attention_xformers(True)
450
+ except Exception as e:
451
+ warnings.warn(
452
+ "Could not enable memory efficient attention. Make sure xformers is installed"
453
+ f" correctly and a GPU is available: {e}"
454
+ )
455
+
456
+ def _set_attention_slice(self, slice_size):
457
+ self.attn1._slice_size = slice_size
458
+ self.attn2._slice_size = slice_size
459
+
460
+ def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
461
+ if not is_xformers_available():
462
+ print("Here is how to install it")
463
+ raise ModuleNotFoundError(
464
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
465
+ " xformers",
466
+ name="xformers",
467
+ )
468
+ elif not torch.cuda.is_available():
469
+ raise ValueError(
470
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
471
+ " available for GPU "
472
+ )
473
+ else:
474
+ try:
475
+ # Make sure we can run the memory efficient attention
476
+ _ = xformers.ops.memory_efficient_attention(
477
+ torch.randn((1, 2, 40), device="cuda"),
478
+ torch.randn((1, 2, 40), device="cuda"),
479
+ torch.randn((1, 2, 40), device="cuda"),
480
+ )
481
+ except Exception as e:
482
+ raise e
483
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
484
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
485
+
486
+ def forward(self, hidden_states, context=None, timestep=None, text_format_dict={}):
487
+ # 1. Self-Attention
488
+ norm_hidden_states = (
489
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(
490
+ hidden_states)
491
+ )
492
+
493
+ if self.only_cross_attention:
494
+ attn_out, _ = self.attn1(
495
+ norm_hidden_states, context=context, text_format_dict=text_format_dict) + hidden_states
496
+ hidden_states = attn_out + hidden_states
497
+ else:
498
+ attn_out, _ = self.attn1(norm_hidden_states)
499
+ hidden_states = attn_out + hidden_states
500
+
501
+ # 2. Cross-Attention
502
+ norm_hidden_states = (
503
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(
504
+ hidden_states)
505
+ )
506
+ attn_out, _ = self.attn2(
507
+ norm_hidden_states, context=context, text_format_dict=text_format_dict)
508
+ hidden_states = attn_out + hidden_states
509
+
510
+ # 3. Feed-forward
511
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
512
+
513
+ return hidden_states
514
+
515
+
516
+ class CrossAttention(nn.Module):
517
+ r"""
518
+ A cross attention layer.
519
+
520
+ Parameters:
521
+ query_dim (`int`): The number of channels in the query.
522
+ cross_attention_dim (`int`, *optional*):
523
+ The number of channels in the context. If not given, defaults to `query_dim`.
524
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
525
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
526
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
527
+ bias (`bool`, *optional*, defaults to False):
528
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
529
+ """
530
+
531
+ def __init__(
532
+ self,
533
+ query_dim: int,
534
+ cross_attention_dim: Optional[int] = None,
535
+ heads: int = 8,
536
+ dim_head: int = 64,
537
+ dropout: float = 0.0,
538
+ bias=False,
539
+ ):
540
+ super().__init__()
541
+ inner_dim = dim_head * heads
542
+ self.is_cross_attn = cross_attention_dim is not None
543
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
544
+
545
+ self.scale = dim_head**-0.5
546
+ self.heads = heads
547
+ # for slice_size > 0 the attention score computation
548
+ # is split across the batch axis to save memory
549
+ # You can set slice_size with `set_attention_slice`
550
+ self._slice_size = None
551
+ self._use_memory_efficient_attention_xformers = False
552
+
553
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
554
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
555
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
556
+
557
+ self.to_out = nn.ModuleList([])
558
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
559
+ self.to_out.append(nn.Dropout(dropout))
560
+
561
+ def reshape_heads_to_batch_dim(self, tensor):
562
+ batch_size, seq_len, dim = tensor.shape
563
+ head_size = self.heads
564
+ tensor = tensor.reshape(batch_size, seq_len,
565
+ head_size, dim // head_size)
566
+ tensor = tensor.permute(0, 2, 1, 3).reshape(
567
+ batch_size * head_size, seq_len, dim // head_size)
568
+ return tensor
569
+
570
+ def reshape_batch_dim_to_heads(self, tensor):
571
+ batch_size, seq_len, dim = tensor.shape
572
+ head_size = self.heads
573
+ tensor = tensor.reshape(batch_size // head_size,
574
+ head_size, seq_len, dim)
575
+ tensor = tensor.permute(0, 2, 1, 3).reshape(
576
+ batch_size // head_size, seq_len, dim * head_size)
577
+ return tensor
578
+
579
+ def reshape_batch_dim_to_heads_and_average(self, tensor):
580
+ batch_size, seq_len, seq_len2 = tensor.shape
581
+ head_size = self.heads
582
+ tensor = tensor.reshape(batch_size // head_size,
583
+ head_size, seq_len, seq_len2)
584
+ return tensor.mean(1)
585
+
586
+ def forward(self, hidden_states, real_attn_probs=None, context=None, mask=None, text_format_dict={}):
587
+ batch_size, sequence_length, _ = hidden_states.shape
588
+
589
+ query = self.to_q(hidden_states)
590
+ context = context if context is not None else hidden_states
591
+ key = self.to_k(context)
592
+ value = self.to_v(context)
593
+
594
+ dim = query.shape[-1]
595
+
596
+ query = self.reshape_heads_to_batch_dim(query)
597
+ key = self.reshape_heads_to_batch_dim(key)
598
+ value = self.reshape_heads_to_batch_dim(value)
599
+
600
+ # attention, what we cannot get enough of
601
+ if self._use_memory_efficient_attention_xformers:
602
+ hidden_states = self._memory_efficient_attention_xformers(
603
+ query, key, value)
604
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
605
+ hidden_states = hidden_states.to(query.dtype)
606
+ else:
607
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
608
+ # only this attention function is used
609
+ hidden_states, attn_probs = self._attention(
610
+ query, key, value, real_attn_probs, **text_format_dict)
611
+
612
+ # linear proj
613
+ hidden_states = self.to_out[0](hidden_states)
614
+ # dropout
615
+ hidden_states = self.to_out[1](hidden_states)
616
+ return hidden_states, attn_probs
617
+
618
+ def _qk(self, query, key):
619
+ return torch.baddbmm(
620
+ torch.empty(query.shape[0], query.shape[1], key.shape[1],
621
+ dtype=query.dtype, device=query.device),
622
+ query,
623
+ key.transpose(-1, -2),
624
+ beta=0,
625
+ alpha=self.scale,
626
+ )
627
+
628
+ def _attention(self, query, key, value, real_attn_probs=None, word_pos=None, font_size=None,
629
+ **kwargs):
630
+ attention_scores = self._qk(query, key)
631
+
632
+ # Font size V2:
633
+ if self.is_cross_attn and word_pos is not None and font_size is not None:
634
+ assert key.shape[1] == 77
635
+ attention_score_exp = attention_scores.exp()
636
+ font_size_abs, font_size_sign = font_size.abs(), font_size.sign()
637
+ attention_score_exp[:, :, word_pos] = attention_score_exp[:, :, word_pos].clone(
638
+ )*font_size_abs
639
+ attention_probs = attention_score_exp / \
640
+ attention_score_exp.sum(-1, True)
641
+ attention_probs[:, :, word_pos] *= font_size_sign
642
+ else:
643
+ attention_probs = attention_scores.softmax(dim=-1)
644
+
645
+ # compute attention output
646
+ if real_attn_probs is None:
647
+ hidden_states = torch.bmm(attention_probs, value)
648
+ else:
649
+ if isinstance(real_attn_probs, dict):
650
+ for pos1, pos2 in zip(real_attn_probs['inject_pos'][0], real_attn_probs['inject_pos'][1]):
651
+ attention_probs[:, :,
652
+ pos2] = real_attn_probs['reference'][:, :, pos1]
653
+ hidden_states = torch.bmm(attention_probs, value)
654
+ else:
655
+ hidden_states = torch.bmm(real_attn_probs, value)
656
+
657
+ # reshape hidden_states
658
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
659
+
660
+ # we also return the map averaged over heads to save memory footprint
661
+ attention_probs_avg = self.reshape_batch_dim_to_heads_and_average(
662
+ attention_probs)
663
+ return hidden_states, [attention_probs_avg, attention_probs]
664
+
665
+ def _memory_efficient_attention_xformers(self, query, key, value):
666
+ query = query.contiguous()
667
+ key = key.contiguous()
668
+ value = value.contiguous()
669
+ hidden_states = xformers.ops.memory_efficient_attention(
670
+ query, key, value, attn_bias=None)
671
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
672
+ return hidden_states
673
+
674
+
675
+ class FeedForward(nn.Module):
676
+ r"""
677
+ A feed-forward layer.
678
+
679
+ Parameters:
680
+ dim (`int`): The number of channels in the input.
681
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
682
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
683
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
684
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
685
+ """
686
+
687
+ def __init__(
688
+ self,
689
+ dim: int,
690
+ dim_out: Optional[int] = None,
691
+ mult: int = 4,
692
+ dropout: float = 0.0,
693
+ activation_fn: str = "geglu",
694
+ ):
695
+ super().__init__()
696
+ inner_dim = int(dim * mult)
697
+ dim_out = dim_out if dim_out is not None else dim
698
+
699
+ if activation_fn == "geglu":
700
+ geglu = GEGLU(dim, inner_dim)
701
+ elif activation_fn == "geglu-approximate":
702
+ geglu = ApproximateGELU(dim, inner_dim)
703
+
704
+ self.net = nn.ModuleList([])
705
+ # project in
706
+ self.net.append(geglu)
707
+ # project dropout
708
+ self.net.append(nn.Dropout(dropout))
709
+ # project out
710
+ self.net.append(nn.Linear(inner_dim, dim_out))
711
+
712
+ def forward(self, hidden_states):
713
+ for module in self.net:
714
+ hidden_states = module(hidden_states)
715
+ return hidden_states
716
+
717
+
718
+ # feedforward
719
+ class GEGLU(nn.Module):
720
+ r"""
721
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
722
+
723
+ Parameters:
724
+ dim_in (`int`): The number of channels in the input.
725
+ dim_out (`int`): The number of channels in the output.
726
+ """
727
+
728
+ def __init__(self, dim_in: int, dim_out: int):
729
+ super().__init__()
730
+ self.proj = nn.Linear(dim_in, dim_out * 2)
731
+
732
+ def gelu(self, gate):
733
+ if gate.device.type != "mps":
734
+ return F.gelu(gate)
735
+ # mps: gelu is not implemented for float16
736
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
737
+
738
+ def forward(self, hidden_states):
739
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
740
+ return hidden_states * self.gelu(gate)
741
+
742
+
743
+ class ApproximateGELU(nn.Module):
744
+ """
745
+ The approximate form of Gaussian Error Linear Unit (GELU)
746
+
747
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
748
+ """
749
+
750
+ def __init__(self, dim_in: int, dim_out: int):
751
+ super().__init__()
752
+ self.proj = nn.Linear(dim_in, dim_out)
753
+
754
+ def forward(self, x):
755
+ x = self.proj(x)
756
+ return x * torch.sigmoid(1.702 * x)
757
+
758
+
759
+ class AdaLayerNorm(nn.Module):
760
+ """
761
+ Norm layer modified to incorporate timestep embeddings.
762
+ """
763
+
764
+ def __init__(self, embedding_dim, num_embeddings):
765
+ super().__init__()
766
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
767
+ self.silu = nn.SiLU()
768
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
769
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
770
+
771
+ def forward(self, x, timestep):
772
+ emb = self.linear(self.silu(self.emb(timestep)))
773
+ scale, shift = torch.chunk(emb, 2)
774
+ x = self.norm(x) * (1 + scale) + shift
775
+ return x
776
+
777
+
778
+ class DualTransformer2DModel(nn.Module):
779
+ """
780
+ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
781
+
782
+ Parameters:
783
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
784
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
785
+ in_channels (`int`, *optional*):
786
+ Pass if the input is continuous. The number of channels in the input and output.
787
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
788
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
789
+ cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
790
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
791
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
792
+ `ImagePositionalEmbeddings`.
793
+ num_vector_embeds (`int`, *optional*):
794
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
795
+ Includes the class for the masked latent pixel.
796
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
797
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
798
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
799
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
800
+ up to but not more than steps than `num_embeds_ada_norm`.
801
+ attention_bias (`bool`, *optional*):
802
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
803
+ """
804
+
805
+ def __init__(
806
+ self,
807
+ num_attention_heads: int = 16,
808
+ attention_head_dim: int = 88,
809
+ in_channels: Optional[int] = None,
810
+ num_layers: int = 1,
811
+ dropout: float = 0.0,
812
+ norm_num_groups: int = 32,
813
+ cross_attention_dim: Optional[int] = None,
814
+ attention_bias: bool = False,
815
+ sample_size: Optional[int] = None,
816
+ num_vector_embeds: Optional[int] = None,
817
+ activation_fn: str = "geglu",
818
+ num_embeds_ada_norm: Optional[int] = None,
819
+ ):
820
+ super().__init__()
821
+ self.transformers = nn.ModuleList(
822
+ [
823
+ Transformer2DModel(
824
+ num_attention_heads=num_attention_heads,
825
+ attention_head_dim=attention_head_dim,
826
+ in_channels=in_channels,
827
+ num_layers=num_layers,
828
+ dropout=dropout,
829
+ norm_num_groups=norm_num_groups,
830
+ cross_attention_dim=cross_attention_dim,
831
+ attention_bias=attention_bias,
832
+ sample_size=sample_size,
833
+ num_vector_embeds=num_vector_embeds,
834
+ activation_fn=activation_fn,
835
+ num_embeds_ada_norm=num_embeds_ada_norm,
836
+ )
837
+ for _ in range(2)
838
+ ]
839
+ )
840
+
841
+ # Variables that can be set by a pipeline:
842
+
843
+ # The ratio of transformer1 to transformer2's output states to be combined during inference
844
+ self.mix_ratio = 0.5
845
+
846
+ # The shape of `encoder_hidden_states` is expected to be
847
+ # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
848
+ self.condition_lengths = [77, 257]
849
+
850
+ # Which transformer to use to encode which condition.
851
+ # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
852
+ self.transformer_index_for_condition = [1, 0]
853
+
854
+ def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_dict: bool = True):
855
+ """
856
+ Args:
857
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
858
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
859
+ hidden_states
860
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
861
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
862
+ self-attention.
863
+ timestep ( `torch.long`, *optional*):
864
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
865
+ return_dict (`bool`, *optional*, defaults to `True`):
866
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
867
+
868
+ Returns:
869
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
870
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
871
+ tensor.
872
+ """
873
+ input_states = hidden_states
874
+
875
+ encoded_states = []
876
+ tokens_start = 0
877
+ for i in range(2):
878
+ # for each of the two transformers, pass the corresponding condition tokens
879
+ condition_state = encoder_hidden_states[:,
880
+ tokens_start: tokens_start + self.condition_lengths[i]]
881
+ transformer_index = self.transformer_index_for_condition[i]
882
+ encoded_state = self.transformers[transformer_index](input_states, condition_state, timestep, return_dict)[
883
+ 0
884
+ ]
885
+ encoded_states.append(encoded_state - input_states)
886
+ tokens_start += self.condition_lengths[i]
887
+
888
+ output_states = encoded_states[0] * self.mix_ratio + \
889
+ encoded_states[1] * (1 - self.mix_ratio)
890
+ output_states = output_states + input_states
891
+
892
+ if not return_dict:
893
+ return (output_states,)
894
+
895
+ return Transformer2DModelOutput(sample=output_states)
896
+
897
+ def _set_attention_slice(self, slice_size):
898
+ for transformer in self.transformers:
899
+ transformer._set_attention_slice(slice_size)
900
+
901
+ def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
902
+ for transformer in self.transformers:
903
+ transformer._set_use_memory_efficient_attention_xformers(
904
+ use_memory_efficient_attention_xformers)
models/region_diffusion.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import collections
4
+ import torch.nn as nn
5
+ from functools import partial
6
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
7
+ from diffusers import AutoencoderKL, PNDMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
8
+ from models.unet_2d_condition import UNet2DConditionModel
9
+ from utils.attention_utils import CrossAttentionLayers, SelfAttentionLayers
10
+
11
+ # suppress partial model loading warning
12
+ logging.set_verbosity_error()
13
+
14
+
15
+ class RegionDiffusion(nn.Module):
16
+ def __init__(self, device):
17
+ super().__init__()
18
+
19
+ self.device = device
20
+ self.num_train_timesteps = 1000
21
+ self.clip_gradient = False
22
+
23
+ print(f'[INFO] loading stable diffusion...')
24
+ model_id = 'runwayml/stable-diffusion-v1-5'
25
+
26
+ self.vae = AutoencoderKL.from_pretrained(
27
+ model_id, subfolder="vae").to(self.device)
28
+ self.tokenizer = CLIPTokenizer.from_pretrained(
29
+ model_id, subfolder='tokenizer')
30
+ self.text_encoder = CLIPTextModel.from_pretrained(
31
+ model_id, subfolder='text_encoder').to(self.device)
32
+ self.unet = UNet2DConditionModel.from_pretrained(
33
+ model_id, subfolder="unet").to(self.device)
34
+
35
+ self.scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
36
+ num_train_timesteps=self.num_train_timesteps, skip_prk_steps=True, steps_offset=1)
37
+ self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
38
+
39
+ self.masks = []
40
+ self.attention_maps = None
41
+ self.selfattn_maps = None
42
+ self.crossattn_maps = None
43
+ self.color_loss = torch.nn.functional.mse_loss
44
+ self.forward_hooks = []
45
+ self.forward_replacement_hooks = []
46
+
47
+ print(f'[INFO] loaded stable diffusion!')
48
+
49
+ def get_text_embeds(self, prompt, negative_prompt):
50
+ # prompt, negative_prompt: [str]
51
+
52
+ # Tokenize text and get embeddings
53
+ text_input = self.tokenizer(
54
+ prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
55
+
56
+ with torch.no_grad():
57
+ text_embeddings = self.text_encoder(
58
+ text_input.input_ids.to(self.device))[0]
59
+
60
+ # Do the same for unconditional embeddings
61
+ uncond_input = self.tokenizer(negative_prompt, padding='max_length',
62
+ max_length=self.tokenizer.model_max_length, return_tensors='pt')
63
+
64
+ with torch.no_grad():
65
+ uncond_embeddings = self.text_encoder(
66
+ uncond_input.input_ids.to(self.device))[0]
67
+
68
+ # Cat for final embeddings
69
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
70
+ return text_embeddings
71
+
72
+ def get_text_embeds_list(self, prompts):
73
+ # prompts: [list]
74
+ text_embeddings = []
75
+ for prompt in prompts:
76
+ # Tokenize text and get embeddings
77
+ text_input = self.tokenizer(
78
+ [prompt], padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
79
+
80
+ with torch.no_grad():
81
+ text_embeddings.append(self.text_encoder(
82
+ text_input.input_ids.to(self.device))[0])
83
+
84
+ return text_embeddings
85
+
86
+ def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5,
87
+ latents=None, use_guidance=False, text_format_dict={}, inject_selfattn=0, inject_background=0):
88
+
89
+ if latents is None:
90
+ latents = torch.randn(
91
+ (1, self.unet.in_channels, height // 8, width // 8), device=self.device)
92
+
93
+ if inject_selfattn > 0 or inject_background > 0:
94
+ latents_reference = latents.clone().detach()
95
+ self.scheduler.set_timesteps(num_inference_steps)
96
+ n_styles = text_embeddings.shape[0]-1
97
+ assert n_styles == len(self.masks)
98
+ with torch.autocast('cuda'):
99
+ for i, t in enumerate(self.scheduler.timesteps):
100
+
101
+ # predict the noise residual
102
+ with torch.no_grad():
103
+ # tokens without any attributes
104
+ feat_inject_step = t > (1-inject_selfattn) * 1000
105
+ background_inject_step = i == int(inject_background * len(self.scheduler.timesteps)) and inject_background > 0
106
+ noise_pred_uncond_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[:1],
107
+ text_format_dict={})['sample']
108
+ noise_pred_text_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[-1:],
109
+ text_format_dict=text_format_dict)['sample']
110
+ if inject_selfattn > 0 or inject_background > 0:
111
+ noise_pred_uncond_refer = self.unet(latents_reference, t, encoder_hidden_states=text_embeddings[:1],
112
+ text_format_dict={})['sample']
113
+ self.register_selfattn_hooks(feat_inject_step)
114
+ noise_pred_text_refer = self.unet(latents_reference, t, encoder_hidden_states=text_embeddings[-1:],
115
+ text_format_dict={})['sample']
116
+ self.remove_selfattn_hooks()
117
+ noise_pred_uncond = noise_pred_uncond_cur * self.masks[-1]
118
+ noise_pred_text = noise_pred_text_cur * self.masks[-1]
119
+ # tokens with attributes
120
+ for style_i, mask in enumerate(self.masks[:-1]):
121
+ self.register_replacement_hooks(feat_inject_step)
122
+ noise_pred_text_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[style_i+1:style_i+2],
123
+ text_format_dict={})['sample']
124
+ self.remove_replacement_hooks()
125
+ noise_pred_uncond = noise_pred_uncond + noise_pred_uncond_cur*mask
126
+ noise_pred_text = noise_pred_text + noise_pred_text_cur*mask
127
+
128
+ # perform classifier-free guidance
129
+ noise_pred = noise_pred_uncond + guidance_scale * \
130
+ (noise_pred_text - noise_pred_uncond)
131
+
132
+ if inject_selfattn > 0 or inject_background > 0:
133
+ noise_pred_refer = noise_pred_uncond_refer + guidance_scale * \
134
+ (noise_pred_text_refer - noise_pred_uncond_refer)
135
+
136
+ # compute the previous noisy sample x_t -> x_t-1
137
+ latents_reference = self.scheduler.step(torch.cat([noise_pred, noise_pred_refer]), t,
138
+ torch.cat([latents, latents_reference]))[
139
+ 'prev_sample']
140
+ latents, latents_reference = torch.chunk(
141
+ latents_reference, 2, dim=0)
142
+
143
+ else:
144
+ # compute the previous noisy sample x_t -> x_t-1
145
+ latents = self.scheduler.step(noise_pred, t, latents)[
146
+ 'prev_sample']
147
+
148
+ # apply guidance
149
+ if use_guidance and t < text_format_dict['guidance_start_step']:
150
+ with torch.enable_grad():
151
+ if not latents.requires_grad:
152
+ latents.requires_grad = True
153
+ latents_0 = self.predict_x0(latents, noise_pred, t)
154
+ latents_inp = 1 / 0.18215 * latents_0
155
+ imgs = self.vae.decode(latents_inp).sample
156
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
157
+ loss_total = 0.
158
+ for attn_map, rgb_val in zip(text_format_dict['color_obj_atten'], text_format_dict['target_RGB']):
159
+ avg_rgb = (
160
+ imgs*attn_map[:, 0]).sum(2).sum(2)/attn_map[:, 0].sum()
161
+ loss = self.color_loss(
162
+ avg_rgb, rgb_val[:, :, 0, 0])*100
163
+ loss_total += loss
164
+ loss_total.backward()
165
+ latents = (
166
+ latents - latents.grad * text_format_dict['color_guidance_weight'] * text_format_dict['color_obj_atten_all']).detach().clone()
167
+
168
+ # apply background injection
169
+ if background_inject_step:
170
+ latents = latents_reference * self.masks[-1] + latents * \
171
+ (1-self.masks[-1])
172
+ return latents
173
+
174
+ def predict_x0(self, x_t, eps_t, t):
175
+ alpha_t = self.scheduler.alphas_cumprod[t]
176
+ return (x_t - eps_t * torch.sqrt(1-alpha_t)) / torch.sqrt(alpha_t)
177
+
178
+ def produce_attn_maps(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50,
179
+ guidance_scale=7.5, latents=None):
180
+
181
+ if isinstance(prompts, str):
182
+ prompts = [prompts]
183
+
184
+ if isinstance(negative_prompts, str):
185
+ negative_prompts = [negative_prompts]
186
+
187
+ # Prompts -> text embeds
188
+ text_embeddings = self.get_text_embeds(
189
+ prompts, negative_prompts) # [2, 77, 768]
190
+ if latents is None:
191
+ latents = torch.randn(
192
+ (text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device)
193
+
194
+ self.scheduler.set_timesteps(num_inference_steps)
195
+ self.remove_replacement_hooks()
196
+
197
+ with torch.autocast('cuda'):
198
+ for i, t in enumerate(self.scheduler.timesteps):
199
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
200
+ latent_model_input = torch.cat([latents] * 2)
201
+
202
+ # predict the noise residual
203
+ with torch.no_grad():
204
+ noise_pred = self.unet(
205
+ latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
206
+
207
+ # perform guidance
208
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
209
+ noise_pred = noise_pred_uncond + guidance_scale * \
210
+ (noise_pred_text - noise_pred_uncond)
211
+
212
+ # compute the previous noisy sample x_t -> x_t-1
213
+ latents = self.scheduler.step(noise_pred, t, latents)[
214
+ 'prev_sample']
215
+
216
+ # Img latents -> imgs
217
+ imgs = self.decode_latents(latents) # [1, 3, 512, 512]
218
+
219
+ # Img to Numpy
220
+ imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
221
+ imgs = (imgs * 255).round().astype('uint8')
222
+
223
+ return imgs
224
+
225
+ def decode_latents(self, latents):
226
+
227
+ latents = 1 / 0.18215 * latents
228
+
229
+ with torch.no_grad():
230
+ imgs = self.vae.decode(latents).sample
231
+
232
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
233
+
234
+ return imgs
235
+
236
+ def encode_imgs(self, imgs):
237
+ # imgs: [B, 3, H, W]
238
+
239
+ imgs = 2 * imgs - 1
240
+
241
+ posterior = self.vae.encode(imgs).latent_dist
242
+ latents = posterior.sample() * 0.18215
243
+
244
+ return latents
245
+
246
+ def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50,
247
+ guidance_scale=7.5, latents=None, text_format_dict={}, use_guidance=False, inject_selfattn=0, inject_background=0):
248
+
249
+ if isinstance(prompts, str):
250
+ prompts = [prompts]
251
+
252
+ if isinstance(negative_prompts, str):
253
+ negative_prompts = [negative_prompts]
254
+
255
+ # Prompts -> text embeds
256
+ text_embeds = self.get_text_embeds(
257
+ prompts, negative_prompts) # [2, 77, 768]
258
+
259
+ # else:
260
+ latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents,
261
+ num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
262
+ use_guidance=use_guidance, text_format_dict=text_format_dict,
263
+ inject_selfattn=inject_selfattn, inject_background=inject_background) # [1, 4, 64, 64]
264
+ # Img latents -> imgs
265
+ imgs = self.decode_latents(latents) # [1, 3, 512, 512]
266
+
267
+ # Img to Numpy
268
+ imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
269
+ imgs = (imgs * 255).round().astype('uint8')
270
+
271
+ return imgs
272
+
273
+ def reset_attention_maps(self):
274
+ r"""Function to reset attention maps.
275
+ We reset attention maps because we append them while getting hooks
276
+ to visualize attention maps for every step.
277
+ """
278
+ for key in self.selfattn_maps:
279
+ self.selfattn_maps[key] = []
280
+ for key in self.crossattn_maps:
281
+ self.crossattn_maps[key] = []
282
+
283
+ def register_evaluation_hooks(self):
284
+ r"""Function for registering hooks during evaluation.
285
+ We mainly store activation maps averaged over queries.
286
+ """
287
+ self.forward_hooks = []
288
+
289
+ def save_activations(activations, name, module, inp, out):
290
+ r"""
291
+ PyTorch Forward hook to save outputs at each forward pass.
292
+ """
293
+ # out[0] - final output of attention layer
294
+ # out[1] - attention probability matrix
295
+ if 'attn2' in name:
296
+ assert out[1].shape[-1] == 77
297
+ activations[name].append(out[1].detach().cpu())
298
+ else:
299
+ assert out[1].shape[-1] != 77
300
+ attention_dict = collections.defaultdict(list)
301
+ for name, module in self.unet.named_modules():
302
+ leaf_name = name.split('.')[-1]
303
+ if 'attn' in leaf_name:
304
+ # Register hook to obtain outputs at every attention layer.
305
+ self.forward_hooks.append(module.register_forward_hook(
306
+ partial(save_activations, attention_dict, name)
307
+ ))
308
+ # attention_dict is a dictionary containing attention maps for every attention layer
309
+ self.attention_maps = attention_dict
310
+
311
+ def register_selfattn_hooks(self, feat_inject_step=False):
312
+ r"""Function for registering hooks during evaluation.
313
+ We mainly store activation maps averaged over queries.
314
+ """
315
+ self.selfattn_forward_hooks = []
316
+
317
+ def save_activations(activations, name, module, inp, out):
318
+ r"""
319
+ PyTorch Forward hook to save outputs at each forward pass.
320
+ """
321
+ # out[0] - final output of attention layer
322
+ # out[1] - attention probability matrix
323
+ if 'attn2' in name:
324
+ assert out[1][1].shape[-1] == 77
325
+ # cross attention injection
326
+ # activations[name] = out[1][1].detach()
327
+ else:
328
+ assert out[1][1].shape[-1] != 77
329
+ activations[name] = out[1][1].detach()
330
+
331
+ def save_resnet_activations(activations, name, module, inp, out):
332
+ r"""
333
+ PyTorch Forward hook to save outputs at each forward pass.
334
+ """
335
+ # out[0] - final output of residual layer
336
+ # out[1] - residual hidden feature
337
+ assert out[1].shape[-1] == 16
338
+ activations[name] = out[1].detach()
339
+ attention_dict = collections.defaultdict(list)
340
+ for name, module in self.unet.named_modules():
341
+ leaf_name = name.split('.')[-1]
342
+ if 'attn' in leaf_name and feat_inject_step:
343
+ # Register hook to obtain outputs at every attention layer.
344
+ self.selfattn_forward_hooks.append(module.register_forward_hook(
345
+ partial(save_activations, attention_dict, name)
346
+ ))
347
+ if name == 'up_blocks.1.resnets.1' and feat_inject_step:
348
+ self.selfattn_forward_hooks.append(module.register_forward_hook(
349
+ partial(save_resnet_activations, attention_dict, name)
350
+ ))
351
+ # attention_dict is a dictionary containing attention maps for every attention layer
352
+ self.self_attention_maps_cur = attention_dict
353
+
354
+ def register_replacement_hooks(self, feat_inject_step=False):
355
+ r"""Function for registering hooks to replace self attention.
356
+ """
357
+ self.forward_replacement_hooks = []
358
+
359
+ def replace_activations(name, module, args):
360
+ r"""
361
+ PyTorch Forward hook to save outputs at each forward pass.
362
+ """
363
+ if 'attn1' in name:
364
+ modified_args = (args[0], self.self_attention_maps_cur[name])
365
+ return modified_args
366
+ # cross attention injection
367
+ # elif 'attn2' in name:
368
+ # modified_map = {
369
+ # 'reference': self.self_attention_maps_cur[name],
370
+ # 'inject_pos': self.inject_pos,
371
+ # }
372
+ # modified_args = (args[0], modified_map)
373
+ # return modified_args
374
+
375
+ def replace_resnet_activations(name, module, args):
376
+ r"""
377
+ PyTorch Forward hook to save outputs at each forward pass.
378
+ """
379
+ modified_args = (args[0], args[1],
380
+ self.self_attention_maps_cur[name])
381
+ return modified_args
382
+ for name, module in self.unet.named_modules():
383
+ leaf_name = name.split('.')[-1]
384
+ if 'attn' in leaf_name and feat_inject_step:
385
+ # Register hook to obtain outputs at every attention layer.
386
+ self.forward_replacement_hooks.append(module.register_forward_pre_hook(
387
+ partial(replace_activations, name)
388
+ ))
389
+ if name == 'up_blocks.1.resnets.1' and feat_inject_step:
390
+ # Register hook to obtain outputs at every attention layer.
391
+ self.forward_replacement_hooks.append(module.register_forward_pre_hook(
392
+ partial(replace_resnet_activations, name)
393
+ ))
394
+
395
+ def register_tokenmap_hooks(self):
396
+ r"""Function for registering hooks during evaluation.
397
+ We mainly store activation maps averaged over queries.
398
+ """
399
+ self.forward_hooks = []
400
+
401
+ def save_activations(selfattn_maps, crossattn_maps, n_maps, name, module, inp, out):
402
+ r"""
403
+ PyTorch Forward hook to save outputs at each forward pass.
404
+ """
405
+ # out[0] - final output of attention layer
406
+ # out[1] - attention probability matrices
407
+ if name in n_maps:
408
+ n_maps[name] += 1
409
+ else:
410
+ n_maps[name] = 1
411
+ if 'attn2' in name:
412
+ assert out[1][0].shape[-1] == 77
413
+ if name in CrossAttentionLayers and n_maps[name] > 10:
414
+ if name in crossattn_maps:
415
+ crossattn_maps[name] += out[1][0].detach().cpu()[1:2]
416
+ else:
417
+ crossattn_maps[name] = out[1][0].detach().cpu()[1:2]
418
+ else:
419
+ assert out[1][0].shape[-1] != 77
420
+ if name in SelfAttentionLayers and n_maps[name] > 10:
421
+ if name in crossattn_maps:
422
+ selfattn_maps[name] += out[1][0].detach().cpu()[1:2]
423
+ else:
424
+ selfattn_maps[name] = out[1][0].detach().cpu()[1:2]
425
+
426
+ selfattn_maps = collections.defaultdict(list)
427
+ crossattn_maps = collections.defaultdict(list)
428
+ n_maps = collections.defaultdict(list)
429
+
430
+ for name, module in self.unet.named_modules():
431
+ leaf_name = name.split('.')[-1]
432
+ if 'attn' in leaf_name:
433
+ # Register hook to obtain outputs at every attention layer.
434
+ self.forward_hooks.append(module.register_forward_hook(
435
+ partial(save_activations, selfattn_maps,
436
+ crossattn_maps, n_maps, name)
437
+ ))
438
+ # attention_dict is a dictionary containing attention maps for every attention layer
439
+ self.selfattn_maps = selfattn_maps
440
+ self.crossattn_maps = crossattn_maps
441
+ self.n_maps = n_maps
442
+
443
+ def remove_tokenmap_hooks(self):
444
+ for hook in self.forward_hooks:
445
+ hook.remove()
446
+ self.selfattn_maps = None
447
+ self.crossattn_maps = None
448
+ self.n_maps = None
449
+
450
+ def remove_evaluation_hooks(self):
451
+ for hook in self.forward_hooks:
452
+ hook.remove()
453
+ self.attention_maps = None
454
+
455
+ def remove_replacement_hooks(self):
456
+ for hook in self.forward_replacement_hooks:
457
+ hook.remove()
458
+
459
+ def remove_selfattn_hooks(self):
460
+ for hook in self.selfattn_forward_hooks:
461
+ hook.remove()
models/unet_2d_blocks.py ADDED
@@ -0,0 +1,1855 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import torch
16
+ from torch import nn
17
+
18
+ from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel
19
+ from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, Upsample2D
20
+
21
+
22
+ def get_down_block(
23
+ down_block_type,
24
+ num_layers,
25
+ in_channels,
26
+ out_channels,
27
+ temb_channels,
28
+ add_downsample,
29
+ resnet_eps,
30
+ resnet_act_fn,
31
+ attn_num_head_channels,
32
+ resnet_groups=None,
33
+ cross_attention_dim=None,
34
+ downsample_padding=None,
35
+ dual_cross_attention=False,
36
+ use_linear_projection=False,
37
+ only_cross_attention=False,
38
+ ):
39
+ down_block_type = down_block_type[7:] if down_block_type.startswith(
40
+ "UNetRes") else down_block_type
41
+ if down_block_type == "DownBlock2D":
42
+ return DownBlock2D(
43
+ num_layers=num_layers,
44
+ in_channels=in_channels,
45
+ out_channels=out_channels,
46
+ temb_channels=temb_channels,
47
+ add_downsample=add_downsample,
48
+ resnet_eps=resnet_eps,
49
+ resnet_act_fn=resnet_act_fn,
50
+ resnet_groups=resnet_groups,
51
+ downsample_padding=downsample_padding,
52
+ )
53
+ elif down_block_type == "AttnDownBlock2D":
54
+ return AttnDownBlock2D(
55
+ num_layers=num_layers,
56
+ in_channels=in_channels,
57
+ out_channels=out_channels,
58
+ temb_channels=temb_channels,
59
+ add_downsample=add_downsample,
60
+ resnet_eps=resnet_eps,
61
+ resnet_act_fn=resnet_act_fn,
62
+ resnet_groups=resnet_groups,
63
+ downsample_padding=downsample_padding,
64
+ attn_num_head_channels=attn_num_head_channels,
65
+ )
66
+ elif down_block_type == "CrossAttnDownBlock2D":
67
+ if cross_attention_dim is None:
68
+ raise ValueError(
69
+ "cross_attention_dim must be specified for CrossAttnDownBlock2D")
70
+ return CrossAttnDownBlock2D(
71
+ num_layers=num_layers,
72
+ in_channels=in_channels,
73
+ out_channels=out_channels,
74
+ temb_channels=temb_channels,
75
+ add_downsample=add_downsample,
76
+ resnet_eps=resnet_eps,
77
+ resnet_act_fn=resnet_act_fn,
78
+ resnet_groups=resnet_groups,
79
+ downsample_padding=downsample_padding,
80
+ cross_attention_dim=cross_attention_dim,
81
+ attn_num_head_channels=attn_num_head_channels,
82
+ dual_cross_attention=dual_cross_attention,
83
+ use_linear_projection=use_linear_projection,
84
+ only_cross_attention=only_cross_attention,
85
+ )
86
+ elif down_block_type == "SkipDownBlock2D":
87
+ return SkipDownBlock2D(
88
+ num_layers=num_layers,
89
+ in_channels=in_channels,
90
+ out_channels=out_channels,
91
+ temb_channels=temb_channels,
92
+ add_downsample=add_downsample,
93
+ resnet_eps=resnet_eps,
94
+ resnet_act_fn=resnet_act_fn,
95
+ downsample_padding=downsample_padding,
96
+ )
97
+ elif down_block_type == "AttnSkipDownBlock2D":
98
+ return AttnSkipDownBlock2D(
99
+ num_layers=num_layers,
100
+ in_channels=in_channels,
101
+ out_channels=out_channels,
102
+ temb_channels=temb_channels,
103
+ add_downsample=add_downsample,
104
+ resnet_eps=resnet_eps,
105
+ resnet_act_fn=resnet_act_fn,
106
+ downsample_padding=downsample_padding,
107
+ attn_num_head_channels=attn_num_head_channels,
108
+ )
109
+ elif down_block_type == "DownEncoderBlock2D":
110
+ return DownEncoderBlock2D(
111
+ num_layers=num_layers,
112
+ in_channels=in_channels,
113
+ out_channels=out_channels,
114
+ add_downsample=add_downsample,
115
+ resnet_eps=resnet_eps,
116
+ resnet_act_fn=resnet_act_fn,
117
+ resnet_groups=resnet_groups,
118
+ downsample_padding=downsample_padding,
119
+ )
120
+ elif down_block_type == "AttnDownEncoderBlock2D":
121
+ return AttnDownEncoderBlock2D(
122
+ num_layers=num_layers,
123
+ in_channels=in_channels,
124
+ out_channels=out_channels,
125
+ add_downsample=add_downsample,
126
+ resnet_eps=resnet_eps,
127
+ resnet_act_fn=resnet_act_fn,
128
+ resnet_groups=resnet_groups,
129
+ downsample_padding=downsample_padding,
130
+ attn_num_head_channels=attn_num_head_channels,
131
+ )
132
+ raise ValueError(f"{down_block_type} does not exist.")
133
+
134
+
135
+ def get_up_block(
136
+ up_block_type,
137
+ num_layers,
138
+ in_channels,
139
+ out_channels,
140
+ prev_output_channel,
141
+ temb_channels,
142
+ add_upsample,
143
+ resnet_eps,
144
+ resnet_act_fn,
145
+ attn_num_head_channels,
146
+ resnet_groups=None,
147
+ cross_attention_dim=None,
148
+ dual_cross_attention=False,
149
+ use_linear_projection=False,
150
+ only_cross_attention=False,
151
+ ):
152
+ up_block_type = up_block_type[7:] if up_block_type.startswith(
153
+ "UNetRes") else up_block_type
154
+ if up_block_type == "UpBlock2D":
155
+ return UpBlock2D(
156
+ num_layers=num_layers,
157
+ in_channels=in_channels,
158
+ out_channels=out_channels,
159
+ prev_output_channel=prev_output_channel,
160
+ temb_channels=temb_channels,
161
+ add_upsample=add_upsample,
162
+ resnet_eps=resnet_eps,
163
+ resnet_act_fn=resnet_act_fn,
164
+ resnet_groups=resnet_groups,
165
+ )
166
+ elif up_block_type == "CrossAttnUpBlock2D":
167
+ if cross_attention_dim is None:
168
+ raise ValueError(
169
+ "cross_attention_dim must be specified for CrossAttnUpBlock2D")
170
+ return CrossAttnUpBlock2D(
171
+ num_layers=num_layers,
172
+ in_channels=in_channels,
173
+ out_channels=out_channels,
174
+ prev_output_channel=prev_output_channel,
175
+ temb_channels=temb_channels,
176
+ add_upsample=add_upsample,
177
+ resnet_eps=resnet_eps,
178
+ resnet_act_fn=resnet_act_fn,
179
+ resnet_groups=resnet_groups,
180
+ cross_attention_dim=cross_attention_dim,
181
+ attn_num_head_channels=attn_num_head_channels,
182
+ dual_cross_attention=dual_cross_attention,
183
+ use_linear_projection=use_linear_projection,
184
+ only_cross_attention=only_cross_attention,
185
+ )
186
+ elif up_block_type == "AttnUpBlock2D":
187
+ return AttnUpBlock2D(
188
+ num_layers=num_layers,
189
+ in_channels=in_channels,
190
+ out_channels=out_channels,
191
+ prev_output_channel=prev_output_channel,
192
+ temb_channels=temb_channels,
193
+ add_upsample=add_upsample,
194
+ resnet_eps=resnet_eps,
195
+ resnet_act_fn=resnet_act_fn,
196
+ resnet_groups=resnet_groups,
197
+ attn_num_head_channels=attn_num_head_channels,
198
+ )
199
+ elif up_block_type == "SkipUpBlock2D":
200
+ return SkipUpBlock2D(
201
+ num_layers=num_layers,
202
+ in_channels=in_channels,
203
+ out_channels=out_channels,
204
+ prev_output_channel=prev_output_channel,
205
+ temb_channels=temb_channels,
206
+ add_upsample=add_upsample,
207
+ resnet_eps=resnet_eps,
208
+ resnet_act_fn=resnet_act_fn,
209
+ )
210
+ elif up_block_type == "AttnSkipUpBlock2D":
211
+ return AttnSkipUpBlock2D(
212
+ num_layers=num_layers,
213
+ in_channels=in_channels,
214
+ out_channels=out_channels,
215
+ prev_output_channel=prev_output_channel,
216
+ temb_channels=temb_channels,
217
+ add_upsample=add_upsample,
218
+ resnet_eps=resnet_eps,
219
+ resnet_act_fn=resnet_act_fn,
220
+ attn_num_head_channels=attn_num_head_channels,
221
+ )
222
+ elif up_block_type == "UpDecoderBlock2D":
223
+ return UpDecoderBlock2D(
224
+ num_layers=num_layers,
225
+ in_channels=in_channels,
226
+ out_channels=out_channels,
227
+ add_upsample=add_upsample,
228
+ resnet_eps=resnet_eps,
229
+ resnet_act_fn=resnet_act_fn,
230
+ resnet_groups=resnet_groups,
231
+ )
232
+ elif up_block_type == "AttnUpDecoderBlock2D":
233
+ return AttnUpDecoderBlock2D(
234
+ num_layers=num_layers,
235
+ in_channels=in_channels,
236
+ out_channels=out_channels,
237
+ add_upsample=add_upsample,
238
+ resnet_eps=resnet_eps,
239
+ resnet_act_fn=resnet_act_fn,
240
+ resnet_groups=resnet_groups,
241
+ attn_num_head_channels=attn_num_head_channels,
242
+ )
243
+ raise ValueError(f"{up_block_type} does not exist.")
244
+
245
+
246
+ class UNetMidBlock2D(nn.Module):
247
+ def __init__(
248
+ self,
249
+ in_channels: int,
250
+ temb_channels: int,
251
+ dropout: float = 0.0,
252
+ num_layers: int = 1,
253
+ resnet_eps: float = 1e-6,
254
+ resnet_time_scale_shift: str = "default",
255
+ resnet_act_fn: str = "swish",
256
+ resnet_groups: int = 32,
257
+ resnet_pre_norm: bool = True,
258
+ attn_num_head_channels=1,
259
+ attention_type="default",
260
+ output_scale_factor=1.0,
261
+ ):
262
+ super().__init__()
263
+
264
+ self.attention_type = attention_type
265
+ resnet_groups = resnet_groups if resnet_groups is not None else min(
266
+ in_channels // 4, 32)
267
+
268
+ # there is always at least one resnet
269
+ resnets = [
270
+ ResnetBlock2D(
271
+ in_channels=in_channels,
272
+ out_channels=in_channels,
273
+ temb_channels=temb_channels,
274
+ eps=resnet_eps,
275
+ groups=resnet_groups,
276
+ dropout=dropout,
277
+ time_embedding_norm=resnet_time_scale_shift,
278
+ non_linearity=resnet_act_fn,
279
+ output_scale_factor=output_scale_factor,
280
+ pre_norm=resnet_pre_norm,
281
+ )
282
+ ]
283
+ attentions = []
284
+
285
+ for _ in range(num_layers):
286
+ attentions.append(
287
+ AttentionBlock(
288
+ in_channels,
289
+ num_head_channels=attn_num_head_channels,
290
+ rescale_output_factor=output_scale_factor,
291
+ eps=resnet_eps,
292
+ norm_num_groups=resnet_groups,
293
+ )
294
+ )
295
+ resnets.append(
296
+ ResnetBlock2D(
297
+ in_channels=in_channels,
298
+ out_channels=in_channels,
299
+ temb_channels=temb_channels,
300
+ eps=resnet_eps,
301
+ groups=resnet_groups,
302
+ dropout=dropout,
303
+ time_embedding_norm=resnet_time_scale_shift,
304
+ non_linearity=resnet_act_fn,
305
+ output_scale_factor=output_scale_factor,
306
+ pre_norm=resnet_pre_norm,
307
+ )
308
+ )
309
+
310
+ self.attentions = nn.ModuleList(attentions)
311
+ self.resnets = nn.ModuleList(resnets)
312
+
313
+ def forward(self, hidden_states, temb=None, encoder_states=None):
314
+ hidden_states = self.resnets[0](hidden_states, temb)
315
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
316
+ if self.attention_type == "default":
317
+ hidden_states = attn(hidden_states)
318
+ else:
319
+ hidden_states = attn(hidden_states, encoder_states)
320
+ hidden_states, _ = resnet(hidden_states, temb)
321
+
322
+ return hidden_states
323
+
324
+
325
+ class UNetMidBlock2DCrossAttn(nn.Module):
326
+ def __init__(
327
+ self,
328
+ in_channels: int,
329
+ temb_channels: int,
330
+ dropout: float = 0.0,
331
+ num_layers: int = 1,
332
+ resnet_eps: float = 1e-6,
333
+ resnet_time_scale_shift: str = "default",
334
+ resnet_act_fn: str = "swish",
335
+ resnet_groups: int = 32,
336
+ resnet_pre_norm: bool = True,
337
+ attn_num_head_channels=1,
338
+ attention_type="default",
339
+ output_scale_factor=1.0,
340
+ cross_attention_dim=1280,
341
+ dual_cross_attention=False,
342
+ use_linear_projection=False,
343
+ ):
344
+ super().__init__()
345
+
346
+ self.attention_type = attention_type
347
+ self.attn_num_head_channels = attn_num_head_channels
348
+ resnet_groups = resnet_groups if resnet_groups is not None else min(
349
+ in_channels // 4, 32)
350
+
351
+ # there is always at least one resnet
352
+ resnets = [
353
+ ResnetBlock2D(
354
+ in_channels=in_channels,
355
+ out_channels=in_channels,
356
+ temb_channels=temb_channels,
357
+ eps=resnet_eps,
358
+ groups=resnet_groups,
359
+ dropout=dropout,
360
+ time_embedding_norm=resnet_time_scale_shift,
361
+ non_linearity=resnet_act_fn,
362
+ output_scale_factor=output_scale_factor,
363
+ pre_norm=resnet_pre_norm,
364
+ )
365
+ ]
366
+ attentions = []
367
+
368
+ for _ in range(num_layers):
369
+ if not dual_cross_attention:
370
+ attentions.append(
371
+ Transformer2DModel(
372
+ attn_num_head_channels,
373
+ in_channels // attn_num_head_channels,
374
+ in_channels=in_channels,
375
+ num_layers=1,
376
+ cross_attention_dim=cross_attention_dim,
377
+ norm_num_groups=resnet_groups,
378
+ use_linear_projection=use_linear_projection,
379
+ )
380
+ )
381
+ else:
382
+ attentions.append(
383
+ DualTransformer2DModel(
384
+ attn_num_head_channels,
385
+ in_channels // attn_num_head_channels,
386
+ in_channels=in_channels,
387
+ num_layers=1,
388
+ cross_attention_dim=cross_attention_dim,
389
+ norm_num_groups=resnet_groups,
390
+ )
391
+ )
392
+ resnets.append(
393
+ ResnetBlock2D(
394
+ in_channels=in_channels,
395
+ out_channels=in_channels,
396
+ temb_channels=temb_channels,
397
+ eps=resnet_eps,
398
+ groups=resnet_groups,
399
+ dropout=dropout,
400
+ time_embedding_norm=resnet_time_scale_shift,
401
+ non_linearity=resnet_act_fn,
402
+ output_scale_factor=output_scale_factor,
403
+ pre_norm=resnet_pre_norm,
404
+ )
405
+ )
406
+
407
+ self.attentions = nn.ModuleList(attentions)
408
+ self.resnets = nn.ModuleList(resnets)
409
+
410
+ def set_attention_slice(self, slice_size):
411
+ head_dims = self.attn_num_head_channels
412
+ head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
413
+ if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
414
+ raise ValueError(
415
+ f"Make sure slice_size {slice_size} is a common divisor of "
416
+ f"the number of heads used in cross_attention: {head_dims}"
417
+ )
418
+ if slice_size is not None and slice_size > min(head_dims):
419
+ raise ValueError(
420
+ f"slice_size {slice_size} has to be smaller or equal to "
421
+ f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
422
+ )
423
+
424
+ for attn in self.attentions:
425
+ attn._set_attention_slice(slice_size)
426
+
427
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
428
+ for attn in self.attentions:
429
+ attn._set_use_memory_efficient_attention_xformers(
430
+ use_memory_efficient_attention_xformers)
431
+
432
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None,
433
+ text_format_dict={}):
434
+ hidden_states, _ = self.resnets[0](hidden_states, temb)
435
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
436
+ hidden_states = attn(hidden_states, encoder_hidden_states,
437
+ text_format_dict).sample
438
+ hidden_states, _ = resnet(hidden_states, temb)
439
+
440
+ return hidden_states
441
+
442
+
443
+ class AttnDownBlock2D(nn.Module):
444
+ def __init__(
445
+ self,
446
+ in_channels: int,
447
+ out_channels: int,
448
+ temb_channels: int,
449
+ dropout: float = 0.0,
450
+ num_layers: int = 1,
451
+ resnet_eps: float = 1e-6,
452
+ resnet_time_scale_shift: str = "default",
453
+ resnet_act_fn: str = "swish",
454
+ resnet_groups: int = 32,
455
+ resnet_pre_norm: bool = True,
456
+ attn_num_head_channels=1,
457
+ attention_type="default",
458
+ output_scale_factor=1.0,
459
+ downsample_padding=1,
460
+ add_downsample=True,
461
+ ):
462
+ super().__init__()
463
+ resnets = []
464
+ attentions = []
465
+
466
+ self.attention_type = attention_type
467
+
468
+ for i in range(num_layers):
469
+ in_channels = in_channels if i == 0 else out_channels
470
+ resnets.append(
471
+ ResnetBlock2D(
472
+ in_channels=in_channels,
473
+ out_channels=out_channels,
474
+ temb_channels=temb_channels,
475
+ eps=resnet_eps,
476
+ groups=resnet_groups,
477
+ dropout=dropout,
478
+ time_embedding_norm=resnet_time_scale_shift,
479
+ non_linearity=resnet_act_fn,
480
+ output_scale_factor=output_scale_factor,
481
+ pre_norm=resnet_pre_norm,
482
+ )
483
+ )
484
+ attentions.append(
485
+ AttentionBlock(
486
+ out_channels,
487
+ num_head_channels=attn_num_head_channels,
488
+ rescale_output_factor=output_scale_factor,
489
+ eps=resnet_eps,
490
+ norm_num_groups=resnet_groups,
491
+ )
492
+ )
493
+
494
+ self.attentions = nn.ModuleList(attentions)
495
+ self.resnets = nn.ModuleList(resnets)
496
+
497
+ if add_downsample:
498
+ self.downsamplers = nn.ModuleList(
499
+ [
500
+ Downsample2D(
501
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
502
+ )
503
+ ]
504
+ )
505
+ else:
506
+ self.downsamplers = None
507
+
508
+ def forward(self, hidden_states, temb=None):
509
+ output_states = ()
510
+
511
+ for resnet, attn in zip(self.resnets, self.attentions):
512
+ hidden_states, _ = resnet(hidden_states, temb)
513
+ hidden_states = attn(hidden_states)
514
+ output_states += (hidden_states,)
515
+
516
+ if self.downsamplers is not None:
517
+ for downsampler in self.downsamplers:
518
+ hidden_states = downsampler(hidden_states)
519
+
520
+ output_states += (hidden_states,)
521
+
522
+ return hidden_states, output_states
523
+
524
+
525
+ class CrossAttnDownBlock2D(nn.Module):
526
+ def __init__(
527
+ self,
528
+ in_channels: int,
529
+ out_channels: int,
530
+ temb_channels: int,
531
+ dropout: float = 0.0,
532
+ num_layers: int = 1,
533
+ resnet_eps: float = 1e-6,
534
+ resnet_time_scale_shift: str = "default",
535
+ resnet_act_fn: str = "swish",
536
+ resnet_groups: int = 32,
537
+ resnet_pre_norm: bool = True,
538
+ attn_num_head_channels=1,
539
+ cross_attention_dim=1280,
540
+ attention_type="default",
541
+ output_scale_factor=1.0,
542
+ downsample_padding=1,
543
+ add_downsample=True,
544
+ dual_cross_attention=False,
545
+ use_linear_projection=False,
546
+ only_cross_attention=False,
547
+ ):
548
+ super().__init__()
549
+ resnets = []
550
+ attentions = []
551
+
552
+ self.attention_type = attention_type
553
+ self.attn_num_head_channels = attn_num_head_channels
554
+
555
+ for i in range(num_layers):
556
+ in_channels = in_channels if i == 0 else out_channels
557
+ resnets.append(
558
+ ResnetBlock2D(
559
+ in_channels=in_channels,
560
+ out_channels=out_channels,
561
+ temb_channels=temb_channels,
562
+ eps=resnet_eps,
563
+ groups=resnet_groups,
564
+ dropout=dropout,
565
+ time_embedding_norm=resnet_time_scale_shift,
566
+ non_linearity=resnet_act_fn,
567
+ output_scale_factor=output_scale_factor,
568
+ pre_norm=resnet_pre_norm,
569
+ )
570
+ )
571
+ if not dual_cross_attention:
572
+ attentions.append(
573
+ Transformer2DModel(
574
+ attn_num_head_channels,
575
+ out_channels // attn_num_head_channels,
576
+ in_channels=out_channels,
577
+ num_layers=1,
578
+ cross_attention_dim=cross_attention_dim,
579
+ norm_num_groups=resnet_groups,
580
+ use_linear_projection=use_linear_projection,
581
+ only_cross_attention=only_cross_attention,
582
+ )
583
+ )
584
+ else:
585
+ attentions.append(
586
+ DualTransformer2DModel(
587
+ attn_num_head_channels,
588
+ out_channels // attn_num_head_channels,
589
+ in_channels=out_channels,
590
+ num_layers=1,
591
+ cross_attention_dim=cross_attention_dim,
592
+ norm_num_groups=resnet_groups,
593
+ )
594
+ )
595
+ self.attentions = nn.ModuleList(attentions)
596
+ self.resnets = nn.ModuleList(resnets)
597
+
598
+ if add_downsample:
599
+ self.downsamplers = nn.ModuleList(
600
+ [
601
+ Downsample2D(
602
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
603
+ )
604
+ ]
605
+ )
606
+ else:
607
+ self.downsamplers = None
608
+
609
+ self.gradient_checkpointing = False
610
+
611
+ def set_attention_slice(self, slice_size):
612
+ head_dims = self.attn_num_head_channels
613
+ head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
614
+ if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
615
+ raise ValueError(
616
+ f"Make sure slice_size {slice_size} is a common divisor of "
617
+ f"the number of heads used in cross_attention: {head_dims}"
618
+ )
619
+ if slice_size is not None and slice_size > min(head_dims):
620
+ raise ValueError(
621
+ f"slice_size {slice_size} has to be smaller or equal to "
622
+ f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
623
+ )
624
+
625
+ for attn in self.attentions:
626
+ attn._set_attention_slice(slice_size)
627
+
628
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
629
+ for attn in self.attentions:
630
+ attn._set_use_memory_efficient_attention_xformers(
631
+ use_memory_efficient_attention_xformers)
632
+
633
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None,
634
+ text_format_dict={}):
635
+ output_states = ()
636
+
637
+ for resnet, attn in zip(self.resnets, self.attentions):
638
+ if self.training and self.gradient_checkpointing:
639
+
640
+ def create_custom_forward(module, return_dict=None):
641
+ def custom_forward(*inputs):
642
+ if return_dict is not None:
643
+ return module(*inputs, return_dict=return_dict)
644
+ else:
645
+ return module(*inputs)
646
+
647
+ return custom_forward
648
+
649
+ hidden_states = torch.utils.checkpoint.checkpoint(
650
+ create_custom_forward(resnet), hidden_states, temb)
651
+ hidden_states = torch.utils.checkpoint.checkpoint(
652
+ create_custom_forward(
653
+ attn, return_dict=False), hidden_states, encoder_hidden_states,
654
+ text_format_dict
655
+ )[0]
656
+ else:
657
+ hidden_states, _ = resnet(hidden_states, temb)
658
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,
659
+ text_format_dict=text_format_dict).sample
660
+
661
+ output_states += (hidden_states,)
662
+
663
+ if self.downsamplers is not None:
664
+ for downsampler in self.downsamplers:
665
+ hidden_states = downsampler(hidden_states)
666
+
667
+ output_states += (hidden_states,)
668
+
669
+ return hidden_states, output_states
670
+
671
+
672
+ class DownBlock2D(nn.Module):
673
+ def __init__(
674
+ self,
675
+ in_channels: int,
676
+ out_channels: int,
677
+ temb_channels: int,
678
+ dropout: float = 0.0,
679
+ num_layers: int = 1,
680
+ resnet_eps: float = 1e-6,
681
+ resnet_time_scale_shift: str = "default",
682
+ resnet_act_fn: str = "swish",
683
+ resnet_groups: int = 32,
684
+ resnet_pre_norm: bool = True,
685
+ output_scale_factor=1.0,
686
+ add_downsample=True,
687
+ downsample_padding=1,
688
+ ):
689
+ super().__init__()
690
+ resnets = []
691
+
692
+ for i in range(num_layers):
693
+ in_channels = in_channels if i == 0 else out_channels
694
+ resnets.append(
695
+ ResnetBlock2D(
696
+ in_channels=in_channels,
697
+ out_channels=out_channels,
698
+ temb_channels=temb_channels,
699
+ eps=resnet_eps,
700
+ groups=resnet_groups,
701
+ dropout=dropout,
702
+ time_embedding_norm=resnet_time_scale_shift,
703
+ non_linearity=resnet_act_fn,
704
+ output_scale_factor=output_scale_factor,
705
+ pre_norm=resnet_pre_norm,
706
+ )
707
+ )
708
+
709
+ self.resnets = nn.ModuleList(resnets)
710
+
711
+ if add_downsample:
712
+ self.downsamplers = nn.ModuleList(
713
+ [
714
+ Downsample2D(
715
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
716
+ )
717
+ ]
718
+ )
719
+ else:
720
+ self.downsamplers = None
721
+
722
+ self.gradient_checkpointing = False
723
+
724
+ def forward(self, hidden_states, temb=None):
725
+ output_states = ()
726
+
727
+ for resnet in self.resnets:
728
+ if self.training and self.gradient_checkpointing:
729
+
730
+ def create_custom_forward(module):
731
+ def custom_forward(*inputs):
732
+ return module(*inputs)
733
+
734
+ return custom_forward
735
+
736
+ hidden_states = torch.utils.checkpoint.checkpoint(
737
+ create_custom_forward(resnet), hidden_states, temb)
738
+ else:
739
+ hidden_states, _ = resnet(hidden_states, temb)
740
+
741
+ output_states += (hidden_states,)
742
+
743
+ if self.downsamplers is not None:
744
+ for downsampler in self.downsamplers:
745
+ hidden_states = downsampler(hidden_states)
746
+
747
+ output_states += (hidden_states,)
748
+
749
+ return hidden_states, output_states
750
+
751
+
752
+ class DownEncoderBlock2D(nn.Module):
753
+ def __init__(
754
+ self,
755
+ in_channels: int,
756
+ out_channels: int,
757
+ dropout: float = 0.0,
758
+ num_layers: int = 1,
759
+ resnet_eps: float = 1e-6,
760
+ resnet_time_scale_shift: str = "default",
761
+ resnet_act_fn: str = "swish",
762
+ resnet_groups: int = 32,
763
+ resnet_pre_norm: bool = True,
764
+ output_scale_factor=1.0,
765
+ add_downsample=True,
766
+ downsample_padding=1,
767
+ ):
768
+ super().__init__()
769
+ resnets = []
770
+
771
+ for i in range(num_layers):
772
+ in_channels = in_channels if i == 0 else out_channels
773
+ resnets.append(
774
+ ResnetBlock2D(
775
+ in_channels=in_channels,
776
+ out_channels=out_channels,
777
+ temb_channels=None,
778
+ eps=resnet_eps,
779
+ groups=resnet_groups,
780
+ dropout=dropout,
781
+ time_embedding_norm=resnet_time_scale_shift,
782
+ non_linearity=resnet_act_fn,
783
+ output_scale_factor=output_scale_factor,
784
+ pre_norm=resnet_pre_norm,
785
+ )
786
+ )
787
+
788
+ self.resnets = nn.ModuleList(resnets)
789
+
790
+ if add_downsample:
791
+ self.downsamplers = nn.ModuleList(
792
+ [
793
+ Downsample2D(
794
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
795
+ )
796
+ ]
797
+ )
798
+ else:
799
+ self.downsamplers = None
800
+
801
+ def forward(self, hidden_states):
802
+ for resnet in self.resnets:
803
+ hidden_states, _ = resnet(hidden_states, temb=None)
804
+
805
+ if self.downsamplers is not None:
806
+ for downsampler in self.downsamplers:
807
+ hidden_states = downsampler(hidden_states)
808
+
809
+ return hidden_states
810
+
811
+
812
+ class AttnDownEncoderBlock2D(nn.Module):
813
+ def __init__(
814
+ self,
815
+ in_channels: int,
816
+ out_channels: int,
817
+ dropout: float = 0.0,
818
+ num_layers: int = 1,
819
+ resnet_eps: float = 1e-6,
820
+ resnet_time_scale_shift: str = "default",
821
+ resnet_act_fn: str = "swish",
822
+ resnet_groups: int = 32,
823
+ resnet_pre_norm: bool = True,
824
+ attn_num_head_channels=1,
825
+ output_scale_factor=1.0,
826
+ add_downsample=True,
827
+ downsample_padding=1,
828
+ ):
829
+ super().__init__()
830
+ resnets = []
831
+ attentions = []
832
+
833
+ for i in range(num_layers):
834
+ in_channels = in_channels if i == 0 else out_channels
835
+ resnets.append(
836
+ ResnetBlock2D(
837
+ in_channels=in_channels,
838
+ out_channels=out_channels,
839
+ temb_channels=None,
840
+ eps=resnet_eps,
841
+ groups=resnet_groups,
842
+ dropout=dropout,
843
+ time_embedding_norm=resnet_time_scale_shift,
844
+ non_linearity=resnet_act_fn,
845
+ output_scale_factor=output_scale_factor,
846
+ pre_norm=resnet_pre_norm,
847
+ )
848
+ )
849
+ attentions.append(
850
+ AttentionBlock(
851
+ out_channels,
852
+ num_head_channels=attn_num_head_channels,
853
+ rescale_output_factor=output_scale_factor,
854
+ eps=resnet_eps,
855
+ norm_num_groups=resnet_groups,
856
+ )
857
+ )
858
+
859
+ self.attentions = nn.ModuleList(attentions)
860
+ self.resnets = nn.ModuleList(resnets)
861
+
862
+ if add_downsample:
863
+ self.downsamplers = nn.ModuleList(
864
+ [
865
+ Downsample2D(
866
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
867
+ )
868
+ ]
869
+ )
870
+ else:
871
+ self.downsamplers = None
872
+
873
+ def forward(self, hidden_states):
874
+ for resnet, attn in zip(self.resnets, self.attentions):
875
+ hidden_states, _ = resnet(hidden_states, temb=None)
876
+ hidden_states = attn(hidden_states)
877
+
878
+ if self.downsamplers is not None:
879
+ for downsampler in self.downsamplers:
880
+ hidden_states = downsampler(hidden_states)
881
+
882
+ return hidden_states
883
+
884
+
885
+ class AttnSkipDownBlock2D(nn.Module):
886
+ def __init__(
887
+ self,
888
+ in_channels: int,
889
+ out_channels: int,
890
+ temb_channels: int,
891
+ dropout: float = 0.0,
892
+ num_layers: int = 1,
893
+ resnet_eps: float = 1e-6,
894
+ resnet_time_scale_shift: str = "default",
895
+ resnet_act_fn: str = "swish",
896
+ resnet_pre_norm: bool = True,
897
+ attn_num_head_channels=1,
898
+ attention_type="default",
899
+ output_scale_factor=np.sqrt(2.0),
900
+ downsample_padding=1,
901
+ add_downsample=True,
902
+ ):
903
+ super().__init__()
904
+ self.attentions = nn.ModuleList([])
905
+ self.resnets = nn.ModuleList([])
906
+
907
+ self.attention_type = attention_type
908
+
909
+ for i in range(num_layers):
910
+ in_channels = in_channels if i == 0 else out_channels
911
+ self.resnets.append(
912
+ ResnetBlock2D(
913
+ in_channels=in_channels,
914
+ out_channels=out_channels,
915
+ temb_channels=temb_channels,
916
+ eps=resnet_eps,
917
+ groups=min(in_channels // 4, 32),
918
+ groups_out=min(out_channels // 4, 32),
919
+ dropout=dropout,
920
+ time_embedding_norm=resnet_time_scale_shift,
921
+ non_linearity=resnet_act_fn,
922
+ output_scale_factor=output_scale_factor,
923
+ pre_norm=resnet_pre_norm,
924
+ )
925
+ )
926
+ self.attentions.append(
927
+ AttentionBlock(
928
+ out_channels,
929
+ num_head_channels=attn_num_head_channels,
930
+ rescale_output_factor=output_scale_factor,
931
+ eps=resnet_eps,
932
+ )
933
+ )
934
+
935
+ if add_downsample:
936
+ self.resnet_down = ResnetBlock2D(
937
+ in_channels=out_channels,
938
+ out_channels=out_channels,
939
+ temb_channels=temb_channels,
940
+ eps=resnet_eps,
941
+ groups=min(out_channels // 4, 32),
942
+ dropout=dropout,
943
+ time_embedding_norm=resnet_time_scale_shift,
944
+ non_linearity=resnet_act_fn,
945
+ output_scale_factor=output_scale_factor,
946
+ pre_norm=resnet_pre_norm,
947
+ use_in_shortcut=True,
948
+ down=True,
949
+ kernel="fir",
950
+ )
951
+ self.downsamplers = nn.ModuleList(
952
+ [FirDownsample2D(out_channels, out_channels=out_channels)])
953
+ self.skip_conv = nn.Conv2d(
954
+ 3, out_channels, kernel_size=(1, 1), stride=(1, 1))
955
+ else:
956
+ self.resnet_down = None
957
+ self.downsamplers = None
958
+ self.skip_conv = None
959
+
960
+ def forward(self, hidden_states, temb=None, skip_sample=None):
961
+ output_states = ()
962
+
963
+ for resnet, attn in zip(self.resnets, self.attentions):
964
+ hidden_states, _ = resnet(hidden_states, temb)
965
+ hidden_states = attn(hidden_states)
966
+ output_states += (hidden_states,)
967
+
968
+ if self.downsamplers is not None:
969
+ hidden_states = self.resnet_down(hidden_states, temb)
970
+ for downsampler in self.downsamplers:
971
+ skip_sample = downsampler(skip_sample)
972
+
973
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
974
+
975
+ output_states += (hidden_states,)
976
+
977
+ return hidden_states, output_states, skip_sample
978
+
979
+
980
+ class SkipDownBlock2D(nn.Module):
981
+ def __init__(
982
+ self,
983
+ in_channels: int,
984
+ out_channels: int,
985
+ temb_channels: int,
986
+ dropout: float = 0.0,
987
+ num_layers: int = 1,
988
+ resnet_eps: float = 1e-6,
989
+ resnet_time_scale_shift: str = "default",
990
+ resnet_act_fn: str = "swish",
991
+ resnet_pre_norm: bool = True,
992
+ output_scale_factor=np.sqrt(2.0),
993
+ add_downsample=True,
994
+ downsample_padding=1,
995
+ ):
996
+ super().__init__()
997
+ self.resnets = nn.ModuleList([])
998
+
999
+ for i in range(num_layers):
1000
+ in_channels = in_channels if i == 0 else out_channels
1001
+ self.resnets.append(
1002
+ ResnetBlock2D(
1003
+ in_channels=in_channels,
1004
+ out_channels=out_channels,
1005
+ temb_channels=temb_channels,
1006
+ eps=resnet_eps,
1007
+ groups=min(in_channels // 4, 32),
1008
+ groups_out=min(out_channels // 4, 32),
1009
+ dropout=dropout,
1010
+ time_embedding_norm=resnet_time_scale_shift,
1011
+ non_linearity=resnet_act_fn,
1012
+ output_scale_factor=output_scale_factor,
1013
+ pre_norm=resnet_pre_norm,
1014
+ )
1015
+ )
1016
+
1017
+ if add_downsample:
1018
+ self.resnet_down = ResnetBlock2D(
1019
+ in_channels=out_channels,
1020
+ out_channels=out_channels,
1021
+ temb_channels=temb_channels,
1022
+ eps=resnet_eps,
1023
+ groups=min(out_channels // 4, 32),
1024
+ dropout=dropout,
1025
+ time_embedding_norm=resnet_time_scale_shift,
1026
+ non_linearity=resnet_act_fn,
1027
+ output_scale_factor=output_scale_factor,
1028
+ pre_norm=resnet_pre_norm,
1029
+ use_in_shortcut=True,
1030
+ down=True,
1031
+ kernel="fir",
1032
+ )
1033
+ self.downsamplers = nn.ModuleList(
1034
+ [FirDownsample2D(out_channels, out_channels=out_channels)])
1035
+ self.skip_conv = nn.Conv2d(
1036
+ 3, out_channels, kernel_size=(1, 1), stride=(1, 1))
1037
+ else:
1038
+ self.resnet_down = None
1039
+ self.downsamplers = None
1040
+ self.skip_conv = None
1041
+
1042
+ def forward(self, hidden_states, temb=None, skip_sample=None):
1043
+ output_states = ()
1044
+
1045
+ for resnet in self.resnets:
1046
+ hidden_states, _ = resnet(hidden_states, temb)
1047
+ output_states += (hidden_states,)
1048
+
1049
+ if self.downsamplers is not None:
1050
+ hidden_states = self.resnet_down(hidden_states, temb)
1051
+ for downsampler in self.downsamplers:
1052
+ skip_sample = downsampler(skip_sample)
1053
+
1054
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
1055
+
1056
+ output_states += (hidden_states,)
1057
+
1058
+ return hidden_states, output_states, skip_sample
1059
+
1060
+
1061
+ class AttnUpBlock2D(nn.Module):
1062
+ def __init__(
1063
+ self,
1064
+ in_channels: int,
1065
+ prev_output_channel: int,
1066
+ out_channels: int,
1067
+ temb_channels: int,
1068
+ dropout: float = 0.0,
1069
+ num_layers: int = 1,
1070
+ resnet_eps: float = 1e-6,
1071
+ resnet_time_scale_shift: str = "default",
1072
+ resnet_act_fn: str = "swish",
1073
+ resnet_groups: int = 32,
1074
+ resnet_pre_norm: bool = True,
1075
+ attention_type="default",
1076
+ attn_num_head_channels=1,
1077
+ output_scale_factor=1.0,
1078
+ add_upsample=True,
1079
+ ):
1080
+ super().__init__()
1081
+ resnets = []
1082
+ attentions = []
1083
+
1084
+ self.attention_type = attention_type
1085
+
1086
+ for i in range(num_layers):
1087
+ res_skip_channels = in_channels if (
1088
+ i == num_layers - 1) else out_channels
1089
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1090
+
1091
+ resnets.append(
1092
+ ResnetBlock2D(
1093
+ in_channels=resnet_in_channels + res_skip_channels,
1094
+ out_channels=out_channels,
1095
+ temb_channels=temb_channels,
1096
+ eps=resnet_eps,
1097
+ groups=resnet_groups,
1098
+ dropout=dropout,
1099
+ time_embedding_norm=resnet_time_scale_shift,
1100
+ non_linearity=resnet_act_fn,
1101
+ output_scale_factor=output_scale_factor,
1102
+ pre_norm=resnet_pre_norm,
1103
+ )
1104
+ )
1105
+ attentions.append(
1106
+ AttentionBlock(
1107
+ out_channels,
1108
+ num_head_channels=attn_num_head_channels,
1109
+ rescale_output_factor=output_scale_factor,
1110
+ eps=resnet_eps,
1111
+ norm_num_groups=resnet_groups,
1112
+ )
1113
+ )
1114
+
1115
+ self.attentions = nn.ModuleList(attentions)
1116
+ self.resnets = nn.ModuleList(resnets)
1117
+
1118
+ if add_upsample:
1119
+ self.upsamplers = nn.ModuleList(
1120
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1121
+ else:
1122
+ self.upsamplers = None
1123
+
1124
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
1125
+ for resnet, attn in zip(self.resnets, self.attentions):
1126
+ # pop res hidden states
1127
+ res_hidden_states = res_hidden_states_tuple[-1]
1128
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1129
+ hidden_states = torch.cat(
1130
+ [hidden_states, res_hidden_states], dim=1)
1131
+
1132
+ hidden_states, _ = resnet(hidden_states, temb)
1133
+ hidden_states = attn(hidden_states)
1134
+
1135
+ if self.upsamplers is not None:
1136
+ for upsampler in self.upsamplers:
1137
+ hidden_states = upsampler(hidden_states)
1138
+
1139
+ return hidden_states
1140
+
1141
+
1142
+ class CrossAttnUpBlock2D(nn.Module):
1143
+ def __init__(
1144
+ self,
1145
+ in_channels: int,
1146
+ out_channels: int,
1147
+ prev_output_channel: int,
1148
+ temb_channels: int,
1149
+ dropout: float = 0.0,
1150
+ num_layers: int = 1,
1151
+ resnet_eps: float = 1e-6,
1152
+ resnet_time_scale_shift: str = "default",
1153
+ resnet_act_fn: str = "swish",
1154
+ resnet_groups: int = 32,
1155
+ resnet_pre_norm: bool = True,
1156
+ attn_num_head_channels=1,
1157
+ cross_attention_dim=1280,
1158
+ attention_type="default",
1159
+ output_scale_factor=1.0,
1160
+ add_upsample=True,
1161
+ dual_cross_attention=False,
1162
+ use_linear_projection=False,
1163
+ only_cross_attention=False,
1164
+ ):
1165
+ super().__init__()
1166
+ resnets = []
1167
+ attentions = []
1168
+
1169
+ self.attention_type = attention_type
1170
+ self.attn_num_head_channels = attn_num_head_channels
1171
+
1172
+ for i in range(num_layers):
1173
+ res_skip_channels = in_channels if (
1174
+ i == num_layers - 1) else out_channels
1175
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1176
+
1177
+ resnets.append(
1178
+ ResnetBlock2D(
1179
+ in_channels=resnet_in_channels + res_skip_channels,
1180
+ out_channels=out_channels,
1181
+ temb_channels=temb_channels,
1182
+ eps=resnet_eps,
1183
+ groups=resnet_groups,
1184
+ dropout=dropout,
1185
+ time_embedding_norm=resnet_time_scale_shift,
1186
+ non_linearity=resnet_act_fn,
1187
+ output_scale_factor=output_scale_factor,
1188
+ pre_norm=resnet_pre_norm,
1189
+ )
1190
+ )
1191
+ if not dual_cross_attention:
1192
+ attentions.append(
1193
+ Transformer2DModel(
1194
+ attn_num_head_channels,
1195
+ out_channels // attn_num_head_channels,
1196
+ in_channels=out_channels,
1197
+ num_layers=1,
1198
+ cross_attention_dim=cross_attention_dim,
1199
+ norm_num_groups=resnet_groups,
1200
+ use_linear_projection=use_linear_projection,
1201
+ only_cross_attention=only_cross_attention,
1202
+ )
1203
+ )
1204
+ else:
1205
+ attentions.append(
1206
+ DualTransformer2DModel(
1207
+ attn_num_head_channels,
1208
+ out_channels // attn_num_head_channels,
1209
+ in_channels=out_channels,
1210
+ num_layers=1,
1211
+ cross_attention_dim=cross_attention_dim,
1212
+ norm_num_groups=resnet_groups,
1213
+ )
1214
+ )
1215
+ self.attentions = nn.ModuleList(attentions)
1216
+ self.resnets = nn.ModuleList(resnets)
1217
+
1218
+ if add_upsample:
1219
+ self.upsamplers = nn.ModuleList(
1220
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1221
+ else:
1222
+ self.upsamplers = None
1223
+
1224
+ self.gradient_checkpointing = False
1225
+
1226
+ def set_attention_slice(self, slice_size):
1227
+ head_dims = self.attn_num_head_channels
1228
+ head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
1229
+ if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
1230
+ raise ValueError(
1231
+ f"Make sure slice_size {slice_size} is a common divisor of "
1232
+ f"the number of heads used in cross_attention: {head_dims}"
1233
+ )
1234
+ if slice_size is not None and slice_size > min(head_dims):
1235
+ raise ValueError(
1236
+ f"slice_size {slice_size} has to be smaller or equal to "
1237
+ f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
1238
+ )
1239
+
1240
+ for attn in self.attentions:
1241
+ attn._set_attention_slice(slice_size)
1242
+
1243
+ self.gradient_checkpointing = False
1244
+
1245
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
1246
+ for attn in self.attentions:
1247
+ attn._set_use_memory_efficient_attention_xformers(
1248
+ use_memory_efficient_attention_xformers)
1249
+
1250
+ def forward(
1251
+ self,
1252
+ hidden_states,
1253
+ res_hidden_states_tuple,
1254
+ temb=None,
1255
+ encoder_hidden_states=None,
1256
+ upsample_size=None,
1257
+ text_format_dict={}
1258
+ ):
1259
+ for resnet, attn in zip(self.resnets, self.attentions):
1260
+ # pop res hidden states
1261
+ res_hidden_states = res_hidden_states_tuple[-1]
1262
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1263
+ hidden_states = torch.cat(
1264
+ [hidden_states, res_hidden_states], dim=1)
1265
+
1266
+ if self.training and self.gradient_checkpointing:
1267
+
1268
+ def create_custom_forward(module, return_dict=None):
1269
+ def custom_forward(*inputs):
1270
+ if return_dict is not None:
1271
+ return module(*inputs, return_dict=return_dict)
1272
+ else:
1273
+ return module(*inputs)
1274
+
1275
+ return custom_forward
1276
+
1277
+ hidden_states = torch.utils.checkpoint.checkpoint(
1278
+ create_custom_forward(resnet), hidden_states, temb)
1279
+ hidden_states = torch.utils.checkpoint.checkpoint(
1280
+ create_custom_forward(
1281
+ attn, return_dict=False), hidden_states, encoder_hidden_states,
1282
+ text_format_dict
1283
+ )[0]
1284
+ else:
1285
+ hidden_states, _ = resnet(hidden_states, temb)
1286
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,
1287
+ text_format_dict=text_format_dict).sample
1288
+
1289
+ if self.upsamplers is not None:
1290
+ for upsampler in self.upsamplers:
1291
+ hidden_states = upsampler(hidden_states, upsample_size)
1292
+
1293
+ return hidden_states
1294
+
1295
+
1296
+ class UpBlock2D(nn.Module):
1297
+ def __init__(
1298
+ self,
1299
+ in_channels: int,
1300
+ prev_output_channel: int,
1301
+ out_channels: int,
1302
+ temb_channels: int,
1303
+ dropout: float = 0.0,
1304
+ num_layers: int = 1,
1305
+ resnet_eps: float = 1e-6,
1306
+ resnet_time_scale_shift: str = "default",
1307
+ resnet_act_fn: str = "swish",
1308
+ resnet_groups: int = 32,
1309
+ resnet_pre_norm: bool = True,
1310
+ output_scale_factor=1.0,
1311
+ add_upsample=True,
1312
+ ):
1313
+ super().__init__()
1314
+ resnets = []
1315
+
1316
+ for i in range(num_layers):
1317
+ res_skip_channels = in_channels if (
1318
+ i == num_layers - 1) else out_channels
1319
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1320
+
1321
+ resnets.append(
1322
+ ResnetBlock2D(
1323
+ in_channels=resnet_in_channels + res_skip_channels,
1324
+ out_channels=out_channels,
1325
+ temb_channels=temb_channels,
1326
+ eps=resnet_eps,
1327
+ groups=resnet_groups,
1328
+ dropout=dropout,
1329
+ time_embedding_norm=resnet_time_scale_shift,
1330
+ non_linearity=resnet_act_fn,
1331
+ output_scale_factor=output_scale_factor,
1332
+ pre_norm=resnet_pre_norm,
1333
+ )
1334
+ )
1335
+
1336
+ self.resnets = nn.ModuleList(resnets)
1337
+
1338
+ if add_upsample:
1339
+ self.upsamplers = nn.ModuleList(
1340
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1341
+ else:
1342
+ self.upsamplers = None
1343
+
1344
+ self.gradient_checkpointing = False
1345
+
1346
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
1347
+ for resnet in self.resnets:
1348
+ # pop res hidden states
1349
+ res_hidden_states = res_hidden_states_tuple[-1]
1350
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1351
+ hidden_states = torch.cat(
1352
+ [hidden_states, res_hidden_states], dim=1)
1353
+
1354
+ if self.training and self.gradient_checkpointing:
1355
+
1356
+ def create_custom_forward(module):
1357
+ def custom_forward(*inputs):
1358
+ return module(*inputs)
1359
+
1360
+ return custom_forward
1361
+
1362
+ hidden_states = torch.utils.checkpoint.checkpoint(
1363
+ create_custom_forward(resnet), hidden_states, temb)
1364
+ else:
1365
+ hidden_states, _ = resnet(hidden_states, temb)
1366
+
1367
+ if self.upsamplers is not None:
1368
+ for upsampler in self.upsamplers:
1369
+ hidden_states = upsampler(hidden_states, upsample_size)
1370
+
1371
+ return hidden_states
1372
+
1373
+
1374
+ class UpDecoderBlock2D(nn.Module):
1375
+ def __init__(
1376
+ self,
1377
+ in_channels: int,
1378
+ out_channels: int,
1379
+ dropout: float = 0.0,
1380
+ num_layers: int = 1,
1381
+ resnet_eps: float = 1e-6,
1382
+ resnet_time_scale_shift: str = "default",
1383
+ resnet_act_fn: str = "swish",
1384
+ resnet_groups: int = 32,
1385
+ resnet_pre_norm: bool = True,
1386
+ output_scale_factor=1.0,
1387
+ add_upsample=True,
1388
+ ):
1389
+ super().__init__()
1390
+ resnets = []
1391
+
1392
+ for i in range(num_layers):
1393
+ input_channels = in_channels if i == 0 else out_channels
1394
+
1395
+ resnets.append(
1396
+ ResnetBlock2D(
1397
+ in_channels=input_channels,
1398
+ out_channels=out_channels,
1399
+ temb_channels=None,
1400
+ eps=resnet_eps,
1401
+ groups=resnet_groups,
1402
+ dropout=dropout,
1403
+ time_embedding_norm=resnet_time_scale_shift,
1404
+ non_linearity=resnet_act_fn,
1405
+ output_scale_factor=output_scale_factor,
1406
+ pre_norm=resnet_pre_norm,
1407
+ )
1408
+ )
1409
+
1410
+ self.resnets = nn.ModuleList(resnets)
1411
+
1412
+ if add_upsample:
1413
+ self.upsamplers = nn.ModuleList(
1414
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1415
+ else:
1416
+ self.upsamplers = None
1417
+
1418
+ def forward(self, hidden_states):
1419
+ for resnet in self.resnets:
1420
+ hidden_states, _ = resnet(hidden_states, temb=None)
1421
+
1422
+ if self.upsamplers is not None:
1423
+ for upsampler in self.upsamplers:
1424
+ hidden_states = upsampler(hidden_states)
1425
+
1426
+ return hidden_states
1427
+
1428
+
1429
+ class AttnUpDecoderBlock2D(nn.Module):
1430
+ def __init__(
1431
+ self,
1432
+ in_channels: int,
1433
+ out_channels: int,
1434
+ dropout: float = 0.0,
1435
+ num_layers: int = 1,
1436
+ resnet_eps: float = 1e-6,
1437
+ resnet_time_scale_shift: str = "default",
1438
+ resnet_act_fn: str = "swish",
1439
+ resnet_groups: int = 32,
1440
+ resnet_pre_norm: bool = True,
1441
+ attn_num_head_channels=1,
1442
+ output_scale_factor=1.0,
1443
+ add_upsample=True,
1444
+ ):
1445
+ super().__init__()
1446
+ resnets = []
1447
+ attentions = []
1448
+
1449
+ for i in range(num_layers):
1450
+ input_channels = in_channels if i == 0 else out_channels
1451
+
1452
+ resnets.append(
1453
+ ResnetBlock2D(
1454
+ in_channels=input_channels,
1455
+ out_channels=out_channels,
1456
+ temb_channels=None,
1457
+ eps=resnet_eps,
1458
+ groups=resnet_groups,
1459
+ dropout=dropout,
1460
+ time_embedding_norm=resnet_time_scale_shift,
1461
+ non_linearity=resnet_act_fn,
1462
+ output_scale_factor=output_scale_factor,
1463
+ pre_norm=resnet_pre_norm,
1464
+ )
1465
+ )
1466
+ attentions.append(
1467
+ AttentionBlock(
1468
+ out_channels,
1469
+ num_head_channels=attn_num_head_channels,
1470
+ rescale_output_factor=output_scale_factor,
1471
+ eps=resnet_eps,
1472
+ norm_num_groups=resnet_groups,
1473
+ )
1474
+ )
1475
+
1476
+ self.attentions = nn.ModuleList(attentions)
1477
+ self.resnets = nn.ModuleList(resnets)
1478
+
1479
+ if add_upsample:
1480
+ self.upsamplers = nn.ModuleList(
1481
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1482
+ else:
1483
+ self.upsamplers = None
1484
+
1485
+ def forward(self, hidden_states):
1486
+ for resnet, attn in zip(self.resnets, self.attentions):
1487
+ hidden_states, _ = resnet(hidden_states, temb=None)
1488
+ hidden_states = attn(hidden_states)
1489
+
1490
+ if self.upsamplers is not None:
1491
+ for upsampler in self.upsamplers:
1492
+ hidden_states = upsampler(hidden_states)
1493
+
1494
+ return hidden_states
1495
+
1496
+
1497
+ class AttnSkipUpBlock2D(nn.Module):
1498
+ def __init__(
1499
+ self,
1500
+ in_channels: int,
1501
+ prev_output_channel: int,
1502
+ out_channels: int,
1503
+ temb_channels: int,
1504
+ dropout: float = 0.0,
1505
+ num_layers: int = 1,
1506
+ resnet_eps: float = 1e-6,
1507
+ resnet_time_scale_shift: str = "default",
1508
+ resnet_act_fn: str = "swish",
1509
+ resnet_pre_norm: bool = True,
1510
+ attn_num_head_channels=1,
1511
+ attention_type="default",
1512
+ output_scale_factor=np.sqrt(2.0),
1513
+ upsample_padding=1,
1514
+ add_upsample=True,
1515
+ ):
1516
+ super().__init__()
1517
+ self.attentions = nn.ModuleList([])
1518
+ self.resnets = nn.ModuleList([])
1519
+
1520
+ self.attention_type = attention_type
1521
+
1522
+ for i in range(num_layers):
1523
+ res_skip_channels = in_channels if (
1524
+ i == num_layers - 1) else out_channels
1525
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1526
+
1527
+ self.resnets.append(
1528
+ ResnetBlock2D(
1529
+ in_channels=resnet_in_channels + res_skip_channels,
1530
+ out_channels=out_channels,
1531
+ temb_channels=temb_channels,
1532
+ eps=resnet_eps,
1533
+ groups=min(resnet_in_channels +
1534
+ res_skip_channels // 4, 32),
1535
+ groups_out=min(out_channels // 4, 32),
1536
+ dropout=dropout,
1537
+ time_embedding_norm=resnet_time_scale_shift,
1538
+ non_linearity=resnet_act_fn,
1539
+ output_scale_factor=output_scale_factor,
1540
+ pre_norm=resnet_pre_norm,
1541
+ )
1542
+ )
1543
+
1544
+ self.attentions.append(
1545
+ AttentionBlock(
1546
+ out_channels,
1547
+ num_head_channels=attn_num_head_channels,
1548
+ rescale_output_factor=output_scale_factor,
1549
+ eps=resnet_eps,
1550
+ )
1551
+ )
1552
+
1553
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
1554
+ if add_upsample:
1555
+ self.resnet_up = ResnetBlock2D(
1556
+ in_channels=out_channels,
1557
+ out_channels=out_channels,
1558
+ temb_channels=temb_channels,
1559
+ eps=resnet_eps,
1560
+ groups=min(out_channels // 4, 32),
1561
+ groups_out=min(out_channels // 4, 32),
1562
+ dropout=dropout,
1563
+ time_embedding_norm=resnet_time_scale_shift,
1564
+ non_linearity=resnet_act_fn,
1565
+ output_scale_factor=output_scale_factor,
1566
+ pre_norm=resnet_pre_norm,
1567
+ use_in_shortcut=True,
1568
+ up=True,
1569
+ kernel="fir",
1570
+ )
1571
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(
1572
+ 3, 3), stride=(1, 1), padding=(1, 1))
1573
+ self.skip_norm = torch.nn.GroupNorm(
1574
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1575
+ )
1576
+ self.act = nn.SiLU()
1577
+ else:
1578
+ self.resnet_up = None
1579
+ self.skip_conv = None
1580
+ self.skip_norm = None
1581
+ self.act = None
1582
+
1583
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
1584
+ for resnet in self.resnets:
1585
+ # pop res hidden states
1586
+ res_hidden_states = res_hidden_states_tuple[-1]
1587
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1588
+ hidden_states = torch.cat(
1589
+ [hidden_states, res_hidden_states], dim=1)
1590
+
1591
+ hidden_states, _ = resnet(hidden_states, temb)
1592
+
1593
+ hidden_states = self.attentions[0](hidden_states)
1594
+
1595
+ if skip_sample is not None:
1596
+ skip_sample = self.upsampler(skip_sample)
1597
+ else:
1598
+ skip_sample = 0
1599
+
1600
+ if self.resnet_up is not None:
1601
+ skip_sample_states = self.skip_norm(hidden_states)
1602
+ skip_sample_states = self.act(skip_sample_states)
1603
+ skip_sample_states = self.skip_conv(skip_sample_states)
1604
+
1605
+ skip_sample = skip_sample + skip_sample_states
1606
+
1607
+ hidden_states = self.resnet_up(hidden_states, temb)
1608
+
1609
+ return hidden_states, skip_sample
1610
+
1611
+
1612
+ class SkipUpBlock2D(nn.Module):
1613
+ def __init__(
1614
+ self,
1615
+ in_channels: int,
1616
+ prev_output_channel: int,
1617
+ out_channels: int,
1618
+ temb_channels: int,
1619
+ dropout: float = 0.0,
1620
+ num_layers: int = 1,
1621
+ resnet_eps: float = 1e-6,
1622
+ resnet_time_scale_shift: str = "default",
1623
+ resnet_act_fn: str = "swish",
1624
+ resnet_pre_norm: bool = True,
1625
+ output_scale_factor=np.sqrt(2.0),
1626
+ add_upsample=True,
1627
+ upsample_padding=1,
1628
+ ):
1629
+ super().__init__()
1630
+ self.resnets = nn.ModuleList([])
1631
+
1632
+ for i in range(num_layers):
1633
+ res_skip_channels = in_channels if (
1634
+ i == num_layers - 1) else out_channels
1635
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1636
+
1637
+ self.resnets.append(
1638
+ ResnetBlock2D(
1639
+ in_channels=resnet_in_channels + res_skip_channels,
1640
+ out_channels=out_channels,
1641
+ temb_channels=temb_channels,
1642
+ eps=resnet_eps,
1643
+ groups=min(
1644
+ (resnet_in_channels + res_skip_channels) // 4, 32),
1645
+ groups_out=min(out_channels // 4, 32),
1646
+ dropout=dropout,
1647
+ time_embedding_norm=resnet_time_scale_shift,
1648
+ non_linearity=resnet_act_fn,
1649
+ output_scale_factor=output_scale_factor,
1650
+ pre_norm=resnet_pre_norm,
1651
+ )
1652
+ )
1653
+
1654
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
1655
+ if add_upsample:
1656
+ self.resnet_up = ResnetBlock2D(
1657
+ in_channels=out_channels,
1658
+ out_channels=out_channels,
1659
+ temb_channels=temb_channels,
1660
+ eps=resnet_eps,
1661
+ groups=min(out_channels // 4, 32),
1662
+ groups_out=min(out_channels // 4, 32),
1663
+ dropout=dropout,
1664
+ time_embedding_norm=resnet_time_scale_shift,
1665
+ non_linearity=resnet_act_fn,
1666
+ output_scale_factor=output_scale_factor,
1667
+ pre_norm=resnet_pre_norm,
1668
+ use_in_shortcut=True,
1669
+ up=True,
1670
+ kernel="fir",
1671
+ )
1672
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(
1673
+ 3, 3), stride=(1, 1), padding=(1, 1))
1674
+ self.skip_norm = torch.nn.GroupNorm(
1675
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1676
+ )
1677
+ self.act = nn.SiLU()
1678
+ else:
1679
+ self.resnet_up = None
1680
+ self.skip_conv = None
1681
+ self.skip_norm = None
1682
+ self.act = None
1683
+
1684
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
1685
+ for resnet in self.resnets:
1686
+ # pop res hidden states
1687
+ res_hidden_states = res_hidden_states_tuple[-1]
1688
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1689
+ hidden_states = torch.cat(
1690
+ [hidden_states, res_hidden_states], dim=1)
1691
+
1692
+ hidden_states, _ = resnet(hidden_states, temb)
1693
+
1694
+ if skip_sample is not None:
1695
+ skip_sample = self.upsampler(skip_sample)
1696
+ else:
1697
+ skip_sample = 0
1698
+
1699
+ if self.resnet_up is not None:
1700
+ skip_sample_states = self.skip_norm(hidden_states)
1701
+ skip_sample_states = self.act(skip_sample_states)
1702
+ skip_sample_states = self.skip_conv(skip_sample_states)
1703
+
1704
+ skip_sample = skip_sample + skip_sample_states
1705
+
1706
+ hidden_states = self.resnet_up(hidden_states, temb)
1707
+
1708
+ return hidden_states, skip_sample
1709
+
1710
+
1711
+ class ResnetBlock2D(nn.Module):
1712
+ def __init__(
1713
+ self,
1714
+ *,
1715
+ in_channels,
1716
+ out_channels=None,
1717
+ conv_shortcut=False,
1718
+ dropout=0.0,
1719
+ temb_channels=512,
1720
+ groups=32,
1721
+ groups_out=None,
1722
+ pre_norm=True,
1723
+ eps=1e-6,
1724
+ non_linearity="swish",
1725
+ time_embedding_norm="default",
1726
+ kernel=None,
1727
+ output_scale_factor=1.0,
1728
+ use_in_shortcut=None,
1729
+ up=False,
1730
+ down=False,
1731
+ ):
1732
+ super().__init__()
1733
+ self.pre_norm = pre_norm
1734
+ self.pre_norm = True
1735
+ self.in_channels = in_channels
1736
+ out_channels = in_channels if out_channels is None else out_channels
1737
+ self.out_channels = out_channels
1738
+ self.use_conv_shortcut = conv_shortcut
1739
+ self.time_embedding_norm = time_embedding_norm
1740
+ self.up = up
1741
+ self.down = down
1742
+ self.output_scale_factor = output_scale_factor
1743
+
1744
+ if groups_out is None:
1745
+ groups_out = groups
1746
+
1747
+ self.norm1 = torch.nn.GroupNorm(
1748
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
1749
+
1750
+ self.conv1 = torch.nn.Conv2d(
1751
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1)
1752
+
1753
+ if temb_channels is not None:
1754
+ if self.time_embedding_norm == "default":
1755
+ time_emb_proj_out_channels = out_channels
1756
+ elif self.time_embedding_norm == "scale_shift":
1757
+ time_emb_proj_out_channels = out_channels * 2
1758
+ else:
1759
+ raise ValueError(
1760
+ f"unknown time_embedding_norm : {self.time_embedding_norm} ")
1761
+
1762
+ self.time_emb_proj = torch.nn.Linear(
1763
+ temb_channels, time_emb_proj_out_channels)
1764
+ else:
1765
+ self.time_emb_proj = None
1766
+
1767
+ self.norm2 = torch.nn.GroupNorm(
1768
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
1769
+ self.dropout = torch.nn.Dropout(dropout)
1770
+ self.conv2 = torch.nn.Conv2d(
1771
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1)
1772
+
1773
+ if non_linearity == "swish":
1774
+ self.nonlinearity = lambda x: F.silu(x)
1775
+ elif non_linearity == "mish":
1776
+ self.nonlinearity = Mish()
1777
+ elif non_linearity == "silu":
1778
+ self.nonlinearity = nn.SiLU()
1779
+
1780
+ self.upsample = self.downsample = None
1781
+ if self.up:
1782
+ if kernel == "fir":
1783
+ fir_kernel = (1, 3, 3, 1)
1784
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
1785
+ elif kernel == "sde_vp":
1786
+ self.upsample = partial(
1787
+ F.interpolate, scale_factor=2.0, mode="nearest")
1788
+ else:
1789
+ self.upsample = Upsample2D(in_channels, use_conv=False)
1790
+ elif self.down:
1791
+ if kernel == "fir":
1792
+ fir_kernel = (1, 3, 3, 1)
1793
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
1794
+ elif kernel == "sde_vp":
1795
+ self.downsample = partial(
1796
+ F.avg_pool2d, kernel_size=2, stride=2)
1797
+ else:
1798
+ self.downsample = Downsample2D(
1799
+ in_channels, use_conv=False, padding=1, name="op")
1800
+
1801
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
1802
+
1803
+ self.conv_shortcut = None
1804
+ if self.use_in_shortcut:
1805
+ self.conv_shortcut = torch.nn.Conv2d(
1806
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0)
1807
+
1808
+ def forward(self, input_tensor, temb, inject_states=None):
1809
+ hidden_states = input_tensor
1810
+
1811
+ hidden_states = self.norm1(hidden_states)
1812
+ hidden_states = self.nonlinearity(hidden_states)
1813
+
1814
+ if self.upsample is not None:
1815
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
1816
+ if hidden_states.shape[0] >= 64:
1817
+ input_tensor = input_tensor.contiguous()
1818
+ hidden_states = hidden_states.contiguous()
1819
+ input_tensor = self.upsample(input_tensor)
1820
+ hidden_states = self.upsample(hidden_states)
1821
+ elif self.downsample is not None:
1822
+ input_tensor = self.downsample(input_tensor)
1823
+ hidden_states = self.downsample(hidden_states)
1824
+
1825
+ hidden_states = self.conv1(hidden_states)
1826
+
1827
+ if temb is not None:
1828
+ temb = self.time_emb_proj(self.nonlinearity(temb))[
1829
+ :, :, None, None]
1830
+
1831
+ if temb is not None and self.time_embedding_norm == "default":
1832
+ hidden_states = hidden_states + temb
1833
+
1834
+ hidden_states = self.norm2(hidden_states)
1835
+
1836
+ if temb is not None and self.time_embedding_norm == "scale_shift":
1837
+ scale, shift = torch.chunk(temb, 2, dim=1)
1838
+ hidden_states = hidden_states * (1 + scale) + shift
1839
+
1840
+ hidden_states = self.nonlinearity(hidden_states)
1841
+
1842
+ hidden_states = self.dropout(hidden_states)
1843
+ hidden_states = self.conv2(hidden_states)
1844
+
1845
+ if self.conv_shortcut is not None:
1846
+ input_tensor = self.conv_shortcut(input_tensor)
1847
+
1848
+ if inject_states is not None:
1849
+ output_tensor = (input_tensor + inject_states) / \
1850
+ self.output_scale_factor
1851
+ else:
1852
+ output_tensor = (input_tensor + hidden_states) / \
1853
+ self.output_scale_factor
1854
+
1855
+ return output_tensor, hidden_states
models/unet_2d_condition.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.modeling_utils import ModelMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
25
+ from .unet_2d_blocks import (
26
+ CrossAttnDownBlock2D,
27
+ CrossAttnUpBlock2D,
28
+ DownBlock2D,
29
+ UNetMidBlock2DCrossAttn,
30
+ UpBlock2D,
31
+ get_down_block,
32
+ get_up_block,
33
+ )
34
+
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+
39
+ @dataclass
40
+ class UNet2DConditionOutput(BaseOutput):
41
+ """
42
+ Args:
43
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
44
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
45
+ """
46
+
47
+ sample: torch.FloatTensor
48
+
49
+
50
+ class UNet2DConditionModel(ModelMixin, ConfigMixin):
51
+ r"""
52
+ UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
53
+ and returns sample shaped output.
54
+
55
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
56
+ implements for all the models (such as downloading or saving, etc.)
57
+
58
+ Parameters:
59
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
60
+ Height and width of input/output sample.
61
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
62
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
63
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
64
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
65
+ Whether to flip the sin to cos in the time embedding.
66
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
67
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
68
+ The tuple of downsample blocks to use.
69
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
70
+ The tuple of upsample blocks to use.
71
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
72
+ The tuple of output channels for each block.
73
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
74
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
75
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
76
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
77
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
78
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
79
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
80
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
81
+ """
82
+
83
+ _supports_gradient_checkpointing = True
84
+
85
+ @register_to_config
86
+ def __init__(
87
+ self,
88
+ sample_size: Optional[int] = None,
89
+ in_channels: int = 4,
90
+ out_channels: int = 4,
91
+ center_input_sample: bool = False,
92
+ flip_sin_to_cos: bool = True,
93
+ freq_shift: int = 0,
94
+ down_block_types: Tuple[str] = (
95
+ "CrossAttnDownBlock2D",
96
+ "CrossAttnDownBlock2D",
97
+ "CrossAttnDownBlock2D",
98
+ "DownBlock2D",
99
+ ),
100
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
101
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
102
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
103
+ layers_per_block: int = 2,
104
+ downsample_padding: int = 1,
105
+ mid_block_scale_factor: float = 1,
106
+ act_fn: str = "silu",
107
+ norm_num_groups: int = 32,
108
+ norm_eps: float = 1e-5,
109
+ cross_attention_dim: int = 1280,
110
+ attention_head_dim: Union[int, Tuple[int]] = 8,
111
+ dual_cross_attention: bool = False,
112
+ use_linear_projection: bool = False,
113
+ num_class_embeds: Optional[int] = None,
114
+ ):
115
+ super().__init__()
116
+
117
+ self.sample_size = sample_size
118
+ time_embed_dim = block_out_channels[0] * 4
119
+ # import ipdb;ipdb.set_trace()
120
+
121
+ # input
122
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
123
+
124
+ # time
125
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
126
+ timestep_input_dim = block_out_channels[0]
127
+
128
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
129
+
130
+ # class embedding
131
+ if num_class_embeds is not None:
132
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
133
+
134
+ self.down_blocks = nn.ModuleList([])
135
+ self.mid_block = None
136
+ self.up_blocks = nn.ModuleList([])
137
+
138
+ if isinstance(only_cross_attention, bool):
139
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
140
+
141
+ if isinstance(attention_head_dim, int):
142
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
143
+
144
+ # down
145
+ output_channel = block_out_channels[0]
146
+ for i, down_block_type in enumerate(down_block_types):
147
+ input_channel = output_channel
148
+ output_channel = block_out_channels[i]
149
+ is_final_block = i == len(block_out_channels) - 1
150
+
151
+ down_block = get_down_block(
152
+ down_block_type,
153
+ num_layers=layers_per_block,
154
+ in_channels=input_channel,
155
+ out_channels=output_channel,
156
+ temb_channels=time_embed_dim,
157
+ add_downsample=not is_final_block,
158
+ resnet_eps=norm_eps,
159
+ resnet_act_fn=act_fn,
160
+ resnet_groups=norm_num_groups,
161
+ cross_attention_dim=cross_attention_dim,
162
+ attn_num_head_channels=attention_head_dim[i],
163
+ downsample_padding=downsample_padding,
164
+ dual_cross_attention=dual_cross_attention,
165
+ use_linear_projection=use_linear_projection,
166
+ only_cross_attention=only_cross_attention[i],
167
+ )
168
+ self.down_blocks.append(down_block)
169
+
170
+ # mid
171
+ self.mid_block = UNetMidBlock2DCrossAttn(
172
+ in_channels=block_out_channels[-1],
173
+ temb_channels=time_embed_dim,
174
+ resnet_eps=norm_eps,
175
+ resnet_act_fn=act_fn,
176
+ output_scale_factor=mid_block_scale_factor,
177
+ resnet_time_scale_shift="default",
178
+ cross_attention_dim=cross_attention_dim,
179
+ attn_num_head_channels=attention_head_dim[-1],
180
+ resnet_groups=norm_num_groups,
181
+ dual_cross_attention=dual_cross_attention,
182
+ use_linear_projection=use_linear_projection,
183
+ )
184
+
185
+ # count how many layers upsample the images
186
+ self.num_upsamplers = 0
187
+
188
+ # up
189
+ reversed_block_out_channels = list(reversed(block_out_channels))
190
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
191
+ only_cross_attention = list(reversed(only_cross_attention))
192
+ output_channel = reversed_block_out_channels[0]
193
+ for i, up_block_type in enumerate(up_block_types):
194
+ is_final_block = i == len(block_out_channels) - 1
195
+
196
+ prev_output_channel = output_channel
197
+ output_channel = reversed_block_out_channels[i]
198
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
199
+
200
+ # add upsample block for all BUT final layer
201
+ if not is_final_block:
202
+ add_upsample = True
203
+ self.num_upsamplers += 1
204
+ else:
205
+ add_upsample = False
206
+
207
+ up_block = get_up_block(
208
+ up_block_type,
209
+ num_layers=layers_per_block + 1,
210
+ in_channels=input_channel,
211
+ out_channels=output_channel,
212
+ prev_output_channel=prev_output_channel,
213
+ temb_channels=time_embed_dim,
214
+ add_upsample=add_upsample,
215
+ resnet_eps=norm_eps,
216
+ resnet_act_fn=act_fn,
217
+ resnet_groups=norm_num_groups,
218
+ cross_attention_dim=cross_attention_dim,
219
+ attn_num_head_channels=reversed_attention_head_dim[i],
220
+ dual_cross_attention=dual_cross_attention,
221
+ use_linear_projection=use_linear_projection,
222
+ only_cross_attention=only_cross_attention[i],
223
+ )
224
+ self.up_blocks.append(up_block)
225
+ prev_output_channel = output_channel
226
+
227
+ # out
228
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
229
+ self.conv_act = nn.SiLU()
230
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
231
+
232
+ def set_attention_slice(self, slice_size):
233
+ head_dims = self.config.attention_head_dim
234
+ head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
235
+ if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
236
+ raise ValueError(
237
+ f"Make sure slice_size {slice_size} is a common divisor of "
238
+ f"the number of heads used in cross_attention: {head_dims}"
239
+ )
240
+ if slice_size is not None and slice_size > min(head_dims):
241
+ raise ValueError(
242
+ f"slice_size {slice_size} has to be smaller or equal to "
243
+ f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
244
+ )
245
+
246
+ for block in self.down_blocks:
247
+ if hasattr(block, "attentions") and block.attentions is not None:
248
+ block.set_attention_slice(slice_size)
249
+
250
+ self.mid_block.set_attention_slice(slice_size)
251
+
252
+ for block in self.up_blocks:
253
+ if hasattr(block, "attentions") and block.attentions is not None:
254
+ block.set_attention_slice(slice_size)
255
+
256
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
257
+ for block in self.down_blocks:
258
+ if hasattr(block, "attentions") and block.attentions is not None:
259
+ block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
260
+
261
+ self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
262
+
263
+ for block in self.up_blocks:
264
+ if hasattr(block, "attentions") and block.attentions is not None:
265
+ block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
266
+
267
+ def _set_gradient_checkpointing(self, module, value=False):
268
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
269
+ module.gradient_checkpointing = value
270
+
271
+ def forward(
272
+ self,
273
+ sample: torch.FloatTensor,
274
+ timestep: Union[torch.Tensor, float, int],
275
+ encoder_hidden_states: torch.Tensor,
276
+ class_labels: Optional[torch.Tensor] = None,
277
+ text_format_dict = {},
278
+ return_dict: bool = True,
279
+ ) -> Union[UNet2DConditionOutput, Tuple]:
280
+ r"""
281
+ Args:
282
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
283
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
284
+ encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
285
+ return_dict (`bool`, *optional*, defaults to `True`):
286
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
287
+
288
+ Returns:
289
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
290
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
291
+ returning a tuple, the first element is the sample tensor.
292
+ """
293
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
294
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
295
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
296
+ # on the fly if necessary.
297
+ default_overall_up_factor = 2**self.num_upsamplers
298
+
299
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
300
+ forward_upsample_size = False
301
+ upsample_size = None
302
+
303
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
304
+ logger.info("Forward upsample size to force interpolation output size.")
305
+ forward_upsample_size = True
306
+
307
+ # 0. center input if necessary
308
+ if self.config.center_input_sample:
309
+ sample = 2 * sample - 1.0
310
+
311
+ # 1. time
312
+ timesteps = timestep
313
+ if not torch.is_tensor(timesteps):
314
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
315
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
316
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
317
+ timesteps = timesteps[None].to(sample.device)
318
+
319
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
320
+ timesteps = timesteps.expand(sample.shape[0])
321
+
322
+ t_emb = self.time_proj(timesteps)
323
+
324
+ # timesteps does not contain any weights and will always return f32 tensors
325
+ # but time_embedding might actually be running in fp16. so we need to cast here.
326
+ # there might be better ways to encapsulate this.
327
+ t_emb = t_emb.to(dtype=self.dtype)
328
+ emb = self.time_embedding(t_emb)
329
+
330
+ if self.config.num_class_embeds is not None:
331
+ if class_labels is None:
332
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
333
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
334
+ emb = emb + class_emb
335
+
336
+ # 2. pre-process
337
+ sample = self.conv_in(sample)
338
+
339
+ # 3. down
340
+ down_block_res_samples = (sample,)
341
+ for downsample_block in self.down_blocks:
342
+ if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
343
+ if isinstance(downsample_block, CrossAttnDownBlock2D):
344
+ sample, res_samples = downsample_block(
345
+ hidden_states=sample,
346
+ temb=emb,
347
+ encoder_hidden_states=encoder_hidden_states,
348
+ text_format_dict=text_format_dict
349
+ )
350
+ else:
351
+ sample, res_samples = downsample_block(
352
+ hidden_states=sample,
353
+ temb=emb,
354
+ encoder_hidden_states=encoder_hidden_states,
355
+ )
356
+ else:
357
+ if isinstance(downsample_block, CrossAttnDownBlock2D):
358
+ import ipdb;ipdb.set_trace()
359
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
360
+ down_block_res_samples += res_samples
361
+
362
+ # 4. mid
363
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states,
364
+ text_format_dict=text_format_dict)
365
+
366
+ # 5. up
367
+ for i, upsample_block in enumerate(self.up_blocks):
368
+ is_final_block = i == len(self.up_blocks) - 1
369
+
370
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
371
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
372
+
373
+ # if we have not reached the final block and need to forward the
374
+ # upsample size, we do it here
375
+ if not is_final_block and forward_upsample_size:
376
+ upsample_size = down_block_res_samples[-1].shape[2:]
377
+
378
+ if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
379
+ if isinstance(upsample_block, CrossAttnUpBlock2D):
380
+ sample = upsample_block(
381
+ hidden_states=sample,
382
+ temb=emb,
383
+ res_hidden_states_tuple=res_samples,
384
+ encoder_hidden_states=encoder_hidden_states,
385
+ upsample_size=upsample_size,
386
+ text_format_dict=text_format_dict
387
+ )
388
+ else:
389
+ sample = upsample_block(
390
+ hidden_states=sample,
391
+ temb=emb,
392
+ res_hidden_states_tuple=res_samples,
393
+ encoder_hidden_states=encoder_hidden_states,
394
+ upsample_size=upsample_size,
395
+ )
396
+ else:
397
+ if isinstance(upsample_block, CrossAttnUpBlock2D):
398
+ upsample_block.attentions
399
+ import ipdb;ipdb.set_trace()
400
+ sample = upsample_block(
401
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
402
+ )
403
+ # 6. post-process
404
+ sample = self.conv_norm_out(sample)
405
+ sample = self.conv_act(sample)
406
+ sample = self.conv_out(sample)
407
+
408
+ if not return_dict:
409
+ return (sample,)
410
+
411
+ return UNet2DConditionOutput(sample=sample)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu117
2
+ torch==1.13.1
3
+ torchvision==0.14.1
4
+ diffusers==0.12.1
5
+ transformers==4.26.0
6
+ numpy==1.24.2
7
+ seaborn==0.12.2
8
+ accelerate==0.16.0
9
+ scikit-learn==0.24.1
rich-text-to-json-iframe.html ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <title>Rich Text to JSON</title>
6
+ <link rel="stylesheet" href="https://cdn.quilljs.com/1.3.6/quill.snow.css">
7
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css">
8
+ <link rel="stylesheet" type="text/css"
9
+ href="https://cdnjs.cloudflare.com/ajax/libs/spectrum/1.8.0/spectrum.min.css">
10
+ <link rel="stylesheet"
11
+ href='https://fonts.googleapis.com/css?family=Mirza|Roboto|Slabo+27px|Sofia|Inconsolata|Ubuntu|Akronim|Monoton&display=swap'>
12
+ <style>
13
+ html,
14
+ body {
15
+ background-color: white;
16
+ margin: 0;
17
+ }
18
+
19
+ /* Set default font-family */
20
+ .ql-snow .ql-tooltip::before {
21
+ content: "Footnote";
22
+ line-height: 26px;
23
+ margin-right: 8px;
24
+ }
25
+
26
+ .ql-snow .ql-tooltip[data-mode=link]::before {
27
+ content: "Enter footnote:";
28
+ }
29
+
30
+ .row {
31
+ margin-top: 15px;
32
+ margin-left: 0px;
33
+ margin-bottom: 15px;
34
+ }
35
+
36
+ .btn-primary {
37
+ color: #ffffff;
38
+ background-color: #2780e3;
39
+ border-color: #2780e3;
40
+ }
41
+
42
+ .btn-primary:hover {
43
+ color: #ffffff;
44
+ background-color: #1967be;
45
+ border-color: #1862b5;
46
+ }
47
+
48
+ .btn {
49
+ display: inline-block;
50
+ margin-bottom: 0;
51
+ font-weight: normal;
52
+ text-align: center;
53
+ vertical-align: middle;
54
+ touch-action: manipulation;
55
+ cursor: pointer;
56
+ background-image: none;
57
+ border: 1px solid transparent;
58
+ white-space: nowrap;
59
+ padding: 10px 18px;
60
+ font-size: 15px;
61
+ line-height: 1.42857143;
62
+ border-radius: 0;
63
+ user-select: none;
64
+ }
65
+
66
+ #standalone-container {
67
+ width: 100%;
68
+ background-color: #ffffff;
69
+ }
70
+
71
+ #editor-container {
72
+ font-family: "Aref Ruqaa";
73
+ font-size: 18px;
74
+ height: 250px;
75
+ width: 100%;
76
+ }
77
+
78
+ #toolbar-container {
79
+ font-family: "Aref Ruqaa";
80
+ display: flex;
81
+ flex-wrap: wrap;
82
+ }
83
+
84
+ #json-container {
85
+ max-width: 720px;
86
+ }
87
+
88
+ /* Set dropdown font-families */
89
+ #toolbar-container .ql-font span[data-label="Base"]::before {
90
+ font-family: "Aref Ruqaa";
91
+ }
92
+
93
+ #toolbar-container .ql-font span[data-label="Claude Monet"]::before {
94
+ font-family: "Mirza";
95
+ }
96
+
97
+ #toolbar-container .ql-font span[data-label="Ukiyoe"]::before {
98
+ font-family: "Roboto";
99
+ }
100
+
101
+ #toolbar-container .ql-font span[data-label="Cyber Punk"]::before {
102
+ font-family: "Comic Sans MS";
103
+ }
104
+
105
+ #toolbar-container .ql-font span[data-label="Pop Art"]::before {
106
+ font-family: "sofia";
107
+ }
108
+
109
+ #toolbar-container .ql-font span[data-label="Van Gogh"]::before {
110
+ font-family: "slabo 27px";
111
+ }
112
+
113
+ #toolbar-container .ql-font span[data-label="Pixel Art"]::before {
114
+ font-family: "inconsolata";
115
+ }
116
+
117
+ #toolbar-container .ql-font span[data-label="Rembrandt"]::before {
118
+ font-family: "ubuntu";
119
+ }
120
+
121
+ #toolbar-container .ql-font span[data-label="Cubism"]::before {
122
+ font-family: "Akronim";
123
+ }
124
+
125
+ #toolbar-container .ql-font span[data-label="Neon Art"]::before {
126
+ font-family: "Monoton";
127
+ }
128
+
129
+ /* Set content font-families */
130
+ .ql-font-mirza {
131
+ font-family: "Mirza";
132
+ }
133
+
134
+ .ql-font-roboto {
135
+ font-family: "Roboto";
136
+ }
137
+
138
+ .ql-font-cursive {
139
+ font-family: "Comic Sans MS";
140
+ }
141
+
142
+ .ql-font-sofia {
143
+ font-family: "sofia";
144
+ }
145
+
146
+ .ql-font-slabo {
147
+ font-family: "slabo 27px";
148
+ }
149
+
150
+ .ql-font-inconsolata {
151
+ font-family: "inconsolata";
152
+ }
153
+
154
+ .ql-font-ubuntu {
155
+ font-family: "ubuntu";
156
+ }
157
+
158
+ .ql-font-Akronim {
159
+ font-family: "Akronim";
160
+ }
161
+
162
+ .ql-font-Monoton {
163
+ font-family: "Monoton";
164
+ }
165
+
166
+ .ql-color .ql-picker-options [data-value=Color-Picker] {
167
+ background: none !important;
168
+ width: 100% !important;
169
+ height: 20px !important;
170
+ text-align: center;
171
+ }
172
+
173
+ .ql-color .ql-picker-options [data-value=Color-Picker]:before {
174
+ content: 'Color Picker';
175
+ }
176
+
177
+ .ql-color .ql-picker-options [data-value=Color-Picker]:hover {
178
+ border-color: transparent !important;
179
+ }
180
+ </style>
181
+ </head>
182
+
183
+ <body>
184
+ <div id="standalone-container">
185
+ <div id="toolbar-container">
186
+ <span class="ql-formats">
187
+ <select class="ql-font">
188
+ <option selected>Base</option>
189
+ <option value="mirza">Claude Monet</option>
190
+ <option value="roboto">Ukiyoe</option>
191
+ <option value="cursive">Cyber Punk</option>
192
+ <option value="sofia">Pop Art</option>
193
+ <option value="slabo">Van Gogh</option>
194
+ <option value="inconsolata">Pixel Art</option>
195
+ <option value="ubuntu">Rembrandt</option>
196
+ <option value="Akronim">Cubism</option>
197
+ <option value="Monoton">Neon Art</option>
198
+ </select>
199
+ <select class="ql-size">
200
+ <option value="18px">Small</option>
201
+ <option selected>Normal</option>
202
+ <option value="32px">Large</option>
203
+ <option value="50px">Huge</option>
204
+ </select>
205
+ </span>
206
+ <span class="ql-formats">
207
+ <button class="ql-strike"></button>
208
+ </span>
209
+ <!-- <span class="ql-formats">
210
+ <button class="ql-bold"></button>
211
+ <button class="ql-italic"></button>
212
+ <button class="ql-underline"></button>
213
+ </span> -->
214
+ <span class="ql-formats">
215
+ <select class="ql-color">
216
+ <option value="Color-Picker"></option>
217
+ </select>
218
+ <!-- <select class="ql-background"></select> -->
219
+ </span>
220
+ <!-- <span class="ql-formats">
221
+ <button class="ql-script" value="sub"></button>
222
+ <button class="ql-script" value="super"></button>
223
+ </span>
224
+ <span class="ql-formats">
225
+ <button class="ql-header" value="1"></button>
226
+ <button class="ql-header" value="2"></button>
227
+ <button class="ql-blockquote"></button>
228
+ <button class="ql-code-block"></button>
229
+ </span>
230
+ <span class="ql-formats">
231
+ <button class="ql-list" value="ordered"></button>
232
+ <button class="ql-list" value="bullet"></button>
233
+ <button class="ql-indent" value="-1"></button>
234
+ <button class="ql-indent" value="+1"></button>
235
+ </span>
236
+ <span class="ql-formats">
237
+ <button class="ql-direction" value="rtl"></button>
238
+ <select class="ql-align"></select>
239
+ </span>
240
+ <span class="ql-formats">
241
+ <button class="ql-link"></button>
242
+ <button class="ql-image"></button>
243
+ <button class="ql-video"></button>
244
+ <button class="ql-formula"></button>
245
+ </span> -->
246
+ <span class="ql-formats">
247
+ <button class="ql-link"></button>
248
+ </span>
249
+ <span class="ql-formats">
250
+ <button class="ql-clean"></button>
251
+ </span>
252
+ </div>
253
+ <div id="editor-container" style="height:300px;"></div>
254
+ </div>
255
+ <script src="https://cdn.quilljs.com/1.3.6/quill.min.js"></script>
256
+ <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.1.0/jquery.min.js"></script>
257
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/spectrum/1.8.0/spectrum.min.js"></script>
258
+ <script>
259
+
260
+ // Register the customs format with Quill
261
+ const Font = Quill.import('formats/font');
262
+ Font.whitelist = ['mirza', 'roboto', 'sofia', 'slabo', 'inconsolata', 'ubuntu', 'cursive', 'Akronim', 'Monoton'];
263
+ const Link = Quill.import('formats/link');
264
+ Link.sanitize = function (url) {
265
+ // modify url if desired
266
+ return url;
267
+ }
268
+ const SizeStyle = Quill.import('attributors/style/size');
269
+ SizeStyle.whitelist = ['10px', '18px', '20px', '32px', '50px', '60px', '64px', '70px'];
270
+ Quill.register(SizeStyle, true);
271
+ Quill.register(Link, true);
272
+ Quill.register(Font, true);
273
+ const icons = Quill.import('ui/icons');
274
+ icons['link'] = `<svg xmlns="http://www.w3.org/2000/svg" width="17" viewBox="0 0 512 512" xml:space="preserve"><path fill="#010101" d="M276.75 1c4.51 3.23 9.2 6.04 12.97 9.77 29.7 29.45 59.15 59.14 88.85 88.6 4.98 4.93 7.13 10.37 7.12 17.32-.1 125.8-.09 251.6-.01 377.4 0 7.94-1.96 14.46-9.62 18.57-121.41.34-242.77.34-364.76.05A288.3 288.3 0 0 1 1 502c0-163.02 0-326.04.34-489.62C3.84 6.53 8.04 3.38 13 1c23.35 0 46.7 0 70.82.3 2.07.43 3.38.68 4.69.68h127.98c18.44.01 36.41.04 54.39-.03 1.7 0 3.41-.62 5.12-.95h.75M33.03 122.5v359.05h320.22V129.18h-76.18c-14.22-.01-19.8-5.68-19.8-20.09V33.31H33.02v89.19m256.29-27.36c.72.66 1.44 1.9 2.17 1.9 12.73.12 25.46.08 37.55.08L289.3 57.45v37.7z"/><path fill="#020202" d="M513 375.53c-4.68 7.99-11.52 10.51-20.21 10.25-13.15-.4-26.32-.1-39.48-.1h-5.58c5.49 8.28 10.7 15.74 15.46 23.47 6.06 9.82 1.14 21.65-9.96 24.27-6.7 1.59-12.45-.64-16.23-6.15a2608.6 2608.6 0 0 1-32.97-49.36c-3.57-5.48-3.39-11.54.17-16.98a3122.5 3122.5 0 0 1 32.39-48.56c5.22-7.65 14.67-9.35 21.95-4.45 7.63 5.12 9.6 14.26 4.5 22.33-4.75 7.54-9.8 14.9-15.11 22.95h33.64V225.19h-5.24c-19.49 0-38.97.11-58.46-.05-12.74-.1-20.12-13.15-13.84-24.14 3.12-5.46 8.14-7.71 14.18-7.73 26.15-.06 52.3-.04 78.45 0 7.1 0 12.47 3.05 16.01 9.64.33 57.44.33 114.8.33 172.62z"/><path fill="#111" d="M216.03 1.97C173.52 1.98 131 2 88.5 1.98a16 16 0 0 1-4.22-.68c43.4-.3 87.09-.3 131.24-.06.48.25.5.73.5.73z"/><path fill="#232323" d="M216.5 1.98c-.47 0-.5-.5-.5-.74C235.7 1 255.38 1 275.53 1c-1.24.33-2.94.95-4.65.95-17.98.07-35.95.04-54.39.03z"/><path fill="#040404" d="M148 321.42h153.5c14.25 0 19.96 5.71 19.96 19.97.01 19.17.03 38.33 0 57.5-.03 12.6-6.16 18.78-18.66 18.78H99.81c-12.42 0-18.75-6.34-18.76-18.73-.01-19.83-.02-39.66 0-59.5.02-11.47 6.4-17.93 17.95-18 16.17-.08 32.33-.02 49-.02m40.5 32.15h-75.16v31.84h175.7v-31.84H188.5z"/><path fill="#030303" d="m110 225.33 178.89-.03c11.98 0 19.25 9.95 15.74 21.44-2.05 6.71-7.5 10.57-15.14 10.57-63.63 0-127.25-.01-190.88-.07-12.03-.02-19.17-8.62-16.7-19.84 1.6-7.21 7.17-11.74 15.1-12.04 4.17-.16 8.33-.03 13-.03zm-24.12-36.19c-5.28-6.2-6.3-12.76-2.85-19.73 3.22-6.49 9.13-8.24 15.86-8.24 25.64.01 51.27-.06 76.91.04 13.07.04 20.66 10.44 16.33 22.08-2.25 6.06-6.63 9.76-13.08 9.8-27.97.18-55.94.2-83.9-.07-3.01-.03-6-2.36-9.27-3.88z"/></svg>`
275
+ const quill = new Quill('#editor-container', {
276
+ modules: {
277
+ toolbar: {
278
+ container: '#toolbar-container',
279
+ },
280
+ },
281
+ theme: 'snow'
282
+ });
283
+ var toolbar = quill.getModule('toolbar');
284
+ $(toolbar.container).find('.ql-color').spectrum({
285
+ preferredFormat: "rgb",
286
+ showInput: true,
287
+ showInitial: true,
288
+ showPalette: true,
289
+ showSelectionPalette: true,
290
+ palette: [
291
+ ["#000", "#444", "#666", "#999", "#ccc", "#eee", "#f3f3f3", "#fff"],
292
+ ["#f00", "#f90", "#ff0", "#0f0", "#0ff", "#00f", "#90f", "#f0f"],
293
+ ["#ea9999", "#f9cb9c", "#ffe599", "#b6d7a8", "#a2c4c9", "#9fc5e8", "#b4a7d6", "#d5a6bd"],
294
+ ["#e06666", "#f6b26b", "#ffd966", "#93c47d", "#76a5af", "#6fa8dc", "#8e7cc3", "#c27ba0"],
295
+ ["#c00", "#e69138", "#f1c232", "#6aa84f", "#45818e", "#3d85c6", "#674ea7", "#a64d79"],
296
+ ["#900", "#b45f06", "#bf9000", "#38761d", "#134f5c", "#0b5394", "#351c75", "#741b47"],
297
+ ["#600", "#783f04", "#7f6000", "#274e13", "#0c343d", "#073763", "#20124d", "#4c1130"]
298
+ ],
299
+ change: function (color) {
300
+ var value = color.toHexString();
301
+ quill.format('color', value);
302
+ }
303
+ });
304
+
305
+ quill.on('text-change', () => {
306
+ // keep qull data inside _data to communicate with Gradio
307
+ document.body._data = quill.getContents()
308
+ })
309
+ function setQuillContents(content) {
310
+ quill.setContents(content);
311
+ document.body._data = quill.getContents();
312
+ }
313
+ document.body.setQuillContents = setQuillContents
314
+ </script>
315
+ <script src="https://unpkg.com/@popperjs/core@2/dist/umd/popper.min.js"></script>
316
+ <script src="https://unpkg.com/tippy.js@6/dist/tippy-bundle.umd.js"></script>
317
+ <script>
318
+ // With the above scripts loaded, you can call `tippy()` with a CSS
319
+ // selector and a `content` prop:
320
+ tippy('.ql-font', {
321
+ content: 'Add a style to the token',
322
+ });
323
+ tippy('.ql-size', {
324
+ content: 'Reweight the token',
325
+ });
326
+ tippy('.ql-color', {
327
+ content: 'Pick a color for the token',
328
+ });
329
+ tippy('.ql-link', {
330
+ content: 'Clarify the token',
331
+ });
332
+ tippy('.ql-strike', {
333
+ content: 'Change the token weight to be negative',
334
+ });
335
+ tippy('.ql-clean', {
336
+ content: 'Remove all the formats',
337
+ });
338
+ </script>
339
+ </body>
340
+
341
+ </html>
rich-text-to-json.js ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class RichTextEditor extends HTMLElement {
2
+ constructor() {
3
+ super();
4
+ this.loadExternalScripts();
5
+ this.attachShadow({ mode: 'open' });
6
+ this.shadowRoot.innerHTML = `
7
+ ${RichTextEditor.header()}
8
+ ${RichTextEditor.template()}
9
+ `;
10
+ }
11
+ connectedCallback() {
12
+ this.myQuill = this.mountQuill();
13
+ }
14
+ loadExternalScripts() {
15
+ const links = ["https://cdn.quilljs.com/1.3.6/quill.snow.css", "https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css", "https://fonts.googleapis.com/css?family=Mirza|Roboto|Slabo+27px|Sofia|Inconsolata|Ubuntu|Akronim|Monoton&display=swap"]
16
+ links.forEach(link => {
17
+ const css = document.createElement("link");
18
+ css.href = link;
19
+ css.rel = "stylesheet"
20
+ document.head.appendChild(css);
21
+ })
22
+
23
+ }
24
+ static template() {
25
+ return `
26
+ <div id="standalone-container">
27
+ <div id="toolbar-container">
28
+ <span class="ql-formats">
29
+ <select class="ql-font">
30
+ <option selected>Base</option>
31
+ <option value="mirza">Claude Monet</option>
32
+ <option value="roboto">Ukiyoe</option>
33
+ <option value="cursive">Cyber Punk</option>
34
+ <option value="sofia">Pop Art</option>
35
+ <option value="slabo">Van Gogh</option>
36
+ <option value="inconsolata">Pixel Art</option>
37
+ <option value="ubuntu">Rembrandt</option>
38
+ <option value="Akronim">Cubism</option>
39
+ <option value="Monoton">Neon Art</option>
40
+ </select>
41
+ <select class="ql-size">
42
+ <option value="18px">Small</option>
43
+ <option selected>Normal</option>
44
+ <option value="32px">Large</option>
45
+ <option value="50px">Huge</option>
46
+ </select>
47
+ </span>
48
+ <span class="ql-formats">
49
+ <button class="ql-strike"></button>
50
+ </span>
51
+ <!-- <span class="ql-formats">
52
+ <button class="ql-bold"></button>
53
+ <button class="ql-italic"></button>
54
+ <button class="ql-underline"></button>
55
+ </span> -->
56
+ <span class="ql-formats">
57
+ <select class="ql-color"></select>
58
+ <!-- <select class="ql-background"></select> -->
59
+ </span>
60
+ <!-- <span class="ql-formats">
61
+ <button class="ql-script" value="sub"></button>
62
+ <button class="ql-script" value="super"></button>
63
+ </span>
64
+ <span class="ql-formats">
65
+ <button class="ql-header" value="1"></button>
66
+ <button class="ql-header" value="2"></button>
67
+ <button class="ql-blockquote"></button>
68
+ <button class="ql-code-block"></button>
69
+ </span>
70
+ <span class="ql-formats">
71
+ <button class="ql-list" value="ordered"></button>
72
+ <button class="ql-list" value="bullet"></button>
73
+ <button class="ql-indent" value="-1"></button>
74
+ <button class="ql-indent" value="+1"></button>
75
+ </span>
76
+ <span class="ql-formats">
77
+ <button class="ql-direction" value="rtl"></button>
78
+ <select class="ql-align"></select>
79
+ </span>
80
+ <span class="ql-formats">
81
+ <button class="ql-link"></button>
82
+ <button class="ql-image"></button>
83
+ <button class="ql-video"></button>
84
+ <button class="ql-formula"></button>
85
+ </span> -->
86
+ <span class="ql-formats">
87
+ <button class="ql-link"></button>
88
+ </span>
89
+ <span class="ql-formats">
90
+ <button class="ql-clean"></button>
91
+ </span>
92
+ </div>
93
+ <div id="editor-container"></div>
94
+ </div>
95
+ `;
96
+ }
97
+
98
+ static header() {
99
+ return `
100
+ <link rel="stylesheet" href="https://cdn.quilljs.com/1.3.6/quill.snow.css">
101
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css">
102
+ <style>
103
+ /* Set default font-family */
104
+ .ql-snow .ql-tooltip::before {
105
+ content: "Footnote";
106
+ line-height: 26px;
107
+ margin-right: 8px;
108
+ }
109
+
110
+ .ql-snow .ql-tooltip[data-mode=link]::before {
111
+ content: "Enter footnote:";
112
+ }
113
+
114
+ .row {
115
+ margin-top: 15px;
116
+ margin-left: 0px;
117
+ margin-bottom: 15px;
118
+ }
119
+
120
+ .btn-primary {
121
+ color: #ffffff;
122
+ background-color: #2780e3;
123
+ border-color: #2780e3;
124
+ }
125
+
126
+ .btn-primary:hover {
127
+ color: #ffffff;
128
+ background-color: #1967be;
129
+ border-color: #1862b5;
130
+ }
131
+
132
+ .btn {
133
+ display: inline-block;
134
+ margin-bottom: 0;
135
+ font-weight: normal;
136
+ text-align: center;
137
+ vertical-align: middle;
138
+ touch-action: manipulation;
139
+ cursor: pointer;
140
+ background-image: none;
141
+ border: 1px solid transparent;
142
+ white-space: nowrap;
143
+ padding: 10px 18px;
144
+ font-size: 15px;
145
+ line-height: 1.42857143;
146
+ border-radius: 0;
147
+ user-select: none;
148
+ }
149
+
150
+ #standalone-container {
151
+ position: relative;
152
+ max-width: 720px;
153
+ background-color: #ffffff;
154
+ color: black !important;
155
+ z-index: 1000;
156
+ }
157
+
158
+ #editor-container {
159
+ font-family: "Aref Ruqaa";
160
+ font-size: 18px;
161
+ height: 250px;
162
+ }
163
+
164
+ #toolbar-container {
165
+ font-family: "Aref Ruqaa";
166
+ display: flex;
167
+ flex-wrap: wrap;
168
+ }
169
+
170
+ #json-container {
171
+ max-width: 720px;
172
+ }
173
+
174
+ /* Set dropdown font-families */
175
+ #toolbar-container .ql-font span[data-label="Base"]::before {
176
+ font-family: "Aref Ruqaa";
177
+ }
178
+
179
+ #toolbar-container .ql-font span[data-label="Claude Monet"]::before {
180
+ font-family: "Mirza";
181
+ }
182
+
183
+ #toolbar-container .ql-font span[data-label="Ukiyoe"]::before {
184
+ font-family: "Roboto";
185
+ }
186
+
187
+ #toolbar-container .ql-font span[data-label="Cyber Punk"]::before {
188
+ font-family: "Comic Sans MS";
189
+ }
190
+
191
+ #toolbar-container .ql-font span[data-label="Pop Art"]::before {
192
+ font-family: "sofia";
193
+ }
194
+
195
+ #toolbar-container .ql-font span[data-label="Van Gogh"]::before {
196
+ font-family: "slabo 27px";
197
+ }
198
+
199
+ #toolbar-container .ql-font span[data-label="Pixel Art"]::before {
200
+ font-family: "inconsolata";
201
+ }
202
+
203
+ #toolbar-container .ql-font span[data-label="Rembrandt"]::before {
204
+ font-family: "ubuntu";
205
+ }
206
+
207
+ #toolbar-container .ql-font span[data-label="Cubism"]::before {
208
+ font-family: "Akronim";
209
+ }
210
+
211
+ #toolbar-container .ql-font span[data-label="Neon Art"]::before {
212
+ font-family: "Monoton";
213
+ }
214
+
215
+ /* Set content font-families */
216
+ .ql-font-mirza {
217
+ font-family: "Mirza";
218
+ }
219
+
220
+ .ql-font-roboto {
221
+ font-family: "Roboto";
222
+ }
223
+
224
+ .ql-font-cursive {
225
+ font-family: "Comic Sans MS";
226
+ }
227
+
228
+ .ql-font-sofia {
229
+ font-family: "sofia";
230
+ }
231
+
232
+ .ql-font-slabo {
233
+ font-family: "slabo 27px";
234
+ }
235
+
236
+ .ql-font-inconsolata {
237
+ font-family: "inconsolata";
238
+ }
239
+
240
+ .ql-font-ubuntu {
241
+ font-family: "ubuntu";
242
+ }
243
+
244
+ .ql-font-Akronim {
245
+ font-family: "Akronim";
246
+ }
247
+
248
+ .ql-font-Monoton {
249
+ font-family: "Monoton";
250
+ }
251
+ </style>
252
+ `;
253
+ }
254
+ async mountQuill() {
255
+ // Register the customs format with Quill
256
+ const lib = await import("https://cdn.jsdelivr.net/npm/shadow-selection-polyfill");
257
+ const getRange = lib.getRange;
258
+
259
+ const Font = Quill.import('formats/font');
260
+ Font.whitelist = ['mirza', 'roboto', 'sofia', 'slabo', 'inconsolata', 'ubuntu', 'cursive', 'Akronim', 'Monoton'];
261
+ const Link = Quill.import('formats/link');
262
+ Link.sanitize = function (url) {
263
+ // modify url if desired
264
+ return url;
265
+ }
266
+ const SizeStyle = Quill.import('attributors/style/size');
267
+ SizeStyle.whitelist = ['10px', '18px', '32px', '50px', '64px'];
268
+ Quill.register(SizeStyle, true);
269
+ Quill.register(Link, true);
270
+ Quill.register(Font, true);
271
+ const icons = Quill.import('ui/icons');
272
+ const icon = `<svg xmlns="http://www.w3.org/2000/svg" width="17" viewBox="0 0 512 512" xml:space="preserve"><path fill="#010101" d="M276.75 1c4.51 3.23 9.2 6.04 12.97 9.77 29.7 29.45 59.15 59.14 88.85 88.6 4.98 4.93 7.13 10.37 7.12 17.32-.1 125.8-.09 251.6-.01 377.4 0 7.94-1.96 14.46-9.62 18.57-121.41.34-242.77.34-364.76.05A288.3 288.3 0 0 1 1 502c0-163.02 0-326.04.34-489.62C3.84 6.53 8.04 3.38 13 1c23.35 0 46.7 0 70.82.3 2.07.43 3.38.68 4.69.68h127.98c18.44.01 36.41.04 54.39-.03 1.7 0 3.41-.62 5.12-.95h.75M33.03 122.5v359.05h320.22V129.18h-76.18c-14.22-.01-19.8-5.68-19.8-20.09V33.31H33.02v89.19m256.29-27.36c.72.66 1.44 1.9 2.17 1.9 12.73.12 25.46.08 37.55.08L289.3 57.45v37.7z"/><path fill="#020202" d="M513 375.53c-4.68 7.99-11.52 10.51-20.21 10.25-13.15-.4-26.32-.1-39.48-.1h-5.58c5.49 8.28 10.7 15.74 15.46 23.47 6.06 9.82 1.14 21.65-9.96 24.27-6.7 1.59-12.45-.64-16.23-6.15a2608.6 2608.6 0 0 1-32.97-49.36c-3.57-5.48-3.39-11.54.17-16.98a3122.5 3122.5 0 0 1 32.39-48.56c5.22-7.65 14.67-9.35 21.95-4.45 7.63 5.12 9.6 14.26 4.5 22.33-4.75 7.54-9.8 14.9-15.11 22.95h33.64V225.19h-5.24c-19.49 0-38.97.11-58.46-.05-12.74-.1-20.12-13.15-13.84-24.14 3.12-5.46 8.14-7.71 14.18-7.73 26.15-.06 52.3-.04 78.45 0 7.1 0 12.47 3.05 16.01 9.64.33 57.44.33 114.8.33 172.62z"/><path fill="#111" d="M216.03 1.97C173.52 1.98 131 2 88.5 1.98a16 16 0 0 1-4.22-.68c43.4-.3 87.09-.3 131.24-.06.48.25.5.73.5.73z"/><path fill="#232323" d="M216.5 1.98c-.47 0-.5-.5-.5-.74C235.7 1 255.38 1 275.53 1c-1.24.33-2.94.95-4.65.95-17.98.07-35.95.04-54.39.03z"/><path fill="#040404" d="M148 321.42h153.5c14.25 0 19.96 5.71 19.96 19.97.01 19.17.03 38.33 0 57.5-.03 12.6-6.16 18.78-18.66 18.78H99.81c-12.42 0-18.75-6.34-18.76-18.73-.01-19.83-.02-39.66 0-59.5.02-11.47 6.4-17.93 17.95-18 16.17-.08 32.33-.02 49-.02m40.5 32.15h-75.16v31.84h175.7v-31.84H188.5z"/><path fill="#030303" d="m110 225.33 178.89-.03c11.98 0 19.25 9.95 15.74 21.44-2.05 6.71-7.5 10.57-15.14 10.57-63.63 0-127.25-.01-190.88-.07-12.03-.02-19.17-8.62-16.7-19.84 1.6-7.21 7.17-11.74 15.1-12.04 4.17-.16 8.33-.03 13-.03zm-24.12-36.19c-5.28-6.2-6.3-12.76-2.85-19.73 3.22-6.49 9.13-8.24 15.86-8.24 25.64.01 51.27-.06 76.91.04 13.07.04 20.66 10.44 16.33 22.08-2.25 6.06-6.63 9.76-13.08 9.8-27.97.18-55.94.2-83.9-.07-3.01-.03-6-2.36-9.27-3.88z"/></svg>`
273
+ icons['link'] = icon;
274
+ const editorContainer = this.shadowRoot.querySelector('#editor-container')
275
+ const toolbarContainer = this.shadowRoot.querySelector('#toolbar-container')
276
+ const myQuill = new Quill(editorContainer, {
277
+ modules: {
278
+ toolbar: {
279
+ container: toolbarContainer,
280
+ },
281
+ },
282
+ theme: 'snow'
283
+ });
284
+ const normalizeNative = (nativeRange) => {
285
+
286
+ if (nativeRange) {
287
+ const range = nativeRange;
288
+
289
+ if (range.baseNode) {
290
+ range.startContainer = nativeRange.baseNode;
291
+ range.endContainer = nativeRange.focusNode;
292
+ range.startOffset = nativeRange.baseOffset;
293
+ range.endOffset = nativeRange.focusOffset;
294
+
295
+ if (range.endOffset < range.startOffset) {
296
+ range.startContainer = nativeRange.focusNode;
297
+ range.endContainer = nativeRange.baseNode;
298
+ range.startOffset = nativeRange.focusOffset;
299
+ range.endOffset = nativeRange.baseOffset;
300
+ }
301
+ }
302
+
303
+ if (range.startContainer) {
304
+ return {
305
+ start: { node: range.startContainer, offset: range.startOffset },
306
+ end: { node: range.endContainer, offset: range.endOffset },
307
+ native: range
308
+ };
309
+ }
310
+ }
311
+
312
+ return null
313
+ };
314
+
315
+ myQuill.selection.getNativeRange = () => {
316
+
317
+ const dom = myQuill.root.getRootNode();
318
+ const selection = getRange(dom);
319
+ const range = normalizeNative(selection);
320
+
321
+ return range;
322
+ };
323
+ let fromEditor = false;
324
+ editorContainer.addEventListener("pointerup", (e) => {
325
+ fromEditor = false;
326
+ });
327
+ editorContainer.addEventListener("pointerout", (e) => {
328
+ fromEditor = false;
329
+ });
330
+ editorContainer.addEventListener("pointerdown", (e) => {
331
+ fromEditor = true;
332
+ });
333
+
334
+ document.addEventListener("selectionchange", () => {
335
+ if (fromEditor) {
336
+ myQuill.selection.update()
337
+ }
338
+ });
339
+
340
+
341
+ myQuill.on('text-change', () => {
342
+ // keep qull data inside _data to communicate with Gradio
343
+ document.querySelector("#rich-text-root")._data = myQuill.getContents()
344
+ })
345
+ return myQuill
346
+ }
347
+ }
348
+
349
+ customElements.define('rich-text-editor', RichTextEditor);
share_btn.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ community_icon_html = """<svg id="share-btn-share-icon" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32">
2
+ <path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path>
3
+ <path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path>
4
+ </svg>"""
5
+
6
+ loading_icon_html = """<svg id="share-btn-loading-icon" style="display:none;" class="animate-spin" style="color: #ffffff;" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
7
+
8
+ share_js = """async () => {
9
+ async function uploadFile(file){
10
+ const UPLOAD_URL = 'https://huggingface.co/uploads';
11
+ const response = await fetch(UPLOAD_URL, {
12
+ method: 'POST',
13
+ headers: {
14
+ 'Content-Type': file.type,
15
+ 'X-Requested-With': 'XMLHttpRequest',
16
+ },
17
+ body: file, /// <- File inherits from Blob
18
+ });
19
+ const url = await response.text();
20
+ return url;
21
+ }
22
+ async function getInputImageFile(imageEl){
23
+ const res = await fetch(imageEl.src);
24
+ const blob = await res.blob();
25
+ const imageId = Date.now();
26
+ const fileName = `rich-text-image-${{imageId}}.png`;
27
+ return new File([blob], fileName, { type: 'image/png'});
28
+ }
29
+ const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
30
+ const richEl = document.getElementById("rich-text-root");
31
+ const data = richEl? richEl.contentDocument.body._data : {};
32
+ const text_input = JSON.stringify(data);
33
+ const negative_prompt = gradioEl.querySelector('#negative_prompt input').value;
34
+ const seed = gradioEl.querySelector('#seed input').value;
35
+ const richTextImg = gradioEl.querySelector('#rich-text-image img');
36
+ const plainTextImg = gradioEl.querySelector('#plain-text-image img');
37
+ const text_input_obj = JSON.parse(text_input);
38
+ const plain_prompt = text_input_obj.ops.map(e=> e.insert).join('');
39
+ const linkSrc = `https://huggingface.co/spaces/songweig/rich-text-to-image?prompt=${encodeURIComponent(text_input)}`;
40
+
41
+ const titleTxt = `RT2I: ${plain_prompt.slice(0, 50)}...`;
42
+ const shareBtnEl = gradioEl.querySelector('#share-btn');
43
+ const shareIconEl = gradioEl.querySelector('#share-btn-share-icon');
44
+ const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon');
45
+ if(!richTextImg){
46
+ return;
47
+ };
48
+ shareBtnEl.style.pointerEvents = 'none';
49
+ shareIconEl.style.display = 'none';
50
+ loadingIconEl.style.removeProperty('display');
51
+
52
+ const richImgFile = await getInputImageFile(richTextImg);
53
+ const plainImgFile = await getInputImageFile(plainTextImg);
54
+ const richImgURL = await uploadFile(richImgFile);
55
+ const plainImgURL = await uploadFile(plainImgFile);
56
+
57
+ const descriptionMd = `
58
+ ### Plain Prompt
59
+ ${plain_prompt}
60
+
61
+ 🔗 Shareable Link + Params: [here](${linkSrc})
62
+
63
+ ### Rich Tech Image
64
+ <img src="${richImgURL}">
65
+
66
+ ### Plain Text Image
67
+ <img src="${plainImgURL}">
68
+
69
+ `;
70
+ const params = new URLSearchParams({
71
+ title: titleTxt,
72
+ description: descriptionMd,
73
+ });
74
+ const paramsStr = params.toString();
75
+ window.open(`https://huggingface.co/spaces/songweig/rich-text-to-image/discussions/new?${paramsStr}`, '_blank');
76
+ shareBtnEl.style.removeProperty('pointer-events');
77
+ shareIconEl.style.removeProperty('display');
78
+ loadingIconEl.style.display = 'none';
79
+ }"""
80
+
81
+ css = """
82
+ #share-btn-container {
83
+ display: flex;
84
+ padding-left: 0.5rem !important;
85
+ padding-right: 0.5rem !important;
86
+ background-color: #000000;
87
+ justify-content: center;
88
+ align-items: center;
89
+ border-radius: 9999px !important;
90
+ width: 13rem;
91
+ margin-top: 10px;
92
+ margin-left: auto;
93
+ flex: unset !important;
94
+ }
95
+ #share-btn {
96
+ all: initial;
97
+ color: #ffffff;
98
+ font-weight: 600;
99
+ cursor: pointer;
100
+ font-family: 'IBM Plex Sans', sans-serif;
101
+ margin-left: 0.5rem !important;
102
+ padding-top: 0.25rem !important;
103
+ padding-bottom: 0.25rem !important;
104
+ right:0;
105
+ }
106
+ #share-btn * {
107
+ all: unset !important;
108
+ }
109
+ #share-btn-container div:nth-child(-n+2){
110
+ width: auto !important;
111
+ min-height: 0px !important;
112
+ }
113
+ #share-btn-container .wrap {
114
+ display: none !important;
115
+ }
116
+ """
utils/.DS_Store ADDED
Binary file (6.15 kB). View file
 
utils/attention_utils.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import matplotlib as mpl
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ import torch
7
+ import torchvision
8
+
9
+ from utils.richtext_utils import seed_everything
10
+ from sklearn.cluster import SpectralClustering
11
+
12
+ SelfAttentionLayers = [
13
+ 'down_blocks.0.attentions.0.transformer_blocks.0.attn1',
14
+ 'down_blocks.0.attentions.1.transformer_blocks.0.attn1',
15
+ 'down_blocks.1.attentions.0.transformer_blocks.0.attn1',
16
+ 'down_blocks.1.attentions.1.transformer_blocks.0.attn1',
17
+ 'down_blocks.2.attentions.0.transformer_blocks.0.attn1',
18
+ 'down_blocks.2.attentions.1.transformer_blocks.0.attn1',
19
+ 'mid_block.attentions.0.transformer_blocks.0.attn1',
20
+ 'up_blocks.1.attentions.0.transformer_blocks.0.attn1',
21
+ 'up_blocks.1.attentions.1.transformer_blocks.0.attn1',
22
+ 'up_blocks.1.attentions.2.transformer_blocks.0.attn1',
23
+ 'up_blocks.2.attentions.0.transformer_blocks.0.attn1',
24
+ 'up_blocks.2.attentions.1.transformer_blocks.0.attn1',
25
+ 'up_blocks.2.attentions.2.transformer_blocks.0.attn1',
26
+ 'up_blocks.3.attentions.0.transformer_blocks.0.attn1',
27
+ 'up_blocks.3.attentions.1.transformer_blocks.0.attn1',
28
+ 'up_blocks.3.attentions.2.transformer_blocks.0.attn1',
29
+ ]
30
+
31
+
32
+ CrossAttentionLayers = [
33
+ # 'down_blocks.0.attentions.0.transformer_blocks.0.attn2',
34
+ # 'down_blocks.0.attentions.1.transformer_blocks.0.attn2',
35
+ 'down_blocks.1.attentions.0.transformer_blocks.0.attn2',
36
+ # 'down_blocks.1.attentions.1.transformer_blocks.0.attn2',
37
+ 'down_blocks.2.attentions.0.transformer_blocks.0.attn2',
38
+ 'down_blocks.2.attentions.1.transformer_blocks.0.attn2',
39
+ 'mid_block.attentions.0.transformer_blocks.0.attn2',
40
+ 'up_blocks.1.attentions.0.transformer_blocks.0.attn2',
41
+ 'up_blocks.1.attentions.1.transformer_blocks.0.attn2',
42
+ 'up_blocks.1.attentions.2.transformer_blocks.0.attn2',
43
+ # 'up_blocks.2.attentions.0.transformer_blocks.0.attn2',
44
+ 'up_blocks.2.attentions.1.transformer_blocks.0.attn2',
45
+ # 'up_blocks.2.attentions.2.transformer_blocks.0.attn2',
46
+ # 'up_blocks.3.attentions.0.transformer_blocks.0.attn2',
47
+ # 'up_blocks.3.attentions.1.transformer_blocks.0.attn2',
48
+ # 'up_blocks.3.attentions.2.transformer_blocks.0.attn2'
49
+ ]
50
+
51
+
52
+ def split_attention_maps_over_steps(attention_maps):
53
+ r"""Function for splitting attention maps over steps.
54
+ Args:
55
+ attention_maps (dict): Dictionary of attention maps.
56
+ sampler_order (int): Order of the sampler.
57
+ """
58
+ # This function splits attention maps into unconditional and conditional score and over steps
59
+
60
+ attention_maps_cond = dict() # Maps corresponding to conditional score
61
+ attention_maps_uncond = dict() # Maps corresponding to unconditional score
62
+
63
+ for layer in attention_maps.keys():
64
+
65
+ for step_num in range(len(attention_maps[layer])):
66
+ if step_num not in attention_maps_cond:
67
+ attention_maps_cond[step_num] = dict()
68
+ attention_maps_uncond[step_num] = dict()
69
+
70
+ attention_maps_uncond[step_num].update(
71
+ {layer: attention_maps[layer][step_num][:1]})
72
+ attention_maps_cond[step_num].update(
73
+ {layer: attention_maps[layer][step_num][1:2]})
74
+
75
+ return attention_maps_cond, attention_maps_uncond
76
+
77
+
78
+ def plot_attention_maps(atten_map_list, obj_tokens, save_dir, seed, tokens_vis=None):
79
+ atten_names = ['presoftmax', 'postsoftmax', 'postsoftmax_erosion']
80
+ for i, attn_map in enumerate(atten_map_list):
81
+ n_obj = len(attn_map)
82
+ plt.figure()
83
+ plt.clf()
84
+
85
+ fig, axs = plt.subplots(
86
+ ncols=n_obj+1, gridspec_kw=dict(width_ratios=[1 for _ in range(n_obj)]+[0.1]))
87
+
88
+ fig.set_figheight(3)
89
+ fig.set_figwidth(3*n_obj+0.1)
90
+
91
+ cmap = plt.get_cmap('OrRd')
92
+
93
+ vmax = 0
94
+ vmin = 1
95
+ for tid in range(n_obj):
96
+ attention_map_cur = attn_map[tid]
97
+ vmax = max(vmax, float(attention_map_cur.max()))
98
+ vmin = min(vmin, float(attention_map_cur.min()))
99
+
100
+ for tid in range(n_obj):
101
+ sns.heatmap(
102
+ attn_map[tid][0], annot=False, cbar=False, ax=axs[tid],
103
+ cmap=cmap, vmin=vmin, vmax=vmax
104
+ )
105
+ axs[tid].set_axis_off()
106
+
107
+ if tokens_vis is not None:
108
+ if tid == n_obj-1:
109
+ axs_xlabel = 'other tokens'
110
+ else:
111
+ axs_xlabel = ''
112
+ for token_id in obj_tokens[tid]:
113
+ axs_xlabel += ' ' + tokens_vis[token_id.item() -
114
+ 1][:-len('</w>')]
115
+ axs[tid].set_title(axs_xlabel)
116
+
117
+ norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
118
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
119
+ fig.colorbar(sm, cax=axs[-1])
120
+ canvas = fig.canvas
121
+ canvas.draw()
122
+ width, height = canvas.get_width_height()
123
+ img = np.frombuffer(canvas.tostring_rgb(),
124
+ dtype='uint8').reshape((height, width, 3))
125
+
126
+ fig.tight_layout()
127
+ plt.close()
128
+ return img
129
+
130
+
131
+ def get_token_maps_deprecated(attention_maps, save_dir, width, height, obj_tokens, seed=0, tokens_vis=None):
132
+ r"""Function to visualize attention maps.
133
+ Args:
134
+ save_dir (str): Path to save attention maps
135
+ batch_size (int): Batch size
136
+ sampler_order (int): Sampler order
137
+ """
138
+
139
+ # Split attention maps over steps
140
+ attention_maps_cond, _ = split_attention_maps_over_steps(
141
+ attention_maps
142
+ )
143
+
144
+ nsteps = len(attention_maps_cond)
145
+ hw_ori = width * height
146
+
147
+ attention_maps = []
148
+ for obj_token in obj_tokens:
149
+ attention_maps.append([])
150
+
151
+ for step_num in range(nsteps):
152
+ attention_maps_cur = attention_maps_cond[step_num]
153
+
154
+ for layer in attention_maps_cur.keys():
155
+ if step_num < 10 or layer not in CrossAttentionLayers:
156
+ continue
157
+
158
+ attention_ind = attention_maps_cur[layer].cpu()
159
+
160
+ # Attention maps are of shape [batch_size, nkeys, 77]
161
+ # since they are averaged out while collecting from hooks to save memory.
162
+ # Now split the heads from batch dimension
163
+ bs, hw, nclip = attention_ind.shape
164
+ down_ratio = np.sqrt(hw_ori // hw)
165
+ width_cur = int(width // down_ratio)
166
+ height_cur = int(height // down_ratio)
167
+ attention_ind = attention_ind.reshape(
168
+ bs, height_cur, width_cur, nclip)
169
+ for obj_id, obj_token in enumerate(obj_tokens):
170
+ if obj_token[0] == -1:
171
+ attention_map_prev = torch.stack(
172
+ [attention_maps[i][-1] for i in range(obj_id)]).sum(0)
173
+ attention_maps[obj_id].append(
174
+ attention_map_prev.max()-attention_map_prev)
175
+ else:
176
+ obj_attention_map = attention_ind[:, :, :, obj_token].max(-1, True)[
177
+ 0].permute([3, 0, 1, 2])
178
+ obj_attention_map = torchvision.transforms.functional.resize(obj_attention_map, (height, width),
179
+ interpolation=torchvision.transforms.InterpolationMode.BICUBIC, antialias=True)
180
+ attention_maps[obj_id].append(obj_attention_map)
181
+
182
+ # average attention maps over steps
183
+ attention_maps_averaged = []
184
+ for obj_id, obj_token in enumerate(obj_tokens):
185
+ if obj_id == len(obj_tokens) - 1:
186
+ attention_maps_averaged.append(
187
+ torch.cat(attention_maps[obj_id]).mean(0))
188
+ else:
189
+ attention_maps_averaged.append(
190
+ torch.cat(attention_maps[obj_id]).mean(0))
191
+
192
+ # normalize attention maps into [0, 1]
193
+ attention_maps_averaged_normalized = []
194
+ attention_maps_averaged_sum = torch.cat(attention_maps_averaged).sum(0)
195
+ for obj_id, obj_token in enumerate(obj_tokens):
196
+ attention_maps_averaged_normalized.append(
197
+ attention_maps_averaged[obj_id]/attention_maps_averaged_sum)
198
+
199
+ # softmax
200
+ attention_maps_averaged_normalized = (
201
+ torch.cat(attention_maps_averaged)/0.001).softmax(0)
202
+ attention_maps_averaged_normalized = [
203
+ attention_maps_averaged_normalized[i:i+1] for i in range(attention_maps_averaged_normalized.shape[0])]
204
+
205
+ token_maps_vis = plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized],
206
+ obj_tokens, save_dir, seed, tokens_vis)
207
+ attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
208
+ [1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized]
209
+ return attention_maps_averaged_normalized, token_maps_vis
210
+
211
+
212
+ def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, height, obj_tokens, seed=0, tokens_vis=None,
213
+ preprocess=False, segment_threshold=0.3, num_segments=5, return_vis=False, save_attn=False):
214
+ r"""Function to visualize attention maps.
215
+ Args:
216
+ save_dir (str): Path to save attention maps
217
+ batch_size (int): Batch size
218
+ sampler_order (int): Sampler order
219
+ """
220
+
221
+ # create the segmentation mask using self-attention maps
222
+ resolution = 32
223
+ attn_maps_1024 = {8: [], 16: [], 32: [], 64: []}
224
+ for attn_map in selfattn_maps.values():
225
+ resolution_map = np.sqrt(attn_map.shape[1]).astype(int)
226
+ if resolution_map != resolution:
227
+ continue
228
+ attn_map = attn_map.reshape(
229
+ 1, resolution_map, resolution_map, resolution_map**2).permute([3, 0, 1, 2])
230
+ attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution),
231
+ mode='bicubic', antialias=True)
232
+ attn_maps_1024[resolution_map].append(attn_map.permute([1, 2, 3, 0]).reshape(
233
+ 1, resolution**2, resolution_map**2))
234
+ attn_maps_1024 = torch.cat([torch.cat(v).mean(0).cpu()
235
+ for v in attn_maps_1024.values() if len(v) > 0], -1).numpy()
236
+ if save_attn:
237
+ print('saving self-attention maps...', attn_maps_1024.shape)
238
+ torch.save(torch.from_numpy(attn_maps_1024),
239
+ 'results/maps/selfattn_maps.pth')
240
+ seed_everything(seed)
241
+ sc = SpectralClustering(num_segments, affinity='precomputed', n_init=100,
242
+ assign_labels='kmeans')
243
+ clusters = sc.fit_predict(attn_maps_1024)
244
+ clusters = clusters.reshape(resolution, resolution)
245
+ fig = plt.figure()
246
+ plt.imshow(clusters)
247
+ plt.axis('off')
248
+ if return_vis:
249
+ canvas = fig.canvas
250
+ canvas.draw()
251
+ cav_width, cav_height = canvas.get_width_height()
252
+ segments_vis = np.frombuffer(canvas.tostring_rgb(),
253
+ dtype='uint8').reshape((cav_height, cav_width, 3))
254
+
255
+ plt.close()
256
+
257
+ # label the segmentation mask using cross-attention maps
258
+ cross_attn_maps_1024 = []
259
+ for attn_map in crossattn_maps.values():
260
+ resolution_map = np.sqrt(attn_map.shape[1]).astype(int)
261
+ attn_map = attn_map.reshape(
262
+ 1, resolution_map, resolution_map, -1).permute([0, 3, 1, 2])
263
+ attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution),
264
+ mode='bicubic', antialias=True)
265
+ cross_attn_maps_1024.append(attn_map.permute([0, 2, 3, 1]))
266
+
267
+ cross_attn_maps_1024 = torch.cat(
268
+ cross_attn_maps_1024).mean(0).cpu().numpy()
269
+ if save_attn:
270
+ print('saving cross-attention maps...', cross_attn_maps_1024.shape)
271
+ torch.save(torch.from_numpy(cross_attn_maps_1024),
272
+ 'results/maps/crossattn_maps.pth')
273
+ normalized_span_maps = []
274
+ for token_ids in obj_tokens:
275
+ span_token_maps = cross_attn_maps_1024[:, :, token_ids.numpy()]
276
+ normalized_span_map = np.zeros_like(span_token_maps)
277
+ for i in range(span_token_maps.shape[-1]):
278
+ curr_noun_map = span_token_maps[:, :, i]
279
+ normalized_span_map[:, :, i] = (
280
+ curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max()
281
+ normalized_span_maps.append(normalized_span_map)
282
+ foreground_token_maps = [np.zeros([clusters.shape[0], clusters.shape[1]]).squeeze(
283
+ ) for normalized_span_map in normalized_span_maps]
284
+ background_map = np.zeros([clusters.shape[0], clusters.shape[1]]).squeeze()
285
+ for c in range(num_segments):
286
+ cluster_mask = np.zeros_like(clusters)
287
+ cluster_mask[clusters == c] = 1.
288
+ is_foreground = False
289
+ for normalized_span_map, foreground_nouns_map, token_ids in zip(normalized_span_maps, foreground_token_maps, obj_tokens):
290
+ score_maps = [cluster_mask * normalized_span_map[:, :, i]
291
+ for i in range(len(token_ids))]
292
+ scores = [score_map.sum() / cluster_mask.sum()
293
+ for score_map in score_maps]
294
+ if max(scores) > segment_threshold:
295
+ foreground_nouns_map += cluster_mask
296
+ is_foreground = True
297
+ if not is_foreground:
298
+ background_map += cluster_mask
299
+ foreground_token_maps.append(background_map)
300
+
301
+ # resize the token maps and visualization
302
+ resized_token_maps = torch.cat([torch.nn.functional.interpolate(torch.from_numpy(token_map).unsqueeze(0).unsqueeze(
303
+ 0), (height, width), mode='bicubic', antialias=True)[0] for token_map in foreground_token_maps]).clamp(0, 1)
304
+
305
+ resized_token_maps = resized_token_maps / \
306
+ (resized_token_maps.sum(0, True)+1e-8)
307
+ resized_token_maps = [token_map.unsqueeze(
308
+ 0) for token_map in resized_token_maps]
309
+ foreground_token_maps = [token_map[None, :, :]
310
+ for token_map in foreground_token_maps]
311
+ token_maps_vis = plot_attention_maps([foreground_token_maps, resized_token_maps], obj_tokens,
312
+ save_dir, seed, tokens_vis)
313
+ resized_token_maps = [token_map.unsqueeze(1).repeat(
314
+ [1, 4, 1, 1]).to(attn_map.dtype).cuda() for token_map in resized_token_maps]
315
+ if return_vis:
316
+ return resized_token_maps, segments_vis, token_maps_vis
317
+ else:
318
+ return resized_token_maps
utils/richtext_utils.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import random
5
+ import numpy as np
6
+
7
+ COLORS = {
8
+ 'brown': [165, 42, 42],
9
+ 'red': [255, 0, 0],
10
+ 'pink': [253, 108, 158],
11
+ 'orange': [255, 165, 0],
12
+ 'yellow': [255, 255, 0],
13
+ 'purple': [128, 0, 128],
14
+ 'green': [0, 128, 0],
15
+ 'blue': [0, 0, 255],
16
+ 'white': [255, 255, 255],
17
+ 'gray': [128, 128, 128],
18
+ 'black': [0, 0, 0],
19
+ }
20
+
21
+
22
+ def seed_everything(seed):
23
+ random.seed(seed)
24
+ os.environ['PYTHONHASHSEED'] = str(seed)
25
+ np.random.seed(seed)
26
+ torch.manual_seed(seed)
27
+ torch.cuda.manual_seed(seed)
28
+
29
+
30
+ def hex_to_rgb(hex_string, return_nearest_color=False):
31
+ r"""
32
+ Covert Hex triplet to RGB triplet.
33
+ """
34
+ # Remove '#' symbol if present
35
+ hex_string = hex_string.lstrip('#')
36
+ # Convert hex values to integers
37
+ red = int(hex_string[0:2], 16)
38
+ green = int(hex_string[2:4], 16)
39
+ blue = int(hex_string[4:6], 16)
40
+ rgb = torch.FloatTensor((red, green, blue))[None, :, None, None]/255.
41
+ if return_nearest_color:
42
+ nearest_color = find_nearest_color(rgb)
43
+ return rgb.cuda(), nearest_color
44
+ return rgb.cuda()
45
+
46
+
47
+ def find_nearest_color(rgb):
48
+ r"""
49
+ Find the nearest neighbor color given the RGB value.
50
+ """
51
+ if isinstance(rgb, list) or isinstance(rgb, tuple):
52
+ rgb = torch.FloatTensor(rgb)[None, :, None, None]/255.
53
+ color_distance = torch.FloatTensor([np.linalg.norm(
54
+ rgb - torch.FloatTensor(COLORS[color])[None, :, None, None]/255.) for color in COLORS.keys()])
55
+ nearest_color = list(COLORS.keys())[torch.argmin(color_distance).item()]
56
+ return nearest_color
57
+
58
+
59
+ def font2style(font):
60
+ r"""
61
+ Convert the font name to the style name.
62
+ """
63
+ return {'mirza': 'Claud Monet, impressionism, oil on canvas',
64
+ 'roboto': 'Ukiyoe',
65
+ 'cursive': 'Cyber Punk, futuristic, blade runner, william gibson, trending on artstation hq',
66
+ 'sofia': 'Pop Art, masterpiece, andy warhol',
67
+ 'slabo': 'Vincent Van Gogh',
68
+ 'inconsolata': 'Pixel Art, 8 bits, 16 bits',
69
+ 'ubuntu': 'Rembrandt',
70
+ 'Monoton': 'neon art, colorful light, highly details, octane render',
71
+ 'Akronim': 'Abstract Cubism, Pablo Picasso', }[font]
72
+
73
+
74
+ def parse_json(json_str):
75
+ r"""
76
+ Convert the JSON string to attributes.
77
+ """
78
+ # initialze region-base attributes.
79
+ base_text_prompt = ''
80
+ style_text_prompts = []
81
+ footnote_text_prompts = []
82
+ footnote_target_tokens = []
83
+ color_text_prompts = []
84
+ color_rgbs = []
85
+ color_names = []
86
+ size_text_prompts_and_sizes = []
87
+
88
+ # parse the attributes from JSON.
89
+ prev_style = None
90
+ prev_color_rgb = None
91
+ use_grad_guidance = False
92
+ for span in json_str['ops']:
93
+ text_prompt = span['insert'].rstrip('\n')
94
+ base_text_prompt += span['insert'].rstrip('\n')
95
+ if text_prompt == ' ':
96
+ continue
97
+ if 'attributes' in span:
98
+ if 'font' in span['attributes']:
99
+ style = font2style(span['attributes']['font'])
100
+ if prev_style == style:
101
+ prev_text_prompt = style_text_prompts[-1].split('in the style of')[
102
+ 0]
103
+ style_text_prompts[-1] = prev_text_prompt + \
104
+ ' ' + text_prompt + f' in the style of {style}'
105
+ else:
106
+ style_text_prompts.append(
107
+ text_prompt + f' in the style of {style}')
108
+ prev_style = style
109
+ else:
110
+ prev_style = None
111
+ if 'link' in span['attributes']:
112
+ footnote_text_prompts.append(span['attributes']['link'])
113
+ footnote_target_tokens.append(text_prompt)
114
+ font_size = 1
115
+ if 'size' in span['attributes'] and 'strike' not in span['attributes']:
116
+ font_size = float(span['attributes']['size'][:-2])/3.
117
+ elif 'size' in span['attributes'] and 'strike' in span['attributes']:
118
+ font_size = -float(span['attributes']['size'][:-2])/3.
119
+ elif 'size' not in span['attributes'] and 'strike' not in span['attributes']:
120
+ font_size = 1
121
+ if 'color' in span['attributes']:
122
+ use_grad_guidance = True
123
+ color_rgb, nearest_color = hex_to_rgb(
124
+ span['attributes']['color'], True)
125
+ if prev_color_rgb == color_rgb:
126
+ prev_text_prompt = color_text_prompts[-1]
127
+ color_text_prompts[-1] = prev_text_prompt + \
128
+ ' ' + text_prompt
129
+ else:
130
+ color_rgbs.append(color_rgb)
131
+ color_names.append(nearest_color)
132
+ color_text_prompts.append(text_prompt)
133
+ if font_size != 1:
134
+ size_text_prompts_and_sizes.append([text_prompt, font_size])
135
+ return base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
136
+ color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance
137
+
138
+
139
+ def get_region_diffusion_input(model, base_text_prompt, style_text_prompts, footnote_text_prompts,
140
+ footnote_target_tokens, color_text_prompts, color_names):
141
+ r"""
142
+ Algorithm 1 in the paper.
143
+ """
144
+ region_text_prompts = []
145
+ region_target_token_ids = []
146
+ base_tokens = model.tokenizer._tokenize(base_text_prompt)
147
+ # process the style text prompt
148
+ for text_prompt in style_text_prompts:
149
+ region_text_prompts.append(text_prompt)
150
+ region_target_token_ids.append([])
151
+ style_tokens = model.tokenizer._tokenize(
152
+ text_prompt.split('in the style of')[0])
153
+ for style_token in style_tokens:
154
+ region_target_token_ids[-1].append(
155
+ base_tokens.index(style_token)+1)
156
+
157
+ # process the complementary text prompt
158
+ for footnote_text_prompt, text_prompt in zip(footnote_text_prompts, footnote_target_tokens):
159
+ region_target_token_ids.append([])
160
+ region_text_prompts.append(footnote_text_prompt)
161
+ style_tokens = model.tokenizer._tokenize(text_prompt)
162
+ for style_token in style_tokens:
163
+ region_target_token_ids[-1].append(
164
+ base_tokens.index(style_token)+1)
165
+
166
+ # process the color text prompt
167
+ for color_text_prompt, color_name in zip(color_text_prompts, color_names):
168
+ region_target_token_ids.append([])
169
+ region_text_prompts.append(color_name+' '+color_text_prompt)
170
+ style_tokens = model.tokenizer._tokenize(color_text_prompt)
171
+ for style_token in style_tokens:
172
+ region_target_token_ids[-1].append(
173
+ base_tokens.index(style_token)+1)
174
+
175
+ # process the remaining tokens without any attributes
176
+ region_text_prompts.append(base_text_prompt)
177
+ region_target_token_ids_all = [
178
+ id for ids in region_target_token_ids for id in ids]
179
+ target_token_ids_rest = [id for id in range(
180
+ 1, len(base_tokens)+1) if id not in region_target_token_ids_all]
181
+ region_target_token_ids.append(target_token_ids_rest)
182
+
183
+ region_target_token_ids = [torch.LongTensor(
184
+ obj_token_id) for obj_token_id in region_target_token_ids]
185
+ return region_text_prompts, region_target_token_ids, base_tokens
186
+
187
+
188
+ def get_attention_control_input(model, base_tokens, size_text_prompts_and_sizes):
189
+ r"""
190
+ Control the token impact using font sizes.
191
+ """
192
+ word_pos = []
193
+ font_sizes = []
194
+ for text_prompt, font_size in size_text_prompts_and_sizes:
195
+ size_tokens = model.tokenizer._tokenize(text_prompt)
196
+ for size_token in size_tokens:
197
+ word_pos.append(base_tokens.index(size_token)+1)
198
+ font_sizes.append(font_size)
199
+ if len(word_pos) > 0:
200
+ word_pos = torch.LongTensor(word_pos).cuda()
201
+ font_sizes = torch.FloatTensor(font_sizes).cuda()
202
+ else:
203
+ word_pos = None
204
+ font_sizes = None
205
+ text_format_dict = {
206
+ 'word_pos': word_pos,
207
+ 'font_size': font_sizes,
208
+ }
209
+ return text_format_dict
210
+
211
+
212
+ def get_gradient_guidance_input(model, base_tokens, color_text_prompts, color_rgbs, text_format_dict,
213
+ guidance_start_step=999, color_guidance_weight=1):
214
+ r"""
215
+ Control the token impact using font sizes.
216
+ """
217
+ color_target_token_ids = []
218
+ for text_prompt in color_text_prompts:
219
+ color_target_token_ids.append([])
220
+ color_tokens = model.tokenizer._tokenize(text_prompt)
221
+ for color_token in color_tokens:
222
+ color_target_token_ids[-1].append(base_tokens.index(color_token)+1)
223
+ color_target_token_ids_all = [
224
+ id for ids in color_target_token_ids for id in ids]
225
+ color_target_token_ids_rest = [id for id in range(
226
+ 1, len(base_tokens)+1) if id not in color_target_token_ids_all]
227
+ color_target_token_ids.append(color_target_token_ids_rest)
228
+ color_target_token_ids = [torch.LongTensor(
229
+ obj_token_id) for obj_token_id in color_target_token_ids]
230
+
231
+ text_format_dict['target_RGB'] = color_rgbs
232
+ text_format_dict['guidance_start_step'] = guidance_start_step
233
+ text_format_dict['color_guidance_weight'] = color_guidance_weight
234
+ return text_format_dict, color_target_token_ids