laceymac songweig commited on
Commit
ff4a299
Β·
0 Parent(s):

Duplicate from songweig/rich-text-to-image

Browse files

Co-authored-by: Songwei Ge <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ venv
2
+ __pycache__/
3
+ *.pyc
4
+ *.png
5
+ *.jpg
6
+ gradio_cached_examples/
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Rich Text To Image
3
+ emoji: 🌍
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 3.27.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: songweig/rich-text-to-image
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import os
4
+ import json
5
+ import time
6
+ import argparse
7
+ import torch
8
+ import numpy as np
9
+ from torchvision import transforms
10
+
11
+ from models.region_diffusion_xl import RegionDiffusionXL
12
+ from utils.attention_utils import get_token_maps
13
+ from utils.richtext_utils import seed_everything, parse_json, get_region_diffusion_input,\
14
+ get_attention_control_input, get_gradient_guidance_input
15
+
16
+
17
+ import gradio as gr
18
+ from PIL import Image, ImageOps
19
+ from share_btn import community_icon_html, loading_icon_html, share_js, css
20
+
21
+
22
+ help_text = """
23
+ If you are encountering an error or not achieving your desired outcome, here are some potential reasons and recommendations to consider:
24
+ 1. If you format only a portion of a word rather than the complete word, an error may occur.
25
+ 2. If you use font color and get completely corrupted results, you may consider decrease the color weight lambda.
26
+ 3. Consider using a different seed.
27
+ """
28
+
29
+
30
+ canvas_html = """<iframe id='rich-text-root' style='width:100%' height='360px' src='file=rich-text-to-json-iframe.html' frameborder='0' scrolling='no'></iframe>"""
31
+ get_js_data = """
32
+ async (text_input, negative_prompt, num_segments, segment_threshold, inject_interval, inject_background, seed, color_guidance_weight, rich_text_input, height, width, steps, guidance_weights) => {
33
+ const richEl = document.getElementById("rich-text-root");
34
+ const data = richEl? richEl.contentDocument.body._data : {};
35
+ return [text_input, negative_prompt, num_segments, segment_threshold, inject_interval, inject_background, seed, color_guidance_weight, JSON.stringify(data), height, width, steps, guidance_weights];
36
+ }
37
+ """
38
+ set_js_data = """
39
+ async (text_input) => {
40
+ const richEl = document.getElementById("rich-text-root");
41
+ const data = text_input ? JSON.parse(text_input) : null;
42
+ if (richEl && data) richEl.contentDocument.body.setQuillContents(data);
43
+ }
44
+ """
45
+
46
+ get_window_url_params = """
47
+ async (url_params) => {
48
+ const params = new URLSearchParams(window.location.search);
49
+ url_params = Object.fromEntries(params);
50
+ return [url_params];
51
+ }
52
+ """
53
+
54
+
55
+ def load_url_params(url_params):
56
+ if 'prompt' in url_params:
57
+ return gr.update(visible=True), url_params
58
+ else:
59
+ return gr.update(visible=False), url_params
60
+
61
+
62
+ def main():
63
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
64
+ model = RegionDiffusionXL()
65
+
66
+ def generate(
67
+ text_input: str,
68
+ negative_text: str,
69
+ num_segments: int,
70
+ segment_threshold: float,
71
+ inject_interval: float,
72
+ inject_background: float,
73
+ seed: int,
74
+ color_guidance_weight: float,
75
+ rich_text_input: str,
76
+ height: int,
77
+ width: int,
78
+ steps: int,
79
+ guidance_weight: float,
80
+ ):
81
+ run_dir = 'results/'
82
+ os.makedirs(run_dir, exist_ok=True)
83
+ # Load region diffusion model.
84
+ height = int(height) if height else 1024
85
+ width = int(width) if width else 1024
86
+ steps = 41 if not steps else steps
87
+ guidance_weight = 8.5 if not guidance_weight else guidance_weight
88
+ text_input = rich_text_input if rich_text_input != '' and rich_text_input != None else text_input
89
+ print('text_input', text_input, width, height, steps, guidance_weight, num_segments, segment_threshold, inject_interval, inject_background, color_guidance_weight, negative_text)
90
+ if (text_input == '' or rich_text_input == ''):
91
+ raise gr.Error("Please enter some text.")
92
+ # parse json to span attributes
93
+ base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
94
+ color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance = parse_json(
95
+ json.loads(text_input))
96
+
97
+ # create control input for region diffusion
98
+ region_text_prompts, region_target_token_ids, base_tokens = get_region_diffusion_input(
99
+ model, base_text_prompt, style_text_prompts, footnote_text_prompts,
100
+ footnote_target_tokens, color_text_prompts, color_names)
101
+
102
+ # create control input for cross attention
103
+ text_format_dict = get_attention_control_input(
104
+ model, base_tokens, size_text_prompts_and_sizes)
105
+
106
+ # create control input for region guidance
107
+ text_format_dict, color_target_token_ids = get_gradient_guidance_input(
108
+ model, base_tokens, color_text_prompts, color_rgbs, text_format_dict, color_guidance_weight=color_guidance_weight)
109
+
110
+ seed_everything(seed)
111
+
112
+ # get token maps from plain text to image generation.
113
+ begin_time = time.time()
114
+ if model.selfattn_maps is None and model.crossattn_maps is None:
115
+ model.remove_tokenmap_hooks()
116
+ model.register_tokenmap_hooks()
117
+ else:
118
+ model.remove_tokenmap_hooks()
119
+ model.remove_tokenmap_hooks()
120
+ plain_img = model.sample([base_text_prompt], negative_prompt=[negative_text],
121
+ height=height, width=width, num_inference_steps=steps,
122
+ guidance_scale=guidance_weight, run_rich_text=False)
123
+ print('time lapses to get attention maps: %.4f' %
124
+ (time.time()-begin_time))
125
+ seed_everything(seed)
126
+ color_obj_masks, segments_vis, token_maps = get_token_maps(model.selfattn_maps, model.crossattn_maps, model.n_maps, run_dir,
127
+ 1024//8, 1024//8, color_target_token_ids[:-1], seed,
128
+ base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
129
+ return_vis=True)
130
+ seed_everything(seed)
131
+ model.masks, segments_vis, token_maps = get_token_maps(model.selfattn_maps, model.crossattn_maps, model.n_maps, run_dir,
132
+ 1024//8, 1024//8, region_target_token_ids[:-1], seed,
133
+ base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
134
+ return_vis=True)
135
+ color_obj_atten_all = torch.zeros_like(color_obj_masks[-1])
136
+ for obj_mask in color_obj_masks[:-1]:
137
+ color_obj_atten_all += obj_mask
138
+ color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
139
+ interpolation=transforms.InterpolationMode.BICUBIC,
140
+ antialias=True)
141
+ for color_obj_mask in color_obj_masks]
142
+ text_format_dict['color_obj_atten'] = color_obj_masks
143
+ text_format_dict['color_obj_atten_all'] = color_obj_atten_all
144
+ model.remove_tokenmap_hooks()
145
+
146
+ # generate image from rich text
147
+ begin_time = time.time()
148
+ seed_everything(seed)
149
+ rich_img = model.sample(region_text_prompts, negative_prompt=[negative_text],
150
+ height=height, width=width, num_inference_steps=steps,
151
+ guidance_scale=guidance_weight, use_guidance=use_grad_guidance,
152
+ text_format_dict=text_format_dict, inject_selfattn=inject_interval,
153
+ inject_background=inject_background, run_rich_text=True)
154
+ print('time lapses to generate image from rich text: %.4f' %
155
+ (time.time()-begin_time))
156
+ return [plain_img.images[0], rich_img.images[0], segments_vis, token_maps]
157
+
158
+ with gr.Blocks(css=css) as demo:
159
+ url_params = gr.JSON({}, visible=False, label="URL Params")
160
+ gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">Expressive Text-to-Image Generation with Rich Text</h1>
161
+ <p> <a href="https://songweige.github.io/">Songwei Ge</a>, <a href="https://taesung.me/">Taesung Park</a>, <a href="https://www.cs.cmu.edu/~junyanz/">Jun-Yan Zhu</a>, <a href="https://jbhuang0604.github.io/">Jia-Bin Huang</a> <p/>
162
+ <p> UMD, Adobe, CMU <p/>
163
+ <p> ICCV, 2023 <p/>
164
+ <p> <a href="https://huggingface.co/spaces/songweig/rich-text-to-image?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="display:inline;"alt="Duplicate Space"></a> | <a href="https://rich-text-to-image.github.io">[Website]</a> | <a href="https://github.com/SongweiGe/rich-text-to-image">[Code]</a> | <a href="https://arxiv.org/abs/2304.06720">[Paper]</a><p/>
165
+ <p> Our method is now using Stable Diffusion XL. For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.""")
166
+ with gr.Row():
167
+ with gr.Column():
168
+ rich_text_el = gr.HTML(canvas_html, elem_id="canvas_html")
169
+ rich_text_input = gr.Textbox(value="", visible=False)
170
+ text_input = gr.Textbox(
171
+ label='Rich-text JSON Input',
172
+ visible=False,
173
+ max_lines=1,
174
+ placeholder='Example: \'{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#b26b00"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background.\n"}]}\'',
175
+ elem_id="text_input"
176
+ )
177
+ negative_prompt = gr.Textbox(
178
+ label='Negative Prompt',
179
+ max_lines=1,
180
+ placeholder='Example: poor quality, blurry, dark, low resolution, low quality, worst quality',
181
+ elem_id="negative_prompt"
182
+ )
183
+ segment_threshold = gr.Slider(label='Token map threshold',
184
+ info='(See less area in token maps? Decrease this. See too much area? Increase this.)',
185
+ minimum=0,
186
+ maximum=1,
187
+ step=0.01,
188
+ value=0.25)
189
+ inject_interval = gr.Slider(label='Detail preservation',
190
+ info='(To preserve more structure from plain-text generation, increase this. To see more rich-text attributes, decrease this.)',
191
+ minimum=0,
192
+ maximum=1,
193
+ step=0.01,
194
+ value=0.)
195
+ inject_background = gr.Slider(label='Unformatted token preservation',
196
+ info='(To affect less the tokens without any rich-text attributes, increase this.)',
197
+ minimum=0,
198
+ maximum=1,
199
+ step=0.01,
200
+ value=0.3)
201
+ color_guidance_weight = gr.Slider(label='Color weight',
202
+ info='(To obtain more precise color, increase this, while too large value may cause artifacts.)',
203
+ minimum=0,
204
+ maximum=2,
205
+ step=0.1,
206
+ value=0.5)
207
+ num_segments = gr.Slider(label='Number of segments',
208
+ minimum=2,
209
+ maximum=20,
210
+ step=1,
211
+ value=9)
212
+ seed = gr.Slider(label='Seed',
213
+ minimum=0,
214
+ maximum=100000,
215
+ step=1,
216
+ value=6,
217
+ elem_id="seed"
218
+ )
219
+ with gr.Accordion('Other Parameters', open=False):
220
+ steps = gr.Slider(label='Number of Steps',
221
+ minimum=0,
222
+ maximum=500,
223
+ step=1,
224
+ value=41)
225
+ guidance_weight = gr.Slider(label='CFG weight',
226
+ minimum=0,
227
+ maximum=50,
228
+ step=0.1,
229
+ value=8.5)
230
+ width = gr.Dropdown(choices=[1024],
231
+ value=1024,
232
+ label='Width',
233
+ visible=True)
234
+ height = gr.Dropdown(choices=[1024],
235
+ value=1024,
236
+ label='height',
237
+ visible=True)
238
+
239
+ with gr.Row():
240
+ with gr.Column(scale=1, min_width=100):
241
+ generate_button = gr.Button("Generate")
242
+ load_params_button = gr.Button(
243
+ "Load from URL Params", visible=True)
244
+ with gr.Column():
245
+ richtext_result = gr.Image(
246
+ label='Rich-text', elem_id="rich-text-image")
247
+ richtext_result.style(height=784)
248
+ with gr.Row():
249
+ plaintext_result = gr.Image(
250
+ label='Plain-text', elem_id="plain-text-image")
251
+ segments = gr.Image(label='Segmentation')
252
+ with gr.Row():
253
+ token_map = gr.Image(label='Token Maps')
254
+ with gr.Row(visible=False) as share_row:
255
+ with gr.Group(elem_id="share-btn-container"):
256
+ community_icon = gr.HTML(community_icon_html)
257
+ loading_icon = gr.HTML(loading_icon_html)
258
+ share_button = gr.Button(
259
+ "Share to community", elem_id="share-btn")
260
+ share_button.click(None, [], [], _js=share_js)
261
+ # with gr.Row():
262
+ # gr.Markdown(help_text)
263
+
264
+ with gr.Row():
265
+ footnote_examples = [
266
+ [
267
+ '{"ops":[{"insert":"A close-up 4k dslr photo of a "},{"attributes":{"link":"A cat wearing sunglasses and a bandana around its neck."},"insert":"cat"},{"insert":" riding a scooter. Palm trees in the background."}]}',
268
+ '',
269
+ 9,
270
+ 0.3,
271
+ 0.3,
272
+ 0.5,
273
+ 3,
274
+ 0,
275
+ None,
276
+ ],
277
+ [
278
+ '{"ops":[{"insert":"A cozy "},{"attributes":{"link":"A charming wooden cabin with Christmas decoration, warm light coming out from the windows."},"insert":"cabin"},{"insert":" nestled in a "},{"attributes":{"link":"Towering evergreen trees covered in a thick layer of pristine snow."},"insert":"snowy forest"},{"insert":", and a "},{"attributes":{"link":"A cute snowman wearing a carrot nose, coal eyes, and a colorful scarf, welcoming visitors with a cheerful vibe."},"insert":"snowman"},{"insert":" stands in the yard."}]}',
279
+ '',
280
+ 12,
281
+ 0.4,
282
+ 0.3,
283
+ 0.5,
284
+ 3,
285
+ 0,
286
+ None,
287
+ ],
288
+ [
289
+ '{"ops":[{"insert":"A "},{"attributes":{"link":"Happy Kung fu panda art, elder, asian art, volumetric lighting, dramatic scene, ultra detailed, realism, chinese"},"insert":"panda"},{"insert":" standing on a cliff by a waterfall, wildlife photography, photograph, high quality, wildlife, f 1.8, soft focus, 8k, national geographic, award - winning photograph by nick nichols"}]}',
290
+ '',
291
+ 5,
292
+ 0.3,
293
+ 0,
294
+ 0.1,
295
+ 4,
296
+ 0,
297
+ None,
298
+ ],
299
+ ]
300
+
301
+ gr.Examples(examples=footnote_examples,
302
+ label='Footnote examples',
303
+ inputs=[
304
+ text_input,
305
+ negative_prompt,
306
+ num_segments,
307
+ segment_threshold,
308
+ inject_interval,
309
+ inject_background,
310
+ seed,
311
+ color_guidance_weight,
312
+ rich_text_input,
313
+ ],
314
+ outputs=[
315
+ plaintext_result,
316
+ richtext_result,
317
+ segments,
318
+ token_map,
319
+ ],
320
+ fn=generate,
321
+ cache_examples=True,
322
+ examples_per_page=20)
323
+ with gr.Row():
324
+ color_examples = [
325
+ [
326
+ '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#04a704"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
327
+ 'lowres, had anatomy, bad hands, cropped, worst quality',
328
+ 11,
329
+ 0.5,
330
+ 0.3,
331
+ 0.3,
332
+ 6,
333
+ 0.5,
334
+ None,
335
+ ],
336
+ [
337
+ '{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#FD6C9E"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background."}]}',
338
+ '',
339
+ 10,
340
+ 0.5,
341
+ 0.5,
342
+ 0.3,
343
+ 7,
344
+ 0.5,
345
+ None,
346
+ ],
347
+ ]
348
+ gr.Examples(examples=color_examples,
349
+ label='Font color examples',
350
+ inputs=[
351
+ text_input,
352
+ negative_prompt,
353
+ num_segments,
354
+ segment_threshold,
355
+ inject_interval,
356
+ inject_background,
357
+ seed,
358
+ color_guidance_weight,
359
+ rich_text_input,
360
+ ],
361
+ outputs=[
362
+ plaintext_result,
363
+ richtext_result,
364
+ segments,
365
+ token_map,
366
+ ],
367
+ fn=generate,
368
+ cache_examples=True,
369
+ examples_per_page=20)
370
+
371
+ with gr.Row():
372
+ style_examples = [
373
+ [
374
+ '{"ops":[{"insert":"a beautiful"},{"attributes":{"font":"mirza"},"insert":" garden"},{"insert":" with a "},{"attributes":{"font":"roboto"},"insert":"snow mountain"},{"insert":" in the background"}]}',
375
+ '',
376
+ 10,
377
+ 0.6,
378
+ 0,
379
+ 0.4,
380
+ 5,
381
+ 0,
382
+ None,
383
+ ],
384
+ [
385
+ '{"ops":[{"insert":"a night"},{"attributes":{"font":"slabo"},"insert":" sky"},{"insert":" filled with stars above a turbulent"},{"attributes":{"font":"roboto"},"insert":" sea"},{"insert":" with giant waves"}]}',
386
+ '',
387
+ 2,
388
+ 0.6,
389
+ 0,
390
+ 0,
391
+ 6,
392
+ 0.5,
393
+ None,
394
+ ],
395
+ ]
396
+ gr.Examples(examples=style_examples,
397
+ label='Font style examples',
398
+ inputs=[
399
+ text_input,
400
+ negative_prompt,
401
+ num_segments,
402
+ segment_threshold,
403
+ inject_interval,
404
+ inject_background,
405
+ seed,
406
+ color_guidance_weight,
407
+ rich_text_input,
408
+ ],
409
+ outputs=[
410
+ plaintext_result,
411
+ richtext_result,
412
+ segments,
413
+ token_map,
414
+ ],
415
+ fn=generate,
416
+ cache_examples=True,
417
+ examples_per_page=20)
418
+
419
+ with gr.Row():
420
+ size_examples = [
421
+ [
422
+ '{"ops": [{"insert": "A pizza with "}, {"attributes": {"size": "60px"}, "insert": "pineapple"}, {"insert": " pepperoni, and mushroom on the top"}]}',
423
+ '',
424
+ 5,
425
+ 0.3,
426
+ 0,
427
+ 0,
428
+ 3,
429
+ 1,
430
+ None,
431
+ ],
432
+ [
433
+ '{"ops": [{"insert": "A pizza with pineapple, "}, {"attributes": {"size": "60px"}, "insert": "pepperoni"}, {"insert": ", and mushroom on the top"}]}',
434
+ '',
435
+ 5,
436
+ 0.3,
437
+ 0,
438
+ 0,
439
+ 3,
440
+ 1,
441
+ None,
442
+ ],
443
+ [
444
+ '{"ops": [{"insert": "A pizza with pineapple, pepperoni, and "}, {"attributes": {"size": "60px"}, "insert": "mushroom"}, {"insert": " on the top"}]}',
445
+ '',
446
+ 5,
447
+ 0.3,
448
+ 0,
449
+ 0,
450
+ 3,
451
+ 1,
452
+ None,
453
+ ],
454
+ ]
455
+ gr.Examples(examples=size_examples,
456
+ label='Font size examples',
457
+ inputs=[
458
+ text_input,
459
+ negative_prompt,
460
+ num_segments,
461
+ segment_threshold,
462
+ inject_interval,
463
+ inject_background,
464
+ seed,
465
+ color_guidance_weight,
466
+ rich_text_input,
467
+ ],
468
+ outputs=[
469
+ plaintext_result,
470
+ richtext_result,
471
+ segments,
472
+ token_map,
473
+ ],
474
+ fn=generate,
475
+ cache_examples=True,
476
+ examples_per_page=20)
477
+ generate_button.click(fn=lambda: gr.update(visible=False), inputs=None, outputs=share_row, queue=False).then(
478
+ fn=generate,
479
+ inputs=[
480
+ text_input,
481
+ negative_prompt,
482
+ num_segments,
483
+ segment_threshold,
484
+ inject_interval,
485
+ inject_background,
486
+ seed,
487
+ color_guidance_weight,
488
+ rich_text_input,
489
+ height,
490
+ width,
491
+ steps,
492
+ guidance_weight,
493
+ ],
494
+ outputs=[plaintext_result, richtext_result, segments, token_map],
495
+ _js=get_js_data
496
+ ).then(
497
+ fn=lambda: gr.update(visible=True), inputs=None, outputs=share_row, queue=False)
498
+ text_input.change(
499
+ fn=None, inputs=[text_input], outputs=None, _js=set_js_data, queue=False)
500
+ # load url param prompt to textinput
501
+ load_params_button.click(fn=lambda x: x['prompt'], inputs=[
502
+ url_params], outputs=[text_input], queue=False)
503
+ demo.load(
504
+ fn=load_url_params,
505
+ inputs=[url_params],
506
+ outputs=[load_params_button, url_params],
507
+ _js=get_window_url_params
508
+ )
509
+ demo.queue(concurrency_count=1)
510
+ demo.launch(share=False)
511
+
512
+
513
+ if __name__ == "__main__":
514
+ main()
app_sd.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import os
4
+ import json
5
+ import time
6
+ import argparse
7
+ import torch
8
+ import numpy as np
9
+ from torchvision import transforms
10
+
11
+ from models.region_diffusion import RegionDiffusion
12
+ from utils.attention_utils import get_token_maps
13
+ from utils.richtext_utils import seed_everything, parse_json, get_region_diffusion_input,\
14
+ get_attention_control_input, get_gradient_guidance_input
15
+
16
+
17
+ import gradio as gr
18
+ from PIL import Image, ImageOps
19
+ from share_btn import community_icon_html, loading_icon_html, share_js, css
20
+
21
+
22
+ help_text = """
23
+ If you are encountering an error or not achieving your desired outcome, here are some potential reasons and recommendations to consider:
24
+ 1. If you format only a portion of a word rather than the complete word, an error may occur.
25
+ 2. If you use font color and get completely corrupted results, you may consider decrease the color weight lambda.
26
+ 3. Consider using a different seed.
27
+ """
28
+
29
+
30
+ canvas_html = """<iframe id='rich-text-root' style='width:100%' height='360px' src='file=rich-text-to-json-iframe.html' frameborder='0' scrolling='no'></iframe>"""
31
+ get_js_data = """
32
+ async (text_input, negative_prompt, height, width, seed, steps, num_segments, segment_threshold, inject_interval, guidance_weight, color_guidance_weight, rich_text_input, background_aug) => {
33
+ const richEl = document.getElementById("rich-text-root");
34
+ const data = richEl? richEl.contentDocument.body._data : {};
35
+ return [text_input, negative_prompt, height, width, seed, steps, num_segments, segment_threshold, inject_interval, guidance_weight, color_guidance_weight, JSON.stringify(data), background_aug];
36
+ }
37
+ """
38
+ set_js_data = """
39
+ async (text_input) => {
40
+ const richEl = document.getElementById("rich-text-root");
41
+ const data = text_input ? JSON.parse(text_input) : null;
42
+ if (richEl && data) richEl.contentDocument.body.setQuillContents(data);
43
+ }
44
+ """
45
+
46
+ get_window_url_params = """
47
+ async (url_params) => {
48
+ const params = new URLSearchParams(window.location.search);
49
+ url_params = Object.fromEntries(params);
50
+ return [url_params];
51
+ }
52
+ """
53
+
54
+
55
+ def load_url_params(url_params):
56
+ if 'prompt' in url_params:
57
+ return gr.update(visible=True), url_params
58
+ else:
59
+ return gr.update(visible=False), url_params
60
+
61
+
62
+ def main():
63
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
64
+ model = RegionDiffusion(device)
65
+
66
+ def generate(
67
+ text_input: str,
68
+ negative_text: str,
69
+ height: int,
70
+ width: int,
71
+ seed: int,
72
+ steps: int,
73
+ num_segments: int,
74
+ segment_threshold: float,
75
+ inject_interval: float,
76
+ guidance_weight: float,
77
+ color_guidance_weight: float,
78
+ rich_text_input: str,
79
+ background_aug: bool,
80
+ ):
81
+ run_dir = 'results/'
82
+ os.makedirs(run_dir, exist_ok=True)
83
+ # Load region diffusion model.
84
+ height = int(height)
85
+ width = int(width)
86
+ steps = 41 if not steps else steps
87
+ guidance_weight = 8.5 if not guidance_weight else guidance_weight
88
+ text_input = rich_text_input if rich_text_input != '' else text_input
89
+ print('text_input', text_input)
90
+ if (text_input == '' or rich_text_input == ''):
91
+ raise gr.Error("Please enter some text.")
92
+ # parse json to span attributes
93
+ base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
94
+ color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance = parse_json(
95
+ json.loads(text_input))
96
+
97
+ # create control input for region diffusion
98
+ region_text_prompts, region_target_token_ids, base_tokens = get_region_diffusion_input(
99
+ model, base_text_prompt, style_text_prompts, footnote_text_prompts,
100
+ footnote_target_tokens, color_text_prompts, color_names)
101
+
102
+ # create control input for cross attention
103
+ text_format_dict = get_attention_control_input(
104
+ model, base_tokens, size_text_prompts_and_sizes)
105
+
106
+ # create control input for region guidance
107
+ text_format_dict, color_target_token_ids = get_gradient_guidance_input(
108
+ model, base_tokens, color_text_prompts, color_rgbs, text_format_dict, color_guidance_weight=color_guidance_weight)
109
+
110
+ seed_everything(seed)
111
+
112
+ # get token maps from plain text to image generation.
113
+ begin_time = time.time()
114
+ if model.selfattn_maps is None and model.crossattn_maps is None:
115
+ model.remove_tokenmap_hooks()
116
+ model.register_tokenmap_hooks()
117
+ else:
118
+ model.reset_attention_maps()
119
+ model.remove_tokenmap_hooks()
120
+ plain_img = model.produce_attn_maps([base_text_prompt], [negative_text],
121
+ height=height, width=width, num_inference_steps=steps,
122
+ guidance_scale=guidance_weight)
123
+ print('time lapses to get attention maps: %.4f' %
124
+ (time.time()-begin_time))
125
+ seed_everything(seed)
126
+ color_obj_masks, segments_vis, token_maps = get_token_maps(model.selfattn_maps, model.crossattn_maps, model.n_maps, run_dir,
127
+ 512//8, 512//8, color_target_token_ids[:-1], seed,
128
+ base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
129
+ return_vis=True)
130
+ seed_everything(seed)
131
+ model.masks, segments_vis, token_maps = get_token_maps(model.selfattn_maps, model.crossattn_maps, model.n_maps, run_dir,
132
+ 512//8, 512//8, region_target_token_ids[:-1], seed,
133
+ base_tokens, segment_threshold=segment_threshold, num_segments=num_segments,
134
+ return_vis=True)
135
+ color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
136
+ interpolation=transforms.InterpolationMode.BICUBIC,
137
+ antialias=True)
138
+ for color_obj_mask in color_obj_masks]
139
+ text_format_dict['color_obj_atten'] = color_obj_masks
140
+ model.remove_tokenmap_hooks()
141
+
142
+ # generate image from rich text
143
+ begin_time = time.time()
144
+ seed_everything(seed)
145
+ if background_aug:
146
+ bg_aug_end = 500
147
+ else:
148
+ bg_aug_end = 1000
149
+ rich_img = model.prompt_to_img(region_text_prompts, [negative_text],
150
+ height=height, width=width, num_inference_steps=steps,
151
+ guidance_scale=guidance_weight, use_guidance=use_grad_guidance,
152
+ text_format_dict=text_format_dict, inject_selfattn=inject_interval,
153
+ bg_aug_end=bg_aug_end)
154
+ print('time lapses to generate image from rich text: %.4f' %
155
+ (time.time()-begin_time))
156
+ return [plain_img[0], rich_img[0], segments_vis, token_maps]
157
+
158
+ with gr.Blocks(css=css) as demo:
159
+ url_params = gr.JSON({}, visible=False, label="URL Params")
160
+ gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">Expressive Text-to-Image Generation with Rich Text</h1>
161
+ <p> <a href="https://songweige.github.io/">Songwei Ge</a>, <a href="https://taesung.me/">Taesung Park</a>, <a href="https://www.cs.cmu.edu/~junyanz/">Jun-Yan Zhu</a>, <a href="https://jbhuang0604.github.io/">Jia-Bin Huang</a> <p/>
162
+ <p> UMD, Adobe, CMU <p/>
163
+ <p> <a href="https://huggingface.co/spaces/songweig/rich-text-to-image?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="display:inline;"alt="Duplicate Space"></a> | <a href="https://rich-text-to-image.github.io">[Website]</a> | <a href="https://github.com/SongweiGe/rich-text-to-image">[Code]</a> | <a href="https://arxiv.org/abs/2304.06720">[Paper]</a><p/>
164
+ <p> For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.""")
165
+ with gr.Row():
166
+ with gr.Column():
167
+ rich_text_el = gr.HTML(canvas_html, elem_id="canvas_html")
168
+ rich_text_input = gr.Textbox(value="", visible=False)
169
+ text_input = gr.Textbox(
170
+ label='Rich-text JSON Input',
171
+ visible=False,
172
+ max_lines=1,
173
+ placeholder='Example: \'{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#b26b00"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background.\n"}]}\'',
174
+ elem_id="text_input"
175
+ )
176
+ negative_prompt = gr.Textbox(
177
+ label='Negative Prompt',
178
+ max_lines=1,
179
+ placeholder='Example: poor quality, blurry, dark, low resolution, low quality, worst quality',
180
+ elem_id="negative_prompt"
181
+ )
182
+ segment_threshold = gr.Slider(label='Token map threshold',
183
+ info='(See less area in token maps? Decrease this. See too much area? Increase this.)',
184
+ minimum=0,
185
+ maximum=1,
186
+ step=0.01,
187
+ value=0.25)
188
+ inject_interval = gr.Slider(label='Detail preservation',
189
+ info='(To preserve more structure from plain-text generation, increase this. To see more rich-text attributes, decrease this.)',
190
+ minimum=0,
191
+ maximum=1,
192
+ step=0.01,
193
+ value=0.)
194
+ color_guidance_weight = gr.Slider(label='Color weight',
195
+ info='(To obtain more precise color, increase this, while too large value may cause artifacts.)',
196
+ minimum=0,
197
+ maximum=2,
198
+ step=0.1,
199
+ value=0.5)
200
+ num_segments = gr.Slider(label='Number of segments',
201
+ minimum=2,
202
+ maximum=20,
203
+ step=1,
204
+ value=9)
205
+ seed = gr.Slider(label='Seed',
206
+ minimum=0,
207
+ maximum=100000,
208
+ step=1,
209
+ value=6,
210
+ elem_id="seed"
211
+ )
212
+ background_aug = gr.Checkbox(
213
+ label='Precise region alignment',
214
+ info='(For strict region alignment, select this option, but beware of potential artifacts when using with style.)',
215
+ value=True)
216
+ with gr.Accordion('Other Parameters', open=False):
217
+ steps = gr.Slider(label='Number of Steps',
218
+ minimum=0,
219
+ maximum=500,
220
+ step=1,
221
+ value=41)
222
+ guidance_weight = gr.Slider(label='CFG weight',
223
+ minimum=0,
224
+ maximum=50,
225
+ step=0.1,
226
+ value=8.5)
227
+ width = gr.Dropdown(choices=[512],
228
+ value=512,
229
+ label='Width',
230
+ visible=True)
231
+ height = gr.Dropdown(choices=[512],
232
+ value=512,
233
+ label='height',
234
+ visible=True)
235
+
236
+ with gr.Row():
237
+ with gr.Column(scale=1, min_width=100):
238
+ generate_button = gr.Button("Generate")
239
+ load_params_button = gr.Button(
240
+ "Load from URL Params", visible=True)
241
+ with gr.Column():
242
+ richtext_result = gr.Image(
243
+ label='Rich-text', elem_id="rich-text-image")
244
+ richtext_result.style(height=512)
245
+ with gr.Row():
246
+ plaintext_result = gr.Image(
247
+ label='Plain-text', elem_id="plain-text-image")
248
+ segments = gr.Image(label='Segmentation')
249
+ with gr.Row():
250
+ token_map = gr.Image(label='Token Maps')
251
+ with gr.Row(visible=False) as share_row:
252
+ with gr.Group(elem_id="share-btn-container"):
253
+ community_icon = gr.HTML(community_icon_html)
254
+ loading_icon = gr.HTML(loading_icon_html)
255
+ share_button = gr.Button(
256
+ "Share to community", elem_id="share-btn")
257
+ share_button.click(None, [], [], _js=share_js)
258
+ with gr.Row():
259
+ gr.Markdown(help_text)
260
+
261
+ with gr.Row():
262
+ footnote_examples = [
263
+ [
264
+ '{"ops":[{"insert":"A close-up 4k dslr photo of a "},{"attributes":{"link":"A cat wearing sunglasses and a bandana around its neck."},"insert":"cat"},{"insert":" riding a scooter. Palm trees in the background."}]}',
265
+ '',
266
+ 5,
267
+ 0.3,
268
+ 0,
269
+ 6,
270
+ 1,
271
+ None,
272
+ True
273
+ ],
274
+ [
275
+ '{"ops":[{"insert":"A "},{"attributes":{"link":"kitchen island with a stove with gas burners and a built-in oven "},"insert":"kitchen island"},{"insert":" next to a "},{"attributes":{"link":"an open refrigerator stocked with fresh produce, dairy products, and beverages. "},"insert":"refrigerator"},{"insert":", by James McDonald and Joarc Architects, home, interior, octane render, deviantart, cinematic, key art, hyperrealism, sun light, sunrays, canon eos c 300, Ζ’ 1.8, 35 mm, 8k, medium - format print"}]}',
276
+ '',
277
+ 6,
278
+ 0.5,
279
+ 0,
280
+ 6,
281
+ 1,
282
+ None,
283
+ True
284
+ ],
285
+ [
286
+ '{"ops":[{"insert":"A "},{"attributes":{"link":"Happy Kung fu panda art, elder, asian art, volumetric lighting, dramatic scene, ultra detailed, realism, chinese"},"insert":"panda"},{"insert":" standing on a cliff by a waterfall, wildlife photography, photograph, high quality, wildlife, f 1.8, soft focus, 8k, national geographic, award - winning photograph by nick nichols"}]}',
287
+ '',
288
+ 4,
289
+ 0.3,
290
+ 0,
291
+ 4,
292
+ 1,
293
+ None,
294
+ True
295
+ ],
296
+ ]
297
+
298
+ gr.Examples(examples=footnote_examples,
299
+ label='Footnote examples',
300
+ inputs=[
301
+ text_input,
302
+ negative_prompt,
303
+ num_segments,
304
+ segment_threshold,
305
+ inject_interval,
306
+ seed,
307
+ color_guidance_weight,
308
+ rich_text_input,
309
+ background_aug,
310
+ ],
311
+ outputs=[
312
+ plaintext_result,
313
+ richtext_result,
314
+ segments,
315
+ token_map,
316
+ ],
317
+ fn=generate,
318
+ # cache_examples=True,
319
+ examples_per_page=20)
320
+ with gr.Row():
321
+ color_examples = [
322
+ [
323
+ '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#00ffff"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
324
+ 'lowres, had anatomy, bad hands, cropped, worst quality',
325
+ 9,
326
+ 0.25,
327
+ 0.3,
328
+ 6,
329
+ 0.5,
330
+ None,
331
+ True
332
+ ],
333
+ [
334
+ '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#eeeeee"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
335
+ 'lowres, had anatomy, bad hands, cropped, worst quality',
336
+ 9,
337
+ 0.25,
338
+ 0.3,
339
+ 6,
340
+ 0.1,
341
+ None,
342
+ True
343
+ ],
344
+ [
345
+ '{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#FD6C9E"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background."}]}',
346
+ '',
347
+ 5,
348
+ 0.3,
349
+ 0.5,
350
+ 6,
351
+ 0.5,
352
+ None,
353
+ False
354
+ ],
355
+ [
356
+ '{"ops":[{"insert":"A mesmerizing sight that captures the beauty of a "},{"attributes":{"color":"#4775fc"},"insert":"rose"},{"insert":" blooming, close up"}]}',
357
+ '',
358
+ 3,
359
+ 0.3,
360
+ 0,
361
+ 9,
362
+ 1,
363
+ None,
364
+ False
365
+ ],
366
+ [
367
+ '{"ops":[{"insert":"A "},{"attributes":{"color":"#FFD700"},"insert":"marble statue of a wolf\'s head and shoulder"},{"insert":", surrounded by colorful flowers michelangelo, detailed, intricate, full of color, led lighting, trending on artstation, 4 k, hyperrealistic, 3 5 mm, focused, extreme details, unreal engine 5, masterpiece "}]}',
368
+ '',
369
+ 5,
370
+ 0.3,
371
+ 0,
372
+ 5,
373
+ 0.6,
374
+ None,
375
+ False
376
+ ],
377
+ ]
378
+ gr.Examples(examples=color_examples,
379
+ label='Font color examples',
380
+ inputs=[
381
+ text_input,
382
+ negative_prompt,
383
+ num_segments,
384
+ segment_threshold,
385
+ inject_interval,
386
+ seed,
387
+ color_guidance_weight,
388
+ rich_text_input,
389
+ background_aug,
390
+ ],
391
+ outputs=[
392
+ plaintext_result,
393
+ richtext_result,
394
+ segments,
395
+ token_map,
396
+ ],
397
+ fn=generate,
398
+ # cache_examples=True,
399
+ examples_per_page=20)
400
+
401
+ with gr.Row():
402
+ style_examples = [
403
+ [
404
+ '{"ops":[{"insert":"a "},{"attributes":{"font":"mirza"},"insert":"beautiful garden"},{"insert":" with a "},{"attributes":{"font":"roboto"},"insert":"snow mountain in the background"},{"insert":""}]}',
405
+ '',
406
+ 10,
407
+ 0.45,
408
+ 0,
409
+ 0.2,
410
+ 3,
411
+ 0.5,
412
+ None,
413
+ False
414
+ ],
415
+ [
416
+ '{"ops":[{"attributes":{"link":"the awe-inspiring sky and ocean in the style of J.M.W. Turner"},"insert":"the awe-inspiring sky and sea"},{"insert":" by "},{"attributes":{"font":"mirza"},"insert":"a coast with flowers and grasses in spring"}]}',
417
+ 'worst quality, dark, poor quality',
418
+ 2,
419
+ 0.45,
420
+ 0,
421
+ 9,
422
+ 0.5,
423
+ None,
424
+ False
425
+ ],
426
+ [
427
+ '{"ops":[{"insert":"a "},{"attributes":{"font":"slabo"},"insert":"night sky filled with stars"},{"insert":" above a "},{"attributes":{"font":"roboto"},"insert":"turbulent sea with giant waves"}]}',
428
+ '',
429
+ 2,
430
+ 0.45,
431
+ 0,
432
+ 0,
433
+ 6,
434
+ 0.5,
435
+ None,
436
+ False
437
+ ],
438
+ ]
439
+ gr.Examples(examples=style_examples,
440
+ label='Font style examples',
441
+ inputs=[
442
+ text_input,
443
+ negative_prompt,
444
+ num_segments,
445
+ segment_threshold,
446
+ inject_interval,
447
+ seed,
448
+ color_guidance_weight,
449
+ rich_text_input,
450
+ background_aug,
451
+ ],
452
+ outputs=[
453
+ plaintext_result,
454
+ richtext_result,
455
+ segments,
456
+ token_map,
457
+ ],
458
+ fn=generate,
459
+ # cache_examples=True,
460
+ examples_per_page=20)
461
+
462
+ with gr.Row():
463
+ size_examples = [
464
+ [
465
+ '{"ops": [{"insert": "A pizza with "}, {"attributes": {"size": "60px"}, "insert": "pineapple"}, {"insert": ", pepperoni, and mushroom on the top, 4k, photorealistic"}]}',
466
+ 'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
467
+ 5,
468
+ 0.3,
469
+ 0,
470
+ 13,
471
+ 1,
472
+ None,
473
+ False
474
+ ],
475
+ [
476
+ '{"ops": [{"insert": "A pizza with pineapple, "}, {"attributes": {"size": "20px"}, "insert": "pepperoni"}, {"insert": ", and mushroom on the top, 4k, photorealistic"}]}',
477
+ 'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
478
+ 5,
479
+ 0.3,
480
+ 0,
481
+ 13,
482
+ 1,
483
+ None,
484
+ False
485
+ ],
486
+ [
487
+ '{"ops": [{"insert": "A pizza with pineapple, pepperoni, and "}, {"attributes": {"size": "70px"}, "insert": "mushroom"}, {"insert": " on the top, 4k, photorealistic"}]}',
488
+ 'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
489
+ 5,
490
+ 0.3,
491
+ 0,
492
+ 13,
493
+ 1,
494
+ None,
495
+ False
496
+ ],
497
+ ]
498
+ gr.Examples(examples=size_examples,
499
+ label='Font size examples',
500
+ inputs=[
501
+ text_input,
502
+ negative_prompt,
503
+ num_segments,
504
+ segment_threshold,
505
+ inject_interval,
506
+ seed,
507
+ color_guidance_weight,
508
+ rich_text_input,
509
+ background_aug,
510
+ ],
511
+ outputs=[
512
+ plaintext_result,
513
+ richtext_result,
514
+ segments,
515
+ token_map,
516
+ ],
517
+ fn=generate,
518
+ # cache_examples=True,
519
+ examples_per_page=20)
520
+ generate_button.click(fn=lambda: gr.update(visible=False), inputs=None, outputs=share_row, queue=False).then(
521
+ fn=generate,
522
+ inputs=[
523
+ text_input,
524
+ negative_prompt,
525
+ height,
526
+ width,
527
+ seed,
528
+ steps,
529
+ num_segments,
530
+ segment_threshold,
531
+ inject_interval,
532
+ guidance_weight,
533
+ color_guidance_weight,
534
+ rich_text_input,
535
+ background_aug
536
+ ],
537
+ outputs=[plaintext_result, richtext_result, segments, token_map],
538
+ _js=get_js_data
539
+ ).then(
540
+ fn=lambda: gr.update(visible=True), inputs=None, outputs=share_row, queue=False)
541
+ text_input.change(
542
+ fn=None, inputs=[text_input], outputs=None, _js=set_js_data, queue=False)
543
+ # load url param prompt to textinput
544
+ load_params_button.click(fn=lambda x: x['prompt'], inputs=[
545
+ url_params], outputs=[text_input], queue=False)
546
+ demo.load(
547
+ fn=load_url_params,
548
+ inputs=[url_params],
549
+ outputs=[load_params_button, url_params],
550
+ _js=get_window_url_params
551
+ )
552
+ demo.queue(concurrency_count=1)
553
+ demo.launch(share=False)
554
+
555
+
556
+ if __name__ == "__main__":
557
+ main()
models/attention.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.utils import maybe_allow_in_graph
21
+ from diffusers.models.activations import get_activation
22
+ from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
23
+
24
+ from models.attention_processor import Attention
25
+
26
+ @maybe_allow_in_graph
27
+ class BasicTransformerBlock(nn.Module):
28
+ r"""
29
+ A basic Transformer block.
30
+
31
+ Parameters:
32
+ dim (`int`): The number of channels in the input and output.
33
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
34
+ attention_head_dim (`int`): The number of channels in each head.
35
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
36
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
37
+ only_cross_attention (`bool`, *optional*):
38
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
39
+ double_self_attention (`bool`, *optional*):
40
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
41
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
42
+ num_embeds_ada_norm (:
43
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
44
+ attention_bias (:
45
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ dim: int,
51
+ num_attention_heads: int,
52
+ attention_head_dim: int,
53
+ dropout=0.0,
54
+ cross_attention_dim: Optional[int] = None,
55
+ activation_fn: str = "geglu",
56
+ num_embeds_ada_norm: Optional[int] = None,
57
+ attention_bias: bool = False,
58
+ only_cross_attention: bool = False,
59
+ double_self_attention: bool = False,
60
+ upcast_attention: bool = False,
61
+ norm_elementwise_affine: bool = True,
62
+ norm_type: str = "layer_norm",
63
+ final_dropout: bool = False,
64
+ ):
65
+ super().__init__()
66
+ self.only_cross_attention = only_cross_attention
67
+
68
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
69
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
70
+
71
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
72
+ raise ValueError(
73
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
74
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
75
+ )
76
+
77
+ # Define 3 blocks. Each block has its own normalization layer.
78
+ # 1. Self-Attn
79
+ if self.use_ada_layer_norm:
80
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
81
+ elif self.use_ada_layer_norm_zero:
82
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
83
+ else:
84
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
85
+ self.attn1 = Attention(
86
+ query_dim=dim,
87
+ heads=num_attention_heads,
88
+ dim_head=attention_head_dim,
89
+ dropout=dropout,
90
+ bias=attention_bias,
91
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
92
+ upcast_attention=upcast_attention,
93
+ )
94
+
95
+ # 2. Cross-Attn
96
+ if cross_attention_dim is not None or double_self_attention:
97
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
98
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
99
+ # the second cross attention block.
100
+ self.norm2 = (
101
+ AdaLayerNorm(dim, num_embeds_ada_norm)
102
+ if self.use_ada_layer_norm
103
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
104
+ )
105
+ self.attn2 = Attention(
106
+ query_dim=dim,
107
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
108
+ heads=num_attention_heads,
109
+ dim_head=attention_head_dim,
110
+ dropout=dropout,
111
+ bias=attention_bias,
112
+ upcast_attention=upcast_attention,
113
+ ) # is self-attn if encoder_hidden_states is none
114
+ else:
115
+ self.norm2 = None
116
+ self.attn2 = None
117
+
118
+ # 3. Feed-forward
119
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
120
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
121
+
122
+ # let chunk size default to None
123
+ self._chunk_size = None
124
+ self._chunk_dim = 0
125
+
126
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
127
+ # Sets chunk feed-forward
128
+ self._chunk_size = chunk_size
129
+ self._chunk_dim = dim
130
+
131
+ def forward(
132
+ self,
133
+ hidden_states: torch.FloatTensor,
134
+ attention_mask: Optional[torch.FloatTensor] = None,
135
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
136
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
137
+ timestep: Optional[torch.LongTensor] = None,
138
+ cross_attention_kwargs: Dict[str, Any] = None,
139
+ class_labels: Optional[torch.LongTensor] = None,
140
+ ):
141
+ # Notice that normalization is always applied before the real computation in the following blocks.
142
+ # 1. Self-Attention
143
+ if self.use_ada_layer_norm:
144
+ norm_hidden_states = self.norm1(hidden_states, timestep)
145
+ elif self.use_ada_layer_norm_zero:
146
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
147
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
148
+ )
149
+ else:
150
+ norm_hidden_states = self.norm1(hidden_states)
151
+
152
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
153
+
154
+ # Rich-Text: ignore the attention probs
155
+ attn_output, _ = self.attn1(
156
+ norm_hidden_states,
157
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
158
+ attention_mask=attention_mask,
159
+ **cross_attention_kwargs,
160
+ )
161
+ if self.use_ada_layer_norm_zero:
162
+ attn_output = gate_msa.unsqueeze(1) * attn_output
163
+ hidden_states = attn_output + hidden_states
164
+
165
+ # 2. Cross-Attention
166
+ if self.attn2 is not None:
167
+ norm_hidden_states = (
168
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
169
+ )
170
+
171
+ # Rich-Text: ignore the attention probs
172
+ attn_output, _ = self.attn2(
173
+ norm_hidden_states,
174
+ encoder_hidden_states=encoder_hidden_states,
175
+ attention_mask=encoder_attention_mask,
176
+ **cross_attention_kwargs,
177
+ )
178
+ hidden_states = attn_output + hidden_states
179
+
180
+ # 3. Feed-forward
181
+ norm_hidden_states = self.norm3(hidden_states)
182
+
183
+ if self.use_ada_layer_norm_zero:
184
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
185
+
186
+ if self._chunk_size is not None:
187
+ # "feed_forward_chunk_size" can be used to save memory
188
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
189
+ raise ValueError(
190
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
191
+ )
192
+
193
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
194
+ ff_output = torch.cat(
195
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
196
+ dim=self._chunk_dim,
197
+ )
198
+ else:
199
+ ff_output = self.ff(norm_hidden_states)
200
+
201
+ if self.use_ada_layer_norm_zero:
202
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
203
+
204
+ hidden_states = ff_output + hidden_states
205
+
206
+ return hidden_states
207
+
208
+
209
+ class FeedForward(nn.Module):
210
+ r"""
211
+ A feed-forward layer.
212
+
213
+ Parameters:
214
+ dim (`int`): The number of channels in the input.
215
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
216
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
217
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
218
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
219
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ dim: int,
225
+ dim_out: Optional[int] = None,
226
+ mult: int = 4,
227
+ dropout: float = 0.0,
228
+ activation_fn: str = "geglu",
229
+ final_dropout: bool = False,
230
+ ):
231
+ super().__init__()
232
+ inner_dim = int(dim * mult)
233
+ dim_out = dim_out if dim_out is not None else dim
234
+
235
+ if activation_fn == "gelu":
236
+ act_fn = GELU(dim, inner_dim)
237
+ if activation_fn == "gelu-approximate":
238
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
239
+ elif activation_fn == "geglu":
240
+ act_fn = GEGLU(dim, inner_dim)
241
+ elif activation_fn == "geglu-approximate":
242
+ act_fn = ApproximateGELU(dim, inner_dim)
243
+
244
+ self.net = nn.ModuleList([])
245
+ # project in
246
+ self.net.append(act_fn)
247
+ # project dropout
248
+ self.net.append(nn.Dropout(dropout))
249
+ # project out
250
+ self.net.append(nn.Linear(inner_dim, dim_out))
251
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
252
+ if final_dropout:
253
+ self.net.append(nn.Dropout(dropout))
254
+
255
+ def forward(self, hidden_states):
256
+ for module in self.net:
257
+ hidden_states = module(hidden_states)
258
+ return hidden_states
259
+
260
+
261
+ class GELU(nn.Module):
262
+ r"""
263
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
264
+ """
265
+
266
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
267
+ super().__init__()
268
+ self.proj = nn.Linear(dim_in, dim_out)
269
+ self.approximate = approximate
270
+
271
+ def gelu(self, gate):
272
+ if gate.device.type != "mps":
273
+ return F.gelu(gate, approximate=self.approximate)
274
+ # mps: gelu is not implemented for float16
275
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
276
+
277
+ def forward(self, hidden_states):
278
+ hidden_states = self.proj(hidden_states)
279
+ hidden_states = self.gelu(hidden_states)
280
+ return hidden_states
281
+
282
+
283
+ class GEGLU(nn.Module):
284
+ r"""
285
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
286
+
287
+ Parameters:
288
+ dim_in (`int`): The number of channels in the input.
289
+ dim_out (`int`): The number of channels in the output.
290
+ """
291
+
292
+ def __init__(self, dim_in: int, dim_out: int):
293
+ super().__init__()
294
+ self.proj = nn.Linear(dim_in, dim_out * 2)
295
+
296
+ def gelu(self, gate):
297
+ if gate.device.type != "mps":
298
+ return F.gelu(gate)
299
+ # mps: gelu is not implemented for float16
300
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
301
+
302
+ def forward(self, hidden_states):
303
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
304
+ return hidden_states * self.gelu(gate)
305
+
306
+
307
+ class ApproximateGELU(nn.Module):
308
+ """
309
+ The approximate form of Gaussian Error Linear Unit (GELU)
310
+
311
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
312
+ """
313
+
314
+ def __init__(self, dim_in: int, dim_out: int):
315
+ super().__init__()
316
+ self.proj = nn.Linear(dim_in, dim_out)
317
+
318
+ def forward(self, x):
319
+ x = self.proj(x)
320
+ return x * torch.sigmoid(1.702 * x)
321
+
322
+
323
+ class AdaLayerNorm(nn.Module):
324
+ """
325
+ Norm layer modified to incorporate timestep embeddings.
326
+ """
327
+
328
+ def __init__(self, embedding_dim, num_embeddings):
329
+ super().__init__()
330
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
331
+ self.silu = nn.SiLU()
332
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
333
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
334
+
335
+ def forward(self, x, timestep):
336
+ emb = self.linear(self.silu(self.emb(timestep)))
337
+ scale, shift = torch.chunk(emb, 2)
338
+ x = self.norm(x) * (1 + scale) + shift
339
+ return x
340
+
341
+
342
+ class AdaLayerNormZero(nn.Module):
343
+ """
344
+ Norm layer adaptive layer norm zero (adaLN-Zero).
345
+ """
346
+
347
+ def __init__(self, embedding_dim, num_embeddings):
348
+ super().__init__()
349
+
350
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
351
+
352
+ self.silu = nn.SiLU()
353
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
354
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
355
+
356
+ def forward(self, x, timestep, class_labels, hidden_dtype=None):
357
+ emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
358
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
359
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
360
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
361
+
362
+
363
+ class AdaGroupNorm(nn.Module):
364
+ """
365
+ GroupNorm layer modified to incorporate timestep embeddings.
366
+ """
367
+
368
+ def __init__(
369
+ self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
370
+ ):
371
+ super().__init__()
372
+ self.num_groups = num_groups
373
+ self.eps = eps
374
+
375
+ if act_fn is None:
376
+ self.act = None
377
+ else:
378
+ self.act = get_activation(act_fn)
379
+
380
+ self.linear = nn.Linear(embedding_dim, out_dim * 2)
381
+
382
+ def forward(self, x, emb):
383
+ if self.act:
384
+ emb = self.act(emb)
385
+ emb = self.linear(emb)
386
+ emb = emb[:, :, None, None]
387
+ scale, shift = emb.chunk(2, dim=1)
388
+
389
+ x = F.group_norm(x, self.num_groups, eps=self.eps)
390
+ x = x * (1 + scale) + shift
391
+ return x
models/attention_processor.py ADDED
@@ -0,0 +1,1687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Callable, Optional, Union
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.utils import deprecate, logging, maybe_allow_in_graph
21
+ from diffusers.utils.import_utils import is_xformers_available
22
+
23
+
24
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
+
26
+
27
+ if is_xformers_available():
28
+ import xformers
29
+ import xformers.ops
30
+ else:
31
+ xformers = None
32
+
33
+
34
+ @maybe_allow_in_graph
35
+ class Attention(nn.Module):
36
+ r"""
37
+ A cross attention layer.
38
+
39
+ Parameters:
40
+ query_dim (`int`): The number of channels in the query.
41
+ cross_attention_dim (`int`, *optional*):
42
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
43
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
44
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
45
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
46
+ bias (`bool`, *optional*, defaults to False):
47
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ query_dim: int,
53
+ cross_attention_dim: Optional[int] = None,
54
+ heads: int = 8,
55
+ dim_head: int = 64,
56
+ dropout: float = 0.0,
57
+ bias=False,
58
+ upcast_attention: bool = False,
59
+ upcast_softmax: bool = False,
60
+ cross_attention_norm: Optional[str] = None,
61
+ cross_attention_norm_num_groups: int = 32,
62
+ added_kv_proj_dim: Optional[int] = None,
63
+ norm_num_groups: Optional[int] = None,
64
+ spatial_norm_dim: Optional[int] = None,
65
+ out_bias: bool = True,
66
+ scale_qk: bool = True,
67
+ only_cross_attention: bool = False,
68
+ eps: float = 1e-5,
69
+ rescale_output_factor: float = 1.0,
70
+ residual_connection: bool = False,
71
+ _from_deprecated_attn_block=False,
72
+ processor: Optional["AttnProcessor"] = None,
73
+ ):
74
+ super().__init__()
75
+ inner_dim = dim_head * heads
76
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
77
+ self.upcast_attention = upcast_attention
78
+ self.upcast_softmax = upcast_softmax
79
+ self.rescale_output_factor = rescale_output_factor
80
+ self.residual_connection = residual_connection
81
+ self.dropout = dropout
82
+
83
+ # we make use of this private variable to know whether this class is loaded
84
+ # with an deprecated state dict so that we can convert it on the fly
85
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
86
+
87
+ self.scale_qk = scale_qk
88
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
89
+
90
+ self.heads = heads
91
+ # for slice_size > 0 the attention score computation
92
+ # is split across the batch axis to save memory
93
+ # You can set slice_size with `set_attention_slice`
94
+ self.sliceable_head_dim = heads
95
+
96
+ self.added_kv_proj_dim = added_kv_proj_dim
97
+ self.only_cross_attention = only_cross_attention
98
+
99
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
100
+ raise ValueError(
101
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
102
+ )
103
+
104
+ if norm_num_groups is not None:
105
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
106
+ else:
107
+ self.group_norm = None
108
+
109
+ if spatial_norm_dim is not None:
110
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
111
+ else:
112
+ self.spatial_norm = None
113
+
114
+ if cross_attention_norm is None:
115
+ self.norm_cross = None
116
+ elif cross_attention_norm == "layer_norm":
117
+ self.norm_cross = nn.LayerNorm(cross_attention_dim)
118
+ elif cross_attention_norm == "group_norm":
119
+ if self.added_kv_proj_dim is not None:
120
+ # The given `encoder_hidden_states` are initially of shape
121
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
122
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
123
+ # before the projection, so we need to use `added_kv_proj_dim` as
124
+ # the number of channels for the group norm.
125
+ norm_cross_num_channels = added_kv_proj_dim
126
+ else:
127
+ norm_cross_num_channels = cross_attention_dim
128
+
129
+ self.norm_cross = nn.GroupNorm(
130
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
131
+ )
132
+ else:
133
+ raise ValueError(
134
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
135
+ )
136
+
137
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
138
+
139
+ if not self.only_cross_attention:
140
+ # only relevant for the `AddedKVProcessor` classes
141
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
142
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
143
+ else:
144
+ self.to_k = None
145
+ self.to_v = None
146
+
147
+ if self.added_kv_proj_dim is not None:
148
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
149
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
150
+
151
+ self.to_out = nn.ModuleList([])
152
+ self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
153
+ self.to_out.append(nn.Dropout(dropout))
154
+
155
+ # set attention processor
156
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
157
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
158
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
159
+ if processor is None:
160
+ processor = (
161
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
162
+ )
163
+ self.set_processor(processor)
164
+
165
+ # Rich-Text: util function for averaging over attention heads
166
+ def reshape_batch_dim_to_heads_and_average(self, tensor):
167
+ batch_size, seq_len, seq_len2 = tensor.shape
168
+ head_size = self.heads
169
+ tensor = tensor.reshape(batch_size // head_size,
170
+ head_size, seq_len, seq_len2)
171
+ return tensor.mean(1)
172
+
173
+ def set_use_memory_efficient_attention_xformers(
174
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
175
+ ):
176
+ is_lora = hasattr(self, "processor") and isinstance(
177
+ self.processor,
178
+ (LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
179
+ )
180
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
181
+ self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
182
+ )
183
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
184
+ self.processor,
185
+ (
186
+ AttnAddedKVProcessor,
187
+ AttnAddedKVProcessor2_0,
188
+ SlicedAttnAddedKVProcessor,
189
+ XFormersAttnAddedKVProcessor,
190
+ LoRAAttnAddedKVProcessor,
191
+ ),
192
+ )
193
+
194
+ if use_memory_efficient_attention_xformers:
195
+ if is_added_kv_processor and (is_lora or is_custom_diffusion):
196
+ raise NotImplementedError(
197
+ f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
198
+ )
199
+ if not is_xformers_available():
200
+ raise ModuleNotFoundError(
201
+ (
202
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
203
+ " xformers"
204
+ ),
205
+ name="xformers",
206
+ )
207
+ elif not torch.cuda.is_available():
208
+ raise ValueError(
209
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
210
+ " only available for GPU "
211
+ )
212
+ else:
213
+ try:
214
+ # Make sure we can run the memory efficient attention
215
+ _ = xformers.ops.memory_efficient_attention(
216
+ torch.randn((1, 2, 40), device="cuda"),
217
+ torch.randn((1, 2, 40), device="cuda"),
218
+ torch.randn((1, 2, 40), device="cuda"),
219
+ )
220
+ except Exception as e:
221
+ raise e
222
+
223
+ if is_lora:
224
+ # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
225
+ # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
226
+ processor = LoRAXFormersAttnProcessor(
227
+ hidden_size=self.processor.hidden_size,
228
+ cross_attention_dim=self.processor.cross_attention_dim,
229
+ rank=self.processor.rank,
230
+ attention_op=attention_op,
231
+ )
232
+ processor.load_state_dict(self.processor.state_dict())
233
+ processor.to(self.processor.to_q_lora.up.weight.device)
234
+ elif is_custom_diffusion:
235
+ processor = CustomDiffusionXFormersAttnProcessor(
236
+ train_kv=self.processor.train_kv,
237
+ train_q_out=self.processor.train_q_out,
238
+ hidden_size=self.processor.hidden_size,
239
+ cross_attention_dim=self.processor.cross_attention_dim,
240
+ attention_op=attention_op,
241
+ )
242
+ processor.load_state_dict(self.processor.state_dict())
243
+ if hasattr(self.processor, "to_k_custom_diffusion"):
244
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
245
+ elif is_added_kv_processor:
246
+ # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
247
+ # which uses this type of cross attention ONLY because the attention mask of format
248
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
249
+ # throw warning
250
+ logger.info(
251
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
252
+ )
253
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
254
+ else:
255
+ processor = XFormersAttnProcessor(attention_op=attention_op)
256
+ else:
257
+ if is_lora:
258
+ attn_processor_class = (
259
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
260
+ )
261
+ processor = attn_processor_class(
262
+ hidden_size=self.processor.hidden_size,
263
+ cross_attention_dim=self.processor.cross_attention_dim,
264
+ rank=self.processor.rank,
265
+ )
266
+ processor.load_state_dict(self.processor.state_dict())
267
+ processor.to(self.processor.to_q_lora.up.weight.device)
268
+ elif is_custom_diffusion:
269
+ processor = CustomDiffusionAttnProcessor(
270
+ train_kv=self.processor.train_kv,
271
+ train_q_out=self.processor.train_q_out,
272
+ hidden_size=self.processor.hidden_size,
273
+ cross_attention_dim=self.processor.cross_attention_dim,
274
+ )
275
+ processor.load_state_dict(self.processor.state_dict())
276
+ if hasattr(self.processor, "to_k_custom_diffusion"):
277
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
278
+ else:
279
+ # set attention processor
280
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
281
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
282
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
283
+ processor = (
284
+ AttnProcessor2_0()
285
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
286
+ else AttnProcessor()
287
+ )
288
+
289
+ self.set_processor(processor)
290
+
291
+ def set_attention_slice(self, slice_size):
292
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
293
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
294
+
295
+ if slice_size is not None and self.added_kv_proj_dim is not None:
296
+ processor = SlicedAttnAddedKVProcessor(slice_size)
297
+ elif slice_size is not None:
298
+ processor = SlicedAttnProcessor(slice_size)
299
+ elif self.added_kv_proj_dim is not None:
300
+ processor = AttnAddedKVProcessor()
301
+ else:
302
+ # set attention processor
303
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
304
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
305
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
306
+ processor = (
307
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
308
+ )
309
+
310
+ self.set_processor(processor)
311
+
312
+ def set_processor(self, processor: "AttnProcessor"):
313
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
314
+ # pop `processor` from `self._modules`
315
+ if (
316
+ hasattr(self, "processor")
317
+ and isinstance(self.processor, torch.nn.Module)
318
+ and not isinstance(processor, torch.nn.Module)
319
+ ):
320
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
321
+ self._modules.pop("processor")
322
+
323
+ self.processor = processor
324
+
325
+ # Rich-Text: inject self-attention maps
326
+ def forward(self, hidden_states, real_attn_probs=None, attn_weights=None, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
327
+ # The `Attention` class can call different attention processors / attention functions
328
+ # here we simply pass along all tensors to the selected processor class
329
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
330
+ return self.processor(
331
+ self,
332
+ hidden_states,
333
+ real_attn_probs=real_attn_probs,
334
+ attn_weights=attn_weights,
335
+ encoder_hidden_states=encoder_hidden_states,
336
+ attention_mask=attention_mask,
337
+ **cross_attention_kwargs,
338
+ )
339
+
340
+ def batch_to_head_dim(self, tensor):
341
+ head_size = self.heads
342
+ batch_size, seq_len, dim = tensor.shape
343
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
344
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
345
+ return tensor
346
+
347
+ def head_to_batch_dim(self, tensor, out_dim=3):
348
+ head_size = self.heads
349
+ batch_size, seq_len, dim = tensor.shape
350
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
351
+ tensor = tensor.permute(0, 2, 1, 3)
352
+
353
+ if out_dim == 3:
354
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
355
+
356
+ return tensor
357
+
358
+ # Rich-Text: return attention scores
359
+ def get_attention_scores(self, query, key, attention_mask=None, attn_weights=False):
360
+ dtype = query.dtype
361
+ if self.upcast_attention:
362
+ query = query.float()
363
+ key = key.float()
364
+
365
+ if attention_mask is None:
366
+ baddbmm_input = torch.empty(
367
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
368
+ )
369
+ beta = 0
370
+ else:
371
+ baddbmm_input = attention_mask
372
+ beta = 1
373
+
374
+ attention_scores = torch.baddbmm(
375
+ baddbmm_input,
376
+ query,
377
+ key.transpose(-1, -2),
378
+ beta=beta,
379
+ alpha=self.scale,
380
+ )
381
+ del baddbmm_input
382
+
383
+ if self.upcast_softmax:
384
+ attention_scores = attention_scores.float()
385
+
386
+ # Rich-Text: font size
387
+ if attn_weights is not None:
388
+ assert key.shape[1] == 77
389
+ attention_scores_stable = attention_scores - attention_scores.max(-1, True)[0]
390
+ attention_score_exp = attention_scores_stable.float().exp()
391
+ # attention_score_exp = attention_scores.float().exp()
392
+ font_size_abs, font_size_sign = attn_weights['font_size'].abs(), attn_weights['font_size'].sign()
393
+ attention_score_exp[:, :, attn_weights['word_pos']] = attention_score_exp[:, :, attn_weights['word_pos']].clone(
394
+ )*font_size_abs
395
+ attention_probs = attention_score_exp / attention_score_exp.sum(-1, True)
396
+ attention_probs[:, :, attn_weights['word_pos']] *= font_size_sign
397
+ # import ipdb; ipdb.set_trace()
398
+ if attention_probs.isnan().any():
399
+ import ipdb; ipdb.set_trace()
400
+ else:
401
+ attention_probs = attention_scores.softmax(dim=-1)
402
+
403
+ del attention_scores
404
+
405
+ attention_probs = attention_probs.to(dtype)
406
+
407
+ return attention_probs
408
+
409
+ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
410
+ if batch_size is None:
411
+ deprecate(
412
+ "batch_size=None",
413
+ "0.0.15",
414
+ (
415
+ "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
416
+ " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
417
+ " `prepare_attention_mask` when preparing the attention_mask."
418
+ ),
419
+ )
420
+ batch_size = 1
421
+
422
+ head_size = self.heads
423
+ if attention_mask is None:
424
+ return attention_mask
425
+
426
+ current_length: int = attention_mask.shape[-1]
427
+ if current_length != target_length:
428
+ if attention_mask.device.type == "mps":
429
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
430
+ # Instead, we can manually construct the padding tensor.
431
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
432
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
433
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
434
+ else:
435
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
436
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
437
+ # remaining_length: int = target_length - current_length
438
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
439
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
440
+
441
+ if out_dim == 3:
442
+ if attention_mask.shape[0] < batch_size * head_size:
443
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
444
+ elif out_dim == 4:
445
+ attention_mask = attention_mask.unsqueeze(1)
446
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
447
+
448
+ return attention_mask
449
+
450
+ def norm_encoder_hidden_states(self, encoder_hidden_states):
451
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
452
+
453
+ if isinstance(self.norm_cross, nn.LayerNorm):
454
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
455
+ elif isinstance(self.norm_cross, nn.GroupNorm):
456
+ # Group norm norms along the channels dimension and expects
457
+ # input to be in the shape of (N, C, *). In this case, we want
458
+ # to norm along the hidden dimension, so we need to move
459
+ # (batch_size, sequence_length, hidden_size) ->
460
+ # (batch_size, hidden_size, sequence_length)
461
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
462
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
463
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
464
+ else:
465
+ assert False
466
+
467
+ return encoder_hidden_states
468
+
469
+
470
+ class AttnProcessor:
471
+ r"""
472
+ Default processor for performing attention-related computations.
473
+ """
474
+
475
+ # Rich-Text: inject self-attention maps
476
+ def __call__(
477
+ self,
478
+ attn: Attention,
479
+ hidden_states,
480
+ real_attn_probs=None,
481
+ attn_weights=None,
482
+ encoder_hidden_states=None,
483
+ attention_mask=None,
484
+ temb=None,
485
+ ):
486
+ residual = hidden_states
487
+
488
+ if attn.spatial_norm is not None:
489
+ hidden_states = attn.spatial_norm(hidden_states, temb)
490
+
491
+ input_ndim = hidden_states.ndim
492
+
493
+ if input_ndim == 4:
494
+ batch_size, channel, height, width = hidden_states.shape
495
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
496
+
497
+ batch_size, sequence_length, _ = (
498
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
499
+ )
500
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
501
+
502
+ if attn.group_norm is not None:
503
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
504
+
505
+ query = attn.to_q(hidden_states)
506
+
507
+ if encoder_hidden_states is None:
508
+ encoder_hidden_states = hidden_states
509
+ elif attn.norm_cross:
510
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
511
+
512
+ key = attn.to_k(encoder_hidden_states)
513
+ value = attn.to_v(encoder_hidden_states)
514
+
515
+ query = attn.head_to_batch_dim(query)
516
+ key = attn.head_to_batch_dim(key)
517
+ value = attn.head_to_batch_dim(value)
518
+
519
+ if real_attn_probs is None:
520
+ # Rich-Text: font size
521
+ attention_probs = attn.get_attention_scores(query, key, attention_mask, attn_weights=attn_weights)
522
+ else:
523
+ # Rich-Text: inject self-attention maps
524
+ attention_probs = real_attn_probs
525
+ hidden_states = torch.bmm(attention_probs, value)
526
+ hidden_states = attn.batch_to_head_dim(hidden_states)
527
+
528
+ # linear proj
529
+ hidden_states = attn.to_out[0](hidden_states)
530
+ # dropout
531
+ hidden_states = attn.to_out[1](hidden_states)
532
+
533
+ if input_ndim == 4:
534
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
535
+
536
+ if attn.residual_connection:
537
+ hidden_states = hidden_states + residual
538
+
539
+ hidden_states = hidden_states / attn.rescale_output_factor
540
+
541
+ # Rich-Text Modified: return attn probs
542
+ # We return the map averaged over heads to save memory footprint
543
+ attention_probs_avg = attn.reshape_batch_dim_to_heads_and_average(
544
+ attention_probs)
545
+ return hidden_states, [attention_probs_avg, attention_probs]
546
+
547
+
548
+ class LoRALinearLayer(nn.Module):
549
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None):
550
+ super().__init__()
551
+
552
+ if rank > min(in_features, out_features):
553
+ raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
554
+
555
+ self.down = nn.Linear(in_features, rank, bias=False)
556
+ self.up = nn.Linear(rank, out_features, bias=False)
557
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
558
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
559
+ self.network_alpha = network_alpha
560
+ self.rank = rank
561
+
562
+ nn.init.normal_(self.down.weight, std=1 / rank)
563
+ nn.init.zeros_(self.up.weight)
564
+
565
+ def forward(self, hidden_states):
566
+ orig_dtype = hidden_states.dtype
567
+ dtype = self.down.weight.dtype
568
+
569
+ down_hidden_states = self.down(hidden_states.to(dtype))
570
+ up_hidden_states = self.up(down_hidden_states)
571
+
572
+ if self.network_alpha is not None:
573
+ up_hidden_states *= self.network_alpha / self.rank
574
+
575
+ return up_hidden_states.to(orig_dtype)
576
+
577
+
578
+ class LoRAAttnProcessor(nn.Module):
579
+ r"""
580
+ Processor for implementing the LoRA attention mechanism.
581
+
582
+ Args:
583
+ hidden_size (`int`, *optional*):
584
+ The hidden size of the attention layer.
585
+ cross_attention_dim (`int`, *optional*):
586
+ The number of channels in the `encoder_hidden_states`.
587
+ rank (`int`, defaults to 4):
588
+ The dimension of the LoRA update matrices.
589
+ network_alpha (`int`, *optional*):
590
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
591
+ """
592
+
593
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
594
+ super().__init__()
595
+
596
+ self.hidden_size = hidden_size
597
+ self.cross_attention_dim = cross_attention_dim
598
+ self.rank = rank
599
+
600
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
601
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
602
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
603
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
604
+
605
+ def __call__(
606
+ self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
607
+ ):
608
+ residual = hidden_states
609
+
610
+ if attn.spatial_norm is not None:
611
+ hidden_states = attn.spatial_norm(hidden_states, temb)
612
+
613
+ input_ndim = hidden_states.ndim
614
+
615
+ if input_ndim == 4:
616
+ batch_size, channel, height, width = hidden_states.shape
617
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
618
+
619
+ batch_size, sequence_length, _ = (
620
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
621
+ )
622
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
623
+
624
+ if attn.group_norm is not None:
625
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
626
+
627
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
628
+ query = attn.head_to_batch_dim(query)
629
+
630
+ if encoder_hidden_states is None:
631
+ encoder_hidden_states = hidden_states
632
+ elif attn.norm_cross:
633
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
634
+
635
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
636
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
637
+
638
+ key = attn.head_to_batch_dim(key)
639
+ value = attn.head_to_batch_dim(value)
640
+
641
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
642
+ hidden_states = torch.bmm(attention_probs, value)
643
+ hidden_states = attn.batch_to_head_dim(hidden_states)
644
+
645
+ # linear proj
646
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
647
+ # dropout
648
+ hidden_states = attn.to_out[1](hidden_states)
649
+
650
+ if input_ndim == 4:
651
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
652
+
653
+ if attn.residual_connection:
654
+ hidden_states = hidden_states + residual
655
+
656
+ hidden_states = hidden_states / attn.rescale_output_factor
657
+
658
+ return hidden_states
659
+
660
+
661
+ class CustomDiffusionAttnProcessor(nn.Module):
662
+ r"""
663
+ Processor for implementing attention for the Custom Diffusion method.
664
+
665
+ Args:
666
+ train_kv (`bool`, defaults to `True`):
667
+ Whether to newly train the key and value matrices corresponding to the text features.
668
+ train_q_out (`bool`, defaults to `True`):
669
+ Whether to newly train query matrices corresponding to the latent image features.
670
+ hidden_size (`int`, *optional*, defaults to `None`):
671
+ The hidden size of the attention layer.
672
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
673
+ The number of channels in the `encoder_hidden_states`.
674
+ out_bias (`bool`, defaults to `True`):
675
+ Whether to include the bias parameter in `train_q_out`.
676
+ dropout (`float`, *optional*, defaults to 0.0):
677
+ The dropout probability to use.
678
+ """
679
+
680
+ def __init__(
681
+ self,
682
+ train_kv=True,
683
+ train_q_out=True,
684
+ hidden_size=None,
685
+ cross_attention_dim=None,
686
+ out_bias=True,
687
+ dropout=0.0,
688
+ ):
689
+ super().__init__()
690
+ self.train_kv = train_kv
691
+ self.train_q_out = train_q_out
692
+
693
+ self.hidden_size = hidden_size
694
+ self.cross_attention_dim = cross_attention_dim
695
+
696
+ # `_custom_diffusion` id for easy serialization and loading.
697
+ if self.train_kv:
698
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
699
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
700
+ if self.train_q_out:
701
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
702
+ self.to_out_custom_diffusion = nn.ModuleList([])
703
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
704
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
705
+
706
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
707
+ batch_size, sequence_length, _ = hidden_states.shape
708
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
709
+ if self.train_q_out:
710
+ query = self.to_q_custom_diffusion(hidden_states)
711
+ else:
712
+ query = attn.to_q(hidden_states)
713
+
714
+ if encoder_hidden_states is None:
715
+ crossattn = False
716
+ encoder_hidden_states = hidden_states
717
+ else:
718
+ crossattn = True
719
+ if attn.norm_cross:
720
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
721
+
722
+ if self.train_kv:
723
+ key = self.to_k_custom_diffusion(encoder_hidden_states)
724
+ value = self.to_v_custom_diffusion(encoder_hidden_states)
725
+ else:
726
+ key = attn.to_k(encoder_hidden_states)
727
+ value = attn.to_v(encoder_hidden_states)
728
+
729
+ if crossattn:
730
+ detach = torch.ones_like(key)
731
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
732
+ key = detach * key + (1 - detach) * key.detach()
733
+ value = detach * value + (1 - detach) * value.detach()
734
+
735
+ query = attn.head_to_batch_dim(query)
736
+ key = attn.head_to_batch_dim(key)
737
+ value = attn.head_to_batch_dim(value)
738
+
739
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
740
+ hidden_states = torch.bmm(attention_probs, value)
741
+ hidden_states = attn.batch_to_head_dim(hidden_states)
742
+
743
+ if self.train_q_out:
744
+ # linear proj
745
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
746
+ # dropout
747
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
748
+ else:
749
+ # linear proj
750
+ hidden_states = attn.to_out[0](hidden_states)
751
+ # dropout
752
+ hidden_states = attn.to_out[1](hidden_states)
753
+
754
+ return hidden_states
755
+
756
+
757
+ class AttnAddedKVProcessor:
758
+ r"""
759
+ Processor for performing attention-related computations with extra learnable key and value matrices for the text
760
+ encoder.
761
+ """
762
+
763
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
764
+ residual = hidden_states
765
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
766
+ batch_size, sequence_length, _ = hidden_states.shape
767
+
768
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
769
+
770
+ if encoder_hidden_states is None:
771
+ encoder_hidden_states = hidden_states
772
+ elif attn.norm_cross:
773
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
774
+
775
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
776
+
777
+ query = attn.to_q(hidden_states)
778
+ query = attn.head_to_batch_dim(query)
779
+
780
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
781
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
782
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
783
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
784
+
785
+ if not attn.only_cross_attention:
786
+ key = attn.to_k(hidden_states)
787
+ value = attn.to_v(hidden_states)
788
+ key = attn.head_to_batch_dim(key)
789
+ value = attn.head_to_batch_dim(value)
790
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
791
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
792
+ else:
793
+ key = encoder_hidden_states_key_proj
794
+ value = encoder_hidden_states_value_proj
795
+
796
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
797
+ hidden_states = torch.bmm(attention_probs, value)
798
+ hidden_states = attn.batch_to_head_dim(hidden_states)
799
+
800
+ # linear proj
801
+ hidden_states = attn.to_out[0](hidden_states)
802
+ # dropout
803
+ hidden_states = attn.to_out[1](hidden_states)
804
+
805
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
806
+ hidden_states = hidden_states + residual
807
+
808
+ return hidden_states
809
+
810
+
811
+ class AttnAddedKVProcessor2_0:
812
+ r"""
813
+ Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
814
+ learnable key and value matrices for the text encoder.
815
+ """
816
+
817
+ def __init__(self):
818
+ if not hasattr(F, "scaled_dot_product_attention"):
819
+ raise ImportError(
820
+ "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
821
+ )
822
+
823
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
824
+ residual = hidden_states
825
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
826
+ batch_size, sequence_length, _ = hidden_states.shape
827
+
828
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
829
+
830
+ if encoder_hidden_states is None:
831
+ encoder_hidden_states = hidden_states
832
+ elif attn.norm_cross:
833
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
834
+
835
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
836
+
837
+ query = attn.to_q(hidden_states)
838
+ query = attn.head_to_batch_dim(query, out_dim=4)
839
+
840
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
841
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
842
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
843
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
844
+
845
+ if not attn.only_cross_attention:
846
+ key = attn.to_k(hidden_states)
847
+ value = attn.to_v(hidden_states)
848
+ key = attn.head_to_batch_dim(key, out_dim=4)
849
+ value = attn.head_to_batch_dim(value, out_dim=4)
850
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
851
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
852
+ else:
853
+ key = encoder_hidden_states_key_proj
854
+ value = encoder_hidden_states_value_proj
855
+
856
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
857
+ # TODO: add support for attn.scale when we move to Torch 2.1
858
+ hidden_states = F.scaled_dot_product_attention(
859
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
860
+ )
861
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
862
+
863
+ # linear proj
864
+ hidden_states = attn.to_out[0](hidden_states)
865
+ # dropout
866
+ hidden_states = attn.to_out[1](hidden_states)
867
+
868
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
869
+ hidden_states = hidden_states + residual
870
+
871
+ return hidden_states
872
+
873
+
874
+ class LoRAAttnAddedKVProcessor(nn.Module):
875
+ r"""
876
+ Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
877
+ encoder.
878
+
879
+ Args:
880
+ hidden_size (`int`, *optional*):
881
+ The hidden size of the attention layer.
882
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
883
+ The number of channels in the `encoder_hidden_states`.
884
+ rank (`int`, defaults to 4):
885
+ The dimension of the LoRA update matrices.
886
+
887
+ """
888
+
889
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
890
+ super().__init__()
891
+
892
+ self.hidden_size = hidden_size
893
+ self.cross_attention_dim = cross_attention_dim
894
+ self.rank = rank
895
+
896
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
897
+ self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
898
+ self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
899
+ self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
900
+ self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
901
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
902
+
903
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
904
+ residual = hidden_states
905
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
906
+ batch_size, sequence_length, _ = hidden_states.shape
907
+
908
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
909
+
910
+ if encoder_hidden_states is None:
911
+ encoder_hidden_states = hidden_states
912
+ elif attn.norm_cross:
913
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
914
+
915
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
916
+
917
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
918
+ query = attn.head_to_batch_dim(query)
919
+
920
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
921
+ encoder_hidden_states
922
+ )
923
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
924
+ encoder_hidden_states
925
+ )
926
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
927
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
928
+
929
+ if not attn.only_cross_attention:
930
+ key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
931
+ value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
932
+ key = attn.head_to_batch_dim(key)
933
+ value = attn.head_to_batch_dim(value)
934
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
935
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
936
+ else:
937
+ key = encoder_hidden_states_key_proj
938
+ value = encoder_hidden_states_value_proj
939
+
940
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
941
+ hidden_states = torch.bmm(attention_probs, value)
942
+ hidden_states = attn.batch_to_head_dim(hidden_states)
943
+
944
+ # linear proj
945
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
946
+ # dropout
947
+ hidden_states = attn.to_out[1](hidden_states)
948
+
949
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
950
+ hidden_states = hidden_states + residual
951
+
952
+ return hidden_states
953
+
954
+
955
+ class XFormersAttnAddedKVProcessor:
956
+ r"""
957
+ Processor for implementing memory efficient attention using xFormers.
958
+
959
+ Args:
960
+ attention_op (`Callable`, *optional*, defaults to `None`):
961
+ The base
962
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
963
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
964
+ operator.
965
+ """
966
+
967
+ def __init__(self, attention_op: Optional[Callable] = None):
968
+ self.attention_op = attention_op
969
+
970
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
971
+ residual = hidden_states
972
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
973
+ batch_size, sequence_length, _ = hidden_states.shape
974
+
975
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
976
+
977
+ if encoder_hidden_states is None:
978
+ encoder_hidden_states = hidden_states
979
+ elif attn.norm_cross:
980
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
981
+
982
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
983
+
984
+ query = attn.to_q(hidden_states)
985
+ query = attn.head_to_batch_dim(query)
986
+
987
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
988
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
989
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
990
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
991
+
992
+ if not attn.only_cross_attention:
993
+ key = attn.to_k(hidden_states)
994
+ value = attn.to_v(hidden_states)
995
+ key = attn.head_to_batch_dim(key)
996
+ value = attn.head_to_batch_dim(value)
997
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
998
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
999
+ else:
1000
+ key = encoder_hidden_states_key_proj
1001
+ value = encoder_hidden_states_value_proj
1002
+
1003
+ hidden_states = xformers.ops.memory_efficient_attention(
1004
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1005
+ )
1006
+ hidden_states = hidden_states.to(query.dtype)
1007
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1008
+
1009
+ # linear proj
1010
+ hidden_states = attn.to_out[0](hidden_states)
1011
+ # dropout
1012
+ hidden_states = attn.to_out[1](hidden_states)
1013
+
1014
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1015
+ hidden_states = hidden_states + residual
1016
+
1017
+ return hidden_states
1018
+
1019
+
1020
+ class XFormersAttnProcessor:
1021
+ r"""
1022
+ Processor for implementing memory efficient attention using xFormers.
1023
+
1024
+ Args:
1025
+ attention_op (`Callable`, *optional*, defaults to `None`):
1026
+ The base
1027
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1028
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1029
+ operator.
1030
+ """
1031
+
1032
+ def __init__(self, attention_op: Optional[Callable] = None):
1033
+ self.attention_op = attention_op
1034
+
1035
+ def __call__(
1036
+ self,
1037
+ attn: Attention,
1038
+ hidden_states: torch.FloatTensor,
1039
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1040
+ attention_mask: Optional[torch.FloatTensor] = None,
1041
+ temb: Optional[torch.FloatTensor] = None,
1042
+ ):
1043
+ residual = hidden_states
1044
+
1045
+ if attn.spatial_norm is not None:
1046
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1047
+
1048
+ input_ndim = hidden_states.ndim
1049
+
1050
+ if input_ndim == 4:
1051
+ batch_size, channel, height, width = hidden_states.shape
1052
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1053
+
1054
+ batch_size, key_tokens, _ = (
1055
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1056
+ )
1057
+
1058
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
1059
+ if attention_mask is not None:
1060
+ # expand our mask's singleton query_tokens dimension:
1061
+ # [batch*heads, 1, key_tokens] ->
1062
+ # [batch*heads, query_tokens, key_tokens]
1063
+ # so that it can be added as a bias onto the attention scores that xformers computes:
1064
+ # [batch*heads, query_tokens, key_tokens]
1065
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
1066
+ _, query_tokens, _ = hidden_states.shape
1067
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
1068
+
1069
+ if attn.group_norm is not None:
1070
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1071
+
1072
+ query = attn.to_q(hidden_states)
1073
+
1074
+ if encoder_hidden_states is None:
1075
+ encoder_hidden_states = hidden_states
1076
+ elif attn.norm_cross:
1077
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1078
+
1079
+ key = attn.to_k(encoder_hidden_states)
1080
+ value = attn.to_v(encoder_hidden_states)
1081
+
1082
+ query = attn.head_to_batch_dim(query).contiguous()
1083
+ key = attn.head_to_batch_dim(key).contiguous()
1084
+ value = attn.head_to_batch_dim(value).contiguous()
1085
+
1086
+ hidden_states = xformers.ops.memory_efficient_attention(
1087
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1088
+ )
1089
+ hidden_states = hidden_states.to(query.dtype)
1090
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1091
+
1092
+ # linear proj
1093
+ hidden_states = attn.to_out[0](hidden_states)
1094
+ # dropout
1095
+ hidden_states = attn.to_out[1](hidden_states)
1096
+
1097
+ if input_ndim == 4:
1098
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1099
+
1100
+ if attn.residual_connection:
1101
+ hidden_states = hidden_states + residual
1102
+
1103
+ hidden_states = hidden_states / attn.rescale_output_factor
1104
+
1105
+ return hidden_states
1106
+
1107
+
1108
+ class AttnProcessor2_0:
1109
+ r"""
1110
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1111
+ """
1112
+
1113
+ def __init__(self):
1114
+ if not hasattr(F, "scaled_dot_product_attention"):
1115
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1116
+
1117
+ def __call__(
1118
+ self,
1119
+ attn: Attention,
1120
+ hidden_states,
1121
+ encoder_hidden_states=None,
1122
+ attention_mask=None,
1123
+ temb=None,
1124
+ ):
1125
+ residual = hidden_states
1126
+
1127
+ if attn.spatial_norm is not None:
1128
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1129
+
1130
+ input_ndim = hidden_states.ndim
1131
+
1132
+ if input_ndim == 4:
1133
+ batch_size, channel, height, width = hidden_states.shape
1134
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1135
+
1136
+ batch_size, sequence_length, _ = (
1137
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1138
+ )
1139
+ inner_dim = hidden_states.shape[-1]
1140
+
1141
+ if attention_mask is not None:
1142
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1143
+ # scaled_dot_product_attention expects attention_mask shape to be
1144
+ # (batch, heads, source_length, target_length)
1145
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1146
+
1147
+ if attn.group_norm is not None:
1148
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1149
+
1150
+ query = attn.to_q(hidden_states)
1151
+
1152
+ if encoder_hidden_states is None:
1153
+ encoder_hidden_states = hidden_states
1154
+ elif attn.norm_cross:
1155
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1156
+
1157
+ key = attn.to_k(encoder_hidden_states)
1158
+ value = attn.to_v(encoder_hidden_states)
1159
+
1160
+ head_dim = inner_dim // attn.heads
1161
+
1162
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1163
+
1164
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1165
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1166
+
1167
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1168
+ # TODO: add support for attn.scale when we move to Torch 2.1
1169
+ hidden_states = F.scaled_dot_product_attention(
1170
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1171
+ )
1172
+
1173
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1174
+ hidden_states = hidden_states.to(query.dtype)
1175
+
1176
+ # linear proj
1177
+ hidden_states = attn.to_out[0](hidden_states)
1178
+ # dropout
1179
+ hidden_states = attn.to_out[1](hidden_states)
1180
+
1181
+ if input_ndim == 4:
1182
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1183
+
1184
+ if attn.residual_connection:
1185
+ hidden_states = hidden_states + residual
1186
+
1187
+ hidden_states = hidden_states / attn.rescale_output_factor
1188
+
1189
+ return hidden_states
1190
+
1191
+
1192
+ class LoRAXFormersAttnProcessor(nn.Module):
1193
+ r"""
1194
+ Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
1195
+
1196
+ Args:
1197
+ hidden_size (`int`, *optional*):
1198
+ The hidden size of the attention layer.
1199
+ cross_attention_dim (`int`, *optional*):
1200
+ The number of channels in the `encoder_hidden_states`.
1201
+ rank (`int`, defaults to 4):
1202
+ The dimension of the LoRA update matrices.
1203
+ attention_op (`Callable`, *optional*, defaults to `None`):
1204
+ The base
1205
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1206
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1207
+ operator.
1208
+ network_alpha (`int`, *optional*):
1209
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1210
+
1211
+ """
1212
+
1213
+ def __init__(
1214
+ self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
1215
+ ):
1216
+ super().__init__()
1217
+
1218
+ self.hidden_size = hidden_size
1219
+ self.cross_attention_dim = cross_attention_dim
1220
+ self.rank = rank
1221
+ self.attention_op = attention_op
1222
+
1223
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1224
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1225
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1226
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1227
+
1228
+ def __call__(
1229
+ self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
1230
+ ):
1231
+ residual = hidden_states
1232
+
1233
+ if attn.spatial_norm is not None:
1234
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1235
+
1236
+ input_ndim = hidden_states.ndim
1237
+
1238
+ if input_ndim == 4:
1239
+ batch_size, channel, height, width = hidden_states.shape
1240
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1241
+
1242
+ batch_size, sequence_length, _ = (
1243
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1244
+ )
1245
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1246
+
1247
+ if attn.group_norm is not None:
1248
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1249
+
1250
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
1251
+ query = attn.head_to_batch_dim(query).contiguous()
1252
+
1253
+ if encoder_hidden_states is None:
1254
+ encoder_hidden_states = hidden_states
1255
+ elif attn.norm_cross:
1256
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1257
+
1258
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
1259
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
1260
+
1261
+ key = attn.head_to_batch_dim(key).contiguous()
1262
+ value = attn.head_to_batch_dim(value).contiguous()
1263
+
1264
+ hidden_states = xformers.ops.memory_efficient_attention(
1265
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1266
+ )
1267
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1268
+
1269
+ # linear proj
1270
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
1271
+ # dropout
1272
+ hidden_states = attn.to_out[1](hidden_states)
1273
+
1274
+ if input_ndim == 4:
1275
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1276
+
1277
+ if attn.residual_connection:
1278
+ hidden_states = hidden_states + residual
1279
+
1280
+ hidden_states = hidden_states / attn.rescale_output_factor
1281
+
1282
+ return hidden_states
1283
+
1284
+
1285
+ class LoRAAttnProcessor2_0(nn.Module):
1286
+ r"""
1287
+ Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
1288
+ attention.
1289
+
1290
+ Args:
1291
+ hidden_size (`int`):
1292
+ The hidden size of the attention layer.
1293
+ cross_attention_dim (`int`, *optional*):
1294
+ The number of channels in the `encoder_hidden_states`.
1295
+ rank (`int`, defaults to 4):
1296
+ The dimension of the LoRA update matrices.
1297
+ network_alpha (`int`, *optional*):
1298
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1299
+ """
1300
+
1301
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
1302
+ super().__init__()
1303
+ if not hasattr(F, "scaled_dot_product_attention"):
1304
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1305
+
1306
+ self.hidden_size = hidden_size
1307
+ self.cross_attention_dim = cross_attention_dim
1308
+ self.rank = rank
1309
+
1310
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1311
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1312
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1313
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1314
+
1315
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
1316
+ residual = hidden_states
1317
+
1318
+ input_ndim = hidden_states.ndim
1319
+
1320
+ if input_ndim == 4:
1321
+ batch_size, channel, height, width = hidden_states.shape
1322
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1323
+
1324
+ batch_size, sequence_length, _ = (
1325
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1326
+ )
1327
+ inner_dim = hidden_states.shape[-1]
1328
+
1329
+ if attention_mask is not None:
1330
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1331
+ # scaled_dot_product_attention expects attention_mask shape to be
1332
+ # (batch, heads, source_length, target_length)
1333
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1334
+
1335
+ if attn.group_norm is not None:
1336
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1337
+
1338
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
1339
+
1340
+ if encoder_hidden_states is None:
1341
+ encoder_hidden_states = hidden_states
1342
+ elif attn.norm_cross:
1343
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1344
+
1345
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
1346
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
1347
+
1348
+ head_dim = inner_dim // attn.heads
1349
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1350
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1351
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1352
+
1353
+ # TODO: add support for attn.scale when we move to Torch 2.1
1354
+ hidden_states = F.scaled_dot_product_attention(
1355
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1356
+ )
1357
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1358
+ hidden_states = hidden_states.to(query.dtype)
1359
+
1360
+ # linear proj
1361
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
1362
+ # dropout
1363
+ hidden_states = attn.to_out[1](hidden_states)
1364
+
1365
+ if input_ndim == 4:
1366
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1367
+
1368
+ if attn.residual_connection:
1369
+ hidden_states = hidden_states + residual
1370
+
1371
+ hidden_states = hidden_states / attn.rescale_output_factor
1372
+
1373
+ return hidden_states
1374
+
1375
+
1376
+ class CustomDiffusionXFormersAttnProcessor(nn.Module):
1377
+ r"""
1378
+ Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
1379
+
1380
+ Args:
1381
+ train_kv (`bool`, defaults to `True`):
1382
+ Whether to newly train the key and value matrices corresponding to the text features.
1383
+ train_q_out (`bool`, defaults to `True`):
1384
+ Whether to newly train query matrices corresponding to the latent image features.
1385
+ hidden_size (`int`, *optional*, defaults to `None`):
1386
+ The hidden size of the attention layer.
1387
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1388
+ The number of channels in the `encoder_hidden_states`.
1389
+ out_bias (`bool`, defaults to `True`):
1390
+ Whether to include the bias parameter in `train_q_out`.
1391
+ dropout (`float`, *optional*, defaults to 0.0):
1392
+ The dropout probability to use.
1393
+ attention_op (`Callable`, *optional*, defaults to `None`):
1394
+ The base
1395
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
1396
+ as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
1397
+ """
1398
+
1399
+ def __init__(
1400
+ self,
1401
+ train_kv=True,
1402
+ train_q_out=False,
1403
+ hidden_size=None,
1404
+ cross_attention_dim=None,
1405
+ out_bias=True,
1406
+ dropout=0.0,
1407
+ attention_op: Optional[Callable] = None,
1408
+ ):
1409
+ super().__init__()
1410
+ self.train_kv = train_kv
1411
+ self.train_q_out = train_q_out
1412
+
1413
+ self.hidden_size = hidden_size
1414
+ self.cross_attention_dim = cross_attention_dim
1415
+ self.attention_op = attention_op
1416
+
1417
+ # `_custom_diffusion` id for easy serialization and loading.
1418
+ if self.train_kv:
1419
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1420
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1421
+ if self.train_q_out:
1422
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
1423
+ self.to_out_custom_diffusion = nn.ModuleList([])
1424
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
1425
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
1426
+
1427
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
1428
+ batch_size, sequence_length, _ = (
1429
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1430
+ )
1431
+
1432
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1433
+
1434
+ if self.train_q_out:
1435
+ query = self.to_q_custom_diffusion(hidden_states)
1436
+ else:
1437
+ query = attn.to_q(hidden_states)
1438
+
1439
+ if encoder_hidden_states is None:
1440
+ crossattn = False
1441
+ encoder_hidden_states = hidden_states
1442
+ else:
1443
+ crossattn = True
1444
+ if attn.norm_cross:
1445
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1446
+
1447
+ if self.train_kv:
1448
+ key = self.to_k_custom_diffusion(encoder_hidden_states)
1449
+ value = self.to_v_custom_diffusion(encoder_hidden_states)
1450
+ else:
1451
+ key = attn.to_k(encoder_hidden_states)
1452
+ value = attn.to_v(encoder_hidden_states)
1453
+
1454
+ if crossattn:
1455
+ detach = torch.ones_like(key)
1456
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
1457
+ key = detach * key + (1 - detach) * key.detach()
1458
+ value = detach * value + (1 - detach) * value.detach()
1459
+
1460
+ query = attn.head_to_batch_dim(query).contiguous()
1461
+ key = attn.head_to_batch_dim(key).contiguous()
1462
+ value = attn.head_to_batch_dim(value).contiguous()
1463
+
1464
+ hidden_states = xformers.ops.memory_efficient_attention(
1465
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1466
+ )
1467
+ hidden_states = hidden_states.to(query.dtype)
1468
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1469
+
1470
+ if self.train_q_out:
1471
+ # linear proj
1472
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
1473
+ # dropout
1474
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
1475
+ else:
1476
+ # linear proj
1477
+ hidden_states = attn.to_out[0](hidden_states)
1478
+ # dropout
1479
+ hidden_states = attn.to_out[1](hidden_states)
1480
+ return hidden_states
1481
+
1482
+
1483
+ class SlicedAttnProcessor:
1484
+ r"""
1485
+ Processor for implementing sliced attention.
1486
+
1487
+ Args:
1488
+ slice_size (`int`, *optional*):
1489
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1490
+ `attention_head_dim` must be a multiple of the `slice_size`.
1491
+ """
1492
+
1493
+ def __init__(self, slice_size):
1494
+ self.slice_size = slice_size
1495
+
1496
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
1497
+ residual = hidden_states
1498
+
1499
+ input_ndim = hidden_states.ndim
1500
+
1501
+ if input_ndim == 4:
1502
+ batch_size, channel, height, width = hidden_states.shape
1503
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1504
+
1505
+ batch_size, sequence_length, _ = (
1506
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1507
+ )
1508
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1509
+
1510
+ if attn.group_norm is not None:
1511
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1512
+
1513
+ query = attn.to_q(hidden_states)
1514
+ dim = query.shape[-1]
1515
+ query = attn.head_to_batch_dim(query)
1516
+
1517
+ if encoder_hidden_states is None:
1518
+ encoder_hidden_states = hidden_states
1519
+ elif attn.norm_cross:
1520
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1521
+
1522
+ key = attn.to_k(encoder_hidden_states)
1523
+ value = attn.to_v(encoder_hidden_states)
1524
+ key = attn.head_to_batch_dim(key)
1525
+ value = attn.head_to_batch_dim(value)
1526
+
1527
+ batch_size_attention, query_tokens, _ = query.shape
1528
+ hidden_states = torch.zeros(
1529
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1530
+ )
1531
+
1532
+ for i in range(batch_size_attention // self.slice_size):
1533
+ start_idx = i * self.slice_size
1534
+ end_idx = (i + 1) * self.slice_size
1535
+
1536
+ query_slice = query[start_idx:end_idx]
1537
+ key_slice = key[start_idx:end_idx]
1538
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1539
+
1540
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1541
+
1542
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1543
+
1544
+ hidden_states[start_idx:end_idx] = attn_slice
1545
+
1546
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1547
+
1548
+ # linear proj
1549
+ hidden_states = attn.to_out[0](hidden_states)
1550
+ # dropout
1551
+ hidden_states = attn.to_out[1](hidden_states)
1552
+
1553
+ if input_ndim == 4:
1554
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1555
+
1556
+ if attn.residual_connection:
1557
+ hidden_states = hidden_states + residual
1558
+
1559
+ hidden_states = hidden_states / attn.rescale_output_factor
1560
+
1561
+ return hidden_states
1562
+
1563
+
1564
+ class SlicedAttnAddedKVProcessor:
1565
+ r"""
1566
+ Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
1567
+
1568
+ Args:
1569
+ slice_size (`int`, *optional*):
1570
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1571
+ `attention_head_dim` must be a multiple of the `slice_size`.
1572
+ """
1573
+
1574
+ def __init__(self, slice_size):
1575
+ self.slice_size = slice_size
1576
+
1577
+ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
1578
+ residual = hidden_states
1579
+
1580
+ if attn.spatial_norm is not None:
1581
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1582
+
1583
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1584
+
1585
+ batch_size, sequence_length, _ = hidden_states.shape
1586
+
1587
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1588
+
1589
+ if encoder_hidden_states is None:
1590
+ encoder_hidden_states = hidden_states
1591
+ elif attn.norm_cross:
1592
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1593
+
1594
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1595
+
1596
+ query = attn.to_q(hidden_states)
1597
+ dim = query.shape[-1]
1598
+ query = attn.head_to_batch_dim(query)
1599
+
1600
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1601
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1602
+
1603
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
1604
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
1605
+
1606
+ if not attn.only_cross_attention:
1607
+ key = attn.to_k(hidden_states)
1608
+ value = attn.to_v(hidden_states)
1609
+ key = attn.head_to_batch_dim(key)
1610
+ value = attn.head_to_batch_dim(value)
1611
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1612
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1613
+ else:
1614
+ key = encoder_hidden_states_key_proj
1615
+ value = encoder_hidden_states_value_proj
1616
+
1617
+ batch_size_attention, query_tokens, _ = query.shape
1618
+ hidden_states = torch.zeros(
1619
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1620
+ )
1621
+
1622
+ for i in range(batch_size_attention // self.slice_size):
1623
+ start_idx = i * self.slice_size
1624
+ end_idx = (i + 1) * self.slice_size
1625
+
1626
+ query_slice = query[start_idx:end_idx]
1627
+ key_slice = key[start_idx:end_idx]
1628
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1629
+
1630
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1631
+
1632
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1633
+
1634
+ hidden_states[start_idx:end_idx] = attn_slice
1635
+
1636
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1637
+
1638
+ # linear proj
1639
+ hidden_states = attn.to_out[0](hidden_states)
1640
+ # dropout
1641
+ hidden_states = attn.to_out[1](hidden_states)
1642
+
1643
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1644
+ hidden_states = hidden_states + residual
1645
+
1646
+ return hidden_states
1647
+
1648
+
1649
+ AttentionProcessor = Union[
1650
+ AttnProcessor,
1651
+ AttnProcessor2_0,
1652
+ XFormersAttnProcessor,
1653
+ SlicedAttnProcessor,
1654
+ AttnAddedKVProcessor,
1655
+ SlicedAttnAddedKVProcessor,
1656
+ AttnAddedKVProcessor2_0,
1657
+ XFormersAttnAddedKVProcessor,
1658
+ LoRAAttnProcessor,
1659
+ LoRAXFormersAttnProcessor,
1660
+ LoRAAttnProcessor2_0,
1661
+ LoRAAttnAddedKVProcessor,
1662
+ CustomDiffusionAttnProcessor,
1663
+ CustomDiffusionXFormersAttnProcessor,
1664
+ ]
1665
+
1666
+
1667
+ class SpatialNorm(nn.Module):
1668
+ """
1669
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
1670
+ """
1671
+
1672
+ def __init__(
1673
+ self,
1674
+ f_channels,
1675
+ zq_channels,
1676
+ ):
1677
+ super().__init__()
1678
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
1679
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1680
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1681
+
1682
+ def forward(self, f, zq):
1683
+ f_size = f.shape[-2:]
1684
+ zq = F.interpolate(zq, size=f_size, mode="nearest")
1685
+ norm_f = self.norm_layer(f)
1686
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
1687
+ return new_f
models/dual_transformer_2d.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional
15
+
16
+ from torch import nn
17
+
18
+ from models.transformer_2d import Transformer2DModel, Transformer2DModelOutput
19
+
20
+
21
+ class DualTransformer2DModel(nn.Module):
22
+ """
23
+ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
24
+
25
+ Parameters:
26
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
27
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
28
+ in_channels (`int`, *optional*):
29
+ Pass if the input is continuous. The number of channels in the input and output.
30
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
31
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
32
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
33
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
34
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
35
+ `ImagePositionalEmbeddings`.
36
+ num_vector_embeds (`int`, *optional*):
37
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
38
+ Includes the class for the masked latent pixel.
39
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
40
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
41
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
42
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
43
+ up to but not more than steps than `num_embeds_ada_norm`.
44
+ attention_bias (`bool`, *optional*):
45
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ num_attention_heads: int = 16,
51
+ attention_head_dim: int = 88,
52
+ in_channels: Optional[int] = None,
53
+ num_layers: int = 1,
54
+ dropout: float = 0.0,
55
+ norm_num_groups: int = 32,
56
+ cross_attention_dim: Optional[int] = None,
57
+ attention_bias: bool = False,
58
+ sample_size: Optional[int] = None,
59
+ num_vector_embeds: Optional[int] = None,
60
+ activation_fn: str = "geglu",
61
+ num_embeds_ada_norm: Optional[int] = None,
62
+ ):
63
+ super().__init__()
64
+ self.transformers = nn.ModuleList(
65
+ [
66
+ Transformer2DModel(
67
+ num_attention_heads=num_attention_heads,
68
+ attention_head_dim=attention_head_dim,
69
+ in_channels=in_channels,
70
+ num_layers=num_layers,
71
+ dropout=dropout,
72
+ norm_num_groups=norm_num_groups,
73
+ cross_attention_dim=cross_attention_dim,
74
+ attention_bias=attention_bias,
75
+ sample_size=sample_size,
76
+ num_vector_embeds=num_vector_embeds,
77
+ activation_fn=activation_fn,
78
+ num_embeds_ada_norm=num_embeds_ada_norm,
79
+ )
80
+ for _ in range(2)
81
+ ]
82
+ )
83
+
84
+ # Variables that can be set by a pipeline:
85
+
86
+ # The ratio of transformer1 to transformer2's output states to be combined during inference
87
+ self.mix_ratio = 0.5
88
+
89
+ # The shape of `encoder_hidden_states` is expected to be
90
+ # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
91
+ self.condition_lengths = [77, 257]
92
+
93
+ # Which transformer to use to encode which condition.
94
+ # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
95
+ self.transformer_index_for_condition = [1, 0]
96
+
97
+ def forward(
98
+ self,
99
+ hidden_states,
100
+ encoder_hidden_states,
101
+ timestep=None,
102
+ attention_mask=None,
103
+ cross_attention_kwargs=None,
104
+ return_dict: bool = True,
105
+ ):
106
+ """
107
+ Args:
108
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
109
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
110
+ hidden_states
111
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
112
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
113
+ self-attention.
114
+ timestep ( `torch.long`, *optional*):
115
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
116
+ attention_mask (`torch.FloatTensor`, *optional*):
117
+ Optional attention mask to be applied in Attention
118
+ return_dict (`bool`, *optional*, defaults to `True`):
119
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
120
+
121
+ Returns:
122
+ [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
123
+ [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
124
+ returning a tuple, the first element is the sample tensor.
125
+ """
126
+ input_states = hidden_states
127
+
128
+ encoded_states = []
129
+ tokens_start = 0
130
+ # attention_mask is not used yet
131
+ for i in range(2):
132
+ # for each of the two transformers, pass the corresponding condition tokens
133
+ condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
134
+ transformer_index = self.transformer_index_for_condition[i]
135
+ encoded_state = self.transformers[transformer_index](
136
+ input_states,
137
+ encoder_hidden_states=condition_state,
138
+ timestep=timestep,
139
+ cross_attention_kwargs=cross_attention_kwargs,
140
+ return_dict=False,
141
+ )[0]
142
+ encoded_states.append(encoded_state - input_states)
143
+ tokens_start += self.condition_lengths[i]
144
+
145
+ output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
146
+ output_states = output_states + input_states
147
+
148
+ if not return_dict:
149
+ return (output_states,)
150
+
151
+ return Transformer2DModelOutput(sample=output_states)
models/region_diffusion.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import collections
4
+ import torch.nn as nn
5
+ from functools import partial
6
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
7
+ from diffusers import AutoencoderKL, PNDMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
8
+ from models.unet_2d_condition import UNet2DConditionModel
9
+ from utils.attention_utils import CrossAttentionLayers, SelfAttentionLayers
10
+
11
+ # suppress partial model loading warning
12
+ logging.set_verbosity_error()
13
+
14
+
15
+ class RegionDiffusion(nn.Module):
16
+ def __init__(self, device):
17
+ super().__init__()
18
+
19
+ self.device = device
20
+ self.num_train_timesteps = 1000
21
+ self.clip_gradient = False
22
+
23
+ print(f'[INFO] loading stable diffusion...')
24
+ model_id = 'runwayml/stable-diffusion-v1-5'
25
+
26
+ self.vae = AutoencoderKL.from_pretrained(
27
+ model_id, subfolder="vae").to(self.device)
28
+ self.tokenizer = CLIPTokenizer.from_pretrained(
29
+ model_id, subfolder='tokenizer')
30
+ self.text_encoder = CLIPTextModel.from_pretrained(
31
+ model_id, subfolder='text_encoder').to(self.device)
32
+ self.unet = UNet2DConditionModel.from_pretrained(
33
+ model_id, subfolder="unet").to(self.device)
34
+
35
+ self.scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
36
+ num_train_timesteps=self.num_train_timesteps, skip_prk_steps=True, steps_offset=1)
37
+ self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
38
+
39
+ self.masks = []
40
+ self.attention_maps = None
41
+ self.selfattn_maps = None
42
+ self.crossattn_maps = None
43
+ self.color_loss = torch.nn.functional.mse_loss
44
+ self.forward_hooks = []
45
+ self.forward_replacement_hooks = []
46
+
47
+ print(f'[INFO] loaded stable diffusion!')
48
+
49
+ def get_text_embeds(self, prompt, negative_prompt):
50
+ # prompt, negative_prompt: [str]
51
+
52
+ # Tokenize text and get embeddings
53
+ text_input = self.tokenizer(
54
+ prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
55
+
56
+ with torch.no_grad():
57
+ text_embeddings = self.text_encoder(
58
+ text_input.input_ids.to(self.device))[0]
59
+
60
+ # Do the same for unconditional embeddings
61
+ uncond_input = self.tokenizer(negative_prompt, padding='max_length',
62
+ max_length=self.tokenizer.model_max_length, return_tensors='pt')
63
+
64
+ with torch.no_grad():
65
+ uncond_embeddings = self.text_encoder(
66
+ uncond_input.input_ids.to(self.device))[0]
67
+
68
+ # Cat for final embeddings
69
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
70
+ return text_embeddings
71
+
72
+ def get_text_embeds_list(self, prompts):
73
+ # prompts: [list]
74
+ text_embeddings = []
75
+ for prompt in prompts:
76
+ # Tokenize text and get embeddings
77
+ text_input = self.tokenizer(
78
+ [prompt], padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
79
+
80
+ with torch.no_grad():
81
+ text_embeddings.append(self.text_encoder(
82
+ text_input.input_ids.to(self.device))[0])
83
+
84
+ return text_embeddings
85
+
86
+ def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5,
87
+ latents=None, use_guidance=False, text_format_dict={}, inject_selfattn=0, bg_aug_end=1000):
88
+
89
+ if latents is None:
90
+ latents = torch.randn(
91
+ (1, self.unet.in_channels, height // 8, width // 8), device=self.device)
92
+
93
+ if inject_selfattn > 0:
94
+ latents_reference = latents.clone().detach()
95
+ self.scheduler.set_timesteps(num_inference_steps)
96
+ n_styles = text_embeddings.shape[0]-1
97
+ print(n_styles, len(self.masks))
98
+ assert n_styles == len(self.masks)
99
+
100
+ with torch.autocast('cuda'):
101
+ for i, t in enumerate(self.scheduler.timesteps):
102
+
103
+ # predict the noise residual
104
+ with torch.no_grad():
105
+ # tokens without any attributes
106
+ feat_inject_step = t > (1-inject_selfattn) * 1000
107
+ noise_pred_uncond_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[:1],
108
+ # text_format_dict={})['sample']
109
+ )['sample']
110
+ # tokens without any style or footnote
111
+ self.register_fontsize_hooks(text_format_dict)
112
+ noise_pred_text_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[-1:],
113
+ # text_format_dict=text_format_dict)['sample']
114
+ )['sample']
115
+ self.remove_fontsize_hooks()
116
+ if inject_selfattn > 0 or inject_background > 0:
117
+ noise_pred_uncond_refer = self.unet(latents_reference, t, encoder_hidden_states=text_embeddings[:1],
118
+ # text_format_dict={})['sample']
119
+ )['sample']
120
+ self.register_selfattn_hooks(feat_inject_step)
121
+ noise_pred_text_refer = self.unet(latents_reference, t, encoder_hidden_states=text_embeddings[-1:],
122
+ # text_format_dict={})['sample']
123
+ )['sample']
124
+ self.remove_selfattn_hooks()
125
+ noise_pred_uncond = noise_pred_uncond_cur * self.masks[-1]
126
+ noise_pred_text = noise_pred_text_cur * self.masks[-1]
127
+ # tokens with attributes
128
+ for style_i, mask in enumerate(self.masks[:-1]):
129
+ if t > bg_aug_end:
130
+ rand_rgb = torch.rand([1, 3, 1, 1]).cuda()
131
+ black_background = torch.ones(
132
+ [1, 3, height, width]).cuda()*rand_rgb
133
+ black_latent = self.encode_imgs(
134
+ black_background)
135
+ noise = torch.randn_like(black_latent)
136
+ black_latent_noisy = self.scheduler.add_noise(
137
+ black_latent, noise, t)
138
+ masked_latent = (
139
+ mask > 0.001) * latents + (mask < 0.001) * black_latent_noisy
140
+ noise_pred_uncond_cur = self.unet(masked_latent, t, encoder_hidden_states=text_embeddings[:1],
141
+ text_format_dict={})['sample']
142
+ else:
143
+ masked_latent = latents
144
+ self.register_replacement_hooks(feat_inject_step)
145
+ noise_pred_text_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[style_i+1:style_i+2],
146
+ # text_format_dict={})['sample']
147
+ )['sample']
148
+ self.remove_replacement_hooks()
149
+ noise_pred_uncond = noise_pred_uncond + noise_pred_uncond_cur*mask
150
+ noise_pred_text = noise_pred_text + noise_pred_text_cur*mask
151
+
152
+ # perform guidance
153
+ noise_pred = noise_pred_uncond + guidance_scale * \
154
+ (noise_pred_text - noise_pred_uncond)
155
+
156
+ if inject_selfattn > 0:
157
+ noise_pred_refer = noise_pred_uncond_refer + guidance_scale * \
158
+ (noise_pred_text_refer - noise_pred_uncond_refer)
159
+
160
+ # compute the previous noisy sample x_t -> x_t-1
161
+ latents_reference = self.scheduler.step(torch.cat([noise_pred, noise_pred_refer]), t,
162
+ torch.cat([latents, latents_reference]))[
163
+ 'prev_sample']
164
+ latents, latents_reference = torch.chunk(
165
+ latents_reference, 2, dim=0)
166
+
167
+ else:
168
+ # compute the previous noisy sample x_t -> x_t-1
169
+ latents = self.scheduler.step(noise_pred, t, latents)[
170
+ 'prev_sample']
171
+
172
+ # apply guidance
173
+ if use_guidance and t < text_format_dict['guidance_start_step']:
174
+ with torch.enable_grad():
175
+ if not latents.requires_grad:
176
+ latents.requires_grad = True
177
+ latents_0 = self.predict_x0(latents, noise_pred, t)
178
+ latents_inp = 1 / 0.18215 * latents_0
179
+ imgs = self.vae.decode(latents_inp).sample
180
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
181
+ # save_path = 'results/font_color/20230425/church_process/orange/'
182
+ # os.makedirs(save_path, exist_ok=True)
183
+ # torchvision.utils.save_image(
184
+ # imgs, os.path.join(save_path, 'step%d.png' % t))
185
+ # loss = (((imgs - text_format_dict['target_RGB'])*text_format_dict['color_obj_atten'][:, 0])**2).mean()*100
186
+ loss_total = 0.
187
+ for attn_map, rgb_val in zip(text_format_dict['color_obj_atten'], text_format_dict['target_RGB']):
188
+ # loss = self.color_loss(
189
+ # imgs*attn_map[:, 0], rgb_val*attn_map[:, 0])*100
190
+ avg_rgb = (
191
+ imgs*attn_map[:, 0]).sum(2).sum(2)/attn_map[:, 0].sum()
192
+ loss = self.color_loss(
193
+ avg_rgb, rgb_val[:, :, 0, 0])*100
194
+ # print(loss)
195
+ loss_total += loss
196
+ loss_total.backward()
197
+ latents = (
198
+ latents - latents.grad * text_format_dict['color_guidance_weight'] * self.masks[0]).detach().clone()
199
+
200
+ return latents
201
+
202
+ def predict_x0(self, x_t, eps_t, t):
203
+ alpha_t = self.scheduler.alphas_cumprod[t]
204
+ return (x_t - eps_t * torch.sqrt(1-alpha_t)) / torch.sqrt(alpha_t)
205
+
206
+ def produce_attn_maps(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50,
207
+ guidance_scale=7.5, latents=None):
208
+
209
+ if isinstance(prompts, str):
210
+ prompts = [prompts]
211
+
212
+ if isinstance(negative_prompts, str):
213
+ negative_prompts = [negative_prompts]
214
+
215
+ # Prompts -> text embeds
216
+ text_embeddings = self.get_text_embeds(
217
+ prompts, negative_prompts) # [2, 77, 768]
218
+ if latents is None:
219
+ latents = torch.randn(
220
+ (text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device)
221
+
222
+ self.scheduler.set_timesteps(num_inference_steps)
223
+ self.remove_replacement_hooks()
224
+
225
+ with torch.autocast('cuda'):
226
+ for i, t in enumerate(self.scheduler.timesteps):
227
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
228
+ latent_model_input = torch.cat([latents] * 2)
229
+
230
+ # predict the noise residual
231
+ with torch.no_grad():
232
+ noise_pred = self.unet(
233
+ latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
234
+
235
+ # perform guidance
236
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
237
+ noise_pred = noise_pred_uncond + guidance_scale * \
238
+ (noise_pred_text - noise_pred_uncond)
239
+
240
+ # compute the previous noisy sample x_t -> x_t-1
241
+ latents = self.scheduler.step(noise_pred, t, latents)[
242
+ 'prev_sample']
243
+
244
+ # Img latents -> imgs
245
+ imgs = self.decode_latents(latents) # [1, 3, 512, 512]
246
+
247
+ # Img to Numpy
248
+ imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
249
+ imgs = (imgs * 255).round().astype('uint8')
250
+
251
+ return imgs
252
+
253
+ def decode_latents(self, latents):
254
+
255
+ latents = 1 / 0.18215 * latents
256
+
257
+ with torch.no_grad():
258
+ imgs = self.vae.decode(latents).sample
259
+
260
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
261
+
262
+ return imgs
263
+
264
+ def encode_imgs(self, imgs):
265
+ # imgs: [B, 3, H, W]
266
+
267
+ imgs = 2 * imgs - 1
268
+
269
+ posterior = self.vae.encode(imgs).latent_dist
270
+ latents = posterior.sample() * 0.18215
271
+
272
+ return latents
273
+
274
+ def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50,
275
+ guidance_scale=7.5, latents=None, text_format_dict={}, use_guidance=False, inject_selfattn=0, bg_aug_end=1000):
276
+
277
+ if isinstance(prompts, str):
278
+ prompts = [prompts]
279
+
280
+ if isinstance(negative_prompts, str):
281
+ negative_prompts = [negative_prompts]
282
+
283
+ # Prompts -> text embeds
284
+ text_embeds = self.get_text_embeds(
285
+ prompts, negative_prompts) # [2, 77, 768]
286
+
287
+ # else:
288
+ latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents,
289
+ num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
290
+ use_guidance=use_guidance, text_format_dict=text_format_dict,
291
+ inject_selfattn=inject_selfattn, bg_aug_end=bg_aug_end) # [1, 4, 64, 64]
292
+ # Img latents -> imgs
293
+ imgs = self.decode_latents(latents) # [1, 3, 512, 512]
294
+
295
+ # Img to Numpy
296
+ imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
297
+ imgs = (imgs * 255).round().astype('uint8')
298
+
299
+ return imgs
300
+
301
+ def reset_attention_maps(self):
302
+ r"""Function to reset attention maps.
303
+ We reset attention maps because we append them while getting hooks
304
+ to visualize attention maps for every step.
305
+ """
306
+ for key in self.selfattn_maps:
307
+ self.selfattn_maps[key] = []
308
+ for key in self.crossattn_maps:
309
+ self.crossattn_maps[key] = []
310
+
311
+ def register_evaluation_hooks(self):
312
+ r"""Function for registering hooks during evaluation.
313
+ We mainly store activation maps averaged over queries.
314
+ """
315
+ self.forward_hooks = []
316
+
317
+ def save_activations(activations, name, module, inp, out):
318
+ r"""
319
+ PyTorch Forward hook to save outputs at each forward pass.
320
+ """
321
+ # out[0] - final output of attention layer
322
+ # out[1] - attention probability matrix
323
+ if 'attn2' in name:
324
+ assert out[1].shape[-1] == 77
325
+ activations[name].append(out[1].detach().cpu())
326
+ else:
327
+ assert out[1].shape[-1] != 77
328
+ attention_dict = collections.defaultdict(list)
329
+ for name, module in self.unet.named_modules():
330
+ leaf_name = name.split('.')[-1]
331
+ if 'attn' in leaf_name:
332
+ # Register hook to obtain outputs at every attention layer.
333
+ self.forward_hooks.append(module.register_forward_hook(
334
+ partial(save_activations, attention_dict, name)
335
+ ))
336
+ # attention_dict is a dictionary containing attention maps for every attention layer
337
+ self.attention_maps = attention_dict
338
+
339
+ def register_selfattn_hooks(self, feat_inject_step=False):
340
+ r"""Function for registering hooks during evaluation.
341
+ We mainly store activation maps averaged over queries.
342
+ """
343
+ self.selfattn_forward_hooks = []
344
+
345
+ def save_activations(activations, name, module, inp, out):
346
+ r"""
347
+ PyTorch Forward hook to save outputs at each forward pass.
348
+ """
349
+ # out[0] - final output of attention layer
350
+ # out[1] - attention probability matrix
351
+ if 'attn2' in name:
352
+ assert out[1][1].shape[-1] == 77
353
+ # cross attention injection
354
+ # activations[name] = out[1][1].detach()
355
+ else:
356
+ assert out[1][1].shape[-1] != 77
357
+ activations[name] = out[1][1].detach()
358
+
359
+ def save_resnet_activations(activations, name, module, inp, out):
360
+ r"""
361
+ PyTorch Forward hook to save outputs at each forward pass.
362
+ """
363
+ # out[0] - final output of residual layer
364
+ # out[1] - residual hidden feature
365
+ # import ipdb
366
+ # ipdb.set_trace()
367
+ assert out[1].shape[-1] == 16
368
+ activations[name] = out[1].detach()
369
+ attention_dict = collections.defaultdict(list)
370
+ for name, module in self.unet.named_modules():
371
+ leaf_name = name.split('.')[-1]
372
+ if 'attn' in leaf_name and feat_inject_step:
373
+ # Register hook to obtain outputs at every attention layer.
374
+ self.selfattn_forward_hooks.append(module.register_forward_hook(
375
+ partial(save_activations, attention_dict, name)
376
+ ))
377
+ if name == 'up_blocks.1.resnets.1' and feat_inject_step:
378
+ self.selfattn_forward_hooks.append(module.register_forward_hook(
379
+ partial(save_resnet_activations, attention_dict, name)
380
+ ))
381
+ # attention_dict is a dictionary containing attention maps for every attention layer
382
+ self.self_attention_maps_cur = attention_dict
383
+
384
+ def register_replacement_hooks(self, feat_inject_step=False):
385
+ r"""Function for registering hooks to replace self attention.
386
+ """
387
+ self.forward_replacement_hooks = []
388
+
389
+ def replace_activations(name, module, args):
390
+ r"""
391
+ PyTorch Forward hook to save outputs at each forward pass.
392
+ """
393
+ if 'attn1' in name:
394
+ modified_args = (args[0], self.self_attention_maps_cur[name])
395
+ return modified_args
396
+ # cross attention injection
397
+ # elif 'attn2' in name:
398
+ # modified_map = {
399
+ # 'reference': self.self_attention_maps_cur[name],
400
+ # 'inject_pos': self.inject_pos,
401
+ # }
402
+ # modified_args = (args[0], modified_map)
403
+ # return modified_args
404
+
405
+ def replace_resnet_activations(name, module, args):
406
+ r"""
407
+ PyTorch Forward hook to save outputs at each forward pass.
408
+ """
409
+ modified_args = (args[0], args[1],
410
+ self.self_attention_maps_cur[name])
411
+ return modified_args
412
+ for name, module in self.unet.named_modules():
413
+ leaf_name = name.split('.')[-1]
414
+ if 'attn' in leaf_name and feat_inject_step:
415
+ # Register hook to obtain outputs at every attention layer.
416
+ self.forward_replacement_hooks.append(module.register_forward_pre_hook(
417
+ partial(replace_activations, name)
418
+ ))
419
+ if name == 'up_blocks.1.resnets.1' and feat_inject_step:
420
+ # Register hook to obtain outputs at every attention layer.
421
+ self.forward_replacement_hooks.append(module.register_forward_pre_hook(
422
+ partial(replace_resnet_activations, name)
423
+ ))
424
+
425
+ def register_tokenmap_hooks(self):
426
+ r"""Function for registering hooks during evaluation.
427
+ We mainly store activation maps averaged over queries.
428
+ """
429
+ self.forward_hooks = []
430
+
431
+ def save_activations(selfattn_maps, crossattn_maps, n_maps, name, module, inp, out):
432
+ r"""
433
+ PyTorch Forward hook to save outputs at each forward pass.
434
+ """
435
+ # out[0] - final output of attention layer
436
+ # out[1] - attention probability matrices
437
+ if name in n_maps:
438
+ n_maps[name] += 1
439
+ else:
440
+ n_maps[name] = 1
441
+ if 'attn2' in name:
442
+ assert out[1][0].shape[-1] == 77
443
+ if name in CrossAttentionLayers and n_maps[name] > 10:
444
+ if name in crossattn_maps:
445
+ crossattn_maps[name] += out[1][0].detach().cpu()[1:2]
446
+ else:
447
+ crossattn_maps[name] = out[1][0].detach().cpu()[1:2]
448
+ else:
449
+ assert out[1][0].shape[-1] != 77
450
+ if name in SelfAttentionLayers and n_maps[name] > 10:
451
+ if name in crossattn_maps:
452
+ selfattn_maps[name] += out[1][0].detach().cpu()[1:2]
453
+ else:
454
+ selfattn_maps[name] = out[1][0].detach().cpu()[1:2]
455
+
456
+ selfattn_maps = collections.defaultdict(list)
457
+ crossattn_maps = collections.defaultdict(list)
458
+ n_maps = collections.defaultdict(list)
459
+
460
+ for name, module in self.unet.named_modules():
461
+ leaf_name = name.split('.')[-1]
462
+ if 'attn' in leaf_name:
463
+ # Register hook to obtain outputs at every attention layer.
464
+ self.forward_hooks.append(module.register_forward_hook(
465
+ partial(save_activations, selfattn_maps,
466
+ crossattn_maps, n_maps, name)
467
+ ))
468
+ # attention_dict is a dictionary containing attention maps for every attention layer
469
+ self.selfattn_maps = selfattn_maps
470
+ self.crossattn_maps = crossattn_maps
471
+ self.n_maps = n_maps
472
+
473
+ def remove_tokenmap_hooks(self):
474
+ for hook in self.forward_hooks:
475
+ hook.remove()
476
+ self.selfattn_maps = None
477
+ self.crossattn_maps = None
478
+ self.n_maps = None
479
+
480
+ def remove_evaluation_hooks(self):
481
+ for hook in self.forward_hooks:
482
+ hook.remove()
483
+ self.attention_maps = None
484
+
485
+ def remove_replacement_hooks(self):
486
+ for hook in self.forward_replacement_hooks:
487
+ hook.remove()
488
+
489
+ def remove_selfattn_hooks(self):
490
+ for hook in self.selfattn_forward_hooks:
491
+ hook.remove()
492
+
493
+ def register_fontsize_hooks(self, text_format_dict={}):
494
+ r"""Function for registering hooks to replace self attention.
495
+ """
496
+ self.forward_fontsize_hooks = []
497
+
498
+ def adjust_attn_weights(name, module, args):
499
+ r"""
500
+ PyTorch Forward hook to save outputs at each forward pass.
501
+ """
502
+ if 'attn2' in name:
503
+ modified_args = (args[0], None, attn_weights)
504
+ return modified_args
505
+
506
+ if text_format_dict['word_pos'] is not None and text_format_dict['font_size'] is not None:
507
+ attn_weights = {'word_pos': text_format_dict['word_pos'], 'font_size': text_format_dict['font_size']}
508
+ else:
509
+ attn_weights = None
510
+
511
+ for name, module in self.unet.named_modules():
512
+ leaf_name = name.split('.')[-1]
513
+ if 'attn' in leaf_name and attn_weights is not None:
514
+ # Register hook to obtain outputs at every attention layer.
515
+ self.forward_fontsize_hooks.append(module.register_forward_pre_hook(
516
+ partial(adjust_attn_weights, name)
517
+ ))
518
+
519
+ def remove_fontsize_hooks(self):
520
+ for hook in self.forward_fontsize_hooks:
521
+ hook.remove()
models/region_diffusion_xl.py ADDED
@@ -0,0 +1,1143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from diffusers.pipelines.stable_diffusion.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.py
2
+
3
+ import inspect
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
8
+
9
+ from diffusers.image_processor import VaeImageProcessor
10
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
11
+ # from diffusers.models import AutoencoderKL, UNet2DConditionModel
12
+ from diffusers.models import AutoencoderKL
13
+
14
+ from diffusers.models.attention_processor import (
15
+ AttnProcessor2_0,
16
+ LoRAAttnProcessor2_0,
17
+ LoRAXFormersAttnProcessor,
18
+ XFormersAttnProcessor,
19
+ )
20
+ from diffusers.schedulers import EulerDiscreteScheduler
21
+ from diffusers.utils import (
22
+ is_accelerate_available,
23
+ is_accelerate_version,
24
+ logging,
25
+ randn_tensor,
26
+ replace_example_docstring,
27
+ )
28
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
29
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
30
+
31
+ ### cutomized modules
32
+ import collections
33
+ from functools import partial
34
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
35
+
36
+ from models.unet_2d_condition import UNet2DConditionModel
37
+ from utils.attention_utils import CrossAttentionLayers_XL
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+
42
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
43
+ """
44
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
45
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
46
+ """
47
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
48
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
49
+ # rescale the results from guidance (fixes overexposure)
50
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
51
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
52
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
53
+ return noise_cfg
54
+
55
+
56
+ class RegionDiffusionXL(DiffusionPipeline, FromSingleFileMixin):
57
+ r"""
58
+ Pipeline for text-to-image generation using Stable Diffusion.
59
+
60
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
61
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
62
+
63
+ In addition the pipeline inherits the following loading methods:
64
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
65
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
66
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
67
+
68
+ as well as the following saving methods:
69
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
70
+
71
+ Args:
72
+ vae ([`AutoencoderKL`]):
73
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
74
+ text_encoder ([`CLIPTextModel`]):
75
+ Frozen text-encoder. Stable Diffusion uses the text portion of
76
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
77
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
78
+ tokenizer (`CLIPTokenizer`):
79
+ Tokenizer of class
80
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
81
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
82
+ scheduler ([`SchedulerMixin`]):
83
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
84
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ load_path: str = "stabilityai/stable-diffusion-xl-base-1.0",
90
+ device: str = "cuda",
91
+ force_zeros_for_empty_prompt: bool = True,
92
+ ):
93
+ super().__init__()
94
+
95
+ # self.register_modules(
96
+ # vae=vae,
97
+ # text_encoder=text_encoder,
98
+ # text_encoder_2=text_encoder_2,
99
+ # tokenizer=tokenizer,
100
+ # tokenizer_2=tokenizer_2,
101
+ # unet=unet,
102
+ # scheduler=scheduler,
103
+ # )
104
+
105
+ # 1. Load the autoencoder model which will be used to decode the latents into image space.
106
+ self.vae = AutoencoderKL.from_pretrained(load_path, subfolder="vae", torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
107
+
108
+ # 2. Load the tokenizer and text encoder to tokenize and encode the text.
109
+ self.tokenizer = CLIPTokenizer.from_pretrained(load_path, subfolder='tokenizer')
110
+ self.tokenizer_2 = CLIPTokenizer.from_pretrained(load_path, subfolder='tokenizer_2')
111
+ self.text_encoder = CLIPTextModel.from_pretrained(load_path, subfolder='text_encoder', torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
112
+ self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(load_path, subfolder='text_encoder_2', torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
113
+
114
+ # 3. The UNet model for generating the latents.
115
+ self.unet = UNet2DConditionModel.from_pretrained(load_path, subfolder="unet", torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
116
+
117
+ # 4. Scheduler.
118
+ self.scheduler = EulerDiscreteScheduler.from_pretrained(load_path, subfolder="scheduler")
119
+
120
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
121
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
122
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
123
+ self.default_sample_size = self.unet.config.sample_size
124
+
125
+ self.watermark = StableDiffusionXLWatermarker()
126
+
127
+ self.device_type = device
128
+
129
+ self.masks = []
130
+ self.attention_maps = None
131
+ self.selfattn_maps = None
132
+ self.crossattn_maps = None
133
+ self.color_loss = torch.nn.functional.mse_loss
134
+ self.forward_hooks = []
135
+ self.forward_replacement_hooks = []
136
+
137
+ # Overwriting the method from diffusers.pipelines.diffusion_pipeline.DiffusionPipeline
138
+ @property
139
+ def device(self) -> torch.device:
140
+ r"""
141
+ Returns:
142
+ `torch.device`: The torch device on which the pipeline is located.
143
+ """
144
+
145
+ return torch.device(self.device_type)
146
+
147
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
148
+ def enable_vae_slicing(self):
149
+ r"""
150
+ Enable sliced VAE decoding.
151
+
152
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
153
+ steps. This is useful to save some memory and allow larger batch sizes.
154
+ """
155
+ self.vae.enable_slicing()
156
+
157
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
158
+ def disable_vae_slicing(self):
159
+ r"""
160
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
161
+ computing decoding in one step.
162
+ """
163
+ self.vae.disable_slicing()
164
+
165
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
166
+ def enable_vae_tiling(self):
167
+ r"""
168
+ Enable tiled VAE decoding.
169
+
170
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
171
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
172
+ """
173
+ self.vae.enable_tiling()
174
+
175
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
176
+ def disable_vae_tiling(self):
177
+ r"""
178
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
179
+ computing decoding in one step.
180
+ """
181
+ self.vae.disable_tiling()
182
+
183
+ def enable_sequential_cpu_offload(self, gpu_id=0):
184
+ r"""
185
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
186
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
187
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
188
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
189
+ `enable_model_cpu_offload`, but performance is lower.
190
+ """
191
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
192
+ from accelerate import cpu_offload
193
+ else:
194
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
195
+
196
+ device = torch.device(f"cuda:{gpu_id}")
197
+
198
+ if self.device.type != "cpu":
199
+ self.to("cpu", silence_dtype_warnings=True)
200
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
201
+
202
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.text_encoder_2, self.vae]:
203
+ cpu_offload(cpu_offloaded_model, device)
204
+
205
+ def enable_model_cpu_offload(self, gpu_id=0):
206
+ r"""
207
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
208
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
209
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
210
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
211
+ """
212
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
213
+ from accelerate import cpu_offload_with_hook
214
+ else:
215
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
216
+
217
+ device = torch.device(f"cuda:{gpu_id}")
218
+
219
+ if self.device.type != "cpu":
220
+ self.to("cpu", silence_dtype_warnings=True)
221
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
222
+
223
+ model_sequence = (
224
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
225
+ )
226
+ model_sequence.extend([self.unet, self.vae])
227
+
228
+ hook = None
229
+ for cpu_offloaded_model in model_sequence:
230
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
231
+
232
+ # We'll offload the last model manually.
233
+ self.final_offload_hook = hook
234
+
235
+ @property
236
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
237
+ def _execution_device(self):
238
+ r"""
239
+ Returns the device on which the pipeline's models will be executed. After calling
240
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
241
+ hooks.
242
+ """
243
+ if not hasattr(self.unet, "_hf_hook"):
244
+ return self.device
245
+ for module in self.unet.modules():
246
+ if (
247
+ hasattr(module, "_hf_hook")
248
+ and hasattr(module._hf_hook, "execution_device")
249
+ and module._hf_hook.execution_device is not None
250
+ ):
251
+ return torch.device(module._hf_hook.execution_device)
252
+ return self.device
253
+
254
+ def encode_prompt(
255
+ self,
256
+ prompt,
257
+ device: Optional[torch.device] = None,
258
+ num_images_per_prompt: int = 1,
259
+ do_classifier_free_guidance: bool = True,
260
+ negative_prompt=None,
261
+ prompt_embeds: Optional[torch.FloatTensor] = None,
262
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
263
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
264
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
265
+ lora_scale: Optional[float] = None,
266
+ ):
267
+ r"""
268
+ Encodes the prompt into text encoder hidden states.
269
+
270
+ Args:
271
+ prompt (`str` or `List[str]`, *optional*):
272
+ prompt to be encoded
273
+ device: (`torch.device`):
274
+ torch device
275
+ num_images_per_prompt (`int`):
276
+ number of images that should be generated per prompt
277
+ do_classifier_free_guidance (`bool`):
278
+ whether to use classifier free guidance or not
279
+ negative_prompt (`str` or `List[str]`, *optional*):
280
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
281
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
282
+ less than `1`).
283
+ prompt_embeds (`torch.FloatTensor`, *optional*):
284
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
285
+ provided, text embeddings will be generated from `prompt` input argument.
286
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
287
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
288
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
289
+ argument.
290
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
291
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
292
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
293
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
294
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
295
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
296
+ input argument.
297
+ lora_scale (`float`, *optional*):
298
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
299
+ """
300
+ device = device or self._execution_device
301
+
302
+ # set lora scale so that monkey patched LoRA
303
+ # function of text encoder can correctly access it
304
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
305
+ self._lora_scale = lora_scale
306
+
307
+ if prompt is not None and isinstance(prompt, str):
308
+ batch_size = 1
309
+ elif prompt is not None and isinstance(prompt, list):
310
+ batch_size = len(prompt)
311
+ batch_size_neg = len(negative_prompt)
312
+ else:
313
+ batch_size = prompt_embeds.shape[0]
314
+
315
+ # Define tokenizers and text encoders
316
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
317
+ text_encoders = (
318
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
319
+ )
320
+
321
+ if prompt_embeds is None:
322
+ # textual inversion: procecss multi-vector tokens if necessary
323
+ prompt_embeds_list = []
324
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
325
+ if isinstance(self, TextualInversionLoaderMixin):
326
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
327
+
328
+ text_inputs = tokenizer(
329
+ prompt,
330
+ padding="max_length",
331
+ max_length=tokenizer.model_max_length,
332
+ truncation=True,
333
+ return_tensors="pt",
334
+ )
335
+ text_input_ids = text_inputs.input_ids
336
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
337
+
338
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
339
+ text_input_ids, untruncated_ids
340
+ ):
341
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
342
+ logger.warning(
343
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
344
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
345
+ )
346
+
347
+ prompt_embeds = text_encoder(
348
+ text_input_ids.to(device),
349
+ output_hidden_states=True,
350
+ )
351
+
352
+ # We are only ALWAYS interested in the pooled output of the final text encoder
353
+ pooled_prompt_embeds = prompt_embeds[0]
354
+ prompt_embeds = prompt_embeds.hidden_states[-2]
355
+
356
+ bs_embed, seq_len, _ = prompt_embeds.shape
357
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
358
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
359
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
360
+
361
+ prompt_embeds_list.append(prompt_embeds)
362
+
363
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
364
+
365
+ # get unconditional embeddings for classifier free guidance
366
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
367
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
368
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
369
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
370
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
371
+ negative_prompt = negative_prompt or ""
372
+ uncond_tokens: List[str]
373
+ if prompt is not None and type(prompt) is not type(negative_prompt):
374
+ raise TypeError(
375
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
376
+ f" {type(prompt)}."
377
+ )
378
+ elif isinstance(negative_prompt, str):
379
+ uncond_tokens = [negative_prompt]
380
+ # elif batch_size != len(negative_prompt):
381
+ # raise ValueError(
382
+ # f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
383
+ # f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
384
+ # " the batch size of `prompt`."
385
+ # )
386
+ else:
387
+ uncond_tokens = negative_prompt
388
+
389
+ negative_prompt_embeds_list = []
390
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
391
+ # textual inversion: procecss multi-vector tokens if necessary
392
+ if isinstance(self, TextualInversionLoaderMixin):
393
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
394
+
395
+ max_length = prompt_embeds.shape[1]
396
+ uncond_input = tokenizer(
397
+ uncond_tokens,
398
+ padding="max_length",
399
+ max_length=max_length,
400
+ truncation=True,
401
+ return_tensors="pt",
402
+ )
403
+
404
+ negative_prompt_embeds = text_encoder(
405
+ uncond_input.input_ids.to(device),
406
+ output_hidden_states=True,
407
+ )
408
+ # We are only ALWAYS interested in the pooled output of the final text encoder
409
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
410
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
411
+
412
+ if do_classifier_free_guidance:
413
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
414
+ seq_len = negative_prompt_embeds.shape[1]
415
+
416
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
417
+
418
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
419
+ negative_prompt_embeds = negative_prompt_embeds.view(
420
+ batch_size_neg * num_images_per_prompt, seq_len, -1
421
+ )
422
+
423
+ # For classifier free guidance, we need to do two forward passes.
424
+ # Here we concatenate the unconditional and text embeddings into a single batch
425
+ # to avoid doing two forward passes
426
+
427
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
428
+
429
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
430
+
431
+ bs_embed = pooled_prompt_embeds.shape[0]
432
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
433
+ bs_embed * num_images_per_prompt, -1
434
+ )
435
+ bs_embed = negative_pooled_prompt_embeds.shape[0]
436
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
437
+ bs_embed * num_images_per_prompt, -1
438
+ )
439
+
440
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
441
+
442
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
443
+ def prepare_extra_step_kwargs(self, generator, eta):
444
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
445
+ # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers.
446
+ # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
447
+ # and should be between [0, 1]
448
+
449
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
450
+ extra_step_kwargs = {}
451
+ if accepts_eta:
452
+ extra_step_kwargs["eta"] = eta
453
+
454
+ # check if the scheduler accepts generator
455
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
456
+ if accepts_generator:
457
+ extra_step_kwargs["generator"] = generator
458
+ return extra_step_kwargs
459
+
460
+ def check_inputs(
461
+ self,
462
+ prompt,
463
+ height,
464
+ width,
465
+ callback_steps,
466
+ negative_prompt=None,
467
+ prompt_embeds=None,
468
+ negative_prompt_embeds=None,
469
+ pooled_prompt_embeds=None,
470
+ negative_pooled_prompt_embeds=None,
471
+ ):
472
+ if height % 8 != 0 or width % 8 != 0:
473
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
474
+
475
+ if (callback_steps is None) or (
476
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
477
+ ):
478
+ raise ValueError(
479
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
480
+ f" {type(callback_steps)}."
481
+ )
482
+
483
+ if prompt is not None and prompt_embeds is not None:
484
+ raise ValueError(
485
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
486
+ " only forward one of the two."
487
+ )
488
+ elif prompt is None and prompt_embeds is None:
489
+ raise ValueError(
490
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
491
+ )
492
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
493
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
494
+
495
+ if negative_prompt is not None and negative_prompt_embeds is not None:
496
+ raise ValueError(
497
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
498
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
499
+ )
500
+
501
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
502
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
503
+ raise ValueError(
504
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
505
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
506
+ f" {negative_prompt_embeds.shape}."
507
+ )
508
+
509
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
510
+ raise ValueError(
511
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
512
+ )
513
+
514
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
515
+ raise ValueError(
516
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
517
+ )
518
+
519
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
520
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
521
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
522
+ if isinstance(generator, list) and len(generator) != batch_size:
523
+ raise ValueError(
524
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
525
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
526
+ )
527
+
528
+ if latents is None:
529
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
530
+ else:
531
+ latents = latents.to(device)
532
+
533
+ # scale the initial noise by the standard deviation required by the scheduler
534
+ latents = latents * self.scheduler.init_noise_sigma
535
+ return latents
536
+
537
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
538
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
539
+
540
+ passed_add_embed_dim = (
541
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
542
+ )
543
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
544
+
545
+ if expected_add_embed_dim != passed_add_embed_dim:
546
+ raise ValueError(
547
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
548
+ )
549
+
550
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
551
+ return add_time_ids
552
+
553
+ @torch.no_grad()
554
+ def sample(
555
+ self,
556
+ prompt: Union[str, List[str]] = None,
557
+ height: Optional[int] = None,
558
+ width: Optional[int] = None,
559
+ num_inference_steps: int = 50,
560
+ guidance_scale: float = 5.0,
561
+ negative_prompt: Optional[Union[str, List[str]]] = None,
562
+ num_images_per_prompt: Optional[int] = 1,
563
+ eta: float = 0.0,
564
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
565
+ latents: Optional[torch.FloatTensor] = None,
566
+ prompt_embeds: Optional[torch.FloatTensor] = None,
567
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
568
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
569
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
570
+ output_type: Optional[str] = "pil",
571
+ return_dict: bool = True,
572
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
573
+ callback_steps: int = 1,
574
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
575
+ guidance_rescale: float = 0.0,
576
+ original_size: Optional[Tuple[int, int]] = None,
577
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
578
+ target_size: Optional[Tuple[int, int]] = None,
579
+ # Rich-Text args
580
+ use_guidance: bool = False,
581
+ inject_selfattn: float = 0.0,
582
+ inject_background: float = 0.0,
583
+ text_format_dict: Optional[dict] = None,
584
+ run_rich_text: bool = False,
585
+ ):
586
+ r"""
587
+ Function invoked when calling the pipeline for generation.
588
+
589
+ Args:
590
+ prompt (`str` or `List[str]`, *optional*):
591
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
592
+ instead.
593
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
594
+ The height in pixels of the generated image.
595
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
596
+ The width in pixels of the generated image.
597
+ num_inference_steps (`int`, *optional*, defaults to 50):
598
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
599
+ expense of slower inference.
600
+ guidance_scale (`float`, *optional*, defaults to 7.5):
601
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
602
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
603
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
604
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
605
+ usually at the expense of lower image quality.
606
+ negative_prompt (`str` or `List[str]`, *optional*):
607
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
608
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
609
+ less than `1`).
610
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
611
+ The number of images to generate per prompt.
612
+ eta (`float`, *optional*, defaults to 0.0):
613
+ Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
614
+ [`schedulers.DDIMScheduler`], will be ignored for others.
615
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
616
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
617
+ to make generation deterministic.
618
+ latents (`torch.FloatTensor`, *optional*):
619
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
620
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
621
+ tensor will ge generated by sampling using the supplied random `generator`.
622
+ prompt_embeds (`torch.FloatTensor`, *optional*):
623
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
624
+ provided, text embeddings will be generated from `prompt` input argument.
625
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
626
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
627
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
628
+ argument.
629
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
630
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
631
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
632
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
633
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
634
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
635
+ input argument.
636
+ output_type (`str`, *optional*, defaults to `"pil"`):
637
+ The output format of the generate image. Choose between
638
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
639
+ return_dict (`bool`, *optional*, defaults to `True`):
640
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
641
+ plain tuple.
642
+ callback (`Callable`, *optional*):
643
+ A function that will be called every `callback_steps` steps during inference. The function will be
644
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
645
+ callback_steps (`int`, *optional*, defaults to 1):
646
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
647
+ called at every step.
648
+ cross_attention_kwargs (`dict`, *optional*):
649
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
650
+ `self.processor` in
651
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
652
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
653
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
654
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `Ο†` in equation 16. of
655
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
656
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
657
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
658
+ TODO
659
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
660
+ TODO
661
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
662
+ TODO
663
+
664
+ Examples:
665
+
666
+ Returns:
667
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
668
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
669
+ `tuple. When returning a tuple, the first element is a list with the generated images, and the second
670
+ element is a list of `bool`s denoting whether the corresponding generated image likely represents
671
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
672
+ """
673
+ # 0. Default height and width to unet
674
+ height = height or self.default_sample_size * self.vae_scale_factor
675
+ width = width or self.default_sample_size * self.vae_scale_factor
676
+
677
+ original_size = original_size or (height, width)
678
+ target_size = target_size or (height, width)
679
+
680
+ # 1. Check inputs. Raise error if not correct
681
+ self.check_inputs(
682
+ prompt,
683
+ height,
684
+ width,
685
+ callback_steps,
686
+ negative_prompt,
687
+ prompt_embeds,
688
+ negative_prompt_embeds,
689
+ pooled_prompt_embeds,
690
+ negative_pooled_prompt_embeds,
691
+ )
692
+
693
+ # 2. Define call parameters
694
+ if prompt is not None and isinstance(prompt, str):
695
+ batch_size = 1
696
+ elif prompt is not None and isinstance(prompt, list):
697
+ # TODO: support batched prompts
698
+ batch_size = 1
699
+ # batch_size = len(prompt)
700
+ else:
701
+ batch_size = prompt_embeds.shape[0]
702
+
703
+ device = self._execution_device
704
+
705
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
706
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
707
+ # corresponds to doing no classifier free guidance.
708
+ do_classifier_free_guidance = guidance_scale > 1.0
709
+
710
+ # 3. Encode input prompt
711
+ text_encoder_lora_scale = (
712
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
713
+ )
714
+ (
715
+ prompt_embeds,
716
+ negative_prompt_embeds,
717
+ pooled_prompt_embeds,
718
+ negative_pooled_prompt_embeds,
719
+ ) = self.encode_prompt(
720
+ prompt,
721
+ device,
722
+ num_images_per_prompt,
723
+ do_classifier_free_guidance,
724
+ negative_prompt,
725
+ prompt_embeds=prompt_embeds,
726
+ negative_prompt_embeds=negative_prompt_embeds,
727
+ pooled_prompt_embeds=pooled_prompt_embeds,
728
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
729
+ lora_scale=text_encoder_lora_scale,
730
+ )
731
+
732
+ # 4. Prepare timesteps
733
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
734
+
735
+ timesteps = self.scheduler.timesteps
736
+
737
+ # 5. Prepare latent variables
738
+ num_channels_latents = self.unet.config.in_channels
739
+ latents = self.prepare_latents(
740
+ batch_size * num_images_per_prompt,
741
+ num_channels_latents,
742
+ height,
743
+ width,
744
+ prompt_embeds.dtype,
745
+ device,
746
+ generator,
747
+ latents,
748
+ )
749
+
750
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
751
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
752
+
753
+ # 7. Prepare added time ids & embeddings
754
+ add_text_embeds = pooled_prompt_embeds
755
+ add_time_ids = self._get_add_time_ids(
756
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
757
+ )
758
+
759
+ if do_classifier_free_guidance:
760
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
761
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
762
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
763
+
764
+ prompt_embeds = prompt_embeds.to(device)
765
+ add_text_embeds = add_text_embeds.to(device)
766
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
767
+
768
+ # 8. Denoising loop
769
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
770
+ if run_rich_text:
771
+ if inject_selfattn > 0 or inject_background > 0:
772
+ latents_reference = latents.clone().detach()
773
+ n_styles = prompt_embeds.shape[0]-1
774
+ self.masks = [mask.to(dtype=prompt_embeds.dtype) for mask in self.masks]
775
+ print(n_styles, len(self.masks))
776
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
777
+ for i, t in enumerate(self.scheduler.timesteps):
778
+ # predict the noise residual
779
+ with torch.no_grad():
780
+ feat_inject_step = t > (1-inject_selfattn) * 1000
781
+ background_inject_step = i < inject_background * len(self.scheduler.timesteps)
782
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
783
+ # import ipdb;ipdb.set_trace()
784
+ # unconditional prediction
785
+ noise_pred_uncond_cur = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds[:1],
786
+ cross_attention_kwargs=cross_attention_kwargs,
787
+ added_cond_kwargs={"text_embeds": add_text_embeds[:1], "time_ids": add_time_ids[:1]}
788
+ )['sample']
789
+ # tokens without any style or footnote
790
+ self.register_fontsize_hooks(text_format_dict)
791
+ noise_pred_text_cur = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds[-1:],
792
+ cross_attention_kwargs=cross_attention_kwargs,
793
+ added_cond_kwargs={"text_embeds": add_text_embeds[-1:], "time_ids": add_time_ids[:1]}
794
+ )['sample']
795
+ self.remove_fontsize_hooks()
796
+ if inject_selfattn > 0 or inject_background > 0:
797
+ latent_reference_model_input = self.scheduler.scale_model_input(latents_reference, t)
798
+ noise_pred_uncond_refer = self.unet(latent_reference_model_input, t, encoder_hidden_states=prompt_embeds[:1],
799
+ cross_attention_kwargs=cross_attention_kwargs,
800
+ added_cond_kwargs={"text_embeds": add_text_embeds[:1], "time_ids": add_time_ids[:1]}
801
+ )['sample']
802
+ self.register_selfattn_hooks(feat_inject_step)
803
+ noise_pred_text_refer = self.unet(latent_reference_model_input, t, encoder_hidden_states=prompt_embeds[-1:],
804
+ cross_attention_kwargs=cross_attention_kwargs,
805
+ added_cond_kwargs={"text_embeds": add_text_embeds[-1:], "time_ids": add_time_ids[:1]}
806
+ )['sample']
807
+ self.remove_selfattn_hooks()
808
+ noise_pred_uncond = noise_pred_uncond_cur * self.masks[-1]
809
+ noise_pred_text = noise_pred_text_cur * self.masks[-1]
810
+ # tokens with style or footnote
811
+ for style_i, mask in enumerate(self.masks[:-1]):
812
+ self.register_replacement_hooks(feat_inject_step)
813
+ noise_pred_text_cur = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds[style_i+1:style_i+2],
814
+ cross_attention_kwargs=cross_attention_kwargs,
815
+ added_cond_kwargs={"text_embeds": add_text_embeds[style_i+1:style_i+2], "time_ids": add_time_ids[:1]}
816
+ )['sample']
817
+ self.remove_replacement_hooks()
818
+ noise_pred_uncond = noise_pred_uncond + noise_pred_uncond_cur*mask
819
+ noise_pred_text = noise_pred_text + noise_pred_text_cur*mask
820
+
821
+ # perform guidance
822
+ noise_pred = noise_pred_uncond + guidance_scale * \
823
+ (noise_pred_text - noise_pred_uncond)
824
+
825
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
826
+ # TODO: Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
827
+ # noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
828
+ raise NotImplementedError
829
+
830
+ if inject_selfattn > 0 or background_inject_step > 0:
831
+ noise_pred_refer = noise_pred_uncond_refer + guidance_scale * \
832
+ (noise_pred_text_refer - noise_pred_uncond_refer)
833
+
834
+ # compute the previous noisy sample x_t -> x_t-1
835
+ latents_reference = self.scheduler.step(torch.cat([noise_pred, noise_pred_refer]), t,
836
+ torch.cat([latents, latents_reference]))[
837
+ 'prev_sample']
838
+ latents, latents_reference = torch.chunk(
839
+ latents_reference, 2, dim=0)
840
+
841
+ else:
842
+ # compute the previous noisy sample x_t -> x_t-1
843
+ latents = self.scheduler.step(noise_pred, t, latents)[
844
+ 'prev_sample']
845
+
846
+ # apply guidance
847
+ if use_guidance and t < text_format_dict['guidance_start_step']:
848
+ with torch.enable_grad():
849
+ self.unet.to(device='cpu')
850
+ torch.cuda.empty_cache()
851
+ if not latents.requires_grad:
852
+ latents.requires_grad = True
853
+ # import ipdb;ipdb.set_trace()
854
+ # latents_0 = self.predict_x0(latents, noise_pred, t).to(dtype=latents.dtype)
855
+ latents_0 = self.predict_x0(latents, noise_pred, t).to(dtype=torch.bfloat16)
856
+ latents_inp = latents_0 / self.vae.config.scaling_factor
857
+ imgs = self.vae.to(dtype=latents_inp.dtype).decode(latents_inp).sample
858
+ # imgs = self.vae.decode(latents_inp.to(dtype=torch.float32)).sample
859
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
860
+ loss_total = 0.
861
+ for attn_map, rgb_val in zip(text_format_dict['color_obj_atten'], text_format_dict['target_RGB']):
862
+ avg_rgb = (
863
+ imgs*attn_map[:, 0]).sum(2).sum(2)/attn_map[:, 0].sum()
864
+ loss = self.color_loss(
865
+ avg_rgb, rgb_val[:, :, 0, 0])*100
866
+ loss_total += loss
867
+ loss_total.backward()
868
+ latents = (
869
+ latents - latents.grad * text_format_dict['color_guidance_weight'] * text_format_dict['color_obj_atten_all']).detach().clone().to(dtype=prompt_embeds.dtype)
870
+ self.unet.to(device=latents.device)
871
+
872
+ # apply background injection
873
+ if i == int(inject_background * len(self.scheduler.timesteps)) and inject_background > 0:
874
+ latents = latents_reference * self.masks[-1] + latents * \
875
+ (1-self.masks[-1])
876
+
877
+ # call the callback, if provided
878
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
879
+ progress_bar.update()
880
+ if callback is not None and i % callback_steps == 0:
881
+ callback(i, t, latents)
882
+ else:
883
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
884
+ for i, t in enumerate(timesteps):
885
+ # expand the latents if we are doing classifier free guidance
886
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
887
+
888
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
889
+
890
+ # predict the noise residual
891
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
892
+ noise_pred = self.unet(
893
+ latent_model_input,
894
+ t,
895
+ encoder_hidden_states=prompt_embeds,
896
+ cross_attention_kwargs=cross_attention_kwargs,
897
+ added_cond_kwargs=added_cond_kwargs,
898
+ return_dict=False,
899
+ )[0]
900
+
901
+ # perform guidance
902
+ if do_classifier_free_guidance:
903
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
904
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
905
+
906
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
907
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
908
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
909
+
910
+ # compute the previous noisy sample x_t -> x_t-1
911
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
912
+
913
+ # call the callback, if provided
914
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
915
+ progress_bar.update()
916
+ if callback is not None and i % callback_steps == 0:
917
+ callback(i, t, latents)
918
+
919
+ # make sure the VAE is in float32 mode, as it overflows in float16
920
+ self.vae.to(dtype=torch.float32)
921
+
922
+ use_torch_2_0_or_xformers = isinstance(
923
+ self.vae.decoder.mid_block.attentions[0].processor,
924
+ (
925
+ AttnProcessor2_0,
926
+ XFormersAttnProcessor,
927
+ LoRAXFormersAttnProcessor,
928
+ LoRAAttnProcessor2_0,
929
+ ),
930
+ )
931
+ # if xformers or torch_2_0 is used attention block does not need
932
+ # to be in float32 which can save lots of memory
933
+ if use_torch_2_0_or_xformers:
934
+ self.vae.post_quant_conv.to(latents.dtype)
935
+ self.vae.decoder.conv_in.to(latents.dtype)
936
+ self.vae.decoder.mid_block.to(latents.dtype)
937
+ else:
938
+ latents = latents.float()
939
+
940
+ if not output_type == "latent":
941
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
942
+ else:
943
+ image = latents
944
+ return StableDiffusionXLPipelineOutput(images=image)
945
+
946
+ image = self.watermark.apply_watermark(image)
947
+ image = self.image_processor.postprocess(image, output_type=output_type)
948
+
949
+ # Offload last model to CPU
950
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
951
+ self.final_offload_hook.offload()
952
+
953
+ if not return_dict:
954
+ return (image,)
955
+
956
+ return StableDiffusionXLPipelineOutput(images=image)
957
+
958
+ def predict_x0(self, x_t, eps_t, t):
959
+ alpha_t = self.scheduler.alphas_cumprod[t.cpu().long().item()]
960
+ return (x_t - eps_t * torch.sqrt(1-alpha_t)) / torch.sqrt(alpha_t)
961
+
962
+ def register_tokenmap_hooks(self):
963
+ r"""Function for registering hooks during evaluation.
964
+ We mainly store activation maps averaged over queries.
965
+ """
966
+ self.forward_hooks = []
967
+
968
+ def save_activations(selfattn_maps, crossattn_maps, n_maps, name, module, inp, out):
969
+ r"""
970
+ PyTorch Forward hook to save outputs at each forward pass.
971
+ """
972
+ # out[0] - final output of attention layer
973
+ # out[1] - attention probability matrices
974
+ if name in n_maps:
975
+ n_maps[name] += 1
976
+ else:
977
+ n_maps[name] = 1
978
+ if 'attn2' in name:
979
+ assert out[1][0].shape[-1] == 77
980
+ if name in CrossAttentionLayers_XL and n_maps[name] > 10:
981
+ # if n_maps[name] > 10:
982
+ if name in crossattn_maps:
983
+ crossattn_maps[name] += out[1][0].detach().cpu()[1:2]
984
+ else:
985
+ crossattn_maps[name] = out[1][0].detach().cpu()[1:2]
986
+ # For visualization
987
+ # crossattn_maps[name].append(out[1][0].detach().cpu()[1:2])
988
+ else:
989
+ assert out[1][0].shape[-1] != 77
990
+ # if name in SelfAttentionLayers and n_maps[name] > 10:
991
+ if n_maps[name] > 10:
992
+ if name in selfattn_maps:
993
+ selfattn_maps[name] += out[1][0].detach().cpu()[1:2]
994
+ else:
995
+ selfattn_maps[name] = out[1][0].detach().cpu()[1:2]
996
+
997
+ selfattn_maps = collections.defaultdict(list)
998
+ crossattn_maps = collections.defaultdict(list)
999
+ n_maps = collections.defaultdict(list)
1000
+
1001
+ for name, module in self.unet.named_modules():
1002
+ leaf_name = name.split('.')[-1]
1003
+ if 'attn' in leaf_name:
1004
+ # Register hook to obtain outputs at every attention layer.
1005
+ self.forward_hooks.append(module.register_forward_hook(
1006
+ partial(save_activations, selfattn_maps,
1007
+ crossattn_maps, n_maps, name)
1008
+ ))
1009
+ # attention_dict is a dictionary containing attention maps for every attention layer
1010
+ self.selfattn_maps = selfattn_maps
1011
+ self.crossattn_maps = crossattn_maps
1012
+ self.n_maps = n_maps
1013
+
1014
+ def remove_tokenmap_hooks(self):
1015
+ for hook in self.forward_hooks:
1016
+ hook.remove()
1017
+ self.selfattn_maps = None
1018
+ self.crossattn_maps = None
1019
+ self.n_maps = None
1020
+
1021
+ def register_replacement_hooks(self, feat_inject_step=False):
1022
+ r"""Function for registering hooks to replace self attention.
1023
+ """
1024
+ self.forward_replacement_hooks = []
1025
+
1026
+ def replace_activations(name, module, args):
1027
+ r"""
1028
+ PyTorch Forward hook to save outputs at each forward pass.
1029
+ """
1030
+ if 'attn1' in name:
1031
+ modified_args = (args[0], self.self_attention_maps_cur[name].to(args[0].device))
1032
+ return modified_args
1033
+ # cross attention injection
1034
+ # elif 'attn2' in name:
1035
+ # modified_map = {
1036
+ # 'reference': self.self_attention_maps_cur[name],
1037
+ # 'inject_pos': self.inject_pos,
1038
+ # }
1039
+ # modified_args = (args[0], modified_map)
1040
+ # return modified_args
1041
+
1042
+ def replace_resnet_activations(name, module, args):
1043
+ r"""
1044
+ PyTorch Forward hook to save outputs at each forward pass.
1045
+ """
1046
+ modified_args = (args[0], args[1],
1047
+ self.self_attention_maps_cur[name].to(args[0].device))
1048
+ return modified_args
1049
+ for name, module in self.unet.named_modules():
1050
+ leaf_name = name.split('.')[-1]
1051
+ if 'attn' in leaf_name and feat_inject_step:
1052
+ # Register hook to obtain outputs at every attention layer.
1053
+ self.forward_replacement_hooks.append(module.register_forward_pre_hook(
1054
+ partial(replace_activations, name)
1055
+ ))
1056
+ if name == 'up_blocks.1.resnets.1' and feat_inject_step:
1057
+ # Register hook to obtain outputs at every attention layer.
1058
+ self.forward_replacement_hooks.append(module.register_forward_pre_hook(
1059
+ partial(replace_resnet_activations, name)
1060
+ ))
1061
+
1062
+ def remove_replacement_hooks(self):
1063
+ for hook in self.forward_replacement_hooks:
1064
+ hook.remove()
1065
+
1066
+
1067
+ def register_selfattn_hooks(self, feat_inject_step=False):
1068
+ r"""Function for registering hooks during evaluation.
1069
+ We mainly store activation maps averaged over queries.
1070
+ """
1071
+ self.selfattn_forward_hooks = []
1072
+
1073
+ def save_activations(activations, name, module, inp, out):
1074
+ r"""
1075
+ PyTorch Forward hook to save outputs at each forward pass.
1076
+ """
1077
+ # out[0] - final output of attention layer
1078
+ # out[1] - attention probability matrix
1079
+ if 'attn2' in name:
1080
+ assert out[1][1].shape[-1] == 77
1081
+ # cross attention injection
1082
+ # activations[name] = out[1][1].detach()
1083
+ else:
1084
+ assert out[1][1].shape[-1] != 77
1085
+ activations[name] = out[1][1].detach().cpu()
1086
+
1087
+ def save_resnet_activations(activations, name, module, inp, out):
1088
+ r"""
1089
+ PyTorch Forward hook to save outputs at each forward pass.
1090
+ """
1091
+ # out[0] - final output of residual layer
1092
+ # out[1] - residual hidden feature
1093
+ # import ipdb;ipdb.set_trace()
1094
+ assert out[1].shape[-1] == 64
1095
+ activations[name] = out[1].detach().cpu()
1096
+ attention_dict = collections.defaultdict(list)
1097
+ for name, module in self.unet.named_modules():
1098
+ leaf_name = name.split('.')[-1]
1099
+ if 'attn' in leaf_name and feat_inject_step:
1100
+ # Register hook to obtain outputs at every attention layer.
1101
+ self.selfattn_forward_hooks.append(module.register_forward_hook(
1102
+ partial(save_activations, attention_dict, name)
1103
+ ))
1104
+ if name == 'up_blocks.1.resnets.1' and feat_inject_step:
1105
+ self.selfattn_forward_hooks.append(module.register_forward_hook(
1106
+ partial(save_resnet_activations, attention_dict, name)
1107
+ ))
1108
+ # attention_dict is a dictionary containing attention maps for every attention layer
1109
+ self.self_attention_maps_cur = attention_dict
1110
+
1111
+ def remove_selfattn_hooks(self):
1112
+ for hook in self.selfattn_forward_hooks:
1113
+ hook.remove()
1114
+
1115
+ def register_fontsize_hooks(self, text_format_dict={}):
1116
+ r"""Function for registering hooks to replace self attention.
1117
+ """
1118
+ self.forward_fontsize_hooks = []
1119
+
1120
+ def adjust_attn_weights(name, module, args):
1121
+ r"""
1122
+ PyTorch Forward hook to save outputs at each forward pass.
1123
+ """
1124
+ if 'attn2' in name:
1125
+ modified_args = (args[0], None, attn_weights)
1126
+ return modified_args
1127
+
1128
+ if text_format_dict['word_pos'] is not None and text_format_dict['font_size'] is not None:
1129
+ attn_weights = {'word_pos': text_format_dict['word_pos'], 'font_size': text_format_dict['font_size']}
1130
+ else:
1131
+ attn_weights = None
1132
+
1133
+ for name, module in self.unet.named_modules():
1134
+ leaf_name = name.split('.')[-1]
1135
+ if 'attn' in leaf_name and attn_weights is not None:
1136
+ # Register hook to obtain outputs at every attention layer.
1137
+ self.forward_fontsize_hooks.append(module.register_forward_pre_hook(
1138
+ partial(adjust_attn_weights, name)
1139
+ ))
1140
+
1141
+ def remove_fontsize_hooks(self):
1142
+ for hook in self.forward_fontsize_hooks:
1143
+ hook.remove()
models/resnet.py ADDED
@@ -0,0 +1,882 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from diffusers.models.activations import get_activation
24
+ from diffusers.models.attention import AdaGroupNorm
25
+ from models.attention_processor import SpatialNorm
26
+
27
+
28
+ class Upsample1D(nn.Module):
29
+ """A 1D upsampling layer with an optional convolution.
30
+
31
+ Parameters:
32
+ channels (`int`):
33
+ number of channels in the inputs and outputs.
34
+ use_conv (`bool`, default `False`):
35
+ option to use a convolution.
36
+ use_conv_transpose (`bool`, default `False`):
37
+ option to use a convolution transpose.
38
+ out_channels (`int`, optional):
39
+ number of output channels. Defaults to `channels`.
40
+ """
41
+
42
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
43
+ super().__init__()
44
+ self.channels = channels
45
+ self.out_channels = out_channels or channels
46
+ self.use_conv = use_conv
47
+ self.use_conv_transpose = use_conv_transpose
48
+ self.name = name
49
+
50
+ self.conv = None
51
+ if use_conv_transpose:
52
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
53
+ elif use_conv:
54
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
55
+
56
+ def forward(self, inputs):
57
+ assert inputs.shape[1] == self.channels
58
+ if self.use_conv_transpose:
59
+ return self.conv(inputs)
60
+
61
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
62
+
63
+ if self.use_conv:
64
+ outputs = self.conv(outputs)
65
+
66
+ return outputs
67
+
68
+
69
+ class Downsample1D(nn.Module):
70
+ """A 1D downsampling layer with an optional convolution.
71
+
72
+ Parameters:
73
+ channels (`int`):
74
+ number of channels in the inputs and outputs.
75
+ use_conv (`bool`, default `False`):
76
+ option to use a convolution.
77
+ out_channels (`int`, optional):
78
+ number of output channels. Defaults to `channels`.
79
+ padding (`int`, default `1`):
80
+ padding for the convolution.
81
+ """
82
+
83
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
84
+ super().__init__()
85
+ self.channels = channels
86
+ self.out_channels = out_channels or channels
87
+ self.use_conv = use_conv
88
+ self.padding = padding
89
+ stride = 2
90
+ self.name = name
91
+
92
+ if use_conv:
93
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
94
+ else:
95
+ assert self.channels == self.out_channels
96
+ self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
97
+
98
+ def forward(self, inputs):
99
+ assert inputs.shape[1] == self.channels
100
+ return self.conv(inputs)
101
+
102
+
103
+ class Upsample2D(nn.Module):
104
+ """A 2D upsampling layer with an optional convolution.
105
+
106
+ Parameters:
107
+ channels (`int`):
108
+ number of channels in the inputs and outputs.
109
+ use_conv (`bool`, default `False`):
110
+ option to use a convolution.
111
+ use_conv_transpose (`bool`, default `False`):
112
+ option to use a convolution transpose.
113
+ out_channels (`int`, optional):
114
+ number of output channels. Defaults to `channels`.
115
+ """
116
+
117
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
118
+ super().__init__()
119
+ self.channels = channels
120
+ self.out_channels = out_channels or channels
121
+ self.use_conv = use_conv
122
+ self.use_conv_transpose = use_conv_transpose
123
+ self.name = name
124
+
125
+ conv = None
126
+ if use_conv_transpose:
127
+ conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
128
+ elif use_conv:
129
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
130
+
131
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
132
+ if name == "conv":
133
+ self.conv = conv
134
+ else:
135
+ self.Conv2d_0 = conv
136
+
137
+ def forward(self, hidden_states, output_size=None):
138
+ assert hidden_states.shape[1] == self.channels
139
+
140
+ if self.use_conv_transpose:
141
+ return self.conv(hidden_states)
142
+
143
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
144
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
145
+ # https://github.com/pytorch/pytorch/issues/86679
146
+ dtype = hidden_states.dtype
147
+ if dtype == torch.bfloat16:
148
+ hidden_states = hidden_states.to(torch.float32)
149
+
150
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
151
+ if hidden_states.shape[0] >= 64:
152
+ hidden_states = hidden_states.contiguous()
153
+
154
+ # if `output_size` is passed we force the interpolation output
155
+ # size and do not make use of `scale_factor=2`
156
+ if output_size is None:
157
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
158
+ else:
159
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
160
+
161
+ # If the input is bfloat16, we cast back to bfloat16
162
+ if dtype == torch.bfloat16:
163
+ hidden_states = hidden_states.to(dtype)
164
+
165
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
166
+ if self.use_conv:
167
+ if self.name == "conv":
168
+ hidden_states = self.conv(hidden_states)
169
+ else:
170
+ hidden_states = self.Conv2d_0(hidden_states)
171
+
172
+ return hidden_states
173
+
174
+
175
+ class Downsample2D(nn.Module):
176
+ """A 2D downsampling layer with an optional convolution.
177
+
178
+ Parameters:
179
+ channels (`int`):
180
+ number of channels in the inputs and outputs.
181
+ use_conv (`bool`, default `False`):
182
+ option to use a convolution.
183
+ out_channels (`int`, optional):
184
+ number of output channels. Defaults to `channels`.
185
+ padding (`int`, default `1`):
186
+ padding for the convolution.
187
+ """
188
+
189
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
190
+ super().__init__()
191
+ self.channels = channels
192
+ self.out_channels = out_channels or channels
193
+ self.use_conv = use_conv
194
+ self.padding = padding
195
+ stride = 2
196
+ self.name = name
197
+
198
+ if use_conv:
199
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
200
+ else:
201
+ assert self.channels == self.out_channels
202
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
203
+
204
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
205
+ if name == "conv":
206
+ self.Conv2d_0 = conv
207
+ self.conv = conv
208
+ elif name == "Conv2d_0":
209
+ self.conv = conv
210
+ else:
211
+ self.conv = conv
212
+
213
+ def forward(self, hidden_states):
214
+ assert hidden_states.shape[1] == self.channels
215
+ if self.use_conv and self.padding == 0:
216
+ pad = (0, 1, 0, 1)
217
+ hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
218
+
219
+ assert hidden_states.shape[1] == self.channels
220
+ hidden_states = self.conv(hidden_states)
221
+
222
+ return hidden_states
223
+
224
+
225
+ class FirUpsample2D(nn.Module):
226
+ """A 2D FIR upsampling layer with an optional convolution.
227
+
228
+ Parameters:
229
+ channels (`int`):
230
+ number of channels in the inputs and outputs.
231
+ use_conv (`bool`, default `False`):
232
+ option to use a convolution.
233
+ out_channels (`int`, optional):
234
+ number of output channels. Defaults to `channels`.
235
+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
236
+ kernel for the FIR filter.
237
+ """
238
+
239
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
240
+ super().__init__()
241
+ out_channels = out_channels if out_channels else channels
242
+ if use_conv:
243
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
244
+ self.use_conv = use_conv
245
+ self.fir_kernel = fir_kernel
246
+ self.out_channels = out_channels
247
+
248
+ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
249
+ """Fused `upsample_2d()` followed by `Conv2d()`.
250
+
251
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
252
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
253
+ arbitrary order.
254
+
255
+ Args:
256
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
257
+ weight: Weight tensor of the shape `[filterH, filterW, inChannels,
258
+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
259
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
260
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
261
+ factor: Integer upsampling factor (default: 2).
262
+ gain: Scaling factor for signal magnitude (default: 1.0).
263
+
264
+ Returns:
265
+ output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
266
+ datatype as `hidden_states`.
267
+ """
268
+
269
+ assert isinstance(factor, int) and factor >= 1
270
+
271
+ # Setup filter kernel.
272
+ if kernel is None:
273
+ kernel = [1] * factor
274
+
275
+ # setup kernel
276
+ kernel = torch.tensor(kernel, dtype=torch.float32)
277
+ if kernel.ndim == 1:
278
+ kernel = torch.outer(kernel, kernel)
279
+ kernel /= torch.sum(kernel)
280
+
281
+ kernel = kernel * (gain * (factor**2))
282
+
283
+ if self.use_conv:
284
+ convH = weight.shape[2]
285
+ convW = weight.shape[3]
286
+ inC = weight.shape[1]
287
+
288
+ pad_value = (kernel.shape[0] - factor) - (convW - 1)
289
+
290
+ stride = (factor, factor)
291
+ # Determine data dimensions.
292
+ output_shape = (
293
+ (hidden_states.shape[2] - 1) * factor + convH,
294
+ (hidden_states.shape[3] - 1) * factor + convW,
295
+ )
296
+ output_padding = (
297
+ output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
298
+ output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
299
+ )
300
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
301
+ num_groups = hidden_states.shape[1] // inC
302
+
303
+ # Transpose weights.
304
+ weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
305
+ weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
306
+ weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
307
+
308
+ inverse_conv = F.conv_transpose2d(
309
+ hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
310
+ )
311
+
312
+ output = upfirdn2d_native(
313
+ inverse_conv,
314
+ torch.tensor(kernel, device=inverse_conv.device),
315
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
316
+ )
317
+ else:
318
+ pad_value = kernel.shape[0] - factor
319
+ output = upfirdn2d_native(
320
+ hidden_states,
321
+ torch.tensor(kernel, device=hidden_states.device),
322
+ up=factor,
323
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
324
+ )
325
+
326
+ return output
327
+
328
+ def forward(self, hidden_states):
329
+ if self.use_conv:
330
+ height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
331
+ height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
332
+ else:
333
+ height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
334
+
335
+ return height
336
+
337
+
338
+ class FirDownsample2D(nn.Module):
339
+ """A 2D FIR downsampling layer with an optional convolution.
340
+
341
+ Parameters:
342
+ channels (`int`):
343
+ number of channels in the inputs and outputs.
344
+ use_conv (`bool`, default `False`):
345
+ option to use a convolution.
346
+ out_channels (`int`, optional):
347
+ number of output channels. Defaults to `channels`.
348
+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
349
+ kernel for the FIR filter.
350
+ """
351
+
352
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
353
+ super().__init__()
354
+ out_channels = out_channels if out_channels else channels
355
+ if use_conv:
356
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
357
+ self.fir_kernel = fir_kernel
358
+ self.use_conv = use_conv
359
+ self.out_channels = out_channels
360
+
361
+ def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
362
+ """Fused `Conv2d()` followed by `downsample_2d()`.
363
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
364
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
365
+ arbitrary order.
366
+
367
+ Args:
368
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
369
+ weight:
370
+ Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
371
+ performed by `inChannels = x.shape[0] // numGroups`.
372
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
373
+ factor`, which corresponds to average pooling.
374
+ factor: Integer downsampling factor (default: 2).
375
+ gain: Scaling factor for signal magnitude (default: 1.0).
376
+
377
+ Returns:
378
+ output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
379
+ same datatype as `x`.
380
+ """
381
+
382
+ assert isinstance(factor, int) and factor >= 1
383
+ if kernel is None:
384
+ kernel = [1] * factor
385
+
386
+ # setup kernel
387
+ kernel = torch.tensor(kernel, dtype=torch.float32)
388
+ if kernel.ndim == 1:
389
+ kernel = torch.outer(kernel, kernel)
390
+ kernel /= torch.sum(kernel)
391
+
392
+ kernel = kernel * gain
393
+
394
+ if self.use_conv:
395
+ _, _, convH, convW = weight.shape
396
+ pad_value = (kernel.shape[0] - factor) + (convW - 1)
397
+ stride_value = [factor, factor]
398
+ upfirdn_input = upfirdn2d_native(
399
+ hidden_states,
400
+ torch.tensor(kernel, device=hidden_states.device),
401
+ pad=((pad_value + 1) // 2, pad_value // 2),
402
+ )
403
+ output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
404
+ else:
405
+ pad_value = kernel.shape[0] - factor
406
+ output = upfirdn2d_native(
407
+ hidden_states,
408
+ torch.tensor(kernel, device=hidden_states.device),
409
+ down=factor,
410
+ pad=((pad_value + 1) // 2, pad_value // 2),
411
+ )
412
+
413
+ return output
414
+
415
+ def forward(self, hidden_states):
416
+ if self.use_conv:
417
+ downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
418
+ hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
419
+ else:
420
+ hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
421
+
422
+ return hidden_states
423
+
424
+
425
+ # downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
426
+ class KDownsample2D(nn.Module):
427
+ def __init__(self, pad_mode="reflect"):
428
+ super().__init__()
429
+ self.pad_mode = pad_mode
430
+ kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
431
+ self.pad = kernel_1d.shape[1] // 2 - 1
432
+ self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
433
+
434
+ def forward(self, inputs):
435
+ inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
436
+ weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
437
+ indices = torch.arange(inputs.shape[1], device=inputs.device)
438
+ kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
439
+ weight[indices, indices] = kernel
440
+ return F.conv2d(inputs, weight, stride=2)
441
+
442
+
443
+ class KUpsample2D(nn.Module):
444
+ def __init__(self, pad_mode="reflect"):
445
+ super().__init__()
446
+ self.pad_mode = pad_mode
447
+ kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
448
+ self.pad = kernel_1d.shape[1] // 2 - 1
449
+ self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
450
+
451
+ def forward(self, inputs):
452
+ inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
453
+ weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
454
+ indices = torch.arange(inputs.shape[1], device=inputs.device)
455
+ kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
456
+ weight[indices, indices] = kernel
457
+ return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
458
+
459
+
460
+ class ResnetBlock2D(nn.Module):
461
+ r"""
462
+ A Resnet block.
463
+
464
+ Parameters:
465
+ in_channels (`int`): The number of channels in the input.
466
+ out_channels (`int`, *optional*, default to be `None`):
467
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
468
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
469
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
470
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
471
+ groups_out (`int`, *optional*, default to None):
472
+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
473
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
474
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
475
+ time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
476
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
477
+ "ada_group" for a stronger conditioning with scale and shift.
478
+ kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
479
+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
480
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
481
+ use_in_shortcut (`bool`, *optional*, default to `True`):
482
+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
483
+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
484
+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
485
+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
486
+ `conv_shortcut` output.
487
+ conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
488
+ If None, same as `out_channels`.
489
+ """
490
+
491
+ def __init__(
492
+ self,
493
+ *,
494
+ in_channels,
495
+ out_channels=None,
496
+ conv_shortcut=False,
497
+ dropout=0.0,
498
+ temb_channels=512,
499
+ groups=32,
500
+ groups_out=None,
501
+ pre_norm=True,
502
+ eps=1e-6,
503
+ non_linearity="swish",
504
+ skip_time_act=False,
505
+ time_embedding_norm="default", # default, scale_shift, ada_group, spatial
506
+ kernel=None,
507
+ output_scale_factor=1.0,
508
+ use_in_shortcut=None,
509
+ up=False,
510
+ down=False,
511
+ conv_shortcut_bias: bool = True,
512
+ conv_2d_out_channels: Optional[int] = None,
513
+ ):
514
+ super().__init__()
515
+ self.pre_norm = pre_norm
516
+ self.pre_norm = True
517
+ self.in_channels = in_channels
518
+ out_channels = in_channels if out_channels is None else out_channels
519
+ self.out_channels = out_channels
520
+ self.use_conv_shortcut = conv_shortcut
521
+ self.up = up
522
+ self.down = down
523
+ self.output_scale_factor = output_scale_factor
524
+ self.time_embedding_norm = time_embedding_norm
525
+ self.skip_time_act = skip_time_act
526
+
527
+ if groups_out is None:
528
+ groups_out = groups
529
+
530
+ if self.time_embedding_norm == "ada_group":
531
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
532
+ elif self.time_embedding_norm == "spatial":
533
+ self.norm1 = SpatialNorm(in_channels, temb_channels)
534
+ else:
535
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
536
+
537
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
538
+
539
+ if temb_channels is not None:
540
+ if self.time_embedding_norm == "default":
541
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
542
+ elif self.time_embedding_norm == "scale_shift":
543
+ self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
544
+ elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
545
+ self.time_emb_proj = None
546
+ else:
547
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
548
+ else:
549
+ self.time_emb_proj = None
550
+
551
+ if self.time_embedding_norm == "ada_group":
552
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
553
+ elif self.time_embedding_norm == "spatial":
554
+ self.norm2 = SpatialNorm(out_channels, temb_channels)
555
+ else:
556
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
557
+
558
+ self.dropout = torch.nn.Dropout(dropout)
559
+ conv_2d_out_channels = conv_2d_out_channels or out_channels
560
+ self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
561
+
562
+ self.nonlinearity = get_activation(non_linearity)
563
+
564
+ self.upsample = self.downsample = None
565
+ if self.up:
566
+ if kernel == "fir":
567
+ fir_kernel = (1, 3, 3, 1)
568
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
569
+ elif kernel == "sde_vp":
570
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
571
+ else:
572
+ self.upsample = Upsample2D(in_channels, use_conv=False)
573
+ elif self.down:
574
+ if kernel == "fir":
575
+ fir_kernel = (1, 3, 3, 1)
576
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
577
+ elif kernel == "sde_vp":
578
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
579
+ else:
580
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
581
+
582
+ self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
583
+
584
+ self.conv_shortcut = None
585
+ if self.use_in_shortcut:
586
+ self.conv_shortcut = torch.nn.Conv2d(
587
+ in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
588
+ )
589
+
590
+ # Rich-Text: feature injection
591
+ def forward(self, input_tensor, temb, inject_states=None):
592
+ hidden_states = input_tensor
593
+
594
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
595
+ hidden_states = self.norm1(hidden_states, temb)
596
+ else:
597
+ hidden_states = self.norm1(hidden_states)
598
+
599
+ hidden_states = self.nonlinearity(hidden_states)
600
+
601
+ if self.upsample is not None:
602
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
603
+ if hidden_states.shape[0] >= 64:
604
+ input_tensor = input_tensor.contiguous()
605
+ hidden_states = hidden_states.contiguous()
606
+ input_tensor = self.upsample(input_tensor)
607
+ hidden_states = self.upsample(hidden_states)
608
+ elif self.downsample is not None:
609
+ input_tensor = self.downsample(input_tensor)
610
+ hidden_states = self.downsample(hidden_states)
611
+
612
+ hidden_states = self.conv1(hidden_states)
613
+
614
+ if self.time_emb_proj is not None:
615
+ if not self.skip_time_act:
616
+ temb = self.nonlinearity(temb)
617
+ temb = self.time_emb_proj(temb)[:, :, None, None]
618
+
619
+ if temb is not None and self.time_embedding_norm == "default":
620
+ hidden_states = hidden_states + temb
621
+
622
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
623
+ hidden_states = self.norm2(hidden_states, temb)
624
+ else:
625
+ hidden_states = self.norm2(hidden_states)
626
+
627
+ if temb is not None and self.time_embedding_norm == "scale_shift":
628
+ scale, shift = torch.chunk(temb, 2, dim=1)
629
+ hidden_states = hidden_states * (1 + scale) + shift
630
+
631
+ hidden_states = self.nonlinearity(hidden_states)
632
+
633
+ hidden_states = self.dropout(hidden_states)
634
+ hidden_states = self.conv2(hidden_states)
635
+
636
+ if self.conv_shortcut is not None:
637
+ input_tensor = self.conv_shortcut(input_tensor)
638
+
639
+ # Rich-Text: feature injection
640
+ if inject_states is not None:
641
+ output_tensor = (input_tensor + inject_states) / self.output_scale_factor
642
+ else:
643
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
644
+
645
+ return output_tensor, hidden_states
646
+
647
+
648
+ # unet_rl.py
649
+ def rearrange_dims(tensor):
650
+ if len(tensor.shape) == 2:
651
+ return tensor[:, :, None]
652
+ if len(tensor.shape) == 3:
653
+ return tensor[:, :, None, :]
654
+ elif len(tensor.shape) == 4:
655
+ return tensor[:, :, 0, :]
656
+ else:
657
+ raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
658
+
659
+
660
+ class Conv1dBlock(nn.Module):
661
+ """
662
+ Conv1d --> GroupNorm --> Mish
663
+ """
664
+
665
+ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
666
+ super().__init__()
667
+
668
+ self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
669
+ self.group_norm = nn.GroupNorm(n_groups, out_channels)
670
+ self.mish = nn.Mish()
671
+
672
+ def forward(self, inputs):
673
+ intermediate_repr = self.conv1d(inputs)
674
+ intermediate_repr = rearrange_dims(intermediate_repr)
675
+ intermediate_repr = self.group_norm(intermediate_repr)
676
+ intermediate_repr = rearrange_dims(intermediate_repr)
677
+ output = self.mish(intermediate_repr)
678
+ return output
679
+
680
+
681
+ # unet_rl.py
682
+ class ResidualTemporalBlock1D(nn.Module):
683
+ def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
684
+ super().__init__()
685
+ self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
686
+ self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
687
+
688
+ self.time_emb_act = nn.Mish()
689
+ self.time_emb = nn.Linear(embed_dim, out_channels)
690
+
691
+ self.residual_conv = (
692
+ nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
693
+ )
694
+
695
+ def forward(self, inputs, t):
696
+ """
697
+ Args:
698
+ inputs : [ batch_size x inp_channels x horizon ]
699
+ t : [ batch_size x embed_dim ]
700
+
701
+ returns:
702
+ out : [ batch_size x out_channels x horizon ]
703
+ """
704
+ t = self.time_emb_act(t)
705
+ t = self.time_emb(t)
706
+ out = self.conv_in(inputs) + rearrange_dims(t)
707
+ out = self.conv_out(out)
708
+ return out + self.residual_conv(inputs)
709
+
710
+
711
+ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
712
+ r"""Upsample2D a batch of 2D images with the given filter.
713
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
714
+ filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
715
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
716
+ a: multiple of the upsampling factor.
717
+
718
+ Args:
719
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
720
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
721
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
722
+ factor: Integer upsampling factor (default: 2).
723
+ gain: Scaling factor for signal magnitude (default: 1.0).
724
+
725
+ Returns:
726
+ output: Tensor of the shape `[N, C, H * factor, W * factor]`
727
+ """
728
+ assert isinstance(factor, int) and factor >= 1
729
+ if kernel is None:
730
+ kernel = [1] * factor
731
+
732
+ kernel = torch.tensor(kernel, dtype=torch.float32)
733
+ if kernel.ndim == 1:
734
+ kernel = torch.outer(kernel, kernel)
735
+ kernel /= torch.sum(kernel)
736
+
737
+ kernel = kernel * (gain * (factor**2))
738
+ pad_value = kernel.shape[0] - factor
739
+ output = upfirdn2d_native(
740
+ hidden_states,
741
+ kernel.to(device=hidden_states.device),
742
+ up=factor,
743
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
744
+ )
745
+ return output
746
+
747
+
748
+ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
749
+ r"""Downsample2D a batch of 2D images with the given filter.
750
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
751
+ given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
752
+ specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
753
+ shape is a multiple of the downsampling factor.
754
+
755
+ Args:
756
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
757
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
758
+ (separable). The default is `[1] * factor`, which corresponds to average pooling.
759
+ factor: Integer downsampling factor (default: 2).
760
+ gain: Scaling factor for signal magnitude (default: 1.0).
761
+
762
+ Returns:
763
+ output: Tensor of the shape `[N, C, H // factor, W // factor]`
764
+ """
765
+
766
+ assert isinstance(factor, int) and factor >= 1
767
+ if kernel is None:
768
+ kernel = [1] * factor
769
+
770
+ kernel = torch.tensor(kernel, dtype=torch.float32)
771
+ if kernel.ndim == 1:
772
+ kernel = torch.outer(kernel, kernel)
773
+ kernel /= torch.sum(kernel)
774
+
775
+ kernel = kernel * gain
776
+ pad_value = kernel.shape[0] - factor
777
+ output = upfirdn2d_native(
778
+ hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
779
+ )
780
+ return output
781
+
782
+
783
+ def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
784
+ up_x = up_y = up
785
+ down_x = down_y = down
786
+ pad_x0 = pad_y0 = pad[0]
787
+ pad_x1 = pad_y1 = pad[1]
788
+
789
+ _, channel, in_h, in_w = tensor.shape
790
+ tensor = tensor.reshape(-1, in_h, in_w, 1)
791
+
792
+ _, in_h, in_w, minor = tensor.shape
793
+ kernel_h, kernel_w = kernel.shape
794
+
795
+ out = tensor.view(-1, in_h, 1, in_w, 1, minor)
796
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
797
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
798
+
799
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
800
+ out = out.to(tensor.device) # Move back to mps if necessary
801
+ out = out[
802
+ :,
803
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
804
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
805
+ :,
806
+ ]
807
+
808
+ out = out.permute(0, 3, 1, 2)
809
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
810
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
811
+ out = F.conv2d(out, w)
812
+ out = out.reshape(
813
+ -1,
814
+ minor,
815
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
816
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
817
+ )
818
+ out = out.permute(0, 2, 3, 1)
819
+ out = out[:, ::down_y, ::down_x, :]
820
+
821
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
822
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
823
+
824
+ return out.view(-1, channel, out_h, out_w)
825
+
826
+
827
+ class TemporalConvLayer(nn.Module):
828
+ """
829
+ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
830
+ https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
831
+ """
832
+
833
+ def __init__(self, in_dim, out_dim=None, dropout=0.0):
834
+ super().__init__()
835
+ out_dim = out_dim or in_dim
836
+ self.in_dim = in_dim
837
+ self.out_dim = out_dim
838
+
839
+ # conv layers
840
+ self.conv1 = nn.Sequential(
841
+ nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
842
+ )
843
+ self.conv2 = nn.Sequential(
844
+ nn.GroupNorm(32, out_dim),
845
+ nn.SiLU(),
846
+ nn.Dropout(dropout),
847
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
848
+ )
849
+ self.conv3 = nn.Sequential(
850
+ nn.GroupNorm(32, out_dim),
851
+ nn.SiLU(),
852
+ nn.Dropout(dropout),
853
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
854
+ )
855
+ self.conv4 = nn.Sequential(
856
+ nn.GroupNorm(32, out_dim),
857
+ nn.SiLU(),
858
+ nn.Dropout(dropout),
859
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
860
+ )
861
+
862
+ # zero out the last layer params,so the conv block is identity
863
+ nn.init.zeros_(self.conv4[-1].weight)
864
+ nn.init.zeros_(self.conv4[-1].bias)
865
+
866
+ def forward(self, hidden_states, num_frames=1):
867
+ hidden_states = (
868
+ hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
869
+ )
870
+
871
+ identity = hidden_states
872
+ hidden_states = self.conv1(hidden_states)
873
+ hidden_states = self.conv2(hidden_states)
874
+ hidden_states = self.conv3(hidden_states)
875
+ hidden_states = self.conv4(hidden_states)
876
+
877
+ hidden_states = identity + hidden_states
878
+
879
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
880
+ (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
881
+ )
882
+ return hidden_states
models/transformer_2d.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
23
+ from diffusers.utils import BaseOutput, deprecate
24
+ from diffusers.models.embeddings import PatchEmbed
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+
27
+ from models.attention import BasicTransformerBlock
28
+
29
+ @dataclass
30
+ class Transformer2DModelOutput(BaseOutput):
31
+ """
32
+ The output of [`Transformer2DModel`].
33
+
34
+ Args:
35
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
36
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
37
+ distributions for the unnoised latent pixels.
38
+ """
39
+
40
+ sample: torch.FloatTensor
41
+
42
+
43
+ class Transformer2DModel(ModelMixin, ConfigMixin):
44
+ """
45
+ A 2D Transformer model for image-like data.
46
+
47
+ Parameters:
48
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
49
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
50
+ in_channels (`int`, *optional*):
51
+ The number of channels in the input and output (specify if the input is **continuous**).
52
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
53
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
54
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
55
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
56
+ This is fixed during training since it is used to learn a number of position embeddings.
57
+ num_vector_embeds (`int`, *optional*):
58
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
59
+ Includes the class for the masked latent pixel.
60
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
61
+ num_embeds_ada_norm ( `int`, *optional*):
62
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
63
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
64
+ added to the hidden states.
65
+
66
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
67
+ attention_bias (`bool`, *optional*):
68
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
69
+ """
70
+
71
+ @register_to_config
72
+ def __init__(
73
+ self,
74
+ num_attention_heads: int = 16,
75
+ attention_head_dim: int = 88,
76
+ in_channels: Optional[int] = None,
77
+ out_channels: Optional[int] = None,
78
+ num_layers: int = 1,
79
+ dropout: float = 0.0,
80
+ norm_num_groups: int = 32,
81
+ cross_attention_dim: Optional[int] = None,
82
+ attention_bias: bool = False,
83
+ sample_size: Optional[int] = None,
84
+ num_vector_embeds: Optional[int] = None,
85
+ patch_size: Optional[int] = None,
86
+ activation_fn: str = "geglu",
87
+ num_embeds_ada_norm: Optional[int] = None,
88
+ use_linear_projection: bool = False,
89
+ only_cross_attention: bool = False,
90
+ upcast_attention: bool = False,
91
+ norm_type: str = "layer_norm",
92
+ norm_elementwise_affine: bool = True,
93
+ ):
94
+ super().__init__()
95
+ self.use_linear_projection = use_linear_projection
96
+ self.num_attention_heads = num_attention_heads
97
+ self.attention_head_dim = attention_head_dim
98
+ inner_dim = num_attention_heads * attention_head_dim
99
+
100
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
101
+ # Define whether input is continuous or discrete depending on configuration
102
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
103
+ self.is_input_vectorized = num_vector_embeds is not None
104
+ self.is_input_patches = in_channels is not None and patch_size is not None
105
+
106
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
107
+ deprecation_message = (
108
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
109
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
110
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
111
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
112
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
113
+ )
114
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
115
+ norm_type = "ada_norm"
116
+
117
+ if self.is_input_continuous and self.is_input_vectorized:
118
+ raise ValueError(
119
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
120
+ " sure that either `in_channels` or `num_vector_embeds` is None."
121
+ )
122
+ elif self.is_input_vectorized and self.is_input_patches:
123
+ raise ValueError(
124
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
125
+ " sure that either `num_vector_embeds` or `num_patches` is None."
126
+ )
127
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
128
+ raise ValueError(
129
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
130
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
131
+ )
132
+
133
+ # 2. Define input layers
134
+ if self.is_input_continuous:
135
+ self.in_channels = in_channels
136
+
137
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
138
+ if use_linear_projection:
139
+ self.proj_in = nn.Linear(in_channels, inner_dim)
140
+ else:
141
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
142
+ elif self.is_input_vectorized:
143
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
144
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
145
+
146
+ self.height = sample_size
147
+ self.width = sample_size
148
+ self.num_vector_embeds = num_vector_embeds
149
+ self.num_latent_pixels = self.height * self.width
150
+
151
+ self.latent_image_embedding = ImagePositionalEmbeddings(
152
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
153
+ )
154
+ elif self.is_input_patches:
155
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
156
+
157
+ self.height = sample_size
158
+ self.width = sample_size
159
+
160
+ self.patch_size = patch_size
161
+ self.pos_embed = PatchEmbed(
162
+ height=sample_size,
163
+ width=sample_size,
164
+ patch_size=patch_size,
165
+ in_channels=in_channels,
166
+ embed_dim=inner_dim,
167
+ )
168
+
169
+ # 3. Define transformers blocks
170
+ self.transformer_blocks = nn.ModuleList(
171
+ [
172
+ BasicTransformerBlock(
173
+ inner_dim,
174
+ num_attention_heads,
175
+ attention_head_dim,
176
+ dropout=dropout,
177
+ cross_attention_dim=cross_attention_dim,
178
+ activation_fn=activation_fn,
179
+ num_embeds_ada_norm=num_embeds_ada_norm,
180
+ attention_bias=attention_bias,
181
+ only_cross_attention=only_cross_attention,
182
+ upcast_attention=upcast_attention,
183
+ norm_type=norm_type,
184
+ norm_elementwise_affine=norm_elementwise_affine,
185
+ )
186
+ for d in range(num_layers)
187
+ ]
188
+ )
189
+
190
+ # 4. Define output layers
191
+ self.out_channels = in_channels if out_channels is None else out_channels
192
+ if self.is_input_continuous:
193
+ # TODO: should use out_channels for continuous projections
194
+ if use_linear_projection:
195
+ self.proj_out = nn.Linear(inner_dim, in_channels)
196
+ else:
197
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
198
+ elif self.is_input_vectorized:
199
+ self.norm_out = nn.LayerNorm(inner_dim)
200
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
201
+ elif self.is_input_patches:
202
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
203
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
204
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
205
+
206
+ def forward(
207
+ self,
208
+ hidden_states: torch.Tensor,
209
+ encoder_hidden_states: Optional[torch.Tensor] = None,
210
+ timestep: Optional[torch.LongTensor] = None,
211
+ class_labels: Optional[torch.LongTensor] = None,
212
+ cross_attention_kwargs: Dict[str, Any] = None,
213
+ attention_mask: Optional[torch.Tensor] = None,
214
+ encoder_attention_mask: Optional[torch.Tensor] = None,
215
+ return_dict: bool = True,
216
+ ):
217
+ """
218
+ The [`Transformer2DModel`] forward method.
219
+
220
+ Args:
221
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
222
+ Input `hidden_states`.
223
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
224
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
225
+ self-attention.
226
+ timestep ( `torch.LongTensor`, *optional*):
227
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
228
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
229
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
230
+ `AdaLayerZeroNorm`.
231
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
232
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
233
+
234
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
235
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
236
+
237
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
238
+ above. This bias will be added to the cross-attention scores.
239
+ return_dict (`bool`, *optional*, defaults to `True`):
240
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
241
+ tuple.
242
+
243
+ Returns:
244
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
245
+ `tuple` where the first element is the sample tensor.
246
+ """
247
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
248
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
249
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
250
+ # expects mask of shape:
251
+ # [batch, key_tokens]
252
+ # adds singleton query_tokens dimension:
253
+ # [batch, 1, key_tokens]
254
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
255
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
256
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
257
+ if attention_mask is not None and attention_mask.ndim == 2:
258
+ # assume that mask is expressed as:
259
+ # (1 = keep, 0 = discard)
260
+ # convert mask into a bias that can be added to attention scores:
261
+ # (keep = +0, discard = -10000.0)
262
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
263
+ attention_mask = attention_mask.unsqueeze(1)
264
+
265
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
266
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
267
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
268
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
269
+
270
+ # 1. Input
271
+ if self.is_input_continuous:
272
+ batch, _, height, width = hidden_states.shape
273
+ residual = hidden_states
274
+
275
+ hidden_states = self.norm(hidden_states)
276
+ if not self.use_linear_projection:
277
+ hidden_states = self.proj_in(hidden_states)
278
+ inner_dim = hidden_states.shape[1]
279
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
280
+ else:
281
+ inner_dim = hidden_states.shape[1]
282
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
283
+ hidden_states = self.proj_in(hidden_states)
284
+ elif self.is_input_vectorized:
285
+ hidden_states = self.latent_image_embedding(hidden_states)
286
+ elif self.is_input_patches:
287
+ hidden_states = self.pos_embed(hidden_states)
288
+
289
+ # 2. Blocks
290
+ for block in self.transformer_blocks:
291
+ hidden_states = block(
292
+ hidden_states,
293
+ attention_mask=attention_mask,
294
+ encoder_hidden_states=encoder_hidden_states,
295
+ encoder_attention_mask=encoder_attention_mask,
296
+ timestep=timestep,
297
+ cross_attention_kwargs=cross_attention_kwargs,
298
+ class_labels=class_labels,
299
+ )
300
+
301
+ # 3. Output
302
+ if self.is_input_continuous:
303
+ if not self.use_linear_projection:
304
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
305
+ hidden_states = self.proj_out(hidden_states)
306
+ else:
307
+ hidden_states = self.proj_out(hidden_states)
308
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
309
+
310
+ output = hidden_states + residual
311
+ elif self.is_input_vectorized:
312
+ hidden_states = self.norm_out(hidden_states)
313
+ logits = self.out(hidden_states)
314
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
315
+ logits = logits.permute(0, 2, 1)
316
+
317
+ # log(p(x_0))
318
+ output = F.log_softmax(logits.double(), dim=1).float()
319
+ elif self.is_input_patches:
320
+ # TODO: cleanup!
321
+ conditioning = self.transformer_blocks[0].norm1.emb(
322
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
323
+ )
324
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
325
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
326
+ hidden_states = self.proj_out_2(hidden_states)
327
+
328
+ # unpatchify
329
+ height = width = int(hidden_states.shape[1] ** 0.5)
330
+ hidden_states = hidden_states.reshape(
331
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
332
+ )
333
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
334
+ output = hidden_states.reshape(
335
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
336
+ )
337
+
338
+ if not return_dict:
339
+ return (output,)
340
+
341
+ return Transformer2DModelOutput(sample=output)
models/unet_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
models/unet_2d_condition.py ADDED
@@ -0,0 +1,983 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import UNet2DConditionLoadersMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.activations import get_activation
25
+
26
+ from diffusers.models.embeddings import (
27
+ GaussianFourierProjection,
28
+ ImageHintTimeEmbedding,
29
+ ImageProjection,
30
+ ImageTimeEmbedding,
31
+ TextImageProjection,
32
+ TextImageTimeEmbedding,
33
+ TextTimeEmbedding,
34
+ TimestepEmbedding,
35
+ Timesteps,
36
+ )
37
+ from diffusers.models.modeling_utils import ModelMixin
38
+
39
+ from models.attention_processor import AttentionProcessor, AttnProcessor
40
+
41
+ from models.unet_2d_blocks import (
42
+ CrossAttnDownBlock2D,
43
+ CrossAttnUpBlock2D,
44
+ DownBlock2D,
45
+ UNetMidBlock2DCrossAttn,
46
+ UNetMidBlock2DSimpleCrossAttn,
47
+ UpBlock2D,
48
+ get_down_block,
49
+ get_up_block,
50
+ )
51
+
52
+
53
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
54
+
55
+
56
+ @dataclass
57
+ class UNet2DConditionOutput(BaseOutput):
58
+ """
59
+ The output of [`UNet2DConditionModel`].
60
+
61
+ Args:
62
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
63
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
64
+ """
65
+
66
+ sample: torch.FloatTensor = None
67
+
68
+
69
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
70
+ r"""
71
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
72
+ shaped output.
73
+
74
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
75
+ for all models (such as downloading or saving).
76
+
77
+ Parameters:
78
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
79
+ Height and width of input/output sample.
80
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
81
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
82
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
83
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
84
+ Whether to flip the sin to cos in the time embedding.
85
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
86
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
87
+ The tuple of downsample blocks to use.
88
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
89
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
90
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
91
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
92
+ The tuple of upsample blocks to use.
93
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
94
+ Whether to include self-attention in the basic transformer blocks, see
95
+ [`~models.attention.BasicTransformerBlock`].
96
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
97
+ The tuple of output channels for each block.
98
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
99
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
100
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
101
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
102
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
103
+ If `None`, normalization and activation layers is skipped in post-processing.
104
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
105
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
106
+ The dimension of the cross attention features.
107
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
108
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
109
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
110
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
111
+ encoder_hid_dim (`int`, *optional*, defaults to None):
112
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
113
+ dimension to `cross_attention_dim`.
114
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
115
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
116
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
117
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
118
+ num_attention_heads (`int`, *optional*):
119
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
120
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
121
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
122
+ class_embed_type (`str`, *optional*, defaults to `None`):
123
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
124
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
125
+ addition_embed_type (`str`, *optional*, defaults to `None`):
126
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
127
+ "text". "text" will use the `TextTimeEmbedding` layer.
128
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
129
+ Dimension for the timestep embeddings.
130
+ num_class_embeds (`int`, *optional*, defaults to `None`):
131
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
132
+ class conditioning with `class_embed_type` equal to `None`.
133
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
134
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
135
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
136
+ An optional override for the dimension of the projected time embedding.
137
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
138
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
139
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
140
+ timestep_post_act (`str`, *optional*, defaults to `None`):
141
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
142
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
143
+ The dimension of `cond_proj` layer in the timestep embedding.
144
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
145
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
146
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
147
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
148
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
149
+ embeddings with the class embeddings.
150
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
151
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
152
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
153
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
154
+ otherwise.
155
+ """
156
+
157
+ _supports_gradient_checkpointing = True
158
+
159
+ @register_to_config
160
+ def __init__(
161
+ self,
162
+ sample_size: Optional[int] = None,
163
+ in_channels: int = 4,
164
+ out_channels: int = 4,
165
+ center_input_sample: bool = False,
166
+ flip_sin_to_cos: bool = True,
167
+ freq_shift: int = 0,
168
+ down_block_types: Tuple[str] = (
169
+ "CrossAttnDownBlock2D",
170
+ "CrossAttnDownBlock2D",
171
+ "CrossAttnDownBlock2D",
172
+ "DownBlock2D",
173
+ ),
174
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
175
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
176
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
177
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
178
+ layers_per_block: Union[int, Tuple[int]] = 2,
179
+ downsample_padding: int = 1,
180
+ mid_block_scale_factor: float = 1,
181
+ act_fn: str = "silu",
182
+ norm_num_groups: Optional[int] = 32,
183
+ norm_eps: float = 1e-5,
184
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
185
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
186
+ encoder_hid_dim: Optional[int] = None,
187
+ encoder_hid_dim_type: Optional[str] = None,
188
+ attention_head_dim: Union[int, Tuple[int]] = 8,
189
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
190
+ dual_cross_attention: bool = False,
191
+ use_linear_projection: bool = False,
192
+ class_embed_type: Optional[str] = None,
193
+ addition_embed_type: Optional[str] = None,
194
+ addition_time_embed_dim: Optional[int] = None,
195
+ num_class_embeds: Optional[int] = None,
196
+ upcast_attention: bool = False,
197
+ resnet_time_scale_shift: str = "default",
198
+ resnet_skip_time_act: bool = False,
199
+ resnet_out_scale_factor: int = 1.0,
200
+ time_embedding_type: str = "positional",
201
+ time_embedding_dim: Optional[int] = None,
202
+ time_embedding_act_fn: Optional[str] = None,
203
+ timestep_post_act: Optional[str] = None,
204
+ time_cond_proj_dim: Optional[int] = None,
205
+ conv_in_kernel: int = 3,
206
+ conv_out_kernel: int = 3,
207
+ projection_class_embeddings_input_dim: Optional[int] = None,
208
+ class_embeddings_concat: bool = False,
209
+ mid_block_only_cross_attention: Optional[bool] = None,
210
+ cross_attention_norm: Optional[str] = None,
211
+ addition_embed_type_num_heads=64,
212
+ ):
213
+ super().__init__()
214
+
215
+ self.sample_size = sample_size
216
+
217
+ if num_attention_heads is not None:
218
+ raise ValueError(
219
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
220
+ )
221
+
222
+ # If `num_attention_heads` is not defined (which is the case for most models)
223
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
224
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
225
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
226
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
227
+ # which is why we correct for the naming here.
228
+ num_attention_heads = num_attention_heads or attention_head_dim
229
+
230
+ # Check inputs
231
+ if len(down_block_types) != len(up_block_types):
232
+ raise ValueError(
233
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
234
+ )
235
+
236
+ if len(block_out_channels) != len(down_block_types):
237
+ raise ValueError(
238
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
239
+ )
240
+
241
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
242
+ raise ValueError(
243
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
244
+ )
245
+
246
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
247
+ raise ValueError(
248
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
249
+ )
250
+
251
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
252
+ raise ValueError(
253
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
254
+ )
255
+
256
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
257
+ raise ValueError(
258
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
259
+ )
260
+
261
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
262
+ raise ValueError(
263
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
264
+ )
265
+
266
+ # input
267
+ conv_in_padding = (conv_in_kernel - 1) // 2
268
+ self.conv_in = nn.Conv2d(
269
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
270
+ )
271
+
272
+ # time
273
+ if time_embedding_type == "fourier":
274
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
275
+ if time_embed_dim % 2 != 0:
276
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
277
+ self.time_proj = GaussianFourierProjection(
278
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
279
+ )
280
+ timestep_input_dim = time_embed_dim
281
+ elif time_embedding_type == "positional":
282
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
283
+
284
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
285
+ timestep_input_dim = block_out_channels[0]
286
+ else:
287
+ raise ValueError(
288
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
289
+ )
290
+
291
+ self.time_embedding = TimestepEmbedding(
292
+ timestep_input_dim,
293
+ time_embed_dim,
294
+ act_fn=act_fn,
295
+ post_act_fn=timestep_post_act,
296
+ cond_proj_dim=time_cond_proj_dim,
297
+ )
298
+
299
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
300
+ encoder_hid_dim_type = "text_proj"
301
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
302
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
303
+
304
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
305
+ raise ValueError(
306
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
307
+ )
308
+
309
+ if encoder_hid_dim_type == "text_proj":
310
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
311
+ elif encoder_hid_dim_type == "text_image_proj":
312
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
313
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
314
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
315
+ self.encoder_hid_proj = TextImageProjection(
316
+ text_embed_dim=encoder_hid_dim,
317
+ image_embed_dim=cross_attention_dim,
318
+ cross_attention_dim=cross_attention_dim,
319
+ )
320
+ elif encoder_hid_dim_type == "image_proj":
321
+ # Kandinsky 2.2
322
+ self.encoder_hid_proj = ImageProjection(
323
+ image_embed_dim=encoder_hid_dim,
324
+ cross_attention_dim=cross_attention_dim,
325
+ )
326
+ elif encoder_hid_dim_type is not None:
327
+ raise ValueError(
328
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
329
+ )
330
+ else:
331
+ self.encoder_hid_proj = None
332
+
333
+ # class embedding
334
+ if class_embed_type is None and num_class_embeds is not None:
335
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
336
+ elif class_embed_type == "timestep":
337
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
338
+ elif class_embed_type == "identity":
339
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
340
+ elif class_embed_type == "projection":
341
+ if projection_class_embeddings_input_dim is None:
342
+ raise ValueError(
343
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
344
+ )
345
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
346
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
347
+ # 2. it projects from an arbitrary input dimension.
348
+ #
349
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
350
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
351
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
352
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
353
+ elif class_embed_type == "simple_projection":
354
+ if projection_class_embeddings_input_dim is None:
355
+ raise ValueError(
356
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
357
+ )
358
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
359
+ else:
360
+ self.class_embedding = None
361
+
362
+ if addition_embed_type == "text":
363
+ if encoder_hid_dim is not None:
364
+ text_time_embedding_from_dim = encoder_hid_dim
365
+ else:
366
+ text_time_embedding_from_dim = cross_attention_dim
367
+
368
+ self.add_embedding = TextTimeEmbedding(
369
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
370
+ )
371
+ elif addition_embed_type == "text_image":
372
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
373
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
374
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
375
+ self.add_embedding = TextImageTimeEmbedding(
376
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
377
+ )
378
+ elif addition_embed_type == "text_time":
379
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
380
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
381
+ elif addition_embed_type == "image":
382
+ # Kandinsky 2.2
383
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
384
+ elif addition_embed_type == "image_hint":
385
+ # Kandinsky 2.2 ControlNet
386
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
387
+ elif addition_embed_type is not None:
388
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
389
+
390
+ if time_embedding_act_fn is None:
391
+ self.time_embed_act = None
392
+ else:
393
+ self.time_embed_act = get_activation(time_embedding_act_fn)
394
+
395
+ self.down_blocks = nn.ModuleList([])
396
+ self.up_blocks = nn.ModuleList([])
397
+
398
+ if isinstance(only_cross_attention, bool):
399
+ if mid_block_only_cross_attention is None:
400
+ mid_block_only_cross_attention = only_cross_attention
401
+
402
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
403
+
404
+ if mid_block_only_cross_attention is None:
405
+ mid_block_only_cross_attention = False
406
+
407
+ if isinstance(num_attention_heads, int):
408
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
409
+
410
+ if isinstance(attention_head_dim, int):
411
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
412
+
413
+ if isinstance(cross_attention_dim, int):
414
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
415
+
416
+ if isinstance(layers_per_block, int):
417
+ layers_per_block = [layers_per_block] * len(down_block_types)
418
+
419
+ if isinstance(transformer_layers_per_block, int):
420
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
421
+
422
+ if class_embeddings_concat:
423
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
424
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
425
+ # regular time embeddings
426
+ blocks_time_embed_dim = time_embed_dim * 2
427
+ else:
428
+ blocks_time_embed_dim = time_embed_dim
429
+
430
+ # down
431
+ output_channel = block_out_channels[0]
432
+ for i, down_block_type in enumerate(down_block_types):
433
+ input_channel = output_channel
434
+ output_channel = block_out_channels[i]
435
+ is_final_block = i == len(block_out_channels) - 1
436
+
437
+ down_block = get_down_block(
438
+ down_block_type,
439
+ num_layers=layers_per_block[i],
440
+ transformer_layers_per_block=transformer_layers_per_block[i],
441
+ in_channels=input_channel,
442
+ out_channels=output_channel,
443
+ temb_channels=blocks_time_embed_dim,
444
+ add_downsample=not is_final_block,
445
+ resnet_eps=norm_eps,
446
+ resnet_act_fn=act_fn,
447
+ resnet_groups=norm_num_groups,
448
+ cross_attention_dim=cross_attention_dim[i],
449
+ num_attention_heads=num_attention_heads[i],
450
+ downsample_padding=downsample_padding,
451
+ dual_cross_attention=dual_cross_attention,
452
+ use_linear_projection=use_linear_projection,
453
+ only_cross_attention=only_cross_attention[i],
454
+ upcast_attention=upcast_attention,
455
+ resnet_time_scale_shift=resnet_time_scale_shift,
456
+ resnet_skip_time_act=resnet_skip_time_act,
457
+ resnet_out_scale_factor=resnet_out_scale_factor,
458
+ cross_attention_norm=cross_attention_norm,
459
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
460
+ )
461
+ self.down_blocks.append(down_block)
462
+
463
+ # mid
464
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
465
+ self.mid_block = UNetMidBlock2DCrossAttn(
466
+ transformer_layers_per_block=transformer_layers_per_block[-1],
467
+ in_channels=block_out_channels[-1],
468
+ temb_channels=blocks_time_embed_dim,
469
+ resnet_eps=norm_eps,
470
+ resnet_act_fn=act_fn,
471
+ output_scale_factor=mid_block_scale_factor,
472
+ resnet_time_scale_shift=resnet_time_scale_shift,
473
+ cross_attention_dim=cross_attention_dim[-1],
474
+ num_attention_heads=num_attention_heads[-1],
475
+ resnet_groups=norm_num_groups,
476
+ dual_cross_attention=dual_cross_attention,
477
+ use_linear_projection=use_linear_projection,
478
+ upcast_attention=upcast_attention,
479
+ )
480
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
481
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
482
+ in_channels=block_out_channels[-1],
483
+ temb_channels=blocks_time_embed_dim,
484
+ resnet_eps=norm_eps,
485
+ resnet_act_fn=act_fn,
486
+ output_scale_factor=mid_block_scale_factor,
487
+ cross_attention_dim=cross_attention_dim[-1],
488
+ attention_head_dim=attention_head_dim[-1],
489
+ resnet_groups=norm_num_groups,
490
+ resnet_time_scale_shift=resnet_time_scale_shift,
491
+ skip_time_act=resnet_skip_time_act,
492
+ only_cross_attention=mid_block_only_cross_attention,
493
+ cross_attention_norm=cross_attention_norm,
494
+ )
495
+ elif mid_block_type is None:
496
+ self.mid_block = None
497
+ else:
498
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
499
+
500
+ # count how many layers upsample the images
501
+ self.num_upsamplers = 0
502
+
503
+ # up
504
+ reversed_block_out_channels = list(reversed(block_out_channels))
505
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
506
+ reversed_layers_per_block = list(reversed(layers_per_block))
507
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
508
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
509
+ only_cross_attention = list(reversed(only_cross_attention))
510
+
511
+ output_channel = reversed_block_out_channels[0]
512
+ for i, up_block_type in enumerate(up_block_types):
513
+ is_final_block = i == len(block_out_channels) - 1
514
+
515
+ prev_output_channel = output_channel
516
+ output_channel = reversed_block_out_channels[i]
517
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
518
+
519
+ # add upsample block for all BUT final layer
520
+ if not is_final_block:
521
+ add_upsample = True
522
+ self.num_upsamplers += 1
523
+ else:
524
+ add_upsample = False
525
+
526
+ up_block = get_up_block(
527
+ up_block_type,
528
+ num_layers=reversed_layers_per_block[i] + 1,
529
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
530
+ in_channels=input_channel,
531
+ out_channels=output_channel,
532
+ prev_output_channel=prev_output_channel,
533
+ temb_channels=blocks_time_embed_dim,
534
+ add_upsample=add_upsample,
535
+ resnet_eps=norm_eps,
536
+ resnet_act_fn=act_fn,
537
+ resnet_groups=norm_num_groups,
538
+ cross_attention_dim=reversed_cross_attention_dim[i],
539
+ num_attention_heads=reversed_num_attention_heads[i],
540
+ dual_cross_attention=dual_cross_attention,
541
+ use_linear_projection=use_linear_projection,
542
+ only_cross_attention=only_cross_attention[i],
543
+ upcast_attention=upcast_attention,
544
+ resnet_time_scale_shift=resnet_time_scale_shift,
545
+ resnet_skip_time_act=resnet_skip_time_act,
546
+ resnet_out_scale_factor=resnet_out_scale_factor,
547
+ cross_attention_norm=cross_attention_norm,
548
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
549
+ )
550
+ self.up_blocks.append(up_block)
551
+ prev_output_channel = output_channel
552
+
553
+ # out
554
+ if norm_num_groups is not None:
555
+ self.conv_norm_out = nn.GroupNorm(
556
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
557
+ )
558
+
559
+ self.conv_act = get_activation(act_fn)
560
+
561
+ else:
562
+ self.conv_norm_out = None
563
+ self.conv_act = None
564
+
565
+ conv_out_padding = (conv_out_kernel - 1) // 2
566
+ self.conv_out = nn.Conv2d(
567
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
568
+ )
569
+
570
+ @property
571
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
572
+ r"""
573
+ Returns:
574
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
575
+ indexed by its weight name.
576
+ """
577
+ # set recursively
578
+ processors = {}
579
+
580
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
581
+ if hasattr(module, "set_processor"):
582
+ processors[f"{name}.processor"] = module.processor
583
+
584
+ for sub_name, child in module.named_children():
585
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
586
+
587
+ return processors
588
+
589
+ for name, module in self.named_children():
590
+ fn_recursive_add_processors(name, module, processors)
591
+
592
+ return processors
593
+
594
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
595
+ r"""
596
+ Sets the attention processor to use to compute attention.
597
+
598
+ Parameters:
599
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
600
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
601
+ for **all** `Attention` layers.
602
+
603
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
604
+ processor. This is strongly recommended when setting trainable attention processors.
605
+
606
+ """
607
+ count = len(self.attn_processors.keys())
608
+
609
+ if isinstance(processor, dict) and len(processor) != count:
610
+ raise ValueError(
611
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
612
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
613
+ )
614
+
615
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
616
+ if hasattr(module, "set_processor"):
617
+ if not isinstance(processor, dict):
618
+ module.set_processor(processor)
619
+ else:
620
+ module.set_processor(processor.pop(f"{name}.processor"))
621
+
622
+ for sub_name, child in module.named_children():
623
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
624
+
625
+ for name, module in self.named_children():
626
+ fn_recursive_attn_processor(name, module, processor)
627
+
628
+ def set_default_attn_processor(self):
629
+ """
630
+ Disables custom attention processors and sets the default attention implementation.
631
+ """
632
+ self.set_attn_processor(AttnProcessor())
633
+
634
+ def set_attention_slice(self, slice_size):
635
+ r"""
636
+ Enable sliced attention computation.
637
+
638
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
639
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
640
+
641
+ Args:
642
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
643
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
644
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
645
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
646
+ must be a multiple of `slice_size`.
647
+ """
648
+ sliceable_head_dims = []
649
+
650
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
651
+ if hasattr(module, "set_attention_slice"):
652
+ sliceable_head_dims.append(module.sliceable_head_dim)
653
+
654
+ for child in module.children():
655
+ fn_recursive_retrieve_sliceable_dims(child)
656
+
657
+ # retrieve number of attention layers
658
+ for module in self.children():
659
+ fn_recursive_retrieve_sliceable_dims(module)
660
+
661
+ num_sliceable_layers = len(sliceable_head_dims)
662
+
663
+ if slice_size == "auto":
664
+ # half the attention head size is usually a good trade-off between
665
+ # speed and memory
666
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
667
+ elif slice_size == "max":
668
+ # make smallest slice possible
669
+ slice_size = num_sliceable_layers * [1]
670
+
671
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
672
+
673
+ if len(slice_size) != len(sliceable_head_dims):
674
+ raise ValueError(
675
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
676
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
677
+ )
678
+
679
+ for i in range(len(slice_size)):
680
+ size = slice_size[i]
681
+ dim = sliceable_head_dims[i]
682
+ if size is not None and size > dim:
683
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
684
+
685
+ # Recursively walk through all the children.
686
+ # Any children which exposes the set_attention_slice method
687
+ # gets the message
688
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
689
+ if hasattr(module, "set_attention_slice"):
690
+ module.set_attention_slice(slice_size.pop())
691
+
692
+ for child in module.children():
693
+ fn_recursive_set_attention_slice(child, slice_size)
694
+
695
+ reversed_slice_size = list(reversed(slice_size))
696
+ for module in self.children():
697
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
698
+
699
+ def _set_gradient_checkpointing(self, module, value=False):
700
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
701
+ module.gradient_checkpointing = value
702
+
703
+ def forward(
704
+ self,
705
+ sample: torch.FloatTensor,
706
+ timestep: Union[torch.Tensor, float, int],
707
+ encoder_hidden_states: torch.Tensor,
708
+ class_labels: Optional[torch.Tensor] = None,
709
+ timestep_cond: Optional[torch.Tensor] = None,
710
+ attention_mask: Optional[torch.Tensor] = None,
711
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
712
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
713
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
714
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
715
+ encoder_attention_mask: Optional[torch.Tensor] = None,
716
+ return_dict: bool = True,
717
+ ) -> Union[UNet2DConditionOutput, Tuple]:
718
+ r"""
719
+ The [`UNet2DConditionModel`] forward method.
720
+
721
+ Args:
722
+ sample (`torch.FloatTensor`):
723
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
724
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
725
+ encoder_hidden_states (`torch.FloatTensor`):
726
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
727
+ encoder_attention_mask (`torch.Tensor`):
728
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
729
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
730
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
731
+ return_dict (`bool`, *optional*, defaults to `True`):
732
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
733
+ tuple.
734
+ cross_attention_kwargs (`dict`, *optional*):
735
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
736
+ added_cond_kwargs: (`dict`, *optional*):
737
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
738
+ are passed along to the UNet blocks.
739
+
740
+ Returns:
741
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
742
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
743
+ a `tuple` is returned where the first element is the sample tensor.
744
+ """
745
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
746
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
747
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
748
+ # on the fly if necessary.
749
+ default_overall_up_factor = 2**self.num_upsamplers
750
+
751
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
752
+ forward_upsample_size = False
753
+ upsample_size = None
754
+
755
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
756
+ logger.info("Forward upsample size to force interpolation output size.")
757
+ forward_upsample_size = True
758
+
759
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
760
+ # expects mask of shape:
761
+ # [batch, key_tokens]
762
+ # adds singleton query_tokens dimension:
763
+ # [batch, 1, key_tokens]
764
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
765
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
766
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
767
+ if attention_mask is not None:
768
+ # assume that mask is expressed as:
769
+ # (1 = keep, 0 = discard)
770
+ # convert mask into a bias that can be added to attention scores:
771
+ # (keep = +0, discard = -10000.0)
772
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
773
+ attention_mask = attention_mask.unsqueeze(1)
774
+
775
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
776
+ if encoder_attention_mask is not None:
777
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
778
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
779
+
780
+ # 0. center input if necessary
781
+ if self.config.center_input_sample:
782
+ sample = 2 * sample - 1.0
783
+
784
+ # 1. time
785
+ timesteps = timestep
786
+ if not torch.is_tensor(timesteps):
787
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
788
+ # This would be a good case for the `match` statement (Python 3.10+)
789
+ is_mps = sample.device.type == "mps"
790
+ if isinstance(timestep, float):
791
+ dtype = torch.float32 if is_mps else torch.float64
792
+ else:
793
+ dtype = torch.int32 if is_mps else torch.int64
794
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
795
+ elif len(timesteps.shape) == 0:
796
+ timesteps = timesteps[None].to(sample.device)
797
+
798
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
799
+ timesteps = timesteps.expand(sample.shape[0])
800
+
801
+ t_emb = self.time_proj(timesteps)
802
+
803
+ # `Timesteps` does not contain any weights and will always return f32 tensors
804
+ # but time_embedding might actually be running in fp16. so we need to cast here.
805
+ # there might be better ways to encapsulate this.
806
+ t_emb = t_emb.to(dtype=sample.dtype)
807
+
808
+ emb = self.time_embedding(t_emb, timestep_cond)
809
+ aug_emb = None
810
+
811
+ if self.class_embedding is not None:
812
+ if class_labels is None:
813
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
814
+
815
+ if self.config.class_embed_type == "timestep":
816
+ class_labels = self.time_proj(class_labels)
817
+
818
+ # `Timesteps` does not contain any weights and will always return f32 tensors
819
+ # there might be better ways to encapsulate this.
820
+ class_labels = class_labels.to(dtype=sample.dtype)
821
+
822
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
823
+
824
+ if self.config.class_embeddings_concat:
825
+ emb = torch.cat([emb, class_emb], dim=-1)
826
+ else:
827
+ emb = emb + class_emb
828
+
829
+ if self.config.addition_embed_type == "text":
830
+ aug_emb = self.add_embedding(encoder_hidden_states)
831
+ elif self.config.addition_embed_type == "text_image":
832
+ # Kandinsky 2.1 - style
833
+ if "image_embeds" not in added_cond_kwargs:
834
+ raise ValueError(
835
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
836
+ )
837
+
838
+ image_embs = added_cond_kwargs.get("image_embeds")
839
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
840
+ aug_emb = self.add_embedding(text_embs, image_embs)
841
+ elif self.config.addition_embed_type == "text_time":
842
+ if "text_embeds" not in added_cond_kwargs:
843
+ raise ValueError(
844
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
845
+ )
846
+ text_embeds = added_cond_kwargs.get("text_embeds")
847
+ if "time_ids" not in added_cond_kwargs:
848
+ raise ValueError(
849
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
850
+ )
851
+ time_ids = added_cond_kwargs.get("time_ids")
852
+ time_embeds = self.add_time_proj(time_ids.flatten())
853
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
854
+
855
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
856
+ add_embeds = add_embeds.to(emb.dtype)
857
+ aug_emb = self.add_embedding(add_embeds)
858
+ elif self.config.addition_embed_type == "image":
859
+ # Kandinsky 2.2 - style
860
+ if "image_embeds" not in added_cond_kwargs:
861
+ raise ValueError(
862
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
863
+ )
864
+ image_embs = added_cond_kwargs.get("image_embeds")
865
+ aug_emb = self.add_embedding(image_embs)
866
+ elif self.config.addition_embed_type == "image_hint":
867
+ # Kandinsky 2.2 - style
868
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
869
+ raise ValueError(
870
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
871
+ )
872
+ image_embs = added_cond_kwargs.get("image_embeds")
873
+ hint = added_cond_kwargs.get("hint")
874
+ aug_emb, hint = self.add_embedding(image_embs, hint)
875
+ sample = torch.cat([sample, hint], dim=1)
876
+
877
+ emb = emb + aug_emb if aug_emb is not None else emb
878
+
879
+ if self.time_embed_act is not None:
880
+ emb = self.time_embed_act(emb)
881
+
882
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
883
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
884
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
885
+ # Kadinsky 2.1 - style
886
+ if "image_embeds" not in added_cond_kwargs:
887
+ raise ValueError(
888
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
889
+ )
890
+
891
+ image_embeds = added_cond_kwargs.get("image_embeds")
892
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
893
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
894
+ # Kandinsky 2.2 - style
895
+ if "image_embeds" not in added_cond_kwargs:
896
+ raise ValueError(
897
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
898
+ )
899
+ image_embeds = added_cond_kwargs.get("image_embeds")
900
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
901
+ # 2. pre-process
902
+ sample = self.conv_in(sample)
903
+
904
+ # 3. down
905
+ down_block_res_samples = (sample,)
906
+ for downsample_block in self.down_blocks:
907
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
908
+ sample, res_samples = downsample_block(
909
+ hidden_states=sample,
910
+ temb=emb,
911
+ encoder_hidden_states=encoder_hidden_states,
912
+ attention_mask=attention_mask,
913
+ cross_attention_kwargs=cross_attention_kwargs,
914
+ encoder_attention_mask=encoder_attention_mask,
915
+ )
916
+ else:
917
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
918
+
919
+ down_block_res_samples += res_samples
920
+
921
+ if down_block_additional_residuals is not None:
922
+ new_down_block_res_samples = ()
923
+
924
+ for down_block_res_sample, down_block_additional_residual in zip(
925
+ down_block_res_samples, down_block_additional_residuals
926
+ ):
927
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
928
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
929
+
930
+ down_block_res_samples = new_down_block_res_samples
931
+
932
+ # 4. mid
933
+ if self.mid_block is not None:
934
+ sample = self.mid_block(
935
+ sample,
936
+ emb,
937
+ encoder_hidden_states=encoder_hidden_states,
938
+ attention_mask=attention_mask,
939
+ cross_attention_kwargs=cross_attention_kwargs,
940
+ encoder_attention_mask=encoder_attention_mask,
941
+ )
942
+
943
+ if mid_block_additional_residual is not None:
944
+ sample = sample + mid_block_additional_residual
945
+
946
+ # 5. up
947
+ for i, upsample_block in enumerate(self.up_blocks):
948
+ is_final_block = i == len(self.up_blocks) - 1
949
+
950
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
951
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
952
+
953
+ # if we have not reached the final block and need to forward the
954
+ # upsample size, we do it here
955
+ if not is_final_block and forward_upsample_size:
956
+ upsample_size = down_block_res_samples[-1].shape[2:]
957
+
958
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
959
+ sample = upsample_block(
960
+ hidden_states=sample,
961
+ temb=emb,
962
+ res_hidden_states_tuple=res_samples,
963
+ encoder_hidden_states=encoder_hidden_states,
964
+ cross_attention_kwargs=cross_attention_kwargs,
965
+ upsample_size=upsample_size,
966
+ attention_mask=attention_mask,
967
+ encoder_attention_mask=encoder_attention_mask,
968
+ )
969
+ else:
970
+ sample = upsample_block(
971
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
972
+ )
973
+
974
+ # 6. post-process
975
+ if self.conv_norm_out:
976
+ sample = self.conv_norm_out(sample)
977
+ sample = self.conv_act(sample)
978
+ sample = self.conv_out(sample)
979
+
980
+ if not return_dict:
981
+ return (sample,)
982
+
983
+ return UNet2DConditionOutput(sample=sample)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu117
2
+ torch==1.13.1
3
+ torchvision==0.14.1
4
+ diffusers==0.18.2
5
+ transformers==4.27.0
6
+ safetensors==0.3.1
7
+ invisible_watermark==0.2.0
8
+ numpy==1.24.3
9
+ seaborn==0.12.2
10
+ accelerate==0.16.0
11
+ scikit-learn==1.1.3
rich-text-to-json-iframe.html ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <title>Rich Text to JSON</title>
6
+ <link rel="stylesheet" href="https://cdn.quilljs.com/1.3.6/quill.snow.css">
7
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css">
8
+ <link rel="stylesheet" type="text/css"
9
+ href="https://cdnjs.cloudflare.com/ajax/libs/spectrum/1.8.0/spectrum.min.css">
10
+ <link rel="stylesheet"
11
+ href='https://fonts.googleapis.com/css?family=Mirza|Roboto|Slabo+27px|Sofia|Inconsolata|Ubuntu|Akronim|Monoton&display=swap'>
12
+ <style>
13
+ html,
14
+ body {
15
+ background-color: white;
16
+ margin: 0;
17
+ }
18
+
19
+ /* Set default font-family */
20
+ .ql-snow .ql-tooltip::before {
21
+ content: "Footnote";
22
+ line-height: 26px;
23
+ margin-right: 8px;
24
+ }
25
+
26
+ .ql-snow .ql-tooltip[data-mode=link]::before {
27
+ content: "Enter footnote:";
28
+ }
29
+
30
+ .row {
31
+ margin-top: 15px;
32
+ margin-left: 0px;
33
+ margin-bottom: 15px;
34
+ }
35
+
36
+ .btn-primary {
37
+ color: #ffffff;
38
+ background-color: #2780e3;
39
+ border-color: #2780e3;
40
+ }
41
+
42
+ .btn-primary:hover {
43
+ color: #ffffff;
44
+ background-color: #1967be;
45
+ border-color: #1862b5;
46
+ }
47
+
48
+ .btn {
49
+ display: inline-block;
50
+ margin-bottom: 0;
51
+ font-weight: normal;
52
+ text-align: center;
53
+ vertical-align: middle;
54
+ touch-action: manipulation;
55
+ cursor: pointer;
56
+ background-image: none;
57
+ border: 1px solid transparent;
58
+ white-space: nowrap;
59
+ padding: 10px 18px;
60
+ font-size: 15px;
61
+ line-height: 1.42857143;
62
+ border-radius: 0;
63
+ user-select: none;
64
+ }
65
+
66
+ #standalone-container {
67
+ width: 100%;
68
+ background-color: #ffffff;
69
+ }
70
+
71
+ #editor-container {
72
+ font-family: "Aref Ruqaa";
73
+ font-size: 18px;
74
+ height: 250px;
75
+ width: 100%;
76
+ }
77
+
78
+ #toolbar-container {
79
+ font-family: "Aref Ruqaa";
80
+ display: flex;
81
+ flex-wrap: wrap;
82
+ }
83
+
84
+ #json-container {
85
+ max-width: 720px;
86
+ }
87
+
88
+ /* Set dropdown font-families */
89
+ #toolbar-container .ql-font span[data-label="Base"]::before {
90
+ font-family: "Aref Ruqaa";
91
+ }
92
+
93
+ #toolbar-container .ql-font span[data-label="Claude Monet"]::before {
94
+ font-family: "Mirza";
95
+ }
96
+
97
+ #toolbar-container .ql-font span[data-label="Ukiyoe"]::before {
98
+ font-family: "Roboto";
99
+ }
100
+
101
+ #toolbar-container .ql-font span[data-label="Cyber Punk"]::before {
102
+ font-family: "Comic Sans MS";
103
+ }
104
+
105
+ #toolbar-container .ql-font span[data-label="Pop Art"]::before {
106
+ font-family: "sofia";
107
+ }
108
+
109
+ #toolbar-container .ql-font span[data-label="Van Gogh"]::before {
110
+ font-family: "slabo 27px";
111
+ }
112
+
113
+ #toolbar-container .ql-font span[data-label="Pixel Art"]::before {
114
+ font-family: "inconsolata";
115
+ }
116
+
117
+ #toolbar-container .ql-font span[data-label="Rembrandt"]::before {
118
+ font-family: "ubuntu";
119
+ }
120
+
121
+ #toolbar-container .ql-font span[data-label="Cubism"]::before {
122
+ font-family: "Akronim";
123
+ }
124
+
125
+ #toolbar-container .ql-font span[data-label="Neon Art"]::before {
126
+ font-family: "Monoton";
127
+ }
128
+
129
+ /* Set content font-families */
130
+ .ql-font-mirza {
131
+ font-family: "Mirza";
132
+ }
133
+
134
+ .ql-font-roboto {
135
+ font-family: "Roboto";
136
+ }
137
+
138
+ .ql-font-cursive {
139
+ font-family: "Comic Sans MS";
140
+ }
141
+
142
+ .ql-font-sofia {
143
+ font-family: "sofia";
144
+ }
145
+
146
+ .ql-font-slabo {
147
+ font-family: "slabo 27px";
148
+ }
149
+
150
+ .ql-font-inconsolata {
151
+ font-family: "inconsolata";
152
+ }
153
+
154
+ .ql-font-ubuntu {
155
+ font-family: "ubuntu";
156
+ }
157
+
158
+ .ql-font-Akronim {
159
+ font-family: "Akronim";
160
+ }
161
+
162
+ .ql-font-Monoton {
163
+ font-family: "Monoton";
164
+ }
165
+
166
+ .ql-color .ql-picker-options [data-value=Color-Picker] {
167
+ background: none !important;
168
+ width: 100% !important;
169
+ height: 20px !important;
170
+ text-align: center;
171
+ }
172
+
173
+ .ql-color .ql-picker-options [data-value=Color-Picker]:before {
174
+ content: 'Color Picker';
175
+ }
176
+
177
+ .ql-color .ql-picker-options [data-value=Color-Picker]:hover {
178
+ border-color: transparent !important;
179
+ }
180
+ </style>
181
+ </head>
182
+
183
+ <body>
184
+ <div id="standalone-container">
185
+ <div id="toolbar-container">
186
+ <span class="ql-formats">
187
+ <select class="ql-font">
188
+ <option selected>Base</option>
189
+ <option value="mirza">Claude Monet</option>
190
+ <option value="roboto">Ukiyoe</option>
191
+ <option value="cursive">Cyber Punk</option>
192
+ <option value="sofia">Pop Art</option>
193
+ <option value="slabo">Van Gogh</option>
194
+ <option value="inconsolata">Pixel Art</option>
195
+ <option value="ubuntu">Rembrandt</option>
196
+ <option value="Akronim">Cubism</option>
197
+ <option value="Monoton">Neon Art</option>
198
+ </select>
199
+ <select class="ql-size">
200
+ <option value="18px">Small</option>
201
+ <option selected>Normal</option>
202
+ <option value="32px">Large</option>
203
+ <option value="50px">Huge</option>
204
+ </select>
205
+ </span>
206
+ <span class="ql-formats">
207
+ <button class="ql-strike"></button>
208
+ </span>
209
+ <!-- <span class="ql-formats">
210
+ <button class="ql-bold"></button>
211
+ <button class="ql-italic"></button>
212
+ <button class="ql-underline"></button>
213
+ </span> -->
214
+ <span class="ql-formats">
215
+ <select class="ql-color">
216
+ <option value="Color-Picker"></option>
217
+ </select>
218
+ <!-- <select class="ql-background"></select> -->
219
+ </span>
220
+ <!-- <span class="ql-formats">
221
+ <button class="ql-script" value="sub"></button>
222
+ <button class="ql-script" value="super"></button>
223
+ </span>
224
+ <span class="ql-formats">
225
+ <button class="ql-header" value="1"></button>
226
+ <button class="ql-header" value="2"></button>
227
+ <button class="ql-blockquote"></button>
228
+ <button class="ql-code-block"></button>
229
+ </span>
230
+ <span class="ql-formats">
231
+ <button class="ql-list" value="ordered"></button>
232
+ <button class="ql-list" value="bullet"></button>
233
+ <button class="ql-indent" value="-1"></button>
234
+ <button class="ql-indent" value="+1"></button>
235
+ </span>
236
+ <span class="ql-formats">
237
+ <button class="ql-direction" value="rtl"></button>
238
+ <select class="ql-align"></select>
239
+ </span>
240
+ <span class="ql-formats">
241
+ <button class="ql-link"></button>
242
+ <button class="ql-image"></button>
243
+ <button class="ql-video"></button>
244
+ <button class="ql-formula"></button>
245
+ </span> -->
246
+ <span class="ql-formats">
247
+ <button class="ql-link"></button>
248
+ </span>
249
+ <span class="ql-formats">
250
+ <button class="ql-clean"></button>
251
+ </span>
252
+ </div>
253
+ <div id="editor-container" style="height:300px;"></div>
254
+ </div>
255
+ <script src="https://cdn.quilljs.com/1.3.6/quill.min.js"></script>
256
+ <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.1.0/jquery.min.js"></script>
257
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/spectrum/1.8.0/spectrum.min.js"></script>
258
+ <script>
259
+
260
+ // Register the customs format with Quill
261
+ const Font = Quill.import('formats/font');
262
+ Font.whitelist = ['mirza', 'roboto', 'sofia', 'slabo', 'inconsolata', 'ubuntu', 'cursive', 'Akronim', 'Monoton'];
263
+ const Link = Quill.import('formats/link');
264
+ Link.sanitize = function (url) {
265
+ // modify url if desired
266
+ return url;
267
+ }
268
+ const SizeStyle = Quill.import('attributors/style/size');
269
+ SizeStyle.whitelist = ['10px', '18px', '20px', '32px', '50px', '60px', '64px', '70px'];
270
+ Quill.register(SizeStyle, true);
271
+ Quill.register(Link, true);
272
+ Quill.register(Font, true);
273
+ const icons = Quill.import('ui/icons');
274
+ icons['link'] = `<svg xmlns="http://www.w3.org/2000/svg" width="17" viewBox="0 0 512 512" xml:space="preserve"><path fill="#010101" d="M276.75 1c4.51 3.23 9.2 6.04 12.97 9.77 29.7 29.45 59.15 59.14 88.85 88.6 4.98 4.93 7.13 10.37 7.12 17.32-.1 125.8-.09 251.6-.01 377.4 0 7.94-1.96 14.46-9.62 18.57-121.41.34-242.77.34-364.76.05A288.3 288.3 0 0 1 1 502c0-163.02 0-326.04.34-489.62C3.84 6.53 8.04 3.38 13 1c23.35 0 46.7 0 70.82.3 2.07.43 3.38.68 4.69.68h127.98c18.44.01 36.41.04 54.39-.03 1.7 0 3.41-.62 5.12-.95h.75M33.03 122.5v359.05h320.22V129.18h-76.18c-14.22-.01-19.8-5.68-19.8-20.09V33.31H33.02v89.19m256.29-27.36c.72.66 1.44 1.9 2.17 1.9 12.73.12 25.46.08 37.55.08L289.3 57.45v37.7z"/><path fill="#020202" d="M513 375.53c-4.68 7.99-11.52 10.51-20.21 10.25-13.15-.4-26.32-.1-39.48-.1h-5.58c5.49 8.28 10.7 15.74 15.46 23.47 6.06 9.82 1.14 21.65-9.96 24.27-6.7 1.59-12.45-.64-16.23-6.15a2608.6 2608.6 0 0 1-32.97-49.36c-3.57-5.48-3.39-11.54.17-16.98a3122.5 3122.5 0 0 1 32.39-48.56c5.22-7.65 14.67-9.35 21.95-4.45 7.63 5.12 9.6 14.26 4.5 22.33-4.75 7.54-9.8 14.9-15.11 22.95h33.64V225.19h-5.24c-19.49 0-38.97.11-58.46-.05-12.74-.1-20.12-13.15-13.84-24.14 3.12-5.46 8.14-7.71 14.18-7.73 26.15-.06 52.3-.04 78.45 0 7.1 0 12.47 3.05 16.01 9.64.33 57.44.33 114.8.33 172.62z"/><path fill="#111" d="M216.03 1.97C173.52 1.98 131 2 88.5 1.98a16 16 0 0 1-4.22-.68c43.4-.3 87.09-.3 131.24-.06.48.25.5.73.5.73z"/><path fill="#232323" d="M216.5 1.98c-.47 0-.5-.5-.5-.74C235.7 1 255.38 1 275.53 1c-1.24.33-2.94.95-4.65.95-17.98.07-35.95.04-54.39.03z"/><path fill="#040404" d="M148 321.42h153.5c14.25 0 19.96 5.71 19.96 19.97.01 19.17.03 38.33 0 57.5-.03 12.6-6.16 18.78-18.66 18.78H99.81c-12.42 0-18.75-6.34-18.76-18.73-.01-19.83-.02-39.66 0-59.5.02-11.47 6.4-17.93 17.95-18 16.17-.08 32.33-.02 49-.02m40.5 32.15h-75.16v31.84h175.7v-31.84H188.5z"/><path fill="#030303" d="m110 225.33 178.89-.03c11.98 0 19.25 9.95 15.74 21.44-2.05 6.71-7.5 10.57-15.14 10.57-63.63 0-127.25-.01-190.88-.07-12.03-.02-19.17-8.62-16.7-19.84 1.6-7.21 7.17-11.74 15.1-12.04 4.17-.16 8.33-.03 13-.03zm-24.12-36.19c-5.28-6.2-6.3-12.76-2.85-19.73 3.22-6.49 9.13-8.24 15.86-8.24 25.64.01 51.27-.06 76.91.04 13.07.04 20.66 10.44 16.33 22.08-2.25 6.06-6.63 9.76-13.08 9.8-27.97.18-55.94.2-83.9-.07-3.01-.03-6-2.36-9.27-3.88z"/></svg>`
275
+ const quill = new Quill('#editor-container', {
276
+ modules: {
277
+ toolbar: {
278
+ container: '#toolbar-container',
279
+ },
280
+ },
281
+ theme: 'snow'
282
+ });
283
+ var toolbar = quill.getModule('toolbar');
284
+ $(toolbar.container).find('.ql-color').spectrum({
285
+ preferredFormat: "rgb",
286
+ showInput: true,
287
+ showInitial: true,
288
+ showPalette: true,
289
+ showSelectionPalette: true,
290
+ palette: [
291
+ ["#000", "#444", "#666", "#999", "#ccc", "#eee", "#f3f3f3", "#fff"],
292
+ ["#f00", "#f90", "#ff0", "#0f0", "#0ff", "#00f", "#90f", "#f0f"],
293
+ ["#ea9999", "#f9cb9c", "#ffe599", "#b6d7a8", "#a2c4c9", "#9fc5e8", "#b4a7d6", "#d5a6bd"],
294
+ ["#e06666", "#f6b26b", "#ffd966", "#93c47d", "#76a5af", "#6fa8dc", "#8e7cc3", "#c27ba0"],
295
+ ["#c00", "#e69138", "#f1c232", "#6aa84f", "#45818e", "#3d85c6", "#674ea7", "#a64d79"],
296
+ ["#900", "#b45f06", "#bf9000", "#38761d", "#134f5c", "#0b5394", "#351c75", "#741b47"],
297
+ ["#600", "#783f04", "#7f6000", "#274e13", "#0c343d", "#073763", "#20124d", "#4c1130"]
298
+ ],
299
+ change: function (color) {
300
+ var value = color.toHexString();
301
+ quill.format('color', value);
302
+ }
303
+ });
304
+
305
+ quill.on('text-change', () => {
306
+ // keep qull data inside _data to communicate with Gradio
307
+ document.body._data = quill.getContents()
308
+ })
309
+ function setQuillContents(content) {
310
+ quill.setContents(content);
311
+ document.body._data = quill.getContents();
312
+ }
313
+ document.body.setQuillContents = setQuillContents
314
+ </script>
315
+ <script src="https://unpkg.com/@popperjs/core@2/dist/umd/popper.min.js"></script>
316
+ <script src="https://unpkg.com/tippy.js@6/dist/tippy-bundle.umd.js"></script>
317
+ <script>
318
+ // With the above scripts loaded, you can call `tippy()` with a CSS
319
+ // selector and a `content` prop:
320
+ tippy('.ql-font', {
321
+ content: 'Add a style to the token',
322
+ });
323
+ tippy('.ql-size', {
324
+ content: 'Reweight the token',
325
+ });
326
+ tippy('.ql-color', {
327
+ content: 'Pick a color for the token',
328
+ });
329
+ tippy('.ql-link', {
330
+ content: 'Clarify the token',
331
+ });
332
+ tippy('.ql-strike', {
333
+ content: 'Change the token weight to be negative',
334
+ });
335
+ tippy('.ql-clean', {
336
+ content: 'Remove all the formats',
337
+ });
338
+ </script>
339
+ </body>
340
+
341
+ </html>
rich-text-to-json.js ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class RichTextEditor extends HTMLElement {
2
+ constructor() {
3
+ super();
4
+ this.loadExternalScripts();
5
+ this.attachShadow({ mode: 'open' });
6
+ this.shadowRoot.innerHTML = `
7
+ ${RichTextEditor.header()}
8
+ ${RichTextEditor.template()}
9
+ `;
10
+ }
11
+ connectedCallback() {
12
+ this.myQuill = this.mountQuill();
13
+ }
14
+ loadExternalScripts() {
15
+ const links = ["https://cdn.quilljs.com/1.3.6/quill.snow.css", "https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css", "https://fonts.googleapis.com/css?family=Mirza|Roboto|Slabo+27px|Sofia|Inconsolata|Ubuntu|Akronim|Monoton&display=swap"]
16
+ links.forEach(link => {
17
+ const css = document.createElement("link");
18
+ css.href = link;
19
+ css.rel = "stylesheet"
20
+ document.head.appendChild(css);
21
+ })
22
+
23
+ }
24
+ static template() {
25
+ return `
26
+ <div id="standalone-container">
27
+ <div id="toolbar-container">
28
+ <span class="ql-formats">
29
+ <select class="ql-font">
30
+ <option selected>Base</option>
31
+ <option value="mirza">Claude Monet</option>
32
+ <option value="roboto">Ukiyoe</option>
33
+ <option value="cursive">Cyber Punk</option>
34
+ <option value="sofia">Pop Art</option>
35
+ <option value="slabo">Van Gogh</option>
36
+ <option value="inconsolata">Pixel Art</option>
37
+ <option value="ubuntu">Rembrandt</option>
38
+ <option value="Akronim">Cubism</option>
39
+ <option value="Monoton">Neon Art</option>
40
+ </select>
41
+ <select class="ql-size">
42
+ <option value="18px">Small</option>
43
+ <option selected>Normal</option>
44
+ <option value="32px">Large</option>
45
+ <option value="50px">Huge</option>
46
+ </select>
47
+ </span>
48
+ <span class="ql-formats">
49
+ <button class="ql-strike"></button>
50
+ </span>
51
+ <!-- <span class="ql-formats">
52
+ <button class="ql-bold"></button>
53
+ <button class="ql-italic"></button>
54
+ <button class="ql-underline"></button>
55
+ </span> -->
56
+ <span class="ql-formats">
57
+ <select class="ql-color"></select>
58
+ <!-- <select class="ql-background"></select> -->
59
+ </span>
60
+ <!-- <span class="ql-formats">
61
+ <button class="ql-script" value="sub"></button>
62
+ <button class="ql-script" value="super"></button>
63
+ </span>
64
+ <span class="ql-formats">
65
+ <button class="ql-header" value="1"></button>
66
+ <button class="ql-header" value="2"></button>
67
+ <button class="ql-blockquote"></button>
68
+ <button class="ql-code-block"></button>
69
+ </span>
70
+ <span class="ql-formats">
71
+ <button class="ql-list" value="ordered"></button>
72
+ <button class="ql-list" value="bullet"></button>
73
+ <button class="ql-indent" value="-1"></button>
74
+ <button class="ql-indent" value="+1"></button>
75
+ </span>
76
+ <span class="ql-formats">
77
+ <button class="ql-direction" value="rtl"></button>
78
+ <select class="ql-align"></select>
79
+ </span>
80
+ <span class="ql-formats">
81
+ <button class="ql-link"></button>
82
+ <button class="ql-image"></button>
83
+ <button class="ql-video"></button>
84
+ <button class="ql-formula"></button>
85
+ </span> -->
86
+ <span class="ql-formats">
87
+ <button class="ql-link"></button>
88
+ </span>
89
+ <span class="ql-formats">
90
+ <button class="ql-clean"></button>
91
+ </span>
92
+ </div>
93
+ <div id="editor-container"></div>
94
+ </div>
95
+ `;
96
+ }
97
+
98
+ static header() {
99
+ return `
100
+ <link rel="stylesheet" href="https://cdn.quilljs.com/1.3.6/quill.snow.css">
101
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css">
102
+ <style>
103
+ /* Set default font-family */
104
+ .ql-snow .ql-tooltip::before {
105
+ content: "Footnote";
106
+ line-height: 26px;
107
+ margin-right: 8px;
108
+ }
109
+
110
+ .ql-snow .ql-tooltip[data-mode=link]::before {
111
+ content: "Enter footnote:";
112
+ }
113
+
114
+ .row {
115
+ margin-top: 15px;
116
+ margin-left: 0px;
117
+ margin-bottom: 15px;
118
+ }
119
+
120
+ .btn-primary {
121
+ color: #ffffff;
122
+ background-color: #2780e3;
123
+ border-color: #2780e3;
124
+ }
125
+
126
+ .btn-primary:hover {
127
+ color: #ffffff;
128
+ background-color: #1967be;
129
+ border-color: #1862b5;
130
+ }
131
+
132
+ .btn {
133
+ display: inline-block;
134
+ margin-bottom: 0;
135
+ font-weight: normal;
136
+ text-align: center;
137
+ vertical-align: middle;
138
+ touch-action: manipulation;
139
+ cursor: pointer;
140
+ background-image: none;
141
+ border: 1px solid transparent;
142
+ white-space: nowrap;
143
+ padding: 10px 18px;
144
+ font-size: 15px;
145
+ line-height: 1.42857143;
146
+ border-radius: 0;
147
+ user-select: none;
148
+ }
149
+
150
+ #standalone-container {
151
+ position: relative;
152
+ max-width: 720px;
153
+ background-color: #ffffff;
154
+ color: black !important;
155
+ z-index: 1000;
156
+ }
157
+
158
+ #editor-container {
159
+ font-family: "Aref Ruqaa";
160
+ font-size: 18px;
161
+ height: 250px;
162
+ }
163
+
164
+ #toolbar-container {
165
+ font-family: "Aref Ruqaa";
166
+ display: flex;
167
+ flex-wrap: wrap;
168
+ }
169
+
170
+ #json-container {
171
+ max-width: 720px;
172
+ }
173
+
174
+ /* Set dropdown font-families */
175
+ #toolbar-container .ql-font span[data-label="Base"]::before {
176
+ font-family: "Aref Ruqaa";
177
+ }
178
+
179
+ #toolbar-container .ql-font span[data-label="Claude Monet"]::before {
180
+ font-family: "Mirza";
181
+ }
182
+
183
+ #toolbar-container .ql-font span[data-label="Ukiyoe"]::before {
184
+ font-family: "Roboto";
185
+ }
186
+
187
+ #toolbar-container .ql-font span[data-label="Cyber Punk"]::before {
188
+ font-family: "Comic Sans MS";
189
+ }
190
+
191
+ #toolbar-container .ql-font span[data-label="Pop Art"]::before {
192
+ font-family: "sofia";
193
+ }
194
+
195
+ #toolbar-container .ql-font span[data-label="Van Gogh"]::before {
196
+ font-family: "slabo 27px";
197
+ }
198
+
199
+ #toolbar-container .ql-font span[data-label="Pixel Art"]::before {
200
+ font-family: "inconsolata";
201
+ }
202
+
203
+ #toolbar-container .ql-font span[data-label="Rembrandt"]::before {
204
+ font-family: "ubuntu";
205
+ }
206
+
207
+ #toolbar-container .ql-font span[data-label="Cubism"]::before {
208
+ font-family: "Akronim";
209
+ }
210
+
211
+ #toolbar-container .ql-font span[data-label="Neon Art"]::before {
212
+ font-family: "Monoton";
213
+ }
214
+
215
+ /* Set content font-families */
216
+ .ql-font-mirza {
217
+ font-family: "Mirza";
218
+ }
219
+
220
+ .ql-font-roboto {
221
+ font-family: "Roboto";
222
+ }
223
+
224
+ .ql-font-cursive {
225
+ font-family: "Comic Sans MS";
226
+ }
227
+
228
+ .ql-font-sofia {
229
+ font-family: "sofia";
230
+ }
231
+
232
+ .ql-font-slabo {
233
+ font-family: "slabo 27px";
234
+ }
235
+
236
+ .ql-font-inconsolata {
237
+ font-family: "inconsolata";
238
+ }
239
+
240
+ .ql-font-ubuntu {
241
+ font-family: "ubuntu";
242
+ }
243
+
244
+ .ql-font-Akronim {
245
+ font-family: "Akronim";
246
+ }
247
+
248
+ .ql-font-Monoton {
249
+ font-family: "Monoton";
250
+ }
251
+ </style>
252
+ `;
253
+ }
254
+ async mountQuill() {
255
+ // Register the customs format with Quill
256
+ const lib = await import("https://cdn.jsdelivr.net/npm/shadow-selection-polyfill");
257
+ const getRange = lib.getRange;
258
+
259
+ const Font = Quill.import('formats/font');
260
+ Font.whitelist = ['mirza', 'roboto', 'sofia', 'slabo', 'inconsolata', 'ubuntu', 'cursive', 'Akronim', 'Monoton'];
261
+ const Link = Quill.import('formats/link');
262
+ Link.sanitize = function (url) {
263
+ // modify url if desired
264
+ return url;
265
+ }
266
+ const SizeStyle = Quill.import('attributors/style/size');
267
+ SizeStyle.whitelist = ['10px', '18px', '32px', '50px', '64px'];
268
+ Quill.register(SizeStyle, true);
269
+ Quill.register(Link, true);
270
+ Quill.register(Font, true);
271
+ const icons = Quill.import('ui/icons');
272
+ const icon = `<svg xmlns="http://www.w3.org/2000/svg" width="17" viewBox="0 0 512 512" xml:space="preserve"><path fill="#010101" d="M276.75 1c4.51 3.23 9.2 6.04 12.97 9.77 29.7 29.45 59.15 59.14 88.85 88.6 4.98 4.93 7.13 10.37 7.12 17.32-.1 125.8-.09 251.6-.01 377.4 0 7.94-1.96 14.46-9.62 18.57-121.41.34-242.77.34-364.76.05A288.3 288.3 0 0 1 1 502c0-163.02 0-326.04.34-489.62C3.84 6.53 8.04 3.38 13 1c23.35 0 46.7 0 70.82.3 2.07.43 3.38.68 4.69.68h127.98c18.44.01 36.41.04 54.39-.03 1.7 0 3.41-.62 5.12-.95h.75M33.03 122.5v359.05h320.22V129.18h-76.18c-14.22-.01-19.8-5.68-19.8-20.09V33.31H33.02v89.19m256.29-27.36c.72.66 1.44 1.9 2.17 1.9 12.73.12 25.46.08 37.55.08L289.3 57.45v37.7z"/><path fill="#020202" d="M513 375.53c-4.68 7.99-11.52 10.51-20.21 10.25-13.15-.4-26.32-.1-39.48-.1h-5.58c5.49 8.28 10.7 15.74 15.46 23.47 6.06 9.82 1.14 21.65-9.96 24.27-6.7 1.59-12.45-.64-16.23-6.15a2608.6 2608.6 0 0 1-32.97-49.36c-3.57-5.48-3.39-11.54.17-16.98a3122.5 3122.5 0 0 1 32.39-48.56c5.22-7.65 14.67-9.35 21.95-4.45 7.63 5.12 9.6 14.26 4.5 22.33-4.75 7.54-9.8 14.9-15.11 22.95h33.64V225.19h-5.24c-19.49 0-38.97.11-58.46-.05-12.74-.1-20.12-13.15-13.84-24.14 3.12-5.46 8.14-7.71 14.18-7.73 26.15-.06 52.3-.04 78.45 0 7.1 0 12.47 3.05 16.01 9.64.33 57.44.33 114.8.33 172.62z"/><path fill="#111" d="M216.03 1.97C173.52 1.98 131 2 88.5 1.98a16 16 0 0 1-4.22-.68c43.4-.3 87.09-.3 131.24-.06.48.25.5.73.5.73z"/><path fill="#232323" d="M216.5 1.98c-.47 0-.5-.5-.5-.74C235.7 1 255.38 1 275.53 1c-1.24.33-2.94.95-4.65.95-17.98.07-35.95.04-54.39.03z"/><path fill="#040404" d="M148 321.42h153.5c14.25 0 19.96 5.71 19.96 19.97.01 19.17.03 38.33 0 57.5-.03 12.6-6.16 18.78-18.66 18.78H99.81c-12.42 0-18.75-6.34-18.76-18.73-.01-19.83-.02-39.66 0-59.5.02-11.47 6.4-17.93 17.95-18 16.17-.08 32.33-.02 49-.02m40.5 32.15h-75.16v31.84h175.7v-31.84H188.5z"/><path fill="#030303" d="m110 225.33 178.89-.03c11.98 0 19.25 9.95 15.74 21.44-2.05 6.71-7.5 10.57-15.14 10.57-63.63 0-127.25-.01-190.88-.07-12.03-.02-19.17-8.62-16.7-19.84 1.6-7.21 7.17-11.74 15.1-12.04 4.17-.16 8.33-.03 13-.03zm-24.12-36.19c-5.28-6.2-6.3-12.76-2.85-19.73 3.22-6.49 9.13-8.24 15.86-8.24 25.64.01 51.27-.06 76.91.04 13.07.04 20.66 10.44 16.33 22.08-2.25 6.06-6.63 9.76-13.08 9.8-27.97.18-55.94.2-83.9-.07-3.01-.03-6-2.36-9.27-3.88z"/></svg>`
273
+ icons['link'] = icon;
274
+ const editorContainer = this.shadowRoot.querySelector('#editor-container')
275
+ const toolbarContainer = this.shadowRoot.querySelector('#toolbar-container')
276
+ const myQuill = new Quill(editorContainer, {
277
+ modules: {
278
+ toolbar: {
279
+ container: toolbarContainer,
280
+ },
281
+ },
282
+ theme: 'snow'
283
+ });
284
+ const normalizeNative = (nativeRange) => {
285
+
286
+ if (nativeRange) {
287
+ const range = nativeRange;
288
+
289
+ if (range.baseNode) {
290
+ range.startContainer = nativeRange.baseNode;
291
+ range.endContainer = nativeRange.focusNode;
292
+ range.startOffset = nativeRange.baseOffset;
293
+ range.endOffset = nativeRange.focusOffset;
294
+
295
+ if (range.endOffset < range.startOffset) {
296
+ range.startContainer = nativeRange.focusNode;
297
+ range.endContainer = nativeRange.baseNode;
298
+ range.startOffset = nativeRange.focusOffset;
299
+ range.endOffset = nativeRange.baseOffset;
300
+ }
301
+ }
302
+
303
+ if (range.startContainer) {
304
+ return {
305
+ start: { node: range.startContainer, offset: range.startOffset },
306
+ end: { node: range.endContainer, offset: range.endOffset },
307
+ native: range
308
+ };
309
+ }
310
+ }
311
+
312
+ return null
313
+ };
314
+
315
+ myQuill.selection.getNativeRange = () => {
316
+
317
+ const dom = myQuill.root.getRootNode();
318
+ const selection = getRange(dom);
319
+ const range = normalizeNative(selection);
320
+
321
+ return range;
322
+ };
323
+ let fromEditor = false;
324
+ editorContainer.addEventListener("pointerup", (e) => {
325
+ fromEditor = false;
326
+ });
327
+ editorContainer.addEventListener("pointerout", (e) => {
328
+ fromEditor = false;
329
+ });
330
+ editorContainer.addEventListener("pointerdown", (e) => {
331
+ fromEditor = true;
332
+ });
333
+
334
+ document.addEventListener("selectionchange", () => {
335
+ if (fromEditor) {
336
+ myQuill.selection.update()
337
+ }
338
+ });
339
+
340
+
341
+ myQuill.on('text-change', () => {
342
+ // keep qull data inside _data to communicate with Gradio
343
+ document.querySelector("#rich-text-root")._data = myQuill.getContents()
344
+ })
345
+ return myQuill
346
+ }
347
+ }
348
+
349
+ customElements.define('rich-text-editor', RichTextEditor);
share_btn.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ community_icon_html = """<svg id="share-btn-share-icon" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32">
2
+ <path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path>
3
+ <path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path>
4
+ </svg>"""
5
+
6
+ loading_icon_html = """<svg id="share-btn-loading-icon" style="display:none;" class="animate-spin" style="color: #ffffff;" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
7
+
8
+ share_js = """async () => {
9
+ async function uploadFile(file){
10
+ const UPLOAD_URL = 'https://huggingface.co/uploads';
11
+ const response = await fetch(UPLOAD_URL, {
12
+ method: 'POST',
13
+ headers: {
14
+ 'Content-Type': file.type,
15
+ 'X-Requested-With': 'XMLHttpRequest',
16
+ },
17
+ body: file, /// <- File inherits from Blob
18
+ });
19
+ const url = await response.text();
20
+ return url;
21
+ }
22
+ async function getInputImageFile(imageEl){
23
+ const res = await fetch(imageEl.src);
24
+ const blob = await res.blob();
25
+ const imageId = Date.now();
26
+ const fileName = `rich-text-image-${{imageId}}.png`;
27
+ return new File([blob], fileName, { type: 'image/png'});
28
+ }
29
+ const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
30
+ const richEl = document.getElementById("rich-text-root");
31
+ const data = richEl? richEl.contentDocument.body._data : {};
32
+ const text_input = JSON.stringify(data);
33
+ const negative_prompt = gradioEl.querySelector('#negative_prompt input').value;
34
+ const seed = gradioEl.querySelector('#seed input').value;
35
+ const richTextImg = gradioEl.querySelector('#rich-text-image img');
36
+ const plainTextImg = gradioEl.querySelector('#plain-text-image img');
37
+ const text_input_obj = JSON.parse(text_input);
38
+ const plain_prompt = text_input_obj.ops.map(e=> e.insert).join('');
39
+ const linkSrc = `https://huggingface.co/spaces/songweig/rich-text-to-image?prompt=${encodeURIComponent(text_input)}`;
40
+
41
+ const titleTxt = `RT2I: ${plain_prompt.slice(0, 50)}...`;
42
+ const shareBtnEl = gradioEl.querySelector('#share-btn');
43
+ const shareIconEl = gradioEl.querySelector('#share-btn-share-icon');
44
+ const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon');
45
+ if(!richTextImg){
46
+ return;
47
+ };
48
+ shareBtnEl.style.pointerEvents = 'none';
49
+ shareIconEl.style.display = 'none';
50
+ loadingIconEl.style.removeProperty('display');
51
+
52
+ const richImgFile = await getInputImageFile(richTextImg);
53
+ const plainImgFile = await getInputImageFile(plainTextImg);
54
+ const richImgURL = await uploadFile(richImgFile);
55
+ const plainImgURL = await uploadFile(plainImgFile);
56
+
57
+ const descriptionMd = `
58
+ ### Plain Prompt
59
+ ${plain_prompt}
60
+
61
+ πŸ”— Shareable Link + Params: [here](${linkSrc})
62
+
63
+ ### Rich Tech Image
64
+ <img src="${richImgURL}">
65
+
66
+ ### Plain Text Image
67
+ <img src="${plainImgURL}">
68
+
69
+ `;
70
+ const params = new URLSearchParams({
71
+ title: titleTxt,
72
+ description: descriptionMd,
73
+ });
74
+ const paramsStr = params.toString();
75
+ window.open(`https://huggingface.co/spaces/songweig/rich-text-to-image/discussions/new?${paramsStr}`, '_blank');
76
+ shareBtnEl.style.removeProperty('pointer-events');
77
+ shareIconEl.style.removeProperty('display');
78
+ loadingIconEl.style.display = 'none';
79
+ }"""
80
+
81
+ css = """
82
+ #share-btn-container {
83
+ display: flex;
84
+ padding-left: 0.5rem !important;
85
+ padding-right: 0.5rem !important;
86
+ background-color: #000000;
87
+ justify-content: center;
88
+ align-items: center;
89
+ border-radius: 9999px !important;
90
+ width: 13rem;
91
+ margin-top: 10px;
92
+ margin-left: auto;
93
+ flex: unset !important;
94
+ }
95
+ #share-btn {
96
+ all: initial;
97
+ color: #ffffff;
98
+ font-weight: 600;
99
+ cursor: pointer;
100
+ font-family: 'IBM Plex Sans', sans-serif;
101
+ margin-left: 0.5rem !important;
102
+ padding-top: 0.25rem !important;
103
+ padding-bottom: 0.25rem !important;
104
+ right:0;
105
+ }
106
+ #share-btn * {
107
+ all: unset !important;
108
+ }
109
+ #share-btn-container div:nth-child(-n+2){
110
+ width: auto !important;
111
+ min-height: 0px !important;
112
+ }
113
+ #share-btn-container .wrap {
114
+ display: none !important;
115
+ }
116
+ """
utils/.DS_Store ADDED
Binary file (6.15 kB). View file
 
utils/attention_utils.py ADDED
@@ -0,0 +1,724 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import matplotlib as mpl
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ import torch
7
+ import torchvision
8
+
9
+ from utils.richtext_utils import seed_everything
10
+ from sklearn.cluster import KMeans, SpectralClustering
11
+
12
+ # SelfAttentionLayers = [
13
+ # # 'down_blocks.0.attentions.0.transformer_blocks.0.attn1',
14
+ # # 'down_blocks.0.attentions.1.transformer_blocks.0.attn1',
15
+ # 'down_blocks.1.attentions.0.transformer_blocks.0.attn1',
16
+ # # 'down_blocks.1.attentions.1.transformer_blocks.0.attn1',
17
+ # 'down_blocks.2.attentions.0.transformer_blocks.0.attn1',
18
+ # 'down_blocks.2.attentions.1.transformer_blocks.0.attn1',
19
+ # 'mid_block.attentions.0.transformer_blocks.0.attn1',
20
+ # 'up_blocks.1.attentions.0.transformer_blocks.0.attn1',
21
+ # 'up_blocks.1.attentions.1.transformer_blocks.0.attn1',
22
+ # 'up_blocks.1.attentions.2.transformer_blocks.0.attn1',
23
+ # # 'up_blocks.2.attentions.0.transformer_blocks.0.attn1',
24
+ # 'up_blocks.2.attentions.1.transformer_blocks.0.attn1',
25
+ # # 'up_blocks.2.attentions.2.transformer_blocks.0.attn1',
26
+ # # 'up_blocks.3.attentions.0.transformer_blocks.0.attn1',
27
+ # # 'up_blocks.3.attentions.1.transformer_blocks.0.attn1',
28
+ # # 'up_blocks.3.attentions.2.transformer_blocks.0.attn1',
29
+ # ]
30
+
31
+ SelfAttentionLayers = [
32
+ # 'down_blocks.0.attentions.0.transformer_blocks.0.attn1',
33
+ # 'down_blocks.0.attentions.1.transformer_blocks.0.attn1',
34
+ 'down_blocks.1.attentions.0.transformer_blocks.0.attn1',
35
+ # 'down_blocks.1.attentions.1.transformer_blocks.0.attn1',
36
+ 'down_blocks.2.attentions.0.transformer_blocks.0.attn1',
37
+ 'down_blocks.2.attentions.1.transformer_blocks.0.attn1',
38
+ 'mid_block.attentions.0.transformer_blocks.0.attn1',
39
+ 'up_blocks.1.attentions.0.transformer_blocks.0.attn1',
40
+ 'up_blocks.1.attentions.1.transformer_blocks.0.attn1',
41
+ 'up_blocks.1.attentions.2.transformer_blocks.0.attn1',
42
+ # 'up_blocks.2.attentions.0.transformer_blocks.0.attn1',
43
+ 'up_blocks.2.attentions.1.transformer_blocks.0.attn1',
44
+ # 'up_blocks.2.attentions.2.transformer_blocks.0.attn1',
45
+ # 'up_blocks.3.attentions.0.transformer_blocks.0.attn1',
46
+ # 'up_blocks.3.attentions.1.transformer_blocks.0.attn1',
47
+ # 'up_blocks.3.attentions.2.transformer_blocks.0.attn1',
48
+ ]
49
+
50
+
51
+ CrossAttentionLayers = [
52
+ # 'down_blocks.0.attentions.0.transformer_blocks.0.attn2',
53
+ # 'down_blocks.0.attentions.1.transformer_blocks.0.attn2',
54
+ 'down_blocks.1.attentions.0.transformer_blocks.0.attn2',
55
+ # 'down_blocks.1.attentions.1.transformer_blocks.0.attn2',
56
+ 'down_blocks.2.attentions.0.transformer_blocks.0.attn2',
57
+ 'down_blocks.2.attentions.1.transformer_blocks.0.attn2',
58
+ 'mid_block.attentions.0.transformer_blocks.0.attn2',
59
+ 'up_blocks.1.attentions.0.transformer_blocks.0.attn2',
60
+ 'up_blocks.1.attentions.1.transformer_blocks.0.attn2',
61
+ 'up_blocks.1.attentions.2.transformer_blocks.0.attn2',
62
+ # 'up_blocks.2.attentions.0.transformer_blocks.0.attn2',
63
+ 'up_blocks.2.attentions.1.transformer_blocks.0.attn2',
64
+ # 'up_blocks.2.attentions.2.transformer_blocks.0.attn2',
65
+ # 'up_blocks.3.attentions.0.transformer_blocks.0.attn2',
66
+ # 'up_blocks.3.attentions.1.transformer_blocks.0.attn2',
67
+ # 'up_blocks.3.attentions.2.transformer_blocks.0.attn2'
68
+ ]
69
+
70
+ # CrossAttentionLayers = [
71
+ # 'down_blocks.0.attentions.0.transformer_blocks.0.attn2',
72
+ # 'down_blocks.0.attentions.1.transformer_blocks.0.attn2',
73
+ # 'down_blocks.1.attentions.0.transformer_blocks.0.attn2',
74
+ # 'down_blocks.1.attentions.1.transformer_blocks.0.attn2',
75
+ # 'down_blocks.2.attentions.0.transformer_blocks.0.attn2',
76
+ # 'down_blocks.2.attentions.1.transformer_blocks.0.attn2',
77
+ # 'mid_block.attentions.0.transformer_blocks.0.attn2',
78
+ # 'up_blocks.1.attentions.0.transformer_blocks.0.attn2',
79
+ # 'up_blocks.1.attentions.1.transformer_blocks.0.attn2',
80
+ # 'up_blocks.1.attentions.2.transformer_blocks.0.attn2',
81
+ # 'up_blocks.2.attentions.0.transformer_blocks.0.attn2',
82
+ # 'up_blocks.2.attentions.1.transformer_blocks.0.attn2',
83
+ # 'up_blocks.2.attentions.2.transformer_blocks.0.attn2',
84
+ # 'up_blocks.3.attentions.0.transformer_blocks.0.attn2',
85
+ # 'up_blocks.3.attentions.1.transformer_blocks.0.attn2',
86
+ # 'up_blocks.3.attentions.2.transformer_blocks.0.attn2'
87
+ # ]
88
+
89
+ # CrossAttentionLayers_XL = [
90
+ # 'up_blocks.0.attentions.0.transformer_blocks.1.attn2',
91
+ # 'up_blocks.0.attentions.0.transformer_blocks.2.attn2',
92
+ # 'up_blocks.0.attentions.0.transformer_blocks.3.attn2',
93
+ # 'up_blocks.0.attentions.0.transformer_blocks.4.attn2',
94
+ # 'up_blocks.0.attentions.0.transformer_blocks.5.attn2',
95
+ # 'up_blocks.0.attentions.0.transformer_blocks.6.attn2',
96
+ # 'up_blocks.0.attentions.0.transformer_blocks.7.attn2',
97
+ # ]
98
+ CrossAttentionLayers_XL = [
99
+ 'down_blocks.2.attentions.1.transformer_blocks.3.attn2',
100
+ 'down_blocks.2.attentions.1.transformer_blocks.4.attn2',
101
+ 'mid_block.attentions.0.transformer_blocks.0.attn2',
102
+ 'mid_block.attentions.0.transformer_blocks.1.attn2',
103
+ 'mid_block.attentions.0.transformer_blocks.2.attn2',
104
+ 'mid_block.attentions.0.transformer_blocks.3.attn2',
105
+ 'up_blocks.0.attentions.0.transformer_blocks.1.attn2',
106
+ 'up_blocks.0.attentions.0.transformer_blocks.2.attn2',
107
+ 'up_blocks.0.attentions.0.transformer_blocks.3.attn2',
108
+ 'up_blocks.0.attentions.0.transformer_blocks.4.attn2',
109
+ 'up_blocks.0.attentions.0.transformer_blocks.5.attn2',
110
+ 'up_blocks.0.attentions.0.transformer_blocks.6.attn2',
111
+ 'up_blocks.0.attentions.0.transformer_blocks.7.attn2',
112
+ 'up_blocks.1.attentions.0.transformer_blocks.0.attn2'
113
+ ]
114
+
115
+ def split_attention_maps_over_steps(attention_maps):
116
+ r"""Function for splitting attention maps over steps.
117
+ Args:
118
+ attention_maps (dict): Dictionary of attention maps.
119
+ sampler_order (int): Order of the sampler.
120
+ """
121
+ # This function splits attention maps into unconditional and conditional score and over steps
122
+
123
+ attention_maps_cond = dict() # Maps corresponding to conditional score
124
+ attention_maps_uncond = dict() # Maps corresponding to unconditional score
125
+
126
+ for layer in attention_maps.keys():
127
+
128
+ for step_num in range(len(attention_maps[layer])):
129
+ if step_num not in attention_maps_cond:
130
+ attention_maps_cond[step_num] = dict()
131
+ attention_maps_uncond[step_num] = dict()
132
+
133
+ attention_maps_uncond[step_num].update(
134
+ {layer: attention_maps[layer][step_num][:1]})
135
+ attention_maps_cond[step_num].update(
136
+ {layer: attention_maps[layer][step_num][1:2]})
137
+
138
+ return attention_maps_cond, attention_maps_uncond
139
+
140
+
141
+ def save_attention_heatmaps(attention_maps, tokens_vis, save_dir, prefix):
142
+ r"""Function to plot heatmaps for attention maps.
143
+
144
+ Args:
145
+ attention_maps (dict): Dictionary of attention maps per layer
146
+ save_dir (str): Directory to save attention maps
147
+ prefix (str): Filename prefix for html files
148
+
149
+ Returns:
150
+ Heatmaps, one per sample.
151
+ """
152
+
153
+ html_names = []
154
+
155
+ idx = 0
156
+ html_list = []
157
+
158
+ for layer in attention_maps.keys():
159
+ if idx == 0:
160
+ # import ipdb;ipdb.set_trace()
161
+ # create a set of html files.
162
+
163
+ batch_size = attention_maps[layer].shape[0]
164
+
165
+ for sample_num in range(batch_size):
166
+ # html path
167
+ html_rel_path = os.path.join('sample_{}'.format(
168
+ sample_num), '{}.html'.format(prefix))
169
+ html_names.append(html_rel_path)
170
+ html_path = os.path.join(save_dir, html_rel_path)
171
+ os.makedirs(os.path.dirname(html_path), exist_ok=True)
172
+ html_list.append(open(html_path, 'wt'))
173
+ html_list[sample_num].write(
174
+ '<html><head></head><body><table>\n')
175
+
176
+ for sample_num in range(batch_size):
177
+
178
+ save_path = os.path.join(save_dir, 'sample_{}'.format(sample_num),
179
+ prefix, 'layer_{}'.format(layer)) + '.jpg'
180
+ Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)
181
+
182
+ layer_name = 'layer_{}'.format(layer)
183
+ html_list[sample_num].write(
184
+ f'<tr><td><h1>{layer_name}</h1></td></tr>\n')
185
+
186
+ prefix_stem = prefix.split('/')[-1]
187
+ relative_image_path = os.path.join(
188
+ prefix_stem, 'layer_{}'.format(layer)) + '.jpg'
189
+ html_list[sample_num].write(
190
+ f'<tr><td><img src=\"{relative_image_path}\"></td></tr>\n')
191
+
192
+ plt.figure()
193
+ plt.clf()
194
+ nrows = 2
195
+ ncols = 7
196
+ fig, axs = plt.subplots(nrows=nrows, ncols=ncols)
197
+
198
+ fig.set_figheight(8)
199
+ fig.set_figwidth(28.5)
200
+
201
+ # axs[0].set_aspect('equal')
202
+ # axs[1].set_aspect('equal')
203
+ # axs[2].set_aspect('equal')
204
+ # axs[3].set_aspect('equal')
205
+ # axs[4].set_aspect('equal')
206
+ # axs[5].set_aspect('equal')
207
+
208
+ cmap = plt.get_cmap('YlOrRd')
209
+
210
+ for rid in range(nrows):
211
+ for cid in range(ncols):
212
+ tid = rid*ncols + cid
213
+ # import ipdb;ipdb.set_trace()
214
+ attention_map_cur = attention_maps[layer][sample_num, :, :, tid].numpy(
215
+ )
216
+ vmax = float(attention_map_cur.max())
217
+ vmin = float(attention_map_cur.min())
218
+ sns.heatmap(
219
+ attention_map_cur, annot=False, cbar=False, ax=axs[rid, cid],
220
+ cmap=cmap, vmin=vmin, vmax=vmax
221
+ )
222
+ axs[rid, cid].set_xlabel(tokens_vis[tid])
223
+
224
+ # axs[0].set_xlabel('Self attention')
225
+ # axs[1].set_xlabel('Temporal attention')
226
+ # axs[2].set_xlabel('T5 text attention')
227
+ # axs[3].set_xlabel('CLIP text attention')
228
+ # axs[4].set_xlabel('CLIP image attention')
229
+ # axs[5].set_xlabel('Null text token')
230
+
231
+ norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
232
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
233
+ # fig.colorbar(sm, cax=axs[6])
234
+
235
+ fig.tight_layout()
236
+ plt.savefig(save_path, dpi=64)
237
+ plt.close('all')
238
+
239
+ if idx == (len(attention_maps.keys()) - 1):
240
+ for sample_num in range(batch_size):
241
+ html_list[sample_num].write('</table></body></html>')
242
+ html_list[sample_num].close()
243
+
244
+ idx += 1
245
+
246
+ return html_names
247
+
248
+
249
+ def create_recursive_html_link(html_path, save_dir):
250
+ r"""Function for creating recursive html links.
251
+ If the path is dir1/dir2/dir3/*.html,
252
+ we create chained directories
253
+ -dir1
254
+ dir1.html (has links to all children)
255
+ -dir2
256
+ dir2.html (has links to all children)
257
+ -dir3
258
+ dir3.html
259
+
260
+ Args:
261
+ html_path (str): Path to html file.
262
+ save_dir (str): Save directory.
263
+ """
264
+
265
+ html_path_split = os.path.splitext(html_path)[0].split('/')
266
+ if len(html_path_split) == 1:
267
+ return
268
+
269
+ # First create the root directory
270
+ root_dir = html_path_split[0]
271
+ child_dir = html_path_split[1]
272
+
273
+ cur_html_path = os.path.join(save_dir, '{}.html'.format(root_dir))
274
+ if os.path.exists(cur_html_path):
275
+
276
+ fp = open(cur_html_path, 'r')
277
+ lines_written = fp.readlines()
278
+ fp.close()
279
+
280
+ fp = open(cur_html_path, 'a+')
281
+ child_path = os.path.join(root_dir, f'{child_dir}.html')
282
+ line_to_write = f'<tr><td><a href=\"{child_path}\">{child_dir}</a></td></tr>\n'
283
+
284
+ if line_to_write not in lines_written:
285
+ fp.write('<html><head></head><body><table>\n')
286
+ fp.write(line_to_write)
287
+ fp.write('</table></body></html>')
288
+ fp.close()
289
+
290
+ else:
291
+
292
+ fp = open(cur_html_path, 'w')
293
+
294
+ child_path = os.path.join(root_dir, f'{child_dir}.html')
295
+ line_to_write = f'<tr><td><a href=\"{child_path}\">{child_dir}</a></td></tr>\n'
296
+
297
+ fp.write('<html><head></head><body><table>\n')
298
+ fp.write(line_to_write)
299
+ fp.write('</table></body></html>')
300
+
301
+ fp.close()
302
+
303
+ child_path = '/'.join(html_path.split('/')[1:])
304
+ save_dir = os.path.join(save_dir, root_dir)
305
+ create_recursive_html_link(child_path, save_dir)
306
+
307
+
308
+ def visualize_attention_maps(attention_maps_all, save_dir, width, height, tokens_vis):
309
+ r"""Function to visualize attention maps.
310
+ Args:
311
+ save_dir (str): Path to save attention maps
312
+ batch_size (int): Batch size
313
+ sampler_order (int): Sampler order
314
+ """
315
+
316
+ rand_name = list(attention_maps_all.keys())[0]
317
+ nsteps = len(attention_maps_all[rand_name])
318
+ hw_ori = width * height
319
+
320
+ # html_path = save_dir + '.html'
321
+ text_input = save_dir.split('/')[-1]
322
+ # f = open(html_path, 'wt')
323
+
324
+ all_html_paths = []
325
+
326
+ for step_num in range(0, nsteps, 5):
327
+
328
+ # if cond_id == 'cond':
329
+ # attention_maps_cur = attention_maps_cond[step_num]
330
+ # else:
331
+ # attention_maps_cur = attention_maps_uncond[step_num]
332
+
333
+ attention_maps = dict()
334
+
335
+ for layer in attention_maps_all.keys():
336
+
337
+ attention_ind = attention_maps_all[layer][step_num].cpu()
338
+
339
+ # Attention maps are of shape [batch_size, nkeys, 77]
340
+ # since they are averaged out while collecting from hooks to save memory.
341
+ # Now split the heads from batch dimension
342
+ bs, hw, nclip = attention_ind.shape
343
+ down_ratio = np.sqrt(hw_ori // hw)
344
+ width_cur = int(width // down_ratio)
345
+ height_cur = int(height // down_ratio)
346
+ attention_ind = attention_ind.reshape(
347
+ bs, height_cur, width_cur, nclip)
348
+
349
+ attention_maps[layer] = attention_ind
350
+
351
+ # Obtain heatmaps corresponding to random heads and individual heads
352
+
353
+ html_names = save_attention_heatmaps(
354
+ attention_maps, tokens_vis, save_dir=save_dir, prefix='step_{}/attention_maps_cond'.format(
355
+ step_num)
356
+ )
357
+
358
+ # Write the logic for recursively creating pages
359
+ for html_name_cur in html_names:
360
+ all_html_paths.append(os.path.join(text_input, html_name_cur))
361
+
362
+ save_dir_root = '/'.join(save_dir.split('/')[0:-1])
363
+ for html_pth in all_html_paths:
364
+ create_recursive_html_link(html_pth, save_dir_root)
365
+
366
+
367
+ def plot_attention_maps(atten_map_list, obj_tokens, save_dir, seed, tokens_vis=None):
368
+ for i, attn_map in enumerate(atten_map_list):
369
+ n_obj = len(attn_map)
370
+ plt.figure()
371
+ plt.clf()
372
+
373
+ fig, axs = plt.subplots(
374
+ ncols=n_obj+1, gridspec_kw=dict(width_ratios=[1 for _ in range(n_obj)]+[0.1]))
375
+
376
+ fig.set_figheight(3)
377
+ fig.set_figwidth(3*n_obj+0.1)
378
+
379
+ cmap = plt.get_cmap('YlOrRd')
380
+
381
+ vmax = 0
382
+ vmin = 1
383
+ for tid in range(n_obj):
384
+ attention_map_cur = attn_map[tid]
385
+ vmax = max(vmax, float(attention_map_cur.max()))
386
+ vmin = min(vmin, float(attention_map_cur.min()))
387
+
388
+ for tid in range(n_obj):
389
+ sns.heatmap(
390
+ attn_map[tid][0], annot=False, cbar=False, ax=axs[tid],
391
+ cmap=cmap, vmin=vmin, vmax=vmax
392
+ )
393
+ axs[tid].set_axis_off()
394
+
395
+ if tokens_vis is not None:
396
+ if tid == n_obj-1:
397
+ axs_xlabel = 'other tokens'
398
+ else:
399
+ axs_xlabel = ''
400
+ for token_id in obj_tokens[tid]:
401
+ axs_xlabel += ' ' + tokens_vis[token_id.item() -
402
+ 1][:-len('</w>')]
403
+ axs[tid].set_title(axs_xlabel)
404
+
405
+ norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
406
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
407
+ fig.colorbar(sm, cax=axs[-1])
408
+
409
+ fig.tight_layout()
410
+
411
+ canvas = fig.canvas
412
+ canvas.draw()
413
+ width, height = canvas.get_width_height()
414
+ img = np.frombuffer(canvas.tostring_rgb(),
415
+ dtype='uint8').reshape((height, width, 3))
416
+ plt.savefig(os.path.join(
417
+ save_dir, 'average_seed%d_attn%d.jpg' % (seed, i)), dpi=100)
418
+ plt.close('all')
419
+ return img
420
+
421
+
422
+ def get_average_attention_maps(attention_maps, save_dir, width, height, obj_tokens, seed=0, tokens_vis=None,
423
+ preprocess=False):
424
+ r"""Function to visualize attention maps.
425
+ Args:
426
+ save_dir (str): Path to save attention maps
427
+ batch_size (int): Batch size
428
+ sampler_order (int): Sampler order
429
+ """
430
+
431
+ # Split attention maps over steps
432
+ attention_maps_cond, _ = split_attention_maps_over_steps(
433
+ attention_maps
434
+ )
435
+
436
+ nsteps = len(attention_maps_cond)
437
+ hw_ori = width * height
438
+
439
+ attention_maps = []
440
+ for obj_token in obj_tokens:
441
+ attention_maps.append([])
442
+
443
+ for step_num in range(nsteps):
444
+ attention_maps_cur = attention_maps_cond[step_num]
445
+
446
+ for layer in attention_maps_cur.keys():
447
+ if step_num < 10 or layer not in CrossAttentionLayers:
448
+ continue
449
+
450
+ attention_ind = attention_maps_cur[layer].cpu()
451
+
452
+ # Attention maps are of shape [batch_size, nkeys, 77]
453
+ # since they are averaged out while collecting from hooks to save memory.
454
+ # Now split the heads from batch dimension
455
+ bs, hw, nclip = attention_ind.shape
456
+ down_ratio = np.sqrt(hw_ori // hw)
457
+ width_cur = int(width // down_ratio)
458
+ height_cur = int(height // down_ratio)
459
+ attention_ind = attention_ind.reshape(
460
+ bs, height_cur, width_cur, nclip)
461
+ for obj_id, obj_token in enumerate(obj_tokens):
462
+ if obj_token[0] == -1:
463
+ attention_map_prev = torch.stack(
464
+ [attention_maps[i][-1] for i in range(obj_id)]).sum(0)
465
+ attention_maps[obj_id].append(
466
+ attention_map_prev.max()-attention_map_prev)
467
+ else:
468
+ obj_attention_map = attention_ind[:, :, :, obj_token].max(-1, True)[
469
+ 0].permute([3, 0, 1, 2])
470
+ # obj_attention_map = attention_ind[:, :, :, obj_token].mean(-1, True).permute([3, 0, 1, 2])
471
+ obj_attention_map = torchvision.transforms.functional.resize(obj_attention_map, (height, width),
472
+ interpolation=torchvision.transforms.InterpolationMode.BICUBIC, antialias=True)
473
+ attention_maps[obj_id].append(obj_attention_map)
474
+
475
+ attention_maps_averaged = []
476
+ for obj_id, obj_token in enumerate(obj_tokens):
477
+ if obj_id == len(obj_tokens) - 1:
478
+ attention_maps_averaged.append(
479
+ torch.cat(attention_maps[obj_id]).mean(0))
480
+ else:
481
+ attention_maps_averaged.append(
482
+ torch.cat(attention_maps[obj_id]).mean(0))
483
+
484
+ attention_maps_averaged_normalized = []
485
+ attention_maps_averaged_sum = torch.cat(attention_maps_averaged).sum(0)
486
+ for obj_id, obj_token in enumerate(obj_tokens):
487
+ attention_maps_averaged_normalized.append(
488
+ attention_maps_averaged[obj_id]/attention_maps_averaged_sum)
489
+
490
+ if obj_tokens[-1][0] != -1:
491
+ attention_maps_averaged_normalized = (
492
+ torch.cat(attention_maps_averaged)/0.001).softmax(0)
493
+ attention_maps_averaged_normalized = [
494
+ attention_maps_averaged_normalized[i:i+1] for i in range(attention_maps_averaged_normalized.shape[0])]
495
+
496
+ if preprocess:
497
+ selem = square(5)
498
+ selem = square(3)
499
+ selem = square(1)
500
+ attention_maps_averaged_eroded = [erosion(skimage.img_as_float(
501
+ map[0].numpy()*255), selem) for map in attention_maps_averaged_normalized[:2]]
502
+ attention_maps_averaged_eroded = [(torch.from_numpy(map).unsqueeze(
503
+ 0)/255. > 0.8).float() for map in attention_maps_averaged_eroded]
504
+ attention_maps_averaged_eroded.append(
505
+ 1 - torch.cat(attention_maps_averaged_eroded).sum(0, True))
506
+ plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized,
507
+ attention_maps_averaged_eroded], obj_tokens, save_dir, seed, tokens_vis)
508
+ attention_maps_averaged_eroded = [attn_mask.unsqueeze(1).repeat(
509
+ [1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_eroded]
510
+ return attention_maps_averaged_eroded
511
+ else:
512
+ plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized],
513
+ obj_tokens, save_dir, seed, tokens_vis)
514
+ attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
515
+ [1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized]
516
+ return attention_maps_averaged_normalized
517
+
518
+
519
+ def get_average_attention_maps_threshold(attention_maps, save_dir, width, height, obj_tokens, seed=0, threshold=0.02):
520
+ r"""Function to visualize attention maps.
521
+ Args:
522
+ save_dir (str): Path to save attention maps
523
+ batch_size (int): Batch size
524
+ sampler_order (int): Sampler order
525
+ """
526
+
527
+ _EPS = 1e-8
528
+ # Split attention maps over steps
529
+ attention_maps_cond, _ = split_attention_maps_over_steps(
530
+ attention_maps
531
+ )
532
+
533
+ nsteps = len(attention_maps_cond)
534
+ hw_ori = width * height
535
+
536
+ attention_maps = []
537
+ for obj_token in obj_tokens:
538
+ attention_maps.append([])
539
+
540
+ # for each side prompt, get attention maps for all steps and all layers
541
+ for step_num in range(nsteps):
542
+ attention_maps_cur = attention_maps_cond[step_num]
543
+ for layer in attention_maps_cur.keys():
544
+ attention_ind = attention_maps_cur[layer].cpu()
545
+ bs, hw, nclip = attention_ind.shape
546
+ down_ratio = np.sqrt(hw_ori // hw)
547
+ width_cur = int(width // down_ratio)
548
+ height_cur = int(height // down_ratio)
549
+ attention_ind = attention_ind.reshape(
550
+ bs, height_cur, width_cur, nclip)
551
+ for obj_id, obj_token in enumerate(obj_tokens):
552
+ if attention_ind.shape[1] > width//2:
553
+ continue
554
+ if obj_token[0] != -1:
555
+ obj_attention_map = attention_ind[:, :, :,
556
+ obj_token].mean(-1, True).permute([3, 0, 1, 2])
557
+ obj_attention_map = torchvision.transforms.functional.resize(obj_attention_map, (height, width),
558
+ interpolation=torchvision.transforms.InterpolationMode.BICUBIC, antialias=True)
559
+ attention_maps[obj_id].append(obj_attention_map)
560
+
561
+ # average of all steps and layers, thresholding
562
+ attention_maps_thres = []
563
+ attention_maps_averaged = []
564
+ for obj_id, obj_token in enumerate(obj_tokens):
565
+ if obj_token[0] != -1:
566
+ average_map = torch.cat(attention_maps[obj_id]).mean(0)
567
+ attention_maps_averaged.append(average_map)
568
+ attention_maps_thres.append((average_map > threshold).float())
569
+
570
+ # get the remaining region except for the original prompt
571
+ attention_maps_averaged_normalized = []
572
+ attention_maps_averaged_sum = torch.cat(attention_maps_thres).sum(0) + _EPS
573
+ for obj_id, obj_token in enumerate(obj_tokens):
574
+ if obj_token[0] != -1:
575
+ attention_maps_averaged_normalized.append(
576
+ attention_maps_thres[obj_id]/attention_maps_averaged_sum)
577
+ else:
578
+ attention_map_prev = torch.stack(
579
+ attention_maps_averaged_normalized).sum(0)
580
+ attention_maps_averaged_normalized.append(1.-attention_map_prev)
581
+
582
+ plot_attention_maps(
583
+ [attention_maps_averaged, attention_maps_averaged_normalized], save_dir, seed)
584
+
585
+ attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
586
+ [1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized]
587
+ # attention_maps_averaged_normalized = attention_maps_averaged_normalized.unsqueeze(1).repeat([1, 4, 1, 1]).cuda()
588
+ return attention_maps_averaged_normalized
589
+
590
+
591
+ def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, height, obj_tokens, kmeans_seed=0, tokens_vis=None,
592
+ preprocess=False, segment_threshold=0.3, num_segments=5, return_vis=False, save_attn=False):
593
+ r"""Function to visualize attention maps.
594
+ Args:
595
+ save_dir (str): Path to save attention maps
596
+ batch_size (int): Batch size
597
+ sampler_order (int): Sampler order
598
+ """
599
+
600
+ resolution = 32
601
+ # attn_maps_1024 = [attn_map for attn_map in selfattn_maps.values(
602
+ # ) if attn_map.shape[1] == resolution**2]
603
+ # attn_maps_1024 = torch.cat(attn_maps_1024).mean(0).cpu().numpy()
604
+ attn_maps_1024 = {8: [], 16: [], 32: [], 64: []}
605
+ for attn_map in selfattn_maps.values():
606
+ resolution_map = np.sqrt(attn_map.shape[1]).astype(int)
607
+ if resolution_map != resolution:
608
+ continue
609
+ # attn_map = torch.nn.functional.interpolate(rearrange(attn_map, '1 c (h w) -> 1 c h w', h=resolution_map), (resolution, resolution),
610
+ # mode='bicubic', antialias=True)
611
+ # attn_map = rearrange(attn_map, '1 (h w) a b -> 1 (a b) h w', h=resolution_map)
612
+ attn_map = attn_map.reshape(
613
+ 1, resolution_map, resolution_map, resolution_map**2).permute([3, 0, 1, 2]).float()
614
+ attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution),
615
+ mode='bicubic', antialias=True)
616
+ attn_maps_1024[resolution_map].append(attn_map.permute([1, 2, 3, 0]).reshape(
617
+ 1, resolution**2, resolution_map**2))
618
+ attn_maps_1024 = torch.cat([torch.cat(v).mean(0).cpu()
619
+ for v in attn_maps_1024.values() if len(v) > 0], -1).numpy()
620
+ if save_attn:
621
+ print('saving self-attention maps...', attn_maps_1024.shape)
622
+ torch.save(torch.from_numpy(attn_maps_1024),
623
+ 'results/maps/selfattn_maps.pth')
624
+ seed_everything(kmeans_seed)
625
+ # import ipdb;ipdb.set_trace()
626
+ # kmeans = KMeans(n_clusters=num_segments,
627
+ # n_init=10).fit(attn_maps_1024)
628
+ # clusters = kmeans.labels_
629
+ # clusters = clusters.reshape(resolution, resolution)
630
+ # mesh = np.array(np.meshgrid(range(resolution), range(resolution), indexing='ij'), dtype=np.float32)/resolution
631
+ # dists = mesh.reshape(2, -1).T
632
+ # delta = 0.01
633
+ # spatial_sim = rbf_kernel(dists, dists)*delta
634
+ sc = SpectralClustering(num_segments, affinity='precomputed', n_init=100,
635
+ assign_labels='kmeans')
636
+ clusters = sc.fit_predict(attn_maps_1024)
637
+ clusters = clusters.reshape(resolution, resolution)
638
+ fig = plt.figure()
639
+ plt.imshow(clusters)
640
+ plt.axis('off')
641
+ plt.savefig(os.path.join(save_dir, 'segmentation_k%d_seed%d.jpg' % (num_segments, kmeans_seed)),
642
+ bbox_inches='tight', pad_inches=0)
643
+ if return_vis:
644
+ canvas = fig.canvas
645
+ canvas.draw()
646
+ cav_width, cav_height = canvas.get_width_height()
647
+ segments_vis = np.frombuffer(canvas.tostring_rgb(),
648
+ dtype='uint8').reshape((cav_height, cav_width, 3))
649
+
650
+ plt.close()
651
+
652
+ # label the segmentation mask using cross-attention maps
653
+ cross_attn_maps_1024 = []
654
+ for attn_map in crossattn_maps.values():
655
+ resolution_map = np.sqrt(attn_map.shape[1]).astype(int)
656
+ # if resolution_map != 16:
657
+ # continue
658
+ attn_map = attn_map.reshape(
659
+ 1, resolution_map, resolution_map, -1).permute([0, 3, 1, 2]).float()
660
+ attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution),
661
+ mode='bicubic', antialias=True)
662
+ cross_attn_maps_1024.append(attn_map.permute([0, 2, 3, 1]))
663
+
664
+ cross_attn_maps_1024 = torch.cat(
665
+ cross_attn_maps_1024).mean(0).cpu().numpy()
666
+ normalized_span_maps = []
667
+ for token_ids in obj_tokens:
668
+ token_ids = torch.clip(token_ids, 0, 76)
669
+ span_token_maps = cross_attn_maps_1024[:, :, token_ids.numpy()]
670
+ normalized_span_map = np.zeros_like(span_token_maps)
671
+ for i in range(span_token_maps.shape[-1]):
672
+ curr_noun_map = span_token_maps[:, :, i]
673
+ normalized_span_map[:, :, i] = (
674
+ # curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max()
675
+ curr_noun_map - np.abs(curr_noun_map.min())) / (curr_noun_map.max()-curr_noun_map.min())
676
+ normalized_span_maps.append(normalized_span_map)
677
+ foreground_token_maps = [np.zeros([clusters.shape[0], clusters.shape[1]]).squeeze(
678
+ ) for normalized_span_map in normalized_span_maps]
679
+ background_map = np.zeros([clusters.shape[0], clusters.shape[1]]).squeeze()
680
+ for c in range(num_segments):
681
+ cluster_mask = np.zeros_like(clusters)
682
+ cluster_mask[clusters == c] = 1.
683
+ is_foreground = False
684
+ for normalized_span_map, foreground_nouns_map, token_ids in zip(normalized_span_maps, foreground_token_maps, obj_tokens):
685
+ score_maps = [cluster_mask * normalized_span_map[:, :, i]
686
+ for i in range(len(token_ids))]
687
+ scores = [score_map.sum() / cluster_mask.sum()
688
+ for score_map in score_maps]
689
+ if max(scores) > segment_threshold:
690
+ foreground_nouns_map += cluster_mask
691
+ is_foreground = True
692
+ if not is_foreground:
693
+ background_map += cluster_mask
694
+ foreground_token_maps.append(background_map)
695
+
696
+ # resize the token maps and visualization
697
+ resized_token_maps = torch.cat([torch.nn.functional.interpolate(torch.from_numpy(token_map).unsqueeze(0).unsqueeze(
698
+ 0), (height, width), mode='bicubic', antialias=True)[0] for token_map in foreground_token_maps]).clamp(0, 1)
699
+
700
+ resized_token_maps = resized_token_maps / \
701
+ (resized_token_maps.sum(0, True)+1e-8)
702
+ resized_token_maps = [token_map.unsqueeze(
703
+ 0) for token_map in resized_token_maps]
704
+ foreground_token_maps = [token_map[None, :, :]
705
+ for token_map in foreground_token_maps]
706
+ if preprocess:
707
+ selem = square(5)
708
+ eroded_token_maps = torch.stack([torch.from_numpy(erosion(skimage.img_as_float(
709
+ map[0].numpy()*255), selem))/255. for map in resized_token_maps[:-1]]).clamp(0, 1)
710
+ # import ipdb; ipdb.set_trace()
711
+ eroded_background_maps = (1-eroded_token_maps.sum(0, True)).clamp(0, 1)
712
+ eroded_token_maps = torch.cat([eroded_token_maps, eroded_background_maps])
713
+ eroded_token_maps = eroded_token_maps / (eroded_token_maps.sum(0, True)+1e-8)
714
+ resized_token_maps = [token_map.unsqueeze(
715
+ 0) for token_map in eroded_token_maps]
716
+
717
+ token_maps_vis = plot_attention_maps([foreground_token_maps, resized_token_maps], obj_tokens,
718
+ save_dir, kmeans_seed, tokens_vis)
719
+ resized_token_maps = [token_map.unsqueeze(1).repeat(
720
+ [1, 4, 1, 1]).to(attn_map.dtype).cuda() for token_map in resized_token_maps]
721
+ if return_vis:
722
+ return resized_token_maps, segments_vis, token_maps_vis
723
+ else:
724
+ return resized_token_maps
utils/richtext_utils.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import random
5
+ import numpy as np
6
+
7
+ COLORS = {
8
+ 'brown': [165, 42, 42],
9
+ 'red': [255, 0, 0],
10
+ 'pink': [253, 108, 158],
11
+ 'orange': [255, 165, 0],
12
+ 'yellow': [255, 255, 0],
13
+ 'purple': [128, 0, 128],
14
+ 'green': [0, 128, 0],
15
+ 'blue': [0, 0, 255],
16
+ 'white': [255, 255, 255],
17
+ 'gray': [128, 128, 128],
18
+ 'black': [0, 0, 0],
19
+ }
20
+
21
+
22
+ def seed_everything(seed):
23
+ random.seed(seed)
24
+ os.environ['PYTHONHASHSEED'] = str(seed)
25
+ np.random.seed(seed)
26
+ torch.manual_seed(seed)
27
+ torch.cuda.manual_seed(seed)
28
+
29
+
30
+ def hex_to_rgb(hex_string, return_nearest_color=False):
31
+ r"""
32
+ Covert Hex triplet to RGB triplet.
33
+ """
34
+ # Remove '#' symbol if present
35
+ hex_string = hex_string.lstrip('#')
36
+ # Convert hex values to integers
37
+ red = int(hex_string[0:2], 16)
38
+ green = int(hex_string[2:4], 16)
39
+ blue = int(hex_string[4:6], 16)
40
+ rgb = torch.FloatTensor((red, green, blue))[None, :, None, None]/255.
41
+ if return_nearest_color:
42
+ nearest_color = find_nearest_color(rgb)
43
+ return rgb.cuda(), nearest_color
44
+ return rgb.cuda()
45
+
46
+
47
+ def find_nearest_color(rgb):
48
+ r"""
49
+ Find the nearest neighbor color given the RGB value.
50
+ """
51
+ if isinstance(rgb, list) or isinstance(rgb, tuple):
52
+ rgb = torch.FloatTensor(rgb)[None, :, None, None]/255.
53
+ color_distance = torch.FloatTensor([np.linalg.norm(
54
+ rgb - torch.FloatTensor(COLORS[color])[None, :, None, None]/255.) for color in COLORS.keys()])
55
+ nearest_color = list(COLORS.keys())[torch.argmin(color_distance).item()]
56
+ return nearest_color
57
+
58
+
59
+ def font2style(font):
60
+ r"""
61
+ Convert the font name to the style name.
62
+ """
63
+ return {'mirza': 'Claud Monet, impressionism, oil on canvas',
64
+ 'roboto': 'Ukiyoe',
65
+ 'cursive': 'Cyber Punk, futuristic, blade runner, william gibson, trending on artstation hq',
66
+ 'sofia': 'Pop Art, masterpiece, andy warhol',
67
+ 'slabo': 'Vincent Van Gogh',
68
+ 'inconsolata': 'Pixel Art, 8 bits, 16 bits',
69
+ 'ubuntu': 'Rembrandt',
70
+ 'Monoton': 'neon art, colorful light, highly details, octane render',
71
+ 'Akronim': 'Abstract Cubism, Pablo Picasso', }[font]
72
+
73
+
74
+ def parse_json(json_str):
75
+ r"""
76
+ Convert the JSON string to attributes.
77
+ """
78
+ # initialze region-base attributes.
79
+ base_text_prompt = ''
80
+ style_text_prompts = []
81
+ footnote_text_prompts = []
82
+ footnote_target_tokens = []
83
+ color_text_prompts = []
84
+ color_rgbs = []
85
+ color_names = []
86
+ size_text_prompts_and_sizes = []
87
+
88
+ # parse the attributes from JSON.
89
+ prev_style = None
90
+ prev_color_rgb = None
91
+ use_grad_guidance = False
92
+ for span in json_str['ops']:
93
+ text_prompt = span['insert'].rstrip('\n')
94
+ base_text_prompt += span['insert'].rstrip('\n')
95
+ if text_prompt == ' ':
96
+ continue
97
+ if 'attributes' in span:
98
+ if 'font' in span['attributes']:
99
+ style = font2style(span['attributes']['font'])
100
+ if prev_style == style:
101
+ prev_text_prompt = style_text_prompts[-1].split('in the style of')[
102
+ 0]
103
+ style_text_prompts[-1] = prev_text_prompt + \
104
+ ' ' + text_prompt + f' in the style of {style}'
105
+ else:
106
+ style_text_prompts.append(
107
+ text_prompt + f' in the style of {style}')
108
+ prev_style = style
109
+ else:
110
+ prev_style = None
111
+ if 'link' in span['attributes']:
112
+ footnote_text_prompts.append(span['attributes']['link'])
113
+ footnote_target_tokens.append(text_prompt)
114
+ font_size = 1
115
+ if 'size' in span['attributes'] and 'strike' not in span['attributes']:
116
+ font_size = float(span['attributes']['size'][:-2])/3.
117
+ elif 'size' in span['attributes'] and 'strike' in span['attributes']:
118
+ font_size = -float(span['attributes']['size'][:-2])/3.
119
+ elif 'size' not in span['attributes'] and 'strike' not in span['attributes']:
120
+ font_size = 1
121
+ if 'color' in span['attributes']:
122
+ use_grad_guidance = True
123
+ color_rgb, nearest_color = hex_to_rgb(
124
+ span['attributes']['color'], True)
125
+ if prev_color_rgb == color_rgb:
126
+ prev_text_prompt = color_text_prompts[-1]
127
+ color_text_prompts[-1] = prev_text_prompt + \
128
+ ' ' + text_prompt
129
+ else:
130
+ color_rgbs.append(color_rgb)
131
+ color_names.append(nearest_color)
132
+ color_text_prompts.append(text_prompt)
133
+ if font_size != 1:
134
+ size_text_prompts_and_sizes.append([text_prompt, font_size])
135
+ return base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
136
+ color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance
137
+
138
+
139
+ def get_region_diffusion_input(model, base_text_prompt, style_text_prompts, footnote_text_prompts,
140
+ footnote_target_tokens, color_text_prompts, color_names):
141
+ r"""
142
+ Algorithm 1 in the paper.
143
+ """
144
+ region_text_prompts = []
145
+ region_target_token_ids = []
146
+ base_tokens = model.tokenizer._tokenize(base_text_prompt)
147
+ # process the style text prompt
148
+ for text_prompt in style_text_prompts:
149
+ region_text_prompts.append(text_prompt)
150
+ region_target_token_ids.append([])
151
+ style_tokens = model.tokenizer._tokenize(
152
+ text_prompt.split('in the style of')[0])
153
+ for style_token in style_tokens:
154
+ region_target_token_ids[-1].append(
155
+ base_tokens.index(style_token)+1)
156
+
157
+ # process the complementary text prompt
158
+ for footnote_text_prompt, text_prompt in zip(footnote_text_prompts, footnote_target_tokens):
159
+ region_target_token_ids.append([])
160
+ region_text_prompts.append(footnote_text_prompt)
161
+ style_tokens = model.tokenizer._tokenize(text_prompt)
162
+ for style_token in style_tokens:
163
+ region_target_token_ids[-1].append(
164
+ base_tokens.index(style_token)+1)
165
+
166
+ # process the color text prompt
167
+ for color_text_prompt, color_name in zip(color_text_prompts, color_names):
168
+ region_target_token_ids.append([])
169
+ region_text_prompts.append(color_name+' '+color_text_prompt)
170
+ style_tokens = model.tokenizer._tokenize(color_text_prompt)
171
+ for style_token in style_tokens:
172
+ region_target_token_ids[-1].append(
173
+ base_tokens.index(style_token)+1)
174
+
175
+ # process the remaining tokens without any attributes
176
+ region_text_prompts.append(base_text_prompt)
177
+ region_target_token_ids_all = [
178
+ id for ids in region_target_token_ids for id in ids]
179
+ target_token_ids_rest = [id for id in range(
180
+ 1, len(base_tokens)+1) if id not in region_target_token_ids_all]
181
+ region_target_token_ids.append(target_token_ids_rest)
182
+
183
+ region_target_token_ids = [torch.LongTensor(
184
+ obj_token_id) for obj_token_id in region_target_token_ids]
185
+ return region_text_prompts, region_target_token_ids, base_tokens
186
+
187
+
188
+ def get_attention_control_input(model, base_tokens, size_text_prompts_and_sizes):
189
+ r"""
190
+ Control the token impact using font sizes.
191
+ """
192
+ word_pos = []
193
+ font_sizes = []
194
+ for text_prompt, font_size in size_text_prompts_and_sizes:
195
+ size_tokens = model.tokenizer._tokenize(text_prompt)
196
+ for size_token in size_tokens:
197
+ word_pos.append(base_tokens.index(size_token)+1)
198
+ font_sizes.append(font_size)
199
+ if len(word_pos) > 0:
200
+ word_pos = torch.LongTensor(word_pos).cuda()
201
+ font_sizes = torch.FloatTensor(font_sizes).cuda()
202
+ else:
203
+ word_pos = None
204
+ font_sizes = None
205
+ text_format_dict = {
206
+ 'word_pos': word_pos,
207
+ 'font_size': font_sizes,
208
+ }
209
+ return text_format_dict
210
+
211
+
212
+ def get_gradient_guidance_input(model, base_tokens, color_text_prompts, color_rgbs, text_format_dict,
213
+ guidance_start_step=999, color_guidance_weight=1):
214
+ r"""
215
+ Control the token impact using font sizes.
216
+ """
217
+ color_target_token_ids = []
218
+ for text_prompt in color_text_prompts:
219
+ color_target_token_ids.append([])
220
+ color_tokens = model.tokenizer._tokenize(text_prompt)
221
+ for color_token in color_tokens:
222
+ color_target_token_ids[-1].append(base_tokens.index(color_token)+1)
223
+ color_target_token_ids_all = [
224
+ id for ids in color_target_token_ids for id in ids]
225
+ color_target_token_ids_rest = [id for id in range(
226
+ 1, len(base_tokens)+1) if id not in color_target_token_ids_all]
227
+ color_target_token_ids.append(color_target_token_ids_rest)
228
+ color_target_token_ids = [torch.LongTensor(
229
+ obj_token_id) for obj_token_id in color_target_token_ids]
230
+
231
+ text_format_dict['target_RGB'] = color_rgbs
232
+ text_format_dict['guidance_start_step'] = guidance_start_step
233
+ text_format_dict['color_guidance_weight'] = color_guidance_weight
234
+ return text_format_dict, color_target_token_ids